diff --git a/.clang-format b/.clang-format index 5f533acfb..b715e3c98 100644 --- a/.clang-format +++ b/.clang-format @@ -16,10 +16,14 @@ SortIncludes: false IndentWidth: 4 TabWidth: 4 ObjCBlockIndentWidth: 4 -AlignAfterOpenBracket: DontAlign UseTab: Never PointerAlignment: Left SpaceAfterTemplateKeyword: false AlignEscapedNewlines: DontAlign AlwaysBreakTemplateDeclarations: Yes MaxEmptyLinesToKeep: 10 +AllowAllParametersOfDeclarationOnNextLine: false +AlignAfterOpenBracket: BlockIndent +BinPackArguments: false +BinPackParameters: false +PenaltyReturnTypeOnItsOwnLine: 10000 diff --git a/.github/ISSUE_TEMPLATE/config.yml b/.github/ISSUE_TEMPLATE/config.yml index 55f9db392..c7fb782a6 100644 --- a/.github/ISSUE_TEMPLATE/config.yml +++ b/.github/ISSUE_TEMPLATE/config.yml @@ -1,5 +1,5 @@ blank_issues_enabled: false contact_links: - name: Questions - url: https://github.com/Roblox/luau/discussions + url: https://github.com/luau-lang/luau/discussions about: Please use GitHub Discussions if you have questions or need help. diff --git a/.github/codecov.yml b/.github/codecov.yml index 69cb76019..7e0dee174 100644 --- a/.github/codecov.yml +++ b/.github/codecov.yml @@ -1 +1,7 @@ comment: false +coverage: + status: + patch: false + project: + default: + informational: true diff --git a/.github/workflows/benchmark-dev.yml b/.github/workflows/benchmark-dev.yml deleted file mode 100644 index b6115acc5..000000000 --- a/.github/workflows/benchmark-dev.yml +++ /dev/null @@ -1,185 +0,0 @@ -name: benchmark-dev - -on: - push: - branches: - - master - - paths-ignore: - - "docs/**" - - "papers/**" - - "rfcs/**" - - "*.md" - -jobs: - windows: - name: windows-${{matrix.arch}} - strategy: - fail-fast: false - matrix: - os: [windows-latest] - arch: [Win32, x64] - bench: - - { - script: "run-benchmarks", - timeout: 12, - title: "Luau Benchmarks", - } - benchResultsRepo: - - { name: "luau-lang/benchmark-data", branch: "main" } - - runs-on: ${{ matrix.os }} - steps: - - name: Checkout Luau repository - uses: actions/checkout@v3 - - - name: Build Luau - shell: bash # necessary for fail-fast - run: | - mkdir build && cd build - cmake .. -DCMAKE_BUILD_TYPE=Release - cmake --build . --target Luau.Repl.CLI --config Release - cmake --build . --target Luau.Analyze.CLI --config Release - - - name: Move build files to root - run: | - move build/Release/* . - - - uses: actions/setup-python@v3 - with: - python-version: "3.9" - architecture: "x64" - - - name: Install python dependencies - run: | - python -m pip install requests - python -m pip install --user numpy scipy matplotlib ipython jupyter pandas sympy nose - - - name: Run benchmark - run: | - python bench/bench.py | tee ${{ matrix.bench.script }}-output.txt - - - name: Push benchmark results - id: pushBenchmarkAttempt1 - continue-on-error: true - uses: ./.github/workflows/push-results - with: - repository: ${{ matrix.benchResultsRepo.name }} - branch: ${{ matrix.benchResultsRepo.branch }} - token: ${{ secrets.BENCH_GITHUB_TOKEN }} - path: "./gh-pages" - bench_name: "${{ matrix.bench.title }} (Windows ${{matrix.arch}})" - bench_tool: "benchmarkluau" - bench_output_file_path: "./${{ matrix.bench.script }}-output.txt" - bench_external_data_json_path: "./gh-pages/dev/bench/data-${{ matrix.os }}.json" - - - name: Push benchmark results (Attempt 2) - id: pushBenchmarkAttempt2 - continue-on-error: true - if: steps.pushBenchmarkAttempt1.outcome == 'failure' - uses: ./.github/workflows/push-results - with: - repository: ${{ matrix.benchResultsRepo.name }} - branch: ${{ matrix.benchResultsRepo.branch }} - token: ${{ secrets.BENCH_GITHUB_TOKEN }} - path: "./gh-pages" - bench_name: "${{ matrix.bench.title }} (Windows ${{matrix.arch}})" - bench_tool: "benchmarkluau" - bench_output_file_path: "./${{ matrix.bench.script }}-output.txt" - bench_external_data_json_path: "./gh-pages/dev/bench/data-${{ matrix.os }}.json" - - - name: Push benchmark results (Attempt 3) - id: pushBenchmarkAttempt3 - continue-on-error: true - if: steps.pushBenchmarkAttempt2.outcome == 'failure' - uses: ./.github/workflows/push-results - with: - repository: ${{ matrix.benchResultsRepo.name }} - branch: ${{ matrix.benchResultsRepo.branch }} - token: ${{ secrets.BENCH_GITHUB_TOKEN }} - path: "./gh-pages" - bench_name: "${{ matrix.bench.title }} (Windows ${{matrix.arch}})" - bench_tool: "benchmarkluau" - bench_output_file_path: "./${{ matrix.bench.script }}-output.txt" - bench_external_data_json_path: "./gh-pages/dev/bench/data-${{ matrix.os }}.json" - - unix: - name: ${{matrix.os}} - strategy: - fail-fast: false - matrix: - os: [ubuntu-20.04, macos-latest] - bench: - - { - script: "run-benchmarks", - timeout: 12, - title: "Luau Benchmarks", - } - benchResultsRepo: - - { name: "luau-lang/benchmark-data", branch: "main" } - - runs-on: ${{ matrix.os }} - steps: - - name: Checkout Luau repository - uses: actions/checkout@v3 - - - name: Build Luau - run: make config=release luau luau-analyze - - - uses: actions/setup-python@v3 - with: - python-version: "3.9" - architecture: "x64" - - - name: Install python dependencies - run: | - python -m pip install requests - python -m pip install --user numpy scipy matplotlib ipython jupyter pandas sympy nose - - - name: Run benchmark - run: | - python bench/bench.py | tee ${{ matrix.bench.script }}-output.txt - - - name: Push benchmark results - id: pushBenchmarkAttempt1 - continue-on-error: true - uses: ./.github/workflows/push-results - with: - repository: ${{ matrix.benchResultsRepo.name }} - branch: ${{ matrix.benchResultsRepo.branch }} - token: ${{ secrets.BENCH_GITHUB_TOKEN }} - path: "./gh-pages" - bench_name: ${{ matrix.bench.title }} - bench_tool: "benchmarkluau" - bench_output_file_path: "./${{ matrix.bench.script }}-output.txt" - bench_external_data_json_path: "./gh-pages/dev/bench/data-${{ matrix.os }}.json" - - - name: Push benchmark results (Attempt 2) - id: pushBenchmarkAttempt2 - continue-on-error: true - if: steps.pushBenchmarkAttempt1.outcome == 'failure' - uses: ./.github/workflows/push-results - with: - repository: ${{ matrix.benchResultsRepo.name }} - branch: ${{ matrix.benchResultsRepo.branch }} - token: ${{ secrets.BENCH_GITHUB_TOKEN }} - path: "./gh-pages" - bench_name: ${{ matrix.bench.title }} - bench_tool: "benchmarkluau" - bench_output_file_path: "./${{ matrix.bench.script }}-output.txt" - bench_external_data_json_path: "./gh-pages/dev/bench/data-${{ matrix.os }}.json" - - - name: Push benchmark results (Attempt 3) - id: pushBenchmarkAttempt3 - continue-on-error: true - if: steps.pushBenchmarkAttempt2.outcome == 'failure' - uses: ./.github/workflows/push-results - with: - repository: ${{ matrix.benchResultsRepo.name }} - branch: ${{ matrix.benchResultsRepo.branch }} - token: ${{ secrets.BENCH_GITHUB_TOKEN }} - path: "./gh-pages" - bench_name: ${{ matrix.bench.title }} - bench_tool: "benchmarkluau" - bench_output_file_path: "./${{ matrix.bench.script }}-output.txt" - bench_external_data_json_path: "./gh-pages/dev/bench/data-${{ matrix.os }}.json" diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml index 7fb88e217..72a0c9ffe 100644 --- a/.github/workflows/benchmark.yml +++ b/.github/workflows/benchmark.yml @@ -25,6 +25,7 @@ jobs: - name: Install valgrind run: | + sudo apt-get update sudo apt-get install valgrind - name: Build Luau (gcc) @@ -41,7 +42,7 @@ jobs: - name: Build Luau (clang) run: | make config=profile clean - CXX=clang++ make config=profile luau luau-analyze + CXX=clang++ make config=profile luau luau-analyze luau-compile - name: Run benchmark (bench-gcc) run: | @@ -62,22 +63,26 @@ jobs: } valgrind --tool=callgrind ./luau-analyze --mode=nonstrict bench/other/LuauPolyfillMap.lua 2>&1 | filter map-nonstrict | tee -a analyze-output.txt valgrind --tool=callgrind ./luau-analyze --mode=strict bench/other/LuauPolyfillMap.lua 2>&1 | filter map-strict | tee -a analyze-output.txt + valgrind --tool=callgrind ./luau-analyze --mode=strict --fflags=LuauSolverV2 bench/other/LuauPolyfillMap.lua 2>&1 | filter map-dcr | tee -a analyze-output.txt valgrind --tool=callgrind ./luau-analyze --mode=nonstrict bench/other/regex.lua 2>&1 | filter regex-nonstrict | tee -a analyze-output.txt valgrind --tool=callgrind ./luau-analyze --mode=strict bench/other/regex.lua 2>&1 | filter regex-strict | tee -a analyze-output.txt + valgrind --tool=callgrind ./luau-analyze --mode=strict --fflags=LuauSolverV2 bench/other/regex.lua 2>&1 | filter regex-dcr | tee -a analyze-output.txt - name: Run benchmark (compile) run: | filter() { - awk '/.*I\s+refs:\s+[0-9,]+/ {gsub(",", "", $4); X=$4} END {print "SUCCESS: '$1' : " X/1e7 "ms +/- 0% on luau --compile"}' + awk '/.*I\s+refs:\s+[0-9,]+/ {gsub(",", "", $4); X=$4} END {print "SUCCESS: '$1' : " X/1e7 "ms +/- 0% on luau-compile"}' } - valgrind --tool=callgrind ./luau --compile=null -O0 bench/other/LuauPolyfillMap.lua 2>&1 | filter map-O0 | tee -a compile-output.txt - valgrind --tool=callgrind ./luau --compile=null -O1 bench/other/LuauPolyfillMap.lua 2>&1 | filter map-O1 | tee -a compile-output.txt - valgrind --tool=callgrind ./luau --compile=null -O2 bench/other/LuauPolyfillMap.lua 2>&1 | filter map-O2 | tee -a compile-output.txt - valgrind --tool=callgrind ./luau --compile=codegennull -O2 bench/other/LuauPolyfillMap.lua 2>&1 | filter map-O2-codegen | tee -a compile-output.txt - valgrind --tool=callgrind ./luau --compile=null -O0 bench/other/regex.lua 2>&1 | filter regex-O0 | tee -a compile-output.txt - valgrind --tool=callgrind ./luau --compile=null -O1 bench/other/regex.lua 2>&1 | filter regex-O1 | tee -a compile-output.txt - valgrind --tool=callgrind ./luau --compile=null -O2 bench/other/regex.lua 2>&1 | filter regex-O2 | tee -a compile-output.txt - valgrind --tool=callgrind ./luau --compile=codegennull -O2 bench/other/regex.lua 2>&1 | filter regex-O2-codegen | tee -a compile-output.txt + valgrind --tool=callgrind ./luau-compile --null -O0 bench/other/LuauPolyfillMap.lua 2>&1 | filter map-O0 | tee -a compile-output.txt + valgrind --tool=callgrind ./luau-compile --null -O1 bench/other/LuauPolyfillMap.lua 2>&1 | filter map-O1 | tee -a compile-output.txt + valgrind --tool=callgrind ./luau-compile --null -O2 bench/other/LuauPolyfillMap.lua 2>&1 | filter map-O2 | tee -a compile-output.txt + valgrind --tool=callgrind ./luau-compile --codegennull -O2 bench/other/LuauPolyfillMap.lua 2>&1 | filter map-O2-codegen | tee -a compile-output.txt + valgrind --tool=callgrind ./luau-compile --codegennull -O2 -t1 bench/other/LuauPolyfillMap.lua 2>&1 | filter map-O2-t1-codegen | tee -a compile-output.txt + valgrind --tool=callgrind ./luau-compile --null -O0 bench/other/regex.lua 2>&1 | filter regex-O0 | tee -a compile-output.txt + valgrind --tool=callgrind ./luau-compile --null -O1 bench/other/regex.lua 2>&1 | filter regex-O1 | tee -a compile-output.txt + valgrind --tool=callgrind ./luau-compile --null -O2 bench/other/regex.lua 2>&1 | filter regex-O2 | tee -a compile-output.txt + valgrind --tool=callgrind ./luau-compile --codegennull -O2 bench/other/regex.lua 2>&1 | filter regex-O2-codegen | tee -a compile-output.txt + valgrind --tool=callgrind ./luau-compile --codegennull -O2 -t1 bench/other/regex.lua 2>&1 | filter regex-O2-t1-codegen | tee -a compile-output.txt - name: Checkout benchmark results uses: actions/checkout@v3 @@ -122,7 +127,7 @@ jobs: - name: Store results (compile) uses: Roblox/rhysd-github-action-benchmark@v-luau with: - name: luau --compile + name: luau-compile tool: "benchmarkluau" output-file-path: ./compile-output.txt external-data-json-path: ./gh-pages/compile.json diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index b084a4f53..7a2b5f105 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -20,11 +20,14 @@ jobs: unix: strategy: matrix: - os: [{name: ubuntu, version: ubuntu-20.04}, {name: macos, version: macos-latest}] + os: [{name: ubuntu, version: ubuntu-latest}, {name: macos, version: macos-latest}, {name: macos-arm, version: macos-14}] name: ${{matrix.os.name}} runs-on: ${{matrix.os.version}} steps: - uses: actions/checkout@v1 + - name: work around ASLR+ASAN compatibility + run: sudo sysctl -w vm.mmap_rnd_bits=28 + if: matrix.os.name == 'ubuntu' - name: make tests run: | make -j2 config=sanitize werror=1 native=1 luau-tests @@ -42,9 +45,10 @@ jobs: ./luau-tests -ts=Conformance --codegen -O2 --fflags=true - name: make cli run: | - make -j2 config=sanitize werror=1 luau luau-analyze # match config with tests to improve build time + make -j2 config=sanitize werror=1 luau luau-analyze luau-compile # match config with tests to improve build time ./luau tests/conformance/assert.lua ./luau-analyze tests/conformance/assert.lua + ./luau-compile tests/conformance/assert.lua windows: runs-on: windows-latest @@ -76,12 +80,13 @@ jobs: - name: cmake cli shell: bash # necessary for fail-fast run: | - cmake --build . --target Luau.Repl.CLI Luau.Analyze.CLI --config Debug # match config with tests to improve build time + cmake --build . --target Luau.Repl.CLI Luau.Analyze.CLI Luau.Compile.CLI --config Debug # match config with tests to improve build time Debug/luau tests/conformance/assert.lua Debug/luau-analyze tests/conformance/assert.lua + Debug/luau-compile tests/conformance/assert.lua coverage: - runs-on: ubuntu-20.04 + runs-on: ubuntu-20.04 # needed for clang++-10 to avoid gcov compatibility issues steps: - uses: actions/checkout@v2 - name: install @@ -97,7 +102,7 @@ jobs: token: ${{ secrets.CODECOV_TOKEN }} web: - runs-on: ubuntu-20.04 + runs-on: ubuntu-latest steps: - uses: actions/checkout@v1 - uses: actions/checkout@v2 diff --git a/.github/workflows/new-release.yml b/.github/workflows/new-release.yml index 5fe7f7920..62c83a180 100644 --- a/.github/workflows/new-release.yml +++ b/.github/workflows/new-release.yml @@ -12,7 +12,7 @@ permissions: jobs: create-release: - runs-on: ubuntu-20.04 + runs-on: ubuntu-latest outputs: upload_url: ${{ steps.create_release.outputs.upload_url }} steps: @@ -29,7 +29,7 @@ jobs: build: needs: ["create-release"] strategy: - matrix: + matrix: # using ubuntu-20.04 to build a Linux binary targeting older glibc to improve compatibility os: [{name: ubuntu, version: ubuntu-20.04}, {name: macos, version: macos-latest}, {name: windows, version: windows-latest}] name: ${{matrix.os.name}} runs-on: ${{matrix.os.version}} @@ -38,7 +38,7 @@ jobs: - name: configure run: cmake . -DCMAKE_BUILD_TYPE=Release - name: build - run: cmake --build . --target Luau.Repl.CLI Luau.Analyze.CLI --config Release -j 2 + run: cmake --build . --target Luau.Repl.CLI Luau.Analyze.CLI Luau.Compile.CLI Luau.Ast.CLI --config Release -j 2 - name: pack if: matrix.os.name != 'windows' run: zip luau-${{matrix.os.name}}.zip luau* @@ -56,7 +56,7 @@ jobs: web: needs: ["create-release"] - runs-on: ubuntu-20.04 + runs-on: ubuntu-latest steps: - uses: actions/checkout@v1 - uses: actions/checkout@v2 diff --git a/.github/workflows/push-results/action.yml b/.github/workflows/push-results/action.yml deleted file mode 100644 index b5ffebee2..000000000 --- a/.github/workflows/push-results/action.yml +++ /dev/null @@ -1,63 +0,0 @@ -name: Checkout & push results -description: Checkout a given repo and push results to GitHub -inputs: - repository: - required: true - type: string - description: The benchmark results repository to check out - branch: - required: true - type: string - description: The benchmark results repository's branch to check out - token: - required: true - type: string - description: The GitHub token to use for pushing results - path: - required: true - type: string - description: The path to check out the results repository to - bench_name: - required: true - type: string - bench_tool: - required: true - type: string - bench_output_file_path: - required: true - type: string - bench_external_data_json_path: - required: true - type: string - -runs: - using: "composite" - steps: - - name: Checkout repository - uses: actions/checkout@v3 - with: - repository: ${{ inputs.repository }} - ref: ${{ inputs.branch }} - token: ${{ inputs.token }} - path: ${{ inputs.path }} - - - name: Store results - uses: Roblox/rhysd-github-action-benchmark@v-luau - with: - name: ${{ inputs.bench_name }} - tool: ${{ inputs.bench_tool }} - gh-pages-branch: ${{ inputs.branch }} - output-file-path: ${{ inputs.bench_output_file_path }} - external-data-json-path: ${{ inputs.bench_external_data_json_path }} - - - name: Push benchmark results - shell: bash - run: | - echo "Pushing benchmark results..." - cd gh-pages - git config user.name github-actions - git config user.email github@users.noreply.github.com - git add *.json - git commit -m "Add benchmarks results for ${{ github.sha }}" - git push - cd .. diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 497483a79..5e18eb68f 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -13,7 +13,7 @@ on: jobs: build: strategy: - matrix: + matrix: # using ubuntu-20.04 to build a Linux binary targeting older glibc to improve compatibility os: [{name: ubuntu, version: ubuntu-20.04}, {name: macos, version: macos-latest}, {name: windows, version: windows-latest}] name: ${{matrix.os.name}} runs-on: ${{matrix.os.version}} @@ -22,20 +22,21 @@ jobs: - name: configure run: cmake . -DCMAKE_BUILD_TYPE=Release - name: build - run: cmake --build . --target Luau.Repl.CLI Luau.Analyze.CLI --config Release -j 2 - - uses: actions/upload-artifact@v2 + run: cmake --build . --target Luau.Repl.CLI Luau.Analyze.CLI Luau.Compile.CLI --config Release -j 2 + - uses: actions/upload-artifact@v4 if: matrix.os.name != 'windows' with: name: luau-${{matrix.os.name}} path: luau* - - uses: actions/upload-artifact@v2 + overwrite: true + - uses: actions/upload-artifact@v4 if: matrix.os.name == 'windows' with: name: luau-${{matrix.os.name}} path: Release\luau*.exe - + overwrite: true web: - runs-on: ubuntu-20.04 + runs-on: ubuntu-latest steps: - uses: actions/checkout@v1 - uses: actions/checkout@v2 @@ -52,7 +53,8 @@ jobs: source emsdk/emsdk_env.sh emcmake cmake . -DLUAU_BUILD_WEB=ON -DCMAKE_BUILD_TYPE=Release make -j2 Luau.Web - - uses: actions/upload-artifact@v2 + - uses: actions/upload-artifact@v4 with: name: Luau.Web.js path: Luau.Web.js + overwrite: true diff --git a/.gitignore b/.gitignore index deba5631a..8b124e79d 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,8 @@ /build/ /build[.-]*/ +/out +/cmake/ +/cmake[.-]*/ /coverage/ /.vs/ /.vscode/ @@ -10,6 +13,8 @@ /luau /luau-tests /luau-analyze +/luau-compile __pycache__ Makefile - +.cache +.clangd diff --git a/Analysis/include/Luau/AnyTypeSummary.h b/Analysis/include/Luau/AnyTypeSummary.h new file mode 100644 index 000000000..d99eea848 --- /dev/null +++ b/Analysis/include/Luau/AnyTypeSummary.h @@ -0,0 +1,148 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/AstQuery.h" +#include "Luau/Config.h" +#include "Luau/ModuleResolver.h" +#include "Luau/Scope.h" +#include "Luau/Variant.h" +#include "Luau/Normalize.h" +#include "Luau/TypePack.h" +#include "Luau/TypeArena.h" + +#include +#include +#include +#include + +namespace Luau +{ + +class AstStat; +class ParseError; +struct TypeError; +struct LintWarning; +struct GlobalTypes; +struct ModuleResolver; +struct ParseResult; +struct DcrLogger; + +struct TelemetryTypePair +{ + std::string annotatedType; + std::string inferredType; +}; + +struct AnyTypeSummary +{ + TypeArena arena; + + AstStatBlock* rootSrc = nullptr; + DenseHashSet seenTypeFamilyInstances{nullptr}; + + int recursionCount = 0; + + std::string root; + int strictCount = 0; + + DenseHashMap seen{nullptr}; + + AnyTypeSummary(); + + void traverse(const Module* module, AstStat* src, NotNull builtinTypes); + + std::pair checkForAnyCast(const Scope* scope, AstExprTypeAssertion* expr); + + bool containsAny(TypePackId typ); + bool containsAny(TypeId typ); + + bool isAnyCast(const Scope* scope, AstExpr* expr, const Module* module, NotNull builtinTypes); + bool isAnyCall(const Scope* scope, AstExpr* expr, const Module* module, NotNull builtinTypes); + + bool hasVariadicAnys(const Scope* scope, AstExprFunction* expr, const Module* module, NotNull builtinTypes); + bool hasArgAnys(const Scope* scope, AstExprFunction* expr, const Module* module, NotNull builtinTypes); + bool hasAnyReturns(const Scope* scope, AstExprFunction* expr, const Module* module, NotNull builtinTypes); + + TypeId checkForFamilyInhabitance(const TypeId instance, Location location); + TypeId lookupType(const AstExpr* expr, const Module* module, NotNull builtinTypes); + TypePackId reconstructTypePack(const AstArray exprs, const Module* module, NotNull builtinTypes); + + DenseHashSet seenTypeFunctionInstances{nullptr}; + TypeId lookupAnnotation(AstType* annotation, const Module* module, NotNull builtintypes); + std::optional lookupPackAnnotation(AstTypePack* annotation, const Module* module); + TypeId checkForTypeFunctionInhabitance(const TypeId instance, const Location location); + + enum Pattern : uint64_t + { + Casts, + FuncArg, + FuncRet, + FuncApp, + VarAnnot, + VarAny, + TableProp, + Alias, + Assign, + TypePk + }; + + struct TypeInfo + { + Pattern code; + std::string node; + TelemetryTypePair type; + + explicit TypeInfo(Pattern code, std::string node, TelemetryTypePair type); + }; + + struct FindReturnAncestry final : public AstVisitor + { + AstNode* currNode{nullptr}; + AstNode* stat{nullptr}; + Position rootEnd; + bool found = false; + + explicit FindReturnAncestry(AstNode* stat, Position rootEnd); + + bool visit(AstType* node) override; + bool visit(AstNode* node) override; + bool visit(AstStatFunction* node) override; + bool visit(AstStatLocalFunction* node) override; + }; + + std::vector typeInfo; + + /** + * Fabricates a scope that is a child of another scope. + * @param node the lexical node that the scope belongs to. + * @param parent the parent scope of the new scope. Must not be null. + */ + const Scope* childScope(const AstNode* node, const Scope* parent); + + std::optional matchRequire(const AstExprCall& call); + AstNode* getNode(AstStatBlock* root, AstNode* node); + const Scope* findInnerMostScope(const Location location, const Module* module); + const AstNode* findAstAncestryAtLocation(const AstStatBlock* root, AstNode* node); + + void visit(const Scope* scope, AstStat* stat, const Module* module, NotNull builtinTypes); + void visit(const Scope* scope, AstStatBlock* block, const Module* module, NotNull builtinTypes); + void visit(const Scope* scope, AstStatIf* ifStatement, const Module* module, NotNull builtinTypes); + void visit(const Scope* scope, AstStatWhile* while_, const Module* module, NotNull builtinTypes); + void visit(const Scope* scope, AstStatRepeat* repeat, const Module* module, NotNull builtinTypes); + void visit(const Scope* scope, AstStatReturn* ret, const Module* module, NotNull builtinTypes); + void visit(const Scope* scope, AstStatLocal* local, const Module* module, NotNull builtinTypes); + void visit(const Scope* scope, AstStatFor* for_, const Module* module, NotNull builtinTypes); + void visit(const Scope* scope, AstStatForIn* forIn, const Module* module, NotNull builtinTypes); + void visit(const Scope* scope, AstStatAssign* assign, const Module* module, NotNull builtinTypes); + void visit(const Scope* scope, AstStatCompoundAssign* assign, const Module* module, NotNull builtinTypes); + void visit(const Scope* scope, AstStatFunction* function, const Module* module, NotNull builtinTypes); + void visit(const Scope* scope, AstStatLocalFunction* function, const Module* module, NotNull builtinTypes); + void visit(const Scope* scope, AstStatTypeAlias* alias, const Module* module, NotNull builtinTypes); + void visit(const Scope* scope, AstStatExpr* expr, const Module* module, NotNull builtinTypes); + void visit(const Scope* scope, AstStatDeclareGlobal* declareGlobal, const Module* module, NotNull builtinTypes); + void visit(const Scope* scope, AstStatDeclareClass* declareClass, const Module* module, NotNull builtinTypes); + void visit(const Scope* scope, AstStatDeclareFunction* declareFunction, const Module* module, NotNull builtinTypes); + void visit(const Scope* scope, AstStatError* error, const Module* module, NotNull builtinTypes); +}; + +} // namespace Luau \ No newline at end of file diff --git a/Analysis/include/Luau/Anyification.h b/Analysis/include/Luau/Anyification.h index 7b6f71716..4b9c8ee93 100644 --- a/Analysis/include/Luau/Anyification.h +++ b/Analysis/include/Luau/Anyification.h @@ -4,7 +4,7 @@ #include "Luau/NotNull.h" #include "Luau/Substitution.h" -#include "Luau/Type.h" +#include "Luau/TypeFwd.h" #include @@ -19,10 +19,22 @@ using ScopePtr = std::shared_ptr; // A substitution which replaces free types by any struct Anyification : Substitution { - Anyification(TypeArena* arena, NotNull scope, NotNull builtinTypes, InternalErrorReporter* iceHandler, TypeId anyType, - TypePackId anyTypePack); - Anyification(TypeArena* arena, const ScopePtr& scope, NotNull builtinTypes, InternalErrorReporter* iceHandler, TypeId anyType, - TypePackId anyTypePack); + Anyification( + TypeArena* arena, + NotNull scope, + NotNull builtinTypes, + InternalErrorReporter* iceHandler, + TypeId anyType, + TypePackId anyTypePack + ); + Anyification( + TypeArena* arena, + const ScopePtr& scope, + NotNull builtinTypes, + InternalErrorReporter* iceHandler, + TypeId anyType, + TypePackId anyTypePack + ); NotNull scope; NotNull builtinTypes; InternalErrorReporter* iceHandler; @@ -39,4 +51,4 @@ struct Anyification : Substitution bool ignoreChildren(TypePackId ty) override; }; -} // namespace Luau \ No newline at end of file +} // namespace Luau diff --git a/Analysis/include/Luau/ApplyTypeFunction.h b/Analysis/include/Luau/ApplyTypeFunction.h index 3f5f47fd4..71430b28f 100644 --- a/Analysis/include/Luau/ApplyTypeFunction.h +++ b/Analysis/include/Luau/ApplyTypeFunction.h @@ -3,7 +3,7 @@ #include "Luau/Substitution.h" #include "Luau/TxnLog.h" -#include "Luau/Type.h" +#include "Luau/TypeFwd.h" namespace Luau { diff --git a/Analysis/include/Luau/AstQuery.h b/Analysis/include/Luau/AstQuery.h index aa7ef8d3e..633d6faf5 100644 --- a/Analysis/include/Luau/AstQuery.h +++ b/Analysis/include/Luau/AstQuery.h @@ -3,6 +3,7 @@ #include "Luau/Ast.h" #include "Luau/Documentation.h" +#include "Luau/TypeFwd.h" #include @@ -13,9 +14,6 @@ struct Binding; struct SourceModule; struct Module; -struct Type; -using TypeId = const Type*; - using ScopePtr = std::shared_ptr; struct ExprOrLocal @@ -63,9 +61,28 @@ struct ExprOrLocal AstLocal* local = nullptr; }; +struct FindFullAncestry final : public AstVisitor +{ + std::vector nodes; + Position pos; + Position documentEnd; + bool includeTypes = false; + + explicit FindFullAncestry(Position pos, Position documentEnd, bool includeTypes = false); + + bool visit(AstType* type) override; + + bool visit(AstStatFunction* node) override; + + bool visit(AstNode* node) override; +}; + std::vector findAncestryAtPositionForAutocomplete(const SourceModule& source, Position pos); +std::vector findAncestryAtPositionForAutocomplete(AstStatBlock* root, Position pos); std::vector findAstAncestryOfPosition(const SourceModule& source, Position pos, bool includeTypes = false); +std::vector findAstAncestryOfPosition(AstStatBlock* root, Position pos, bool includeTypes = false); AstNode* findNodeAtPosition(const SourceModule& source, Position pos); +AstNode* findNodeAtPosition(AstStatBlock* root, Position pos); AstExpr* findExprAtPosition(const SourceModule& source, Position pos); ScopePtr findScopeAtPosition(const Module& module, Position pos); std::optional findBindingAtPosition(const Module& module, const SourceModule& source, Position pos); diff --git a/Analysis/include/Luau/Autocomplete.h b/Analysis/include/Luau/Autocomplete.h index 618325777..96bac9e4b 100644 --- a/Analysis/include/Luau/Autocomplete.h +++ b/Analysis/include/Luau/Autocomplete.h @@ -38,6 +38,8 @@ enum class AutocompleteEntryKind String, Type, Module, + GeneratedFunction, + RequirePath, }; enum class ParenthesesRecommendation @@ -70,6 +72,10 @@ struct AutocompleteEntry std::optional documentationSymbol = std::nullopt; Tags tags; ParenthesesRecommendation parens = ParenthesesRecommendation::None; + std::optional insertText; + + // Only meaningful if kind is Property. + bool indexedWithSelf = false; }; using AutocompleteEntryMap = std::unordered_map; @@ -94,4 +100,6 @@ using StringCompletionCallback = AutocompleteResult autocomplete(Frontend& frontend, const ModuleName& moduleName, Position position, StringCompletionCallback callback); +constexpr char kGeneratedAnonymousFunctionEntryName[] = "function (anonymous autofilled)"; + } // namespace Luau diff --git a/Analysis/include/Luau/Breadcrumb.h b/Analysis/include/Luau/Breadcrumb.h deleted file mode 100644 index 59b293a0b..000000000 --- a/Analysis/include/Luau/Breadcrumb.h +++ /dev/null @@ -1,75 +0,0 @@ -// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details -#pragma once - -#include "Luau/Def.h" -#include "Luau/NotNull.h" -#include "Luau/Variant.h" - -#include -#include - -namespace Luau -{ - -using NullableBreadcrumbId = const struct Breadcrumb*; -using BreadcrumbId = NotNull; - -struct FieldMetadata -{ - std::string prop; -}; - -struct SubscriptMetadata -{ - BreadcrumbId key; -}; - -using Metadata = Variant; - -struct Breadcrumb -{ - NullableBreadcrumbId previous; - DefId def; - std::optional metadata; - std::vector children; -}; - -inline Breadcrumb* asMutable(NullableBreadcrumbId breadcrumb) -{ - LUAU_ASSERT(breadcrumb); - return const_cast(breadcrumb); -} - -template -const T* getMetadata(NullableBreadcrumbId breadcrumb) -{ - if (!breadcrumb || !breadcrumb->metadata) - return nullptr; - - return get_if(&*breadcrumb->metadata); -} - -struct BreadcrumbArena -{ - TypedAllocator allocator; - - template - BreadcrumbId add(NullableBreadcrumbId previous, DefId def, Args&&... args) - { - Breadcrumb* bc = allocator.allocate(Breadcrumb{previous, def, std::forward(args)...}); - if (previous) - asMutable(previous)->children.push_back(NotNull{bc}); - return NotNull{bc}; - } - - template - BreadcrumbId emplace(NullableBreadcrumbId previous, DefId def, Args&&... args) - { - Breadcrumb* bc = allocator.allocate(Breadcrumb{previous, def, Metadata{T{std::forward(args)...}}}); - if (previous) - asMutable(previous)->children.push_back(NotNull{bc}); - return NotNull{bc}; - } -}; - -} // namespace Luau diff --git a/Analysis/include/Luau/BuiltinDefinitions.h b/Analysis/include/Luau/BuiltinDefinitions.h index 0604b40e2..db2f67127 100644 --- a/Analysis/include/Luau/BuiltinDefinitions.h +++ b/Analysis/include/Luau/BuiltinDefinitions.h @@ -9,60 +9,89 @@ namespace Luau { +static constexpr char kRequireTagName[] = "require"; + struct Frontend; +struct GlobalTypes; struct TypeChecker; struct TypeArena; +struct Subtyping; -void registerBuiltinTypes(Frontend& frontend); - -void registerBuiltinGlobals(TypeChecker& typeChecker); -void registerBuiltinGlobals(Frontend& frontend); - +void registerBuiltinGlobals(Frontend& frontend, GlobalTypes& globals, bool typeCheckForAutocomplete = false); TypeId makeUnion(TypeArena& arena, std::vector&& types); TypeId makeIntersection(TypeArena& arena, std::vector&& types); /** Build an optional 't' */ -TypeId makeOption(TypeChecker& typeChecker, TypeArena& arena, TypeId t); -TypeId makeOption(Frontend& frontend, TypeArena& arena, TypeId t); +TypeId makeOption(NotNull builtinTypes, TypeArena& arena, TypeId t); /** Small utility function for building up type definitions from C++. */ TypeId makeFunction( // Monomorphic - TypeArena& arena, std::optional selfType, std::initializer_list paramTypes, std::initializer_list retTypes); + TypeArena& arena, + std::optional selfType, + std::initializer_list paramTypes, + std::initializer_list retTypes, + bool checked = false +); TypeId makeFunction( // Polymorphic - TypeArena& arena, std::optional selfType, std::initializer_list generics, std::initializer_list genericPacks, - std::initializer_list paramTypes, std::initializer_list retTypes); + TypeArena& arena, + std::optional selfType, + std::initializer_list generics, + std::initializer_list genericPacks, + std::initializer_list paramTypes, + std::initializer_list retTypes, + bool checked = false +); TypeId makeFunction( // Monomorphic - TypeArena& arena, std::optional selfType, std::initializer_list paramTypes, std::initializer_list paramNames, - std::initializer_list retTypes); + TypeArena& arena, + std::optional selfType, + std::initializer_list paramTypes, + std::initializer_list paramNames, + std::initializer_list retTypes, + bool checked = false +); TypeId makeFunction( // Polymorphic - TypeArena& arena, std::optional selfType, std::initializer_list generics, std::initializer_list genericPacks, - std::initializer_list paramTypes, std::initializer_list paramNames, std::initializer_list retTypes); + TypeArena& arena, + std::optional selfType, + std::initializer_list generics, + std::initializer_list genericPacks, + std::initializer_list paramTypes, + std::initializer_list paramNames, + std::initializer_list retTypes, + bool checked = false +); void attachMagicFunction(TypeId ty, MagicFunction fn); void attachDcrMagicFunction(TypeId ty, DcrMagicFunction fn); void attachDcrMagicRefinement(TypeId ty, DcrMagicRefinement fn); - +void attachDcrMagicFunctionTypeCheck(TypeId ty, DcrMagicFunctionTypeCheck fn); Property makeProperty(TypeId ty, std::optional documentationSymbol = std::nullopt); void assignPropDocumentationSymbols(TableType::Props& props, const std::string& baseName); std::string getBuiltinDefinitionSource(); -void addGlobalBinding(TypeChecker& typeChecker, const std::string& name, Binding binding); -void addGlobalBinding(TypeChecker& typeChecker, const std::string& name, TypeId ty, const std::string& packageName); -void addGlobalBinding(TypeChecker& typeChecker, const ScopePtr& scope, const std::string& name, TypeId ty, const std::string& packageName); -void addGlobalBinding(TypeChecker& typeChecker, const ScopePtr& scope, const std::string& name, Binding binding); -void addGlobalBinding(Frontend& frontend, const std::string& name, TypeId ty, const std::string& packageName); -void addGlobalBinding(Frontend& frontend, const std::string& name, Binding binding); -void addGlobalBinding(Frontend& frontend, const ScopePtr& scope, const std::string& name, TypeId ty, const std::string& packageName); -void addGlobalBinding(Frontend& frontend, const ScopePtr& scope, const std::string& name, Binding binding); -std::optional tryGetGlobalBinding(Frontend& frontend, const std::string& name); -Binding* tryGetGlobalBindingRef(TypeChecker& typeChecker, const std::string& name); -TypeId getGlobalBinding(Frontend& frontend, const std::string& name); -TypeId getGlobalBinding(TypeChecker& typeChecker, const std::string& name); +void addGlobalBinding(GlobalTypes& globals, const std::string& name, TypeId ty, const std::string& packageName); +void addGlobalBinding(GlobalTypes& globals, const std::string& name, Binding binding); +void addGlobalBinding(GlobalTypes& globals, const ScopePtr& scope, const std::string& name, TypeId ty, const std::string& packageName); +void addGlobalBinding(GlobalTypes& globals, const ScopePtr& scope, const std::string& name, Binding binding); +std::optional tryGetGlobalBinding(GlobalTypes& globals, const std::string& name); +Binding* tryGetGlobalBindingRef(GlobalTypes& globals, const std::string& name); +TypeId getGlobalBinding(GlobalTypes& globals, const std::string& name); + + +/** A number of built-in functions are magical enough that we need to match on them specifically by + * name when they are called. These are listed here to be used whenever necessary, instead of duplicating this logic repeatedly. + */ + +bool matchSetMetatable(const AstExprCall& call); +bool matchTableFreeze(const AstExprCall& call); +bool matchAssert(const AstExprCall& call); + +// Returns `true` if the function should introduce typestate for its first argument. +bool shouldTypestateForFirstArgument(const AstExprCall& call); } // namespace Luau diff --git a/Analysis/include/Luau/Cancellation.h b/Analysis/include/Luau/Cancellation.h new file mode 100644 index 000000000..441318631 --- /dev/null +++ b/Analysis/include/Luau/Cancellation.h @@ -0,0 +1,24 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include + +namespace Luau +{ + +struct FrontendCancellationToken +{ + void cancel() + { + cancelled.store(true); + } + + bool requested() + { + return cancelled.load(); + } + + std::atomic cancelled; +}; + +} // namespace Luau diff --git a/Analysis/include/Luau/Clone.h b/Analysis/include/Luau/Clone.h index 51f1e7a67..b0c8fd17c 100644 --- a/Analysis/include/Luau/Clone.h +++ b/Analysis/include/Luau/Clone.h @@ -16,17 +16,23 @@ using SeenTypePacks = std::unordered_map; struct CloneState { + NotNull builtinTypes; + SeenTypes seenTypes; SeenTypePacks seenTypePacks; - - int recursionCount = 0; }; +/** `shallowClone` will make a copy of only the _top level_ constructor of the type, + * while `clone` will make a deep copy of the entire type and its every component. + * + * Be mindful about which behavior you actually _want_. + */ + +TypePackId shallowClone(TypePackId tp, TypeArena& dest, CloneState& cloneState); +TypeId shallowClone(TypeId typeId, TypeArena& dest, CloneState& cloneState); + TypePackId clone(TypePackId tp, TypeArena& dest, CloneState& cloneState); TypeId clone(TypeId tp, TypeArena& dest, CloneState& cloneState); TypeFun clone(const TypeFun& typeFun, TypeArena& dest, CloneState& cloneState); -TypeId shallowClone(TypeId ty, TypeArena& dest, const TxnLog* log, bool alwaysClone = false); -TypeId shallowClone(TypeId ty, NotNull dest); - } // namespace Luau diff --git a/Analysis/include/Luau/Constraint.h b/Analysis/include/Luau/Constraint.h index 2223c29e0..612537328 100644 --- a/Analysis/include/Luau/Constraint.h +++ b/Analysis/include/Luau/Constraint.h @@ -4,8 +4,8 @@ #include "Luau/Ast.h" // Used for some of the enumerations #include "Luau/DenseHash.h" #include "Luau/NotNull.h" -#include "Luau/Type.h" #include "Luau/Variant.h" +#include "Luau/TypeFwd.h" #include #include @@ -14,13 +14,15 @@ namespace Luau { +enum class ValueContext; struct Scope; -struct Type; -using TypeId = const Type*; - -struct TypePackVar; -using TypePackId = const TypePackVar*; +// if resultType is a freeType, assignmentType <: freeType <: resultType bounds +struct EqualityConstraint +{ + TypeId resultType; + TypeId assignmentType; +}; // subType <: superType struct SubtypeConstraint @@ -34,6 +36,11 @@ struct PackSubtypeConstraint { TypePackId subPack; TypePackId superPack; + + // HACK!! TODO clip. + // We need to know which of `PackSubtypeConstraint` are emitted from `AstStatReturn` vs any others. + // Then we force these specific `PackSubtypeConstraint` to only dispatch in the order of the `return`s. + bool returns = false; }; // generalizedType ~ gen sourceType @@ -41,46 +48,19 @@ struct GeneralizationConstraint { TypeId generalizedType; TypeId sourceType; -}; - -// subType ~ inst superType -struct InstantiationConstraint -{ - TypeId subType; - TypeId superType; -}; -struct UnaryConstraint -{ - AstExprUnary::Op op; - TypeId operandType; - TypeId resultType; + std::vector interiorTypes; }; -// let L : leftType -// let R : rightType -// in -// L op R : resultType -struct BinaryConstraint -{ - AstExprBinary::Op op; - TypeId leftType; - TypeId rightType; - TypeId resultType; - - // When we dispatch this constraint, we update the key at this map to record - // the overload that we selected. - const AstNode* astFragment; - DenseHashMap* astOriginalCallTypes; - DenseHashMap* astOverloadResolvedTypes; -}; - -// iteratee is iterable -// iterators is the iteration types. +// variables ~ iterate iterator +// Unpack the iterator, figure out what types it iterates over, and bind those types to variables. struct IterableConstraint { TypePackId iterator; - TypePackId variables; + std::vector variables; + + const AstNode* nextAstFragment; + DenseHashMap* astForInNextTypes; }; // name(namedType) = name @@ -105,22 +85,52 @@ struct FunctionCallConstraint TypeId fn; TypePackId argsPack; TypePackId result; - class AstExprCall* callSite; + class AstExprCall* callSite = nullptr; std::vector> discriminantTypes; + + // When we dispatch this constraint, we update the key at this map to record + // the overload that we selected. + DenseHashMap* astOverloadResolvedTypes = nullptr; +}; + +// function_check fn argsPack +// +// If fn is a function type and argsPack is a partially solved +// pack of arguments to be supplied to the function, propagate the argument +// types of fn into the types of argsPack. This is used to implement +// bidirectional inference of lambda arguments. +struct FunctionCheckConstraint +{ + TypeId fn; + TypePackId argsPack; + + class AstExprCall* callSite = nullptr; + NotNull> astTypes; + NotNull> astExpectedTypes; }; -// result ~ prim ExpectedType SomeSingletonType MultitonType +// prim FreeType ExpectedType PrimitiveType // -// If ExpectedType is potentially a singleton (an actual singleton or a union -// that contains a singleton), then result ~ SomeSingletonType +// FreeType is bounded below by the singleton type and above by PrimitiveType +// initially. When this constraint is resolved, it will check that the bounds +// of the free type are well-formed by subtyping. // -// else result ~ MultitonType +// If they are not well-formed, then FreeType is replaced by its lower bound +// +// If they are well-formed and ExpectedType is potentially a singleton (an +// actual singleton or a union that contains a singleton), +// then FreeType is replaced by its lower bound +// +// else FreeType is replaced by PrimitiveType struct PrimitiveTypeConstraint { - TypeId resultType; - TypeId expectedType; - TypeId singletonType; - TypeId multitonType; + TypeId freeType; + + // potentially gets used to force the lower bound? + std::optional expectedType; + + // the primitive type to check against + TypeId primitiveType; }; // result ~ hasProp type "prop_name" @@ -139,63 +149,131 @@ struct HasPropConstraint TypeId resultType; TypeId subjectType; std::string prop; + ValueContext context; + + // We want to track if this `HasPropConstraint` comes from a conditional. + // If it does, we're going to change the behavior of property look-up a bit. + // In particular, we're going to return `unknownType` for property lookups + // on `table` or inexact table types where the property is not present. + // + // This allows us to refine table types to have additional properties + // without reporting errors in typechecking on the property tests. + bool inConditional = false; + + // HACK: We presently need types like true|false or string|"hello" when + // deciding whether a particular literal expression should have a singleton + // type. This boolean is set to true when extracting the property type of a + // value that may be a union of tables. + // + // For example, in the following code fragment, we want the lookup of the + // success property to yield true|false when extracting an expectedType in + // this expression: + // + // type Result = {success:true, result: T} | {success:false, error: E} + // + // local r: Result = {success=true, result=9} + // + // If we naively simplify the expectedType to boolean, we will erroneously + // compute the type boolean for the success property of the table literal. + // This causes type checking to fail. + bool suppressSimplification = false; }; -// result ~ setProp subjectType ["prop", "prop2", ...] propType -// -// If the subject is a table or table-like thing that already has the named -// property chain, we unify propType with that existing property type. +// resultType ~ hasIndexer subjectType indexType // -// If the subject is a free table, we augment it in place. +// If the subject type is a table or table-like thing that supports indexing, +// populate the type result with the result type of such an index operation. // -// If the subject is an unsealed table, result is an augmented table that -// includes that new prop. -struct SetPropConstraint +// If the subject is not indexable, resultType is bound to errorType. +struct HasIndexerConstraint { TypeId resultType; TypeId subjectType; - std::vector path; - TypeId propType; + TypeId indexType; }; -// result ~ setIndexer subjectType indexType propType +// assignProp lhsType propName rhsType // -// If the subject is a table or table-like thing that already has an indexer, -// unify its indexType and propType with those from this constraint. -// -// If the table is a free or unsealed table, we augment it with a new indexer. -struct SetIndexerConstraint +// Assign a value of type rhsType into the named property of lhsType. + +struct AssignPropConstraint { - TypeId resultType; - TypeId subjectType; - TypeId indexType; + TypeId lhsType; + std::string propName; + TypeId rhsType; + + /// If a new property is to be inserted into a table type, it will be + /// ascribed this location. + std::optional propLocation; + + /// The canonical write type of the property. It is _solely_ used to + /// populate astTypes during constraint resolution. Nothing should ever + /// block on it. TypeId propType; + + // When we generate constraints, we increment the remaining prop count on + // the table if we are able. This flag informs the solver as to whether or + // not it should in turn decrement the prop count when this constraint is + // dispatched. + bool decrementPropCount = false; }; -// if negation: -// result ~ if isSingleton D then ~D else unknown where D = discriminantType -// if not negation: -// result ~ if isSingleton D then D else unknown where D = discriminantType -struct SingletonOrTopTypeConstraint +struct AssignIndexConstraint { - TypeId resultType; - TypeId discriminantType; - bool negated; + TypeId lhsType; + TypeId indexType; + TypeId rhsType; + + /// The canonical write type of the property. It is _solely_ used to + /// populate astTypes during constraint resolution. Nothing should ever + /// block on it. + TypeId propType; }; -// resultType ~ unpack sourceTypePack +// resultTypes ~ unpack sourceTypePack // // Similar to PackSubtypeConstraint, but with one important difference: If the // sourcePack is blocked, this constraint blocks. struct UnpackConstraint { - TypePackId resultPack; + std::vector resultPack; TypePackId sourcePack; }; -using ConstraintV = Variant; +// ty ~ reduce ty +// +// Try to reduce ty, if it is a TypeFunctionInstanceType. Otherwise, do nothing. +struct ReduceConstraint +{ + TypeId ty; +}; + +// tp ~ reduce tp +// +// Analogous to ReduceConstraint, but for type packs. +struct ReducePackConstraint +{ + TypePackId tp; +}; + +using ConstraintV = Variant< + SubtypeConstraint, + PackSubtypeConstraint, + GeneralizationConstraint, + IterableConstraint, + NameConstraint, + TypeAliasExpansionConstraint, + FunctionCallConstraint, + FunctionCheckConstraint, + PrimitiveTypeConstraint, + HasPropConstraint, + HasIndexerConstraint, + AssignPropConstraint, + AssignIndexConstraint, + UnpackConstraint, + ReduceConstraint, + ReducePackConstraint, + EqualityConstraint>; struct Constraint { @@ -209,10 +287,14 @@ struct Constraint ConstraintV c; std::vector> dependencies; + + DenseHashSet getMaybeMutatedFreeTypes() const; }; using ConstraintPtr = std::unique_ptr; +bool isReferenceCountedType(const TypeId typ); + inline Constraint& asMutable(const Constraint& c) { return const_cast(c); diff --git a/Analysis/include/Luau/ConstraintGraphBuilder.h b/Analysis/include/Luau/ConstraintGenerator.h similarity index 55% rename from Analysis/include/Luau/ConstraintGraphBuilder.h rename to Analysis/include/Luau/ConstraintGenerator.h index e79c4c91e..435c62fb6 100644 --- a/Analysis/include/Luau/ConstraintGraphBuilder.h +++ b/Analysis/include/Luau/ConstraintGenerator.h @@ -2,15 +2,20 @@ #pragma once #include "Luau/Ast.h" -#include "Luau/Refinement.h" #include "Luau/Constraint.h" +#include "Luau/ControlFlow.h" #include "Luau/DataFlowGraph.h" +#include "Luau/InsertionOrderedMap.h" #include "Luau/Module.h" #include "Luau/ModuleResolver.h" +#include "Luau/Normalize.h" #include "Luau/NotNull.h" +#include "Luau/Refinement.h" #include "Luau/Symbol.h" -#include "Luau/Type.h" +#include "Luau/TypeFwd.h" +#include "Luau/TypeUtils.h" #include "Luau/Variant.h" +#include "Luau/Normalize.h" #include #include @@ -23,6 +28,7 @@ struct Scope; using ScopePtr = std::shared_ptr; struct DcrLogger; +struct TypeFunctionRuntime; struct Inference { @@ -52,21 +58,37 @@ struct InferencePack } }; -struct ConstraintGraphBuilder +struct ConstraintGenerator { // A list of all the scopes in the module. This vector holds ownership of the // scope pointers; the scopes themselves borrow pointers to other scopes to // define the scope hierarchy. std::vector> scopes; - ModuleName moduleName; ModulePtr module; NotNull builtinTypes; const NotNull arena; // The root scope of the module we're generating constraints for. - // This is null when the CGB is initially constructed. + // This is null when the CG is initially constructed. Scope* rootScope; + TypeContext typeContext = TypeContext::Default; + + struct InferredBinding + { + Scope* scope; + Location location; + TypeIds types; + }; + + // Some locals have multiple type states. We wish for Scope::bindings to + // map each local name onto the union of every type that the local can have + // over its lifetime, so we use this map to accumulate the set of types it + // might have. + // + // See the functions recordInferredBinding and fillInInferredBindings. + DenseHashMap inferredBindings{{}}; + // Constraints that go straight to the solver. std::vector constraints; @@ -85,17 +107,49 @@ struct ConstraintGraphBuilder // It is pretty uncommon for constraint generation to itself produce errors, but it can happen. std::vector errors; + // Needed to be able to enable error-suppression preservation for immediate refinements. + NotNull normalizer; + // Needed to register all available type functions for execution at later stages. + NotNull typeFunctionRuntime; // Needed to resolve modules to make 'require' import types properly. NotNull moduleResolver; // Occasionally constraint generation needs to produce an ICE. const NotNull ice; ScopePtr globalScope; + + std::function prepareModuleScope; + std::vector requireCycles; + + DenseHashMap localTypes{nullptr}; + DcrLogger* logger; - ConstraintGraphBuilder(const ModuleName& moduleName, ModulePtr module, TypeArena* arena, NotNull moduleResolver, - NotNull builtinTypes, NotNull ice, const ScopePtr& globalScope, DcrLogger* logger, - NotNull dfg); + ConstraintGenerator( + ModulePtr module, + NotNull normalizer, + NotNull typeFunctionRuntime, + NotNull moduleResolver, + NotNull builtinTypes, + NotNull ice, + const ScopePtr& globalScope, + std::function prepareModuleScope, + DcrLogger* logger, + NotNull dfg, + std::vector requireCycles + ); + + /** + * The entry point to the ConstraintGenerator. This will construct a set + * of scopes, constraints, and free types that can be solved later. + * @param block the root block to generate constraints for. + */ + void visitModuleRoot(AstStatBlock* block); + + void visitFragmentRoot(const ScopePtr& resumeScope, AstStatBlock* block); + +private: + std::vector> interiorTypes; /** * Fabricates a new free type belonging to a given scope. @@ -109,6 +163,18 @@ struct ConstraintGraphBuilder */ TypePackId freshTypePack(const ScopePtr& scope); + /** + * Allocate a new TypePack with the given head and tail. + * + * Avoids allocating 0-length type packs: + * + * If the head is non-empty, allocate and return a type pack with the given + * head and tail. + * If the head is empty and tail is non-empty, return *tail. + * If both the head and tail are empty, return an empty type pack. + */ + TypePackId addTypePack(std::vector head, std::optional tail); + /** * Fabricates a scope that is a child of another scope. * @param node the lexical node that the scope belongs to. @@ -116,6 +182,8 @@ struct ConstraintGraphBuilder */ ScopePtr childScope(AstNode* node, const ScopePtr& parent); + std::optional lookup(const ScopePtr& scope, Location location, DefId def, bool prototype = true); + /** * Adds a new constraint with no dependencies to a given scope. * @param scope the scope to add the constraint to. @@ -132,38 +200,67 @@ struct ConstraintGraphBuilder */ NotNull addConstraint(const ScopePtr& scope, std::unique_ptr c); + struct RefinementPartition + { + // Types that we want to intersect against the type of the expression. + std::vector discriminantTypes; + + // Sometimes the type we're discriminating against is implicitly nil. + bool shouldAppendNilType = false; + }; + + using RefinementContext = InsertionOrderedMap; + void unionRefinements( + const ScopePtr& scope, + Location location, + const RefinementContext& lhs, + const RefinementContext& rhs, + RefinementContext& dest, + std::vector* constraints + ); + void computeRefinement( + const ScopePtr& scope, + Location location, + RefinementId refinement, + RefinementContext* refis, + bool sense, + bool eq, + std::vector* constraints + ); void applyRefinements(const ScopePtr& scope, Location location, RefinementId refinement); - /** - * The entry point to the ConstraintGraphBuilder. This will construct a set - * of scopes, constraints, and free types that can be solved later. - * @param block the root block to generate constraints for. - */ - void visit(AstStatBlock* block); - - void visitBlockWithoutChildScope(const ScopePtr& scope, AstStatBlock* block); - - void visit(const ScopePtr& scope, AstStat* stat); - void visit(const ScopePtr& scope, AstStatBlock* block); - void visit(const ScopePtr& scope, AstStatLocal* local); - void visit(const ScopePtr& scope, AstStatFor* for_); - void visit(const ScopePtr& scope, AstStatForIn* forIn); - void visit(const ScopePtr& scope, AstStatWhile* while_); - void visit(const ScopePtr& scope, AstStatRepeat* repeat); - void visit(const ScopePtr& scope, AstStatLocalFunction* function); - void visit(const ScopePtr& scope, AstStatFunction* function); - void visit(const ScopePtr& scope, AstStatReturn* ret); - void visit(const ScopePtr& scope, AstStatAssign* assign); - void visit(const ScopePtr& scope, AstStatCompoundAssign* assign); - void visit(const ScopePtr& scope, AstStatIf* ifStatement); - void visit(const ScopePtr& scope, AstStatTypeAlias* alias); - void visit(const ScopePtr& scope, AstStatDeclareGlobal* declareGlobal); - void visit(const ScopePtr& scope, AstStatDeclareClass* declareClass); - void visit(const ScopePtr& scope, AstStatDeclareFunction* declareFunction); - void visit(const ScopePtr& scope, AstStatError* error); + LUAU_NOINLINE void checkAliases(const ScopePtr& scope, AstStatBlock* block); + + ControlFlow visitBlockWithoutChildScope(const ScopePtr& scope, AstStatBlock* block); + ControlFlow visitBlockWithoutChildScope_DEPRECATED(const ScopePtr& scope, AstStatBlock* block); + + ControlFlow visit(const ScopePtr& scope, AstStat* stat); + ControlFlow visit(const ScopePtr& scope, AstStatBlock* block); + ControlFlow visit(const ScopePtr& scope, AstStatLocal* local); + ControlFlow visit(const ScopePtr& scope, AstStatFor* for_); + ControlFlow visit(const ScopePtr& scope, AstStatForIn* forIn); + ControlFlow visit(const ScopePtr& scope, AstStatWhile* while_); + ControlFlow visit(const ScopePtr& scope, AstStatRepeat* repeat); + ControlFlow visit(const ScopePtr& scope, AstStatLocalFunction* function); + ControlFlow visit(const ScopePtr& scope, AstStatFunction* function); + ControlFlow visit(const ScopePtr& scope, AstStatReturn* ret); + ControlFlow visit(const ScopePtr& scope, AstStatAssign* assign); + ControlFlow visit(const ScopePtr& scope, AstStatCompoundAssign* assign); + ControlFlow visit(const ScopePtr& scope, AstStatIf* ifStatement); + ControlFlow visit(const ScopePtr& scope, AstStatTypeAlias* alias); + ControlFlow visit(const ScopePtr& scope, AstStatTypeFunction* function); + ControlFlow visit(const ScopePtr& scope, AstStatDeclareGlobal* declareGlobal); + ControlFlow visit(const ScopePtr& scope, AstStatDeclareClass* declareClass); + ControlFlow visit(const ScopePtr& scope, AstStatDeclareFunction* declareFunction); + ControlFlow visit(const ScopePtr& scope, AstStatError* error); InferencePack checkPack(const ScopePtr& scope, AstArray exprs, const std::vector>& expectedTypes = {}); - InferencePack checkPack(const ScopePtr& scope, AstExpr* expr, const std::vector>& expectedTypes = {}); + InferencePack checkPack( + const ScopePtr& scope, + AstExpr* expr, + const std::vector>& expectedTypes = {}, + bool generalize = true + ); InferencePack checkPack(const ScopePtr& scope, AstExprCall* call); @@ -173,16 +270,25 @@ struct ConstraintGraphBuilder * @param expr the expression to check. * @param expectedType the type of the expression that is expected from its * surrounding context. Used to implement bidirectional type checking. + * @param generalize If true, generalize any lambdas that are encountered. * @return the type of the expression. */ - Inference check(const ScopePtr& scope, AstExpr* expr, std::optional expectedType = {}, bool forceSingleton = false); + Inference check( + const ScopePtr& scope, + AstExpr* expr, + std::optional expectedType = {}, + bool forceSingleton = false, + bool generalize = true + ); Inference check(const ScopePtr& scope, AstExprConstantString* string, std::optional expectedType, bool forceSingleton); Inference check(const ScopePtr& scope, AstExprConstantBool* bool_, std::optional expectedType, bool forceSingleton); Inference check(const ScopePtr& scope, AstExprLocal* local); Inference check(const ScopePtr& scope, AstExprGlobal* global); + Inference checkIndexName(const ScopePtr& scope, const RefinementKey* key, AstExpr* indexee, const std::string& index, Location indexLocation); Inference check(const ScopePtr& scope, AstExprIndexName* indexName); Inference check(const ScopePtr& scope, AstExprIndexExpr* indexExpr); + Inference check(const ScopePtr& scope, AstExprFunction* func, std::optional expectedType, bool generalize); Inference check(const ScopePtr& scope, AstExprUnary* unary); Inference check(const ScopePtr& scope, AstExprBinary* binary, std::optional expectedType); Inference check(const ScopePtr& scope, AstExprIfElse* ifElse, std::optional expectedType); @@ -191,9 +297,11 @@ struct ConstraintGraphBuilder Inference check(const ScopePtr& scope, AstExprTable* expr, std::optional expectedType); std::tuple checkBinary(const ScopePtr& scope, AstExprBinary* binary, std::optional expectedType); - std::vector checkLValues(const ScopePtr& scope, AstArray exprs); - - TypeId checkLValue(const ScopePtr& scope, AstExpr* expr); + void visitLValue(const ScopePtr& scope, AstExpr* expr, TypeId rhsType); + void visitLValue(const ScopePtr& scope, AstExprLocal* local, TypeId rhsType); + void visitLValue(const ScopePtr& scope, AstExprGlobal* global, TypeId rhsType); + void visitLValue(const ScopePtr& scope, AstExprIndexName* indexName, TypeId rhsType); + void visitLValue(const ScopePtr& scope, AstExprIndexExpr* indexExpr, TypeId rhsType); struct FunctionSignature { @@ -208,7 +316,12 @@ struct ConstraintGraphBuilder ScopePtr bodyScope; }; - FunctionSignature checkFunctionSignature(const ScopePtr& parent, AstExprFunction* fn, std::optional expectedType = {}); + FunctionSignature checkFunctionSignature( + const ScopePtr& parent, + AstExprFunction* fn, + std::optional expectedType = {}, + std::optional originalName = {} + ); /** * Checks the body of a function expression. @@ -217,6 +330,11 @@ struct ConstraintGraphBuilder */ void checkFunctionBody(const ScopePtr& scope, AstExprFunction* fn); + // Specializations of 'resolveType' below + TypeId resolveReferenceType(const ScopePtr& scope, AstType* ty, AstTypeReference* ref, bool inTypeArguments, bool replaceErrorWithFresh); + TypeId resolveTableType(const ScopePtr& scope, AstType* ty, AstTypeTable* tab, bool inTypeArguments, bool replaceErrorWithFresh); + TypeId resolveFunctionType(const ScopePtr& scope, AstType* ty, AstTypeFunction* fn, bool inTypeArguments, bool replaceErrorWithFresh); + /** * Resolves a type from its AST annotation. * @param scope the scope that the type annotation appears within. @@ -255,7 +373,11 @@ struct ConstraintGraphBuilder * privateTypeBindings map. **/ std::vector> createGenerics( - const ScopePtr& scope, AstArray generics, bool useCache = false, bool addTypes = true); + const ScopePtr& scope, + AstArray generics, + bool useCache = false, + bool addTypes = true + ); /** * Creates generic type packs given a list of AST definitions, resolving @@ -268,26 +390,51 @@ struct ConstraintGraphBuilder * privateTypePackBindings map. **/ std::vector> createGenericPacks( - const ScopePtr& scope, AstArray packs, bool useCache = false, bool addTypes = true); + const ScopePtr& scope, + AstArray packs, + bool useCache = false, + bool addTypes = true + ); Inference flattenPack(const ScopePtr& scope, Location location, InferencePack pack); void reportError(Location location, TypeErrorData err); void reportCodeTooComplex(Location location); + // make a union type function of these two types + TypeId makeUnion(const ScopePtr& scope, Location location, TypeId lhs, TypeId rhs); + // make an intersect type function of these two types + TypeId makeIntersect(const ScopePtr& scope, Location location, TypeId lhs, TypeId rhs); + /** Scan the program for global definitions. * - * ConstraintGraphBuilder needs to differentiate between globals and accesses to undefined symbols. Doing this "for + * ConstraintGenerator needs to differentiate between globals and accesses to undefined symbols. Doing this "for * real" in a general way is going to be pretty hard, so we are choosing not to tackle that yet. For now, we do an * initial scan of the AST and note what globals are defined. */ void prepopulateGlobalScope(const ScopePtr& globalScope, AstStatBlock* program); + bool recordPropertyAssignment(TypeId ty); + + // Record the fact that a particular local has a particular type in at least + // one of its states. + void recordInferredBinding(AstLocal* local, TypeId ty); + + void fillInInferredBindings(const ScopePtr& globalScope, AstStatBlock* block); + /** Given a function type annotation, return a vector describing the expected types of the calls to the function * For example, calling a function with annotation ((number) -> string & ((string) -> number)) * yields a vector of size 1, with value: [number | string] */ std::vector> getExpectedCallTypesForFunctionOverloads(const TypeId fnType); + + TypeId createTypeFunctionInstance( + const TypeFunction& function, + std::vector typeArguments, + std::vector packArguments, + const ScopePtr& scope, + Location location + ); }; /** Borrow a vector of pointers from a vector of owning pointers to constraints. diff --git a/Analysis/include/Luau/ConstraintSolver.h b/Analysis/include/Luau/ConstraintSolver.h index 4fd7d0d10..c9336c1d0 100644 --- a/Analysis/include/Luau/ConstraintSolver.h +++ b/Analysis/include/Luau/ConstraintSolver.h @@ -3,21 +3,32 @@ #pragma once #include "Luau/Constraint.h" +#include "Luau/DataFlowGraph.h" +#include "Luau/DenseHash.h" #include "Luau/Error.h" +#include "Luau/Location.h" #include "Luau/Module.h" #include "Luau/Normalize.h" +#include "Luau/Substitution.h" #include "Luau/ToString.h" #include "Luau/Type.h" -#include "Luau/TypeReduction.h" +#include "Luau/TypeCheckLimits.h" +#include "Luau/TypeFunction.h" +#include "Luau/TypeFwd.h" #include "Luau/Variant.h" +#include #include namespace Luau { +enum class ValueContext; + struct DcrLogger; +class AstExpr; + // TypeId, TypePackId, or Constraint*. It is impossible to know which, but we // never dereference this pointer. using BlockedConstraintId = Variant; @@ -49,16 +60,19 @@ struct HashInstantiationSignature struct ConstraintSolver { - TypeArena* arena; + NotNull arena; NotNull builtinTypes; InternalErrorReporter iceReporter; NotNull normalizer; - NotNull reducer; + NotNull typeFunctionRuntime; // The entire set of constraints that the solver is trying to resolve. std::vector> constraints; NotNull rootScope; ModuleName currentModuleName; + // The dataflow graph of the program, used in constraint generation and for magic functions. + NotNull dfg; + // Constraints that the solver has generated, rather than sourcing from the // scope tree. std::vector> solverConstraints; @@ -72,9 +86,23 @@ struct ConstraintSolver // anything. std::unordered_map, size_t> blockedConstraints; // A mapping of type/pack pointers to the constraints they block. - std::unordered_map>, HashBlockedConstraintId> blocked; + std::unordered_map, HashBlockedConstraintId> blocked; // Memoized instantiations of type aliases. DenseHashMap instantiatedAliases{{}}; + // Breadcrumbs for where a free type's upper bound was expanded. We use + // these to provide more helpful error messages when a free type is solved + // as never unexpectedly. + DenseHashMap>> upperBoundContributors{nullptr}; + + // A mapping from free types to the number of unresolved constraints that mention them. + DenseHashMap unresolvedConstraints{{}}; + + // Irreducible/uninhabited type functions or type pack functions. + DenseHashSet uninhabitedTypeFunctions{{}}; + + // The set of types that will definitely be unchanged by generalization. + DenseHashSet generalizedTypes_{nullptr}; + const NotNull> generalizedTypes{&generalizedTypes_}; // Recorded errors that take place within the solver. ErrorVec errors; @@ -83,10 +111,22 @@ struct ConstraintSolver std::vector requireCycles; DcrLogger* logger; - - explicit ConstraintSolver(NotNull normalizer, NotNull rootScope, std::vector> constraints, - ModuleName moduleName, NotNull reducer, NotNull moduleResolver, std::vector requireCycles, - DcrLogger* logger); + TypeCheckLimits limits; + + DenseHashMap typeFunctionsToFinalize{nullptr}; + + explicit ConstraintSolver( + NotNull normalizer, + NotNull typeFunctionRuntime, + NotNull rootScope, + std::vector> constraints, + ModuleName moduleName, + NotNull moduleResolver, + std::vector requireCycles, + DcrLogger* logger, + NotNull dfg, + TypeCheckLimits limits + ); // Randomize the order in which to dispatch constraints void randomize(unsigned seed); @@ -97,43 +137,109 @@ struct ConstraintSolver **/ void run(); + + /** + * Attempts to perform one final reduction on type functions after every constraint has been completed + * + **/ + void finalizeTypeFunctions(); + bool isDone(); - void finalizeModule(); +private: + /** + * Bind a type variable to another type. + * + * A constraint is required and will validate that blockedTy is owned by this + * constraint. This prevents one constraint from interfering with another's + * blocked types. + * + * Bind will also unblock the type variable for you. + */ + void bind(NotNull constraint, TypeId ty, TypeId boundTo); + void bind(NotNull constraint, TypePackId tp, TypePackId boundTo); + + template + void emplace(NotNull constraint, TypeId ty, Args&&... args); + + template + void emplace(NotNull constraint, TypePackId tp, Args&&... args); +public: /** Attempt to dispatch a constraint. Returns true if it was successful. If * tryDispatch() returns false, the constraint remains in the unsolved set * and will be retried later. */ bool tryDispatch(NotNull c, bool force); - bool tryDispatch(const SubtypeConstraint& c, NotNull constraint, bool force); - bool tryDispatch(const PackSubtypeConstraint& c, NotNull constraint, bool force); - bool tryDispatch(const GeneralizationConstraint& c, NotNull constraint, bool force); - bool tryDispatch(const InstantiationConstraint& c, NotNull constraint, bool force); - bool tryDispatch(const UnaryConstraint& c, NotNull constraint, bool force); - bool tryDispatch(const BinaryConstraint& c, NotNull constraint, bool force); + bool tryDispatch(const SubtypeConstraint& c, NotNull constraint); + bool tryDispatch(const PackSubtypeConstraint& c, NotNull constraint); + bool tryDispatch(const GeneralizationConstraint& c, NotNull constraint); bool tryDispatch(const IterableConstraint& c, NotNull constraint, bool force); bool tryDispatch(const NameConstraint& c, NotNull constraint); bool tryDispatch(const TypeAliasExpansionConstraint& c, NotNull constraint); bool tryDispatch(const FunctionCallConstraint& c, NotNull constraint); + bool tryDispatch(const FunctionCheckConstraint& c, NotNull constraint); bool tryDispatch(const PrimitiveTypeConstraint& c, NotNull constraint); bool tryDispatch(const HasPropConstraint& c, NotNull constraint); - bool tryDispatch(const SetPropConstraint& c, NotNull constraint, bool force); - bool tryDispatch(const SetIndexerConstraint& c, NotNull constraint, bool force); - bool tryDispatch(const SingletonOrTopTypeConstraint& c, NotNull constraint); + + + bool tryDispatchHasIndexer( + int& recursionDepth, + NotNull constraint, + TypeId subjectType, + TypeId indexType, + TypeId resultType, + Set& seen + ); + bool tryDispatch(const HasIndexerConstraint& c, NotNull constraint); + + bool tryDispatch(const AssignPropConstraint& c, NotNull constraint); + bool tryDispatch(const AssignIndexConstraint& c, NotNull constraint); bool tryDispatch(const UnpackConstraint& c, NotNull constraint); + bool tryDispatch(const ReduceConstraint& c, NotNull constraint, bool force); + bool tryDispatch(const ReducePackConstraint& c, NotNull constraint, bool force); + bool tryDispatch(const EqualityConstraint& c, NotNull constraint); // for a, ... in some_table do // also handles __iter metamethod bool tryDispatchIterableTable(TypeId iteratorTy, const IterableConstraint& c, NotNull constraint, bool force); // for a, ... in next_function, t, ... do - bool tryDispatchIterableFunction( - TypeId nextTy, TypeId tableTy, TypeId firstIndexTy, const IterableConstraint& c, NotNull constraint, bool force); + bool tryDispatchIterableFunction(TypeId nextTy, TypeId tableTy, const IterableConstraint& c, NotNull constraint); + + std::pair, std::optional> lookupTableProp( + NotNull constraint, + TypeId subjectType, + const std::string& propName, + ValueContext context, + bool inConditional = false, + bool suppressSimplification = false + ); + std::pair, std::optional> lookupTableProp( + NotNull constraint, + TypeId subjectType, + const std::string& propName, + ValueContext context, + bool inConditional, + bool suppressSimplification, + DenseHashSet& seen + ); - std::pair, std::optional> lookupTableProp(TypeId subjectType, const std::string& propName); - std::pair, std::optional> lookupTableProp(TypeId subjectType, const std::string& propName, std::unordered_set& seen); + /** + * Generate constraints to unpack the types of srcTypes and assign each + * value to the corresponding BlockedType in destTypes. + * + * This function also overwrites the owners of each BlockedType. This is + * okay because this function is only used to decompose IterableConstraint + * into an UnpackConstraint. + * + * @param destTypes A vector of types comprised of BlockedTypes. + * @param srcTypes A TypePack that represents rvalues to be assigned. + * @returns The underlying UnpackConstraint. There's a bit of code in + * iteration that needs to pass blocks on to this constraint. + */ + NotNull unpackAndAssign(const std::vector destTypes, TypePackId srcTypes, NotNull constraint); void block(NotNull target, NotNull constraint); /** @@ -143,53 +249,55 @@ struct ConstraintSolver bool block(TypeId target, NotNull constraint); bool block(TypePackId target, NotNull constraint); - // Traverse the type. If any blocked or pending types are found, block - // the constraint on them. + // Block on every target + template + bool block(const T& targets, NotNull constraint) + { + for (TypeId target : targets) + block(target, constraint); + + return false; + } + + /** + * For all constraints that are blocked on one constraint, make them block + * on a new constraint. + * @param source the constraint to copy blocks from. + * @param addition the constraint that other constraints should now block on. + */ + void inheritBlocks(NotNull source, NotNull addition); + + // Traverse the type. If any pending types are found, block the constraint + // on them. // // Returns false if a type blocks the constraint. // // FIXME: This use of a boolean for the return result is an appalling // interface. - bool recursiveBlock(TypeId target, NotNull constraint); - bool recursiveBlock(TypePackId target, NotNull constraint); + bool blockOnPendingTypes(TypeId target, NotNull constraint); + bool blockOnPendingTypes(TypePackId target, NotNull constraint); void unblock(NotNull progressed); - void unblock(TypeId progressed); - void unblock(TypePackId progressed); - void unblock(const std::vector& types); - void unblock(const std::vector& packs); + void unblock(TypeId progressed, Location location); + void unblock(TypePackId progressed, Location location); + void unblock(const std::vector& types, Location location); + void unblock(const std::vector& packs, Location location); /** * @returns true if the TypeId is in a blocked state. */ - bool isBlocked(TypeId ty); + bool isBlocked(TypeId ty) const; /** * @returns true if the TypePackId is in a blocked state. */ - bool isBlocked(TypePackId tp); + bool isBlocked(TypePackId tp) const; /** * Returns whether the constraint is blocked on anything. * @param constraint the constraint to check. */ - bool isBlocked(NotNull constraint); - - /** - * Creates a new Unifier and performs a single unification operation. Commits - * the result. - * @param subType the sub-type to unify. - * @param superType the super-type to unify. - */ - void unify(TypeId subType, TypeId superType, NotNull scope); - - /** - * Creates a new Unifier and performs a single unification operation. Commits - * the result. - * @param subPack the sub-type pack to unify. - * @param superPack the super-type pack to unify. - */ - void unify(TypePackId subPack, TypePackId superPack, NotNull scope); + bool isBlocked(NotNull constraint) const; /** Pushes a new solver constraint to the solver. * @param cv the body of the constraint. @@ -210,7 +318,48 @@ struct ConstraintSolver void reportError(TypeErrorData&& data, const Location& location); void reportError(TypeError e); -private: + /** + * Shifts the count of references from `source` to `target`. This should be paired + * with any instance of binding a free type in order to maintain accurate refcounts. + * If `target` is not a free type, this is a noop. + * @param source the free type which is being bound + * @param target the type which the free type is being bound to + */ + void shiftReferences(TypeId source, TypeId target); + + /** + * Generalizes the given free type if the reference counting allows it. + * @param the scope to generalize in + * @param type the free type we want to generalize + * @returns a non-free type that generalizes the argument, or `std::nullopt` if one + * does not exist + */ + std::optional generalizeFreeType(NotNull scope, TypeId type, bool avoidSealingTables = false); + + /** + * Checks the existing set of constraints to see if there exist any that contain + * the provided free type, indicating that it is not yet ready to be replaced by + * one of its bounds. + * @param ty the free type that to check for related constraints + * @returns whether or not it is unsafe to replace the free type by one of its bounds + */ + bool hasUnresolvedConstraints(TypeId ty); + + /** Attempts to unify subTy with superTy. If doing so would require unifying + * BlockedTypes, fail and block the constraint on those BlockedTypes. + * + * Note: TID can only be TypeId or TypePackId. + * + * If unification fails, replace all free types with errorType. + * + * If unification succeeds, unblock every type changed by the unification. + * + * @returns true if the unification succeeded. False if the unification was + * too complex. + */ + template + bool unify(NotNull constraint, TID subTy, TID superTy); + /** * Marks a constraint as being blocked on a type or type pack. The constraint * solver will not attempt to dispatch blocked constraints until their @@ -218,7 +367,7 @@ struct ConstraintSolver * @param target the type or type pack pointer that the constraint is blocked on. * @param constraint the constraint to block. **/ - void block_(BlockedConstraintId target, NotNull constraint); + bool block_(BlockedConstraintId target, NotNull constraint); /** * Informs the solver that progress has been made on a type or type pack. The @@ -228,10 +377,20 @@ struct ConstraintSolver **/ void unblock_(BlockedConstraintId progressed); + /** + * Reproduces any constraints necessary for new types that are copied when applying a substitution. + * At the time of writing, this pertains only to type functions. + * @param subst the substitution that was applied + **/ + void reproduceConstraints(NotNull scope, const Location& location, const Substitution& subst); + TypeId errorRecoveryType() const; TypePackId errorRecoveryTypePack() const; - TypeId unionOfTypes(TypeId a, TypeId b, NotNull scope, bool unifyFreeTypes); + TypePackId anyifyModuleReturnTypePackGenerics(TypePackId tp); + + void throwTimeLimitError() const; + void throwUserCancelError() const; ToStringOptions opts; }; diff --git a/Analysis/include/Luau/ControlFlow.h b/Analysis/include/Luau/ControlFlow.h new file mode 100644 index 000000000..82c0403ce --- /dev/null +++ b/Analysis/include/Luau/ControlFlow.h @@ -0,0 +1,36 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include + +namespace Luau +{ + +struct Scope; +using ScopePtr = std::shared_ptr; + +enum class ControlFlow +{ + None = 0b00001, + Returns = 0b00010, + Throws = 0b00100, + Breaks = 0b01000, + Continues = 0b10000, +}; + +inline ControlFlow operator&(ControlFlow a, ControlFlow b) +{ + return ControlFlow(int(a) & int(b)); +} + +inline ControlFlow operator|(ControlFlow a, ControlFlow b) +{ + return ControlFlow(int(a) | int(b)); +} + +inline bool matches(ControlFlow a, ControlFlow b) +{ + return (a & b) != ControlFlow(0); +} + +} // namespace Luau diff --git a/Analysis/include/Luau/DataFlowGraph.h b/Analysis/include/Luau/DataFlowGraph.h index ce4ecb04c..662e50aa1 100644 --- a/Analysis/include/Luau/DataFlowGraph.h +++ b/Analysis/include/Luau/DataFlowGraph.h @@ -3,29 +3,49 @@ // Do not include LValue. It should never be used here. #include "Luau/Ast.h" -#include "Luau/Breadcrumb.h" +#include "Luau/ControlFlow.h" #include "Luau/DenseHash.h" #include "Luau/Def.h" #include "Luau/Symbol.h" +#include "Luau/TypedAllocator.h" #include namespace Luau { +struct RefinementKey +{ + const RefinementKey* parent = nullptr; + DefId def; + std::optional propName; +}; + +struct RefinementKeyArena +{ + TypedAllocator allocator; + + const RefinementKey* leaf(DefId def); + const RefinementKey* node(const RefinementKey* parent, DefId def, const std::string& propName); +}; + struct DataFlowGraph { DataFlowGraph(DataFlowGraph&&) = default; DataFlowGraph& operator=(DataFlowGraph&&) = default; - NullableBreadcrumbId getBreadcrumb(const AstExpr* expr) const; + DefId getDef(const AstExpr* expr) const; + // Look up the definition optionally, knowing it may not be present. + std::optional getDefOptional(const AstExpr* expr) const; + // Look up for the rvalue def for a compound assignment. + std::optional getRValueDefForCompoundAssign(const AstExpr* expr) const; + + DefId getDef(const AstLocal* local) const; - BreadcrumbId getBreadcrumb(const AstLocal* local) const; - BreadcrumbId getBreadcrumb(const AstExprLocal* local) const; - BreadcrumbId getBreadcrumb(const AstExprGlobal* global) const; + DefId getDef(const AstStatDeclareGlobal* global) const; + DefId getDef(const AstStatDeclareFunction* func) const; - BreadcrumbId getBreadcrumb(const AstStatDeclareGlobal* global) const; - BreadcrumbId getBreadcrumb(const AstStatDeclareFunction* func) const; + const RefinementKey* getRefinementKey(const AstExpr* expr) const; private: DataFlowGraph() = default; @@ -33,37 +53,98 @@ struct DataFlowGraph DataFlowGraph(const DataFlowGraph&) = delete; DataFlowGraph& operator=(const DataFlowGraph&) = delete; - DefArena defs; - BreadcrumbArena breadcrumbs; + DefArena defArena; + RefinementKeyArena keyArena; - DenseHashMap astBreadcrumbs{nullptr}; + DenseHashMap astDefs{nullptr}; // Sometimes we don't have the AstExprLocal* but we have AstLocal*, and sometimes we need to extract that DefId. - DenseHashMap localBreadcrumbs{nullptr}; + DenseHashMap localDefs{nullptr}; // There's no AstStatDeclaration, and it feels useless to introduce it just to enforce an invariant in one place. // All keys in this maps are really only statements that ambiently declares a symbol. - DenseHashMap declaredBreadcrumbs{nullptr}; + DenseHashMap declaredDefs{nullptr}; + // Compound assignments are in a weird situation where the local being assigned to is also being used at its + // previous type implicitly in an rvalue position. This map provides the previous binding. + DenseHashMap compoundAssignDefs{nullptr}; + + DenseHashMap astRefinementKeys{nullptr}; friend struct DataFlowGraphBuilder; }; struct DfgScope { + enum ScopeType + { + Linear, + Loop, + Function, + }; + DfgScope* parent; - DenseHashMap bindings{Symbol{}}; - DenseHashMap> props{nullptr}; + ScopeType scopeType; + Location location; - NullableBreadcrumbId lookup(Symbol symbol) const; - NullableBreadcrumbId lookup(DefId def, const std::string& key) const; + using Bindings = DenseHashMap; + using Props = DenseHashMap>; + + Bindings bindings{Symbol{}}; + Props props{nullptr}; + + std::optional lookup(Symbol symbol) const; + std::optional lookup(DefId def, const std::string& key) const; + + void inherit(const DfgScope* childScope); + + bool canUpdateDefinition(Symbol symbol) const; + bool canUpdateDefinition(DefId def, const std::string& key) const; }; -// Currently unsound. We do not presently track the control flow of the program. -// Additionally, we do not presently track assignments. +struct DataFlowResult +{ + DefId def; + const RefinementKey* parent = nullptr; +}; + +using ScopeStack = std::vector; + struct DataFlowGraphBuilder { static DataFlowGraph build(AstStatBlock* root, NotNull handle); + /** + * This method is identical to the build method above, but returns a pair of dfg, scopes as the data flow graph + * here is intended to live on the module between runs of typechecking. Before, the DFG only needed to live as + * long as the typecheck, but in a world with incremental typechecking, we need the information on the dfg to incrementally + * typecheck small fragments of code. + * @param block - pointer to the ast to build the dfg for + * @param handle - for raising internal errors while building the dfg + */ + static std::pair, std::vector>> buildShared( + AstStatBlock* block, + NotNull handle + ); + + /** + * Takes a stale graph along with a list of scopes, a small fragment of the ast, and a cursor position + * and constructs the DataFlowGraph for just that fragment. This method will fabricate defs in the final + * DFG for things that have been referenced and exist in the stale dfg. + * For example, the fragment local z = x + y will populate defs for x and y from the stale graph. + * @param staleGraph - the old DFG + * @param scopes - the old DfgScopes in the graph + * @param fragment - the Ast Fragment to re-build the root for + * @param cursorPos - the current location of the cursor - used to determine which scope we are currently in + * @param handle - for internal compiler errors + */ + static DataFlowGraph updateGraph( + const DataFlowGraph& staleGraph, + const std::vector>& scopes, + AstStatBlock* fragment, + const Position& cursorPos, + NotNull handle + ); + private: DataFlowGraphBuilder() = default; @@ -71,80 +152,104 @@ struct DataFlowGraphBuilder DataFlowGraphBuilder& operator=(const DataFlowGraphBuilder&) = delete; DataFlowGraph graph; - NotNull defs{&graph.defs}; - NotNull breadcrumbs{&graph.breadcrumbs}; + NotNull defArena{&graph.defArena}; + NotNull keyArena{&graph.keyArena}; struct InternalErrorReporter* handle = nullptr; - DfgScope* moduleScope = nullptr; + /// The arena owning all of the scope allocations for the dataflow graph being built. std::vector> scopes; - DfgScope* childScope(DfgScope* scope); - - void visit(DfgScope* scope, AstStatBlock* b); - void visitBlockWithoutChildScope(DfgScope* scope, AstStatBlock* b); - - void visit(DfgScope* scope, AstStat* s); - void visit(DfgScope* scope, AstStatIf* i); - void visit(DfgScope* scope, AstStatWhile* w); - void visit(DfgScope* scope, AstStatRepeat* r); - void visit(DfgScope* scope, AstStatBreak* b); - void visit(DfgScope* scope, AstStatContinue* c); - void visit(DfgScope* scope, AstStatReturn* r); - void visit(DfgScope* scope, AstStatExpr* e); - void visit(DfgScope* scope, AstStatLocal* l); - void visit(DfgScope* scope, AstStatFor* f); - void visit(DfgScope* scope, AstStatForIn* f); - void visit(DfgScope* scope, AstStatAssign* a); - void visit(DfgScope* scope, AstStatCompoundAssign* c); - void visit(DfgScope* scope, AstStatFunction* f); - void visit(DfgScope* scope, AstStatLocalFunction* l); - void visit(DfgScope* scope, AstStatTypeAlias* t); - void visit(DfgScope* scope, AstStatDeclareGlobal* d); - void visit(DfgScope* scope, AstStatDeclareFunction* d); - void visit(DfgScope* scope, AstStatDeclareClass* d); - void visit(DfgScope* scope, AstStatError* error); - - BreadcrumbId visitExpr(DfgScope* scope, AstExpr* e); - BreadcrumbId visitExpr(DfgScope* scope, AstExprLocal* l); - BreadcrumbId visitExpr(DfgScope* scope, AstExprGlobal* g); - BreadcrumbId visitExpr(DfgScope* scope, AstExprCall* c); - BreadcrumbId visitExpr(DfgScope* scope, AstExprIndexName* i); - BreadcrumbId visitExpr(DfgScope* scope, AstExprIndexExpr* i); - BreadcrumbId visitExpr(DfgScope* scope, AstExprFunction* f); - BreadcrumbId visitExpr(DfgScope* scope, AstExprTable* t); - BreadcrumbId visitExpr(DfgScope* scope, AstExprUnary* u); - BreadcrumbId visitExpr(DfgScope* scope, AstExprBinary* b); - BreadcrumbId visitExpr(DfgScope* scope, AstExprTypeAssertion* t); - BreadcrumbId visitExpr(DfgScope* scope, AstExprIfElse* i); - BreadcrumbId visitExpr(DfgScope* scope, AstExprInterpString* i); - BreadcrumbId visitExpr(DfgScope* scope, AstExprError* error); - - void visitLValue(DfgScope* scope, AstExpr* e); - void visitLValue(DfgScope* scope, AstExprLocal* l); - void visitLValue(DfgScope* scope, AstExprGlobal* g); - void visitLValue(DfgScope* scope, AstExprIndexName* i); - void visitLValue(DfgScope* scope, AstExprIndexExpr* i); - void visitLValue(DfgScope* scope, AstExprError* e); - - void visitType(DfgScope* scope, AstType* t); - void visitType(DfgScope* scope, AstTypeReference* r); - void visitType(DfgScope* scope, AstTypeTable* t); - void visitType(DfgScope* scope, AstTypeFunction* f); - void visitType(DfgScope* scope, AstTypeTypeof* t); - void visitType(DfgScope* scope, AstTypeUnion* u); - void visitType(DfgScope* scope, AstTypeIntersection* i); - void visitType(DfgScope* scope, AstTypeError* error); - - void visitTypePack(DfgScope* scope, AstTypePack* p); - void visitTypePack(DfgScope* scope, AstTypePackExplicit* e); - void visitTypePack(DfgScope* scope, AstTypePackVariadic* v); - void visitTypePack(DfgScope* scope, AstTypePackGeneric* g); - - void visitTypeList(DfgScope* scope, AstTypeList l); - - void visitGenerics(DfgScope* scope, AstArray g); - void visitGenericPacks(DfgScope* scope, AstArray g); + /// A stack of scopes used by the visitor to see where we are. + ScopeStack scopeStack; + + DfgScope* currentScope(); + + struct FunctionCapture + { + std::vector captureDefs; + std::vector allVersions; + size_t versionOffset = 0; + }; + + DenseHashMap captures{Symbol{}}; + void resolveCaptures(); + + DfgScope* makeChildScope(Location loc, DfgScope::ScopeType scopeType = DfgScope::Linear); + + void join(DfgScope* p, DfgScope* a, DfgScope* b); + void joinBindings(DfgScope* p, const DfgScope& a, const DfgScope& b); + void joinProps(DfgScope* p, const DfgScope& a, const DfgScope& b); + + DefId lookup(Symbol symbol); + DefId lookup(DefId def, const std::string& key); + + ControlFlow visit(AstStatBlock* b); + ControlFlow visitBlockWithoutChildScope(AstStatBlock* b); + + ControlFlow visit(AstStat* s); + ControlFlow visit(AstStatIf* i); + ControlFlow visit(AstStatWhile* w); + ControlFlow visit(AstStatRepeat* r); + ControlFlow visit(AstStatBreak* b); + ControlFlow visit(AstStatContinue* c); + ControlFlow visit(AstStatReturn* r); + ControlFlow visit(AstStatExpr* e); + ControlFlow visit(AstStatLocal* l); + ControlFlow visit(AstStatFor* f); + ControlFlow visit(AstStatForIn* f); + ControlFlow visit(AstStatAssign* a); + ControlFlow visit(AstStatCompoundAssign* c); + ControlFlow visit(AstStatFunction* f); + ControlFlow visit(AstStatLocalFunction* l); + ControlFlow visit(AstStatTypeAlias* t); + ControlFlow visit(AstStatTypeFunction* f); + ControlFlow visit(AstStatDeclareGlobal* d); + ControlFlow visit(AstStatDeclareFunction* d); + ControlFlow visit(AstStatDeclareClass* d); + ControlFlow visit(AstStatError* error); + + DataFlowResult visitExpr(AstExpr* e); + DataFlowResult visitExpr(AstExprGroup* group); + DataFlowResult visitExpr(AstExprLocal* l); + DataFlowResult visitExpr(AstExprGlobal* g); + DataFlowResult visitExpr(AstExprCall* c); + DataFlowResult visitExpr(AstExprIndexName* i); + DataFlowResult visitExpr(AstExprIndexExpr* i); + DataFlowResult visitExpr(AstExprFunction* f); + DataFlowResult visitExpr(AstExprTable* t); + DataFlowResult visitExpr(AstExprUnary* u); + DataFlowResult visitExpr(AstExprBinary* b); + DataFlowResult visitExpr(AstExprTypeAssertion* t); + DataFlowResult visitExpr(AstExprIfElse* i); + DataFlowResult visitExpr(AstExprInterpString* i); + DataFlowResult visitExpr(AstExprError* error); + + void visitLValue(AstExpr* e, DefId incomingDef); + DefId visitLValue(AstExprLocal* l, DefId incomingDef); + DefId visitLValue(AstExprGlobal* g, DefId incomingDef); + DefId visitLValue(AstExprIndexName* i, DefId incomingDef); + DefId visitLValue(AstExprIndexExpr* i, DefId incomingDef); + DefId visitLValue(AstExprError* e, DefId incomingDef); + + void visitType(AstType* t); + void visitType(AstTypeReference* r); + void visitType(AstTypeTable* t); + void visitType(AstTypeFunction* f); + void visitType(AstTypeTypeof* t); + void visitType(AstTypeUnion* u); + void visitType(AstTypeIntersection* i); + void visitType(AstTypeError* error); + + void visitTypePack(AstTypePack* p); + void visitTypePack(AstTypePackExplicit* e); + void visitTypePack(AstTypePackVariadic* v); + void visitTypePack(AstTypePackGeneric* g); + + void visitTypeList(AstTypeList l); + + void visitGenerics(AstArray g); + void visitGenericPacks(AstArray g); }; } // namespace Luau diff --git a/Analysis/include/Luau/DcrLogger.h b/Analysis/include/Luau/DcrLogger.h index 1e170d5bb..d650d9e06 100644 --- a/Analysis/include/Luau/DcrLogger.h +++ b/Analysis/include/Luau/DcrLogger.h @@ -126,7 +126,11 @@ struct DcrLogger void captureInitialSolverState(const Scope* rootScope, const std::vector>& unsolvedConstraints); StepSnapshot prepareStepSnapshot( - const Scope* rootScope, NotNull current, bool force, const std::vector>& unsolvedConstraints); + const Scope* rootScope, + NotNull current, + bool force, + const std::vector>& unsolvedConstraints + ); void commitStepSnapshot(StepSnapshot snapshot); void captureFinalSolverState(const Scope* rootScope, const std::vector>& unsolvedConstraints); diff --git a/Analysis/include/Luau/Def.h b/Analysis/include/Luau/Def.h index 10d81367e..9627f9988 100644 --- a/Analysis/include/Luau/Def.h +++ b/Analysis/include/Luau/Def.h @@ -23,6 +23,7 @@ using DefId = NotNull; */ struct Cell { + bool subscripted = false; }; /** @@ -71,13 +72,16 @@ const T* get(DefId def) return get_if(&def->v); } +bool containsSubscriptedDefinition(DefId def); +void collectOperands(DefId def, std::vector* operands); + struct DefArena { TypedAllocator allocator; - DefId freshCell(); - // TODO: implement once we have cases where we need to merge in definitions - // DefId phi(const std::vector& defs); + DefId freshCell(bool subscripted = false); + DefId phi(DefId a, DefId b); + DefId phi(const std::vector& defs); }; } // namespace Luau diff --git a/Analysis/include/Luau/Differ.h b/Analysis/include/Luau/Differ.h new file mode 100644 index 000000000..d9b78939b --- /dev/null +++ b/Analysis/include/Luau/Differ.h @@ -0,0 +1,208 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/DenseHash.h" +#include "Luau/TypeFwd.h" +#include "Luau/UnifierSharedState.h" + +#include +#include +#include + +namespace Luau +{ +struct DiffPathNode +{ + // TODO: consider using Variants to simplify toString implementation + enum Kind + { + TableProperty, + FunctionArgument, + FunctionReturn, + Union, + Intersection, + Negation, + }; + Kind kind; + // non-null when TableProperty + std::optional tableProperty; + // non-null when FunctionArgument (unless variadic arg), FunctionReturn (unless variadic arg), Union, or Intersection (i.e. anonymous fields) + std::optional index; + + /** + * Do not use for leaf nodes + */ + DiffPathNode(Kind kind) + : kind(kind) + { + } + + DiffPathNode(Kind kind, std::optional tableProperty, std::optional index) + : kind(kind) + , tableProperty(tableProperty) + , index(index) + { + } + + std::string toString() const; + + static DiffPathNode constructWithTableProperty(Name tableProperty); + + static DiffPathNode constructWithKindAndIndex(Kind kind, size_t index); + + static DiffPathNode constructWithKind(Kind kind); +}; + +struct DiffPathNodeLeaf +{ + std::optional ty; + std::optional tableProperty; + std::optional minLength; + bool isVariadic; + // TODO: Rename to anonymousIndex, for both union and Intersection + std::optional unionIndex; + DiffPathNodeLeaf( + std::optional ty, + std::optional tableProperty, + std::optional minLength, + bool isVariadic, + std::optional unionIndex + ) + : ty(ty) + , tableProperty(tableProperty) + , minLength(minLength) + , isVariadic(isVariadic) + , unionIndex(unionIndex) + { + } + + static DiffPathNodeLeaf detailsNormal(TypeId ty); + + static DiffPathNodeLeaf detailsTableProperty(TypeId ty, Name tableProperty); + + static DiffPathNodeLeaf detailsUnionIndex(TypeId ty, size_t index); + + static DiffPathNodeLeaf detailsLength(int minLength, bool isVariadic); + + static DiffPathNodeLeaf nullopts(); +}; + +struct DiffPath +{ + std::vector path; + + std::string toString(bool prependDot) const; +}; +struct DiffError +{ + enum Kind + { + Normal, + MissingTableProperty, + MissingUnionMember, + MissingIntersectionMember, + IncompatibleGeneric, + LengthMismatchInFnArgs, + LengthMismatchInFnRets, + }; + Kind kind; + + DiffPath diffPath; + DiffPathNodeLeaf left; + DiffPathNodeLeaf right; + + std::string leftRootName; + std::string rightRootName; + + DiffError(Kind kind, DiffPathNodeLeaf left, DiffPathNodeLeaf right, std::string leftRootName, std::string rightRootName) + : kind(kind) + , left(left) + , right(right) + , leftRootName(leftRootName) + , rightRootName(rightRootName) + { + checkValidInitialization(left, right); + } + DiffError(Kind kind, DiffPath diffPath, DiffPathNodeLeaf left, DiffPathNodeLeaf right, std::string leftRootName, std::string rightRootName) + : kind(kind) + , diffPath(diffPath) + , left(left) + , right(right) + , leftRootName(leftRootName) + , rightRootName(rightRootName) + { + checkValidInitialization(left, right); + } + + std::string toString(bool multiLine = false) const; + +private: + std::string toStringALeaf(std::string rootName, const DiffPathNodeLeaf& leaf, const DiffPathNodeLeaf& otherLeaf, bool multiLine) const; + void checkValidInitialization(const DiffPathNodeLeaf& left, const DiffPathNodeLeaf& right); + void checkNonMissingPropertyLeavesHaveNulloptTableProperty() const; +}; + +struct DifferResult +{ + std::optional diffError; + + DifferResult() {} + DifferResult(DiffError diffError) + : diffError(diffError) + { + } + + void wrapDiffPath(DiffPathNode node); +}; +struct DifferEnvironment +{ + TypeId rootLeft; + TypeId rootRight; + std::optional externalSymbolLeft; + std::optional externalSymbolRight; + DenseHashMap genericMatchedPairs; + DenseHashMap genericTpMatchedPairs; + + DifferEnvironment( + TypeId rootLeft, + TypeId rootRight, + std::optional externalSymbolLeft, + std::optional externalSymbolRight + ) + : rootLeft(rootLeft) + , rootRight(rootRight) + , externalSymbolLeft(externalSymbolLeft) + , externalSymbolRight(externalSymbolRight) + , genericMatchedPairs(nullptr) + , genericTpMatchedPairs(nullptr) + { + } + + bool isProvenEqual(TypeId left, TypeId right) const; + bool isAssumedEqual(TypeId left, TypeId right) const; + void recordProvenEqual(TypeId left, TypeId right); + void pushVisiting(TypeId left, TypeId right); + void popVisiting(); + std::vector>::const_reverse_iterator visitingBegin() const; + std::vector>::const_reverse_iterator visitingEnd() const; + std::string getDevFixFriendlyNameLeft() const; + std::string getDevFixFriendlyNameRight() const; + +private: + // TODO: consider using DenseHashSet + std::unordered_set, TypeIdPairHash> provenEqual; + // Ancestors of current types + std::unordered_set, TypeIdPairHash> visiting; + std::vector> visitingStack; +}; +DifferResult diff(TypeId ty1, TypeId ty2); +DifferResult diffWithSymbols(TypeId ty1, TypeId ty2, std::optional symbol1, std::optional symbol2); + +/** + * True if ty is a "simple" type, i.e. cannot contain types. + * string, number, boolean are simple types. + * function and table are not simple types. + */ +bool isSimple(TypeId ty); + +} // namespace Luau diff --git a/Analysis/include/Luau/Error.h b/Analysis/include/Luau/Error.h index 69d4cca3c..fe9d79248 100644 --- a/Analysis/include/Luau/Error.h +++ b/Analysis/include/Luau/Error.h @@ -2,8 +2,12 @@ #pragma once #include "Luau/Location.h" +#include "Luau/NotNull.h" #include "Luau/Type.h" #include "Luau/Variant.h" +#include "Luau/Ast.h" + +#include namespace Luau { @@ -190,6 +194,11 @@ struct InternalError bool operator==(const InternalError& rhs) const; }; +struct ConstraintSolvingIncompleteError +{ + bool operator==(const ConstraintSolvingIncompleteError& rhs) const; +}; + struct CannotCallNonFunction { TypeId ty; @@ -318,6 +327,7 @@ struct TypePackMismatch { TypePackId wantedTp; TypePackId givenTp; + std::string reason; bool operator==(const TypePackMismatch& rhs) const; }; @@ -329,12 +339,172 @@ struct DynamicPropertyLookupOnClassesUnsafe bool operator==(const DynamicPropertyLookupOnClassesUnsafe& rhs) const; }; -using TypeErrorData = Variant; +struct UninhabitedTypeFunction +{ + TypeId ty; + + bool operator==(const UninhabitedTypeFunction& rhs) const; +}; + +struct ExplicitFunctionAnnotationRecommended +{ + std::vector> recommendedArgs; + TypeId recommendedReturn; + bool operator==(const ExplicitFunctionAnnotationRecommended& rhs) const; +}; + +struct UninhabitedTypePackFunction +{ + TypePackId tp; + + bool operator==(const UninhabitedTypePackFunction& rhs) const; +}; + +struct WhereClauseNeeded +{ + TypeId ty; + + bool operator==(const WhereClauseNeeded& rhs) const; +}; + +struct PackWhereClauseNeeded +{ + TypePackId tp; + + bool operator==(const PackWhereClauseNeeded& rhs) const; +}; + +struct CheckedFunctionCallError +{ + TypeId expected; + TypeId passed; + std::string checkedFunctionName; + // TODO: make this a vector + size_t argumentIndex; + bool operator==(const CheckedFunctionCallError& rhs) const; +}; + +struct NonStrictFunctionDefinitionError +{ + std::string functionName; + std::string argument; + TypeId argumentType; + bool operator==(const NonStrictFunctionDefinitionError& rhs) const; +}; + +struct PropertyAccessViolation +{ + TypeId table; + Name key; + + enum + { + CannotRead, + CannotWrite + } context; + + bool operator==(const PropertyAccessViolation& rhs) const; +}; + +struct CheckedFunctionIncorrectArgs +{ + std::string functionName; + size_t expected; + size_t actual; + bool operator==(const CheckedFunctionIncorrectArgs& rhs) const; +}; + +struct CannotAssignToNever +{ + // type of the rvalue being assigned + TypeId rhsType; + + // Originating type. + std::vector cause; + + enum class Reason + { + // when assigning to a property in a union of tables, the properties type + // is narrowed to the intersection of its type in each variant. + PropertyNarrowed, + }; + + Reason reason; + + bool operator==(const CannotAssignToNever& rhs) const; +}; + +struct UnexpectedTypeInSubtyping +{ + TypeId ty; + + bool operator==(const UnexpectedTypeInSubtyping& rhs) const; +}; + +struct UnexpectedTypePackInSubtyping +{ + TypePackId tp; + + bool operator==(const UnexpectedTypePackInSubtyping& rhs) const; +}; + +struct UserDefinedTypeFunctionError +{ + std::string message; + + bool operator==(const UserDefinedTypeFunctionError& rhs) const; +}; + +using TypeErrorData = Variant< + TypeMismatch, + UnknownSymbol, + UnknownProperty, + NotATable, + CannotExtendTable, + OnlyTablesCanHaveMethods, + DuplicateTypeDefinition, + CountMismatch, + FunctionDoesNotTakeSelf, + FunctionRequiresSelf, + OccursCheckFailed, + UnknownRequire, + IncorrectGenericParameterCount, + SyntaxError, + CodeTooComplex, + UnificationTooComplex, + UnknownPropButFoundLikeProp, + GenericError, + InternalError, + ConstraintSolvingIncompleteError, + CannotCallNonFunction, + ExtraInformation, + DeprecatedApiUsed, + ModuleHasCyclicDependency, + IllegalRequire, + FunctionExitsWithoutReturning, + DuplicateGenericParameter, + CannotAssignToNever, + CannotInferBinaryOperation, + MissingProperties, + SwappedGenericTypeParameter, + OptionalValueAccess, + MissingUnionProperty, + TypesAreUnrelated, + NormalizationTooComplex, + TypePackMismatch, + DynamicPropertyLookupOnClassesUnsafe, + UninhabitedTypeFunction, + UninhabitedTypePackFunction, + WhereClauseNeeded, + PackWhereClauseNeeded, + CheckedFunctionCallError, + NonStrictFunctionDefinitionError, + PropertyAccessViolation, + CheckedFunctionIncorrectArgs, + UnexpectedTypeInSubtyping, + UnexpectedTypePackInSubtyping, + ExplicitFunctionAnnotationRecommended, + UserDefinedTypeFunctionError>; struct TypeErrorSummary { @@ -403,7 +573,7 @@ std::string toString(const TypeError& error, TypeErrorToStringOptions options); bool containsParseErrorName(const TypeError& error); // Copy any types named in the error into destArena. -void copyErrors(ErrorVec& errors, struct TypeArena& destArena); +void copyErrors(ErrorVec& errors, struct TypeArena& destArena, NotNull builtinTypes); // Internal Compiler Error struct InternalErrorReporter @@ -411,8 +581,8 @@ struct InternalErrorReporter std::function onInternalError; std::string moduleName; - [[noreturn]] void ice(const std::string& message, const Location& location); - [[noreturn]] void ice(const std::string& message); + [[noreturn]] void ice(const std::string& message, const Location& location) const; + [[noreturn]] void ice(const std::string& message) const; }; class InternalCompilerError : public std::exception diff --git a/Analysis/include/Luau/FileResolver.h b/Analysis/include/Luau/FileResolver.h index 0fdcce161..2f17e5660 100644 --- a/Analysis/include/Luau/FileResolver.h +++ b/Analysis/include/Luau/FileResolver.h @@ -3,6 +3,7 @@ #include #include +#include namespace Luau { @@ -31,6 +32,9 @@ struct ModuleInfo bool optional = false; }; +using RequireSuggestion = std::string; +using RequireSuggestions = std::vector; + struct FileResolver { virtual ~FileResolver() {} @@ -51,6 +55,11 @@ struct FileResolver { return std::nullopt; } + + virtual std::optional getRequireSuggestions(const ModuleName& requirer, const std::optional& pathString) const + { + return std::nullopt; + } }; struct NullFileResolver : FileResolver diff --git a/Analysis/include/Luau/FragmentAutocomplete.h b/Analysis/include/Luau/FragmentAutocomplete.h new file mode 100644 index 000000000..671cbb693 --- /dev/null +++ b/Analysis/include/Luau/FragmentAutocomplete.h @@ -0,0 +1,61 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/Ast.h" +#include "Luau/Parser.h" +#include "Luau/Autocomplete.h" +#include "Luau/DenseHash.h" +#include "Luau/Module.h" + +#include +#include + +namespace Luau +{ +struct FrontendOptions; + +struct FragmentAutocompleteAncestryResult +{ + DenseHashMap localMap{AstName()}; + std::vector localStack; + std::vector ancestry; + AstStat* nearestStatement = nullptr; +}; + +struct FragmentParseResult +{ + std::string fragmentToParse; + AstStatBlock* root = nullptr; + std::vector ancestry; + std::unique_ptr alloc = std::make_unique(); +}; + +struct FragmentTypeCheckResult +{ + ModulePtr incrementalModule = nullptr; + Scope* freshScope = nullptr; +}; + +FragmentAutocompleteAncestryResult findAncestryForFragmentParse(AstStatBlock* root, const Position& cursorPos); + +FragmentParseResult parseFragment(const SourceModule& srcModule, std::string_view src, const Position& cursorPos); + +FragmentTypeCheckResult typecheckFragment( + Frontend& frontend, + const ModuleName& moduleName, + const Position& cursorPos, + std::optional opts, + std::string_view src +); + +AutocompleteResult fragmentAutocomplete( + Frontend& frontend, + std::string_view src, + const ModuleName& moduleName, + Position& cursorPosition, + std::optional opts, + StringCompletionCallback callback +); + + +} // namespace Luau diff --git a/Analysis/include/Luau/Frontend.h b/Analysis/include/Luau/Frontend.h index 7c5dc4a0d..49d7a36dd 100644 --- a/Analysis/include/Luau/Frontend.h +++ b/Analysis/include/Luau/Frontend.h @@ -2,13 +2,16 @@ #pragma once #include "Luau/Config.h" +#include "Luau/GlobalTypes.h" #include "Luau/Module.h" #include "Luau/ModuleResolver.h" #include "Luau/RequireTracer.h" #include "Luau/Scope.h" -#include "Luau/TypeInfer.h" +#include "Luau/TypeCheckLimits.h" #include "Luau/Variant.h" +#include "Luau/AnyTypeSummary.h" +#include #include #include #include @@ -21,39 +24,26 @@ class ParseError; struct Frontend; struct TypeError; struct LintWarning; +struct GlobalTypes; struct TypeChecker; struct FileResolver; struct ModuleResolver; struct ParseResult; struct HotComment; +struct BuildQueueItem; +struct FrontendCancellationToken; +struct AnyTypeSummary; struct LoadDefinitionFileResult { bool success; ParseResult parseResult; + SourceModule sourceModule; ModulePtr module; }; -LoadDefinitionFileResult loadDefinitionFile( - TypeChecker& typeChecker, ScopePtr targetScope, std::string_view definition, const std::string& packageName); - std::optional parseMode(const std::vector& hotcomments); -std::vector parsePathExpr(const AstExpr& pathExpr); - -// Exported only for convenient testing. -std::optional pathExprToModuleName(const ModuleName& currentModuleName, const std::vector& expr); - -/** Try to convert an AST fragment into a ModuleName. - * Returns std::nullopt if the expression cannot be resolved. This will most likely happen in cases where - * the import path involves some dynamic computation that we cannot see into at typechecking time. - * - * Unintuitively, weirdly-formulated modules (like game.Parent.Parent.Parent.Foo) will successfully produce a ModuleName - * as long as it falls within the permitted syntax. This is ok because we will fail to find the module and produce an - * error when we try during typechecking. - */ -std::optional pathExprToModuleName(const ModuleName& currentModuleName, const AstExpr& expr); - struct SourceNode { bool hasDirtySourceModule() const @@ -67,7 +57,8 @@ struct SourceNode } ModuleName name; - std::unordered_set requireSet; + std::string humanReadableName; + DenseHashSet requireSet{{}}; std::vector> requireLocations; bool dirtySourceModule = true; bool dirtyModule = true; @@ -87,14 +78,33 @@ struct FrontendOptions // order to get more precise type information) bool forAutocomplete = false; + bool runLintChecks = false; + // If not empty, randomly shuffle the constraint set before attempting to // solve. Use this value to seed the random number generator. std::optional randomizeConstraintResolutionSeed; + + std::optional enabledLintWarnings; + + std::shared_ptr cancellationToken; + + // Time limit for typechecking a single module + std::optional moduleTimeLimitSec; + + // When true, some internal complexity limits will be scaled down for modules that miss the limit set by moduleTimeLimitSec + bool applyInternalLimitScaling = false; + + // An optional callback which is called for every *dirty* module was checked + // Is multi-threaded typechecking is used, this callback might be called from multiple threads and has to be thread-safe + std::function customModuleCheck; }; struct CheckResult { std::vector errors; + + LintResult lintResult; + std::vector timeoutHits; }; @@ -107,7 +117,13 @@ struct FrontendModuleResolver : ModuleResolver std::optional resolveModuleInfo(const ModuleName& currentModuleName, const AstExpr& pathExpr) override; std::string getHumanReadableModuleName(const ModuleName& moduleName) const override; + void setModule(const ModuleName& moduleName, ModulePtr module); + void clearModules(); + +private: Frontend* frontend; + + mutable std::mutex moduleMutex; std::unordered_map modules; }; @@ -129,10 +145,11 @@ struct Frontend Frontend(FileResolver* fileResolver, ConfigResolver* configResolver, const FrontendOptions& options = {}); - CheckResult check(const ModuleName& name, std::optional optionOverride = {}); // new shininess + // Parse module graph and prepare SourceNode/SourceModule data, including required dependencies without running typechecking + void parse(const ModuleName& name); - LintResult lint(const ModuleName& name, std::optional enabledLintWarnings = {}); - LintResult lint(const SourceModule& module, std::optional enabledLintWarnings = {}); + // Parse and typecheck module graph + CheckResult check(const ModuleName& name, std::optional optionOverride = {}); // new shininess bool isDirty(const ModuleName& name, bool forAutocomplete = false) const; void markDirty(const ModuleName& name, std::vector* markedDirty = nullptr); @@ -152,29 +169,70 @@ struct Frontend void clear(); ScopePtr addEnvironment(const std::string& environmentName); - ScopePtr getEnvironmentScope(const std::string& environmentName); + ScopePtr getEnvironmentScope(const std::string& environmentName) const; - void registerBuiltinDefinition(const std::string& name, std::function); + void registerBuiltinDefinition(const std::string& name, std::function); void applyBuiltinDefinitionToEnvironment(const std::string& environmentName, const std::string& definitionName); - LoadDefinitionFileResult loadDefinitionFile(std::string_view source, const std::string& packageName); - - ScopePtr getGlobalScope(); + LoadDefinitionFileResult loadDefinitionFile( + GlobalTypes& globals, + ScopePtr targetScope, + std::string_view source, + const std::string& packageName, + bool captureComments, + bool typeCheckForAutocomplete = false + ); + + // Batch module checking. Queue modules and check them together, retrieve results with 'getCheckResult' + // If provided, 'executeTask' function is allowed to call the 'task' function on any thread and return without waiting for 'task' to complete + void queueModuleCheck(const std::vector& names); + void queueModuleCheck(const ModuleName& name); + std::vector checkQueuedModules( + std::optional optionOverride = {}, + std::function task)> executeTask = {}, + std::function progress = {} + ); + + std::optional getCheckResult(const ModuleName& name, bool accumulateNested, bool forAutocomplete = false); private: - ModulePtr check(const SourceModule& sourceModule, Mode mode, std::vector requireCycles, bool forAutocomplete = false, bool recordJsonLog = false); + ModulePtr check( + const SourceModule& sourceModule, + Mode mode, + std::vector requireCycles, + std::optional environmentScope, + bool forAutocomplete, + bool recordJsonLog, + TypeCheckLimits typeCheckLimits + ); std::pair getSourceNode(const ModuleName& name); SourceModule parse(const ModuleName& name, std::string_view src, const ParseOptions& parseOptions); - bool parseGraph(std::vector& buildQueue, const ModuleName& root, bool forAutocomplete); + bool parseGraph( + std::vector& buildQueue, + const ModuleName& root, + bool forAutocomplete, + std::function canSkip = {} + ); + + void addBuildQueueItems( + std::vector& items, + std::vector& buildQueue, + bool cycleDetected, + DenseHashSet& seen, + const FrontendOptions& frontendOptions + ); + void checkBuildQueueItem(BuildQueueItem& item); + void checkBuildQueueItems(std::vector& items); + void recordItemResult(const BuildQueueItem& item); static LintResult classifyLints(const std::vector& warnings, const Config& config); - ScopePtr getModuleEnvironment(const SourceModule& module, const Config& config, bool forAutocomplete); + ScopePtr getModuleEnvironment(const SourceModule& module, const Config& config, bool forAutocomplete) const; std::unordered_map environments; - std::unordered_map> builtinDefinitions; + std::unordered_map> builtinDefinitions; BuiltinTypes builtinTypes_; @@ -182,31 +240,56 @@ struct Frontend const NotNull builtinTypes; FileResolver* fileResolver; + FrontendModuleResolver moduleResolver; FrontendModuleResolver moduleResolverForAutocomplete; - TypeChecker typeChecker; - TypeChecker typeCheckerForAutocomplete; + + GlobalTypes globals; + GlobalTypes globalsForAutocomplete; + ConfigResolver* configResolver; FrontendOptions options; InternalErrorReporter iceHandler; - TypeArena globalTypes; + std::function prepareModuleScope; + std::function writeJsonLog = {}; - std::unordered_map sourceNodes; - std::unordered_map sourceModules; + std::unordered_map> sourceNodes; + std::unordered_map> sourceModules; std::unordered_map requireTrace; Stats stats = {}; -private: - ScopePtr globalScope; + std::vector moduleQueue; }; -ModulePtr check(const SourceModule& sourceModule, const std::vector& requireCycles, NotNull builtinTypes, - NotNull iceHandler, NotNull moduleResolver, NotNull fileResolver, - const ScopePtr& globalScope, FrontendOptions options); - -ModulePtr check(const SourceModule& sourceModule, const std::vector& requireCycles, NotNull builtinTypes, - NotNull iceHandler, NotNull moduleResolver, NotNull fileResolver, - const ScopePtr& globalScope, FrontendOptions options, bool recordJsonLog); +ModulePtr check( + const SourceModule& sourceModule, + Mode mode, + const std::vector& requireCycles, + NotNull builtinTypes, + NotNull iceHandler, + NotNull moduleResolver, + NotNull fileResolver, + const ScopePtr& globalScope, + std::function prepareModuleScope, + FrontendOptions options, + TypeCheckLimits limits +); + +ModulePtr check( + const SourceModule& sourceModule, + Mode mode, + const std::vector& requireCycles, + NotNull builtinTypes, + NotNull iceHandler, + NotNull moduleResolver, + NotNull fileResolver, + const ScopePtr& globalScope, + std::function prepareModuleScope, + FrontendOptions options, + TypeCheckLimits limits, + bool recordJsonLog, + std::function writeJsonLog +); } // namespace Luau diff --git a/Analysis/include/Luau/Generalization.h b/Analysis/include/Luau/Generalization.h new file mode 100644 index 000000000..18d5b6782 --- /dev/null +++ b/Analysis/include/Luau/Generalization.h @@ -0,0 +1,19 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/Scope.h" +#include "Luau/NotNull.h" +#include "Luau/TypeFwd.h" + +namespace Luau +{ + +std::optional generalize( + NotNull arena, + NotNull builtinTypes, + NotNull scope, + NotNull> bakedTypes, + TypeId ty, + /* avoid sealing tables*/ bool avoidSealingTables = false +); +} diff --git a/Analysis/include/Luau/GlobalTypes.h b/Analysis/include/Luau/GlobalTypes.h new file mode 100644 index 000000000..55a6d6c73 --- /dev/null +++ b/Analysis/include/Luau/GlobalTypes.h @@ -0,0 +1,25 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#pragma once + +#include "Luau/Module.h" +#include "Luau/NotNull.h" +#include "Luau/Scope.h" +#include "Luau/TypeArena.h" +#include "Luau/TypeFwd.h" + +namespace Luau +{ + +struct GlobalTypes +{ + explicit GlobalTypes(NotNull builtinTypes); + + NotNull builtinTypes; // Global types are based on builtin types + + TypeArena globalTypes; + SourceModule globalNames; // names for symbols entered into globalScope + ScopePtr globalScope; // shared by all modules +}; + +} // namespace Luau diff --git a/Analysis/include/Luau/InsertionOrderedMap.h b/Analysis/include/Luau/InsertionOrderedMap.h new file mode 100644 index 000000000..2937dcda2 --- /dev/null +++ b/Analysis/include/Luau/InsertionOrderedMap.h @@ -0,0 +1,134 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/Common.h" + +#include +#include +#include +#include + +namespace Luau +{ + +template +struct InsertionOrderedMap +{ + static_assert(std::is_trivially_copyable_v, "key must be trivially copyable"); + +private: + using vec = std::vector>; + +public: + using iterator = typename vec::iterator; + using const_iterator = typename vec::const_iterator; + + void insert(K k, V v) + { + if (indices.count(k) != 0) + return; + + pairs.push_back(std::make_pair(k, std::move(v))); + indices[k] = pairs.size() - 1; + } + + void clear() + { + pairs.clear(); + indices.clear(); + } + + size_t size() const + { + LUAU_ASSERT(pairs.size() == indices.size()); + return pairs.size(); + } + + bool contains(const K& k) const + { + return indices.count(k) > 0; + } + + const V* get(const K& k) const + { + auto it = indices.find(k); + if (it == indices.end()) + return nullptr; + else + return &pairs.at(it->second).second; + } + + V* get(const K& k) + { + auto it = indices.find(k); + if (it == indices.end()) + return nullptr; + else + return &pairs.at(it->second).second; + } + + const_iterator begin() const + { + return pairs.begin(); + } + + const_iterator end() const + { + return pairs.end(); + } + + iterator begin() + { + return pairs.begin(); + } + + iterator end() + { + return pairs.end(); + } + + const_iterator find(K k) const + { + auto indicesIt = indices.find(k); + if (indicesIt == indices.end()) + return end(); + else + return begin() + indicesIt->second; + } + + iterator find(K k) + { + auto indicesIt = indices.find(k); + if (indicesIt == indices.end()) + return end(); + else + return begin() + indicesIt->second; + } + + void erase(iterator it) + { + if (it == pairs.end()) + return; + + K k = it->first; + auto indexIt = indices.find(k); + if (indexIt == indices.end()) + return; + + size_t removed = indexIt->second; + indices.erase(indexIt); + pairs.erase(it); + + for (auto& [_, index] : indices) + { + if (index > removed) + --index; + } + } + +private: + vec pairs; + std::unordered_map indices; +}; + +} // namespace Luau diff --git a/Analysis/include/Luau/Instantiation.h b/Analysis/include/Luau/Instantiation.h index c916f953b..73345f98c 100644 --- a/Analysis/include/Luau/Instantiation.h +++ b/Analysis/include/Luau/Instantiation.h @@ -1,22 +1,33 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #pragma once +#include "Luau/NotNull.h" #include "Luau/Substitution.h" -#include "Luau/Type.h" +#include "Luau/TypeFwd.h" #include "Luau/Unifiable.h" +#include "Luau/VisitType.h" namespace Luau { -struct TypeArena; struct TxnLog; +struct TypeArena; +struct TypeCheckLimits; // A substitution which replaces generic types in a given set by free types. struct ReplaceGenerics : Substitution { - ReplaceGenerics(const TxnLog* log, TypeArena* arena, TypeLevel level, Scope* scope, const std::vector& generics, - const std::vector& genericPacks) + ReplaceGenerics( + const TxnLog* log, + TypeArena* arena, + NotNull builtinTypes, + TypeLevel level, + Scope* scope, + const std::vector& generics, + const std::vector& genericPacks + ) : Substitution(log, arena) + , builtinTypes(builtinTypes) , level(level) , scope(scope) , generics(generics) @@ -24,10 +35,23 @@ struct ReplaceGenerics : Substitution { } + void resetState( + const TxnLog* log, + TypeArena* arena, + NotNull builtinTypes, + TypeLevel level, + Scope* scope, + const std::vector& generics, + const std::vector& genericPacks + ); + + NotNull builtinTypes; + TypeLevel level; Scope* scope; std::vector generics; std::vector genericPacks; + bool ignoreChildren(TypeId ty) override; bool isDirty(TypeId ty) override; bool isDirty(TypePackId tp) override; @@ -36,17 +60,26 @@ struct ReplaceGenerics : Substitution }; // A substitution which replaces generic functions by monomorphic functions -struct Instantiation : Substitution +struct Instantiation final : Substitution { - Instantiation(const TxnLog* log, TypeArena* arena, TypeLevel level, Scope* scope) + Instantiation(const TxnLog* log, TypeArena* arena, NotNull builtinTypes, TypeLevel level, Scope* scope) : Substitution(log, arena) + , builtinTypes(builtinTypes) , level(level) , scope(scope) + , reusableReplaceGenerics(log, arena, builtinTypes, level, scope, {}, {}) { } + void resetState(const TxnLog* log, TypeArena* arena, NotNull builtinTypes, TypeLevel level, Scope* scope); + + NotNull builtinTypes; + TypeLevel level; Scope* scope; + + ReplaceGenerics reusableReplaceGenerics; + bool ignoreChildren(TypeId ty) override; bool isDirty(TypeId ty) override; bool isDirty(TypePackId tp) override; @@ -54,4 +87,79 @@ struct Instantiation : Substitution TypePackId clean(TypePackId tp) override; }; +// Used to find if a FunctionType requires generic type cleanup during instantiation +struct GenericTypeFinder : TypeOnceVisitor +{ + bool found = false; + + bool visit(TypeId ty) override + { + return !found; + } + + bool visit(TypePackId ty) override + { + return !found; + } + + bool visit(TypeId ty, const Luau::FunctionType& ftv) override + { + if (ftv.hasNoFreeOrGenericTypes) + return false; + + if (!ftv.generics.empty() || !ftv.genericPacks.empty()) + found = true; + + return !found; + } + + bool visit(TypeId ty, const Luau::TableType& ttv) override + { + if (ttv.state == Luau::TableState::Generic) + found = true; + + return !found; + } + + bool visit(TypeId ty, const Luau::GenericType&) override + { + found = true; + return false; + } + + bool visit(TypePackId ty, const Luau::GenericTypePack&) override + { + found = true; + return false; + } + + bool visit(TypeId ty, const Luau::ClassType&) override + { + // During function instantiation, classes are not traversed even if they have generics + return false; + } +}; + +/** Attempt to instantiate a type. Only used under local type inference. + * + * When given a generic function type, instantiate() will return a copy with the + * generics replaced by fresh types. Instantiation will return the same TypeId + * back if the function does not have any generics. + * + * All higher order generics are left as-is. For example, instantiation of + * ((Y) -> (X, Y)) -> (X, Y) is ((Y) -> ('x, Y)) -> ('x, Y) + * + * We substitute the generic X for the free 'x, but leave the generic Y alone. + * + * Instantiation fails only when processing the type causes internal recursion + * limits to be exceeded. + */ +std::optional instantiate( + NotNull builtinTypes, + NotNull arena, + NotNull limits, + NotNull scope, + TypeId ty +); + } // namespace Luau diff --git a/Analysis/include/Luau/Instantiation2.h b/Analysis/include/Luau/Instantiation2.h new file mode 100644 index 000000000..c9215fada --- /dev/null +++ b/Analysis/include/Luau/Instantiation2.h @@ -0,0 +1,90 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/NotNull.h" +#include "Luau/Substitution.h" +#include "Luau/TxnLog.h" +#include "Luau/TypeFwd.h" +#include "Luau/Unifiable.h" + +namespace Luau +{ + +struct TypeArena; +struct TypeCheckLimits; + +struct Replacer : Substitution +{ + DenseHashMap replacements; + DenseHashMap replacementPacks; + + Replacer(NotNull arena, DenseHashMap replacements, DenseHashMap replacementPacks) + : Substitution(TxnLog::empty(), arena) + , replacements(std::move(replacements)) + , replacementPacks(std::move(replacementPacks)) + { + } + + bool isDirty(TypeId ty) override + { + return replacements.find(ty) != nullptr; + } + + bool isDirty(TypePackId tp) override + { + return replacementPacks.find(tp) != nullptr; + } + + TypeId clean(TypeId ty) override + { + TypeId res = replacements[ty]; + LUAU_ASSERT(res); + dontTraverseInto(res); + return res; + } + + TypePackId clean(TypePackId tp) override + { + TypePackId res = replacementPacks[tp]; + LUAU_ASSERT(res); + dontTraverseInto(res); + return res; + } +}; + +// A substitution which replaces generic functions by monomorphic functions +struct Instantiation2 : Substitution +{ + // Mapping from generic types to free types to be used in instantiation. + DenseHashMap genericSubstitutions{nullptr}; + // Mapping from generic type packs to `TypePack`s of free types to be used in instantiation. + DenseHashMap genericPackSubstitutions{nullptr}; + + Instantiation2(TypeArena* arena, DenseHashMap genericSubstitutions, DenseHashMap genericPackSubstitutions) + : Substitution(TxnLog::empty(), arena) + , genericSubstitutions(std::move(genericSubstitutions)) + , genericPackSubstitutions(std::move(genericPackSubstitutions)) + { + } + + bool ignoreChildren(TypeId ty) override; + bool isDirty(TypeId ty) override; + bool isDirty(TypePackId tp) override; + TypeId clean(TypeId ty) override; + TypePackId clean(TypePackId tp) override; +}; + +std::optional instantiate2( + TypeArena* arena, + DenseHashMap genericSubstitutions, + DenseHashMap genericPackSubstitutions, + TypeId ty +); +std::optional instantiate2( + TypeArena* arena, + DenseHashMap genericSubstitutions, + DenseHashMap genericPackSubstitutions, + TypePackId tp +); + +} // namespace Luau diff --git a/Analysis/include/Luau/IostreamHelpers.h b/Analysis/include/Luau/IostreamHelpers.h index 42b362bee..a16455dfa 100644 --- a/Analysis/include/Luau/IostreamHelpers.h +++ b/Analysis/include/Luau/IostreamHelpers.h @@ -5,6 +5,7 @@ #include "Luau/Location.h" #include "Luau/Type.h" #include "Luau/Ast.h" +#include "Luau/TypePath.h" #include @@ -48,4 +49,14 @@ std::ostream& operator<<(std::ostream& lhs, const TypePackVar& tv); std::ostream& operator<<(std::ostream& lhs, const TypeErrorData& ted); +std::ostream& operator<<(std::ostream& lhs, TypeId ty); +std::ostream& operator<<(std::ostream& lhs, TypePackId tp); + +namespace TypePath +{ + +std::ostream& operator<<(std::ostream& lhs, const Path& path); + +}; // namespace TypePath + } // namespace Luau diff --git a/Analysis/include/Luau/LValue.h b/Analysis/include/Luau/LValue.h index 9a8b863b3..e20d9901e 100644 --- a/Analysis/include/Luau/LValue.h +++ b/Analysis/include/Luau/LValue.h @@ -3,6 +3,7 @@ #include "Luau/Variant.h" #include "Luau/Symbol.h" +#include "Luau/TypeFwd.h" #include #include @@ -10,9 +11,6 @@ namespace Luau { -struct Type; -using TypeId = const Type*; - struct Field; // Deprecated. Do not use in new work. diff --git a/Analysis/include/Luau/Linter.h b/Analysis/include/Luau/Linter.h index 6bbc3d660..f911a6524 100644 --- a/Analysis/include/Luau/Linter.h +++ b/Analysis/include/Luau/Linter.h @@ -1,6 +1,7 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #pragma once +#include "Luau/LinterConfig.h" #include "Luau/Location.h" #include @@ -15,88 +16,23 @@ class AstStat; class AstNameTable; struct TypeChecker; struct Module; -struct HotComment; using ScopePtr = std::shared_ptr; -struct LintWarning -{ - // Make sure any new lint codes are documented here: https://luau-lang.org/lint - // Note that in Studio, the active set of lint warnings is determined by FStringStudioLuauLints - enum Code - { - Code_Unknown = 0, - - Code_UnknownGlobal = 1, // superseded by type checker - Code_DeprecatedGlobal = 2, - Code_GlobalUsedAsLocal = 3, - Code_LocalShadow = 4, // disabled in Studio - Code_SameLineStatement = 5, // disabled in Studio - Code_MultiLineStatement = 6, - Code_LocalUnused = 7, // disabled in Studio - Code_FunctionUnused = 8, // disabled in Studio - Code_ImportUnused = 9, // disabled in Studio - Code_BuiltinGlobalWrite = 10, - Code_PlaceholderRead = 11, - Code_UnreachableCode = 12, - Code_UnknownType = 13, - Code_ForRange = 14, - Code_UnbalancedAssignment = 15, - Code_ImplicitReturn = 16, // disabled in Studio, superseded by type checker in strict mode - Code_DuplicateLocal = 17, - Code_FormatString = 18, - Code_TableLiteral = 19, - Code_UninitializedLocal = 20, - Code_DuplicateFunction = 21, - Code_DeprecatedApi = 22, - Code_TableOperations = 23, - Code_DuplicateCondition = 24, - Code_MisleadingAndOr = 25, - Code_CommentDirective = 26, - Code_IntegerParsing = 27, - Code_ComparisonPrecedence = 28, - - Code__Count - }; - - Code code; - Location location; - std::string text; - - static const char* getName(Code code); - static Code parseName(const char* name); - static uint64_t parseMask(const std::vector& hotcomments); -}; - struct LintResult { std::vector errors; std::vector warnings; }; -struct LintOptions -{ - uint64_t warningMask = 0; - - void enableWarning(LintWarning::Code code) - { - warningMask |= 1ull << code; - } - void disableWarning(LintWarning::Code code) - { - warningMask &= ~(1ull << code); - } - - bool isEnabled(LintWarning::Code code) const - { - return 0 != (warningMask & (1ull << code)); - } - - void setDefaults(); -}; - -std::vector lint(AstStat* root, const AstNameTable& names, const ScopePtr& env, const Module* module, - const std::vector& hotcomments, const LintOptions& options); +std::vector lint( + AstStat* root, + const AstNameTable& names, + const ScopePtr& env, + const Module* module, + const std::vector& hotcomments, + const LintOptions& options +); std::vector getDeprecatedGlobals(const AstNameTable& names); diff --git a/Analysis/include/Luau/Metamethods.h b/Analysis/include/Luau/Metamethods.h index 84b0092fb..747b7201c 100644 --- a/Analysis/include/Luau/Metamethods.h +++ b/Analysis/include/Luau/Metamethods.h @@ -19,6 +19,7 @@ static const std::unordered_map kBinaryOpMetamet {AstExprBinary::Op::Sub, "__sub"}, {AstExprBinary::Op::Mul, "__mul"}, {AstExprBinary::Op::Div, "__div"}, + {AstExprBinary::Op::FloorDiv, "__idiv"}, {AstExprBinary::Op::Pow, "__pow"}, {AstExprBinary::Op::Mod, "__mod"}, {AstExprBinary::Op::Concat, "__concat"}, diff --git a/Analysis/include/Luau/Module.h b/Analysis/include/Luau/Module.h index 2faa0297f..82c189aa1 100644 --- a/Analysis/include/Luau/Module.h +++ b/Analysis/include/Luau/Module.h @@ -2,11 +2,14 @@ #pragma once #include "Luau/Error.h" +#include "Luau/Linter.h" #include "Luau/FileResolver.h" #include "Luau/ParseOptions.h" #include "Luau/ParseResult.h" #include "Luau/Scope.h" #include "Luau/TypeArena.h" +#include "Luau/AnyTypeSummary.h" +#include "Luau/DataFlowGraph.h" #include #include @@ -17,6 +20,7 @@ namespace Luau { struct Module; +struct AnyTypeSummary; using ScopePtr = std::shared_ptr; using ModulePtr = std::shared_ptr; @@ -27,7 +31,9 @@ class AstTypePack; /// Root of the AST of a parsed source file struct SourceModule { - ModuleName name; // DataModel path if possible. Filename if not. + ModuleName name; // Module identifier or a filename + std::string humanReadableName; + SourceCode::Type type = SourceCode::None; std::optional environmentName; bool cyclic = false; @@ -50,6 +56,7 @@ struct SourceModule }; bool isWithinComment(const SourceModule& sourceModule, Position pos); +bool isWithinComment(const ParseResult& result, Position pos); struct RequireCycle { @@ -61,9 +68,16 @@ struct Module { ~Module(); + ModuleName name; + std::string humanReadableName; + TypeArena interfaceTypes; TypeArena internalTypes; + // Summary of Ast Nodes that either contain + // user annotated anys or typechecker inferred anys + AnyTypeSummary ats{}; + // Scopes and AST types refer to parse data, so we need to keep that alive std::shared_ptr allocator; std::shared_ptr names; @@ -74,32 +88,60 @@ struct Module DenseHashMap astTypePacks{nullptr}; DenseHashMap astExpectedTypes{nullptr}; + // For AST nodes that are function calls, this map provides the + // unspecialized type of the function that was called. If a function call + // resolves to a __call metamethod application, this map will point at that + // metamethod. + // + // This is useful for type checking and Signature Help. DenseHashMap astOriginalCallTypes{nullptr}; + + // The specialization of a function that was selected. If the function is + // generic, those generic type parameters will be replaced with the actual + // types that were passed. If the function is an overload, this map will + // point at the specific overloads that were selected. DenseHashMap astOverloadResolvedTypes{nullptr}; + // Only used with for...in loops. The computed type of the next() function + // is kept here for type checking. + DenseHashMap astForInNextTypes{nullptr}; + DenseHashMap astResolvedTypes{nullptr}; - DenseHashMap astOriginalResolvedTypes{nullptr}; DenseHashMap astResolvedTypePacks{nullptr}; - // Map AST nodes to the scope they create. Cannot be NotNull because we need a sentinel value for the map. - DenseHashMap astScopes{nullptr}; + // The computed result type of a compound assignment. (eg foo += 1) + // + // Type checking uses this to check that the result of such an operation is + // actually compatible with the left-side operand. + DenseHashMap astCompoundAssignResultTypes{nullptr}; - std::unique_ptr reduction; + DenseHashMap>> upperBoundContributors{nullptr}; + + // Map AST nodes to the scope they create. Cannot be NotNull because + // we need a sentinel value for the map. + DenseHashMap astScopes{nullptr}; std::unordered_map declaredGlobals; ErrorVec errors; + LintResult lintResult; Mode mode; SourceCode::Type type; + double checkDurationSec = 0.0; bool timeout = false; + bool cancelled = false; TypePackId returnType = nullptr; std::unordered_map exportedTypeBindings; + // We also need to keep DFG data alive between runs + std::shared_ptr dataFlowGraph = nullptr; + std::vector> dfgScopes; bool hasModuleScope() const; ScopePtr getModuleScope() const; - // Once a module has been typechecked, we clone its public interface into a separate arena. - // This helps us to force Type ownership into a DAG rather than a DCG. + // Once a module has been typechecked, we clone its public interface into a + // separate arena. This helps us to force Type ownership into a DAG rather + // than a DCG. void clonePublicInterface(NotNull builtinTypes, InternalErrorReporter& ice); }; diff --git a/Analysis/include/Luau/ModuleResolver.h b/Analysis/include/Luau/ModuleResolver.h index d892ccd7f..59751793c 100644 --- a/Analysis/include/Luau/ModuleResolver.h +++ b/Analysis/include/Luau/ModuleResolver.h @@ -20,8 +20,6 @@ struct ModuleResolver virtual ~ModuleResolver() {} /** Compute a ModuleName from an AST fragment. This AST fragment is generally the argument to the require() function. - * - * You probably want to implement this with some variation of pathExprToModuleName. * * @returns The ModuleInfo if the expression is a syntactically legal path. * @returns std::nullopt if we are unable to determine whether or not the expression is a valid path. Type inference will diff --git a/Analysis/include/Luau/NonStrictTypeChecker.h b/Analysis/include/Luau/NonStrictTypeChecker.h new file mode 100644 index 000000000..6229a932c --- /dev/null +++ b/Analysis/include/Luau/NonStrictTypeChecker.h @@ -0,0 +1,28 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/Module.h" +#include "Luau/NotNull.h" +#include "Luau/DataFlowGraph.h" + +namespace Luau +{ + +struct BuiltinTypes; +struct TypeFunctionRuntime; +struct UnifierSharedState; +struct TypeCheckLimits; + +void checkNonStrict( + NotNull builtinTypes, + NotNull typeFunctionRuntime, + NotNull ice, + NotNull unifierState, + NotNull dfg, + NotNull limits, + const SourceModule& sourceModule, + Module* module +); + + +} // namespace Luau diff --git a/Analysis/include/Luau/Normalize.h b/Analysis/include/Luau/Normalize.h index 15dc7d4a1..97d13a600 100644 --- a/Analysis/include/Luau/Normalize.h +++ b/Analysis/include/Luau/Normalize.h @@ -2,10 +2,15 @@ #pragma once #include "Luau/NotNull.h" -#include "Luau/Type.h" +#include "Luau/Set.h" +#include "Luau/TypeFwd.h" #include "Luau/UnifierSharedState.h" +#include +#include #include +#include +#include namespace Luau { @@ -13,7 +18,6 @@ namespace Luau struct InternalErrorReporter; struct Module; struct Scope; -struct BuiltinTypes; using ModulePtr = std::shared_ptr; @@ -23,7 +27,7 @@ bool isSubtype(TypePackId subTy, TypePackId superTy, NotNull scope, NotNu class TypeIds { private: - std::unordered_set types; + DenseHashMap types{nullptr}; std::vector order; std::size_t hash = 0; @@ -31,10 +35,15 @@ class TypeIds using iterator = std::vector::iterator; using const_iterator = std::vector::const_iterator; - TypeIds(const TypeIds&) = default; - TypeIds(TypeIds&&) = default; TypeIds() = default; ~TypeIds() = default; + + TypeIds(std::initializer_list tys); + + TypeIds(const TypeIds&) = default; + TypeIds& operator=(const TypeIds&) = default; + + TypeIds(TypeIds&&) = default; TypeIds& operator=(TypeIds&&) = default; void insert(TypeId ty); @@ -48,6 +57,7 @@ class TypeIds const_iterator begin() const; const_iterator end() const; iterator erase(const_iterator it); + void erase(TypeId ty); size_t size() const; bool empty() const; @@ -62,6 +72,7 @@ class TypeIds bool operator==(const TypeIds& there) const; size_t getHash() const; + bool isNever() const; }; } // namespace Luau @@ -189,12 +200,8 @@ struct NormalizedClassType // this type may contain `error`. struct NormalizedFunctionType { - NormalizedFunctionType(); - bool isTop = false; - // TODO: Remove this wrapping optional when clipping - // FFlagLuauNegatedFunctionTypes. - std::optional parts; + TypeIds parts; void resetToNever(); void resetToTop(); @@ -203,18 +210,30 @@ struct NormalizedFunctionType }; // A normalized generic/free type is a union, where each option is of the form (X & T) where -// * X is either a free type or a generic +// * X is either a free type, a generic or a blocked type. // * T is a normalized type. struct NormalizedType; using NormalizedTyvars = std::unordered_map>; -bool isInhabited_DEPRECATED(const NormalizedType& norm); +// Operations provided by `Normalizer` can have ternary results: +// 1. The operation returned true. +// 2. The operation returned false. +// 3. They can hit resource limitations, which invalidates _all normalized types_. +enum class NormalizationResult +{ + // The operation returned true or succeeded. + True, + // The operation returned false or failed. + False, + // Resource limits were hit, invalidating all normalized types. + HitLimits, +}; // A normalized type is either any, unknown, or one of the form P | T | F | G where // * P is a union of primitive types (including singletons, classes and the error type) // * T is a union of table types // * F is a union of an intersection of function types -// * G is a union of generic/free normalized types, intersected with a normalized type +// * G is a union of generic/free/blocked types, intersected with a normalized type struct NormalizedType { // The top part of the type. @@ -228,10 +247,6 @@ struct NormalizedType NormalizedClassType classes; - // The class part of the type. - // Each element of this set is a class, and none of the classes are subclasses of each other. - TypeIds DEPRECATED_classes; - // The error part of the type. // This type is either never or the error type. TypeId errors; @@ -252,6 +267,10 @@ struct NormalizedType // This type is either never or thread. TypeId threads; + // The buffer part of the type. + // This type is either never or buffer. + TypeId buffers; + // The (meta)table part of the type. // Each element of this set is a (meta)table type, or the top `table` type. // An empty set denotes never. @@ -263,6 +282,11 @@ struct NormalizedType // The generic/free part of the type. NormalizedTyvars tyvars; + // Free types, blocked types, and certain other types change shape as type + // inference is done. If we were to cache the normalization of these types, + // we'd be reusing bad, stale data. + bool isCacheable = true; + NormalizedType(NotNull builtinTypes); NormalizedType() = delete; @@ -273,22 +297,64 @@ struct NormalizedType NormalizedType(NormalizedType&&) = default; NormalizedType& operator=(NormalizedType&&) = default; + + // IsType functions + bool isUnknown() const; + /// Returns true if the type is exactly a number. Behaves like Type::isNumber() + bool isExactlyNumber() const; + + /// Returns true if the type is a subtype of string(it could be a singleton). Behaves like Type::isString() + bool isSubtypeOfString() const; + + /// Returns true if the type is a subtype of boolean(it could be a singleton). Behaves like Type::isBoolean() + bool isSubtypeOfBooleans() const; + + /// Returns true if this type should result in error suppressing behavior. + bool shouldSuppressErrors() const; + + /// Returns true if this type contains the primitve top table type, `table`. + bool hasTopTable() const; + + // Helpers that improve readability of the above (they just say if the component is present) + bool hasTops() const; + bool hasBooleans() const; + bool hasClasses() const; + bool hasErrors() const; + bool hasNils() const; + bool hasNumbers() const; + bool hasStrings() const; + bool hasThreads() const; + bool hasBuffers() const; + bool hasTables() const; + bool hasFunctions() const; + bool hasTyvars() const; + + bool isFalsy() const; + bool isTruthy() const; }; + +using SeenTablePropPairs = Set, TypeIdPairHash>; + class Normalizer { - std::unordered_map> cachedNormals; + std::unordered_map> cachedNormals; std::unordered_map cachedIntersections; std::unordered_map cachedUnions; std::unordered_map> cachedTypeIds; + + DenseHashMap cachedIsInhabited{nullptr}; + DenseHashMap, bool, TypeIdPairHash> cachedIsInhabitedIntersection{{nullptr, nullptr}}; + bool withinResourceLimits(); public: TypeArena* arena; NotNull builtinTypes; NotNull sharedState; + bool cacheInhabitance = false; - Normalizer(TypeArena* arena, NotNull builtinTypes, NotNull sharedState); + Normalizer(TypeArena* arena, NotNull builtinTypes, NotNull sharedState, bool cacheInhabitance = false); Normalizer(const Normalizer&) = delete; Normalizer(Normalizer&&) = delete; Normalizer() = delete; @@ -297,7 +363,7 @@ class Normalizer Normalizer& operator=(Normalizer&) = delete; // If this returns null, the typechecker should emit a "too complex" error - const NormalizedType* normalize(TypeId ty); + std::shared_ptr normalize(TypeId ty); void clearNormal(NormalizedType& norm); // ------- Cached TypeIds @@ -322,8 +388,14 @@ class Normalizer void unionFunctions(NormalizedFunctionType& heress, const NormalizedFunctionType& theress); void unionTablesWithTable(TypeIds& heres, TypeId there); void unionTables(TypeIds& heres, const TypeIds& theres); - bool unionNormals(NormalizedType& here, const NormalizedType& there, int ignoreSmallerTyvars = -1); - bool unionNormalWithTy(NormalizedType& here, TypeId there, int ignoreSmallerTyvars = -1); + NormalizationResult unionNormals(NormalizedType& here, const NormalizedType& there, int ignoreSmallerTyvars = -1); + NormalizationResult unionNormalWithTy( + NormalizedType& here, + TypeId there, + SeenTablePropPairs& seenTablePropPairs, + Set& seenSetTypes, + int ignoreSmallerTyvars = -1 + ); // ------- Negations std::optional negateNormal(const NormalizedType& here); @@ -331,29 +403,45 @@ class Normalizer TypeId negate(TypeId there); void subtractPrimitive(NormalizedType& here, TypeId ty); void subtractSingleton(NormalizedType& here, TypeId ty); + NormalizationResult intersectNormalWithNegationTy(TypeId toNegate, NormalizedType& intersect); // ------- Normalizing intersections TypeId intersectionOfTops(TypeId here, TypeId there); TypeId intersectionOfBools(TypeId here, TypeId there); - void DEPRECATED_intersectClasses(TypeIds& heres, const TypeIds& theres); - void DEPRECATED_intersectClassesWithClass(TypeIds& heres, TypeId there); void intersectClasses(NormalizedClassType& heres, const NormalizedClassType& theres); void intersectClassesWithClass(NormalizedClassType& heres, TypeId there); void intersectStrings(NormalizedStringType& here, const NormalizedStringType& there); std::optional intersectionOfTypePacks(TypePackId here, TypePackId there); - std::optional intersectionOfTables(TypeId here, TypeId there); - void intersectTablesWithTable(TypeIds& heres, TypeId there); + std::optional intersectionOfTables(TypeId here, TypeId there, SeenTablePropPairs& seenTablePropPairs, Set& seenSet); + void intersectTablesWithTable(TypeIds& heres, TypeId there, SeenTablePropPairs& seenTablePropPairs, Set& seenSetTypes); void intersectTables(TypeIds& heres, const TypeIds& theres); std::optional intersectionOfFunctions(TypeId here, TypeId there); void intersectFunctionsWithFunction(NormalizedFunctionType& heress, TypeId there); void intersectFunctions(NormalizedFunctionType& heress, const NormalizedFunctionType& theress); - bool intersectTyvarsWithTy(NormalizedTyvars& here, TypeId there); - bool intersectNormals(NormalizedType& here, const NormalizedType& there, int ignoreSmallerTyvars = -1); - bool intersectNormalWithTy(NormalizedType& here, TypeId there); + NormalizationResult intersectTyvarsWithTy( + NormalizedTyvars& here, + TypeId there, + SeenTablePropPairs& seenTablePropPairs, + Set& seenSetTypes + ); + NormalizationResult intersectNormals(NormalizedType& here, const NormalizedType& there, int ignoreSmallerTyvars = -1); + NormalizationResult intersectNormalWithTy(NormalizedType& here, TypeId there, SeenTablePropPairs& seenTablePropPairs, Set& seenSetTypes); + NormalizationResult normalizeIntersections( + const std::vector& intersections, + NormalizedType& outType, + SeenTablePropPairs& seenTablePropPairs, + Set& seenSet + ); // Check for inhabitance - bool isInhabited(TypeId ty, std::unordered_set seen = {}); - bool isInhabited(const NormalizedType* norm, std::unordered_set seen = {}); + NormalizationResult isInhabited(TypeId ty); + NormalizationResult isInhabited(TypeId ty, Set& seen); + NormalizationResult isInhabited(const NormalizedType* norm); + NormalizationResult isInhabited(const NormalizedType* norm, Set& seen); + + // Check for intersections being inhabited + NormalizationResult isIntersectionInhabited(TypeId left, TypeId right); + NormalizationResult isIntersectionInhabited(TypeId left, TypeId right, SeenTablePropPairs& seenTablePropPairs, Set& seenSet); // -------- Convert back from a normalized type to a type TypeId typeFromNormal(const NormalizedType& norm); diff --git a/Analysis/include/Luau/OverloadResolution.h b/Analysis/include/Luau/OverloadResolution.h new file mode 100644 index 000000000..83a33215a --- /dev/null +++ b/Analysis/include/Luau/OverloadResolution.h @@ -0,0 +1,123 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/Ast.h" +#include "Luau/InsertionOrderedMap.h" +#include "Luau/NotNull.h" +#include "Luau/TypeFwd.h" +#include "Luau/Location.h" +#include "Luau/Error.h" +#include "Luau/Subtyping.h" + +namespace Luau +{ + +struct BuiltinTypes; +struct TypeArena; +struct Scope; +struct InternalErrorReporter; +struct TypeCheckLimits; +struct Subtyping; + +class Normalizer; + +struct OverloadResolver +{ + enum Analysis + { + Ok, + TypeIsNotAFunction, + ArityMismatch, + OverloadIsNonviable, // Arguments were incompatible with the overloads parameters but were otherwise compatible by arity + }; + + OverloadResolver( + NotNull builtinTypes, + NotNull arena, + NotNull normalizer, + NotNull typeFunctionRuntime, + NotNull scope, + NotNull reporter, + NotNull limits, + Location callLocation + ); + + NotNull builtinTypes; + NotNull arena; + NotNull normalizer; + NotNull typeFunctionRuntime; + NotNull scope; + NotNull ice; + NotNull limits; + Subtyping subtyping; + Location callLoc; + + // Resolver results + std::vector ok; + std::vector nonFunctions; + std::vector> arityMismatches; + std::vector> nonviableOverloads; + InsertionOrderedMap> resolution; + + + std::pair selectOverload(TypeId ty, TypePackId args); + void resolve(TypeId fnTy, const TypePack* args, AstExpr* selfExpr, const std::vector* argExprs); + +private: + std::optional testIsSubtype(const Location& location, TypeId subTy, TypeId superTy); + std::optional testIsSubtype(const Location& location, TypePackId subTy, TypePackId superTy); + std::pair checkOverload( + TypeId fnTy, + const TypePack* args, + AstExpr* fnLoc, + const std::vector* argExprs, + bool callMetamethodOk = true + ); + static bool isLiteral(AstExpr* expr); + LUAU_NOINLINE + std::pair checkOverload_( + TypeId fnTy, + const FunctionType* fn, + const TypePack* args, + AstExpr* fnExpr, + const std::vector* argExprs + ); + size_t indexof(Analysis analysis); + void add(Analysis analysis, TypeId ty, ErrorVec&& errors); +}; + +struct SolveResult +{ + enum OverloadCallResult + { + Ok, + CodeTooComplex, + OccursCheckFailed, + NoMatchingOverload, + }; + + OverloadCallResult result; + std::optional typePackId; // nullopt if result != Ok + + TypeId overloadToUse = nullptr; + TypeId inferredTy = nullptr; + DenseHashMap> expandedFreeTypes{nullptr}; +}; + +// Helper utility, presently used for binary operator type functions. +// +// Given a function and a set of arguments, select a suitable overload. +SolveResult solveFunctionCall( + NotNull arena, + NotNull builtinTypes, + NotNull normalizer, + NotNull typeFunctionRuntime, + NotNull iceReporter, + NotNull limits, + NotNull scope, + const Location& location, + TypeId fn, + TypePackId argsPack +); + +} // namespace Luau diff --git a/Analysis/include/Luau/Predicate.h b/Analysis/include/Luau/Predicate.h index 50fd7edd8..52ee1f298 100644 --- a/Analysis/include/Luau/Predicate.h +++ b/Analysis/include/Luau/Predicate.h @@ -4,15 +4,13 @@ #include "Luau/Location.h" #include "Luau/LValue.h" #include "Luau/Variant.h" +#include "Luau/TypeFwd.h" #include namespace Luau { -struct Type; -using TypeId = const Type*; - struct TruthyPredicate; struct IsAPredicate; struct TypeGuardPredicate; diff --git a/Analysis/include/Luau/Quantify.h b/Analysis/include/Luau/Quantify.h index b350fab52..bae3751da 100644 --- a/Analysis/include/Luau/Quantify.h +++ b/Analysis/include/Luau/Quantify.h @@ -1,7 +1,12 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #pragma once -#include "Luau/Type.h" +#include "Luau/TypeFwd.h" +#include "Luau/DenseHash.h" +#include "Luau/Unifiable.h" + +#include +#include namespace Luau { @@ -10,6 +15,29 @@ struct TypeArena; struct Scope; void quantify(TypeId ty, TypeLevel level); -TypeId quantify(TypeArena* arena, TypeId ty, Scope* scope); + +// TODO: This is eerily similar to the pattern that NormalizedClassType +// implements. We could, and perhaps should, merge them together. +template +struct OrderedMap +{ + std::vector keys; + DenseHashMap pairings{nullptr}; + + void push(K k, V v) + { + keys.push_back(k); + pairings[k] = v; + } +}; + +struct QuantifierResult +{ + TypeId result; + OrderedMap insertedGenerics; + OrderedMap insertedGenericPacks; +}; + +std::optional quantify(TypeArena* arena, TypeId ty, Scope* scope); } // namespace Luau diff --git a/Analysis/include/Luau/Refinement.h b/Analysis/include/Luau/Refinement.h index fecf459ad..3fea78688 100644 --- a/Analysis/include/Luau/Refinement.h +++ b/Analysis/include/Luau/Refinement.h @@ -4,14 +4,13 @@ #include "Luau/NotNull.h" #include "Luau/TypedAllocator.h" #include "Luau/Variant.h" +#include "Luau/TypeFwd.h" namespace Luau { -using BreadcrumbId = NotNull; - -struct Type; -using TypeId = const Type*; +struct RefinementKey; +using DefId = NotNull; struct Variadic; struct Negation; @@ -52,7 +51,7 @@ struct Equivalence struct Proposition { - BreadcrumbId breadcrumb; + const RefinementKey* key; TypeId discriminantTy; }; @@ -69,7 +68,7 @@ struct RefinementArena RefinementId conjunction(RefinementId lhs, RefinementId rhs); RefinementId disjunction(RefinementId lhs, RefinementId rhs); RefinementId equivalence(RefinementId lhs, RefinementId rhs); - RefinementId proposition(BreadcrumbId breadcrumb, TypeId discriminantTy); + RefinementId proposition(const RefinementKey* key, TypeId discriminantTy); private: TypedAllocator allocator; diff --git a/Analysis/include/Luau/Scope.h b/Analysis/include/Luau/Scope.h index 0d3972672..0e6eff56d 100644 --- a/Analysis/include/Luau/Scope.h +++ b/Analysis/include/Luau/Scope.h @@ -2,9 +2,13 @@ #pragma once #include "Luau/Def.h" +#include "Luau/LValue.h" #include "Luau/Location.h" #include "Luau/NotNull.h" #include "Luau/Type.h" +#include "Luau/DenseHash.h" +#include "Luau/Symbol.h" +#include "Luau/Unifiable.h" #include #include @@ -41,6 +45,8 @@ struct Scope TypeLevel level; + Location location; // the spanning location associated with this scope + std::unordered_map exportedTypeBindings; std::unordered_map privateTypeBindings; std::unordered_map typeAliasLocations; @@ -52,20 +58,32 @@ struct Scope void addBuiltinTypeBinding(const Name& name, const TypeFun& tyFun); std::optional lookup(Symbol sym) const; + std::optional lookupUnrefinedType(DefId def) const; std::optional lookup(DefId def) const; + std::optional> lookupEx(DefId def); std::optional> lookupEx(Symbol sym); - std::optional lookupType(const Name& name); - std::optional lookupImportedType(const Name& moduleAlias, const Name& name); + std::optional lookupType(const Name& name) const; + std::optional lookupImportedType(const Name& moduleAlias, const Name& name) const; std::unordered_map privateTypePackBindings; - std::optional lookupPack(const Name& name); + std::optional lookupPack(const Name& name) const; // WARNING: This function linearly scans for a string key of equal value! It is thus O(n**2) std::optional linearSearchForBinding(const std::string& name, bool traverseScopeChain = true) const; RefinementMap refinements; - DenseHashMap dcrRefinements{nullptr}; + + // This can be viewed as the "unrefined" type of each binding. + DenseHashMap lvalueTypes{nullptr}; + + // Luau values are routinely refined more narrowly than their actual + // inferred type through control flow statements. We retain those refined + // types here. + DenseHashMap rvalueRefinements{nullptr}; + + void inheritAssignments(const ScopePtr& childScope); + void inheritRefinements(const ScopePtr& childScope); // For mutually recursive type aliases, it's important that // they use the same types for the same names. @@ -84,4 +102,12 @@ bool subsumesStrict(Scope* left, Scope* right); // outermost-possible scope. bool subsumes(Scope* left, Scope* right); +inline Scope* max(Scope* left, Scope* right) +{ + if (subsumes(left, right)) + return right; + else + return left; +} + } // namespace Luau diff --git a/Analysis/include/Luau/Set.h b/Analysis/include/Luau/Set.h new file mode 100644 index 000000000..613e5aa5e --- /dev/null +++ b/Analysis/include/Luau/Set.h @@ -0,0 +1,194 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/Common.h" +#include "Luau/DenseHash.h" + +LUAU_FASTFLAG(LuauSolverV2) + +namespace Luau +{ + +template +using SetHashDefault = std::conditional_t, DenseHashPointer, std::hash>; + +// This is an implementation of `unordered_set` using `DenseHashMap` to support erasure. +// This lets us work around `DenseHashSet` limitations and get a more traditional set interface. +template> +class Set +{ +private: + using Impl = DenseHashMap; + Impl mapping; + size_t entryCount = 0; + +public: + class const_iterator; + using iterator = const_iterator; + + Set(const T& empty_key) + : mapping{empty_key} + { + } + + bool insert(const T& element) + { + bool& entry = mapping[element]; + bool fresh = !entry; + + if (fresh) + { + entry = true; + entryCount++; + } + + return fresh; + } + + template + void insert(Iterator begin, Iterator end) + { + for (Iterator it = begin; it != end; ++it) + insert(*it); + } + + void erase(T&& element) + { + bool& entry = mapping[element]; + + if (entry) + { + entry = false; + entryCount--; + } + } + + void erase(const T& element) + { + bool& entry = mapping[element]; + + if (entry) + { + entry = false; + entryCount--; + } + } + + void clear() + { + mapping.clear(); + entryCount = 0; + } + + size_t size() const + { + return entryCount; + } + + bool empty() const + { + return entryCount == 0; + } + + size_t count(const T& element) const + { + const bool* entry = mapping.find(element); + return (entry && *entry) ? 1 : 0; + } + + bool contains(const T& element) const + { + return count(element) != 0; + } + + const_iterator begin() const + { + return const_iterator(mapping.begin(), mapping.end()); + } + + const_iterator end() const + { + return const_iterator(mapping.end(), mapping.end()); + } + + bool operator==(const Set& there) const + { + // if the sets are unequal sizes, then they cannot possibly be equal. + if (size() != there.size()) + return false; + + // otherwise, we'll need to check that every element we have here is in `there`. + for (auto [elem, present] : mapping) + { + // if it's not, we'll return `false` + if (present && there.contains(elem)) + return false; + } + + // otherwise, we've proven the two equal! + return true; + } + + class const_iterator + { + public: + using value_type = T; + using reference = T&; + using pointer = T*; + using difference_type = ptrdiff_t; + using iterator_category = std::forward_iterator_tag; + + const_iterator(typename Impl::const_iterator impl_, typename Impl::const_iterator end_) + : impl(impl_) + , end(end_) + { + while (impl != end && impl->second == false) + ++impl; + } + + const T& operator*() const + { + return impl->first; + } + + const T* operator->() const + { + return &impl->first; + } + + bool operator==(const const_iterator& other) const + { + return impl == other.impl; + } + + bool operator!=(const const_iterator& other) const + { + return impl != other.impl; + } + + + const_iterator& operator++() + { + do + { + impl++; + } while (impl != end && impl->second == false); + // keep iterating past pairs where the value is `false` + + return *this; + } + + const_iterator operator++(int) + { + const_iterator res = *this; + ++*this; + return res; + } + + private: + typename Impl::const_iterator impl; + typename Impl::const_iterator end; + }; +}; + +} // namespace Luau diff --git a/Analysis/include/Luau/Simplify.h b/Analysis/include/Luau/Simplify.h new file mode 100644 index 000000000..5b363e964 --- /dev/null +++ b/Analysis/include/Luau/Simplify.h @@ -0,0 +1,38 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#pragma once + +#include "Luau/DenseHash.h" +#include "Luau/NotNull.h" +#include "Luau/TypeFwd.h" +#include + +namespace Luau +{ + +struct TypeArena; + +struct SimplifyResult +{ + TypeId result; + + DenseHashSet blockedTypes; +}; + +SimplifyResult simplifyIntersection(NotNull builtinTypes, NotNull arena, TypeId ty, TypeId discriminant); +SimplifyResult simplifyIntersection(NotNull builtinTypes, NotNull arena, std::set parts); + +SimplifyResult simplifyUnion(NotNull builtinTypes, NotNull arena, TypeId ty, TypeId discriminant); + +enum class Relation +{ + Disjoint, // No A is a B or vice versa + Coincident, // Every A is in B and vice versa + Intersects, // Some As are in B and some Bs are in A. ex (number | string) <-> (string | boolean) + Subset, // Every A is in B + Superset, // Every B is in A +}; + +Relation relate(TypeId left, TypeId right); + +} // namespace Luau diff --git a/Analysis/include/Luau/Substitution.h b/Analysis/include/Luau/Substitution.h index 2efca2df5..28ebc93d6 100644 --- a/Analysis/include/Luau/Substitution.h +++ b/Analysis/include/Luau/Substitution.h @@ -2,8 +2,7 @@ #pragma once #include "Luau/TypeArena.h" -#include "Luau/TypePack.h" -#include "Luau/Type.h" +#include "Luau/TypeFwd.h" #include "Luau/DenseHash.h" // We provide an implementation of substitution on types, @@ -69,24 +68,34 @@ struct TarjanWorklistVertex int lastEdge; }; +struct TarjanNode +{ + TypeId ty; + TypePackId tp; + + bool onStack; + bool dirty; + + // Tarjan calculates the lowlink for each vertex, + // which is the lowest ancestor index reachable from the vertex. + int lowlink; +}; + // Tarjan's algorithm for finding the SCCs in a cyclic structure. // https://en.wikipedia.org/wiki/Tarjan%27s_strongly_connected_components_algorithm struct Tarjan { + Tarjan(); + // Vertices (types and type packs) are indexed, using pre-order traversal. DenseHashMap typeToIndex{nullptr}; DenseHashMap packToIndex{nullptr}; - std::vector indexToType; - std::vector indexToPack; + + std::vector nodes; // Tarjan keeps a stack of vertices where we're still in the process // of finding their SCC. std::vector stack; - std::vector onStack; - - // Tarjan calculates the lowlink for each vertex, - // which is the lowest ancestor index reachable from the vertex. - std::vector lowlink; int childCount = 0; int childLimit = 0; @@ -98,6 +107,7 @@ struct Tarjan std::vector edgesTy; std::vector edgesTp; std::vector worklist; + // This is hot code, so we optimize recursion to a stack. TarjanResult loop(); @@ -113,45 +123,53 @@ struct Tarjan void visitChild(TypeId ty); void visitChild(TypePackId ty); + template + void visitChild(std::optional ty) + { + if (ty) + visitChild(*ty); + } + // Visit the root vertex. TarjanResult visitRoot(TypeId ty); TarjanResult visitRoot(TypePackId ty); - // Each subclass gets called back once for each edge, - // and once for each SCC. - virtual void visitEdge(int index, int parentIndex) {} - virtual void visitSCC(int index) {} + // Used to reuse the object for a new operation + void clearTarjan(const TxnLog* log); + + // Get/set the dirty bit for an index (grows the vector if needed) + bool getDirty(int index); + void setDirty(int index, bool d); + + // Find all the dirty vertices reachable from `t`. + TarjanResult findDirty(TypeId t); + TarjanResult findDirty(TypePackId t); + + // We find dirty vertices using Tarjan + void visitEdge(int index, int parentIndex); + void visitSCC(int index); // Each subclass can decide to ignore some nodes. virtual bool ignoreChildren(TypeId ty) { return false; } + virtual bool ignoreChildren(TypePackId ty) { return false; } -}; - -// We use Tarjan to calculate dirty bits. We set `dirty[i]` true -// if the vertex with index `i` can reach a dirty vertex. -struct FindDirty : Tarjan -{ - std::vector dirty; - void clearTarjan(); - - // Get/set the dirty bit for an index (grows the vector if needed) - bool getDirty(int index); - void setDirty(int index, bool d); - - // Find all the dirty vertices reachable from `t`. - TarjanResult findDirty(TypeId t); - TarjanResult findDirty(TypePackId t); + // Some subclasses might ignore children visit, but not other actions like replacing the children + virtual bool ignoreChildrenVisit(TypeId ty) + { + return ignoreChildren(ty); + } - // We find dirty vertices using Tarjan - void visitEdge(int index, int parentIndex) override; - void visitSCC(int index) override; + virtual bool ignoreChildrenVisit(TypePackId ty) + { + return ignoreChildren(ty); + } // Subclasses should say which vertices are dirty, // and what to do with dirty vertices. @@ -163,16 +181,24 @@ struct FindDirty : Tarjan // And finally substitution, which finds all the reachable dirty vertices // and replaces them with clean ones. -struct Substitution : FindDirty +struct Substitution : Tarjan { protected: - Substitution(const TxnLog* log_, TypeArena* arena) - : arena(arena) - { - log = log_; - LUAU_ASSERT(log); - LUAU_ASSERT(arena); - } + Substitution(const TxnLog* log_, TypeArena* arena); + + /* + * By default, Substitution assumes that the types produced by clean() are + * freshly allocated types that are safe to mutate. + * + * If your clean() implementation produces a type that is not safe to + * mutate, you must call dontTraverseInto on this type (or type pack) to + * prevent Substitution from attempting to perform substitutions within the + * cleaned type. + * + * See the test weird_cyclic_instantiation for an example. + */ + void dontTraverseInto(TypeId ty); + void dontTraverseInto(TypePackId tp); public: TypeArena* arena; @@ -181,13 +207,20 @@ struct Substitution : FindDirty DenseHashSet replacedTypes{nullptr}; DenseHashSet replacedTypePacks{nullptr}; + DenseHashSet noTraverseTypes{nullptr}; + DenseHashSet noTraverseTypePacks{nullptr}; + std::optional substitute(TypeId ty); std::optional substitute(TypePackId tp); + void resetState(const TxnLog* log, TypeArena* arena); + TypeId replace(TypeId ty); TypePackId replace(TypePackId tp); + void replaceChildren(TypeId ty); void replaceChildren(TypePackId tp); + TypeId clone(TypeId ty); TypePackId clone(TypePackId tp); @@ -211,6 +244,16 @@ struct Substitution : FindDirty { return arena->addTypePack(TypePackVar{tp}); } + +private: + template + std::optional replace(std::optional ty) + { + if (ty) + return replace(*ty); + else + return std::nullopt; + } }; } // namespace Luau diff --git a/Analysis/include/Luau/Subtyping.h b/Analysis/include/Luau/Subtyping.h new file mode 100644 index 000000000..1e7810560 --- /dev/null +++ b/Analysis/include/Luau/Subtyping.h @@ -0,0 +1,311 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/Set.h" +#include "Luau/TypeFwd.h" +#include "Luau/TypePairHash.h" +#include "Luau/TypePath.h" +#include "Luau/TypeFunction.h" +#include "Luau/TypeCheckLimits.h" +#include "Luau/DenseHash.h" + +#include +#include + +namespace Luau +{ + +template +struct TryPair; +struct InternalErrorReporter; + +class TypeIds; +class Normalizer; +struct NormalizedClassType; +struct NormalizedFunctionType; +struct NormalizedStringType; +struct NormalizedType; +struct Property; +struct Scope; +struct TableIndexer; +struct TypeArena; +struct TypeCheckLimits; + +enum class SubtypingVariance +{ + // Used for an empty key. Should never appear in actual code. + Invalid, + Covariant, + // This is used to identify cases where we have a covariant + a + // contravariant reason and we need to merge them. + Contravariant, + Invariant, +}; + +struct SubtypingReasoning +{ + // The path, relative to the _root subtype_, where subtyping failed. + Path subPath; + // The path, relative to the _root supertype_, where subtyping failed. + Path superPath; + SubtypingVariance variance = SubtypingVariance::Covariant; + + bool operator==(const SubtypingReasoning& other) const; +}; + +struct SubtypingReasoningHash +{ + size_t operator()(const SubtypingReasoning& r) const; +}; + +using SubtypingReasonings = DenseHashSet; +static const SubtypingReasoning kEmptyReasoning = SubtypingReasoning{TypePath::kEmpty, TypePath::kEmpty, SubtypingVariance::Invalid}; + +struct SubtypingResult +{ + bool isSubtype = false; + bool normalizationTooComplex = false; + bool isCacheable = true; + ErrorVec errors; + /// The reason for isSubtype to be false. May not be present even if + /// isSubtype is false, depending on the input types. + SubtypingReasonings reasoning{kEmptyReasoning}; + + SubtypingResult& andAlso(const SubtypingResult& other); + SubtypingResult& orElse(const SubtypingResult& other); + SubtypingResult& withBothComponent(TypePath::Component component); + SubtypingResult& withSuperComponent(TypePath::Component component); + SubtypingResult& withSubComponent(TypePath::Component component); + SubtypingResult& withBothPath(TypePath::Path path); + SubtypingResult& withSubPath(TypePath::Path path); + SubtypingResult& withSuperPath(TypePath::Path path); + SubtypingResult& withErrors(ErrorVec& err); + SubtypingResult& withError(TypeError err); + + // Only negates the `isSubtype`. + static SubtypingResult negate(const SubtypingResult& result); + static SubtypingResult all(const std::vector& results); + static SubtypingResult any(const std::vector& results); +}; + +struct SubtypingEnvironment +{ + struct GenericBounds + { + DenseHashSet lowerBound{nullptr}; + DenseHashSet upperBound{nullptr}; + }; + + /* For nested subtyping relationship tests of mapped generic bounds, we keep the outer environment immutable */ + SubtypingEnvironment* parent = nullptr; + + /// Applies `mappedGenerics` to the given type. + /// This is used specifically to substitute for generics in type function instances. + std::optional applyMappedGenerics(NotNull builtinTypes, NotNull arena, TypeId ty); + + const TypeId* tryFindSubstitution(TypeId ty) const; + const SubtypingResult* tryFindSubtypingResult(std::pair subAndSuper) const; + + bool containsMappedType(TypeId ty) const; + bool containsMappedPack(TypePackId tp) const; + + GenericBounds& getMappedTypeBounds(TypeId ty); + TypePackId* getMappedPackBounds(TypePackId tp); + + /* + * When we encounter a generic over the course of a subtyping test, we need + * to tentatively map that generic onto a type on the other side. + */ + DenseHashMap mappedGenerics{nullptr}; + DenseHashMap mappedGenericPacks{nullptr}; + + /* + * See the test cyclic_tables_are_assumed_to_be_compatible_with_classes for + * details. + * + * An empty value is equivalent to a nonexistent key. + */ + DenseHashMap substitutions{nullptr}; + + DenseHashMap, SubtypingResult, TypePairHash> ephemeralCache{{}}; +}; + +struct Subtyping +{ + NotNull builtinTypes; + NotNull arena; + NotNull normalizer; + NotNull typeFunctionRuntime; + NotNull iceReporter; + + TypeCheckLimits limits; + + enum class Variance + { + Covariant, + Contravariant + }; + + Variance variance = Variance::Covariant; + + using SeenSet = Set, TypePairHash>; + + SeenSet seenTypes{{}}; + + Subtyping( + NotNull builtinTypes, + NotNull typeArena, + NotNull normalizer, + NotNull typeFunctionRuntime, + NotNull iceReporter + ); + + Subtyping(const Subtyping&) = delete; + Subtyping& operator=(const Subtyping&) = delete; + + Subtyping(Subtyping&&) = default; + Subtyping& operator=(Subtyping&&) = default; + + // Only used by unit tests to test that the cache works. + const DenseHashMap, SubtypingResult, TypePairHash>& peekCache() const + { + return resultCache; + } + + // TODO cache + // TODO cyclic types + // TODO recursion limits + + SubtypingResult isSubtype(TypeId subTy, TypeId superTy, NotNull scope); + SubtypingResult isSubtype(TypePackId subTy, TypePackId superTy, NotNull scope); + +private: + DenseHashMap, SubtypingResult, TypePairHash> resultCache{{}}; + + SubtypingResult cache(SubtypingEnvironment& env, SubtypingResult res, TypeId subTy, TypeId superTy); + + SubtypingResult isCovariantWith(SubtypingEnvironment& env, TypeId subTy, TypeId superTy, NotNull scope); + SubtypingResult isCovariantWith(SubtypingEnvironment& env, TypePackId subTp, TypePackId superTp, NotNull scope); + + template + SubtypingResult isContravariantWith(SubtypingEnvironment& env, SubTy&& subTy, SuperTy&& superTy, NotNull scope); + + template + SubtypingResult isInvariantWith(SubtypingEnvironment& env, SubTy&& subTy, SuperTy&& superTy, NotNull scope); + + template + SubtypingResult isCovariantWith(SubtypingEnvironment& env, const TryPair& pair, NotNull scope); + + template + SubtypingResult isContravariantWith(SubtypingEnvironment& env, const TryPair& pair, NotNull); + + template + SubtypingResult isInvariantWith(SubtypingEnvironment& env, const TryPair& pair, NotNull); + + SubtypingResult isCovariantWith(SubtypingEnvironment& env, TypeId subTy, const UnionType* superUnion, NotNull scope); + SubtypingResult isCovariantWith(SubtypingEnvironment& env, const UnionType* subUnion, TypeId superTy, NotNull scope); + SubtypingResult isCovariantWith(SubtypingEnvironment& env, TypeId subTy, const IntersectionType* superIntersection, NotNull scope); + SubtypingResult isCovariantWith(SubtypingEnvironment& env, const IntersectionType* subIntersection, TypeId superTy, NotNull scope); + + SubtypingResult isCovariantWith(SubtypingEnvironment& env, const NegationType* subNegation, TypeId superTy, NotNull scope); + SubtypingResult isCovariantWith(SubtypingEnvironment& env, const TypeId subTy, const NegationType* superNegation, NotNull scope); + + SubtypingResult isCovariantWith(SubtypingEnvironment& env, const PrimitiveType* subPrim, const PrimitiveType* superPrim, NotNull scope); + SubtypingResult isCovariantWith( + SubtypingEnvironment& env, + const SingletonType* subSingleton, + const PrimitiveType* superPrim, + NotNull scope + ); + SubtypingResult isCovariantWith( + SubtypingEnvironment& env, + const SingletonType* subSingleton, + const SingletonType* superSingleton, + NotNull scope + ); + SubtypingResult isCovariantWith(SubtypingEnvironment& env, const TableType* subTable, const TableType* superTable, NotNull scope); + SubtypingResult isCovariantWith(SubtypingEnvironment& env, const MetatableType* subMt, const MetatableType* superMt, NotNull scope); + SubtypingResult isCovariantWith(SubtypingEnvironment& env, const MetatableType* subMt, const TableType* superTable, NotNull scope); + SubtypingResult isCovariantWith(SubtypingEnvironment& env, const ClassType* subClass, const ClassType* superClass, NotNull scope); + SubtypingResult + isCovariantWith(SubtypingEnvironment& env, TypeId subTy, const ClassType* subClass, TypeId superTy, const TableType* superTable, NotNull); + SubtypingResult isCovariantWith( + SubtypingEnvironment& env, + const FunctionType* subFunction, + const FunctionType* superFunction, + NotNull scope + ); + SubtypingResult isCovariantWith(SubtypingEnvironment& env, const TableType* subTable, const PrimitiveType* superPrim, NotNull scope); + SubtypingResult isCovariantWith(SubtypingEnvironment& env, const PrimitiveType* subPrim, const TableType* superTable, NotNull scope); + SubtypingResult isCovariantWith(SubtypingEnvironment& env, const SingletonType* subSingleton, const TableType* superTable, NotNull scope); + + SubtypingResult isCovariantWith( + SubtypingEnvironment& env, + const TableIndexer& subIndexer, + const TableIndexer& superIndexer, + NotNull scope + ); + SubtypingResult + isCovariantWith(SubtypingEnvironment& env, const Property& subProperty, const Property& superProperty, const std::string& name, NotNull); + + SubtypingResult isCovariantWith( + SubtypingEnvironment& env, + const std::shared_ptr& subNorm, + const std::shared_ptr& superNorm, + NotNull scope + ); + SubtypingResult isCovariantWith( + SubtypingEnvironment& env, + const NormalizedClassType& subClass, + const NormalizedClassType& superClass, + NotNull scope + ); + SubtypingResult isCovariantWith(SubtypingEnvironment& env, const NormalizedClassType& subClass, const TypeIds& superTables, NotNull scope); + SubtypingResult isCovariantWith( + SubtypingEnvironment& env, + const NormalizedStringType& subString, + const NormalizedStringType& superString, + NotNull scope + ); + SubtypingResult isCovariantWith( + SubtypingEnvironment& env, + const NormalizedStringType& subString, + const TypeIds& superTables, + NotNull scope + ); + SubtypingResult + isCovariantWith(SubtypingEnvironment& env, const NormalizedFunctionType& subFunction, const NormalizedFunctionType& superFunction, NotNull); + SubtypingResult isCovariantWith(SubtypingEnvironment& env, const TypeIds& subTypes, const TypeIds& superTypes, NotNull scope); + + SubtypingResult isCovariantWith( + SubtypingEnvironment& env, + const VariadicTypePack* subVariadic, + const VariadicTypePack* superVariadic, + NotNull scope + ); + SubtypingResult isCovariantWith( + SubtypingEnvironment& env, + const TypeFunctionInstanceType* subFunctionInstance, + const TypeId superTy, + NotNull scope + ); + SubtypingResult isCovariantWith( + SubtypingEnvironment& env, + const TypeId subTy, + const TypeFunctionInstanceType* superFunctionInstance, + NotNull scope + ); + + bool bindGeneric(SubtypingEnvironment& env, TypeId subTp, TypeId superTp); + bool bindGeneric(SubtypingEnvironment& env, TypePackId subTp, TypePackId superTp); + + template + TypeId makeAggregateType(const Container& container, TypeId orElse); + + std::pair handleTypeFunctionReductionResult(const TypeFunctionInstanceType* functionInstance, NotNull scope); + + [[noreturn]] void unexpected(TypeId ty); + [[noreturn]] void unexpected(TypePackId tp); +}; + +} // namespace Luau diff --git a/Analysis/include/Luau/Symbol.h b/Analysis/include/Luau/Symbol.h index b47554e0d..337e2a9f2 100644 --- a/Analysis/include/Luau/Symbol.h +++ b/Analysis/include/Luau/Symbol.h @@ -6,8 +6,6 @@ #include -LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) - namespace Luau { @@ -42,17 +40,7 @@ struct Symbol return local != nullptr || global.value != nullptr; } - bool operator==(const Symbol& rhs) const - { - if (local) - return local == rhs.local; - else if (global.value) - return rhs.global.value && global == rhs.global.value; // Subtlety: AstName::operator==(const char*) uses strcmp, not pointer identity. - else if (FFlag::DebugLuauDeferredConstraintResolution) - return !rhs.local && !rhs.global.value; // Reflexivity: we already know `this` Symbol is empty, so check that rhs is. - else - return false; - } + bool operator==(const Symbol& rhs) const; bool operator!=(const Symbol& rhs) const { diff --git a/Analysis/include/Luau/TableLiteralInference.h b/Analysis/include/Luau/TableLiteralInference.h new file mode 100644 index 000000000..dd9ecf971 --- /dev/null +++ b/Analysis/include/Luau/TableLiteralInference.h @@ -0,0 +1,28 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#pragma once + +#include "Luau/DenseHash.h" +#include "Luau/NotNull.h" +#include "Luau/TypeFwd.h" + +namespace Luau +{ + +struct TypeArena; +struct BuiltinTypes; +struct Unifier2; +class AstExpr; + +TypeId matchLiteralType( + NotNull> astTypes, + NotNull> astExpectedTypes, + NotNull builtinTypes, + NotNull arena, + NotNull unifier, + TypeId expectedType, + TypeId exprType, + const AstExpr* expr, + std::vector& toBlock +); +} // namespace Luau diff --git a/Analysis/include/Luau/ToDot.h b/Analysis/include/Luau/ToDot.h index 1a9c2811a..6fa99ec3f 100644 --- a/Analysis/include/Luau/ToDot.h +++ b/Analysis/include/Luau/ToDot.h @@ -2,16 +2,12 @@ #pragma once #include "Luau/Common.h" +#include "Luau/TypeFwd.h" #include namespace Luau { -struct Type; -using TypeId = const Type*; - -struct TypePackVar; -using TypePackId = const TypePackVar*; struct ToDotOptions { diff --git a/Analysis/include/Luau/ToString.h b/Analysis/include/Luau/ToString.h index 7758e8f99..f8001e088 100644 --- a/Analysis/include/Luau/ToString.h +++ b/Analysis/include/Luau/ToString.h @@ -2,6 +2,7 @@ #pragma once #include "Luau/Common.h" +#include "Luau/TypeFwd.h" #include #include @@ -19,13 +20,6 @@ class AstExpr; struct Scope; -struct Type; -using TypeId = const Type*; - -struct TypePackVar; -using TypePackId = const TypePackVar*; - -struct FunctionType; struct Constraint; struct Position; @@ -39,6 +33,11 @@ struct ToStringNameMap struct ToStringOptions { + ToStringOptions(bool exhaustive = false) + : exhaustive(exhaustive) + { + } + bool exhaustive = false; // If true, we produce complete output rather than comprehensible output bool useLineBreaks = false; // If true, we insert new lines to separate long results such as table entries/metatable. bool functionTypeArguments = false; // If true, output function type argument names when they are available @@ -47,6 +46,7 @@ struct ToStringOptions bool hideFunctionSelfArgument = false; // If true, `self: X` will be omitted from the function signature if the function has self size_t maxTableLength = size_t(FInt::LuauTableTypeMaximumStringifierLength); // Only applied to TableTypes size_t maxTypeLength = size_t(FInt::LuauTypeMaximumStringifierLength); + size_t compositeTypesSingleLineLimit = 5; // The number of type elements permitted on a single line when printing type unions/intersections ToStringNameMap nameMap; std::shared_ptr scope; // If present, module names will be added and types that are not available in scope will be marked as 'invalid' std::vector namedFunctionOverrideArgNames; // If present, named function argument names will be overridden @@ -99,10 +99,7 @@ inline std::string toString(const Constraint& c, ToStringOptions&& opts) return toString(c, opts); } -inline std::string toString(const Constraint& c) -{ - return toString(c, ToStringOptions{}); -} +std::string toString(const Constraint& c); std::string toString(const Type& tv, ToStringOptions& opts); std::string toString(const TypePackVar& tp, ToStringOptions& opts); @@ -142,6 +139,16 @@ std::string dump(const std::shared_ptr& scope, const char* name); std::string generateName(size_t n); std::string toString(const Position& position); -std::string toString(const Location& location); +std::string toString(const Location& location, int offset = 0, bool useBegin = true); + +std::string toString(const TypeOrPack& tyOrTp, ToStringOptions& opts); + +inline std::string toString(const TypeOrPack& tyOrTp) +{ + ToStringOptions opts{}; + return toString(tyOrTp, opts); +} + +std::string dump(const TypeOrPack& tyOrTp); } // namespace Luau diff --git a/Analysis/include/Luau/TxnLog.h b/Analysis/include/Luau/TxnLog.h index 0ed8a49ad..951f89ee5 100644 --- a/Analysis/include/Luau/TxnLog.h +++ b/Analysis/include/Luau/TxnLog.h @@ -19,6 +19,10 @@ struct PendingType // The pending Type state. Type pending; + // On very rare occasions, we need to delete an entry from the TxnLog. + // DenseHashMap does not afford that so we note its deadness here. + bool dead = false; + explicit PendingType(Type state) : pending(std::move(state)) { @@ -61,10 +65,11 @@ T* getMutable(PendingTypePack* pending) // Log of what TypeIds we are rebinding, to be committed later. struct TxnLog { - TxnLog() + explicit TxnLog(bool useScopes = false) : typeVarChanges(nullptr) , typePackChanges(nullptr) , ownedSeen() + , useScopes(useScopes) , sharedSeen(&ownedSeen) { } @@ -297,6 +302,18 @@ struct TxnLog void popSeen(TypeOrPackId lhs, TypeOrPackId rhs); public: + // There is one spot in the code where TxnLog has to reconcile collisions + // between parallel logs. In that codepath, we have to work out which of two + // FreeTypes subsumes the other. If useScopes is false, the TypeLevel is + // used. Else we use the embedded Scope*. + bool useScopes = false; + + // It is sometimes the case under DCR that we speculatively rebind + // GenericTypes to other types as though they were free. We mark logs that + // contain these kinds of substitutions as radioactive so that we know that + // we must never commit one. + bool radioactive = false; + // Used to avoid infinite recursion when types are cyclic. // Shared with all the descendent TxnLogs. std::vector>* sharedSeen; diff --git a/Analysis/include/Luau/Type.h b/Analysis/include/Luau/Type.h index cf1f8dae4..d100fa4d7 100644 --- a/Analysis/include/Luau/Type.h +++ b/Analysis/include/Luau/Type.h @@ -1,6 +1,8 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #pragma once +#include "Luau/TypeFwd.h" + #include "Luau/Ast.h" #include "Luau/Common.h" #include "Luau/Refinement.h" @@ -9,15 +11,15 @@ #include "Luau/Predicate.h" #include "Luau/Unifiable.h" #include "Luau/Variant.h" +#include "Luau/VecDeque.h" -#include +#include #include #include #include #include #include #include -#include #include LUAU_FASTINT(LuauTableTypeMaximumStringifierLength) @@ -30,6 +32,11 @@ struct TypeArena; struct Scope; using ScopePtr = std::shared_ptr; +struct TypeFunction; +struct Constraint; +struct Subtyping; +struct TypeChecker2; + /** * There are three kinds of type variables: * - `Free` variables are metavariables, which stand for unconstrained types. @@ -56,32 +63,53 @@ using ScopePtr = std::shared_ptr; * ``` */ -// So... why `const T*` here rather than `T*`? -// It's because we've had problems caused by the type graph being mutated -// in ways it shouldn't be, for example mutating types from other modules. -// To try to control this, we make the use of types immutable by default, -// then provide explicit mutable access via getMutable and asMutable. -// This means we can grep for all the places we're mutating the type graph, -// and it makes it possible to provide other APIs (e.g. the txn log) -// which control mutable access to the type graph. -struct TypePackVar; -using TypePackId = const TypePackVar*; +using Name = std::string; -struct Type; +// A free type is one whose exact shape has yet to be fully determined. +struct FreeType +{ + explicit FreeType(TypeLevel level); + explicit FreeType(Scope* scope); + FreeType(Scope* scope, TypeLevel level); -// Should never be null -using TypeId = const Type*; + FreeType(Scope* scope, TypeId lowerBound, TypeId upperBound); -using Name = std::string; + int index; + TypeLevel level; + Scope* scope = nullptr; -// A free type var is one whose exact shape has yet to be fully determined. -using FreeType = Unifiable::Free; + // True if this free type variable is part of a mutually + // recursive type alias whose definitions haven't been + // resolved yet. + bool forwardedTypeAlias = false; -// When a free type var is unified with any other, it is then "bound" -// to that type var, indicating that the two types are actually the same type. -using BoundType = Unifiable::Bound; + // Only used under local type inference + TypeId lowerBound = nullptr; + TypeId upperBound = nullptr; +}; + +struct GenericType +{ + // By default, generics are global, with a synthetic name + GenericType(); -using GenericType = Unifiable::Generic; + explicit GenericType(TypeLevel level); + explicit GenericType(const Name& name); + explicit GenericType(Scope* scope); + + GenericType(TypeLevel level, const Name& name); + GenericType(Scope* scope, const Name& name); + + int index; + TypeLevel level; + Scope* scope = nullptr; + Name name; + bool explicitName = false; +}; + +// When an equality constraint is found, it is then "bound" to that type, +// indicating that the two types are actually the same type. +using BoundType = Unifiable::Bound; using Tags = std::vector; @@ -102,7 +130,14 @@ struct BlockedType BlockedType(); int index; - static int nextIndex; + Constraint* getOwner() const; + void setOwner(Constraint* newOwner); + void replaceOwner(Constraint* newOwner); + +private: + // The constraint that is intended to unblock this type. Other constraints + // should block on this constraint if present. + Constraint* owner = nullptr; }; struct PrimitiveType @@ -116,6 +151,7 @@ struct PrimitiveType Thread, Function, Table, + Buffer, }; Type type; @@ -133,7 +169,7 @@ struct PrimitiveType } }; -// Singleton types https://github.com/Roblox/luau/blob/master/rfcs/syntax-singleton-types.md +// Singleton types https://github.com/luau-lang/rfcs/blob/master/docs/syntax-singleton-types.md // Types for true and false struct BooleanSingleton { @@ -206,22 +242,6 @@ const T* get(const SingletonType* stv) return nullptr; } -struct GenericTypeDefinition -{ - TypeId ty; - std::optional defaultValue; - - bool operator==(const GenericTypeDefinition& rhs) const; -}; - -struct GenericTypePackDefinition -{ - TypePackId tp; - std::optional defaultValue; - - bool operator==(const GenericTypePackDefinition& rhs) const; -}; - struct FunctionArgument { Name name; @@ -258,19 +278,19 @@ struct WithPredicate } }; -using MagicFunction = std::function>( - struct TypeChecker&, const std::shared_ptr&, const class AstExprCall&, WithPredicate)>; +using MagicFunction = std::function>(struct TypeChecker&, const std::shared_ptr&, const class AstExprCall&, WithPredicate)>; struct MagicFunctionCallContext { NotNull solver; + NotNull constraint; const class AstExprCall* callSite; TypePackId arguments; TypePackId result; }; -using DcrMagicFunction = bool (*)(MagicFunctionCallContext); - +using DcrMagicFunction = std::function; struct MagicRefinementContext { NotNull scope; @@ -278,27 +298,63 @@ struct MagicRefinementContext std::vector> discriminantTypes; }; -using DcrMagicRefinement = void (*)(const MagicRefinementContext&); +struct MagicFunctionTypeCheckContext +{ + NotNull typechecker; + NotNull builtinTypes; + const class AstExprCall* callSite; + TypePackId arguments; + NotNull checkScope; +}; +using DcrMagicRefinement = void (*)(const MagicRefinementContext&); +using DcrMagicFunctionTypeCheck = std::function; struct FunctionType { // Global monomorphic function FunctionType(TypePackId argTypes, TypePackId retTypes, std::optional defn = {}, bool hasSelf = false); // Global polymorphic function - FunctionType(std::vector generics, std::vector genericPacks, TypePackId argTypes, TypePackId retTypes, - std::optional defn = {}, bool hasSelf = false); + FunctionType( + std::vector generics, + std::vector genericPacks, + TypePackId argTypes, + TypePackId retTypes, + std::optional defn = {}, + bool hasSelf = false + ); // Local monomorphic function FunctionType(TypeLevel level, TypePackId argTypes, TypePackId retTypes, std::optional defn = {}, bool hasSelf = false); FunctionType( - TypeLevel level, Scope* scope, TypePackId argTypes, TypePackId retTypes, std::optional defn = {}, bool hasSelf = false); + TypeLevel level, + Scope* scope, + TypePackId argTypes, + TypePackId retTypes, + std::optional defn = {}, + bool hasSelf = false + ); // Local polymorphic function - FunctionType(TypeLevel level, std::vector generics, std::vector genericPacks, TypePackId argTypes, TypePackId retTypes, - std::optional defn = {}, bool hasSelf = false); - FunctionType(TypeLevel level, Scope* scope, std::vector generics, std::vector genericPacks, TypePackId argTypes, - TypePackId retTypes, std::optional defn = {}, bool hasSelf = false); + FunctionType( + TypeLevel level, + std::vector generics, + std::vector genericPacks, + TypePackId argTypes, + TypePackId retTypes, + std::optional defn = {}, + bool hasSelf = false + ); + FunctionType( + TypeLevel level, + Scope* scope, + std::vector generics, + std::vector genericPacks, + TypePackId argTypes, + TypePackId retTypes, + std::optional defn = {}, + bool hasSelf = false + ); std::optional definition; /// These should all be generic @@ -313,8 +369,19 @@ struct FunctionType MagicFunction magicFunction = nullptr; DcrMagicFunction dcrMagicFunction = nullptr; DcrMagicRefinement dcrMagicRefinement = nullptr; + + // Callback to allow custom typechecking of builtin function calls whose argument types + // will only be resolved after constraint solving. For example, the arguments to string.format + // have types that can only be decided after parsing the format string and unifying + // with the passed in values, but the correctness of the call can only be decided after + // all the types have been finalized. + DcrMagicFunctionTypeCheck dcrMagicTypeCheck = nullptr; + bool hasSelf; - bool hasNoGenerics = false; + // `hasNoFreeOrGenericTypes` should be true if and only if the type does not have any free or generic types present inside it. + // this flag is used as an optimization to exit early from procedures that manipulate free or generic types. + bool hasNoFreeOrGenericTypes = false; + bool isCheckedFunction = false; }; enum class TableState @@ -347,12 +414,60 @@ struct TableIndexer struct Property { - TypeId type; + static Property readonly(TypeId ty); + static Property writeonly(TypeId ty); + static Property rw(TypeId ty); // Shared read-write type. + static Property rw(TypeId read, TypeId write); // Separate read-write type. + + // Invariant: at least one of the two optionals are not nullopt! + // If the read type is not nullopt, but the write type is, then the property is readonly. + // If the read type is nullopt, but the write type is not, then the property is writeonly. + // If the read and write types are not nullopt, then the property is read and write. + // Otherwise, an assertion where read and write types are both nullopt will be tripped. + static Property create(std::optional read, std::optional write); + bool deprecated = false; std::string deprecatedSuggestion; + + // If this property was inferred from an expression, this field will be + // populated with the source location of the corresponding table property. std::optional location = std::nullopt; + + // If this property was built from an explicit type annotation, this field + // will be populated with the source location of that table property. + std::optional typeLocation = std::nullopt; + Tags tags; std::optional documentationSymbol; + + // DEPRECATED + // TODO: Kill all constructors in favor of `Property::rw(TypeId read, TypeId write)` and friends. + Property(); + Property( + TypeId readTy, + bool deprecated = false, + const std::string& deprecatedSuggestion = "", + std::optional location = std::nullopt, + const Tags& tags = {}, + const std::optional& documentationSymbol = std::nullopt, + std::optional typeLocation = std::nullopt + ); + + // DEPRECATED: Should only be called in non-RWP! We assert that the `readTy` is not nullopt. + // TODO: Kill once we don't have non-RWP. + TypeId type() const; + void setType(TypeId ty); + + // Sets the write type of this property to the read type. + void makeShared(); + + bool isShared() const; + bool isReadOnly() const; + bool isWriteOnly() const; + bool isReadWrite() const; + + std::optional readTy; + std::optional writeTy; }; struct TableType @@ -390,14 +505,21 @@ struct TableType // Methods of this table that have an untyped self will use the same shared self type. std::optional selfTy; + + // We track the number of as-yet-unadded properties to unsealed tables. + // Some constraints will use this information to decide whether or not they + // are able to dispatch. + size_t remainingProps = 0; }; // Represents a metatable attached to a table type. Somewhat analogous to a bound type. struct MetatableType { - // Always points to a TableType. + // Should always be a TableType. TypeId table; - // Always points to either a TableType or a MetatableType. + // Should almost always either be a TableType or another MetatableType, + // though it is possible for other types (like AnyType and ErrorType) to + // find their way here sometimes. TypeId metatable; std::optional syntheticName; @@ -428,9 +550,41 @@ struct ClassType Tags tags; std::shared_ptr userData; ModuleName definitionModuleName; + std::optional definitionLocation; + std::optional indexer; + + ClassType( + Name name, + Props props, + std::optional parent, + std::optional metatable, + Tags tags, + std::shared_ptr userData, + ModuleName definitionModuleName, + std::optional definitionLocation + ) + : name(name) + , props(props) + , parent(parent) + , metatable(metatable) + , tags(tags) + , userData(userData) + , definitionModuleName(definitionModuleName) + , definitionLocation(definitionLocation) + { + } - ClassType(Name name, Props props, std::optional parent, std::optional metatable, Tags tags, - std::shared_ptr userData, ModuleName definitionModuleName) + ClassType( + Name name, + Props props, + std::optional parent, + std::optional metatable, + Tags tags, + std::shared_ptr userData, + ModuleName definitionModuleName, + std::optional definitionLocation, + std::optional indexer + ) : name(name) , props(props) , parent(parent) @@ -438,44 +592,54 @@ struct ClassType , tags(tags) , userData(userData) , definitionModuleName(definitionModuleName) + , definitionLocation(definitionLocation) + , indexer(indexer) { } }; -struct TypeFun +/** + * An instance of a type function that has not yet been reduced to a more concrete + * type. The constraint solver receives a constraint to reduce each + * TypeFunctionInstanceType to a concrete type. A design detail is important to + * note here: the parameters for this instantiation of the type function are + * contained within this type, so that they can be substituted. + */ +struct TypeFunctionInstanceType { - // These should all be generic - std::vector typeParams; - std::vector typePackParams; - - /** The underlying type. - * - * WARNING! This is not safe to use as a type if typeParams is not empty!! - * You must first use TypeChecker::instantiateTypeFun to turn it into a real type. - */ - TypeId type; + NotNull function; - TypeFun() = default; + std::vector typeArguments; + std::vector packArguments; - explicit TypeFun(TypeId ty) - : type(ty) + std::optional userFuncName; // Name of the user-defined type function; only available for UDTFs + + TypeFunctionInstanceType( + NotNull function, + std::vector typeArguments, + std::vector packArguments, + std::optional userFuncName = std::nullopt + ) + : function(function) + , typeArguments(typeArguments) + , packArguments(packArguments) + , userFuncName(userFuncName) { } - TypeFun(std::vector typeParams, TypeId type) - : typeParams(std::move(typeParams)) - , type(type) + TypeFunctionInstanceType(const TypeFunction& function, std::vector typeArguments) + : function{&function} + , typeArguments(typeArguments) + , packArguments{} { } - TypeFun(std::vector typeParams, std::vector typePackParams, TypeId type) - : typeParams(std::move(typeParams)) - , typePackParams(std::move(typePackParams)) - , type(type) + TypeFunctionInstanceType(const TypeFunction& function, std::vector typeArguments, std::vector packArguments) + : function{&function} + , typeArguments(typeArguments) + , packArguments(packArguments) { } - - bool operator==(const TypeFun& rhs) const; }; /** Represents a pending type alias instantiation. @@ -503,6 +667,11 @@ struct AnyType { }; +// A special, trivial type for the refinement system that is always eliminated from intersections. +struct NoRefineType +{ +}; + // `T | U` struct UnionType { @@ -517,7 +686,43 @@ struct IntersectionType struct LazyType { - std::function thunk; + LazyType() = default; + LazyType(std::function unwrap) + : unwrap(unwrap) + { + } + + // std::atomic is sad and requires a manual copy + LazyType(const LazyType& rhs) + : unwrap(rhs.unwrap) + , unwrapped(rhs.unwrapped.load()) + { + } + + LazyType(LazyType&& rhs) noexcept + : unwrap(std::move(rhs.unwrap)) + , unwrapped(rhs.unwrapped.load()) + { + } + + LazyType& operator=(const LazyType& rhs) + { + unwrap = rhs.unwrap; + unwrapped = rhs.unwrapped.load(); + + return *this; + } + + LazyType& operator=(LazyType&& rhs) noexcept + { + unwrap = std::move(rhs.unwrap); + unwrapped = rhs.unwrapped.load(); + + return *this; + } + + std::function unwrap; + std::atomic unwrapped = nullptr; }; struct UnknownType @@ -536,8 +741,27 @@ struct NegationType using ErrorType = Unifiable::Error; -using TypeVariant = Unifiable::Variant; +using TypeVariant = Unifiable::Variant< + TypeId, + FreeType, + GenericType, + PrimitiveType, + SingletonType, + BlockedType, + PendingExpansionType, + FunctionType, + TableType, + MetatableType, + ClassType, + AnyType, + UnionType, + IntersectionType, + LazyType, + UnknownType, + NeverType, + NegationType, + NoRefineType, + TypeFunctionInstanceType>; struct Type final { @@ -582,15 +806,82 @@ struct Type final Type& operator=(const TypeVariant& rhs); Type& operator=(TypeVariant&& rhs); + Type(Type&&) = default; + Type& operator=(Type&&) = default; + + Type clone() const; + +private: + Type(const Type&) = default; Type& operator=(const Type& rhs); }; +struct GenericTypeDefinition +{ + TypeId ty; + std::optional defaultValue; + + bool operator==(const GenericTypeDefinition& rhs) const; +}; + +struct GenericTypePackDefinition +{ + TypePackId tp; + std::optional defaultValue; + + bool operator==(const GenericTypePackDefinition& rhs) const; +}; + +struct TypeFun +{ + // These should all be generic + std::vector typeParams; + std::vector typePackParams; + + /** The underlying type. + * + * WARNING! This is not safe to use as a type if typeParams is not empty!! + * You must first use TypeChecker::instantiateTypeFun to turn it into a real type. + */ + TypeId type; + + TypeFun() = default; + + explicit TypeFun(TypeId ty) + : type(ty) + { + } + + TypeFun(std::vector typeParams, TypeId type) + : typeParams(std::move(typeParams)) + , type(type) + { + } + + TypeFun(std::vector typeParams, std::vector typePackParams, TypeId type) + : typeParams(std::move(typeParams)) + , typePackParams(std::move(typePackParams)) + , type(type) + { + } + + bool operator==(const TypeFun& rhs) const; +}; + using SeenSet = std::set>; bool areEqual(SeenSet& seen, const Type& lhs, const Type& rhs); +enum class FollowOption +{ + Normal, + DisableLazyTypeThunks, +}; + // Follow BoundTypes until we get to something real TypeId follow(TypeId t); -TypeId follow(TypeId t, std::function mapper); +TypeId follow(TypeId t, FollowOption followOption); +TypeId follow(TypeId t, const void* context, TypeId (*mapper)(const void*, TypeId)); +TypeId follow(TypeId t, FollowOption followOption, const void* context, TypeId (*mapper)(const void*, TypeId)); std::vector flattenIntersection(TypeId ty); @@ -600,8 +891,10 @@ bool isBoolean(TypeId ty); bool isNumber(TypeId ty); bool isString(TypeId ty); bool isThread(TypeId ty); +bool isBuffer(TypeId ty); bool isOptional(TypeId ty); bool isTableIntersection(TypeId ty); +bool isTableUnion(TypeId ty); bool isOverloadedFunction(TypeId ty); // True when string is a subtype of ty @@ -640,23 +933,25 @@ struct BuiltinTypes BuiltinTypes(const BuiltinTypes&) = delete; void operator=(const BuiltinTypes&) = delete; - TypeId errorRecoveryType(TypeId guess); - TypePackId errorRecoveryTypePack(TypePackId guess); - TypeId errorRecoveryType(); - TypePackId errorRecoveryTypePack(); + TypeId errorRecoveryType(TypeId guess) const; + TypePackId errorRecoveryTypePack(TypePackId guess) const; + TypeId errorRecoveryType() const; + TypePackId errorRecoveryTypePack() const; + + friend TypeId makeStringMetatable(NotNull builtinTypes); + friend struct GlobalTypes; private: std::unique_ptr arena; bool debugFreezeArena = false; - TypeId makeStringMetatable(); - public: const TypeId nilType; const TypeId numberType; const TypeId stringType; const TypeId booleanType; const TypeId threadType; + const TypeId bufferType; const TypeId functionType; const TypeId classType; const TypeId tableType; @@ -667,13 +962,16 @@ struct BuiltinTypes const TypeId unknownType; const TypeId neverType; const TypeId errorType; + const TypeId noRefineType; const TypeId falsyType; const TypeId truthyType; const TypeId optionalNumberType; const TypeId optionalStringType; + const TypePackId emptyTypePack; const TypePackId anyTypePack; + const TypePackId unknownTypePack; const TypePackId neverTypePack; const TypePackId uninhabitableTypePack; const TypePackId errorTypePack; @@ -694,6 +992,18 @@ bool isSubclass(const ClassType* cls, const ClassType* parent); Type* asMutable(TypeId ty); +template +bool is(T&& tv) +{ + if (!tv) + return false; + + if constexpr (std::is_same_v && !(std::is_same_v || ...)) + LUAU_ASSERT(get_if(&tv->ty) == nullptr); + + return (get(tv) || ...); +} + template const T* get(TypeId tv) { @@ -705,6 +1015,7 @@ const T* get(TypeId tv) return get_if(&tv->ty); } + template T* getMutable(TypeId tv) { @@ -764,7 +1075,7 @@ struct TypeIterator TypeIterator operator++(int) { TypeIterator copy = *this; - ++copy; + ++*this; return copy; } @@ -810,8 +1121,8 @@ struct TypeIterator // (T* t, size_t currentIndex) using SavedIterInfo = std::pair; - std::deque stack; - std::unordered_set seen; // Only needed to protect the iterator from hanging the thread. + VecDeque stack; + DenseHashSet seen{nullptr}; // Only needed to protect the iterator from hanging the thread. void advance() { @@ -838,7 +1149,7 @@ struct TypeIterator { // If we're about to descend into a cyclic type, we should skip over this. // Ideally this should never happen, but alas it does from time to time. :( - if (seen.find(inner) != seen.end()) + if (seen.contains(inner)) advance(); else { @@ -854,9 +1165,15 @@ struct TypeIterator } }; +TypeId freshType(NotNull arena, NotNull builtinTypes, Scope* scope); + using TypeIdPredicate = std::function(TypeId)>; std::vector filterMap(TypeId type, TypeIdPredicate predicate); +// A tag to mark a type which doesn't derive directly from the root type as overriding the return of `typeof`. +// Any classes which derive from this type will have typeof return this type. +static constexpr char kTypeofRootTag[] = "typeofRoot"; + void attachTag(TypeId ty, const std::string& tagName); void attachTag(Property& prop, const std::string& tagName); @@ -864,6 +1181,19 @@ bool hasTag(TypeId ty, const std::string& tagName); bool hasTag(const Property& prop, const std::string& tagName); bool hasTag(const Tags& tags, const std::string& tagName); // Do not use in new work. +template +bool hasTypeInIntersection(TypeId ty) +{ + TypeId tf = follow(ty); + if (get(tf)) + return true; + for (auto t : flattenIntersection(tf)) + if (get(follow(t))) + return true; + return false; +} + +bool hasPrimitiveTypeInIntersection(TypeId ty, PrimitiveType::Type primTy); /* * Use this to change the kind of a particular type. * @@ -875,4 +1205,7 @@ LUAU_NOINLINE T* emplaceType(Type* ty, Args&&... args) return &ty->ty.emplace(std::forward(args)...); } +template<> +LUAU_NOINLINE Unifiable::Bound* emplaceType(Type* ty, TypeId& tyArg); + } // namespace Luau diff --git a/Analysis/include/Luau/TypeArena.h b/Analysis/include/Luau/TypeArena.h index 0e69bb4aa..4f8aea879 100644 --- a/Analysis/include/Luau/TypeArena.h +++ b/Analysis/include/Luau/TypeArena.h @@ -9,12 +9,16 @@ namespace Luau { +struct Module; struct TypeArena { TypedAllocator types; TypedAllocator typePacks; + // Owning module, if any + Module* owningModule = nullptr; + void clear(); template @@ -44,6 +48,11 @@ struct TypeArena { return addTypePack(TypePackVar(std::move(tp))); } + + TypeId addTypeFunction(const TypeFunction& function, std::initializer_list types); + TypeId addTypeFunction(const TypeFunction& function, std::vector typeArguments, std::vector packArguments = {}); + TypePackId addTypePackFunction(const TypePackFunction& function, std::initializer_list types); + TypePackId addTypePackFunction(const TypePackFunction& function, std::vector typeArguments, std::vector packArguments = {}); }; void freeze(TypeArena& arena); diff --git a/Analysis/include/Luau/TypeCheckLimits.h b/Analysis/include/Luau/TypeCheckLimits.h new file mode 100644 index 000000000..9eabe0ff6 --- /dev/null +++ b/Analysis/include/Luau/TypeCheckLimits.h @@ -0,0 +1,41 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/Cancellation.h" +#include "Luau/Error.h" + +#include +#include +#include + +namespace Luau +{ + +class TimeLimitError : public InternalCompilerError +{ +public: + explicit TimeLimitError(const std::string& moduleName) + : InternalCompilerError("Typeinfer failed to complete in allotted time", moduleName) + { + } +}; + +class UserCancelError : public InternalCompilerError +{ +public: + explicit UserCancelError(const std::string& moduleName) + : InternalCompilerError("Analysis has been cancelled by user", moduleName) + { + } +}; + +struct TypeCheckLimits +{ + std::optional finishTime; + std::optional instantiationChildLimit; + std::optional unifierIterationLimit; + + std::shared_ptr cancellationToken; +}; + +} // namespace Luau diff --git a/Analysis/include/Luau/TypeChecker2.h b/Analysis/include/Luau/TypeChecker2.h index 6045aecff..3ede5ca71 100644 --- a/Analysis/include/Luau/TypeChecker2.h +++ b/Analysis/include/Luau/TypeChecker2.h @@ -2,16 +2,222 @@ #pragma once -#include "Luau/Ast.h" -#include "Luau/Module.h" +#include "Luau/Error.h" #include "Luau/NotNull.h" +#include "Luau/Common.h" +#include "Luau/TypeUtils.h" +#include "Luau/Type.h" +#include "Luau/TypeFwd.h" +#include "Luau/TypeOrPack.h" +#include "Luau/Normalize.h" +#include "Luau/Subtyping.h" namespace Luau { -struct DcrLogger; struct BuiltinTypes; +struct DcrLogger; +struct TypeCheckLimits; +struct UnifierSharedState; +struct SourceModule; +struct Module; +struct InternalErrorReporter; +struct Scope; +struct PropertyType; +struct PropertyTypes; +struct StackPusher; + +struct Reasonings +{ + // the list of reasons + std::vector reasons; + + // this should be true if _all_ of the reasons have an error suppressing type, and false otherwise. + bool suppressed; + + std::string toString() + { + // DenseHashSet ordering is entirely undefined, so we want to + // sort the reasons here to achieve a stable error + // stringification. + std::sort(reasons.begin(), reasons.end()); + std::string allReasons; + bool first = true; + for (const std::string& reason : reasons) + { + if (first) + first = false; + else + allReasons += "\n\t"; + + allReasons += reason; + } + + return allReasons; + } +}; + + +void check( + NotNull builtinTypes, + NotNull typeFunctionRuntime, + NotNull sharedState, + NotNull limits, + DcrLogger* logger, + const SourceModule& sourceModule, + Module* module +); + +struct TypeChecker2 +{ + NotNull builtinTypes; + NotNull typeFunctionRuntime; + DcrLogger* logger; + const NotNull limits; + const NotNull ice; + const SourceModule* sourceModule; + Module* module; + + TypeContext typeContext = TypeContext::Default; + std::vector> stack; + std::vector functionDeclStack; + + DenseHashSet seenTypeFunctionInstances{nullptr}; + + Normalizer normalizer; + Subtyping _subtyping; + NotNull subtyping; + + TypeChecker2( + NotNull builtinTypes, + NotNull typeFunctionRuntime, + NotNull unifierState, + NotNull limits, + DcrLogger* logger, + const SourceModule* sourceModule, + Module* module + ); + + void visit(AstStatBlock* block); + void reportError(TypeErrorData data, const Location& location); + Reasonings explainReasonings(TypeId subTy, TypeId superTy, Location location, const SubtypingResult& r); + Reasonings explainReasonings(TypePackId subTp, TypePackId superTp, Location location, const SubtypingResult& r); + +private: + static bool allowsNoReturnValues(const TypePackId tp); + static Location getEndLocation(const AstExprFunction* function); + bool isErrorCall(const AstExprCall* call); + bool hasBreak(AstStat* node); + const AstStat* getFallthrough(const AstStat* node); + std::optional pushStack(AstNode* node); + void checkForInternalTypeFunction(TypeId ty, Location location); + TypeId checkForTypeFunctionInhabitance(TypeId instance, Location location); + TypePackId lookupPack(AstExpr* expr); + TypeId lookupType(AstExpr* expr); + TypeId lookupAnnotation(AstType* annotation); + std::optional lookupPackAnnotation(AstTypePack* annotation); + TypeId lookupExpectedType(AstExpr* expr); + TypePackId lookupExpectedPack(AstExpr* expr, TypeArena& arena); + TypePackId reconstructPack(AstArray exprs, TypeArena& arena); + Scope* findInnermostScope(Location location); + void visit(AstStat* stat); + void visit(AstStatIf* ifStatement); + void visit(AstStatWhile* whileStatement); + void visit(AstStatRepeat* repeatStatement); + void visit(AstStatBreak*); + void visit(AstStatContinue*); + void visit(AstStatReturn* ret); + void visit(AstStatExpr* expr); + void visit(AstStatLocal* local); + void visit(AstStatFor* forStatement); + void visit(AstStatForIn* forInStatement); + std::optional getBindingType(AstExpr* expr); + void reportErrorsFromAssigningToNever(AstExpr* lhs, TypeId rhsType); + void visit(AstStatAssign* assign); + void visit(AstStatCompoundAssign* stat); + void visit(AstStatFunction* stat); + void visit(AstStatLocalFunction* stat); + void visit(const AstTypeList* typeList); + void visit(AstStatTypeAlias* stat); + void visit(AstStatTypeFunction* stat); + void visit(AstTypeList types); + void visit(AstStatDeclareFunction* stat); + void visit(AstStatDeclareGlobal* stat); + void visit(AstStatDeclareClass* stat); + void visit(AstStatError* stat); + void visit(AstExpr* expr, ValueContext context); + void visit(AstExprGroup* expr, ValueContext context); + void visit(AstExprConstantNil* expr); + void visit(AstExprConstantBool* expr); + void visit(AstExprConstantNumber* expr); + void visit(AstExprConstantString* expr); + void visit(AstExprLocal* expr); + void visit(AstExprGlobal* expr); + void visit(AstExprVarargs* expr); + void visitCall(AstExprCall* call); + void visit(AstExprCall* call); + std::optional tryStripUnionFromNil(TypeId ty); + TypeId stripFromNilAndReport(TypeId ty, const Location& location); + void visitExprName(AstExpr* expr, Location location, const std::string& propName, ValueContext context, TypeId astIndexExprTy); + void visit(AstExprIndexName* indexName, ValueContext context); + void indexExprMetatableHelper(AstExprIndexExpr* indexExpr, const MetatableType* metaTable, TypeId exprType, TypeId indexType); + void visit(AstExprIndexExpr* indexExpr, ValueContext context); + void visit(AstExprFunction* fn); + void visit(AstExprTable* expr); + void visit(AstExprUnary* expr); + TypeId visit(AstExprBinary* expr, AstNode* overrideKey = nullptr); + void visit(AstExprTypeAssertion* expr); + void visit(AstExprIfElse* expr); + void visit(AstExprInterpString* interpString); + void visit(AstExprError* expr); + TypeId flattenPack(TypePackId pack); + void visitGenerics(AstArray generics, AstArray genericPacks); + void visit(AstType* ty); + void visit(AstTypeReference* ty); + void visit(AstTypeTable* table); + void visit(AstTypeFunction* ty); + void visit(AstTypeTypeof* ty); + void visit(AstTypeUnion* ty); + void visit(AstTypeIntersection* ty); + void visit(AstTypePack* pack); + void visit(AstTypePackExplicit* tp); + void visit(AstTypePackVariadic* tp); + void visit(AstTypePackGeneric* tp); + + template + Reasonings explainReasonings_(TID subTy, TID superTy, Location location, const SubtypingResult& r); + + void explainError(TypeId subTy, TypeId superTy, Location location, const SubtypingResult& result); + void explainError(TypePackId subTy, TypePackId superTy, Location location, const SubtypingResult& result); + bool testIsSubtype(TypeId subTy, TypeId superTy, Location location); + bool testIsSubtype(TypePackId subTy, TypePackId superTy, Location location); + void reportError(TypeError e); + void reportErrors(ErrorVec errors); + PropertyTypes lookupProp( + const NormalizedType* norm, + const std::string& prop, + ValueContext context, + const Location& location, + TypeId astIndexExprType, + std::vector& errors + ); + // If the provided type does not have the named property, report an error. + void checkIndexTypeFromType(TypeId tableTy, const std::string& prop, ValueContext context, const Location& location, TypeId astIndexExprType); + PropertyType hasIndexTypeFromType( + TypeId ty, + const std::string& prop, + ValueContext context, + const Location& location, + DenseHashSet& seen, + TypeId astIndexExprType, + std::vector& errors + ); -void check(NotNull builtinTypes, DcrLogger* logger, const SourceModule& sourceModule, Module* module); + void diagnoseMissingTableKey(UnknownProperty* utk, TypeErrorData& data) const; + bool isErrorSuppressing(Location loc, TypeId ty); + bool isErrorSuppressing(Location loc1, TypeId ty1, Location loc2, TypeId ty2); + bool isErrorSuppressing(Location loc, TypePackId tp); + bool isErrorSuppressing(Location loc1, TypePackId tp1, Location loc2, TypePackId tp2); +}; } // namespace Luau diff --git a/Analysis/include/Luau/TypeFunction.h b/Analysis/include/Luau/TypeFunction.h new file mode 100644 index 000000000..df696b62f --- /dev/null +++ b/Analysis/include/Luau/TypeFunction.h @@ -0,0 +1,224 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/Constraint.h" +#include "Luau/Error.h" +#include "Luau/NotNull.h" +#include "Luau/TypeCheckLimits.h" +#include "Luau/TypeFunctionRuntime.h" +#include "Luau/TypeFwd.h" + +#include +#include +#include + +struct lua_State; + +namespace Luau +{ + +struct TypeArena; +struct TxnLog; +struct ConstraintSolver; +class Normalizer; + +using StateRef = std::unique_ptr; + +struct TypeFunctionRuntime +{ + TypeFunctionRuntime(NotNull ice, NotNull limits); + ~TypeFunctionRuntime(); + + // Return value is an error message if registration failed + std::optional registerFunction(AstStatTypeFunction* function); + + // For user-defined type functions, we store all generated types and packs for the duration of the typecheck + TypedAllocator typeArena; + TypedAllocator typePackArena; + + NotNull ice; + NotNull limits; + + StateRef state; + + // Evaluation of type functions should only be performed in the absence of parse errors in the source module + bool allowEvaluation = true; + +private: + void prepareState(); +}; + +struct TypeFunctionContext +{ + NotNull arena; + NotNull builtins; + NotNull scope; + NotNull normalizer; + NotNull typeFunctionRuntime; + NotNull ice; + NotNull limits; + + // nullptr if the type function is being reduced outside of the constraint solver. + ConstraintSolver* solver; + // The constraint being reduced in this run of the reduction + const Constraint* constraint; + + std::optional userFuncName; // Name of the user-defined type function; only available for UDTFs + + TypeFunctionContext(NotNull cs, NotNull scope, NotNull constraint); + + TypeFunctionContext( + NotNull arena, + NotNull builtins, + NotNull scope, + NotNull normalizer, + NotNull typeFunctionRuntime, + NotNull ice, + NotNull limits + ) + : arena(arena) + , builtins(builtins) + , scope(scope) + , normalizer(normalizer) + , typeFunctionRuntime(typeFunctionRuntime) + , ice(ice) + , limits(limits) + , solver(nullptr) + , constraint(nullptr) + { + } + + NotNull pushConstraint(ConstraintV&& c) const; +}; + +/// Represents a reduction result, which may have successfully reduced the type, +/// may have concretely failed to reduce the type, or may simply be stuck +/// without more information. +template +struct TypeFunctionReductionResult +{ + /// The result of the reduction, if any. If this is nullopt, the type function + /// could not be reduced. + std::optional result; + /// Whether the result is uninhabited: whether we know, unambiguously and + /// permanently, whether this type function reduction results in an + /// uninhabitable type. This will trigger an error to be reported. + bool uninhabited; + /// Any types that need to be progressed or mutated before the reduction may + /// proceed. + std::vector blockedTypes; + /// Any type packs that need to be progressed or mutated before the + /// reduction may proceed. + std::vector blockedPacks; + /// A runtime error message from user-defined type functions + std::optional error; +}; + +template +using ReducerFunction = + std::function(T, const std::vector&, const std::vector&, NotNull)>; + +/// Represents a type function that may be applied to map a series of types and +/// type packs to a single output type. +struct TypeFunction +{ + /// The human-readable name of the type function. Used to stringify instance + /// types. + std::string name; + + /// The reducer function for the type function. + ReducerFunction reducer; +}; + +/// Represents a type function that may be applied to map a series of types and +/// type packs to a single output type pack. +struct TypePackFunction +{ + /// The human-readable name of the type pack function. Used to stringify + /// instance packs. + std::string name; + + /// The reducer function for the type pack function. + ReducerFunction reducer; +}; + +struct FunctionGraphReductionResult +{ + ErrorVec errors; + DenseHashSet blockedTypes{nullptr}; + DenseHashSet blockedPacks{nullptr}; + DenseHashSet reducedTypes{nullptr}; + DenseHashSet reducedPacks{nullptr}; +}; + +/** + * Attempt to reduce all instances of any type or type pack functions in the type + * graph provided. + * + * @param entrypoint the entry point to the type graph. + * @param location the location the reduction is occurring at; used to populate + * type errors. + * @param arena an arena to allocate types into. + * @param builtins the built-in types. + * @param normalizer the normalizer to use when normalizing types + * @param ice the internal error reporter to use for ICEs + */ +FunctionGraphReductionResult reduceTypeFunctions(TypeId entrypoint, Location location, TypeFunctionContext, bool force = false); + +/** + * Attempt to reduce all instances of any type or type pack functions in the type + * graph provided. + * + * @param entrypoint the entry point to the type graph. + * @param location the location the reduction is occurring at; used to populate + * type errors. + * @param arena an arena to allocate types into. + * @param builtins the built-in types. + * @param normalizer the normalizer to use when normalizing types + * @param ice the internal error reporter to use for ICEs + */ +FunctionGraphReductionResult reduceTypeFunctions(TypePackId entrypoint, Location location, TypeFunctionContext, bool force = false); + +struct BuiltinTypeFunctions +{ + BuiltinTypeFunctions(); + + TypeFunction userFunc; + + TypeFunction notFunc; + TypeFunction lenFunc; + TypeFunction unmFunc; + + TypeFunction addFunc; + TypeFunction subFunc; + TypeFunction mulFunc; + TypeFunction divFunc; + TypeFunction idivFunc; + TypeFunction powFunc; + TypeFunction modFunc; + + TypeFunction concatFunc; + + TypeFunction andFunc; + TypeFunction orFunc; + + TypeFunction ltFunc; + TypeFunction leFunc; + TypeFunction eqFunc; + + TypeFunction refineFunc; + TypeFunction singletonFunc; + TypeFunction unionFunc; + TypeFunction intersectFunc; + + TypeFunction keyofFunc; + TypeFunction rawkeyofFunc; + TypeFunction indexFunc; + TypeFunction rawgetFunc; + + void addToScope(NotNull arena, NotNull scope) const; +}; + +const BuiltinTypeFunctions& builtinTypeFunctions(); + +} // namespace Luau diff --git a/Analysis/include/Luau/TypeFunctionReductionGuesser.h b/Analysis/include/Luau/TypeFunctionReductionGuesser.h new file mode 100644 index 000000000..b6d4a74c8 --- /dev/null +++ b/Analysis/include/Luau/TypeFunctionReductionGuesser.h @@ -0,0 +1,85 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#pragma once + +#include "Luau/Ast.h" +#include "Luau/VecDeque.h" +#include "Luau/DenseHash.h" +#include "Luau/TypeFunction.h" +#include "Luau/Type.h" +#include "Luau/TypePack.h" +#include "Luau/TypeUtils.h" +#include "Luau/Normalize.h" +#include "Luau/TypeFwd.h" +#include "Luau/VisitType.h" +#include "Luau/NotNull.h" +#include "TypeArena.h" + +namespace Luau +{ + +struct TypeFunctionReductionGuessResult +{ + std::vector> guessedFunctionAnnotations; + TypeId guessedReturnType; + bool shouldRecommendAnnotation = true; +}; + +// An Inference result for a type function is a list of types corresponding to the guessed argument types, followed by a type for the result +struct TypeFunctionInferenceResult +{ + std::vector operandInference; + TypeId functionResultInference; +}; + +struct TypeFunctionReductionGuesser +{ + // Tracks our hypothesis about what a type function reduces to + DenseHashMap functionReducesTo{nullptr}; + // Tracks our constraints on type function operands + DenseHashMap substitutable{nullptr}; + // List of instances to try progress + VecDeque toInfer; + DenseHashSet cyclicInstances{nullptr}; + + // Utilities + NotNull arena; + NotNull builtins; + NotNull normalizer; + + TypeFunctionReductionGuesser(NotNull arena, NotNull builtins, NotNull normalizer); + + std::optional guess(TypeId typ); + std::optional guess(TypePackId typ); + TypeFunctionReductionGuessResult guessTypeFunctionReductionForFunctionExpr(const AstExprFunction& expr, const FunctionType* ftv, TypeId retTy); + +private: + std::optional guessType(TypeId arg); + void dumpGuesses(); + + bool isNumericBinopFunction(const TypeFunctionInstanceType& instance); + bool isComparisonFunction(const TypeFunctionInstanceType& instance); + bool isOrAndFunction(const TypeFunctionInstanceType& instance); + bool isNotFunction(const TypeFunctionInstanceType& instance); + bool isLenFunction(const TypeFunctionInstanceType& instance); + bool isUnaryMinus(const TypeFunctionInstanceType& instance); + + // Operand is assignable if it looks like a cyclic type function instance, or a generic type + bool operandIsAssignable(TypeId ty); + std::optional tryAssignOperandType(TypeId ty); + + std::shared_ptr normalize(TypeId ty); + void step(); + void infer(); + bool done(); + + bool isFunctionGenericsSaturated(const FunctionType& ftv, DenseHashSet& instanceArgs); + void inferTypeFunctionSubstitutions(TypeId ty, const TypeFunctionInstanceType* instance); + TypeFunctionInferenceResult inferNumericBinopFunction(const TypeFunctionInstanceType* instance); + TypeFunctionInferenceResult inferComparisonFunction(const TypeFunctionInstanceType* instance); + TypeFunctionInferenceResult inferOrAndFunction(const TypeFunctionInstanceType* instance); + TypeFunctionInferenceResult inferNotFunction(const TypeFunctionInstanceType* instance); + TypeFunctionInferenceResult inferLenFunction(const TypeFunctionInstanceType* instance); + TypeFunctionInferenceResult inferUnaryMinusFunction(const TypeFunctionInstanceType* instance); +}; +} // namespace Luau diff --git a/Analysis/include/Luau/TypeFunctionRuntime.h b/Analysis/include/Luau/TypeFunctionRuntime.h new file mode 100644 index 000000000..44eef1360 --- /dev/null +++ b/Analysis/include/Luau/TypeFunctionRuntime.h @@ -0,0 +1,268 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/Common.h" +#include "Luau/Variant.h" + +#include +#include +#include +#include + +using lua_State = struct lua_State; + +namespace Luau +{ + +void* typeFunctionAlloc(void* ud, void* ptr, size_t osize, size_t nsize); + +// Replica of types from Type.h +struct TypeFunctionType; +using TypeFunctionTypeId = const TypeFunctionType*; + +struct TypeFunctionTypePackVar; +using TypeFunctionTypePackId = const TypeFunctionTypePackVar*; + +struct TypeFunctionPrimitiveType +{ + enum Type + { + NilType, + Boolean, + Number, + String, + }; + + Type type; + + TypeFunctionPrimitiveType(Type type) + : type(type) + { + } +}; + +struct TypeFunctionBooleanSingleton +{ + bool value = false; +}; + +struct TypeFunctionStringSingleton +{ + std::string value; +}; + +using TypeFunctionSingletonVariant = Variant; + +struct TypeFunctionSingletonType +{ + TypeFunctionSingletonVariant variant; + + explicit TypeFunctionSingletonType(TypeFunctionSingletonVariant variant) + : variant(std::move(variant)) + { + } +}; + +template +const T* get(const TypeFunctionSingletonType* tv) +{ + LUAU_ASSERT(tv); + + return tv ? get_if(&tv->variant) : nullptr; +} + +template +T* getMutable(const TypeFunctionSingletonType* tv) +{ + LUAU_ASSERT(tv); + + return tv ? get_if(&const_cast(tv)->variant) : nullptr; +} + +struct TypeFunctionUnionType +{ + std::vector components; +}; + +struct TypeFunctionIntersectionType +{ + std::vector components; +}; + +struct TypeFunctionAnyType +{ +}; + +struct TypeFunctionUnknownType +{ +}; + +struct TypeFunctionNeverType +{ +}; + +struct TypeFunctionNegationType +{ + TypeFunctionTypeId type; +}; + +struct TypeFunctionTypePack +{ + std::vector head; + std::optional tail; +}; + +struct TypeFunctionVariadicTypePack +{ + TypeFunctionTypeId type; +}; + +using TypeFunctionTypePackVariant = Variant; + +struct TypeFunctionTypePackVar +{ + TypeFunctionTypePackVariant type; + + TypeFunctionTypePackVar(TypeFunctionTypePackVariant type) + : type(std::move(type)) + { + } + + bool operator==(const TypeFunctionTypePackVar& rhs) const; +}; + +struct TypeFunctionFunctionType +{ + TypeFunctionTypePackId argTypes; + TypeFunctionTypePackId retTypes; +}; + +template +const T* get(TypeFunctionTypePackId tv) +{ + LUAU_ASSERT(tv); + + return tv ? get_if(&tv->type) : nullptr; +} + +template +T* getMutable(TypeFunctionTypePackId tv) +{ + LUAU_ASSERT(tv); + + return tv ? get_if(&const_cast(tv)->type) : nullptr; +} + +struct TypeFunctionTableIndexer +{ + TypeFunctionTableIndexer(TypeFunctionTypeId keyType, TypeFunctionTypeId valueType) + : keyType(keyType) + , valueType(valueType) + { + } + + TypeFunctionTypeId keyType; + TypeFunctionTypeId valueType; +}; + +struct TypeFunctionProperty +{ + static TypeFunctionProperty readonly(TypeFunctionTypeId ty); + static TypeFunctionProperty writeonly(TypeFunctionTypeId ty); + static TypeFunctionProperty rw(TypeFunctionTypeId ty); // Shared read-write type. + static TypeFunctionProperty rw(TypeFunctionTypeId read, TypeFunctionTypeId write); // Separate read-write type. + + bool isReadOnly() const; + bool isWriteOnly() const; + + std::optional readTy; + std::optional writeTy; +}; + +struct TypeFunctionTableType +{ + using Name = std::string; + using Props = std::map; + + Props props; + + std::optional indexer; + + // Should always be a TypeFunctionTableType + std::optional metatable; +}; + +struct TypeFunctionClassType +{ + using Name = std::string; + using Props = std::map; + + Props props; + + std::optional indexer; + + std::optional metatable; // metaclass? + + std::optional parent; + + std::string name; +}; + +using TypeFunctionTypeVariant = Luau::Variant< + TypeFunctionPrimitiveType, + TypeFunctionAnyType, + TypeFunctionUnknownType, + TypeFunctionNeverType, + TypeFunctionSingletonType, + TypeFunctionUnionType, + TypeFunctionIntersectionType, + TypeFunctionNegationType, + TypeFunctionFunctionType, + TypeFunctionTableType, + TypeFunctionClassType>; + +struct TypeFunctionType +{ + TypeFunctionTypeVariant type; + + TypeFunctionType(TypeFunctionTypeVariant type) + : type(std::move(type)) + { + } + + bool operator==(const TypeFunctionType& rhs) const; +}; + +template +const T* get(TypeFunctionTypeId tv) +{ + LUAU_ASSERT(tv); + + return tv ? Luau::get_if(&tv->type) : nullptr; +} + +template +T* getMutable(TypeFunctionTypeId tv) +{ + LUAU_ASSERT(tv); + + return tv ? Luau::get_if(&const_cast(tv)->type) : nullptr; +} + +std::optional checkResultForError(lua_State* L, const char* typeFunctionName, int luaResult); + +TypeFunctionType* allocateTypeFunctionType(lua_State* L, TypeFunctionTypeVariant type); +TypeFunctionTypePackVar* allocateTypeFunctionTypePack(lua_State* L, TypeFunctionTypePackVariant type); + +void allocTypeUserData(lua_State* L, TypeFunctionTypeVariant type); + +bool isTypeUserData(lua_State* L, int idx); +TypeFunctionTypeId getTypeUserData(lua_State* L, int idx); +std::optional optionalTypeUserData(lua_State* L, int idx); + +void registerTypesLibrary(lua_State* L); +void registerTypeUserData(lua_State* L); + +void setTypeFunctionEnvironment(lua_State* L); + +} // namespace Luau diff --git a/Analysis/include/Luau/TypeFunctionRuntimeBuilder.h b/Analysis/include/Luau/TypeFunctionRuntimeBuilder.h new file mode 100644 index 000000000..c9e1152f9 --- /dev/null +++ b/Analysis/include/Luau/TypeFunctionRuntimeBuilder.h @@ -0,0 +1,52 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/Type.h" +#include "Luau/TypeFunction.h" +#include "Luau/TypeFunctionRuntime.h" + +namespace Luau +{ + +using Kind = Variant; + +template +const T* get(const Kind& kind) +{ + return get_if(&kind); +} + +using TypeFunctionKind = Variant; + +template +const T* get(const TypeFunctionKind& tfkind) +{ + return get_if(&tfkind); +} + +struct TypeFunctionRuntimeBuilderState +{ + NotNull ctx; + + // Mapping of class name to ClassType + // Invariant: users can not create a new class types -> any class types that get deserialized must have been an argument to the type function + // Using this invariant, whenever a ClassType is serialized, we can put it into this map + // whenever a ClassType is deserialized, we can use this map to return the corresponding value + DenseHashMap classesSerialized{{}}; + + // List of errors that occur during serialization/deserialization + // At every iteration of serialization/deserialzation, if this list.size() != 0, we halt the process + std::vector errors{}; + + TypeFunctionRuntimeBuilderState(NotNull ctx) + : ctx(ctx) + , classesSerialized({}) + , errors({}) + { + } +}; + +TypeFunctionTypeId serialize(TypeId ty, TypeFunctionRuntimeBuilderState* state); +TypeId deserialize(TypeFunctionTypeId ty, TypeFunctionRuntimeBuilderState* state); + +} // namespace Luau diff --git a/Analysis/include/Luau/TypeFwd.h b/Analysis/include/Luau/TypeFwd.h new file mode 100644 index 000000000..42d582fea --- /dev/null +++ b/Analysis/include/Luau/TypeFwd.h @@ -0,0 +1,59 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/Variant.h" + +#include + +namespace Luau +{ + +// So... why `const T*` here rather than `T*`? +// It's because we've had problems caused by the type graph being mutated +// in ways it shouldn't be, for example mutating types from other modules. +// To try to control this, we make the use of types immutable by default, +// then provide explicit mutable access via getMutable and asMutable. +// This means we can grep for all the places we're mutating the type graph, +// and it makes it possible to provide other APIs (e.g. the txn log) +// which control mutable access to the type graph. + +struct Type; +using TypeId = const Type*; + +struct FreeType; +struct GenericType; +struct PrimitiveType; +struct BlockedType; +struct PendingExpansionType; +struct SingletonType; +struct FunctionType; +struct TableType; +struct MetatableType; +struct ClassType; +struct AnyType; +struct UnionType; +struct IntersectionType; +struct LazyType; +struct UnknownType; +struct NeverType; +struct NegationType; +struct TypeFunctionInstanceType; + +struct TypePackVar; +using TypePackId = const TypePackVar*; + +struct FreeTypePack; +struct GenericTypePack; +struct TypePack; +struct VariadicTypePack; +struct BlockedTypePack; +struct TypeFunctionInstanceTypePack; + +using Name = std::string; +using ModuleName = std::string; + +struct BuiltinTypes; + +using TypeOrPack = Variant; + +} // namespace Luau diff --git a/Analysis/include/Luau/TypeInfer.h b/Analysis/include/Luau/TypeInfer.h index 678bd419d..7f2e29b5a 100644 --- a/Analysis/include/Luau/TypeInfer.h +++ b/Analysis/include/Luau/TypeInfer.h @@ -2,14 +2,17 @@ #pragma once #include "Luau/Anyification.h" -#include "Luau/Predicate.h" +#include "Luau/ControlFlow.h" #include "Luau/Error.h" +#include "Luau/Instantiation.h" #include "Luau/Module.h" -#include "Luau/Symbol.h" +#include "Luau/Predicate.h" #include "Luau/Substitution.h" +#include "Luau/Symbol.h" #include "Luau/TxnLog.h" -#include "Luau/TypePack.h" -#include "Luau/Type.h" +#include "Luau/TypeFwd.h" +#include "Luau/TypeCheckLimits.h" +#include "Luau/TypeUtils.h" #include "Luau/Unifier.h" #include "Luau/UnifierSharedState.h" @@ -17,18 +20,24 @@ #include #include -LUAU_FASTFLAG(LuauClassTypeVarsInSubstitution) - namespace Luau { struct Scope; struct TypeChecker; struct ModuleResolver; +struct FrontendCancellationToken; using Name = std::string; using ScopePtr = std::shared_ptr; -using OverloadErrorEntry = std::tuple, std::vector, const FunctionType*>; + +struct OverloadErrorEntry +{ + TxnLog log; + ErrorVec errors; + std::vector arguments; + const FunctionType* fnTy; +}; bool doesCallError(const AstExprCall* call); bool hasBreak(AstStat* node); @@ -48,26 +57,16 @@ struct HashBoolNamePair size_t operator()(const std::pair& pair) const; }; -class TimeLimitError : public InternalCompilerError -{ -public: - explicit TimeLimitError(const std::string& moduleName) - : InternalCompilerError("Typeinfer failed to complete in allotted time", moduleName) - { - } -}; - -enum class ValueContext -{ - LValue, - RValue -}; - // All Types are retained via Environment::types. All TypeIds // within a program are borrowed pointers into this set. struct TypeChecker { - explicit TypeChecker(ModuleResolver* resolver, NotNull builtinTypes, InternalErrorReporter* iceHandler); + explicit TypeChecker( + const ScopePtr& globalScope, + ModuleResolver* resolver, + NotNull builtinTypes, + InternalErrorReporter* iceHandler + ); TypeChecker(const TypeChecker&) = delete; TypeChecker& operator=(const TypeChecker&) = delete; @@ -76,32 +75,37 @@ struct TypeChecker std::vector> getScopes() const; - void check(const ScopePtr& scope, const AstStat& statement); - void check(const ScopePtr& scope, const AstStatBlock& statement); - void check(const ScopePtr& scope, const AstStatIf& statement); - void check(const ScopePtr& scope, const AstStatWhile& statement); - void check(const ScopePtr& scope, const AstStatRepeat& statement); - void check(const ScopePtr& scope, const AstStatReturn& return_); - void check(const ScopePtr& scope, const AstStatAssign& assign); - void check(const ScopePtr& scope, const AstStatCompoundAssign& assign); - void check(const ScopePtr& scope, const AstStatLocal& local); - void check(const ScopePtr& scope, const AstStatFor& local); - void check(const ScopePtr& scope, const AstStatForIn& forin); - void check(const ScopePtr& scope, TypeId ty, const ScopePtr& funScope, const AstStatFunction& function); - void check(const ScopePtr& scope, TypeId ty, const ScopePtr& funScope, const AstStatLocalFunction& function); - void check(const ScopePtr& scope, const AstStatTypeAlias& typealias); - void check(const ScopePtr& scope, const AstStatDeclareClass& declaredClass); - void check(const ScopePtr& scope, const AstStatDeclareFunction& declaredFunction); + ControlFlow check(const ScopePtr& scope, const AstStat& statement); + ControlFlow check(const ScopePtr& scope, const AstStatBlock& statement); + ControlFlow check(const ScopePtr& scope, const AstStatIf& statement); + ControlFlow check(const ScopePtr& scope, const AstStatWhile& statement); + ControlFlow check(const ScopePtr& scope, const AstStatRepeat& statement); + ControlFlow check(const ScopePtr& scope, const AstStatReturn& return_); + ControlFlow check(const ScopePtr& scope, const AstStatAssign& assign); + ControlFlow check(const ScopePtr& scope, const AstStatCompoundAssign& assign); + ControlFlow check(const ScopePtr& scope, const AstStatLocal& local); + ControlFlow check(const ScopePtr& scope, const AstStatFor& local); + ControlFlow check(const ScopePtr& scope, const AstStatForIn& forin); + ControlFlow check(const ScopePtr& scope, TypeId ty, const ScopePtr& funScope, const AstStatFunction& function); + ControlFlow check(const ScopePtr& scope, TypeId ty, const ScopePtr& funScope, const AstStatLocalFunction& function); + ControlFlow check(const ScopePtr& scope, const AstStatTypeAlias& typealias); + ControlFlow check(const ScopePtr& scope, const AstStatTypeFunction& typefunction); + ControlFlow check(const ScopePtr& scope, const AstStatDeclareClass& declaredClass); + ControlFlow check(const ScopePtr& scope, const AstStatDeclareFunction& declaredFunction); void prototype(const ScopePtr& scope, const AstStatTypeAlias& typealias, int subLevel = 0); void prototype(const ScopePtr& scope, const AstStatDeclareClass& declaredClass); - void checkBlock(const ScopePtr& scope, const AstStatBlock& statement); - void checkBlockWithoutRecursionCheck(const ScopePtr& scope, const AstStatBlock& statement); + ControlFlow checkBlock(const ScopePtr& scope, const AstStatBlock& statement); + ControlFlow checkBlockWithoutRecursionCheck(const ScopePtr& scope, const AstStatBlock& statement); void checkBlockTypeAliases(const ScopePtr& scope, std::vector& sorted); WithPredicate checkExpr( - const ScopePtr& scope, const AstExpr& expr, std::optional expectedType = std::nullopt, bool forceSingleton = false); + const ScopePtr& scope, + const AstExpr& expr, + std::optional expectedType = std::nullopt, + bool forceSingleton = false + ); WithPredicate checkExpr(const ScopePtr& scope, const AstExprLocal& expr); WithPredicate checkExpr(const ScopePtr& scope, const AstExprGlobal& expr); WithPredicate checkExpr(const ScopePtr& scope, const AstExprVarargs& expr); @@ -112,17 +116,31 @@ struct TypeChecker WithPredicate checkExpr(const ScopePtr& scope, const AstExprTable& expr, std::optional expectedType = std::nullopt); WithPredicate checkExpr(const ScopePtr& scope, const AstExprUnary& expr); TypeId checkRelationalOperation( - const ScopePtr& scope, const AstExprBinary& expr, TypeId lhsType, TypeId rhsType, const PredicateVec& predicates = {}); + const ScopePtr& scope, + const AstExprBinary& expr, + TypeId lhsType, + TypeId rhsType, + const PredicateVec& predicates = {} + ); TypeId checkBinaryOperation( - const ScopePtr& scope, const AstExprBinary& expr, TypeId lhsType, TypeId rhsType, const PredicateVec& predicates = {}); + const ScopePtr& scope, + const AstExprBinary& expr, + TypeId lhsType, + TypeId rhsType, + const PredicateVec& predicates = {} + ); WithPredicate checkExpr(const ScopePtr& scope, const AstExprBinary& expr, std::optional expectedType = std::nullopt); WithPredicate checkExpr(const ScopePtr& scope, const AstExprTypeAssertion& expr); WithPredicate checkExpr(const ScopePtr& scope, const AstExprError& expr); WithPredicate checkExpr(const ScopePtr& scope, const AstExprIfElse& expr, std::optional expectedType = std::nullopt); WithPredicate checkExpr(const ScopePtr& scope, const AstExprInterpString& expr); - TypeId checkExprTable(const ScopePtr& scope, const AstExprTable& expr, const std::vector>& fieldTypes, - std::optional expectedType); + TypeId checkExprTable( + const ScopePtr& scope, + const AstExprTable& expr, + const std::vector>& fieldTypes, + std::optional expectedType + ); // Returns the type of the lvalue. TypeId checkLValue(const ScopePtr& scope, const AstExpr& expr, ValueContext ctx); @@ -135,34 +153,79 @@ struct TypeChecker TypeId checkLValueBinding(const ScopePtr& scope, const AstExprIndexExpr& expr, ValueContext ctx); TypeId checkFunctionName(const ScopePtr& scope, AstExpr& funName, TypeLevel level); - std::pair checkFunctionSignature(const ScopePtr& scope, int subLevel, const AstExprFunction& expr, - std::optional originalNameLoc, std::optional selfType, std::optional expectedType); + std::pair checkFunctionSignature( + const ScopePtr& scope, + int subLevel, + const AstExprFunction& expr, + std::optional originalNameLoc, + std::optional selfType, + std::optional expectedType + ); void checkFunctionBody(const ScopePtr& scope, TypeId type, const AstExprFunction& function); - void checkArgumentList(const ScopePtr& scope, const AstExpr& funName, Unifier& state, TypePackId paramPack, TypePackId argPack, - const std::vector& argLocations); + void checkArgumentList( + const ScopePtr& scope, + const AstExpr& funName, + Unifier& state, + TypePackId paramPack, + TypePackId argPack, + const std::vector& argLocations + ); WithPredicate checkExprPack(const ScopePtr& scope, const AstExpr& expr); WithPredicate checkExprPackHelper(const ScopePtr& scope, const AstExpr& expr); WithPredicate checkExprPackHelper(const ScopePtr& scope, const AstExprCall& expr); WithPredicate checkExprPackHelper2( - const ScopePtr& scope, const AstExprCall& expr, TypeId selfType, TypeId actualFunctionType, TypeId functionType, TypePackId retPack); + const ScopePtr& scope, + const AstExprCall& expr, + TypeId selfType, + TypeId actualFunctionType, + TypeId functionType, + TypePackId retPack + ); std::vector> getExpectedTypesForCall(const std::vector& overloads, size_t argumentCount, bool selfCall); - std::unique_ptr> checkCallOverload(const ScopePtr& scope, const AstExprCall& expr, TypeId fn, TypePackId retPack, - TypePackId argPack, TypePack* args, const std::vector* argLocations, const WithPredicate& argListResult, - std::vector& overloadsThatMatchArgCount, std::vector& overloadsThatDont, std::vector& errors); - bool handleSelfCallMismatch(const ScopePtr& scope, const AstExprCall& expr, TypePack* args, const std::vector& argLocations, - const std::vector& errors); - void reportOverloadResolutionError(const ScopePtr& scope, const AstExprCall& expr, TypePackId retPack, TypePackId argPack, - const std::vector& argLocations, const std::vector& overloads, const std::vector& overloadsThatMatchArgCount, - const std::vector& errors); - - WithPredicate checkExprList(const ScopePtr& scope, const Location& location, const AstArray& exprs, - bool substituteFreeForNil = false, const std::vector& lhsAnnotations = {}, - const std::vector>& expectedTypes = {}); + std::unique_ptr> checkCallOverload( + const ScopePtr& scope, + const AstExprCall& expr, + TypeId fn, + TypePackId retPack, + TypePackId argPack, + TypePack* args, + const std::vector* argLocations, + const WithPredicate& argListResult, + std::vector& overloadsThatMatchArgCount, + std::vector& overloadsThatDont, + std::vector& errors + ); + bool handleSelfCallMismatch( + const ScopePtr& scope, + const AstExprCall& expr, + TypePack* args, + const std::vector& argLocations, + const std::vector& errors + ); + void reportOverloadResolutionError( + const ScopePtr& scope, + const AstExprCall& expr, + TypePackId retPack, + TypePackId argPack, + const std::vector& argLocations, + const std::vector& overloads, + const std::vector& overloadsThatMatchArgCount, + std::vector& errors + ); + + WithPredicate checkExprList( + const ScopePtr& scope, + const Location& location, + const AstArray& exprs, + bool substituteFreeForNil = false, + const std::vector& lhsAnnotations = {}, + const std::vector>& expectedTypes = {} + ); static std::optional matchRequire(const AstExprCall& call); TypeId checkRequire(const ScopePtr& scope, const ModuleInfo& moduleInfo, const Location& location); @@ -180,8 +243,13 @@ struct TypeChecker */ bool unify(TypeId subTy, TypeId superTy, const ScopePtr& scope, const Location& location); bool unify(TypeId subTy, TypeId superTy, const ScopePtr& scope, const Location& location, const UnifierOptions& options); - bool unify(TypePackId subTy, TypePackId superTy, const ScopePtr& scope, const Location& location, - CountMismatch::Context ctx = CountMismatch::Context::Arg); + bool unify( + TypePackId subTy, + TypePackId superTy, + const ScopePtr& scope, + const Location& location, + CountMismatch::Context ctx = CountMismatch::Context::Arg + ); /** Attempt to unify the types. * If this fails, and the subTy type can be instantiated, do so and try unification again. @@ -249,6 +317,7 @@ struct TypeChecker [[noreturn]] void ice(const std::string& message, const Location& location); [[noreturn]] void ice(const std::string& message); [[noreturn]] void throwTimeLimitError(); + [[noreturn]] void throwUserCancelError(); ScopePtr childFunctionScope(const ScopePtr& parent, const Location& location, int subLevel = 0); ScopePtr childScope(const ScopePtr& parent, const Location& location); @@ -317,12 +386,23 @@ struct TypeChecker TypeId resolveTypeWorker(const ScopePtr& scope, const AstType& annotation); TypePackId resolveTypePack(const ScopePtr& scope, const AstTypeList& types); TypePackId resolveTypePack(const ScopePtr& scope, const AstTypePack& annotation); - TypeId instantiateTypeFun(const ScopePtr& scope, const TypeFun& tf, const std::vector& typeParams, - const std::vector& typePackParams, const Location& location); + TypeId instantiateTypeFun( + const ScopePtr& scope, + const TypeFun& tf, + const std::vector& typeParams, + const std::vector& typePackParams, + const Location& location + ); // Note: `scope` must be a fresh scope. - GenericTypeDefinitions createGenericTypes(const ScopePtr& scope, std::optional levelOpt, const AstNode& node, - const AstArray& genericNames, const AstArray& genericPackNames, bool useCache = false); + GenericTypeDefinitions createGenericTypes( + const ScopePtr& scope, + std::optional levelOpt, + const AstNode& node, + const AstArray& genericNames, + const AstArray& genericPackNames, + bool useCache = false + ); public: void resolve(const PredicateVec& predicates, const ScopePtr& scope, bool sense); @@ -355,13 +435,10 @@ struct TypeChecker */ std::vector unTypePack(const ScopePtr& scope, TypePackId pack, size_t expectedLength, const Location& location); - TypeArena globalTypes; + const ScopePtr& globalScope; ModuleResolver* resolver; - SourceModule globalNames; // names for symbols entered into globalScope - ScopePtr globalScope; // shared by all modules ModulePtr currentModule; - ModuleName currentModuleName; std::function prepareModuleScope; NotNull builtinTypes; @@ -370,6 +447,8 @@ struct TypeChecker UnifierSharedState unifierState; Normalizer normalizer; + Instantiation reusableInstantiation; + std::vector requireCycles; // Type inference limits @@ -377,12 +456,15 @@ struct TypeChecker std::optional instantiationChildLimit; std::optional unifierIterationLimit; + std::shared_ptr cancellationToken; + public: const TypeId nilType; const TypeId numberType; const TypeId stringType; const TypeId booleanType; const TypeId threadType; + const TypeId bufferType; const TypeId anyType; const TypeId unknownType; const TypeId neverType; diff --git a/Analysis/include/Luau/TypeOrPack.h b/Analysis/include/Luau/TypeOrPack.h new file mode 100644 index 000000000..870019109 --- /dev/null +++ b/Analysis/include/Luau/TypeOrPack.h @@ -0,0 +1,41 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/Type.h" +#include "Luau/TypePack.h" +#include "Luau/Variant.h" + +#include + +namespace Luau +{ + +const void* ptr(TypeOrPack ty); + +template, bool> = true> +const T* get(const TypeOrPack& tyOrTp) +{ + return tyOrTp.get_if(); +} + +template, bool> = true> +const T* get(const TypeOrPack& tyOrTp) +{ + if (const TypeId* ty = get(tyOrTp)) + return get(*ty); + else + return nullptr; +} + +template, bool> = true> +const T* get(const TypeOrPack& tyOrTp) +{ + if (const TypePackId* tp = get(tyOrTp)) + return get(*tp); + else + return nullptr; +} + +TypeOrPack follow(TypeOrPack ty); + +} // namespace Luau diff --git a/Analysis/include/Luau/TypePack.h b/Analysis/include/Luau/TypePack.h index 4831f2338..1065b9475 100644 --- a/Analysis/include/Luau/TypePack.h +++ b/Analysis/include/Luau/TypePack.h @@ -1,31 +1,61 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #pragma once -#include "Luau/Type.h" #include "Luau/Unifiable.h" #include "Luau/Variant.h" +#include "Luau/TypeFwd.h" +#include "Luau/NotNull.h" +#include "Luau/Common.h" #include #include +#include namespace Luau { struct TypeArena; +struct TypePackFunction; +struct TxnLog; struct TypePack; struct VariadicTypePack; struct BlockedTypePack; +struct TypeFunctionInstanceTypePack; -struct TypePackVar; +struct FreeTypePack +{ + explicit FreeTypePack(TypeLevel level); + explicit FreeTypePack(Scope* scope); + FreeTypePack(Scope* scope, TypeLevel level); -struct TxnLog; + int index; + TypeLevel level; + Scope* scope = nullptr; +}; + +struct GenericTypePack +{ + // By default, generics are global, with a synthetic name + GenericTypePack(); + explicit GenericTypePack(TypeLevel level); + explicit GenericTypePack(const Name& name); + explicit GenericTypePack(Scope* scope); + GenericTypePack(TypeLevel level, const Name& name); + GenericTypePack(Scope* scope, const Name& name); + + int index; + TypeLevel level; + Scope* scope = nullptr; + Name name; + bool explicitName = false; +}; -using TypePackId = const TypePackVar*; -using FreeTypePack = Unifiable::Free; using BoundTypePack = Unifiable::Bound; -using GenericTypePack = Unifiable::Generic; -using TypePackVariant = Unifiable::Variant; +using ErrorTypePack = Unifiable::Error; + +using TypePackVariant = + Unifiable::Variant; /* A TypePack is a rope-like string of TypeIds. We use this structure to encode * notions like packs of unknown length and packs of any length, as well as more @@ -52,9 +82,22 @@ struct BlockedTypePack BlockedTypePack(); size_t index; + struct Constraint* owner = nullptr; + static size_t nextIndex; }; +/** + * Analogous to a TypeFunctionInstanceType. + */ +struct TypeFunctionInstanceTypePack +{ + NotNull function; + + std::vector typeArguments; + std::vector packArguments; +}; + struct TypePackVar { explicit TypePackVar(const TypePackVariant& ty); @@ -141,7 +184,7 @@ using SeenSet = std::set>; bool areEqual(SeenSet& seen, const TypePackVar& lhs, const TypePackVar& rhs); TypePackId follow(TypePackId tp); -TypePackId follow(TypePackId tp, std::function mapper); +TypePackId follow(TypePackId t, const void* context, TypePackId (*mapper)(const void*, TypePackId)); size_t size(TypePackId tp, TxnLog* log = nullptr); bool finite(TypePackId tp, TxnLog* log = nullptr); @@ -190,4 +233,18 @@ bool isVariadicTail(TypePackId tp, const TxnLog& log, bool includeHiddenVariadic bool containsNever(TypePackId tp); +/* + * Use this to change the kind of a particular type pack. + * + * LUAU_NOINLINE so that the calling frame doesn't have to pay the stack storage for the new variant. + */ +template +LUAU_NOINLINE T* emplaceTypePack(TypePackVar* ty, Args&&... args) +{ + return &ty->ty.emplace(std::forward(args)...); +} + +template<> +LUAU_NOINLINE Unifiable::Bound* emplaceTypePack(TypePackVar* ty, TypePackId& tyArg); + } // namespace Luau diff --git a/Analysis/include/Luau/TypePairHash.h b/Analysis/include/Luau/TypePairHash.h new file mode 100644 index 000000000..591f20f11 --- /dev/null +++ b/Analysis/include/Luau/TypePairHash.h @@ -0,0 +1,35 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/TypeFwd.h" + +#include +#include + +namespace Luau +{ + +struct TypePairHash +{ + size_t hashOne(TypeId key) const + { + return (uintptr_t(key) >> 4) ^ (uintptr_t(key) >> 9); + } + + size_t hashOne(TypePackId key) const + { + return (uintptr_t(key) >> 4) ^ (uintptr_t(key) >> 9); + } + + size_t operator()(const std::pair& x) const + { + return hashOne(x.first) ^ (hashOne(x.second) << 1); + } + + size_t operator()(const std::pair& x) const + { + return hashOne(x.first) ^ (hashOne(x.second) << 1); + } +}; + +} // namespace Luau diff --git a/Analysis/include/Luau/TypePath.h b/Analysis/include/Luau/TypePath.h new file mode 100644 index 000000000..2af5185d6 --- /dev/null +++ b/Analysis/include/Luau/TypePath.h @@ -0,0 +1,239 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/TypeFwd.h" +#include "Luau/Variant.h" +#include "Luau/NotNull.h" + +#include +#include +#include + +namespace Luau +{ + +namespace TypePath +{ + +/// Represents a property of a class, table, or anything else with a concept of +/// a named property. +struct Property +{ + /// The name of the property. + std::string name; + /// Whether to look at the read or the write type. + bool isRead = true; + + explicit Property(std::string name); + Property(std::string name, bool read) + : name(std::move(name)) + , isRead(read) + { + } + + static Property read(std::string name); + static Property write(std::string name); + + bool operator==(const Property& other) const; +}; + +/// Represents an index into a type or a pack. For a type, this indexes into a +/// union or intersection's list. For a pack, this indexes into the pack's nth +/// element. +struct Index +{ + /// The 0-based index to use for the lookup. + size_t index; + + bool operator==(const Index& other) const; +}; + +/// Represents fields of a type or pack that contain a type. +enum class TypeField +{ + /// The table of a metatable type. + Table, + /// The metatable of a type. This could be a metatable type, a primitive + /// type, a class type, or perhaps even a string singleton type. + Metatable, + /// The lower bound of this type, if one is present. + LowerBound, + /// The upper bound of this type, if present. + UpperBound, + /// The index type. + IndexLookup, + /// The indexer result type. + IndexResult, + /// The negated type, for negations. + Negated, + /// The variadic type for a type pack. + Variadic, +}; + +/// Represents fields of a type or type pack that contain a type pack. +enum class PackField +{ + /// What arguments this type accepts. + Arguments, + /// What this type returns when called. + Returns, + /// The tail of a type pack. + Tail, +}; + +/// Component that represents the result of a reduction +/// `resultType` is `never` if the reduction could not proceed +struct Reduction +{ + TypeId resultType; + + bool operator==(const Reduction& other) const; +}; + +/// A single component of a path, representing one inner type or type pack to +/// traverse into. +using Component = Luau::Variant; + +/// A path through a type or type pack accessing a particular type or type pack +/// contained within. +/// +/// Paths are always relative; to make use of a Path, you need to specify an +/// entry point. They are not canonicalized; two Paths may not compare equal but +/// may point to the same result, depending on the layout of the entry point. +/// +/// Paths always descend through an entry point. This doesn't mean that they +/// cannot reach "upwards" in the actual type hierarchy in some cases, but it +/// does mean that there is no equivalent to `../` in file system paths. This is +/// intentional and unavoidable, because types and type packs don't have a +/// concept of a parent - they are a directed cyclic graph, with no hierarchy +/// that actually holds in all cases. +struct Path +{ + /// The Components of this Path. + std::vector components; + + /// Creates a new empty Path. + Path() {} + + /// Creates a new Path from a list of components. + explicit Path(std::vector components) + : components(std::move(components)) + { + } + + /// Creates a new single-component Path. + explicit Path(Component component) + : components({component}) + { + } + + /// Creates a new Path by appending another Path to this one. + /// @param suffix the Path to append + /// @return a new Path representing `this + suffix` + Path append(const Path& suffix) const; + + /// Creates a new Path by appending a Component to this Path. + /// @param component the Component to append + /// @return a new Path with `component` appended to it. + Path push(Component component) const; + + /// Creates a new Path by prepending a Component to this Path. + /// @param component the Component to prepend + /// @return a new Path with `component` prepended to it. + Path push_front(Component component) const; + + /// Creates a new Path by removing the last Component of this Path. + /// If the Path is empty, this is a no-op. + /// @return a Path with the last component removed. + Path pop() const; + + /// Returns the last Component of this Path, if present. + std::optional last() const; + + /// Returns whether this Path is empty, meaning it has no components at all. + /// Traversing an empty Path results in the type you started with. + bool empty() const; + + bool operator==(const Path& other) const; + bool operator!=(const Path& other) const + { + return !(*this == other); + } +}; + +struct PathHash +{ + size_t operator()(const Property& prop) const; + size_t operator()(const Index& idx) const; + size_t operator()(const TypeField& field) const; + size_t operator()(const PackField& field) const; + size_t operator()(const Reduction& reduction) const; + size_t operator()(const Component& component) const; + size_t operator()(const Path& path) const; +}; + +/// The canonical "empty" Path, meaning a Path with no components. +static const Path kEmpty{}; + +struct PathBuilder +{ + std::vector components; + + Path build(); + + PathBuilder& readProp(std::string name); + PathBuilder& writeProp(std::string name); + PathBuilder& prop(std::string name); + PathBuilder& index(size_t i); + PathBuilder& mt(); + PathBuilder& lb(); + PathBuilder& ub(); + PathBuilder& indexKey(); + PathBuilder& indexValue(); + PathBuilder& negated(); + PathBuilder& variadic(); + PathBuilder& args(); + PathBuilder& rets(); + PathBuilder& tail(); +}; + +} // namespace TypePath + +using Path = TypePath::Path; + +/// Converts a Path to a string for debugging purposes. This output may not be +/// terribly clear to end users of the Luau type system. +std::string toString(const TypePath::Path& path, bool prefixDot = false); + +std::optional traverse(TypeId root, const Path& path, NotNull builtinTypes); +std::optional traverse(TypePackId root, const Path& path, NotNull builtinTypes); + +/// Traverses a path from a type to its end point, which must be a type. +/// @param root the entry point of the traversal +/// @param path the path to traverse +/// @param builtinTypes the built-in types in use (used to acquire the string metatable) +/// @returns the TypeId at the end of the path, or nullopt if the traversal failed. +std::optional traverseForType(TypeId root, const Path& path, NotNull builtinTypes); + +/// Traverses a path from a type pack to its end point, which must be a type. +/// @param root the entry point of the traversal +/// @param path the path to traverse +/// @param builtinTypes the built-in types in use (used to acquire the string metatable) +/// @returns the TypeId at the end of the path, or nullopt if the traversal failed. +std::optional traverseForType(TypePackId root, const Path& path, NotNull builtinTypes); + +/// Traverses a path from a type to its end point, which must be a type pack. +/// @param root the entry point of the traversal +/// @param path the path to traverse +/// @param builtinTypes the built-in types in use (used to acquire the string metatable) +/// @returns the TypePackId at the end of the path, or nullopt if the traversal failed. +std::optional traverseForPack(TypeId root, const Path& path, NotNull builtinTypes); + +/// Traverses a path from a type pack to its end point, which must be a type pack. +/// @param root the entry point of the traversal +/// @param path the path to traverse +/// @param builtinTypes the built-in types in use (used to acquire the string metatable) +/// @returns the TypePackId at the end of the path, or nullopt if the traversal failed. +std::optional traverseForPack(TypePackId root, const Path& path, NotNull builtinTypes); + +} // namespace Luau diff --git a/Analysis/include/Luau/TypeReduction.h b/Analysis/include/Luau/TypeReduction.h deleted file mode 100644 index 3f64870ab..000000000 --- a/Analysis/include/Luau/TypeReduction.h +++ /dev/null @@ -1,85 +0,0 @@ -// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details -#pragma once - -#include "Luau/Type.h" -#include "Luau/TypeArena.h" -#include "Luau/TypePack.h" -#include "Luau/Variant.h" - -namespace Luau -{ - -namespace detail -{ -template -struct ReductionEdge -{ - T type = nullptr; - bool irreducible = false; -}; - -struct TypeReductionMemoization -{ - TypeReductionMemoization() = default; - - TypeReductionMemoization(const TypeReductionMemoization&) = delete; - TypeReductionMemoization& operator=(const TypeReductionMemoization&) = delete; - - TypeReductionMemoization(TypeReductionMemoization&&) = default; - TypeReductionMemoization& operator=(TypeReductionMemoization&&) = default; - - DenseHashMap> types{nullptr}; - DenseHashMap> typePacks{nullptr}; - - bool isIrreducible(TypeId ty); - bool isIrreducible(TypePackId tp); - - TypeId memoize(TypeId ty, TypeId reducedTy); - TypePackId memoize(TypePackId tp, TypePackId reducedTp); - - // Reducing A into B may have a non-irreducible edge A to B for which B is not irreducible, which means B could be reduced into C. - // Because reduction should always be transitive, A should point to C if A points to B and B points to C. - std::optional> memoizedof(TypeId ty) const; - std::optional> memoizedof(TypePackId tp) const; -}; -} // namespace detail - -struct TypeReductionOptions -{ - /// If it's desirable for type reduction to allocate into a different arena than the TypeReduction instance you have, you will need - /// to create a temporary TypeReduction in that case, and set [`TypeReductionOptions::allowTypeReductionsFromOtherArenas`] to true. - /// This is because TypeReduction caches the reduced type. - bool allowTypeReductionsFromOtherArenas = false; -}; - -struct TypeReduction -{ - explicit TypeReduction(NotNull arena, NotNull builtinTypes, NotNull handle, - const TypeReductionOptions& opts = {}); - - TypeReduction(const TypeReduction&) = delete; - TypeReduction& operator=(const TypeReduction&) = delete; - - TypeReduction(TypeReduction&&) = default; - TypeReduction& operator=(TypeReduction&&) = default; - - std::optional reduce(TypeId ty); - std::optional reduce(TypePackId tp); - std::optional reduce(const TypeFun& fun); - -private: - NotNull arena; - NotNull builtinTypes; - NotNull handle; - - TypeReductionOptions options; - detail::TypeReductionMemoization memoization; - - // Computes an *estimated length* of the cartesian product of the given type. - size_t cartesianProductSize(TypeId ty) const; - - bool hasExceededCartesianProductLimit(TypeId ty) const; - bool hasExceededCartesianProductLimit(TypePackId tp) const; -}; - -} // namespace Luau diff --git a/Analysis/include/Luau/TypeUtils.h b/Analysis/include/Luau/TypeUtils.h index 42ba40522..de9660ef6 100644 --- a/Analysis/include/Luau/TypeUtils.h +++ b/Analysis/include/Luau/TypeUtils.h @@ -14,13 +14,79 @@ namespace Luau struct TxnLog; struct TypeArena; +class Normalizer; + +enum class ValueContext +{ + LValue, + RValue +}; + +/// the current context of the type checker +enum class TypeContext +{ + /// the default context + Default, + /// inside of a condition + Condition, +}; + +bool inConditional(const TypeContext& context); + +// sets the given type context to `Condition` and restores it to its original +// value when the struct drops out of scope +struct InConditionalContext +{ + TypeContext* typeContext; + TypeContext oldValue; + + InConditionalContext(TypeContext* c) + : typeContext(c) + , oldValue(*c) + { + *typeContext = TypeContext::Condition; + } + + ~InConditionalContext() + { + *typeContext = oldValue; + } +}; using ScopePtr = std::shared_ptr; +std::optional findTableProperty( + NotNull builtinTypes, + ErrorVec& errors, + TypeId ty, + const std::string& name, + Location location +); + std::optional findMetatableEntry( - NotNull builtinTypes, ErrorVec& errors, TypeId type, const std::string& entry, Location location); + NotNull builtinTypes, + ErrorVec& errors, + TypeId type, + const std::string& entry, + Location location +); std::optional findTablePropertyRespectingMeta( - NotNull builtinTypes, ErrorVec& errors, TypeId ty, const std::string& name, Location location); + NotNull builtinTypes, + ErrorVec& errors, + TypeId ty, + const std::string& name, + Location location +); +std::optional findTablePropertyRespectingMeta( + NotNull builtinTypes, + ErrorVec& errors, + TypeId ty, + const std::string& name, + ValueContext context, + Location location +); + +bool occursCheck(TypeId needle, TypeId haystack); // Returns the minimum and maximum number of types the argument list can accept. std::pair> getParameterExtents(const TxnLog* log, TypePackId tp, bool includeHiddenVariadics = false); @@ -28,7 +94,12 @@ std::pair> getParameterExtents(const TxnLog* log, // Extend the provided pack to at least `length` types. // Returns a temporary TypePack that contains those types plus a tail. TypePack extendTypePack( - TypeArena& arena, NotNull builtinTypes, TypePackId pack, size_t length, std::vector> overrides = {}); + TypeArena& arena, + NotNull builtinTypes, + TypePackId pack, + size_t length, + std::vector> overrides = {} +); /** * Reduces a union by decomposing to the any/error type if it appears in the @@ -49,4 +120,164 @@ std::vector reduceUnion(const std::vector& types); */ TypeId stripNil(NotNull builtinTypes, TypeArena& arena, TypeId ty); +struct ErrorSuppression +{ + enum Value + { + Suppress, + DoNotSuppress, + NormalizationFailed, + }; + + ErrorSuppression() = default; + constexpr ErrorSuppression(Value enumValue) + : value(enumValue) + { + } + + constexpr operator Value() const + { + return value; + } + explicit operator bool() const = delete; + + ErrorSuppression orElse(const ErrorSuppression& other) const + { + switch (value) + { + case DoNotSuppress: + return other; + default: + return *this; + } + } + +private: + Value value; +}; + +/** + * Normalizes the given type using the normalizer to determine if the type + * should suppress any errors that would be reported involving it. + * @param normalizer the normalizer to use + * @param ty the type to check for error suppression + * @returns an enum indicating whether or not to suppress the error or to signal a normalization failure + */ +ErrorSuppression shouldSuppressErrors(NotNull normalizer, TypeId ty); + +/** + * Flattens and normalizes the given typepack using the normalizer to determine if the type + * should suppress any errors that would be reported involving it. + * @param normalizer the normalizer to use + * @param tp the typepack to check for error suppression + * @returns an enum indicating whether or not to suppress the error or to signal a normalization failure + */ +ErrorSuppression shouldSuppressErrors(NotNull normalizer, TypePackId tp); + +/** + * Normalizes the two given type using the normalizer to determine if either type + * should suppress any errors that would be reported involving it. + * @param normalizer the normalizer to use + * @param ty1 the first type to check for error suppression + * @param ty2 the second type to check for error suppression + * @returns an enum indicating whether or not to suppress the error or to signal a normalization failure + */ +ErrorSuppression shouldSuppressErrors(NotNull normalizer, TypeId ty1, TypeId ty2); + +/** + * Flattens and normalizes the two given typepacks using the normalizer to determine if either type + * should suppress any errors that would be reported involving it. + * @param normalizer the normalizer to use + * @param tp1 the first typepack to check for error suppression + * @param tp2 the second typepack to check for error suppression + * @returns an enum indicating whether or not to suppress the error or to signal a normalization failure + */ +ErrorSuppression shouldSuppressErrors(NotNull normalizer, TypePackId tp1, TypePackId tp2); + +// Similar to `std::optional>`, but whose `sizeof()` is the same as `std::pair` +// and cooperates with C++'s `if (auto p = ...)` syntax without the extra fatness of `std::optional`. +template +struct TryPair +{ + A first; + B second; + + explicit operator bool() const + { + return bool(first) && bool(second); + } +}; + +template +TryPair get2(Ty one, Ty two) +{ + static_assert(std::is_pointer_v, "argument must be a pointer type"); + + const A* a = get(one); + const B* b = get(two); + if (a && b) + return {a, b}; + else + return {nullptr, nullptr}; +} + +template +const T* get(std::optional ty) +{ + if (ty) + return get(*ty); + else + return nullptr; +} + +template +T* getMutable(std::optional ty) +{ + if (ty) + return getMutable(*ty); + else + return nullptr; +} + +template +std::optional follow(std::optional ty) +{ + if (ty) + return follow(*ty); + else + return std::nullopt; +} + +/** + * Returns whether or not expr is a literal expression, for example: + * - Scalar literals (numbers, booleans, strings, nil) + * - Table literals + * - Lambdas (a "function literal") + */ +bool isLiteral(const AstExpr* expr); + +/** + * Given a table literal and a mapping from expression to type, determine + * whether any literal expression in this table depends on any blocked types. + * This is used as a precondition for bidirectional inference: be warned that + * the behavior of this algorithm is tightly coupled to that of bidirectional + * inference. + * @param expr Expression to search + * @param astTypes Mapping from AST node to TypeID + * @returns A vector of blocked types + */ +std::vector findBlockedTypesIn(AstExprTable* expr, NotNull> astTypes); + +/** + * Given a function call and a mapping from expression to type, determine + * whether the type of any argument in said call in depends on a blocked types. + * This is used as a precondition for bidirectional inference: be warned that + * the behavior of this algorithm is tightly coupled to that of bidirectional + * inference. + * @param expr Expression to search + * @param astTypes Mapping from AST node to TypeID + * @returns A vector of blocked types + */ +std::vector findBlockedArgTypesIn(AstExprCall* expr, NotNull> astTypes); + } // namespace Luau diff --git a/Analysis/include/Luau/Unifiable.h b/Analysis/include/Luau/Unifiable.h index 15e501f02..79b3b7dea 100644 --- a/Analysis/include/Luau/Unifiable.h +++ b/Analysis/include/Luau/Unifiable.h @@ -81,23 +81,7 @@ namespace Luau::Unifiable using Name = std::string; -struct Free -{ - explicit Free(TypeLevel level); - explicit Free(Scope* scope); - explicit Free(Scope* scope, TypeLevel level); - - int index; - TypeLevel level; - Scope* scope = nullptr; - // True if this free type variable is part of a mutually - // recursive type alias whose definitions haven't been - // resolved yet. - bool forwardedTypeAlias = false; - -private: - static int DEPRECATED_nextIndex; -}; +int freshIndex(); template struct Bound @@ -110,26 +94,6 @@ struct Bound Id boundTo; }; -struct Generic -{ - // By default, generics are global, with a synthetic name - Generic(); - explicit Generic(TypeLevel level); - explicit Generic(const Name& name); - explicit Generic(Scope* scope); - Generic(TypeLevel level, const Name& name); - Generic(Scope* scope, const Name& name); - - int index; - TypeLevel level; - Scope* scope = nullptr; - Name name; - bool explicitName = false; - -private: - static int DEPRECATED_nextIndex; -}; - struct Error { // This constructor has to be public, since it's used in Type and TypePack, @@ -143,6 +107,6 @@ struct Error }; template -using Variant = Luau::Variant, Generic, Error, Value...>; +using Variant = Luau::Variant, Error, Value...>; } // namespace Luau::Unifiable diff --git a/Analysis/include/Luau/Unifier.h b/Analysis/include/Luau/Unifier.h index 50024e3fd..3de841ed2 100644 --- a/Analysis/include/Luau/Unifier.h +++ b/Analysis/include/Luau/Unifier.h @@ -9,7 +9,7 @@ #include "Luau/TxnLog.h" #include "Luau/TypeArena.h" #include "Luau/UnifierSharedState.h" -#include "Normalize.h" +#include "Luau/Normalize.h" #include @@ -43,6 +43,21 @@ struct Widen : Substitution TypePackId operator()(TypePackId ty); }; +/** + * Normally, when we unify table properties, we must do so invariantly, but we + * can introduce a special exception: If the table property in the subtype + * position arises from a literal expression, it is safe to instead perform a + * covariant check. + * + * This is very useful for typechecking cases where table literals (and trees of + * table literals) are passed directly to functions. + * + * In this case, we know that the property has no other name referring to it and + * so it is perfectly safe for the function to mutate the table any way it + * wishes. + */ +using LiteralProperties = DenseHashSet; + // TODO: Use this more widely. struct UnifierOptions { @@ -54,18 +69,20 @@ struct Unifier TypeArena* const types; NotNull builtinTypes; NotNull normalizer; - Mode mode; NotNull scope; // const Scope maybe TxnLog log; + bool failure = false; ErrorVec errors; Location location; Variance variance = Covariant; bool normalize = true; // Normalize unions and intersections if necessary bool checkInhabited = true; // Normalize types to check if they are inhabited - bool useScopes = false; // If true, we use the scope hierarchy rather than TypeLevels CountMismatch::Context ctx = CountMismatch::Arg; + // If true, generics act as free types when unifying. + bool hideousFixMeGenericsAreActuallyFree = false; + UnifierSharedState& sharedState; // When the Unifier is forced to unify two blocked types (or packs), they @@ -74,8 +91,11 @@ struct Unifier std::vector blockedTypes; std::vector blockedTypePacks; - Unifier( - NotNull normalizer, Mode mode, NotNull scope, const Location& location, Variance variance, TxnLog* parentLog = nullptr); + Unifier(NotNull normalizer, NotNull scope, const Location& location, Variance variance, TxnLog* parentLog = nullptr); + + // Configure the Unifier to test for scope subsumption via embedded Scope + // pointers rather than TypeLevels. + void enableNewSolver(); // Test whether the two type vars unify. Never commits the result. ErrorVec canUnify(TypeId subTy, TypeId superTy); @@ -85,20 +105,43 @@ struct Unifier * Populate the vector errors with any type errors that may arise. * Populate the transaction log with the set of TypeIds that need to be reset to undo the unification attempt. */ - void tryUnify(TypeId subTy, TypeId superTy, bool isFunctionCall = false, bool isIntersection = false); + void tryUnify( + TypeId subTy, + TypeId superTy, + bool isFunctionCall = false, + bool isIntersection = false, + const LiteralProperties* aliasableMap = nullptr + ); private: - void tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall = false, bool isIntersection = false); + void tryUnify_( + TypeId subTy, + TypeId superTy, + bool isFunctionCall = false, + bool isIntersection = false, + const LiteralProperties* aliasableMap = nullptr + ); void tryUnifyUnionWithType(TypeId subTy, const UnionType* uv, TypeId superTy); + + // Traverse the two types provided and block on any BlockedTypes we find. + // Returns true if any types were blocked on. + bool DEPRECATED_blockOnBlockedTypes(TypeId subTy, TypeId superTy); + void tryUnifyTypeWithUnion(TypeId subTy, TypeId superTy, const UnionType* uv, bool cacheEnabled, bool isFunctionCall); void tryUnifyTypeWithIntersection(TypeId subTy, TypeId superTy, const IntersectionType* uv); void tryUnifyIntersectionWithType(TypeId subTy, const IntersectionType* uv, TypeId superTy, bool cacheEnabled, bool isFunctionCall); - void tryUnifyNormalizedTypes(TypeId subTy, TypeId superTy, const NormalizedType& subNorm, const NormalizedType& superNorm, std::string reason, - std::optional error = std::nullopt); + void tryUnifyNormalizedTypes( + TypeId subTy, + TypeId superTy, + const NormalizedType& subNorm, + const NormalizedType& superNorm, + std::string reason, + std::optional error = std::nullopt + ); void tryUnifyPrimitives(TypeId subTy, TypeId superTy); void tryUnifySingletons(TypeId subTy, TypeId superTy); void tryUnifyFunctions(TypeId subTy, TypeId superTy, bool isFunctionCall = false); - void tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection = false); + void tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection = false, const LiteralProperties* aliasableMap = nullptr); void tryUnifyScalarShape(TypeId subTy, TypeId superTy, bool reversed); void tryUnifyWithMetatable(TypeId subTy, TypeId superTy, bool reversed); void tryUnifyWithClass(TypeId subTy, TypeId superTy, bool reversed); @@ -131,18 +174,17 @@ struct Unifier public: // Returns true if the type "needle" already occurs within "haystack" and reports an "infinite type error" - bool occursCheck(TypeId needle, TypeId haystack); + bool occursCheck(TypeId needle, TypeId haystack, bool reversed); bool occursCheck(DenseHashSet& seen, TypeId needle, TypeId haystack); - bool occursCheck(TypePackId needle, TypePackId haystack); + bool occursCheck(TypePackId needle, TypePackId haystack, bool reversed); bool occursCheck(DenseHashSet& seen, TypePackId needle, TypePackId haystack); - Unifier makeChildUnifier(); + std::unique_ptr makeChildUnifier(); void reportError(TypeError err); LUAU_NOINLINE void reportError(Location location, TypeErrorData data); private: - bool isNonstrictMode() const; TypeMismatch::Context mismatchContext(); void checkChildUnifierTypeMismatch(const ErrorVec& innerErrors, TypeId wantedType, TypeId givenType); @@ -153,9 +195,15 @@ struct Unifier // Available after regular type pack unification errors std::optional firstPackErrorPos; + + // If true, we do a bunch of small things differently to work better with + // the new type inference engine. Most notably, we use the Scope hierarchy + // directly rather than using TypeLevels. + bool useNewSolver = false; }; void promoteTypeLevels(TxnLog& log, const TypeArena* arena, TypeLevel minLevel, Scope* outerScope, bool useScope, TypePackId tp); std::optional hasUnificationTooComplex(const ErrorVec& errors); +std::optional hasCountMismatch(const ErrorVec& errors); } // namespace Luau diff --git a/Analysis/include/Luau/Unifier2.h b/Analysis/include/Luau/Unifier2.h new file mode 100644 index 000000000..8734aeec2 --- /dev/null +++ b/Analysis/include/Luau/Unifier2.h @@ -0,0 +1,115 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#pragma once + +#include "Luau/Constraint.h" +#include "Luau/DenseHash.h" +#include "Luau/NotNull.h" +#include "Luau/TypeCheckLimits.h" +#include "Luau/TypeFwd.h" +#include "Luau/TypePairHash.h" + +#include +#include +#include + +namespace Luau +{ + +struct InternalErrorReporter; +struct Scope; +struct TypeArena; + +enum class OccursCheckResult +{ + Pass, + Fail +}; + +struct Unifier2 +{ + NotNull arena; + NotNull builtinTypes; + NotNull scope; + NotNull ice; + TypeCheckLimits limits; + + DenseHashSet, TypePairHash> seenTypePairings{{nullptr, nullptr}}; + DenseHashSet, TypePairHash> seenTypePackPairings{{nullptr, nullptr}}; + + DenseHashMap> expandedFreeTypes{nullptr}; + + // Mapping from generic types to free types to be used in instantiation. + DenseHashMap genericSubstitutions{nullptr}; + // Mapping from generic type packs to `TypePack`s of free types to be used in instantiation. + DenseHashMap genericPackSubstitutions{nullptr}; + + int recursionCount = 0; + int recursionLimit = 0; + + std::vector incompleteSubtypes; + // null if not in a constraint solving context + DenseHashSet* uninhabitedTypeFunctions; + + Unifier2(NotNull arena, NotNull builtinTypes, NotNull scope, NotNull ice); + Unifier2( + NotNull arena, + NotNull builtinTypes, + NotNull scope, + NotNull ice, + DenseHashSet* uninhabitedTypeFunctions + ); + + /** Attempt to commit the subtype relation subTy <: superTy to the type + * graph. + * + * @returns true if successful. + * + * Note that incoherent types can and will successfully be unified. We stop + * when we *cannot know* how to relate the provided types, not when doing so + * would narrow something down to never or broaden it to unknown. + * + * Presently, the only way unification can fail is if we attempt to bind one + * free TypePack to another and encounter an occurs check violation. + */ + bool unify(TypeId subTy, TypeId superTy); + bool unifyFreeWithType(TypeId subTy, TypeId superTy); + bool unify(TypeId subTy, const FunctionType* superFn); + bool unify(const UnionType* subUnion, TypeId superTy); + bool unify(TypeId subTy, const UnionType* superUnion); + bool unify(const IntersectionType* subIntersection, TypeId superTy); + bool unify(TypeId subTy, const IntersectionType* superIntersection); + bool unify(TableType* subTable, const TableType* superTable); + bool unify(const MetatableType* subMetatable, const MetatableType* superMetatable); + + bool unify(const AnyType* subAny, const FunctionType* superFn); + bool unify(const FunctionType* subFn, const AnyType* superAny); + bool unify(const AnyType* subAny, const TableType* superTable); + bool unify(const TableType* subTable, const AnyType* superAny); + + // TODO think about this one carefully. We don't do unions or intersections of type packs + bool unify(TypePackId subTp, TypePackId superTp); + + std::optional generalize(TypeId ty); + +private: + /** + * @returns simplify(left | right) + */ + TypeId mkUnion(TypeId left, TypeId right); + + /** + * @returns simplify(left & right) + */ + TypeId mkIntersection(TypeId left, TypeId right); + + // Returns true if needle occurs within haystack already. ie if we bound + // needle to haystack, would a cyclic type result? + OccursCheckResult occursCheck(DenseHashSet& seen, TypeId needle, TypeId haystack); + + // Returns true if needle occurs within haystack already. ie if we bound + // needle to haystack, would a cyclic TypePack result? + OccursCheckResult occursCheck(DenseHashSet& seen, TypePackId needle, TypePackId haystack); +}; + +} // namespace Luau diff --git a/Analysis/include/Luau/UnifierSharedState.h b/Analysis/include/Luau/UnifierSharedState.h index ada56ec56..de69c17cf 100644 --- a/Analysis/include/Luau/UnifierSharedState.h +++ b/Analysis/include/Luau/UnifierSharedState.h @@ -3,8 +3,7 @@ #include "Luau/DenseHash.h" #include "Luau/Error.h" -#include "Luau/Type.h" -#include "Luau/TypePack.h" +#include "Luau/TypeFwd.h" #include diff --git a/Analysis/include/Luau/VisitType.h b/Analysis/include/Luau/VisitType.h index ff4dfc3c3..7202c1005 100644 --- a/Analysis/include/Luau/VisitType.h +++ b/Analysis/include/Luau/VisitType.h @@ -7,8 +7,11 @@ #include "Luau/RecursionCounter.h" #include "Luau/TypePack.h" #include "Luau/Type.h" +#include "Type.h" LUAU_FASTINT(LuauVisitRecursionLimit) +LUAU_FASTFLAG(LuauBoundLazyTypes2) +LUAU_FASTFLAG(LuauSolverV2) namespace Luau { @@ -62,6 +65,9 @@ inline void unsee(DenseHashSet& seen, const void* tv) } // namespace visit_detail +// recursion counter is equivalent here, but we'd like a better name to express the intent. +using TypeFunctionDepthCounter = RecursionCounter; + template struct GenericTypeVisitor { @@ -70,6 +76,7 @@ struct GenericTypeVisitor Set seen; bool skipBoundTypes = false; int recursionCounter = 0; + int typeFunctionDepth = 0; GenericTypeVisitor() = default; @@ -126,6 +133,10 @@ struct GenericTypeVisitor { return visit(ty); } + virtual bool visit(TypeId ty, const NoRefineType& nrt) + { + return visit(ty); + } virtual bool visit(TypeId ty, const UnknownType& utv) { return visit(ty); @@ -158,6 +169,10 @@ struct GenericTypeVisitor { return visit(ty); } + virtual bool visit(TypeId ty, const TypeFunctionInstanceType& tfit) + { + return visit(ty); + } virtual bool visit(TypePackId tp) { @@ -191,6 +206,10 @@ struct GenericTypeVisitor { return visit(tp); } + virtual bool visit(TypePackId tp, const TypeFunctionInstanceTypePack& tfitp) + { + return visit(tp); + } void traverse(TypeId ty) { @@ -210,7 +229,26 @@ struct GenericTypeVisitor traverse(btv->boundTo); } else if (auto ftv = get(ty)) - visit(ty, *ftv); + { + if (FFlag::LuauSolverV2) + { + if (visit(ty, *ftv)) + { + // TODO: Replace these if statements with assert()s when we + // delete FFlag::LuauSolverV2. + // + // When the old solver is used, these pointers are always + // unused. When the new solver is used, they are never null. + if (ftv->lowerBound) + traverse(ftv->lowerBound); + + if (ftv->upperBound) + traverse(ftv->upperBound); + } + } + else + visit(ty, *ftv); + } else if (auto gtv = get(ty)) visit(ty, *gtv); else if (auto etv = get(ty)) @@ -241,7 +279,22 @@ struct GenericTypeVisitor else { for (auto& [_name, prop] : ttv->props) - traverse(prop.type); + { + if (FFlag::LuauSolverV2) + { + if (auto ty = prop.readTy) + traverse(*ty); + + // In the case that the readType and the writeType + // are the same pointer, just traverse once. + // Traversing each property twice has pretty + // significant performance consequences. + if (auto ty = prop.writeTy; ty && !prop.isShared()) + traverse(*ty); + } + else + traverse(prop.type()); + } if (ttv->indexer) { @@ -264,36 +317,84 @@ struct GenericTypeVisitor if (visit(ty, *ctv)) { for (const auto& [name, prop] : ctv->props) - traverse(prop.type); + { + if (FFlag::LuauSolverV2) + { + if (auto ty = prop.readTy) + traverse(*ty); + + // In the case that the readType and the writeType are + // the same pointer, just traverse once. Traversing each + // property twice would have pretty significant + // performance consequences. + if (auto ty = prop.writeTy; ty && !prop.isShared()) + traverse(*ty); + } + else + traverse(prop.type()); + } if (ctv->parent) traverse(*ctv->parent); if (ctv->metatable) traverse(*ctv->metatable); + + if (ctv->indexer) + { + traverse(ctv->indexer->indexType); + traverse(ctv->indexer->indexResultType); + } } } else if (auto atv = get(ty)) visit(ty, *atv); + else if (auto nrt = get(ty)) + visit(ty, *nrt); else if (auto utv = get(ty)) { if (visit(ty, *utv)) { + bool unionChanged = false; for (TypeId optTy : utv->options) + { traverse(optTy); + if (!get(follow(ty))) + { + unionChanged = true; + break; + } + } + + if (unionChanged) + traverse(ty); } } else if (auto itv = get(ty)) { if (visit(ty, *itv)) { + bool intersectionChanged = false; for (TypeId partTy : itv->parts) + { traverse(partTy); + if (!get(follow(ty))) + { + intersectionChanged = true; + break; + } + } + + if (intersectionChanged) + traverse(ty); } } - else if (get(ty)) + else if (auto ltv = get(ty)) { - // Visiting into LazyType may necessarily cause infinite expansion, so we don't do that on purpose. + if (TypeId unwrapped = ltv->unwrapped) + traverse(unwrapped); + + // Visiting into LazyType that hasn't been unwrapped may necessarily cause infinite expansion, so we don't do that on purpose. // Asserting also makes no sense, because the type _will_ happen here, most likely as a property of some ClassType // that doesn't need to be expanded. } @@ -321,6 +422,19 @@ struct GenericTypeVisitor if (visit(ty, *ntv)) traverse(ntv->ty); } + else if (auto tfit = get(ty)) + { + TypeFunctionDepthCounter tfdc{&typeFunctionDepth}; + + if (visit(ty, *tfit)) + { + for (TypeId p : tfit->typeArguments) + traverse(p); + + for (TypePackId p : tfit->packArguments) + traverse(p); + } + } else LUAU_ASSERT(!"GenericTypeVisitor::traverse(TypeId) is not exhaustive!"); @@ -341,10 +455,10 @@ struct GenericTypeVisitor traverse(btv->boundTo); } - else if (auto ftv = get(tp)) + else if (auto ftv = get(tp)) visit(tp, *ftv); - else if (auto gtv = get(tp)) + else if (auto gtv = get(tp)) visit(tp, *gtv); else if (auto etv = get(tp)) @@ -370,6 +484,19 @@ struct GenericTypeVisitor } else if (auto btp = get(tp)) visit(tp, *btp); + else if (auto tfitp = get(tp)) + { + TypeFunctionDepthCounter tfdc{&typeFunctionDepth}; + + if (visit(tp, *tfitp)) + { + for (TypeId t : tfitp->typeArguments) + traverse(t); + + for (TypePackId t : tfitp->packArguments) + traverse(t); + } + } else LUAU_ASSERT(!"GenericTypeVisitor::traverse(TypePackId) is not exhaustive!"); diff --git a/Analysis/src/AnyTypeSummary.cpp b/Analysis/src/AnyTypeSummary.cpp new file mode 100644 index 000000000..85f567aff --- /dev/null +++ b/Analysis/src/AnyTypeSummary.cpp @@ -0,0 +1,903 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/AnyTypeSummary.h" + +#include "Luau/BuiltinDefinitions.h" +#include "Luau/Clone.h" +#include "Luau/Common.h" +#include "Luau/Config.h" +#include "Luau/ConstraintGenerator.h" +#include "Luau/ConstraintSolver.h" +#include "Luau/DataFlowGraph.h" +#include "Luau/DcrLogger.h" +#include "Luau/Module.h" +#include "Luau/Parser.h" +#include "Luau/Scope.h" +#include "Luau/StringUtils.h" +#include "Luau/TimeTrace.h" +#include "Luau/ToString.h" +#include "Luau/Transpiler.h" +#include "Luau/TypeArena.h" +#include "Luau/TypeChecker2.h" +#include "Luau/NonStrictTypeChecker.h" +#include "Luau/TypeInfer.h" +#include "Luau/Variant.h" +#include "Luau/VisitType.h" +#include "Luau/TypePack.h" +#include "Luau/TypeOrPack.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + + +#include + +LUAU_FASTFLAGVARIABLE(StudioReportLuauAny2, false); +LUAU_FASTINTVARIABLE(LuauAnySummaryRecursionLimit, 300); + +LUAU_FASTFLAG(DebugLuauMagicTypes); + +namespace Luau +{ + +void AnyTypeSummary::traverse(const Module* module, AstStat* src, NotNull builtinTypes) +{ + visit(findInnerMostScope(src->location, module), src, module, builtinTypes); +} + +void AnyTypeSummary::visit(const Scope* scope, AstStat* stat, const Module* module, NotNull builtinTypes) +{ + RecursionLimiter limiter{&recursionCount, FInt::LuauAnySummaryRecursionLimit}; + + if (auto s = stat->as()) + return visit(scope, s, module, builtinTypes); + else if (auto i = stat->as()) + return visit(scope, i, module, builtinTypes); + else if (auto s = stat->as()) + return visit(scope, s, module, builtinTypes); + else if (auto s = stat->as()) + return visit(scope, s, module, builtinTypes); + else if (auto r = stat->as()) + return visit(scope, r, module, builtinTypes); + else if (auto e = stat->as()) + return visit(scope, e, module, builtinTypes); + else if (auto s = stat->as()) + return visit(scope, s, module, builtinTypes); + else if (auto s = stat->as()) + return visit(scope, s, module, builtinTypes); + else if (auto s = stat->as()) + return visit(scope, s, module, builtinTypes); + else if (auto a = stat->as()) + return visit(scope, a, module, builtinTypes); + else if (auto a = stat->as()) + return visit(scope, a, module, builtinTypes); + else if (auto f = stat->as()) + return visit(scope, f, module, builtinTypes); + else if (auto f = stat->as()) + return visit(scope, f, module, builtinTypes); + else if (auto a = stat->as()) + return visit(scope, a, module, builtinTypes); + else if (auto s = stat->as()) + return visit(scope, s, module, builtinTypes); + else if (auto s = stat->as()) + return visit(scope, s, module, builtinTypes); + else if (auto s = stat->as()) + return visit(scope, s, module, builtinTypes); + else if (auto s = stat->as()) + return visit(scope, s, module, builtinTypes); +} + +void AnyTypeSummary::visit(const Scope* scope, AstStatBlock* block, const Module* module, NotNull builtinTypes) +{ + RecursionCounter counter{&recursionCount}; + + if (recursionCount >= FInt::LuauAnySummaryRecursionLimit) + return; // don't report + + for (AstStat* stat : block->body) + visit(scope, stat, module, builtinTypes); +} + +void AnyTypeSummary::visit(const Scope* scope, AstStatIf* ifStatement, const Module* module, NotNull builtinTypes) +{ + if (ifStatement->thenbody) + { + const Scope* thenScope = findInnerMostScope(ifStatement->thenbody->location, module); + visit(thenScope, ifStatement->thenbody, module, builtinTypes); + } + + if (ifStatement->elsebody) + { + const Scope* elseScope = findInnerMostScope(ifStatement->elsebody->location, module); + visit(elseScope, ifStatement->elsebody, module, builtinTypes); + } +} + +void AnyTypeSummary::visit(const Scope* scope, AstStatWhile* while_, const Module* module, NotNull builtinTypes) +{ + const Scope* whileScope = findInnerMostScope(while_->location, module); + visit(whileScope, while_->body, module, builtinTypes); +} + +void AnyTypeSummary::visit(const Scope* scope, AstStatRepeat* repeat, const Module* module, NotNull builtinTypes) +{ + const Scope* repeatScope = findInnerMostScope(repeat->location, module); + visit(repeatScope, repeat->body, module, builtinTypes); +} + +void AnyTypeSummary::visit(const Scope* scope, AstStatReturn* ret, const Module* module, NotNull builtinTypes) +{ + const Scope* retScope = findInnerMostScope(ret->location, module); + + auto ctxNode = getNode(rootSrc, ret); + bool seenTP = false; + + for (auto val : ret->list) + { + if (isAnyCall(retScope, val, module, builtinTypes)) + { + TelemetryTypePair types; + types.inferredType = toString(lookupType(val, module, builtinTypes)); + TypeInfo ti{Pattern::FuncApp, toString(ctxNode), types}; + typeInfo.push_back(ti); + } + + if (isAnyCast(retScope, val, module, builtinTypes)) + { + if (auto cast = val->as()) + { + TelemetryTypePair types; + + types.annotatedType = toString(lookupAnnotation(cast->annotation, module, builtinTypes)); + types.inferredType = toString(lookupType(cast->expr, module, builtinTypes)); + + TypeInfo ti{Pattern::Casts, toString(ctxNode), types}; + typeInfo.push_back(ti); + } + } + + if (ret->list.size > 1 && !seenTP) + { + if (containsAny(retScope->returnType)) + { + seenTP = true; + + TelemetryTypePair types; + + types.inferredType = toString(retScope->returnType); + + TypeInfo ti{Pattern::TypePk, toString(ctxNode), types}; + typeInfo.push_back(ti); + } + } + } + +} + +void AnyTypeSummary::visit(const Scope* scope, AstStatLocal* local, const Module* module, NotNull builtinTypes) +{ + auto ctxNode = getNode(rootSrc, local); + + TypePackId values = reconstructTypePack(local->values, module, builtinTypes); + auto [head, tail] = flatten(values); + + size_t posn = 0; + for (AstLocal* loc : local->vars) + { + if (local->vars.data[0] == loc && posn < local->values.size) + { + if (loc->annotation) + { + auto annot = lookupAnnotation(loc->annotation, module, builtinTypes); + if (containsAny(annot)) + { + TelemetryTypePair types; + + types.annotatedType = toString(annot); + types.inferredType = toString(lookupType(local->values.data[posn], module, builtinTypes)); + + TypeInfo ti{Pattern::VarAnnot, toString(ctxNode), types}; + typeInfo.push_back(ti); + } + } + + const AstExprTypeAssertion* maybeRequire = local->values.data[posn]->as(); + if (!maybeRequire) + continue; + + if (std::min(local->values.size - 1, posn) < head.size()) + { + if (isAnyCast(scope, local->values.data[posn], module, builtinTypes)) + { + TelemetryTypePair types; + + types.inferredType = toString(head[std::min(local->values.size - 1, posn)]); + + TypeInfo ti{Pattern::Casts, toString(ctxNode), types}; + typeInfo.push_back(ti); + } + } + } + else + { + + if (std::min(local->values.size - 1, posn) < head.size()) + { + if (loc->annotation) + { + auto annot = lookupAnnotation(loc->annotation, module, builtinTypes); + if (containsAny(annot)) + { + TelemetryTypePair types; + + types.annotatedType = toString(annot); + types.inferredType = toString(head[std::min(local->values.size - 1, posn)]); + + TypeInfo ti{Pattern::VarAnnot, toString(ctxNode), types}; + typeInfo.push_back(ti); + } + } + } + else + { + if (tail) + { + if (containsAny(*tail)) + { + TelemetryTypePair types; + + types.inferredType = toString(*tail); + + TypeInfo ti{Pattern::VarAny, toString(ctxNode), types}; + typeInfo.push_back(ti); + } + } + } + } + + ++posn; + } +} + +void AnyTypeSummary::visit(const Scope* scope, AstStatFor* for_, const Module* module, NotNull builtinTypes) +{ + const Scope* forScope = findInnerMostScope(for_->location, module); + visit(forScope, for_->body, module, builtinTypes); +} + +void AnyTypeSummary::visit(const Scope* scope, AstStatForIn* forIn, const Module* module, NotNull builtinTypes) +{ + const Scope* loopScope = findInnerMostScope(forIn->location, module); + visit(loopScope, forIn->body, module, builtinTypes); +} + +void AnyTypeSummary::visit(const Scope* scope, AstStatAssign* assign, const Module* module, NotNull builtinTypes) +{ + auto ctxNode = getNode(rootSrc, assign); + + TypePackId values = reconstructTypePack(assign->values, module, builtinTypes); + auto [head, tail] = flatten(values); + + size_t posn = 0; + for (AstExpr* var : assign->vars) + { + TypeId tp = lookupType(var, module, builtinTypes); + if (containsAny(tp)) + { + TelemetryTypePair types; + + types.annotatedType = toString(tp); + + auto loc = std::min(assign->vars.size - 1, posn); + if (head.size() >= assign->vars.size && posn < head.size()) + { + types.inferredType = toString(head[posn]); + } + else if (loc < head.size()) + types.inferredType = toString(head[loc]); + else + types.inferredType = toString(builtinTypes->nilType); + + TypeInfo ti{Pattern::Assign, toString(ctxNode), types}; + typeInfo.push_back(ti); + } + ++posn; + } + + for (AstExpr* val : assign->values) + { + if (isAnyCall(scope, val, module, builtinTypes)) + { + TelemetryTypePair types; + + types.inferredType = toString(lookupType(val, module, builtinTypes)); + + TypeInfo ti{Pattern::FuncApp, toString(ctxNode), types}; + typeInfo.push_back(ti); + } + + if (isAnyCast(scope, val, module, builtinTypes)) + { + if (auto cast = val->as()) + { + TelemetryTypePair types; + + types.annotatedType = toString(lookupAnnotation(cast->annotation, module, builtinTypes)); + types.inferredType = toString(lookupType(val, module, builtinTypes)); + + TypeInfo ti{Pattern::Casts, toString(ctxNode), types}; + typeInfo.push_back(ti); + } + } + } + + if (tail) + { + if (containsAny(*tail)) + { + TelemetryTypePair types; + + types.inferredType = toString(*tail); + + TypeInfo ti{Pattern::Assign, toString(ctxNode), types}; + typeInfo.push_back(ti); + } + } +} + +void AnyTypeSummary::visit(const Scope* scope, AstStatCompoundAssign* assign, const Module* module, NotNull builtinTypes) +{ + auto ctxNode = getNode(rootSrc, assign); + + TelemetryTypePair types; + + types.inferredType = toString(lookupType(assign->value, module, builtinTypes)); + types.annotatedType = toString(lookupType(assign->var, module, builtinTypes)); + + if (module->astTypes.contains(assign->var)) + { + if (containsAny(*module->astTypes.find(assign->var))) + { + TypeInfo ti{Pattern::Assign, toString(ctxNode), types}; + typeInfo.push_back(ti); + } + } + else if (module->astTypePacks.contains(assign->var)) + { + if (containsAny(*module->astTypePacks.find(assign->var))) + { + TypeInfo ti{Pattern::Assign, toString(ctxNode), types}; + typeInfo.push_back(ti); + } + } + + if (isAnyCall(scope, assign->value, module, builtinTypes)) + { + TypeInfo ti{Pattern::FuncApp, toString(ctxNode), types}; + typeInfo.push_back(ti); + } + + if (isAnyCast(scope, assign->value, module, builtinTypes)) + { + if (auto cast = assign->value->as()) + { + types.annotatedType = toString(lookupAnnotation(cast->annotation, module, builtinTypes)); + types.inferredType = toString(lookupType(cast->expr, module, builtinTypes)); + + TypeInfo ti{Pattern::Casts, toString(ctxNode), types}; + typeInfo.push_back(ti); + } + } +} + +void AnyTypeSummary::visit(const Scope* scope, AstStatFunction* function, const Module* module, NotNull builtinTypes) +{ + TelemetryTypePair types; + types.inferredType = toString(lookupType(function->func, module, builtinTypes)); + + if (hasVariadicAnys(scope, function->func, module, builtinTypes)) + { + TypeInfo ti{Pattern::VarAny, toString(function), types}; + typeInfo.push_back(ti); + } + + if (hasArgAnys(scope, function->func, module, builtinTypes)) + { + TypeInfo ti{Pattern::FuncArg, toString(function), types}; + typeInfo.push_back(ti); + } + + if (hasAnyReturns(scope, function->func, module, builtinTypes)) + { + TypeInfo ti{Pattern::FuncRet, toString(function), types}; + typeInfo.push_back(ti); + } + + if (function->func->body->body.size > 0) + visit(scope, function->func->body, module, builtinTypes); +} + +void AnyTypeSummary::visit(const Scope* scope, AstStatLocalFunction* function, const Module* module, NotNull builtinTypes) +{ + TelemetryTypePair types; + + if (hasVariadicAnys(scope, function->func, module, builtinTypes)) + { + types.inferredType = toString(lookupType(function->func, module, builtinTypes)); + TypeInfo ti{Pattern::VarAny, toString(function), types}; + typeInfo.push_back(ti); + } + + if (hasArgAnys(scope, function->func, module, builtinTypes)) + { + types.inferredType = toString(lookupType(function->func, module, builtinTypes)); + TypeInfo ti{Pattern::FuncArg, toString(function), types}; + typeInfo.push_back(ti); + } + + if (hasAnyReturns(scope, function->func, module, builtinTypes)) + { + types.inferredType = toString(lookupType(function->func, module, builtinTypes)); + TypeInfo ti{Pattern::FuncRet, toString(function), types}; + typeInfo.push_back(ti); + } + + if (function->func->body->body.size > 0) + visit(scope, function->func->body, module, builtinTypes); +} + +void AnyTypeSummary::visit(const Scope* scope, AstStatTypeAlias* alias, const Module* module, NotNull builtinTypes) +{ + auto ctxNode = getNode(rootSrc, alias); + + auto annot = lookupAnnotation(alias->type, module, builtinTypes); + if (containsAny(annot)) + { + // no expr => no inference for aliases + TelemetryTypePair types; + + types.annotatedType = toString(annot); + TypeInfo ti{Pattern::Alias, toString(ctxNode), types}; + typeInfo.push_back(ti); + } +} + +void AnyTypeSummary::visit(const Scope* scope, AstStatExpr* expr, const Module* module, NotNull builtinTypes) +{ + auto ctxNode = getNode(rootSrc, expr); + + if (isAnyCall(scope, expr->expr, module, builtinTypes)) + { + TelemetryTypePair types; + types.inferredType = toString(lookupType(expr->expr, module, builtinTypes)); + + TypeInfo ti{Pattern::FuncApp, toString(ctxNode), types}; + typeInfo.push_back(ti); + } +} + +void AnyTypeSummary::visit(const Scope* scope, AstStatDeclareGlobal* declareGlobal, const Module* module, NotNull builtinTypes) {} + +void AnyTypeSummary::visit(const Scope* scope, AstStatDeclareClass* declareClass, const Module* module, NotNull builtinTypes) {} + +void AnyTypeSummary::visit(const Scope* scope, AstStatDeclareFunction* declareFunction, const Module* module, NotNull builtinTypes) {} + +void AnyTypeSummary::visit(const Scope* scope, AstStatError* error, const Module* module, NotNull builtinTypes) {} + +TypeId AnyTypeSummary::checkForFamilyInhabitance(const TypeId instance, const Location location) +{ + if (seenTypeFamilyInstances.find(instance)) + return instance; + + seenTypeFamilyInstances.insert(instance); + return instance; +} + +TypeId AnyTypeSummary::lookupType(const AstExpr* expr, const Module* module, NotNull builtinTypes) +{ + const TypeId* ty = module->astTypes.find(expr); + if (ty) + return checkForFamilyInhabitance(follow(*ty), expr->location); + + const TypePackId* tp = module->astTypePacks.find(expr); + if (tp) + { + if (auto fst = first(*tp, /*ignoreHiddenVariadics*/ false)) + return checkForFamilyInhabitance(*fst, expr->location); + else if (finite(*tp) && size(*tp) == 0) + return checkForFamilyInhabitance(builtinTypes->nilType, expr->location); + } + + return builtinTypes->errorRecoveryType(); +} + +TypePackId AnyTypeSummary::reconstructTypePack(AstArray exprs, const Module* module, NotNull builtinTypes) +{ + if (exprs.size == 0) + return arena.addTypePack(TypePack{{}, std::nullopt}); + + std::vector head; + + for (size_t i = 0; i < exprs.size - 1; ++i) + { + head.push_back(lookupType(exprs.data[i], module, builtinTypes)); + } + + const TypePackId* tail = module->astTypePacks.find(exprs.data[exprs.size - 1]); + if (tail) + return arena.addTypePack(TypePack{std::move(head), follow(*tail)}); + else + return arena.addTypePack(TypePack{std::move(head), builtinTypes->errorRecoveryTypePack()}); +} + +bool AnyTypeSummary::isAnyCall(const Scope* scope, AstExpr* expr, const Module* module, NotNull builtinTypes) +{ + if (auto call = expr->as()) + { + TypePackId args = reconstructTypePack(call->args, module, builtinTypes); + if (containsAny(args)) + return true; + + TypeId func = lookupType(call->func, module, builtinTypes); + if (containsAny(func)) + return true; + } + return false; +} + +bool AnyTypeSummary::hasVariadicAnys(const Scope* scope, AstExprFunction* expr, const Module* module, NotNull builtinTypes) +{ + if (expr->vararg && expr->varargAnnotation) + { + auto annot = lookupPackAnnotation(expr->varargAnnotation, module); + if (annot && containsAny(*annot)) + { + return true; + } + } + return false; +} + +bool AnyTypeSummary::hasArgAnys(const Scope* scope, AstExprFunction* expr, const Module* module, NotNull builtinTypes) +{ + if (expr->args.size > 0) + { + for (const AstLocal* arg : expr->args) + { + if (arg->annotation) + { + auto annot = lookupAnnotation(arg->annotation, module, builtinTypes); + if (containsAny(annot)) + { + return true; + } + } + } + } + return false; +} + +bool AnyTypeSummary::hasAnyReturns(const Scope* scope, AstExprFunction* expr, const Module* module, NotNull builtinTypes) +{ + if (!expr->returnAnnotation) + { + return false; + } + + for (AstType* ret : expr->returnAnnotation->types) + { + if (containsAny(lookupAnnotation(ret, module, builtinTypes))) + { + return true; + } + } + + if (expr->returnAnnotation->tailType) + { + auto annot = lookupPackAnnotation(expr->returnAnnotation->tailType, module); + if (annot && containsAny(*annot)) + { + return true; + } + } + + return false; +} + +bool AnyTypeSummary::isAnyCast(const Scope* scope, AstExpr* expr, const Module* module, NotNull builtinTypes) +{ + if (auto cast = expr->as()) + { + auto annot = lookupAnnotation(cast->annotation, module, builtinTypes); + if (containsAny(annot)) + { + return true; + } + } + return false; +} + +TypeId AnyTypeSummary::lookupAnnotation(AstType* annotation, const Module* module, NotNull builtintypes) +{ + if (FFlag::DebugLuauMagicTypes) + { + if (auto ref = annotation->as(); ref && ref->parameters.size > 0) + { + if (auto ann = ref->parameters.data[0].type) + { + TypeId argTy = lookupAnnotation(ref->parameters.data[0].type, module, builtintypes); + return follow(argTy); + } + } + } + + const TypeId* ty = module->astResolvedTypes.find(annotation); + if (ty) + return checkForTypeFunctionInhabitance(follow(*ty), annotation->location); + else + return checkForTypeFunctionInhabitance(builtintypes->errorRecoveryType(), annotation->location); +} + +TypeId AnyTypeSummary::checkForTypeFunctionInhabitance(const TypeId instance, const Location location) +{ + if (seenTypeFunctionInstances.find(instance)) + return instance; + seenTypeFunctionInstances.insert(instance); + + return instance; +} + +std::optional AnyTypeSummary::lookupPackAnnotation(AstTypePack* annotation, const Module* module) +{ + const TypePackId* tp = module->astResolvedTypePacks.find(annotation); + if (tp != nullptr) + return {follow(*tp)}; + return {}; +} + +bool AnyTypeSummary::containsAny(TypeId typ) +{ + typ = follow(typ); + + if (auto t = seen.find(typ); t && !*t) + { + return false; + } + + seen[typ] = false; + + RecursionCounter counter{&recursionCount}; + if (recursionCount >= FInt::LuauAnySummaryRecursionLimit) + { + return false; + } + + bool found = false; + + if (auto ty = get(typ)) + { + found = true; + } + else if (auto ty = get(typ)) + { + found = true; + } + else if (auto ty = get(typ)) + { + for (auto& [_name, prop] : ty->props) + { + if (FFlag::LuauSolverV2) + { + if (auto newT = follow(prop.readTy)) + { + if (containsAny(*newT)) + found = true; + } + else if (auto newT = follow(prop.writeTy)) + { + if (containsAny(*newT)) + found = true; + } + } + else + { + if (containsAny(prop.type())) + found = true; + } + } + } + else if (auto ty = get(typ)) + { + for (auto part : ty->parts) + { + if (containsAny(part)) + { + found = true; + } + } + } + else if (auto ty = get(typ)) + { + for (auto option : ty->options) + { + if (containsAny(option)) + { + found = true; + } + } + } + else if (auto ty = get(typ)) + { + if (containsAny(ty->argTypes)) + found = true; + else if (containsAny(ty->retTypes)) + found = true; + } + + seen[typ] = found; + + return found; +} + +bool AnyTypeSummary::containsAny(TypePackId typ) +{ + typ = follow(typ); + + if (auto t = seen.find(typ); t && !*t) + { + return false; + } + + seen[typ] = false; + + auto [head, tail] = flatten(typ); + bool found = false; + + for (auto tp : head) + { + if (containsAny(tp)) + found = true; + } + + if (tail) + { + if (auto vtp = get(tail)) + { + if (auto ty = get(follow(vtp->ty))) + { + found = true; + } + } + else if (auto tftp = get(tail)) + { + + for (TypePackId tp : tftp->packArguments) + { + if (containsAny(tp)) + { + found = true; + } + } + + for (TypeId t : tftp->typeArguments) + { + if (containsAny(t)) + { + found = true; + } + } + } + } + + seen[typ] = found; + + return found; +} + +const Scope* AnyTypeSummary::findInnerMostScope(const Location location, const Module* module) +{ + const Scope* bestScope = module->getModuleScope().get(); + + bool didNarrow = false; + do + { + didNarrow = false; + for (auto scope : bestScope->children) + { + if (scope->location.encloses(location)) + { + bestScope = scope.get(); + didNarrow = true; + break; + } + } + } while (didNarrow && bestScope->children.size() > 0); + + return bestScope; +} + +std::optional AnyTypeSummary::matchRequire(const AstExprCall& call) +{ + const char* require = "require"; + + if (call.args.size != 1) + return std::nullopt; + + const AstExprGlobal* funcAsGlobal = call.func->as(); + if (!funcAsGlobal || funcAsGlobal->name != require) + return std::nullopt; + + if (call.args.size != 1) + return std::nullopt; + + return call.args.data[0]; +} + +AstNode* AnyTypeSummary::getNode(AstStatBlock* root, AstNode* node) +{ + FindReturnAncestry finder(node, root->location.end); + root->visit(&finder); + + if (!finder.currNode) + finder.currNode = node; + + LUAU_ASSERT(finder.found && finder.currNode); + return finder.currNode; +} + +bool AnyTypeSummary::FindReturnAncestry::visit(AstStatLocalFunction* node) +{ + currNode = node; + return !found; +} + +bool AnyTypeSummary::FindReturnAncestry::visit(AstStatFunction* node) +{ + currNode = node; + return !found; +} + +bool AnyTypeSummary::FindReturnAncestry::visit(AstType* node) +{ + return !found; +} + +bool AnyTypeSummary::FindReturnAncestry::visit(AstNode* node) +{ + if (node == stat) + { + found = true; + } + + if (node->location.end == rootEnd && stat->location.end >= rootEnd) + { + currNode = node; + found = true; + } + + return !found; +} + + +AnyTypeSummary::TypeInfo::TypeInfo(Pattern code, std::string node, TelemetryTypePair type) + : code(code) + , node(node) + , type(type) +{ +} + +AnyTypeSummary::FindReturnAncestry::FindReturnAncestry(AstNode* stat, Position rootEnd) + : stat(stat) + , rootEnd(rootEnd) +{ +} + +AnyTypeSummary::AnyTypeSummary() {} + +} // namespace Luau \ No newline at end of file diff --git a/Analysis/src/Anyification.cpp b/Analysis/src/Anyification.cpp index 15dd25cc5..4bacec039 100644 --- a/Analysis/src/Anyification.cpp +++ b/Analysis/src/Anyification.cpp @@ -6,13 +6,17 @@ #include "Luau/Normalize.h" #include "Luau/TxnLog.h" -LUAU_FASTFLAG(LuauClassTypeVarsInSubstitution) - namespace Luau { -Anyification::Anyification(TypeArena* arena, NotNull scope, NotNull builtinTypes, InternalErrorReporter* iceHandler, - TypeId anyType, TypePackId anyTypePack) +Anyification::Anyification( + TypeArena* arena, + NotNull scope, + NotNull builtinTypes, + InternalErrorReporter* iceHandler, + TypeId anyType, + TypePackId anyTypePack +) : Substitution(TxnLog::empty(), arena) , scope(scope) , builtinTypes(builtinTypes) @@ -22,8 +26,14 @@ Anyification::Anyification(TypeArena* arena, NotNull scope, NotNull builtinTypes, InternalErrorReporter* iceHandler, - TypeId anyType, TypePackId anyTypePack) +Anyification::Anyification( + TypeArena* arena, + const ScopePtr& scope, + NotNull builtinTypes, + InternalErrorReporter* iceHandler, + TypeId anyType, + TypePackId anyTypePack +) : Anyification(arena, NotNull{scope.get()}, builtinTypes, iceHandler, anyType, anyTypePack) { } @@ -78,7 +88,7 @@ TypePackId Anyification::clean(TypePackId tp) bool Anyification::ignoreChildren(TypeId ty) { - if (FFlag::LuauClassTypeVarsInSubstitution && get(ty)) + if (get(ty)) return true; return ty->persistent; diff --git a/Analysis/src/ApplyTypeFunction.cpp b/Analysis/src/ApplyTypeFunction.cpp index fe8cc8ac3..025e8f6db 100644 --- a/Analysis/src/ApplyTypeFunction.cpp +++ b/Analysis/src/ApplyTypeFunction.cpp @@ -2,8 +2,6 @@ #include "Luau/ApplyTypeFunction.h" -LUAU_FASTFLAG(LuauClassTypeVarsInSubstitution) - namespace Luau { @@ -33,7 +31,7 @@ bool ApplyTypeFunction::ignoreChildren(TypeId ty) { if (get(ty)) return true; - else if (FFlag::LuauClassTypeVarsInSubstitution && get(ty)) + else if (get(ty)) return true; else return false; diff --git a/Analysis/src/AstJsonEncoder.cpp b/Analysis/src/AstJsonEncoder.cpp index 62a9527bd..1a8edf170 100644 --- a/Analysis/src/AstJsonEncoder.cpp +++ b/Analysis/src/AstJsonEncoder.cpp @@ -198,6 +198,23 @@ struct AstJsonEncoder : public AstVisitor { writeString(name.value ? name.value : ""); } + void write(std::optional name) + { + if (name) + write(*name); + else + writeRaw("null"); + } + void write(AstArgumentName name) + { + writeRaw("{"); + bool c = pushComma(); + writeType("AstArgumentName"); + write("name", name.first); + write("location", name.second); + popComma(c); + writeRaw("}"); + } void write(const Position& position) { @@ -254,9 +271,14 @@ struct AstJsonEncoder : public AstVisitor void write(class AstExprGroup* node) { - writeNode(node, "AstExprGroup", [&]() { - write("expr", node->expr); - }); + writeNode( + node, + "AstExprGroup", + [&]() + { + write("expr", node->expr); + } + ); } void write(class AstExprConstantNil* node) @@ -266,37 +288,62 @@ struct AstJsonEncoder : public AstVisitor void write(class AstExprConstantBool* node) { - writeNode(node, "AstExprConstantBool", [&]() { - write("value", node->value); - }); + writeNode( + node, + "AstExprConstantBool", + [&]() + { + write("value", node->value); + } + ); } void write(class AstExprConstantNumber* node) { - writeNode(node, "AstExprConstantNumber", [&]() { - write("value", node->value); - }); + writeNode( + node, + "AstExprConstantNumber", + [&]() + { + write("value", node->value); + } + ); } void write(class AstExprConstantString* node) { - writeNode(node, "AstExprConstantString", [&]() { - write("value", node->value); - }); + writeNode( + node, + "AstExprConstantString", + [&]() + { + write("value", node->value); + } + ); } void write(class AstExprLocal* node) { - writeNode(node, "AstExprLocal", [&]() { - write("local", node->local); - }); + writeNode( + node, + "AstExprLocal", + [&]() + { + write("local", node->local); + } + ); } void write(class AstExprGlobal* node) { - writeNode(node, "AstExprGlobal", [&]() { - write("global", node->name); - }); + writeNode( + node, + "AstExprGlobal", + [&]() + { + write("global", node->name); + } + ); } void write(class AstExprVarargs* node) @@ -330,52 +377,72 @@ struct AstJsonEncoder : public AstVisitor void write(class AstExprCall* node) { - writeNode(node, "AstExprCall", [&]() { - PROP(func); - PROP(args); - PROP(self); - PROP(argLocation); - }); + writeNode( + node, + "AstExprCall", + [&]() + { + PROP(func); + PROP(args); + PROP(self); + PROP(argLocation); + } + ); } void write(class AstExprIndexName* node) { - writeNode(node, "AstExprIndexName", [&]() { - PROP(expr); - PROP(index); - PROP(indexLocation); - PROP(op); - }); + writeNode( + node, + "AstExprIndexName", + [&]() + { + PROP(expr); + PROP(index); + PROP(indexLocation); + PROP(op); + } + ); } void write(class AstExprIndexExpr* node) { - writeNode(node, "AstExprIndexExpr", [&]() { - PROP(expr); - PROP(index); - }); + writeNode( + node, + "AstExprIndexExpr", + [&]() + { + PROP(expr); + PROP(index); + } + ); } void write(class AstExprFunction* node) { - writeNode(node, "AstExprFunction", [&]() { - PROP(generics); - PROP(genericPacks); - if (node->self) - PROP(self); - PROP(args); - if (node->returnAnnotation) - PROP(returnAnnotation); - PROP(vararg); - PROP(varargLocation); - if (node->varargAnnotation) - PROP(varargAnnotation); - - PROP(body); - PROP(functionDepth); - PROP(debugname); - PROP(hasEnd); - }); + writeNode( + node, + "AstExprFunction", + [&]() + { + PROP(attributes); + PROP(generics); + PROP(genericPacks); + if (node->self) + PROP(self); + PROP(args); + if (node->returnAnnotation) + PROP(returnAnnotation); + PROP(vararg); + PROP(varargLocation); + if (node->varargAnnotation) + PROP(varargAnnotation); + + PROP(body); + PROP(functionDepth); + PROP(debugname); + } + ); } void write(const std::optional& typeList) @@ -457,28 +524,43 @@ struct AstJsonEncoder : public AstVisitor void write(class AstExprIfElse* node) { - writeNode(node, "AstExprIfElse", [&]() { - PROP(condition); - PROP(hasThen); - PROP(trueExpr); - PROP(hasElse); - PROP(falseExpr); - }); + writeNode( + node, + "AstExprIfElse", + [&]() + { + PROP(condition); + PROP(hasThen); + PROP(trueExpr); + PROP(hasElse); + PROP(falseExpr); + } + ); } void write(class AstExprInterpString* node) { - writeNode(node, "AstExprInterpString", [&]() { - PROP(strings); - PROP(expressions); - }); + writeNode( + node, + "AstExprInterpString", + [&]() + { + PROP(strings); + PROP(expressions); + } + ); } void write(class AstExprTable* node) { - writeNode(node, "AstExprTable", [&]() { - PROP(items); - }); + writeNode( + node, + "AstExprTable", + [&]() + { + PROP(items); + } + ); } void write(AstExprUnary::Op op) @@ -502,10 +584,15 @@ struct AstJsonEncoder : public AstVisitor void write(class AstExprUnary* node) { - writeNode(node, "AstExprUnary", [&]() { - PROP(op); - PROP(expr); - }); + writeNode( + node, + "AstExprUnary", + [&]() + { + PROP(op); + PROP(expr); + } + ); } void write(AstExprBinary::Op op) @@ -520,6 +607,8 @@ struct AstJsonEncoder : public AstVisitor return writeString("Mul"); case AstExprBinary::Div: return writeString("Div"); + case AstExprBinary::FloorDiv: + return writeString("FloorDiv"); case AstExprBinary::Mod: return writeString("Mod"); case AstExprBinary::Pow: @@ -542,8 +631,6 @@ struct AstJsonEncoder : public AstVisitor return writeString("And"); case AstExprBinary::Or: return writeString("Or"); - case AstExprBinary::DivInt: - return writeString("IDiv"); case AstExprBinary::MaxOf: return writeString("Max"); case AstExprBinary::MinOf: @@ -558,81 +645,117 @@ struct AstJsonEncoder : public AstVisitor return writeString("Shl"); case AstExprBinary::BinShiftR: return writeString("Shr"); + default: + LUAU_ASSERT(!"Unknown Op"); } } void write(class AstExprBinary* node) { - writeNode(node, "AstExprBinary", [&]() { - PROP(op); - PROP(left); - PROP(right); - }); + writeNode( + node, + "AstExprBinary", + [&]() + { + PROP(op); + PROP(left); + PROP(right); + } + ); } void write(class AstExprTypeAssertion* node) { - writeNode(node, "AstExprTypeAssertion", [&]() { - PROP(expr); - PROP(annotation); - }); + writeNode( + node, + "AstExprTypeAssertion", + [&]() + { + PROP(expr); + PROP(annotation); + } + ); } void write(class AstExprError* node) { - writeNode(node, "AstExprError", [&]() { - PROP(expressions); - PROP(messageIndex); - }); + writeNode( + node, + "AstExprError", + [&]() + { + PROP(expressions); + PROP(messageIndex); + } + ); } void write(class AstStatBlock* node) { - writeNode(node, "AstStatBlock", [&]() { - writeRaw(",\"body\":["); - bool comma = false; - for (AstStat* stat : node->body) + writeNode( + node, + "AstStatBlock", + [&]() { - if (comma) - writeRaw(","); - else - comma = true; - - write(stat); + writeRaw(",\"hasEnd\":"); + write(node->hasEnd); + writeRaw(",\"body\":["); + bool comma = false; + for (AstStat* stat : node->body) + { + if (comma) + writeRaw(","); + else + comma = true; + + write(stat); + } + writeRaw("]"); } - writeRaw("]"); - }); + ); } void write(class AstStatIf* node) { - writeNode(node, "AstStatIf", [&]() { - PROP(condition); - PROP(thenbody); - if (node->elsebody) - PROP(elsebody); - write("hasThen", node->thenLocation.has_value()); - PROP(hasEnd); - }); + writeNode( + node, + "AstStatIf", + [&]() + { + PROP(condition); + PROP(thenbody); + if (node->elsebody) + PROP(elsebody); + write("hasThen", node->thenLocation.has_value()); + } + ); } void write(class AstStatWhile* node) { - writeNode(node, "AstStatWhile", [&]() { - PROP(condition); - PROP(body); - PROP(hasDo); - PROP(hasEnd); - }); + writeNode( + node, + "AstStatWhile", + [&]() + { + PROP(condition); + PROP(body); + PROP(hasDo); + } + ); } void write(class AstStatRepeat* node) { - writeNode(node, "AstStatRepeat", [&]() { - PROP(condition); - PROP(body); - PROP(hasUntil); - }); + writeNode( + node, + "AstStatRepeat", + [&]() + { + PROP(condition); + PROP(body); + } + ); } void write(class AstStatBreak* node) @@ -647,113 +770,177 @@ struct AstJsonEncoder : public AstVisitor void write(class AstStatReturn* node) { - writeNode(node, "AstStatReturn", [&]() { - PROP(list); - }); + writeNode( + node, + "AstStatReturn", + [&]() + { + PROP(list); + } + ); } void write(class AstStatExpr* node) { - writeNode(node, "AstStatExpr", [&]() { - PROP(expr); - }); + writeNode( + node, + "AstStatExpr", + [&]() + { + PROP(expr); + } + ); } void write(class AstStatLocal* node) { - writeNode(node, "AstStatLocal", [&]() { - PROP(vars); - PROP(values); - }); + writeNode( + node, + "AstStatLocal", + [&]() + { + PROP(vars); + PROP(values); + } + ); } void write(class AstStatFor* node) { - writeNode(node, "AstStatFor", [&]() { - PROP(var); - PROP(from); - PROP(to); - if (node->step) - PROP(step); - PROP(body); - PROP(hasDo); - PROP(hasEnd); - }); + writeNode( + node, + "AstStatFor", + [&]() + { + PROP(var); + PROP(from); + PROP(to); + if (node->step) + PROP(step); + PROP(body); + PROP(hasDo); + } + ); } void write(class AstStatForIn* node) { - writeNode(node, "AstStatForIn", [&]() { - PROP(vars); - PROP(values); - PROP(body); - PROP(hasIn); - PROP(hasDo); - PROP(hasEnd); - }); + writeNode( + node, + "AstStatForIn", + [&]() + { + PROP(vars); + PROP(values); + PROP(body); + PROP(hasIn); + PROP(hasDo); + } + ); } void write(class AstStatAssign* node) { - writeNode(node, "AstStatAssign", [&]() { - PROP(vars); - PROP(values); - }); + writeNode( + node, + "AstStatAssign", + [&]() + { + PROP(vars); + PROP(values); + } + ); } void write(class AstStatCompoundAssign* node) { - writeNode(node, "AstStatCompoundAssign", [&]() { - PROP(op); - PROP(var); - PROP(value); - }); + writeNode( + node, + "AstStatCompoundAssign", + [&]() + { + PROP(op); + PROP(var); + PROP(value); + } + ); } void write(class AstStatFunction* node) { - writeNode(node, "AstStatFunction", [&]() { - PROP(name); - PROP(func); - }); + writeNode( + node, + "AstStatFunction", + [&]() + { + PROP(name); + PROP(func); + } + ); } void write(class AstStatLocalFunction* node) { - writeNode(node, "AstStatLocalFunction", [&]() { - PROP(name); - PROP(func); - }); + writeNode( + node, + "AstStatLocalFunction", + [&]() + { + PROP(name); + PROP(func); + } + ); } void write(class AstStatTypeAlias* node) { - writeNode(node, "AstStatTypeAlias", [&]() { - PROP(name); - PROP(generics); - PROP(genericPacks); - PROP(type); - PROP(exported); - }); + writeNode( + node, + "AstStatTypeAlias", + [&]() + { + PROP(name); + PROP(generics); + PROP(genericPacks); + write("value", node->type); + PROP(exported); + } + ); } void write(class AstStatDeclareFunction* node) { - writeNode(node, "AstStatDeclareFunction", [&]() { - PROP(name); - PROP(params); - PROP(retTypes); - PROP(generics); - PROP(genericPacks); - }); + writeNode( + node, + "AstStatDeclareFunction", + [&]() + { + PROP(attributes); + PROP(name); + PROP(nameLocation); + PROP(params); + PROP(paramNames); + PROP(vararg); + PROP(varargLocation); + PROP(retTypes); + PROP(generics); + PROP(genericPacks); + } + ); } void write(class AstStatDeclareGlobal* node) { - writeNode(node, "AstStatDeclareGlobal", [&]() { - PROP(name); - PROP(type); - }); + writeNode( + node, + "AstStatDeclareGlobal", + [&]() + { + PROP(name); + PROP(nameLocation); + PROP(type); + } + ); } void write(const AstDeclaredClassProp& prop) @@ -761,28 +948,41 @@ struct AstJsonEncoder : public AstVisitor writeRaw("{"); bool c = pushComma(); write("name", prop.name); + write("nameLocation", prop.nameLocation); writeType("AstDeclaredClassProp"); write("luauType", prop.ty); + write("location", prop.location); popComma(c); writeRaw("}"); } void write(class AstStatDeclareClass* node) { - writeNode(node, "AstStatDeclareClass", [&]() { - PROP(name); - if (node->superName) - write("superName", *node->superName); - PROP(props); - }); + writeNode( + node, + "AstStatDeclareClass", + [&]() + { + PROP(name); + if (node->superName) + write("superName", *node->superName); + PROP(props); + PROP(indexer); + } + ); } void write(class AstStatError* node) { - writeNode(node, "AstStatError", [&]() { - PROP(expressions); - PROP(statements); - }); + writeNode( + node, + "AstStatError", + [&]() + { + PROP(expressions); + PROP(statements); + } + ); } void write(struct AstTypeOrPack node) @@ -795,12 +995,20 @@ struct AstJsonEncoder : public AstVisitor void write(class AstTypeReference* node) { - writeNode(node, "AstTypeReference", [&]() { - if (node->prefix) - PROP(prefix); - PROP(name); - PROP(parameters); - }); + writeNode( + node, + "AstTypeReference", + [&]() + { + if (node->prefix) + PROP(prefix); + if (node->prefixLocation) + write("prefixLocation", *node->prefixLocation); + PROP(name); + PROP(nameLocation); + PROP(parameters); + } + ); } void write(const AstTableProp& prop) @@ -819,10 +1027,15 @@ struct AstJsonEncoder : public AstVisitor void write(class AstTypeTable* node) { - writeNode(node, "AstTypeTable", [&]() { - PROP(props); - PROP(indexer); - }); + writeNode( + node, + "AstTypeTable", + [&]() + { + PROP(props); + PROP(indexer); + } + ); } void write(struct AstTableIndexer* indexer) @@ -845,62 +1058,153 @@ struct AstJsonEncoder : public AstVisitor void write(class AstTypeFunction* node) { - writeNode(node, "AstTypeFunction", [&]() { - PROP(generics); - PROP(genericPacks); - PROP(argTypes); - PROP(returnTypes); - }); + writeNode( + node, + "AstTypeFunction", + [&]() + { + PROP(attributes); + PROP(generics); + PROP(genericPacks); + PROP(argTypes); + PROP(argNames); + PROP(returnTypes); + } + ); } void write(class AstTypeTypeof* node) { - writeNode(node, "AstTypeTypeof", [&]() { - PROP(expr); - }); + writeNode( + node, + "AstTypeTypeof", + [&]() + { + PROP(expr); + } + ); } void write(class AstTypeUnion* node) { - writeNode(node, "AstTypeUnion", [&]() { - PROP(types); - }); + writeNode( + node, + "AstTypeUnion", + [&]() + { + PROP(types); + } + ); } void write(class AstTypeIntersection* node) { - writeNode(node, "AstTypeIntersection", [&]() { - PROP(types); - }); + writeNode( + node, + "AstTypeIntersection", + [&]() + { + PROP(types); + } + ); } void write(class AstTypeError* node) { - writeNode(node, "AstTypeError", [&]() { - PROP(types); - PROP(messageIndex); - }); + writeNode( + node, + "AstTypeError", + [&]() + { + PROP(types); + PROP(messageIndex); + } + ); } void write(class AstTypePackExplicit* node) { - writeNode(node, "AstTypePackExplicit", [&]() { - PROP(typeList); - }); + writeNode( + node, + "AstTypePackExplicit", + [&]() + { + PROP(typeList); + } + ); } void write(class AstTypePackVariadic* node) { - writeNode(node, "AstTypePackVariadic", [&]() { - PROP(variadicType); - }); + writeNode( + node, + "AstTypePackVariadic", + [&]() + { + PROP(variadicType); + } + ); } void write(class AstTypePackGeneric* node) { - writeNode(node, "AstTypePackGeneric", [&]() { - PROP(genericName); - }); + writeNode( + node, + "AstTypePackGeneric", + [&]() + { + PROP(genericName); + } + ); + } + + void write(AstAttr::Type type) + { + switch (type) + { + case AstAttr::Type::Checked: + return writeString("checked"); + case AstAttr::Type::Native: + return writeString("native"); + } + } + + void write(class AstAttr* node) + { + writeNode( + node, + "AstAttr", + [&]() + { + write("name", node->type); + } + ); + } + + bool visit(class AstTypeSingletonBool* node) override + { + writeNode( + node, + "AstTypeSingletonBool", + [&]() + { + write("value", node->value); + } + ); + return false; + } + + bool visit(class AstTypeSingletonString* node) override + { + writeNode( + node, + "AstTypeSingletonString", + [&]() + { + write("value", node->value); + } + ); + return false; } bool visit(class AstExprGroup* node) override diff --git a/Analysis/src/AstQuery.cpp b/Analysis/src/AstQuery.cpp index e95b0017f..6b48b16ea 100644 --- a/Analysis/src/AstQuery.cpp +++ b/Analysis/src/AstQuery.cpp @@ -11,8 +11,9 @@ #include -LUAU_FASTFLAG(LuauCompleteTableKeysBetter); -LUAU_FASTFLAGVARIABLE(SupportTypeAliasGoToDeclaration, false); +LUAU_FASTFLAG(LuauSolverV2) + +LUAU_FASTFLAGVARIABLE(LuauDocumentationAtPosition, false) namespace Luau { @@ -32,24 +33,12 @@ struct AutocompleteNodeFinder : public AstVisitor bool visit(AstExpr* expr) override { - if (FFlag::LuauCompleteTableKeysBetter) + if (expr->location.begin <= pos && pos <= expr->location.end) { - if (expr->location.begin <= pos && pos <= expr->location.end) - { - ancestry.push_back(expr); - return true; - } - return false; - } - else - { - if (expr->location.begin < pos && pos <= expr->location.end) - { - ancestry.push_back(expr); - return true; - } - return false; + ancestry.push_back(expr); + return true; } + return false; } bool visit(AstStat* stat) override @@ -161,6 +150,16 @@ struct FindNode : public AstVisitor return false; } + bool visit(AstStatFunction* node) override + { + visit(static_cast(node)); + if (node->name->location.contains(pos)) + node->name->visit(this); + else if (node->func->location.contains(pos)) + node->func->visit(this); + return false; + } + bool visit(AstStatBlock* block) override { visit(static_cast(block)); @@ -179,87 +178,97 @@ struct FindNode : public AstVisitor } }; -struct FindFullAncestry final : public AstVisitor +} // namespace + +FindFullAncestry::FindFullAncestry(Position pos, Position documentEnd, bool includeTypes) + : pos(pos) + , documentEnd(documentEnd) + , includeTypes(includeTypes) { - std::vector nodes; - Position pos; - Position documentEnd; - bool includeTypes = false; +} - explicit FindFullAncestry(Position pos, Position documentEnd, bool includeTypes = false) - : pos(pos) - , documentEnd(documentEnd) - , includeTypes(includeTypes) - { - } +bool FindFullAncestry::visit(AstType* type) +{ + if (includeTypes) + return visit(static_cast(type)); + else + return false; +} - bool visit(AstType* type) override - { - if (FFlag::SupportTypeAliasGoToDeclaration) - { - if (includeTypes) - return visit(static_cast(type)); - else - return false; - } - else - { - return AstVisitor::visit(type); - } - } +bool FindFullAncestry::visit(AstStatFunction* node) +{ + visit(static_cast(node)); + if (node->name->location.contains(pos)) + node->name->visit(this); + else if (node->func->location.contains(pos)) + node->func->visit(this); + return false; +} - bool visit(AstNode* node) override +bool FindFullAncestry::visit(AstNode* node) +{ + if (node->location.contains(pos)) { - if (node->location.contains(pos)) - { - nodes.push_back(node); - return true; - } + nodes.push_back(node); + return true; + } - // Edge case: If we ask for the node at the position that is the very end of the document - // return the innermost AST element that ends at that position. + // Edge case: If we ask for the node at the position that is the very end of the document + // return the innermost AST element that ends at that position. - if (node->location.end == documentEnd && pos >= documentEnd) - { - nodes.push_back(node); - return true; - } - - return false; + if (node->location.end == documentEnd && pos >= documentEnd) + { + nodes.push_back(node); + return true; } -}; -} // namespace + return false; +} std::vector findAncestryAtPositionForAutocomplete(const SourceModule& source, Position pos) { - AutocompleteNodeFinder finder{pos, source.root}; - source.root->visit(&finder); + return findAncestryAtPositionForAutocomplete(source.root, pos); +} + +std::vector findAncestryAtPositionForAutocomplete(AstStatBlock* root, Position pos) +{ + AutocompleteNodeFinder finder{pos, root}; + root->visit(&finder); return finder.ancestry; } std::vector findAstAncestryOfPosition(const SourceModule& source, Position pos, bool includeTypes) { - const Position end = source.root->location.end; + return findAstAncestryOfPosition(source.root, pos, includeTypes); +} + +std::vector findAstAncestryOfPosition(AstStatBlock* root, Position pos, bool includeTypes) +{ + const Position end = root->location.end; if (pos > end) pos = end; FindFullAncestry finder(pos, end, includeTypes); - source.root->visit(&finder); + root->visit(&finder); return finder.nodes; } AstNode* findNodeAtPosition(const SourceModule& source, Position pos) { - const Position end = source.root->location.end; - if (pos < source.root->location.begin) - return source.root; + return findNodeAtPosition(source.root, pos); +} + +AstNode* findNodeAtPosition(AstStatBlock* root, Position pos) +{ + const Position end = root->location.end; + if (pos < root->location.begin) + return root; if (pos > end) pos = end; FindNode findNode{pos, end}; - findNode.visit(source.root); + findNode.visit(root); return findNode.best; } @@ -317,10 +326,20 @@ std::optional findExpectedTypeAtPosition(const Module& module, const Sou static std::optional findBindingLocalStatement(const SourceModule& source, const Binding& binding) { + // Bindings coming from global sources (e.g., definition files) have a zero position. + // They cannot be defined from a local statement + if (binding.location == Location{{0, 0}, {0, 0}}) + return std::nullopt; + std::vector nodes = findAstAncestryOfPosition(source, binding.location.begin); - auto iter = std::find_if(nodes.rbegin(), nodes.rend(), [](AstNode* node) { - return node->is(); - }); + auto iter = std::find_if( + nodes.rbegin(), + nodes.rend(), + [](AstNode* node) + { + return node->is(); + } + ); return iter != nodes.rend() ? std::make_optional((*iter)->as()) : std::nullopt; } @@ -459,7 +478,11 @@ ExprOrLocal findExprOrLocalAtPosition(const SourceModule& source, Position pos) } static std::optional checkOverloadedDocumentationSymbol( - const Module& module, const TypeId ty, const AstExpr* parentExpr, const std::optional documentationSymbol) + const Module& module, + const TypeId ty, + const AstExpr* parentExpr, + const std::optional documentationSymbol +) { if (!documentationSymbol) return std::nullopt; @@ -488,6 +511,38 @@ static std::optional checkOverloadedDocumentationSymbol( return documentationSymbol; } +static std::optional getMetatableDocumentation( + const Module& module, + AstExpr* parentExpr, + const TableType* mtable, + const AstName& index +) +{ + LUAU_ASSERT(FFlag::LuauDocumentationAtPosition); + auto indexIt = mtable->props.find("__index"); + if (indexIt == mtable->props.end()) + return std::nullopt; + + TypeId followed = follow(indexIt->second.type()); + const TableType* ttv = get(followed); + if (!ttv) + return std::nullopt; + + auto propIt = ttv->props.find(index.value); + if (propIt == ttv->props.end()) + return std::nullopt; + + if (FFlag::LuauSolverV2) + { + if (auto ty = propIt->second.readTy) + return checkOverloadedDocumentationSymbol(module, *ty, parentExpr, propIt->second.documentationSymbol); + } + else + return checkOverloadedDocumentationSymbol(module, propIt->second.type(), parentExpr, propIt->second.documentationSymbol); + + return std::nullopt; +} + std::optional getDocumentationSymbolAtPosition(const SourceModule& source, const Module& module, Position position) { std::vector ancestry = findAstAncestryOfPosition(source, position); @@ -508,12 +563,63 @@ std::optional getDocumentationSymbolAtPosition(const Source if (const TableType* ttv = get(parentTy)) { if (auto propIt = ttv->props.find(indexName->index.value); propIt != ttv->props.end()) - return checkOverloadedDocumentationSymbol(module, propIt->second.type, parentExpr, propIt->second.documentationSymbol); + { + if (FFlag::LuauSolverV2) + { + if (auto ty = propIt->second.readTy) + return checkOverloadedDocumentationSymbol(module, *ty, parentExpr, propIt->second.documentationSymbol); + } + else + return checkOverloadedDocumentationSymbol(module, propIt->second.type(), parentExpr, propIt->second.documentationSymbol); + } } else if (const ClassType* ctv = get(parentTy)) { - if (auto propIt = ctv->props.find(indexName->index.value); propIt != ctv->props.end()) - return checkOverloadedDocumentationSymbol(module, propIt->second.type, parentExpr, propIt->second.documentationSymbol); + if (FFlag::LuauDocumentationAtPosition) + { + while (ctv) + { + if (auto propIt = ctv->props.find(indexName->index.value); propIt != ctv->props.end()) + { + if (FFlag::LuauSolverV2) + { + if (auto ty = propIt->second.readTy) + return checkOverloadedDocumentationSymbol(module, *ty, parentExpr, propIt->second.documentationSymbol); + } + else + return checkOverloadedDocumentationSymbol( + module, propIt->second.type(), parentExpr, propIt->second.documentationSymbol + ); + } + ctv = ctv->parent ? Luau::get(*ctv->parent) : nullptr; + } + } + else + { + if (auto propIt = ctv->props.find(indexName->index.value); propIt != ctv->props.end()) + { + if (FFlag::LuauSolverV2) + { + if (auto ty = propIt->second.readTy) + return checkOverloadedDocumentationSymbol(module, *ty, parentExpr, propIt->second.documentationSymbol); + } + else + return checkOverloadedDocumentationSymbol( + module, propIt->second.type(), parentExpr, propIt->second.documentationSymbol + ); + } + } + } + else if (FFlag::LuauDocumentationAtPosition) + { + if (const PrimitiveType* ptv = get(parentTy); ptv && ptv->metatable) + { + if (auto mtable = get(*ptv->metatable)) + { + if (std::optional docSymbol = getMetatableDocumentation(module, parentExpr, mtable, indexName->index)) + return docSymbol; + } + } } } } diff --git a/Analysis/src/Autocomplete.cpp b/Analysis/src/Autocomplete.cpp index 1e0949711..c89d77931 100644 --- a/Analysis/src/Autocomplete.cpp +++ b/Analysis/src/Autocomplete.cpp @@ -3,23 +3,27 @@ #include "Luau/AstQuery.h" #include "Luau/BuiltinDefinitions.h" +#include "Luau/Common.h" +#include "Luau/FileResolver.h" #include "Luau/Frontend.h" #include "Luau/ToString.h" +#include "Luau/Subtyping.h" #include "Luau/TypeInfer.h" #include "Luau/TypePack.h" -#include "Luau/TypeReduction.h" #include #include #include -LUAU_FASTFLAGVARIABLE(LuauCompleteTableKeysBetter, false); -LUAU_FASTFLAGVARIABLE(LuauFixAutocompleteInWhile, false); -LUAU_FASTFLAGVARIABLE(LuauFixAutocompleteInFor, false); -LUAU_FASTFLAGVARIABLE(LuauAutocompleteSkipNormalization, false); +LUAU_FASTFLAG(LuauSolverV2) +LUAU_FASTFLAGVARIABLE(AutocompleteRequirePathSuggestions, false) -static const std::unordered_set kStatementStartingKeywords = { - "while", "if", "local", "repeat", "function", "do", "for", "return", "break", "continue", "type", "export"}; +LUAU_DYNAMIC_FASTINT(LuauTypeSolverRelease) +LUAU_FASTINT(LuauTypeInferIterationLimit) +LUAU_FASTINT(LuauTypeInferRecursionLimit) + +static const std::unordered_set kStatementStartingKeywords = + {"while", "if", "local", "repeat", "function", "do", "for", "return", "break", "continue", "type", "export"}; namespace Luau { @@ -144,20 +148,41 @@ static bool checkTypeMatch(TypeId subTy, TypeId superTy, NotNull scope, T InternalErrorReporter iceReporter; UnifierSharedState unifierState(&iceReporter); Normalizer normalizer{typeArena, builtinTypes, NotNull{&unifierState}}; - Unifier unifier(NotNull{&normalizer}, Mode::Strict, scope, Location(), Variance::Covariant); - if (FFlag::LuauAutocompleteSkipNormalization) + if (FFlag::LuauSolverV2) + { + TypeCheckLimits limits; + TypeFunctionRuntime typeFunctionRuntime{ + NotNull{&iceReporter}, NotNull{&limits} + }; // TODO: maybe subtyping checks should not invoke user-defined type function runtime + + unifierState.counters.recursionLimit = FInt::LuauTypeInferRecursionLimit; + unifierState.counters.iterationLimit = FInt::LuauTypeInferIterationLimit; + + Subtyping subtyping{builtinTypes, NotNull{typeArena}, NotNull{&normalizer}, NotNull{&typeFunctionRuntime}, NotNull{&iceReporter}}; + + return subtyping.isSubtype(subTy, superTy, scope).isSubtype; + } + else { + Unifier unifier(NotNull{&normalizer}, scope, Location(), Variance::Covariant); + // Cost of normalization can be too high for autocomplete response time requirements unifier.normalize = false; unifier.checkInhabited = false; - } - return unifier.canUnify(subTy, superTy).empty(); + return unifier.canUnify(subTy, superTy).empty(); + } } static TypeCorrectKind checkTypeCorrectKind( - const Module& module, TypeArena* typeArena, NotNull builtinTypes, AstNode* node, Position position, TypeId ty) + const Module& module, + TypeArena* typeArena, + NotNull builtinTypes, + AstNode* node, + Position position, + TypeId ty +) { ty = follow(ty); @@ -172,7 +197,8 @@ static TypeCorrectKind checkTypeCorrectKind( TypeId expectedType = follow(*typeAtPosition); - auto checkFunctionType = [typeArena, builtinTypes, moduleScope, &expectedType](const FunctionType* ftv) { + auto checkFunctionType = [typeArena, builtinTypes, moduleScope, &expectedType](const FunctionType* ftv) + { if (std::optional firstRetTy = first(ftv->retTypes)) return checkTypeMatch(*firstRetTy, expectedType, moduleScope, typeArena, builtinTypes); @@ -188,6 +214,8 @@ static TypeCorrectKind checkTypeCorrectKind( { for (TypeId id : itv->parts) { + id = follow(id); + if (const FunctionType* ftv = get(id); ftv && checkFunctionType(ftv)) { return TypeCorrectKind::CorrectFunctionResult; @@ -205,9 +233,18 @@ enum class PropIndexType Key, }; -static void autocompleteProps(const Module& module, TypeArena* typeArena, NotNull builtinTypes, TypeId rootTy, TypeId ty, - PropIndexType indexType, const std::vector& nodes, AutocompleteEntryMap& result, std::unordered_set& seen, - std::optional containingClass = std::nullopt) +static void autocompleteProps( + const Module& module, + TypeArena* typeArena, + NotNull builtinTypes, + TypeId rootTy, + TypeId ty, + PropIndexType indexType, + const std::vector& nodes, + AutocompleteEntryMap& result, + std::unordered_set& seen, + std::optional containingClass = std::nullopt +) { rootTy = follow(rootTy); ty = follow(ty); @@ -216,13 +253,15 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, NotNul return; seen.insert(ty); - auto isWrongIndexer = [typeArena, builtinTypes, &module, rootTy, indexType](Luau::TypeId type) { + auto isWrongIndexer = [typeArena, builtinTypes, &module, rootTy, indexType](Luau::TypeId type) + { if (indexType == PropIndexType::Key) return false; bool calledWithSelf = indexType == PropIndexType::Colon; - auto isCompatibleCall = [typeArena, builtinTypes, &module, rootTy, calledWithSelf](const FunctionType* ftv) { + auto isCompatibleCall = [typeArena, builtinTypes, &module, rootTy, calledWithSelf](const FunctionType* ftv) + { // Strong match with definition is a success if (calledWithSelf == ftv->hasSelf) return true; @@ -261,17 +300,30 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, NotNul return calledWithSelf; }; - auto fillProps = [&](const ClassType::Props& props) { + auto fillProps = [&](const ClassType::Props& props) + { for (const auto& [name, prop] : props) { // We are walking up the class hierarchy, so if we encounter a property that we have // already populated, it takes precedence over the property we found just now. if (result.count(name) == 0 && name != kParseNameError) { - Luau::TypeId type = Luau::follow(prop.type); + Luau::TypeId type; + + if (FFlag::LuauSolverV2) + { + if (auto ty = prop.readTy) + type = follow(*ty); + else + continue; + } + else + type = follow(prop.type()); + TypeCorrectKind typeCorrect = indexType == PropIndexType::Key ? TypeCorrectKind::Correct : checkTypeCorrectKind(module, typeArena, builtinTypes, nodes.back(), {{}, {}}, type); + ParenthesesRecommendation parens = indexType == PropIndexType::Key ? ParenthesesRecommendation::None : getParenRecommendation(type, nodes, typeCorrect); @@ -286,16 +338,19 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, NotNul prop.documentationSymbol, {}, parens, + {}, + indexType == PropIndexType::Colon }; } } }; - auto fillMetatableProps = [&](const TableType* mtable) { + auto fillMetatableProps = [&](const TableType* mtable) + { auto indexIt = mtable->props.find("__index"); if (indexIt != mtable->props.end()) { - TypeId followed = follow(indexIt->second.type); + TypeId followed = follow(indexIt->second.type()); if (get(followed) || get(followed)) { autocompleteProps(module, typeArena, builtinTypes, rootTy, followed, indexType, nodes, result, seen); @@ -403,7 +458,11 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, NotNul } static void autocompleteKeywords( - const SourceModule& sourceModule, const std::vector& ancestry, Position position, AutocompleteEntryMap& result) + const SourceModule& sourceModule, + const std::vector& ancestry, + Position position, + AutocompleteEntryMap& result +) { LUAU_ASSERT(!ancestry.empty()); @@ -423,15 +482,28 @@ static void autocompleteKeywords( } } -static void autocompleteProps(const Module& module, TypeArena* typeArena, NotNull builtinTypes, TypeId ty, PropIndexType indexType, - const std::vector& nodes, AutocompleteEntryMap& result) +static void autocompleteProps( + const Module& module, + TypeArena* typeArena, + NotNull builtinTypes, + TypeId ty, + PropIndexType indexType, + const std::vector& nodes, + AutocompleteEntryMap& result +) { std::unordered_set seen; autocompleteProps(module, typeArena, builtinTypes, ty, ty, indexType, nodes, result, seen); } -AutocompleteEntryMap autocompleteProps(const Module& module, TypeArena* typeArena, NotNull builtinTypes, TypeId ty, - PropIndexType indexType, const std::vector& nodes) +AutocompleteEntryMap autocompleteProps( + const Module& module, + TypeArena* typeArena, + NotNull builtinTypes, + TypeId ty, + PropIndexType indexType, + const std::vector& nodes +) { AutocompleteEntryMap result; autocompleteProps(module, typeArena, builtinTypes, ty, indexType, nodes, result); @@ -456,9 +528,18 @@ AutocompleteEntryMap autocompleteModuleTypes(const Module& module, Position posi return result; } -static void autocompleteStringSingleton(TypeId ty, bool addQuotes, AutocompleteEntryMap& result) +static void autocompleteStringSingleton(TypeId ty, bool addQuotes, AstNode* node, Position position, AutocompleteEntryMap& result) { - auto formatKey = [addQuotes](const std::string& key) { + if (position == node->location.begin || position == node->location.end) + { + if (auto str = node->as(); str && str->quoteStyle == AstExprConstantString::Quoted) + return; + else if (node->is()) + return; + } + + auto formatKey = [addQuotes](const std::string& key) + { if (addQuotes) return "\"" + escape(key) + "\""; @@ -586,14 +667,13 @@ std::optional getLocalTypeInScopeAt(const Module& module, Position posit return {}; } -static std::optional tryGetTypeNameInScope(ScopePtr scope, TypeId ty) +template +static std::optional tryToStringDetailed(const ScopePtr& scope, T ty, bool functionTypeArguments) { - if (!canSuggestInferredType(scope, ty)) - return std::nullopt; - ToStringOptions opts; opts.useLineBreaks = false; opts.hideTableKind = true; + opts.functionTypeArguments = functionTypeArguments; opts.scope = scope; ToStringResult name = toStringDetailed(ty, opts); @@ -603,6 +683,14 @@ static std::optional tryGetTypeNameInScope(ScopePtr scope, TypeId ty) return name.name; } +static std::optional tryGetTypeNameInScope(ScopePtr scope, TypeId ty, bool functionTypeArguments = false) +{ + if (!canSuggestInferredType(scope, ty)) + return std::nullopt; + + return tryToStringDetailed(scope, ty, functionTypeArguments); +} + static bool tryAddTypeCorrectSuggestion(AutocompleteEntryMap& result, ScopePtr scope, AstType* topType, TypeId inferredType, Position position) { std::optional ty; @@ -684,9 +772,14 @@ static std::optional functionIsExpectedAt(const Module& module, AstNode* n if (const IntersectionType* itv = get(expectedType)) { - return std::all_of(begin(itv->parts), end(itv->parts), [](auto&& ty) { - return get(Luau::follow(ty)) != nullptr; - }); + return std::all_of( + begin(itv->parts), + end(itv->parts), + [](auto&& ty) + { + return get(Luau::follow(ty)) != nullptr; + } + ); } if (const UnionType* utv = get(expectedType)) @@ -706,15 +799,31 @@ AutocompleteEntryMap autocompleteTypeNames(const Module& module, Position positi for (const auto& [name, ty] : scope->exportedTypeBindings) { if (!result.count(name)) - result[name] = AutocompleteEntry{AutocompleteEntryKind::Type, ty.type, false, false, TypeCorrectKind::None, std::nullopt, - std::nullopt, ty.type->documentationSymbol}; + result[name] = AutocompleteEntry{ + AutocompleteEntryKind::Type, + ty.type, + false, + false, + TypeCorrectKind::None, + std::nullopt, + std::nullopt, + ty.type->documentationSymbol + }; } for (const auto& [name, ty] : scope->privateTypeBindings) { if (!result.count(name)) - result[name] = AutocompleteEntry{AutocompleteEntryKind::Type, ty.type, false, false, TypeCorrectKind::None, std::nullopt, - std::nullopt, ty.type->documentationSymbol}; + result[name] = AutocompleteEntry{ + AutocompleteEntryKind::Type, + ty.type, + false, + false, + TypeCorrectKind::None, + std::nullopt, + std::nullopt, + ty.type->documentationSymbol + }; } for (const auto& [name, _] : scope->importedTypeBindings) @@ -804,7 +913,8 @@ AutocompleteEntryMap autocompleteTypeNames(const Module& module, Position positi else if (AstExprFunction* node = parent->as()) { // For lookup inside expected function type if that's available - auto tryGetExpectedFunctionType = [](const Module& module, AstExpr* expr) -> const FunctionType* { + auto tryGetExpectedFunctionType = [](const Module& module, AstExpr* expr) -> const FunctionType* + { auto it = module.astExpectedTypes.find(expr); if (!it) @@ -983,25 +1093,14 @@ T* extractStat(const std::vector& ancestry) AstNode* grandParent = ancestry.size() >= 3 ? ancestry.rbegin()[2] : nullptr; AstNode* greatGrandParent = ancestry.size() >= 4 ? ancestry.rbegin()[3] : nullptr; - if (FFlag::LuauCompleteTableKeysBetter) - { - if (!grandParent) - return nullptr; - - if (T* t = parent->as(); t && grandParent->is()) - return t; + if (!grandParent) + return nullptr; - if (!greatGrandParent) - return nullptr; - } - else - { - if (T* t = parent->as(); t && parent->is()) - return t; + if (T* t = parent->as(); t && grandParent->is()) + return t; - if (!grandParent || !greatGrandParent) - return nullptr; - } + if (!greatGrandParent) + return nullptr; if (T* t = greatGrandParent->as(); t && grandParent->is() && parent->is() && isIdentifier(node)) return t; @@ -1019,7 +1118,11 @@ static bool isBindingLegalAtCurrentPosition(const Symbol& symbol, const Binding& } static AutocompleteEntryMap autocompleteStatement( - const SourceModule& sourceModule, const Module& module, const std::vector& ancestry, Position position) + const SourceModule& sourceModule, + const Module& module, + const std::vector& ancestry, + Position position +) { // This is inefficient. :( ScopePtr scope = findScopeAtPosition(module, position); @@ -1041,8 +1144,18 @@ static AutocompleteEntryMap autocompleteStatement( std::string n = toString(name); if (!result.count(n)) - result[n] = {AutocompleteEntryKind::Binding, binding.typeId, binding.deprecated, false, TypeCorrectKind::None, std::nullopt, - std::nullopt, binding.documentationSymbol, {}, getParenRecommendation(binding.typeId, ancestry, TypeCorrectKind::None)}; + result[n] = { + AutocompleteEntryKind::Binding, + binding.typeId, + binding.deprecated, + false, + TypeCorrectKind::None, + std::nullopt, + std::nullopt, + binding.documentationSymbol, + {}, + getParenRecommendation(binding.typeId, ancestry, TypeCorrectKind::None) + }; } scope = scope->parent; @@ -1053,15 +1166,27 @@ static AutocompleteEntryMap autocompleteStatement( for (auto it = ancestry.rbegin(); it != ancestry.rend(); ++it) { - if (AstStatForIn* statForIn = (*it)->as(); statForIn && !statForIn->hasEnd) + if (AstStatForIn* statForIn = (*it)->as(); statForIn && !statForIn->body->hasEnd) result.emplace("end", AutocompleteEntry{AutocompleteEntryKind::Keyword}); - if (AstStatFor* statFor = (*it)->as(); statFor && !statFor->hasEnd) + else if (AstStatFor* statFor = (*it)->as(); statFor && !statFor->body->hasEnd) result.emplace("end", AutocompleteEntry{AutocompleteEntryKind::Keyword}); - if (AstStatIf* statIf = (*it)->as(); statIf && !statIf->hasEnd) + else if (AstStatIf* statIf = (*it)->as()) + { + bool hasEnd = statIf->thenbody->hasEnd; + if (statIf->elsebody) + { + if (AstStatBlock* elseBlock = statIf->elsebody->as()) + hasEnd = elseBlock->hasEnd; + } + + if (!hasEnd) + result.emplace("end", AutocompleteEntry{AutocompleteEntryKind::Keyword}); + } + else if (AstStatWhile* statWhile = (*it)->as(); statWhile && !statWhile->body->hasEnd) result.emplace("end", AutocompleteEntry{AutocompleteEntryKind::Keyword}); - if (AstStatWhile* statWhile = (*it)->as(); statWhile && !statWhile->hasEnd) + else if (AstExprFunction* exprFunction = (*it)->as(); exprFunction && !exprFunction->body->hasEnd) result.emplace("end", AutocompleteEntry{AutocompleteEntryKind::Keyword}); - if (AstExprFunction* exprFunction = (*it)->as(); exprFunction && !exprFunction->hasEnd) + if (AstStatBlock* exprBlock = (*it)->as(); exprBlock && !exprBlock->hasEnd) result.emplace("end", AutocompleteEntry{AutocompleteEntryKind::Keyword}); } @@ -1077,7 +1202,7 @@ static AutocompleteEntryMap autocompleteStatement( } } - if (AstStatRepeat* statRepeat = parent->as(); statRepeat && !statRepeat->hasUntil) + if (AstStatRepeat* statRepeat = parent->as(); statRepeat && !statRepeat->body->hasEnd) result.emplace("until", AutocompleteEntry{AutocompleteEntryKind::Keyword}); } @@ -1092,7 +1217,7 @@ static AutocompleteEntryMap autocompleteStatement( } } - if (AstStatRepeat* statRepeat = extractStat(ancestry); statRepeat && !statRepeat->hasUntil) + if (AstStatRepeat* statRepeat = extractStat(ancestry); statRepeat && !statRepeat->body->hasEnd) result.emplace("until", AutocompleteEntry{AutocompleteEntryKind::Keyword}); return result; @@ -1100,7 +1225,11 @@ static AutocompleteEntryMap autocompleteStatement( // Returns true iff `node` was handled by this function (completions, if any, are returned in `outResult`) static bool autocompleteIfElseExpression( - const AstNode* node, const std::vector& ancestry, const Position& position, AutocompleteEntryMap& outResult) + const AstNode* node, + const std::vector& ancestry, + const Position& position, + AutocompleteEntryMap& outResult +) { AstNode* parent = ancestry.size() >= 2 ? ancestry.rbegin()[1] : nullptr; if (!parent) @@ -1139,8 +1268,15 @@ static bool autocompleteIfElseExpression( } } -static AutocompleteContext autocompleteExpression(const SourceModule& sourceModule, const Module& module, NotNull builtinTypes, - TypeArena* typeArena, const std::vector& ancestry, Position position, AutocompleteEntryMap& result) +static AutocompleteContext autocompleteExpression( + const SourceModule& sourceModule, + const Module& module, + NotNull builtinTypes, + TypeArena* typeArena, + const std::vector& ancestry, + Position position, + AutocompleteEntryMap& result +) { LUAU_ASSERT(!ancestry.empty()); @@ -1175,8 +1311,18 @@ static AutocompleteContext autocompleteExpression(const SourceModule& sourceModu { TypeCorrectKind typeCorrect = checkTypeCorrectKind(module, typeArena, builtinTypes, node, position, binding.typeId); - result[n] = {AutocompleteEntryKind::Binding, binding.typeId, binding.deprecated, false, typeCorrect, std::nullopt, std::nullopt, - binding.documentationSymbol, {}, getParenRecommendation(binding.typeId, ancestry, typeCorrect)}; + result[n] = { + AutocompleteEntryKind::Binding, + binding.typeId, + binding.deprecated, + false, + typeCorrect, + std::nullopt, + std::nullopt, + binding.documentationSymbol, + {}, + getParenRecommendation(binding.typeId, ancestry, typeCorrect) + }; } } @@ -1197,14 +1343,20 @@ static AutocompleteContext autocompleteExpression(const SourceModule& sourceModu result["function"] = {AutocompleteEntryKind::Keyword, std::nullopt, false, false, correctForFunction}; if (auto ty = findExpectedTypeAt(module, node, position)) - autocompleteStringSingleton(*ty, true, result); + autocompleteStringSingleton(*ty, true, node, position, result); } return AutocompleteContext::Expression; } -static AutocompleteResult autocompleteExpression(const SourceModule& sourceModule, const Module& module, NotNull builtinTypes, - TypeArena* typeArena, const std::vector& ancestry, Position position) +static AutocompleteResult autocompleteExpression( + const SourceModule& sourceModule, + const Module& module, + NotNull builtinTypes, + TypeArena* typeArena, + const std::vector& ancestry, + Position position +) { AutocompleteEntryMap result; AutocompleteContext context = autocompleteExpression(sourceModule, module, builtinTypes, typeArena, ancestry, position, result); @@ -1290,8 +1442,27 @@ static std::optional getStringContents(const AstNode* node) } } -static std::optional autocompleteStringParams(const SourceModule& sourceModule, const ModulePtr& module, - const std::vector& nodes, Position position, StringCompletionCallback callback) +static std::optional convertRequireSuggestionsToAutocompleteEntryMap(std::optional suggestions) +{ + if (!suggestions) + return std::nullopt; + + AutocompleteEntryMap result; + for (const RequireSuggestion& suggestion : *suggestions) + { + result[suggestion] = {AutocompleteEntryKind::RequirePath}; + } + return result; +} + +static std::optional autocompleteStringParams( + const SourceModule& sourceModule, + const ModulePtr& module, + const std::vector& nodes, + Position position, + FileResolver* fileResolver, + StringCompletionCallback callback +) { if (nodes.size() < 2) { @@ -1303,6 +1474,14 @@ static std::optional autocompleteStringParams(const Source return std::nullopt; } + if (!nodes.back()->is()) + { + if (nodes.back()->location.end == position || nodes.back()->location.begin == position) + { + return std::nullopt; + } + } + AstExprCall* candidate = nodes.at(nodes.size() - 2)->as(); if (!candidate) { @@ -1324,9 +1503,17 @@ static std::optional autocompleteStringParams(const Source std::optional candidateString = getStringContents(nodes.back()); - auto performCallback = [&](const FunctionType* funcType) -> std::optional { + auto performCallback = [&](const FunctionType* funcType) -> std::optional + { for (const std::string& tag : funcType->tags) { + if (FFlag::AutocompleteRequirePathSuggestions) + { + if (tag == kRequireTagName && fileResolver) + { + return convertRequireSuggestionsToAutocompleteEntryMap(fileResolver->getRequireSuggestions(module->name, candidateString)); + } + } if (std::optional ret = callback(tag, getMethodContainingClass(module, candidate->func), candidateString)) { return ret; @@ -1367,8 +1554,153 @@ static AutocompleteResult autocompleteWhileLoopKeywords(std::vector an return {std::move(ret), std::move(ancestry), AutocompleteContext::Keyword}; } -static AutocompleteResult autocomplete(const SourceModule& sourceModule, const ModulePtr& module, NotNull builtinTypes, - TypeArena* typeArena, Scope* globalScope, Position position, StringCompletionCallback callback) +static std::string makeAnonymous(const ScopePtr& scope, const FunctionType& funcTy) +{ + std::string result = "function("; + + auto [args, tail] = Luau::flatten(funcTy.argTypes); + + bool first = true; + // Skip the implicit 'self' argument if call is indexed with ':' + for (size_t argIdx = 0; argIdx < args.size(); ++argIdx) + { + if (!first) + result += ", "; + else + first = false; + + std::string name; + if (argIdx < funcTy.argNames.size() && funcTy.argNames[argIdx]) + name = funcTy.argNames[argIdx]->name; + else + name = "a" + std::to_string(argIdx); + + if (std::optional type = tryGetTypeNameInScope(scope, args[argIdx], true)) + result += name + ": " + *type; + else + result += name; + } + + if (tail && (Luau::isVariadic(*tail) || Luau::get(Luau::follow(*tail)))) + { + if (!first) + result += ", "; + + std::optional varArgType; + if (const VariadicTypePack* pack = get(follow(*tail))) + { + if (std::optional res = tryToStringDetailed(scope, pack->ty, true)) + varArgType = std::move(res); + } + + if (varArgType) + result += "...: " + *varArgType; + else + result += "..."; + } + + result += ")"; + + auto [rets, retTail] = Luau::flatten(funcTy.retTypes); + if (const size_t totalRetSize = rets.size() + (retTail ? 1 : 0); totalRetSize > 0) + { + if (std::optional returnTypes = tryToStringDetailed(scope, funcTy.retTypes, true)) + { + result += ": "; + bool wrap = totalRetSize != 1; + if (wrap) + result += "("; + result += *returnTypes; + if (wrap) + result += ")"; + } + } + result += " end"; + return result; +} + +static std::optional makeAnonymousAutofilled( + const ModulePtr& module, + Position position, + const AstNode* node, + const std::vector& ancestry +) +{ + const AstExprCall* call = node->as(); + if (!call && ancestry.size() > 1) + call = ancestry[ancestry.size() - 2]->as(); + + if (!call) + return std::nullopt; + + if (!call->location.containsClosed(position) || call->func->location.containsClosed(position)) + return std::nullopt; + + TypeId* typeIter = module->astTypes.find(call->func); + if (!typeIter) + return std::nullopt; + + const FunctionType* outerFunction = get(follow(*typeIter)); + if (!outerFunction) + return std::nullopt; + + size_t argument = 0; + for (size_t i = 0; i < call->args.size; ++i) + { + if (call->args.data[i]->location.containsClosed(position)) + { + argument = i; + break; + } + } + + if (call->self) + argument++; + + std::optional argType; + auto [args, tail] = flatten(outerFunction->argTypes); + if (argument < args.size()) + argType = args[argument]; + + if (!argType) + return std::nullopt; + + TypeId followed = follow(*argType); + const FunctionType* type = get(followed); + if (!type) + { + if (const UnionType* unionType = get(followed)) + { + if (std::optional nonnullFunction = returnFirstNonnullOptionOfType(unionType)) + type = *nonnullFunction; + } + } + + if (!type) + return std::nullopt; + + const ScopePtr scope = findScopeAtPosition(*module, position); + if (!scope) + return std::nullopt; + + AutocompleteEntry entry; + entry.kind = AutocompleteEntryKind::GeneratedFunction; + entry.typeCorrect = TypeCorrectKind::Correct; + entry.type = argType; + entry.insertText = makeAnonymous(scope, *type); + return std::make_optional(std::move(entry)); +} + +static AutocompleteResult autocomplete( + const SourceModule& sourceModule, + const ModulePtr& module, + NotNull builtinTypes, + TypeArena* typeArena, + Scope* globalScope, + Position position, + FileResolver* fileResolver, + StringCompletionCallback callback +) { if (isWithinComment(sourceModule, position)) return {}; @@ -1425,24 +1757,12 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M { if (!statFor->hasDo || position < statFor->doLocation.begin) { - if (FFlag::LuauFixAutocompleteInFor) - { - if (statFor->from->location.containsClosed(position) || statFor->to->location.containsClosed(position) || - (statFor->step && statFor->step->location.containsClosed(position))) - return autocompleteExpression(sourceModule, *module, builtinTypes, typeArena, ancestry, position); - - if (!statFor->from->is() && !statFor->to->is() && (!statFor->step || !statFor->step->is())) - return {{{"do", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, ancestry, AutocompleteContext::Keyword}; - } - else - { - if (!statFor->from->is() && !statFor->to->is() && (!statFor->step || !statFor->step->is())) - return {{{"do", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, ancestry, AutocompleteContext::Keyword}; + if (statFor->from->location.containsClosed(position) || statFor->to->location.containsClosed(position) || + (statFor->step && statFor->step->location.containsClosed(position))) + return autocompleteExpression(sourceModule, *module, builtinTypes, typeArena, ancestry, position); - if (statFor->from->location.containsClosed(position) || statFor->to->location.containsClosed(position) || - (statFor->step && statFor->step->location.containsClosed(position))) - return autocompleteExpression(sourceModule, *module, builtinTypes, typeArena, ancestry, position); - } + if (!statFor->from->is() && !statFor->to->is() && (!statFor->step || !statFor->step->is())) + return {{{"do", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, ancestry, AutocompleteContext::Keyword}; return {}; } @@ -1493,14 +1813,7 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M { if (!statWhile->hasDo && !statWhile->condition->is() && position > statWhile->condition->location.end) { - if (FFlag::LuauFixAutocompleteInWhile) - { - return autocompleteWhileLoopKeywords(ancestry); - } - else - { - return {{{"do", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, ancestry, AutocompleteContext::Keyword}; - } + return autocompleteWhileLoopKeywords(ancestry); } if (!statWhile->hasDo || position < statWhile->doLocation.begin) @@ -1511,23 +1824,18 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M } else if (AstStatWhile* statWhile = extractStat(ancestry); - FFlag::LuauFixAutocompleteInWhile ? (statWhile && (!statWhile->hasDo || statWhile->doLocation.containsClosed(position)) && - statWhile->condition && !statWhile->condition->location.containsClosed(position)) - : (statWhile && !statWhile->hasDo)) + (statWhile && (!statWhile->hasDo || statWhile->doLocation.containsClosed(position)) && statWhile->condition && + !statWhile->condition->location.containsClosed(position))) { - if (FFlag::LuauFixAutocompleteInWhile) - { - return autocompleteWhileLoopKeywords(ancestry); - } - else - { - return {{{"do", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, ancestry, AutocompleteContext::Keyword}; - } + return autocompleteWhileLoopKeywords(ancestry); } else if (AstStatIf* statIf = node->as(); statIf && !statIf->elseLocation.has_value()) { - return {{{"else", AutocompleteEntry{AutocompleteEntryKind::Keyword}}, {"elseif", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, - ancestry, AutocompleteContext::Keyword}; + return { + {{"else", AutocompleteEntry{AutocompleteEntryKind::Keyword}}, {"elseif", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, + ancestry, + AutocompleteContext::Keyword + }; } else if (AstStatIf* statIf = parent->as(); statIf && node->is()) { @@ -1562,23 +1870,20 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M { auto result = autocompleteProps(*module, typeArena, builtinTypes, *it, PropIndexType::Key, ancestry); - if (FFlag::LuauCompleteTableKeysBetter) - { - if (auto nodeIt = module->astExpectedTypes.find(node->asExpr())) - autocompleteStringSingleton(*nodeIt, !node->is(), result); + if (auto nodeIt = module->astExpectedTypes.find(node->asExpr())) + autocompleteStringSingleton(*nodeIt, !node->is(), node, position, result); - if (!key) + if (!key) + { + // If there is "no key," it may be that the user + // intends for the current token to be the key, but + // has yet to type the `=` sign. + // + // If the key type is a union of singleton strings, + // suggest those too. + if (auto ttv = get(follow(*it)); ttv && ttv->indexer) { - // If there is "no key," it may be that the user - // intends for the current token to be the key, but - // has yet to type the `=` sign. - // - // If the key type is a union of singleton strings, - // suggest those too. - if (auto ttv = get(follow(*it)); ttv && ttv->indexer) - { - autocompleteStringSingleton(ttv->indexer->indexType, false, result); - } + autocompleteStringSingleton(ttv->indexer->indexType, false, node, position, result); } } @@ -1603,10 +1908,41 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M } } } + else if (AstExprTable* exprTable = node->as()) + { + AutocompleteEntryMap result; + + if (auto it = module->astExpectedTypes.find(exprTable)) + { + result = autocompleteProps(*module, typeArena, builtinTypes, *it, PropIndexType::Key, ancestry); + + // If the key type is a union of singleton strings, + // suggest those too. + if (auto ttv = get(follow(*it)); ttv && ttv->indexer) + { + autocompleteStringSingleton(ttv->indexer->indexType, false, node, position, result); + } + + // Remove keys that are already completed + for (const auto& item : exprTable->items) + { + if (!item.key) + continue; + + if (auto stringKey = item.key->as()) + result.erase(std::string(stringKey->value.data, stringKey->value.size)); + } + } + + // Also offer general expression suggestions + autocompleteExpression(sourceModule, *module, builtinTypes, typeArena, ancestry, position, result); + + return {result, ancestry, AutocompleteContext::Property}; + } else if (isIdentifier(node) && (parent->is() || parent->is())) return {autocompleteStatement(sourceModule, *module, ancestry, position), ancestry, AutocompleteContext::Statement}; - if (std::optional ret = autocompleteStringParams(sourceModule, module, ancestry, position, callback)) + if (std::optional ret = autocompleteStringParams(sourceModule, module, ancestry, position, fileResolver, callback)) { return {*ret, ancestry, AutocompleteContext::String}; } @@ -1615,7 +1951,7 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M AutocompleteEntryMap result; if (auto it = module->astExpectedTypes.find(node->asExpr())) - autocompleteStringSingleton(*it, false, result); + autocompleteStringSingleton(*it, false, node, position, result); if (ancestry.size() >= 2) { @@ -1629,7 +1965,7 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M if (binExpr->op == AstExprBinary::CompareEq || binExpr->op == AstExprBinary::CompareNe) { if (auto it = module->astTypes.find(node == binExpr->left ? binExpr->right : binExpr->left)) - autocompleteStringSingleton(*it, false, result); + autocompleteStringSingleton(*it, false, node, position, result); } } } @@ -1648,7 +1984,12 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M return {}; if (node->asExpr()) - return autocompleteExpression(sourceModule, *module, builtinTypes, typeArena, ancestry, position); + { + AutocompleteResult ret = autocompleteExpression(sourceModule, *module, builtinTypes, typeArena, ancestry, position); + if (std::optional generated = makeAnonymousAutofilled(module, position, node, ancestry)) + ret.entryMap[kGeneratedAnonymousFunctionEntryName] = std::move(*generated); + return ret; + } else if (node->asStat()) return {autocompleteStatement(sourceModule, *module, ancestry, position), ancestry, AutocompleteContext::Statement}; @@ -1657,25 +1998,28 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M AutocompleteResult autocomplete(Frontend& frontend, const ModuleName& moduleName, Position position, StringCompletionCallback callback) { - // FIXME: We can improve performance here by parsing without checking. - // The old type graph is probably fine. (famous last words!) - FrontendOptions opts; - opts.forAutocomplete = true; - frontend.check(moduleName, opts); - const SourceModule* sourceModule = frontend.getSourceModule(moduleName); if (!sourceModule) return {}; - ModulePtr module = frontend.moduleResolverForAutocomplete.getModule(moduleName); + ModulePtr module; + if (FFlag::LuauSolverV2) + module = frontend.moduleResolver.getModule(moduleName); + else + module = frontend.moduleResolverForAutocomplete.getModule(moduleName); + if (!module) return {}; NotNull builtinTypes = frontend.builtinTypes; - Scope* globalScope = frontend.typeCheckerForAutocomplete.globalScope.get(); + Scope* globalScope; + if (FFlag::LuauSolverV2) + globalScope = frontend.globals.globalScope.get(); + else + globalScope = frontend.globalsForAutocomplete.globalScope.get(); TypeArena typeArena; - return autocomplete(*sourceModule, module, builtinTypes, &typeArena, globalScope, position, callback); + return autocomplete(*sourceModule, module, builtinTypes, &typeArena, globalScope, position, frontend.fileResolver, callback); } } // namespace Luau diff --git a/Analysis/src/BuiltinDefinitions.cpp b/Analysis/src/BuiltinDefinitions.cpp index b111c504a..84d2d6e9d 100644 --- a/Analysis/src/BuiltinDefinitions.cpp +++ b/Analysis/src/BuiltinDefinitions.cpp @@ -2,45 +2,77 @@ #include "Luau/BuiltinDefinitions.h" #include "Luau/Ast.h" +#include "Luau/Clone.h" +#include "Luau/Error.h" #include "Luau/Frontend.h" #include "Luau/Symbol.h" #include "Luau/Common.h" #include "Luau/ToString.h" #include "Luau/ConstraintSolver.h" -#include "Luau/ConstraintGraphBuilder.h" +#include "Luau/ConstraintGenerator.h" +#include "Luau/NotNull.h" #include "Luau/TypeInfer.h" +#include "Luau/TypeChecker2.h" +#include "Luau/TypeFunction.h" #include "Luau/TypePack.h" #include "Luau/Type.h" #include "Luau/TypeUtils.h" +#include "Luau/Subtyping.h" #include -LUAU_FASTFLAGVARIABLE(LuauDeprecateTableGetnForeach, false) - /** FIXME: Many of these type definitions are not quite completely accurate. * * Some of them require richer generics than we have. For instance, we do not yet have a way to talk * about a function that takes any number of values, but where each value must have some specific type. */ +LUAU_FASTFLAG(LuauSolverV2) +LUAU_DYNAMIC_FASTINT(LuauTypeSolverRelease) +LUAU_FASTFLAGVARIABLE(LuauTypestateBuiltins, false) +LUAU_FASTFLAGVARIABLE(LuauStringFormatArityFix, false) + +LUAU_FASTFLAG(AutocompleteRequirePathSuggestions) + namespace Luau { static std::optional> magicFunctionSelect( - TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate); + TypeChecker& typechecker, + const ScopePtr& scope, + const AstExprCall& expr, + WithPredicate withPredicate +); static std::optional> magicFunctionSetMetaTable( - TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate); + TypeChecker& typechecker, + const ScopePtr& scope, + const AstExprCall& expr, + WithPredicate withPredicate +); static std::optional> magicFunctionAssert( - TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate); + TypeChecker& typechecker, + const ScopePtr& scope, + const AstExprCall& expr, + WithPredicate withPredicate +); static std::optional> magicFunctionPack( - TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate); + TypeChecker& typechecker, + const ScopePtr& scope, + const AstExprCall& expr, + WithPredicate withPredicate +); static std::optional> magicFunctionRequire( - TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate); + TypeChecker& typechecker, + const ScopePtr& scope, + const AstExprCall& expr, + WithPredicate withPredicate +); static bool dcrMagicFunctionSelect(MagicFunctionCallContext context); static bool dcrMagicFunctionRequire(MagicFunctionCallContext context); static bool dcrMagicFunctionPack(MagicFunctionCallContext context); +static bool dcrMagicFunctionFreeze(MagicFunctionCallContext context); TypeId makeUnion(TypeArena& arena, std::vector&& types) { @@ -52,37 +84,58 @@ TypeId makeIntersection(TypeArena& arena, std::vector&& types) return arena.addType(IntersectionType{std::move(types)}); } -TypeId makeOption(Frontend& frontend, TypeArena& arena, TypeId t) -{ - return makeUnion(arena, {frontend.typeChecker.nilType, t}); -} - -TypeId makeOption(TypeChecker& typeChecker, TypeArena& arena, TypeId t) +TypeId makeOption(NotNull builtinTypes, TypeArena& arena, TypeId t) { - return makeUnion(arena, {typeChecker.nilType, t}); + LUAU_ASSERT(t); + return makeUnion(arena, {builtinTypes->nilType, t}); } TypeId makeFunction( - TypeArena& arena, std::optional selfType, std::initializer_list paramTypes, std::initializer_list retTypes) + TypeArena& arena, + std::optional selfType, + std::initializer_list paramTypes, + std::initializer_list retTypes, + bool checked +) { - return makeFunction(arena, selfType, {}, {}, paramTypes, {}, retTypes); + return makeFunction(arena, selfType, {}, {}, paramTypes, {}, retTypes, checked); } -TypeId makeFunction(TypeArena& arena, std::optional selfType, std::initializer_list generics, - std::initializer_list genericPacks, std::initializer_list paramTypes, std::initializer_list retTypes) +TypeId makeFunction( + TypeArena& arena, + std::optional selfType, + std::initializer_list generics, + std::initializer_list genericPacks, + std::initializer_list paramTypes, + std::initializer_list retTypes, + bool checked +) { - return makeFunction(arena, selfType, generics, genericPacks, paramTypes, {}, retTypes); + return makeFunction(arena, selfType, generics, genericPacks, paramTypes, {}, retTypes, checked); } -TypeId makeFunction(TypeArena& arena, std::optional selfType, std::initializer_list paramTypes, - std::initializer_list paramNames, std::initializer_list retTypes) +TypeId makeFunction( + TypeArena& arena, + std::optional selfType, + std::initializer_list paramTypes, + std::initializer_list paramNames, + std::initializer_list retTypes, + bool checked +) { - return makeFunction(arena, selfType, {}, {}, paramTypes, paramNames, retTypes); + return makeFunction(arena, selfType, {}, {}, paramTypes, paramNames, retTypes, checked); } -TypeId makeFunction(TypeArena& arena, std::optional selfType, std::initializer_list generics, - std::initializer_list genericPacks, std::initializer_list paramTypes, std::initializer_list paramNames, - std::initializer_list retTypes) +TypeId makeFunction( + TypeArena& arena, + std::optional selfType, + std::initializer_list generics, + std::initializer_list genericPacks, + std::initializer_list paramTypes, + std::initializer_list paramNames, + std::initializer_list retTypes, + bool checked +) { std::vector params; if (selfType) @@ -109,6 +162,8 @@ TypeId makeFunction(TypeArena& arena, std::optional selfType, std::initi ftv.argNames.push_back(std::nullopt); } + ftv.isCheckedFunction = checked; + return arena.addType(std::move(ftv)); } @@ -136,6 +191,14 @@ void attachDcrMagicRefinement(TypeId ty, DcrMagicRefinement fn) LUAU_ASSERT(!"Got a non functional type"); } +void attachDcrMagicFunctionTypeCheck(TypeId ty, DcrMagicFunctionTypeCheck fn) +{ + if (auto ftv = getMutable(ty)) + ftv->dcrMagicTypeCheck = fn; + else + LUAU_ASSERT(!"Got a non functional type"); +} + Property makeProperty(TypeId ty, std::optional documentationSymbol) { return { @@ -148,85 +211,52 @@ Property makeProperty(TypeId ty, std::optional documentationSymbol) }; } -void addGlobalBinding(Frontend& frontend, const std::string& name, TypeId ty, const std::string& packageName) +void addGlobalBinding(GlobalTypes& globals, const std::string& name, TypeId ty, const std::string& packageName) { - addGlobalBinding(frontend, frontend.getGlobalScope(), name, ty, packageName); + addGlobalBinding(globals, globals.globalScope, name, ty, packageName); } -void addGlobalBinding(TypeChecker& typeChecker, const ScopePtr& scope, const std::string& name, TypeId ty, const std::string& packageName); - -void addGlobalBinding(TypeChecker& typeChecker, const std::string& name, TypeId ty, const std::string& packageName) +void addGlobalBinding(GlobalTypes& globals, const std::string& name, Binding binding) { - addGlobalBinding(typeChecker, typeChecker.globalScope, name, ty, packageName); + addGlobalBinding(globals, globals.globalScope, name, binding); } -void addGlobalBinding(Frontend& frontend, const std::string& name, Binding binding) -{ - addGlobalBinding(frontend, frontend.getGlobalScope(), name, binding); -} - -void addGlobalBinding(TypeChecker& typeChecker, const std::string& name, Binding binding) -{ - addGlobalBinding(typeChecker, typeChecker.globalScope, name, binding); -} - -void addGlobalBinding(Frontend& frontend, const ScopePtr& scope, const std::string& name, TypeId ty, const std::string& packageName) -{ - std::string documentationSymbol = packageName + "/global/" + name; - addGlobalBinding(frontend, scope, name, Binding{ty, Location{}, {}, {}, documentationSymbol}); -} - -void addGlobalBinding(TypeChecker& typeChecker, const ScopePtr& scope, const std::string& name, TypeId ty, const std::string& packageName) +void addGlobalBinding(GlobalTypes& globals, const ScopePtr& scope, const std::string& name, TypeId ty, const std::string& packageName) { std::string documentationSymbol = packageName + "/global/" + name; - addGlobalBinding(typeChecker, scope, name, Binding{ty, Location{}, {}, {}, documentationSymbol}); + addGlobalBinding(globals, scope, name, Binding{ty, Location{}, {}, {}, documentationSymbol}); } -void addGlobalBinding(Frontend& frontend, const ScopePtr& scope, const std::string& name, Binding binding) +void addGlobalBinding(GlobalTypes& globals, const ScopePtr& scope, const std::string& name, Binding binding) { - addGlobalBinding(frontend.typeChecker, scope, name, binding); + scope->bindings[globals.globalNames.names->getOrAdd(name.c_str())] = binding; } -void addGlobalBinding(TypeChecker& typeChecker, const ScopePtr& scope, const std::string& name, Binding binding) +std::optional tryGetGlobalBinding(GlobalTypes& globals, const std::string& name) { - scope->bindings[typeChecker.globalNames.names->getOrAdd(name.c_str())] = binding; -} - -std::optional tryGetGlobalBinding(TypeChecker& typeChecker, const std::string& name) -{ - AstName astName = typeChecker.globalNames.names->getOrAdd(name.c_str()); - auto it = typeChecker.globalScope->bindings.find(astName); - if (it != typeChecker.globalScope->bindings.end()) + AstName astName = globals.globalNames.names->getOrAdd(name.c_str()); + auto it = globals.globalScope->bindings.find(astName); + if (it != globals.globalScope->bindings.end()) return it->second; return std::nullopt; } -TypeId getGlobalBinding(TypeChecker& typeChecker, const std::string& name) +TypeId getGlobalBinding(GlobalTypes& globals, const std::string& name) { - auto t = tryGetGlobalBinding(typeChecker, name); + auto t = tryGetGlobalBinding(globals, name); LUAU_ASSERT(t.has_value()); return t->typeId; } -TypeId getGlobalBinding(Frontend& frontend, const std::string& name) -{ - return getGlobalBinding(frontend.typeChecker, name); -} - -std::optional tryGetGlobalBinding(Frontend& frontend, const std::string& name) -{ - return tryGetGlobalBinding(frontend.typeChecker, name); -} - -Binding* tryGetGlobalBindingRef(TypeChecker& typeChecker, const std::string& name) +Binding* tryGetGlobalBindingRef(GlobalTypes& globals, const std::string& name) { - AstName astName = typeChecker.globalNames.names->get(name.c_str()); + AstName astName = globals.globalNames.names->get(name.c_str()); if (astName == AstName()) return nullptr; - auto it = typeChecker.globalScope->bindings.find(astName); - if (it != typeChecker.globalScope->bindings.end()) + auto it = globals.globalScope->bindings.find(astName); + if (it != globals.globalScope->bindings.end()) return &it->second; return nullptr; @@ -240,34 +270,25 @@ void assignPropDocumentationSymbols(TableType::Props& props, const std::string& } } -void registerBuiltinTypes(Frontend& frontend) +void registerBuiltinGlobals(Frontend& frontend, GlobalTypes& globals, bool typeCheckForAutocomplete) { - frontend.getGlobalScope()->addBuiltinTypeBinding("any", TypeFun{{}, frontend.builtinTypes->anyType}); - frontend.getGlobalScope()->addBuiltinTypeBinding("nil", TypeFun{{}, frontend.builtinTypes->nilType}); - frontend.getGlobalScope()->addBuiltinTypeBinding("number", TypeFun{{}, frontend.builtinTypes->numberType}); - frontend.getGlobalScope()->addBuiltinTypeBinding("string", TypeFun{{}, frontend.builtinTypes->stringType}); - frontend.getGlobalScope()->addBuiltinTypeBinding("boolean", TypeFun{{}, frontend.builtinTypes->booleanType}); - frontend.getGlobalScope()->addBuiltinTypeBinding("thread", TypeFun{{}, frontend.builtinTypes->threadType}); - frontend.getGlobalScope()->addBuiltinTypeBinding("unknown", TypeFun{{}, frontend.builtinTypes->unknownType}); - frontend.getGlobalScope()->addBuiltinTypeBinding("never", TypeFun{{}, frontend.builtinTypes->neverType}); -} - -void registerBuiltinGlobals(TypeChecker& typeChecker) -{ - LUAU_ASSERT(!typeChecker.globalTypes.types.isFrozen()); - LUAU_ASSERT(!typeChecker.globalTypes.typePacks.isFrozen()); + LUAU_ASSERT(!globals.globalTypes.types.isFrozen()); + LUAU_ASSERT(!globals.globalTypes.typePacks.isFrozen()); - TypeId nilType = typeChecker.nilType; + TypeArena& arena = globals.globalTypes; + NotNull builtinTypes = globals.builtinTypes; - TypeArena& arena = typeChecker.globalTypes; - NotNull builtinTypes = typeChecker.builtinTypes; + if (FFlag::LuauSolverV2) + builtinTypeFunctions().addToScope(NotNull{&arena}, NotNull{globals.globalScope.get()}); - LoadDefinitionFileResult loadResult = Luau::loadDefinitionFile(typeChecker, typeChecker.globalScope, getBuiltinDefinitionSource(), "@luau"); + LoadDefinitionFileResult loadResult = frontend.loadDefinitionFile( + globals, globals.globalScope, getBuiltinDefinitionSource(), "@luau", /* captureComments */ false, typeCheckForAutocomplete + ); LUAU_ASSERT(loadResult.success); TypeId genericK = arena.addType(GenericType{"K"}); TypeId genericV = arena.addType(GenericType{"V"}); - TypeId mapOfKtoV = arena.addType(TableType{{}, TableIndexer(genericK, genericV), typeChecker.globalScope->level, TableState::Generic}); + TypeId mapOfKtoV = arena.addType(TableType{{}, TableIndexer(genericK, genericV), globals.globalScope->level, TableState::Generic}); std::optional stringMetatableTy = getMetatable(builtinTypes->stringType, builtinTypes); LUAU_ASSERT(stringMetatableTy); @@ -277,45 +298,68 @@ void registerBuiltinGlobals(TypeChecker& typeChecker) auto it = stringMetatableTable->props.find("__index"); LUAU_ASSERT(it != stringMetatableTable->props.end()); - addGlobalBinding(typeChecker, "string", it->second.type, "@luau"); + addGlobalBinding(globals, "string", it->second.type(), "@luau"); // next(t: Table, i: K?) -> (K?, V) - TypePackId nextArgsTypePack = arena.addTypePack(TypePack{{mapOfKtoV, makeOption(typeChecker, arena, genericK)}}); - TypePackId nextRetsTypePack = arena.addTypePack(TypePack{{makeOption(typeChecker, arena, genericK), genericV}}); - addGlobalBinding(typeChecker, "next", arena.addType(FunctionType{{genericK, genericV}, {}, nextArgsTypePack, nextRetsTypePack}), "@luau"); + TypePackId nextArgsTypePack = arena.addTypePack(TypePack{{mapOfKtoV, makeOption(builtinTypes, arena, genericK)}}); + TypePackId nextRetsTypePack = arena.addTypePack(TypePack{{makeOption(builtinTypes, arena, genericK), genericV}}); + addGlobalBinding(globals, "next", arena.addType(FunctionType{{genericK, genericV}, {}, nextArgsTypePack, nextRetsTypePack}), "@luau"); TypePackId pairsArgsTypePack = arena.addTypePack({mapOfKtoV}); TypeId pairsNext = arena.addType(FunctionType{nextArgsTypePack, nextRetsTypePack}); - TypePackId pairsReturnTypePack = arena.addTypePack(TypePack{{pairsNext, mapOfKtoV, nilType}}); + TypePackId pairsReturnTypePack = arena.addTypePack(TypePack{{pairsNext, mapOfKtoV, builtinTypes->nilType}}); // pairs(t: Table) -> ((Table, K?) -> (K, V), Table, nil) - addGlobalBinding(typeChecker, "pairs", arena.addType(FunctionType{{genericK, genericV}, {}, pairsArgsTypePack, pairsReturnTypePack}), "@luau"); + addGlobalBinding(globals, "pairs", arena.addType(FunctionType{{genericK, genericV}, {}, pairsArgsTypePack, pairsReturnTypePack}), "@luau"); TypeId genericMT = arena.addType(GenericType{"MT"}); - TableType tab{TableState::Generic, typeChecker.globalScope->level}; + TableType tab{TableState::Generic, globals.globalScope->level}; TypeId tabTy = arena.addType(tab); TypeId tableMetaMT = arena.addType(MetatableType{tabTy, genericMT}); - addGlobalBinding(typeChecker, "getmetatable", makeFunction(arena, std::nullopt, {genericMT}, {}, {tableMetaMT}, {genericMT}), "@luau"); - - // clang-format off - // setmetatable(T, MT) -> { @metatable MT, T } - addGlobalBinding(typeChecker, "setmetatable", - arena.addType( - FunctionType{ - {genericMT}, - {}, - arena.addTypePack(TypePack{{tabTy, genericMT}}), - arena.addTypePack(TypePack{{tableMetaMT}}) - } - ), "@luau" - ); - // clang-format on + // getmetatable : ({ @metatable MT, {+ +} }) -> MT + addGlobalBinding(globals, "getmetatable", makeFunction(arena, std::nullopt, {genericMT}, {}, {tableMetaMT}, {genericMT}), "@luau"); - for (const auto& pair : typeChecker.globalScope->bindings) + if (FFlag::LuauSolverV2) + { + TypeId genericT = arena.addType(GenericType{"T"}); + TypeId tMetaMT = arena.addType(MetatableType{genericT, genericMT}); + + // clang-format off + // setmetatable(T, MT) -> { @metatable MT, T } + addGlobalBinding(globals, "setmetatable", + arena.addType( + FunctionType{ + {genericT, genericMT}, + {}, + arena.addTypePack(TypePack{{genericT, genericMT}}), + arena.addTypePack(TypePack{{tMetaMT}}) + } + ), "@luau" + ); + // clang-format on + } + else + { + // clang-format off + // setmetatable(T, MT) -> { @metatable MT, T } + addGlobalBinding(globals, "setmetatable", + arena.addType( + FunctionType{ + {genericMT}, + {}, + arena.addTypePack(TypePack{{tabTy, genericMT}}), + arena.addTypePack(TypePack{{tableMetaMT}}) + } + ), "@luau" + ); + // clang-format on + } + + for (const auto& pair : globals.globalScope->bindings) { persist(pair.second.typeId); @@ -326,135 +370,691 @@ void registerBuiltinGlobals(TypeChecker& typeChecker) } } - attachMagicFunction(getGlobalBinding(typeChecker, "assert"), magicFunctionAssert); - attachMagicFunction(getGlobalBinding(typeChecker, "setmetatable"), magicFunctionSetMetaTable); - attachMagicFunction(getGlobalBinding(typeChecker, "select"), magicFunctionSelect); - attachDcrMagicFunction(getGlobalBinding(typeChecker, "select"), dcrMagicFunctionSelect); + attachMagicFunction(getGlobalBinding(globals, "assert"), magicFunctionAssert); - if (TableType* ttv = getMutable(getGlobalBinding(typeChecker, "table"))) + if (FFlag::LuauSolverV2) { - // tabTy is a generic table type which we can't express via declaration syntax yet - ttv->props["freeze"] = makeProperty(makeFunction(arena, std::nullopt, {tabTy}, {tabTy}), "@luau/global/table.freeze"); - ttv->props["clone"] = makeProperty(makeFunction(arena, std::nullopt, {tabTy}, {tabTy}), "@luau/global/table.clone"); + // declare function assert(value: T, errorMessage: string?): intersect + TypeId genericT = arena.addType(GenericType{"T"}); + TypeId refinedTy = arena.addType(TypeFunctionInstanceType{ + NotNull{&builtinTypeFunctions().intersectFunc}, {genericT, arena.addType(NegationType{builtinTypes->falsyType})}, {} + }); + + TypeId assertTy = arena.addType(FunctionType{ + {genericT}, {}, arena.addTypePack(TypePack{{genericT, builtinTypes->optionalStringType}}), arena.addTypePack(TypePack{{refinedTy}}) + }); + addGlobalBinding(globals, "assert", assertTy, "@luau"); + } + + attachMagicFunction(getGlobalBinding(globals, "setmetatable"), magicFunctionSetMetaTable); + attachMagicFunction(getGlobalBinding(globals, "select"), magicFunctionSelect); + attachDcrMagicFunction(getGlobalBinding(globals, "select"), dcrMagicFunctionSelect); - if (FFlag::LuauDeprecateTableGetnForeach) + if (TableType* ttv = getMutable(getGlobalBinding(globals, "table"))) + { + if (FFlag::LuauSolverV2) { - ttv->props["getn"].deprecated = true; - ttv->props["getn"].deprecatedSuggestion = "#"; - ttv->props["foreach"].deprecated = true; - ttv->props["foreachi"].deprecated = true; + // CLI-114044 - The new solver does not yet support generic tables, + // which act, in an odd way, like generics that are constrained to + // the top table type. We do the best we can by modelling these + // functions using unconstrained generics. It's not quite right, + // but it'll be ok for now. + TypeId genericTy = arena.addType(GenericType{"T"}); + TypePackId thePack = arena.addTypePack({genericTy}); + TypeId idTyWithMagic = arena.addType(FunctionType{{genericTy}, {}, thePack, thePack}); + ttv->props["freeze"] = makeProperty(idTyWithMagic, "@luau/global/table.freeze"); + + TypeId idTy = arena.addType(FunctionType{{genericTy}, {}, thePack, thePack}); + ttv->props["clone"] = makeProperty(idTy, "@luau/global/table.clone"); } + else + { + // tabTy is a generic table type which we can't express via declaration syntax yet + ttv->props["freeze"] = makeProperty(makeFunction(arena, std::nullopt, {tabTy}, {tabTy}), "@luau/global/table.freeze"); + ttv->props["clone"] = makeProperty(makeFunction(arena, std::nullopt, {tabTy}, {tabTy}), "@luau/global/table.clone"); + } + + ttv->props["getn"].deprecated = true; + ttv->props["getn"].deprecatedSuggestion = "#"; + ttv->props["foreach"].deprecated = true; + ttv->props["foreachi"].deprecated = true; - attachMagicFunction(ttv->props["pack"].type, magicFunctionPack); - attachDcrMagicFunction(ttv->props["pack"].type, dcrMagicFunctionPack); + attachMagicFunction(ttv->props["pack"].type(), magicFunctionPack); + attachDcrMagicFunction(ttv->props["pack"].type(), dcrMagicFunctionPack); + if (FFlag::LuauTypestateBuiltins) + attachDcrMagicFunction(ttv->props["freeze"].type(), dcrMagicFunctionFreeze); } - attachMagicFunction(getGlobalBinding(typeChecker, "require"), magicFunctionRequire); - attachDcrMagicFunction(getGlobalBinding(typeChecker, "require"), dcrMagicFunctionRequire); + if (FFlag::AutocompleteRequirePathSuggestions) + { + TypeId requireTy = getGlobalBinding(globals, "require"); + attachTag(requireTy, kRequireTagName); + attachMagicFunction(requireTy, magicFunctionRequire); + attachDcrMagicFunction(requireTy, dcrMagicFunctionRequire); + } + else + { + attachMagicFunction(getGlobalBinding(globals, "require"), magicFunctionRequire); + attachDcrMagicFunction(getGlobalBinding(globals, "require"), dcrMagicFunctionRequire); + } } -void registerBuiltinGlobals(Frontend& frontend) +static std::vector parseFormatString(NotNull builtinTypes, const char* data, size_t size) { - LUAU_ASSERT(!frontend.globalTypes.types.isFrozen()); - LUAU_ASSERT(!frontend.globalTypes.typePacks.isFrozen()); + const char* options = "cdiouxXeEfgGqs*"; - registerBuiltinTypes(frontend); + std::vector result; - TypeArena& arena = frontend.globalTypes; - NotNull builtinTypes = frontend.builtinTypes; + for (size_t i = 0; i < size; ++i) + { + if (data[i] == '%') + { + i++; + + if (i < size && data[i] == '%') + continue; + + // we just ignore all characters (including flags/precision) up until first alphabetic character + while (i < size && !(data[i] > 0 && (isalpha(data[i]) || data[i] == '*'))) + i++; + + if (i == size) + break; + + if (data[i] == 'q' || data[i] == 's') + result.push_back(builtinTypes->stringType); + else if (data[i] == '*') + result.push_back(builtinTypes->unknownType); + else if (strchr(options, data[i])) + result.push_back(builtinTypes->numberType); + else + result.push_back(builtinTypes->errorRecoveryType(builtinTypes->anyType)); + } + } - LoadDefinitionFileResult loadResult = frontend.loadDefinitionFile(getBuiltinDefinitionSource(), "@luau"); - LUAU_ASSERT(loadResult.success); + return result; +} - TypeId genericK = arena.addType(GenericType{"K"}); - TypeId genericV = arena.addType(GenericType{"V"}); - TypeId mapOfKtoV = arena.addType(TableType{{}, TableIndexer(genericK, genericV), frontend.getGlobalScope()->level, TableState::Generic}); +std::optional> magicFunctionFormat( + TypeChecker& typechecker, + const ScopePtr& scope, + const AstExprCall& expr, + WithPredicate withPredicate +) +{ + auto [paramPack, _predicates] = withPredicate; - std::optional stringMetatableTy = getMetatable(builtinTypes->stringType, builtinTypes); - LUAU_ASSERT(stringMetatableTy); - const TableType* stringMetatableTable = get(follow(*stringMetatableTy)); - LUAU_ASSERT(stringMetatableTable); + TypeArena& arena = typechecker.currentModule->internalTypes; - auto it = stringMetatableTable->props.find("__index"); - LUAU_ASSERT(it != stringMetatableTable->props.end()); + AstExprConstantString* fmt = nullptr; + if (auto index = expr.func->as(); index && expr.self) + { + if (auto group = index->expr->as()) + fmt = group->expr->as(); + else + fmt = index->expr->as(); + } - addGlobalBinding(frontend, "string", it->second.type, "@luau"); + if (!expr.self && expr.args.size > 0) + fmt = expr.args.data[0]->as(); - // next(t: Table, i: K?) -> (K?, V) - TypePackId nextArgsTypePack = arena.addTypePack(TypePack{{mapOfKtoV, makeOption(frontend, arena, genericK)}}); - TypePackId nextRetsTypePack = arena.addTypePack(TypePack{{makeOption(frontend, arena, genericK), genericV}}); - addGlobalBinding(frontend, "next", arena.addType(FunctionType{{genericK, genericV}, {}, nextArgsTypePack, nextRetsTypePack}), "@luau"); + if (!fmt) + return std::nullopt; - TypePackId pairsArgsTypePack = arena.addTypePack({mapOfKtoV}); + std::vector expected = parseFormatString(typechecker.builtinTypes, fmt->value.data, fmt->value.size); + const auto& [params, tail] = flatten(paramPack); - TypeId pairsNext = arena.addType(FunctionType{nextArgsTypePack, nextRetsTypePack}); - TypePackId pairsReturnTypePack = arena.addTypePack(TypePack{{pairsNext, mapOfKtoV, frontend.builtinTypes->nilType}}); + size_t paramOffset = 1; + size_t dataOffset = expr.self ? 0 : 1; - // pairs(t: Table) -> ((Table, K?) -> (K?, V), Table, nil) - addGlobalBinding(frontend, "pairs", arena.addType(FunctionType{{genericK, genericV}, {}, pairsArgsTypePack, pairsReturnTypePack}), "@luau"); + // unify the prefix one argument at a time + for (size_t i = 0; i < expected.size() && i + paramOffset < params.size(); ++i) + { + Location location = expr.args.data[std::min(i + dataOffset, expr.args.size - 1)]->location; - TypeId genericMT = arena.addType(GenericType{"MT"}); + typechecker.unify(params[i + paramOffset], expected[i], scope, location); + } - TableType tab{TableState::Generic, frontend.getGlobalScope()->level}; - TypeId tabTy = arena.addType(tab); + // if we know the argument count or if we have too many arguments for sure, we can issue an error + size_t numActualParams = params.size(); + size_t numExpectedParams = expected.size() + 1; // + 1 for the format string - TypeId tableMetaMT = arena.addType(MetatableType{tabTy, genericMT}); + if (numExpectedParams != numActualParams && (!tail || numExpectedParams < numActualParams)) + typechecker.reportError(TypeError{expr.location, CountMismatch{numExpectedParams, std::nullopt, numActualParams}}); - addGlobalBinding(frontend, "getmetatable", makeFunction(arena, std::nullopt, {genericMT}, {}, {tableMetaMT}, {genericMT}), "@luau"); - - // clang-format off - // setmetatable(T, MT) -> { @metatable MT, T } - addGlobalBinding(frontend, "setmetatable", - arena.addType( - FunctionType{ - {genericMT}, - {}, - arena.addTypePack(TypePack{{tabTy, genericMT}}), - arena.addTypePack(TypePack{{tableMetaMT}}) - } - ), "@luau" - ); - // clang-format on + return WithPredicate{arena.addTypePack({typechecker.stringType})}; +} - for (const auto& pair : frontend.getGlobalScope()->bindings) +static bool dcrMagicFunctionFormat(MagicFunctionCallContext context) +{ + TypeArena* arena = context.solver->arena; + + AstExprConstantString* fmt = nullptr; + if (auto index = context.callSite->func->as(); index && context.callSite->self) { - persist(pair.second.typeId); + if (auto group = index->expr->as()) + fmt = group->expr->as(); + else + fmt = index->expr->as(); + } - if (TableType* ttv = getMutable(pair.second.typeId)) + if (!context.callSite->self && context.callSite->args.size > 0) + fmt = context.callSite->args.data[0]->as(); + + if (!fmt) + return false; + + std::vector expected = parseFormatString(context.solver->builtinTypes, fmt->value.data, fmt->value.size); + const auto& [params, tail] = flatten(context.arguments); + + size_t paramOffset = 1; + + // unify the prefix one argument at a time - needed if any of the involved types are free + for (size_t i = 0; i < expected.size() && i + paramOffset < params.size(); ++i) + { + context.solver->unify(context.constraint, params[i + paramOffset], expected[i]); + } + + // if we know the argument count or if we have too many arguments for sure, we can issue an error + size_t numActualParams = params.size(); + size_t numExpectedParams = expected.size() + 1; // + 1 for the format string + + if (numExpectedParams != numActualParams && (!tail || numExpectedParams < numActualParams)) + context.solver->reportError(TypeError{context.callSite->location, CountMismatch{numExpectedParams, std::nullopt, numActualParams}}); + + // This is invoked at solve time, so we just need to provide a type for the result of :/.format + TypePackId resultPack = arena->addTypePack({context.solver->builtinTypes->stringType}); + asMutable(context.result)->ty.emplace(resultPack); + + return true; +} + +static void dcrMagicFunctionTypeCheckFormat(MagicFunctionTypeCheckContext context) +{ + AstExprConstantString* fmt = nullptr; + if (auto index = context.callSite->func->as(); index && context.callSite->self) + { + if (auto group = index->expr->as()) + fmt = group->expr->as(); + else + fmt = index->expr->as(); + } + + if (!context.callSite->self && context.callSite->args.size > 0) + fmt = context.callSite->args.data[0]->as(); + + if (!fmt) + { + if (FFlag::LuauStringFormatArityFix) + context.typechecker->reportError(CountMismatch{1, std::nullopt, 0, CountMismatch::Arg, true, "string.format"}, context.callSite->location); + return; + } + + std::vector expected = parseFormatString(context.builtinTypes, fmt->value.data, fmt->value.size); + const auto& [params, tail] = flatten(context.arguments); + + size_t paramOffset = 1; + // Compare the expressions passed with the types the function expects to determine whether this function was called with : or . + bool calledWithSelf = expected.size() == context.callSite->args.size; + // unify the prefix one argument at a time + for (size_t i = 0; i < expected.size() && i + paramOffset < params.size(); ++i) + { + TypeId actualTy = params[i + paramOffset]; + TypeId expectedTy = expected[i]; + Location location = context.callSite->args.data[i + (calledWithSelf ? 0 : paramOffset)]->location; + // use subtyping instead here + SubtypingResult result = context.typechecker->subtyping->isSubtype(actualTy, expectedTy, context.checkScope); + if (!result.isSubtype) { - if (!ttv->name) - ttv->name = "typeof(" + toString(pair.first) + ")"; + Reasonings reasonings = context.typechecker->explainReasonings(actualTy, expectedTy, location, result); + context.typechecker->reportError(TypeMismatch{expectedTy, actualTy, reasonings.toString()}, location); } } +} - attachMagicFunction(getGlobalBinding(frontend, "assert"), magicFunctionAssert); - attachMagicFunction(getGlobalBinding(frontend, "setmetatable"), magicFunctionSetMetaTable); - attachMagicFunction(getGlobalBinding(frontend, "select"), magicFunctionSelect); - attachDcrMagicFunction(getGlobalBinding(frontend, "select"), dcrMagicFunctionSelect); +static std::vector parsePatternString(NotNull builtinTypes, const char* data, size_t size) +{ + std::vector result; + int depth = 0; + bool parsingSet = false; - if (TableType* ttv = getMutable(getGlobalBinding(frontend, "table"))) + for (size_t i = 0; i < size; ++i) { - // tabTy is a generic table type which we can't express via declaration syntax yet - ttv->props["freeze"] = makeProperty(makeFunction(arena, std::nullopt, {tabTy}, {tabTy}), "@luau/global/table.freeze"); - ttv->props["clone"] = makeProperty(makeFunction(arena, std::nullopt, {tabTy}, {tabTy}), "@luau/global/table.clone"); + if (data[i] == '%') + { + ++i; + if (!parsingSet && i < size && data[i] == 'b') + i += 2; + } + else if (!parsingSet && data[i] == '[') + { + parsingSet = true; + if (i + 1 < size && data[i + 1] == ']') + i += 1; + } + else if (parsingSet && data[i] == ']') + { + parsingSet = false; + } + else if (data[i] == '(') + { + if (parsingSet) + continue; + + if (i + 1 < size && data[i + 1] == ')') + { + i++; + result.push_back(builtinTypes->optionalNumberType); + continue; + } - if (FFlag::LuauDeprecateTableGetnForeach) + ++depth; + result.push_back(builtinTypes->optionalStringType); + } + else if (data[i] == ')') { - ttv->props["getn"].deprecated = true; - ttv->props["getn"].deprecatedSuggestion = "#"; - ttv->props["foreach"].deprecated = true; - ttv->props["foreachi"].deprecated = true; + if (parsingSet) + continue; + + --depth; + + if (depth < 0) + break; } + } + + if (depth != 0 || parsingSet) + return std::vector(); + + if (result.empty()) + result.push_back(builtinTypes->optionalStringType); + + return result; +} + +static std::optional> magicFunctionGmatch( + TypeChecker& typechecker, + const ScopePtr& scope, + const AstExprCall& expr, + WithPredicate withPredicate +) +{ + auto [paramPack, _predicates] = withPredicate; + const auto& [params, tail] = flatten(paramPack); + + if (params.size() != 2) + return std::nullopt; + + TypeArena& arena = typechecker.currentModule->internalTypes; + + AstExprConstantString* pattern = nullptr; + size_t index = expr.self ? 0 : 1; + if (expr.args.size > index) + pattern = expr.args.data[index]->as(); + + if (!pattern) + return std::nullopt; + + std::vector returnTypes = parsePatternString(typechecker.builtinTypes, pattern->value.data, pattern->value.size); + + if (returnTypes.empty()) + return std::nullopt; + + typechecker.unify(params[0], typechecker.stringType, scope, expr.args.data[0]->location); + + const TypePackId emptyPack = arena.addTypePack({}); + const TypePackId returnList = arena.addTypePack(returnTypes); + const TypeId iteratorType = arena.addType(FunctionType{emptyPack, returnList}); + return WithPredicate{arena.addTypePack({iteratorType})}; +} - attachMagicFunction(ttv->props["pack"].type, magicFunctionPack); +static bool dcrMagicFunctionGmatch(MagicFunctionCallContext context) +{ + const auto& [params, tail] = flatten(context.arguments); + + if (params.size() != 2) + return false; + + TypeArena* arena = context.solver->arena; + + AstExprConstantString* pattern = nullptr; + size_t index = context.callSite->self ? 0 : 1; + if (context.callSite->args.size > index) + pattern = context.callSite->args.data[index]->as(); + + if (!pattern) + return false; + + std::vector returnTypes = parsePatternString(context.solver->builtinTypes, pattern->value.data, pattern->value.size); + + if (returnTypes.empty()) + return false; + + context.solver->unify(context.constraint, params[0], context.solver->builtinTypes->stringType); + + const TypePackId emptyPack = arena->addTypePack({}); + const TypePackId returnList = arena->addTypePack(returnTypes); + const TypeId iteratorType = arena->addType(FunctionType{emptyPack, returnList}); + const TypePackId resTypePack = arena->addTypePack({iteratorType}); + asMutable(context.result)->ty.emplace(resTypePack); + + return true; +} + +static std::optional> magicFunctionMatch( + TypeChecker& typechecker, + const ScopePtr& scope, + const AstExprCall& expr, + WithPredicate withPredicate +) +{ + auto [paramPack, _predicates] = withPredicate; + const auto& [params, tail] = flatten(paramPack); + + if (params.size() < 2 || params.size() > 3) + return std::nullopt; + + TypeArena& arena = typechecker.currentModule->internalTypes; + + AstExprConstantString* pattern = nullptr; + size_t patternIndex = expr.self ? 0 : 1; + if (expr.args.size > patternIndex) + pattern = expr.args.data[patternIndex]->as(); + + if (!pattern) + return std::nullopt; + + std::vector returnTypes = parsePatternString(typechecker.builtinTypes, pattern->value.data, pattern->value.size); + + if (returnTypes.empty()) + return std::nullopt; + + typechecker.unify(params[0], typechecker.stringType, scope, expr.args.data[0]->location); + + const TypeId optionalNumber = arena.addType(UnionType{{typechecker.nilType, typechecker.numberType}}); + + size_t initIndex = expr.self ? 1 : 2; + if (params.size() == 3 && expr.args.size > initIndex) + typechecker.unify(params[2], optionalNumber, scope, expr.args.data[initIndex]->location); + + const TypePackId returnList = arena.addTypePack(returnTypes); + return WithPredicate{returnList}; +} + +static bool dcrMagicFunctionMatch(MagicFunctionCallContext context) +{ + const auto& [params, tail] = flatten(context.arguments); + + if (params.size() < 2 || params.size() > 3) + return false; + + TypeArena* arena = context.solver->arena; + + AstExprConstantString* pattern = nullptr; + size_t patternIndex = context.callSite->self ? 0 : 1; + if (context.callSite->args.size > patternIndex) + pattern = context.callSite->args.data[patternIndex]->as(); + + if (!pattern) + return false; + + std::vector returnTypes = parsePatternString(context.solver->builtinTypes, pattern->value.data, pattern->value.size); + + if (returnTypes.empty()) + return false; + + context.solver->unify(context.constraint, params[0], context.solver->builtinTypes->stringType); + + const TypeId optionalNumber = arena->addType(UnionType{{context.solver->builtinTypes->nilType, context.solver->builtinTypes->numberType}}); + + size_t initIndex = context.callSite->self ? 1 : 2; + if (params.size() == 3 && context.callSite->args.size > initIndex) + context.solver->unify(context.constraint, params[2], optionalNumber); + + const TypePackId returnList = arena->addTypePack(returnTypes); + asMutable(context.result)->ty.emplace(returnList); + + return true; +} + +static std::optional> magicFunctionFind( + TypeChecker& typechecker, + const ScopePtr& scope, + const AstExprCall& expr, + WithPredicate withPredicate +) +{ + auto [paramPack, _predicates] = withPredicate; + const auto& [params, tail] = flatten(paramPack); + + if (params.size() < 2 || params.size() > 4) + return std::nullopt; + + TypeArena& arena = typechecker.currentModule->internalTypes; + + AstExprConstantString* pattern = nullptr; + size_t patternIndex = expr.self ? 0 : 1; + if (expr.args.size > patternIndex) + pattern = expr.args.data[patternIndex]->as(); + + if (!pattern) + return std::nullopt; + + bool plain = false; + size_t plainIndex = expr.self ? 2 : 3; + if (expr.args.size > plainIndex) + { + AstExprConstantBool* p = expr.args.data[plainIndex]->as(); + plain = p && p->value; + } + + std::vector returnTypes; + if (!plain) + { + returnTypes = parsePatternString(typechecker.builtinTypes, pattern->value.data, pattern->value.size); + + if (returnTypes.empty()) + return std::nullopt; + } + + typechecker.unify(params[0], typechecker.stringType, scope, expr.args.data[0]->location); + + const TypeId optionalNumber = arena.addType(UnionType{{typechecker.nilType, typechecker.numberType}}); + const TypeId optionalBoolean = arena.addType(UnionType{{typechecker.nilType, typechecker.booleanType}}); + + size_t initIndex = expr.self ? 1 : 2; + if (params.size() >= 3 && expr.args.size > initIndex) + typechecker.unify(params[2], optionalNumber, scope, expr.args.data[initIndex]->location); + + if (params.size() == 4 && expr.args.size > plainIndex) + typechecker.unify(params[3], optionalBoolean, scope, expr.args.data[plainIndex]->location); + + returnTypes.insert(returnTypes.begin(), {optionalNumber, optionalNumber}); + + const TypePackId returnList = arena.addTypePack(returnTypes); + return WithPredicate{returnList}; +} + +static bool dcrMagicFunctionFind(MagicFunctionCallContext context) +{ + const auto& [params, tail] = flatten(context.arguments); + + if (params.size() < 2 || params.size() > 4) + return false; + + TypeArena* arena = context.solver->arena; + NotNull builtinTypes = context.solver->builtinTypes; + + AstExprConstantString* pattern = nullptr; + size_t patternIndex = context.callSite->self ? 0 : 1; + if (context.callSite->args.size > patternIndex) + pattern = context.callSite->args.data[patternIndex]->as(); + + if (!pattern) + return false; + + bool plain = false; + size_t plainIndex = context.callSite->self ? 2 : 3; + if (context.callSite->args.size > plainIndex) + { + AstExprConstantBool* p = context.callSite->args.data[plainIndex]->as(); + plain = p && p->value; } - attachMagicFunction(getGlobalBinding(frontend, "require"), magicFunctionRequire); - attachDcrMagicFunction(getGlobalBinding(frontend, "require"), dcrMagicFunctionRequire); + std::vector returnTypes; + if (!plain) + { + returnTypes = parsePatternString(builtinTypes, pattern->value.data, pattern->value.size); + + if (returnTypes.empty()) + return false; + } + + context.solver->unify(context.constraint, params[0], builtinTypes->stringType); + + const TypeId optionalNumber = arena->addType(UnionType{{builtinTypes->nilType, builtinTypes->numberType}}); + const TypeId optionalBoolean = arena->addType(UnionType{{builtinTypes->nilType, builtinTypes->booleanType}}); + + size_t initIndex = context.callSite->self ? 1 : 2; + if (params.size() >= 3 && context.callSite->args.size > initIndex) + context.solver->unify(context.constraint, params[2], optionalNumber); + + if (params.size() == 4 && context.callSite->args.size > plainIndex) + context.solver->unify(context.constraint, params[3], optionalBoolean); + + returnTypes.insert(returnTypes.begin(), {optionalNumber, optionalNumber}); + + const TypePackId returnList = arena->addTypePack(returnTypes); + asMutable(context.result)->ty.emplace(returnList); + return true; +} + +TypeId makeStringMetatable(NotNull builtinTypes) +{ + NotNull arena{builtinTypes->arena.get()}; + + const TypeId nilType = builtinTypes->nilType; + const TypeId numberType = builtinTypes->numberType; + const TypeId booleanType = builtinTypes->booleanType; + const TypeId stringType = builtinTypes->stringType; + + const TypeId optionalNumber = arena->addType(UnionType{{nilType, numberType}}); + const TypeId optionalString = arena->addType(UnionType{{nilType, stringType}}); + const TypeId optionalBoolean = arena->addType(UnionType{{nilType, booleanType}}); + + const TypePackId oneStringPack = arena->addTypePack({stringType}); + const TypePackId anyTypePack = builtinTypes->anyTypePack; + + const TypePackId variadicTailPack = FFlag::LuauSolverV2 ? builtinTypes->unknownTypePack : anyTypePack; + const TypePackId emptyPack = arena->addTypePack({}); + const TypePackId stringVariadicList = arena->addTypePack(TypePackVar{VariadicTypePack{stringType}}); + const TypePackId numberVariadicList = arena->addTypePack(TypePackVar{VariadicTypePack{numberType}}); + + + FunctionType formatFTV{arena->addTypePack(TypePack{{stringType}, variadicTailPack}), oneStringPack}; + formatFTV.magicFunction = &magicFunctionFormat; + formatFTV.isCheckedFunction = true; + const TypeId formatFn = arena->addType(formatFTV); + attachDcrMagicFunction(formatFn, dcrMagicFunctionFormat); + attachDcrMagicFunctionTypeCheck(formatFn, dcrMagicFunctionTypeCheckFormat); + + + const TypeId stringToStringType = makeFunction(*arena, std::nullopt, {}, {}, {stringType}, {}, {stringType}, /* checked */ true); + + const TypeId replArgType = arena->addType(UnionType{ + {stringType, + arena->addType(TableType({}, TableIndexer(stringType, stringType), TypeLevel{}, TableState::Generic)), + makeFunction(*arena, std::nullopt, {}, {}, {stringType}, {}, {stringType}, /* checked */ false)} + }); + const TypeId gsubFunc = + makeFunction(*arena, stringType, {}, {}, {stringType, replArgType, optionalNumber}, {}, {stringType, numberType}, /* checked */ false); + const TypeId gmatchFunc = + makeFunction(*arena, stringType, {}, {}, {stringType}, {}, {arena->addType(FunctionType{emptyPack, stringVariadicList})}, /* checked */ true); + attachMagicFunction(gmatchFunc, magicFunctionGmatch); + attachDcrMagicFunction(gmatchFunc, dcrMagicFunctionGmatch); + + FunctionType matchFuncTy{ + arena->addTypePack({stringType, stringType, optionalNumber}), arena->addTypePack(TypePackVar{VariadicTypePack{stringType}}) + }; + matchFuncTy.isCheckedFunction = true; + const TypeId matchFunc = arena->addType(matchFuncTy); + attachMagicFunction(matchFunc, magicFunctionMatch); + attachDcrMagicFunction(matchFunc, dcrMagicFunctionMatch); + + FunctionType findFuncTy{ + arena->addTypePack({stringType, stringType, optionalNumber, optionalBoolean}), + arena->addTypePack(TypePack{{optionalNumber, optionalNumber}, stringVariadicList}) + }; + findFuncTy.isCheckedFunction = true; + const TypeId findFunc = arena->addType(findFuncTy); + attachMagicFunction(findFunc, magicFunctionFind); + attachDcrMagicFunction(findFunc, dcrMagicFunctionFind); + + // string.byte : string -> number? -> number? -> ...number + FunctionType stringDotByte{arena->addTypePack({stringType, optionalNumber, optionalNumber}), numberVariadicList}; + stringDotByte.isCheckedFunction = true; + + // string.char : .... number -> string + FunctionType stringDotChar{numberVariadicList, arena->addTypePack({stringType})}; + stringDotChar.isCheckedFunction = true; + + // string.unpack : string -> string -> number? -> ...any + FunctionType stringDotUnpack{ + arena->addTypePack(TypePack{{stringType, stringType, optionalNumber}}), + variadicTailPack, + }; + stringDotUnpack.isCheckedFunction = true; + + TableType::Props stringLib = { + {"byte", {arena->addType(stringDotByte)}}, + {"char", {arena->addType(stringDotChar)}}, + {"find", {findFunc}}, + {"format", {formatFn}}, // FIXME + {"gmatch", {gmatchFunc}}, + {"gsub", {gsubFunc}}, + {"len", {makeFunction(*arena, stringType, {}, {}, {}, {}, {numberType}, /* checked */ true)}}, + {"lower", {stringToStringType}}, + {"match", {matchFunc}}, + {"rep", {makeFunction(*arena, stringType, {}, {}, {numberType}, {}, {stringType}, /* checked */ true)}}, + {"reverse", {stringToStringType}}, + {"sub", {makeFunction(*arena, stringType, {}, {}, {numberType, optionalNumber}, {}, {stringType}, /* checked */ true)}}, + {"upper", {stringToStringType}}, + {"split", + {makeFunction( + *arena, + stringType, + {}, + {}, + {optionalString}, + {}, + {arena->addType(TableType{{}, TableIndexer{numberType, stringType}, TypeLevel{}, TableState::Sealed})}, + /* checked */ true + )}}, + {"pack", + {arena->addType(FunctionType{ + arena->addTypePack(TypePack{{stringType}, variadicTailPack}), + oneStringPack, + })}}, + {"packsize", {makeFunction(*arena, stringType, {}, {}, {}, {}, {numberType}, /* checked */ true)}}, + {"unpack", {arena->addType(stringDotUnpack)}}, + }; + + assignPropDocumentationSymbols(stringLib, "@luau/global/string"); + + TypeId tableType = arena->addType(TableType{std::move(stringLib), std::nullopt, TypeLevel{}, TableState::Sealed}); + + if (TableType* ttv = getMutable(tableType)) + ttv->name = "typeof(string)"; + + return arena->addType(TableType{{{{"__index", {tableType}}}}, std::nullopt, TypeLevel{}, TableState::Sealed}); } static std::optional> magicFunctionSelect( - TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate) + TypeChecker& typechecker, + const ScopePtr& scope, + const AstExprCall& expr, + WithPredicate withPredicate +) { auto [paramPack, _predicates] = withPredicate; @@ -540,7 +1140,11 @@ static bool dcrMagicFunctionSelect(MagicFunctionCallContext context) } static std::optional> magicFunctionSetMetaTable( - TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate) + TypeChecker& typechecker, + const ScopePtr& scope, + const AstExprCall& expr, + WithPredicate withPredicate +) { auto [paramPack, _predicates] = withPredicate; @@ -597,6 +1201,18 @@ static std::optional> magicFunctionSetMetaTable( else if (get(target) || get(target) || isTableIntersection(target)) { } + else if (isTableUnion(target)) + { + const UnionType* ut = get(target); + LUAU_ASSERT(ut); + + std::vector resultParts; + + for (TypeId ty : ut) + resultParts.push_back(arena.addType(MetatableType{ty, mt})); + + return WithPredicate{arena.addTypePack({arena.addType(UnionType{std::move(resultParts)})})}; + } else { typechecker.reportError(TypeError{expr.location, GenericError{"setmetatable should take a table"}}); @@ -606,7 +1222,11 @@ static std::optional> magicFunctionSetMetaTable( } static std::optional> magicFunctionAssert( - TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate) + TypeChecker& typechecker, + const ScopePtr& scope, + const AstExprCall& expr, + WithPredicate withPredicate +) { auto [paramPack, predicates] = withPredicate; @@ -636,7 +1256,11 @@ static std::optional> magicFunctionAssert( } static std::optional> magicFunctionPack( - TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate) + TypeChecker& typechecker, + const ScopePtr& scope, + const AstExprCall& expr, + WithPredicate withPredicate +) { auto [paramPack, _predicates] = withPredicate; @@ -714,6 +1338,58 @@ static bool dcrMagicFunctionPack(MagicFunctionCallContext context) return true; } +static bool dcrMagicFunctionFreeze(MagicFunctionCallContext context) +{ + LUAU_ASSERT(FFlag::LuauTypestateBuiltins); + + TypeArena* arena = context.solver->arena; + const DataFlowGraph* dfg = context.solver->dfg.get(); + Scope* scope = context.constraint->scope.get(); + + const auto& [paramTypes, paramTail] = extendTypePack(*arena, context.solver->builtinTypes, context.arguments, 1); + LUAU_ASSERT(paramTypes.size() >= 1); + + TypeId inputType = follow(paramTypes.at(0)); + + // we'll check if it's a table first since this magic function also produces the error if it's not until we have bounded generics + if (!get(inputType)) + { + context.solver->reportError(TypeMismatch{context.solver->builtinTypes->tableType, inputType}, context.callSite->argLocation); + return false; + } + + AstExpr* targetExpr = context.callSite->args.data[0]; + std::optional resultDef = dfg->getDefOptional(targetExpr); + std::optional resultTy = resultDef ? scope->lookup(*resultDef) : std::nullopt; + + // Clone the input type, this will become our final result type after we mutate it. + CloneState cloneState{context.solver->builtinTypes}; + TypeId clonedType = shallowClone(inputType, *arena, cloneState); + auto tableTy = getMutable(clonedType); + // `clone` should not break this. + LUAU_ASSERT(tableTy); + tableTy->state = TableState::Sealed; + tableTy->syntheticName = std::nullopt; + + // We'll mutate the table to make every property type read-only. + for (auto iter = tableTy->props.begin(); iter != tableTy->props.end();) + { + if (iter->second.isWriteOnly()) + iter = tableTy->props.erase(iter); + else + { + iter->second.writeTy = std::nullopt; + iter++; + } + } + + if (resultTy) + asMutable(*resultTy)->ty.emplace(clonedType); + asMutable(context.result)->ty.emplace(arena->addTypePack({clonedType})); + + return true; +} + static bool checkRequirePath(TypeChecker& typechecker, AstExpr* expr) { // require(foo.parent.bar) will technically work, but it depends on legacy goop that @@ -737,7 +1413,11 @@ static bool checkRequirePath(TypeChecker& typechecker, AstExpr* expr) } static std::optional> magicFunctionRequire( - TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate) + TypeChecker& typechecker, + const ScopePtr& scope, + const AstExprCall& expr, + WithPredicate withPredicate +) { TypeArena& arena = typechecker.currentModule->internalTypes; @@ -750,7 +1430,7 @@ static std::optional> magicFunctionRequire( if (!checkRequirePath(typechecker, expr.args.data[0])) return std::nullopt; - if (auto moduleInfo = typechecker.resolver->resolveModuleInfo(typechecker.currentModuleName, expr)) + if (auto moduleInfo = typechecker.resolver->resolveModuleInfo(typechecker.currentModule->name, expr)) return WithPredicate{arena.addTypePack({typechecker.checkRequire(scope, *moduleInfo, expr.location)})}; return std::nullopt; @@ -801,4 +1481,52 @@ static bool dcrMagicFunctionRequire(MagicFunctionCallContext context) return false; } +bool matchSetMetatable(const AstExprCall& call) +{ + const char* smt = "setmetatable"; + + if (call.args.size != 2) + return false; + + const AstExprGlobal* funcAsGlobal = call.func->as(); + if (!funcAsGlobal || funcAsGlobal->name != smt) + return false; + + return true; +} + +bool matchTableFreeze(const AstExprCall& call) +{ + if (call.args.size < 1) + return false; + + const AstExprIndexName* index = call.func->as(); + if (!index || index->index != "freeze") + return false; + + const AstExprGlobal* global = index->expr->as(); + if (!global || global->name != "table") + return false; + + return true; +} + +bool matchAssert(const AstExprCall& call) +{ + if (call.args.size < 1) + return false; + + const AstExprGlobal* funcAsGlobal = call.func->as(); + if (!funcAsGlobal || funcAsGlobal->name != "assert") + return false; + + return true; +} + +bool shouldTypestateForFirstArgument(const AstExprCall& call) +{ + // TODO: magic function for setmetatable and assert and then add them + return matchTableFreeze(call); +} + } // namespace Luau diff --git a/Analysis/src/Clone.cpp b/Analysis/src/Clone.cpp index ff8e0c3c2..745a03074 100644 --- a/Analysis/src/Clone.cpp +++ b/Analysis/src/Clone.cpp @@ -1,15 +1,15 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/Clone.h" -#include "Luau/RecursionCounter.h" -#include "Luau/TxnLog.h" +#include "Luau/NotNull.h" +#include "Luau/Type.h" #include "Luau/TypePack.h" #include "Luau/Unifiable.h" -LUAU_FASTFLAG(DebugLuauCopyBeforeNormalizing) -LUAU_FASTFLAG(LuauClonePublicInterfaceLess) +LUAU_FASTFLAG(LuauSolverV2) -LUAU_FASTINTVARIABLE(LuauTypeCloneRecursionLimit, 300) +// For each `Luau::clone` call, we will clone only up to N amount of types _and_ packs, as controlled by this limit. +LUAU_FASTINTVARIABLE(LuauTypeCloneIterationLimit, 100'000) namespace Luau { @@ -17,354 +17,468 @@ namespace Luau namespace { -struct TypePackCloner; +using Kind = Variant; -/* - * Both TypeCloner and TypePackCloner work by depositing the requested type variable into the appropriate 'seen' set. - * They do not return anything because their sole consumer (the deepClone function) already has a pointer into this storage. - */ - -struct TypeCloner +template +const T* get(const Kind& kind) { - TypeCloner(TypeArena& dest, TypeId typeId, CloneState& cloneState) - : dest(dest) - , typeId(typeId) - , seenTypes(cloneState.seenTypes) - , seenTypePacks(cloneState.seenTypePacks) - , cloneState(cloneState) - { - } - - TypeArena& dest; - TypeId typeId; - SeenTypes& seenTypes; - SeenTypePacks& seenTypePacks; - CloneState& cloneState; - - template - void defaultClone(const T& t); - - void operator()(const Unifiable::Free& t); - void operator()(const Unifiable::Generic& t); - void operator()(const Unifiable::Bound& t); - void operator()(const Unifiable::Error& t); - void operator()(const BlockedType& t); - void operator()(const PendingExpansionType& t); - void operator()(const PrimitiveType& t); - void operator()(const SingletonType& t); - void operator()(const FunctionType& t); - void operator()(const TableType& t); - void operator()(const MetatableType& t); - void operator()(const ClassType& t); - void operator()(const AnyType& t); - void operator()(const UnionType& t); - void operator()(const IntersectionType& t); - void operator()(const LazyType& t); - void operator()(const UnknownType& t); - void operator()(const NeverType& t); - void operator()(const NegationType& t); -}; + return get_if(&kind); +} -struct TypePackCloner +class TypeCloner { - TypeArena& dest; - TypePackId typePackId; - SeenTypes& seenTypes; - SeenTypePacks& seenTypePacks; - CloneState& cloneState; + NotNull arena; + NotNull builtinTypes; + + // A queue of kinds where we cloned it, but whose interior types hasn't + // been updated to point to new clones. Once all of its interior types + // has been updated, it gets removed from the queue. + std::vector queue; + + NotNull types; + NotNull packs; + + int steps = 0; - TypePackCloner(TypeArena& dest, TypePackId typePackId, CloneState& cloneState) - : dest(dest) - , typePackId(typePackId) - , seenTypes(cloneState.seenTypes) - , seenTypePacks(cloneState.seenTypePacks) - , cloneState(cloneState) +public: + TypeCloner(NotNull arena, NotNull builtinTypes, NotNull types, NotNull packs) + : arena(arena) + , builtinTypes(builtinTypes) + , types(types) + , packs(packs) { } - template - void defaultClone(const T& t) + TypeId clone(TypeId ty) { - TypePackId cloned = dest.addTypePack(TypePackVar{t}); - seenTypePacks[typePackId] = cloned; + shallowClone(ty); + run(); + + if (hasExceededIterationLimit()) + { + TypeId error = builtinTypes->errorRecoveryType(); + (*types)[ty] = error; + return error; + } + + return find(ty).value_or(builtinTypes->errorRecoveryType()); } - void operator()(const Unifiable::Free& t) + TypePackId clone(TypePackId tp) { - defaultClone(t); + shallowClone(tp); + run(); + + if (hasExceededIterationLimit()) + { + TypePackId error = builtinTypes->errorRecoveryTypePack(); + (*packs)[tp] = error; + return error; + } + + return find(tp).value_or(builtinTypes->errorRecoveryTypePack()); } - void operator()(const Unifiable::Generic& t) + +private: + bool hasExceededIterationLimit() const { - defaultClone(t); + if (FInt::LuauTypeCloneIterationLimit == 0) + return false; + + return steps + queue.size() >= size_t(FInt::LuauTypeCloneIterationLimit); } - void operator()(const Unifiable::Error& t) + + void run() { - defaultClone(t); + while (!queue.empty()) + { + ++steps; + + if (hasExceededIterationLimit()) + break; + + Kind kind = queue.back(); + queue.pop_back(); + + if (find(kind)) + continue; + + cloneChildren(kind); + } } - void operator()(const BlockedTypePack& t) + std::optional find(TypeId ty) const { - defaultClone(t); + ty = follow(ty, FollowOption::DisableLazyTypeThunks); + if (auto it = types->find(ty); it != types->end()) + return it->second; + else if (ty->persistent) + return ty; + return std::nullopt; } - // While we are a-cloning, we can flatten out bound Types and make things a bit tighter. - // We just need to be sure that we rewrite pointers both to the binder and the bindee to the same pointer. - void operator()(const Unifiable::Bound& t) + std::optional find(TypePackId tp) const { - TypePackId cloned = clone(t.boundTo, dest, cloneState); - if (FFlag::DebugLuauCopyBeforeNormalizing) - cloned = dest.addTypePack(TypePackVar{BoundTypePack{cloned}}); - seenTypePacks[typePackId] = cloned; + tp = follow(tp); + if (auto it = packs->find(tp); it != packs->end()) + return it->second; + else if (tp->persistent) + return tp; + return std::nullopt; } - void operator()(const VariadicTypePack& t) + std::optional find(Kind kind) const { - TypePackId cloned = dest.addTypePack(TypePackVar{VariadicTypePack{clone(t.ty, dest, cloneState), /*hidden*/ t.hidden}}); - seenTypePacks[typePackId] = cloned; + if (auto ty = get(kind)) + return find(*ty); + else if (auto tp = get(kind)) + return find(*tp); + else + { + LUAU_ASSERT(!"Unknown kind?"); + return std::nullopt; + } + } + +public: + TypeId shallowClone(TypeId ty) + { + // We want to [`Luau::follow`] but without forcing the expansion of [`LazyType`]s. + ty = follow(ty, FollowOption::DisableLazyTypeThunks); + + if (auto clone = find(ty)) + return *clone; + else if (ty->persistent) + return ty; + + TypeId target = arena->addType(ty->ty); + asMutable(target)->documentationSymbol = ty->documentationSymbol; + + if (auto generic = getMutable(target)) + generic->scope = nullptr; + else if (auto free = getMutable(target)) + free->scope = nullptr; + else if (auto fn = getMutable(target)) + fn->scope = nullptr; + else if (auto table = getMutable(target)) + table->scope = nullptr; + + (*types)[ty] = target; + queue.push_back(target); + return target; } - void operator()(const TypePack& t) + TypePackId shallowClone(TypePackId tp) { - TypePackId cloned = dest.addTypePack(TypePack{}); - TypePack* destTp = getMutable(cloned); - LUAU_ASSERT(destTp != nullptr); - seenTypePacks[typePackId] = cloned; + tp = follow(tp); + + if (auto clone = find(tp)) + return *clone; + else if (tp->persistent) + return tp; + + TypePackId target = arena->addTypePack(tp->ty); - for (TypeId ty : t.head) - destTp->head.push_back(clone(ty, dest, cloneState)); + if (auto generic = getMutable(target)) + generic->scope = nullptr; + else if (auto free = getMutable(target)) + free->scope = nullptr; - if (t.tail) - destTp->tail = clone(*t.tail, dest, cloneState); + (*packs)[tp] = target; + queue.push_back(target); + return target; } -}; -template -void TypeCloner::defaultClone(const T& t) -{ - TypeId cloned = dest.addType(t); - seenTypes[typeId] = cloned; -} +private: + Property shallowClone(const Property& p) + { + if (FFlag::LuauSolverV2) + { + std::optional cloneReadTy; + if (auto ty = p.readTy) + cloneReadTy = shallowClone(*ty); + + std::optional cloneWriteTy; + if (auto ty = p.writeTy) + cloneWriteTy = shallowClone(*ty); + + Property cloned = Property::create(cloneReadTy, cloneWriteTy); + cloned.deprecated = p.deprecated; + cloned.deprecatedSuggestion = p.deprecatedSuggestion; + cloned.location = p.location; + cloned.tags = p.tags; + cloned.documentationSymbol = p.documentationSymbol; + cloned.typeLocation = p.typeLocation; + return cloned; + } + else + { + return Property{ + shallowClone(p.type()), + p.deprecated, + p.deprecatedSuggestion, + p.location, + p.tags, + p.documentationSymbol, + p.typeLocation, + }; + } + } -void TypeCloner::operator()(const Unifiable::Free& t) -{ - defaultClone(t); -} + void cloneChildren(TypeId ty) + { + return visit( + [&](auto&& t) + { + return cloneChildren(&t); + }, + asMutable(ty)->ty + ); + } -void TypeCloner::operator()(const Unifiable::Generic& t) -{ - defaultClone(t); -} + void cloneChildren(TypePackId tp) + { + return visit( + [&](auto&& t) + { + return cloneChildren(&t); + }, + asMutable(tp)->ty + ); + } -void TypeCloner::operator()(const Unifiable::Bound& t) -{ - TypeId boundTo = clone(t.boundTo, dest, cloneState); - if (FFlag::DebugLuauCopyBeforeNormalizing) - boundTo = dest.addType(BoundType{boundTo}); - seenTypes[typeId] = boundTo; -} + void cloneChildren(Kind kind) + { + if (auto ty = get(kind)) + return cloneChildren(*ty); + else if (auto tp = get(kind)) + return cloneChildren(*tp); + else + LUAU_ASSERT(!"Item holds neither TypeId nor TypePackId when enqueuing its children?"); + } -void TypeCloner::operator()(const Unifiable::Error& t) -{ - defaultClone(t); -} + // ErrorType and ErrorTypePack is an alias to this type. + void cloneChildren(Unifiable::Error* t) + { + // noop. + } -void TypeCloner::operator()(const BlockedType& t) -{ - defaultClone(t); -} + void cloneChildren(BoundType* t) + { + t->boundTo = shallowClone(t->boundTo); + } -void TypeCloner::operator()(const PendingExpansionType& t) -{ - TypeId res = dest.addType(PendingExpansionType{t.prefix, t.name, t.typeArguments, t.packArguments}); - PendingExpansionType* petv = getMutable(res); - LUAU_ASSERT(petv); + void cloneChildren(FreeType* t) + { + if (t->lowerBound) + t->lowerBound = shallowClone(t->lowerBound); + if (t->upperBound) + t->upperBound = shallowClone(t->upperBound); + } - seenTypes[typeId] = res; + void cloneChildren(GenericType* t) + { + // TOOD: clone upper bounds. + } - std::vector typeArguments; - for (TypeId arg : t.typeArguments) - typeArguments.push_back(clone(arg, dest, cloneState)); + void cloneChildren(PrimitiveType* t) + { + // noop. + } - std::vector packArguments; - for (TypePackId arg : t.packArguments) - packArguments.push_back(clone(arg, dest, cloneState)); + void cloneChildren(BlockedType* t) + { + // TODO: In the new solver, we should ice. + } - petv->typeArguments = std::move(typeArguments); - petv->packArguments = std::move(packArguments); -} + void cloneChildren(PendingExpansionType* t) + { + // TODO: In the new solver, we should ice. + } -void TypeCloner::operator()(const PrimitiveType& t) -{ - defaultClone(t); -} + void cloneChildren(SingletonType* t) + { + // noop. + } -void TypeCloner::operator()(const SingletonType& t) -{ - defaultClone(t); -} + void cloneChildren(FunctionType* t) + { + for (TypeId& g : t->generics) + g = shallowClone(g); -void TypeCloner::operator()(const FunctionType& t) -{ - // FISHY: We always erase the scope when we clone things. clone() was - // originally written so that we could copy a module's type surface into an - // export arena. This probably dates to that. - TypeId result = dest.addType(FunctionType{TypeLevel{0, 0}, {}, {}, nullptr, nullptr, t.definition, t.hasSelf}); - FunctionType* ftv = getMutable(result); - LUAU_ASSERT(ftv != nullptr); - - seenTypes[typeId] = result; - - for (TypeId generic : t.generics) - ftv->generics.push_back(clone(generic, dest, cloneState)); - - for (TypePackId genericPack : t.genericPacks) - ftv->genericPacks.push_back(clone(genericPack, dest, cloneState)); - - ftv->tags = t.tags; - ftv->argTypes = clone(t.argTypes, dest, cloneState); - ftv->argNames = t.argNames; - ftv->retTypes = clone(t.retTypes, dest, cloneState); - ftv->hasNoGenerics = t.hasNoGenerics; -} + for (TypePackId& gp : t->genericPacks) + gp = shallowClone(gp); -void TypeCloner::operator()(const TableType& t) -{ - // If table is now bound to another one, we ignore the content of the original - if (!FFlag::DebugLuauCopyBeforeNormalizing && t.boundTo) + t->argTypes = shallowClone(t->argTypes); + t->retTypes = shallowClone(t->retTypes); + } + + void cloneChildren(TableType* t) { - TypeId boundTo = clone(*t.boundTo, dest, cloneState); - seenTypes[typeId] = boundTo; - return; + if (t->indexer) + { + t->indexer->indexType = shallowClone(t->indexer->indexType); + t->indexer->indexResultType = shallowClone(t->indexer->indexResultType); + } + + for (auto& [_, p] : t->props) + p = shallowClone(p); + + for (TypeId& ty : t->instantiatedTypeParams) + ty = shallowClone(ty); + + for (TypePackId& tp : t->instantiatedTypePackParams) + tp = shallowClone(tp); } - TypeId result = dest.addType(TableType{}); - TableType* ttv = getMutable(result); - LUAU_ASSERT(ttv != nullptr); + void cloneChildren(MetatableType* t) + { + t->table = shallowClone(t->table); + t->metatable = shallowClone(t->metatable); + } - *ttv = t; + void cloneChildren(ClassType* t) + { + for (auto& [_, p] : t->props) + p = shallowClone(p); - seenTypes[typeId] = result; + if (t->parent) + t->parent = shallowClone(*t->parent); - ttv->level = TypeLevel{0, 0}; + if (t->metatable) + t->metatable = shallowClone(*t->metatable); - if (FFlag::DebugLuauCopyBeforeNormalizing && t.boundTo) - ttv->boundTo = clone(*t.boundTo, dest, cloneState); + if (t->indexer) + { + t->indexer->indexType = shallowClone(t->indexer->indexType); + t->indexer->indexResultType = shallowClone(t->indexer->indexResultType); + } + } - for (const auto& [name, prop] : t.props) - ttv->props[name] = {clone(prop.type, dest, cloneState), prop.deprecated, {}, prop.location, prop.tags}; + void cloneChildren(AnyType* t) + { + // noop. + } - if (t.indexer) - ttv->indexer = TableIndexer{clone(t.indexer->indexType, dest, cloneState), clone(t.indexer->indexResultType, dest, cloneState)}; + void cloneChildren(NoRefineType* t) + { + // noop. + } - for (TypeId& arg : ttv->instantiatedTypeParams) - arg = clone(arg, dest, cloneState); + void cloneChildren(UnionType* t) + { + for (TypeId& ty : t->options) + ty = shallowClone(ty); + } - for (TypePackId& arg : ttv->instantiatedTypePackParams) - arg = clone(arg, dest, cloneState); + void cloneChildren(IntersectionType* t) + { + for (TypeId& ty : t->parts) + ty = shallowClone(ty); + } - ttv->definitionModuleName = t.definitionModuleName; - ttv->definitionLocation = t.definitionLocation; - ttv->tags = t.tags; -} + void cloneChildren(LazyType* t) + { + if (auto unwrapped = t->unwrapped.load()) + t->unwrapped.store(shallowClone(unwrapped)); + } -void TypeCloner::operator()(const MetatableType& t) -{ - TypeId result = dest.addType(MetatableType{}); - MetatableType* mtv = getMutable(result); - seenTypes[typeId] = result; + void cloneChildren(UnknownType* t) + { + // noop. + } - mtv->table = clone(t.table, dest, cloneState); - mtv->metatable = clone(t.metatable, dest, cloneState); -} + void cloneChildren(NeverType* t) + { + // noop. + } -void TypeCloner::operator()(const ClassType& t) -{ - TypeId result = dest.addType(ClassType{t.name, {}, std::nullopt, std::nullopt, t.tags, t.userData, t.definitionModuleName}); - ClassType* ctv = getMutable(result); + void cloneChildren(NegationType* t) + { + t->ty = shallowClone(t->ty); + } - seenTypes[typeId] = result; + void cloneChildren(TypeFunctionInstanceType* t) + { + for (TypeId& ty : t->typeArguments) + ty = shallowClone(ty); - for (const auto& [name, prop] : t.props) - ctv->props[name] = {clone(prop.type, dest, cloneState), prop.deprecated, {}, prop.location, prop.tags}; + for (TypePackId& tp : t->packArguments) + tp = shallowClone(tp); + } - if (t.parent) - ctv->parent = clone(*t.parent, dest, cloneState); + void cloneChildren(FreeTypePack* t) + { + // TODO: clone lower and upper bounds. + // TODO: In the new solver, we should ice. + } - if (t.metatable) - ctv->metatable = clone(*t.metatable, dest, cloneState); -} + void cloneChildren(GenericTypePack* t) + { + // TOOD: clone upper bounds. + } -void TypeCloner::operator()(const AnyType& t) -{ - defaultClone(t); -} + void cloneChildren(BlockedTypePack* t) + { + // TODO: In the new solver, we should ice. + } -void TypeCloner::operator()(const UnionType& t) -{ - std::vector options; - options.reserve(t.options.size()); + void cloneChildren(BoundTypePack* t) + { + t->boundTo = shallowClone(t->boundTo); + } - for (TypeId ty : t.options) - options.push_back(clone(ty, dest, cloneState)); + void cloneChildren(VariadicTypePack* t) + { + t->ty = shallowClone(t->ty); + } - TypeId result = dest.addType(UnionType{std::move(options)}); - seenTypes[typeId] = result; -} + void cloneChildren(TypePack* t) + { + for (TypeId& ty : t->head) + ty = shallowClone(ty); -void TypeCloner::operator()(const IntersectionType& t) -{ - TypeId result = dest.addType(IntersectionType{}); - seenTypes[typeId] = result; + if (t->tail) + t->tail = shallowClone(*t->tail); + } - IntersectionType* option = getMutable(result); - LUAU_ASSERT(option != nullptr); + void cloneChildren(TypeFunctionInstanceTypePack* t) + { + for (TypeId& ty : t->typeArguments) + ty = shallowClone(ty); - for (TypeId ty : t.parts) - option->parts.push_back(clone(ty, dest, cloneState)); -} + for (TypePackId& tp : t->packArguments) + tp = shallowClone(tp); + } +}; -void TypeCloner::operator()(const LazyType& t) -{ - defaultClone(t); -} +} // namespace -void TypeCloner::operator()(const UnknownType& t) +TypePackId shallowClone(TypePackId tp, TypeArena& dest, CloneState& cloneState) { - defaultClone(t); -} + if (tp->persistent) + return tp; -void TypeCloner::operator()(const NeverType& t) -{ - defaultClone(t); + TypeCloner cloner{NotNull{&dest}, cloneState.builtinTypes, NotNull{&cloneState.seenTypes}, NotNull{&cloneState.seenTypePacks}}; + return cloner.shallowClone(tp); } -void TypeCloner::operator()(const NegationType& t) +TypeId shallowClone(TypeId typeId, TypeArena& dest, CloneState& cloneState) { - TypeId result = dest.addType(AnyType{}); - seenTypes[typeId] = result; + if (typeId->persistent) + return typeId; - TypeId ty = clone(t.ty, dest, cloneState); - asMutable(result)->ty = NegationType{ty}; + TypeCloner cloner{NotNull{&dest}, cloneState.builtinTypes, NotNull{&cloneState.seenTypes}, NotNull{&cloneState.seenTypePacks}}; + return cloner.shallowClone(typeId); } -} // anonymous namespace - TypePackId clone(TypePackId tp, TypeArena& dest, CloneState& cloneState) { if (tp->persistent) return tp; - RecursionLimiter _ra(&cloneState.recursionCount, FInt::LuauTypeCloneRecursionLimit); - - TypePackId& res = cloneState.seenTypePacks[tp]; - - if (res == nullptr) - { - TypePackCloner cloner{dest, tp, cloneState}; - Luau::visit(cloner, tp->ty); // Mutates the storage that 'res' points into. - } - - return res; + TypeCloner cloner{NotNull{&dest}, cloneState.builtinTypes, NotNull{&cloneState.seenTypes}, NotNull{&cloneState.seenTypePacks}}; + return cloner.clone(tp); } TypeId clone(TypeId typeId, TypeArena& dest, CloneState& cloneState) @@ -372,136 +486,35 @@ TypeId clone(TypeId typeId, TypeArena& dest, CloneState& cloneState) if (typeId->persistent) return typeId; - RecursionLimiter _ra(&cloneState.recursionCount, FInt::LuauTypeCloneRecursionLimit); - - TypeId& res = cloneState.seenTypes[typeId]; - - if (res == nullptr) - { - TypeCloner cloner{dest, typeId, cloneState}; - Luau::visit(cloner, typeId->ty); // Mutates the storage that 'res' points into. - - // Persistent types are not being cloned and we get the original type back which might be read-only - if (!res->persistent) - { - asMutable(res)->documentationSymbol = typeId->documentationSymbol; - } - } - - return res; + TypeCloner cloner{NotNull{&dest}, cloneState.builtinTypes, NotNull{&cloneState.seenTypes}, NotNull{&cloneState.seenTypePacks}}; + return cloner.clone(typeId); } TypeFun clone(const TypeFun& typeFun, TypeArena& dest, CloneState& cloneState) { - TypeFun result; + TypeCloner cloner{NotNull{&dest}, cloneState.builtinTypes, NotNull{&cloneState.seenTypes}, NotNull{&cloneState.seenTypePacks}}; + + TypeFun copy = typeFun; - for (auto param : typeFun.typeParams) + for (auto& param : copy.typeParams) { - TypeId ty = clone(param.ty, dest, cloneState); - std::optional defaultValue; + param.ty = cloner.clone(param.ty); if (param.defaultValue) - defaultValue = clone(*param.defaultValue, dest, cloneState); - - result.typeParams.push_back({ty, defaultValue}); + param.defaultValue = cloner.clone(*param.defaultValue); } - for (auto param : typeFun.typePackParams) + for (auto& param : copy.typePackParams) { - TypePackId tp = clone(param.tp, dest, cloneState); - std::optional defaultValue; + param.tp = cloner.clone(param.tp); if (param.defaultValue) - defaultValue = clone(*param.defaultValue, dest, cloneState); - - result.typePackParams.push_back({tp, defaultValue}); - } - - result.type = clone(typeFun.type, dest, cloneState); - - return result; -} - -TypeId shallowClone(TypeId ty, TypeArena& dest, const TxnLog* log, bool alwaysClone) -{ - ty = log->follow(ty); - - TypeId result = ty; - - if (auto pty = log->pending(ty)) - ty = &pty->pending; - - if (const FunctionType* ftv = get(ty)) - { - FunctionType clone = FunctionType{ftv->level, ftv->scope, ftv->argTypes, ftv->retTypes, ftv->definition, ftv->hasSelf}; - clone.generics = ftv->generics; - clone.genericPacks = ftv->genericPacks; - clone.magicFunction = ftv->magicFunction; - clone.dcrMagicFunction = ftv->dcrMagicFunction; - clone.dcrMagicRefinement = ftv->dcrMagicRefinement; - clone.tags = ftv->tags; - clone.argNames = ftv->argNames; - result = dest.addType(std::move(clone)); - } - else if (const TableType* ttv = get(ty)) - { - LUAU_ASSERT(!ttv->boundTo); - TableType clone = TableType{ttv->props, ttv->indexer, ttv->level, ttv->scope, ttv->state}; - clone.definitionModuleName = ttv->definitionModuleName; - clone.definitionLocation = ttv->definitionLocation; - clone.name = ttv->name; - clone.syntheticName = ttv->syntheticName; - clone.instantiatedTypeParams = ttv->instantiatedTypeParams; - clone.instantiatedTypePackParams = ttv->instantiatedTypePackParams; - clone.tags = ttv->tags; - result = dest.addType(std::move(clone)); - } - else if (const MetatableType* mtv = get(ty)) - { - MetatableType clone = MetatableType{mtv->table, mtv->metatable}; - clone.syntheticName = mtv->syntheticName; - result = dest.addType(std::move(clone)); - } - else if (const UnionType* utv = get(ty)) - { - UnionType clone; - clone.options = utv->options; - result = dest.addType(std::move(clone)); + param.defaultValue = cloner.clone(*param.defaultValue); } - else if (const IntersectionType* itv = get(ty)) - { - IntersectionType clone; - clone.parts = itv->parts; - result = dest.addType(std::move(clone)); - } - else if (const PendingExpansionType* petv = get(ty)) - { - PendingExpansionType clone{petv->prefix, petv->name, petv->typeArguments, petv->packArguments}; - result = dest.addType(std::move(clone)); - } - else if (const ClassType* ctv = get(ty); FFlag::LuauClonePublicInterfaceLess && ctv && alwaysClone) - { - ClassType clone{ctv->name, ctv->props, ctv->parent, ctv->metatable, ctv->tags, ctv->userData, ctv->definitionModuleName}; - result = dest.addType(std::move(clone)); - } - else if (FFlag::LuauClonePublicInterfaceLess && alwaysClone) - { - result = dest.addType(*ty); - } - else if (const NegationType* ntv = get(ty)) - { - result = dest.addType(NegationType{ntv->ty}); - } - else - return result; - asMutable(result)->documentationSymbol = ty->documentationSymbol; - return result; -} + copy.type = cloner.clone(copy.type); -TypeId shallowClone(TypeId ty, NotNull dest) -{ - return shallowClone(ty, *dest, TxnLog::empty()); + return copy; } } // namespace Luau diff --git a/Analysis/src/Constraint.cpp b/Analysis/src/Constraint.cpp index 3a6417dc1..a62879fae 100644 --- a/Analysis/src/Constraint.cpp +++ b/Analysis/src/Constraint.cpp @@ -1,6 +1,7 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/Constraint.h" +#include "Luau/VisitType.h" namespace Luau { @@ -12,4 +13,127 @@ Constraint::Constraint(NotNull scope, const Location& location, Constrain { } +struct ReferenceCountInitializer : TypeOnceVisitor +{ + + DenseHashSet* result; + + ReferenceCountInitializer(DenseHashSet* result) + : result(result) + { + } + + bool visit(TypeId ty, const FreeType&) override + { + result->insert(ty); + return false; + } + + bool visit(TypeId ty, const BlockedType&) override + { + result->insert(ty); + return false; + } + + bool visit(TypeId ty, const PendingExpansionType&) override + { + result->insert(ty); + return false; + } + + bool visit(TypeId ty, const ClassType&) override + { + // ClassTypes never contain free types. + return false; + } +}; + +bool isReferenceCountedType(const TypeId typ) +{ + // n.b. this should match whatever `ReferenceCountInitializer` includes. + return get(typ) || get(typ) || get(typ); +} + +DenseHashSet Constraint::getMaybeMutatedFreeTypes() const +{ + // For the purpose of this function and reference counting in general, we are only considering + // mutations that affect the _bounds_ of the free type, and not something that may bind the free + // type itself to a new type. As such, `ReduceConstraint` and `GeneralizationConstraint` have no + // contribution to the output set here. + + DenseHashSet types{{}}; + ReferenceCountInitializer rci{&types}; + + if (auto ec = get(*this)) + { + rci.traverse(ec->resultType); + // `EqualityConstraints` should not mutate `assignmentType`. + } + else if (auto sc = get(*this)) + { + rci.traverse(sc->subType); + rci.traverse(sc->superType); + } + else if (auto psc = get(*this)) + { + rci.traverse(psc->subPack); + rci.traverse(psc->superPack); + } + else if (auto itc = get(*this)) + { + for (TypeId ty : itc->variables) + rci.traverse(ty); + // `IterableConstraints` should not mutate `iterator`. + } + else if (auto nc = get(*this)) + { + rci.traverse(nc->namedType); + } + else if (auto taec = get(*this)) + { + rci.traverse(taec->target); + } + else if (auto fchc = get(*this)) + { + rci.traverse(fchc->argsPack); + } + else if (auto ptc = get(*this)) + { + rci.traverse(ptc->freeType); + } + else if (auto hpc = get(*this)) + { + rci.traverse(hpc->resultType); + // `HasPropConstraints` should not mutate `subjectType`. + } + else if (auto hic = get(*this)) + { + rci.traverse(hic->resultType); + // `HasIndexerConstraint` should not mutate `subjectType` or `indexType`. + } + else if (auto apc = get(*this)) + { + rci.traverse(apc->lhsType); + rci.traverse(apc->rhsType); + } + else if (auto aic = get(*this)) + { + rci.traverse(aic->lhsType); + rci.traverse(aic->indexType); + rci.traverse(aic->rhsType); + } + else if (auto uc = get(*this)) + { + for (TypeId ty : uc->resultPack) + rci.traverse(ty); + // `UnpackConstraint` should not mutate `sourcePack`. + } + else if (auto rpc = get(*this)) + { + rci.traverse(rpc->tp); + } + + return types; +} + } // namespace Luau diff --git a/Analysis/src/ConstraintGenerator.cpp b/Analysis/src/ConstraintGenerator.cpp new file mode 100644 index 000000000..e242df8ec --- /dev/null +++ b/Analysis/src/ConstraintGenerator.cpp @@ -0,0 +1,3827 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/ConstraintGenerator.h" + +#include "Luau/Ast.h" +#include "Luau/BuiltinDefinitions.h" +#include "Luau/Common.h" +#include "Luau/Constraint.h" +#include "Luau/ControlFlow.h" +#include "Luau/DcrLogger.h" +#include "Luau/Def.h" +#include "Luau/DenseHash.h" +#include "Luau/ModuleResolver.h" +#include "Luau/RecursionCounter.h" +#include "Luau/Refinement.h" +#include "Luau/Scope.h" +#include "Luau/Simplify.h" +#include "Luau/StringUtils.h" +#include "Luau/TableLiteralInference.h" +#include "Luau/TimeTrace.h" +#include "Luau/Type.h" +#include "Luau/TypeFunction.h" +#include "Luau/TypePack.h" +#include "Luau/TypeUtils.h" +#include "Luau/Unifier2.h" +#include "Luau/VisitType.h" + +#include +#include + +LUAU_FASTINT(LuauCheckRecursionLimit) +LUAU_FASTFLAG(DebugLuauLogSolverToJson) +LUAU_FASTFLAG(DebugLuauMagicTypes) +LUAU_DYNAMIC_FASTINT(LuauTypeSolverRelease) +LUAU_FASTFLAG(LuauTypestateBuiltins) + +LUAU_FASTFLAGVARIABLE(LuauNewSolverVisitErrorExprLvalues, false) + +namespace Luau +{ + +bool doesCallError(const AstExprCall* call); // TypeInfer.cpp +const AstStat* getFallthrough(const AstStat* node); // TypeInfer.cpp + +static std::optional matchRequire(const AstExprCall& call) +{ + const char* require = "require"; + + if (call.args.size != 1) + return std::nullopt; + + const AstExprGlobal* funcAsGlobal = call.func->as(); + if (!funcAsGlobal || funcAsGlobal->name != require) + return std::nullopt; + + if (call.args.size != 1) + return std::nullopt; + + return call.args.data[0]; +} + +struct TypeGuard +{ + bool isTypeof; + AstExpr* target; + std::string type; +}; + +static std::optional matchTypeGuard(const AstExprBinary* binary) +{ + if (binary->op != AstExprBinary::CompareEq && binary->op != AstExprBinary::CompareNe) + return std::nullopt; + + AstExpr* left = binary->left; + AstExpr* right = binary->right; + if (right->is()) + std::swap(left, right); + + if (!right->is()) + return std::nullopt; + + AstExprCall* call = left->as(); + AstExprConstantString* string = right->as(); + if (!call || !string) + return std::nullopt; + + AstExprGlobal* callee = call->func->as(); + if (!callee) + return std::nullopt; + + if (callee->name != "type" && callee->name != "typeof") + return std::nullopt; + + if (call->args.size != 1) + return std::nullopt; + + return TypeGuard{ + /*isTypeof*/ callee->name == "typeof", + /*target*/ call->args.data[0], + /*type*/ std::string(string->value.data, string->value.size), + }; +} + +namespace +{ + +struct Checkpoint +{ + size_t offset; +}; + +Checkpoint checkpoint(const ConstraintGenerator* cg) +{ + return Checkpoint{cg->constraints.size()}; +} + +template +void forEachConstraint(const Checkpoint& start, const Checkpoint& end, const ConstraintGenerator* cg, F f) +{ + for (size_t i = start.offset; i < end.offset; ++i) + f(cg->constraints[i]); +} + +struct HasFreeType : TypeOnceVisitor +{ + bool result = false; + + HasFreeType() {} + + bool visit(TypeId ty) override + { + if (result || ty->persistent) + return false; + return true; + } + + bool visit(TypePackId tp) override + { + if (result) + return false; + return true; + } + + bool visit(TypeId ty, const ClassType&) override + { + return false; + } + + bool visit(TypeId ty, const FreeType&) override + { + result = true; + return false; + } + + bool visit(TypePackId ty, const FreeTypePack&) override + { + result = true; + return false; + } +}; + +bool hasFreeType(TypeId ty) +{ + HasFreeType hft{}; + hft.traverse(ty); + return hft.result; +} + +} // namespace + +ConstraintGenerator::ConstraintGenerator( + ModulePtr module, + NotNull normalizer, + NotNull typeFunctionRuntime, + NotNull moduleResolver, + NotNull builtinTypes, + NotNull ice, + const ScopePtr& globalScope, + std::function prepareModuleScope, + DcrLogger* logger, + NotNull dfg, + std::vector requireCycles +) + : module(module) + , builtinTypes(builtinTypes) + , arena(normalizer->arena) + , rootScope(nullptr) + , dfg(dfg) + , normalizer(normalizer) + , typeFunctionRuntime(typeFunctionRuntime) + , moduleResolver(moduleResolver) + , ice(ice) + , globalScope(globalScope) + , prepareModuleScope(std::move(prepareModuleScope)) + , requireCycles(std::move(requireCycles)) + , logger(logger) +{ + LUAU_ASSERT(module); +} + +void ConstraintGenerator::visitModuleRoot(AstStatBlock* block) +{ + LUAU_TIMETRACE_SCOPE("ConstraintGenerator::visitModuleRoot", "Typechecking"); + + LUAU_ASSERT(scopes.empty()); + LUAU_ASSERT(rootScope == nullptr); + ScopePtr scope = std::make_shared(globalScope); + rootScope = scope.get(); + scopes.emplace_back(block->location, scope); + rootScope->location = block->location; + module->astScopes[block] = NotNull{scope.get()}; + + rootScope->returnType = freshTypePack(scope); + + TypeId moduleFnTy = arena->addType(FunctionType{TypeLevel{}, rootScope, builtinTypes->anyTypePack, rootScope->returnType}); + interiorTypes.emplace_back(); + + prepopulateGlobalScope(scope, block); + + Checkpoint start = checkpoint(this); + + ControlFlow cf = + DFInt::LuauTypeSolverRelease >= 646 ? visitBlockWithoutChildScope(scope, block) : visitBlockWithoutChildScope_DEPRECATED(scope, block); + if (cf == ControlFlow::None) + addConstraint(scope, block->location, PackSubtypeConstraint{builtinTypes->emptyTypePack, rootScope->returnType}); + + Checkpoint end = checkpoint(this); + + TypeId result = arena->addType(BlockedType{}); + NotNull genConstraint = + addConstraint(scope, block->location, GeneralizationConstraint{result, moduleFnTy, std::move(interiorTypes.back())}); + getMutable(result)->setOwner(genConstraint); + forEachConstraint( + start, + end, + this, + [genConstraint](const ConstraintPtr& c) + { + genConstraint->dependencies.push_back(NotNull{c.get()}); + } + ); + + interiorTypes.pop_back(); + + fillInInferredBindings(scope, block); + + if (logger) + logger->captureGenerationModule(module); + + for (const auto& [ty, domain] : localTypes) + { + // FIXME: This isn't the most efficient thing. + TypeId domainTy = builtinTypes->neverType; + for (TypeId d : domain) + { + d = follow(d); + if (d == ty) + continue; + domainTy = simplifyUnion(builtinTypes, arena, domainTy, d).result; + } + + LUAU_ASSERT(get(ty)); + asMutable(ty)->ty.emplace(domainTy); + } +} + +void ConstraintGenerator::visitFragmentRoot(const ScopePtr& resumeScope, AstStatBlock* block) +{ + visitBlockWithoutChildScope(resumeScope, block); + fillInInferredBindings(resumeScope, block); + + if (logger) + logger->captureGenerationModule(module); + + for (const auto& [ty, domain] : localTypes) + { + // FIXME: This isn't the most efficient thing. + TypeId domainTy = builtinTypes->neverType; + for (TypeId d : domain) + { + d = follow(d); + if (d == ty) + continue; + domainTy = simplifyUnion(builtinTypes, arena, domainTy, d).result; + } + + LUAU_ASSERT(get(ty)); + asMutable(ty)->ty.emplace(domainTy); + } +} + +TypeId ConstraintGenerator::freshType(const ScopePtr& scope) +{ + return Luau::freshType(arena, builtinTypes, scope.get()); +} + +TypePackId ConstraintGenerator::freshTypePack(const ScopePtr& scope) +{ + FreeTypePack f{scope.get()}; + return arena->addTypePack(TypePackVar{std::move(f)}); +} + +TypePackId ConstraintGenerator::addTypePack(std::vector head, std::optional tail) +{ + if (head.empty()) + { + if (tail) + return *tail; + else + return builtinTypes->emptyTypePack; + } + else + return arena->addTypePack(TypePack{std::move(head), tail}); +} + +ScopePtr ConstraintGenerator::childScope(AstNode* node, const ScopePtr& parent) +{ + auto scope = std::make_shared(parent); + scopes.emplace_back(node->location, scope); + scope->location = node->location; + + scope->returnType = parent->returnType; + scope->varargPack = parent->varargPack; + + parent->children.push_back(NotNull{scope.get()}); + module->astScopes[node] = scope.get(); + + return scope; +} + +std::optional ConstraintGenerator::lookup(const ScopePtr& scope, Location location, DefId def, bool prototype) +{ + if (get(def)) + return scope->lookup(def); + if (auto phi = get(def)) + { + if (auto found = scope->lookup(def)) + return *found; + else if (!prototype && phi->operands.size() == 1) + return lookup(scope, location, phi->operands.at(0), prototype); + else if (!prototype) + return std::nullopt; + + TypeId res = builtinTypes->neverType; + + for (DefId operand : phi->operands) + { + // `scope->lookup(operand)` may return nothing because we only bind a type to that operand + // once we've seen that particular `DefId`. In this case, we need to prototype those types + // and use those at a later time. + std::optional ty = lookup(scope, location, operand, /*prototype*/ false); + if (!ty) + { + ty = arena->addType(BlockedType{}); + localTypes.try_insert(*ty, {}); + rootScope->lvalueTypes[operand] = *ty; + } + + res = makeUnion(scope, location, res, *ty); + } + + scope->lvalueTypes[def] = res; + return res; + } + else + ice->ice("ConstraintGenerator::lookup is inexhaustive?"); +} + +NotNull ConstraintGenerator::addConstraint(const ScopePtr& scope, const Location& location, ConstraintV cv) +{ + return NotNull{constraints.emplace_back(new Constraint{NotNull{scope.get()}, location, std::move(cv)}).get()}; +} + +NotNull ConstraintGenerator::addConstraint(const ScopePtr& scope, std::unique_ptr c) +{ + return NotNull{constraints.emplace_back(std::move(c)).get()}; +} + +void ConstraintGenerator::unionRefinements( + const ScopePtr& scope, + Location location, + const RefinementContext& lhs, + const RefinementContext& rhs, + RefinementContext& dest, + std::vector* constraints +) +{ + const auto intersect = [&](const std::vector& types) + { + if (1 == types.size()) + return types[0]; + else if (2 == types.size()) + return makeIntersect(scope, location, types[0], types[1]); + + return arena->addType(IntersectionType{types}); + }; + + for (auto& [def, partition] : lhs) + { + auto rhsIt = rhs.find(def); + if (rhsIt == rhs.end()) + continue; + + LUAU_ASSERT(!partition.discriminantTypes.empty()); + LUAU_ASSERT(!rhsIt->second.discriminantTypes.empty()); + + TypeId leftDiscriminantTy = partition.discriminantTypes.size() == 1 ? partition.discriminantTypes[0] : intersect(partition.discriminantTypes); + + TypeId rightDiscriminantTy = + rhsIt->second.discriminantTypes.size() == 1 ? rhsIt->second.discriminantTypes[0] : intersect(rhsIt->second.discriminantTypes); + + dest.insert(def, {}); + dest.get(def)->discriminantTypes.push_back(makeUnion(scope, location, leftDiscriminantTy, rightDiscriminantTy)); + dest.get(def)->shouldAppendNilType |= partition.shouldAppendNilType || rhsIt->second.shouldAppendNilType; + } +} + +void ConstraintGenerator::computeRefinement( + const ScopePtr& scope, + Location location, + RefinementId refinement, + RefinementContext* refis, + bool sense, + bool eq, + std::vector* constraints +) +{ + if (!refinement) + return; + else if (auto variadic = get(refinement)) + { + for (RefinementId refi : variadic->refinements) + computeRefinement(scope, location, refi, refis, sense, eq, constraints); + } + else if (auto negation = get(refinement)) + return computeRefinement(scope, location, negation->refinement, refis, !sense, eq, constraints); + else if (auto conjunction = get(refinement)) + { + RefinementContext lhsRefis; + RefinementContext rhsRefis; + + computeRefinement(scope, location, conjunction->lhs, sense ? refis : &lhsRefis, sense, eq, constraints); + computeRefinement(scope, location, conjunction->rhs, sense ? refis : &rhsRefis, sense, eq, constraints); + + if (!sense) + unionRefinements(scope, location, lhsRefis, rhsRefis, *refis, constraints); + } + else if (auto disjunction = get(refinement)) + { + RefinementContext lhsRefis; + RefinementContext rhsRefis; + + computeRefinement(scope, location, disjunction->lhs, sense ? &lhsRefis : refis, sense, eq, constraints); + computeRefinement(scope, location, disjunction->rhs, sense ? &rhsRefis : refis, sense, eq, constraints); + + if (sense) + unionRefinements(scope, location, lhsRefis, rhsRefis, *refis, constraints); + } + else if (auto equivalence = get(refinement)) + { + computeRefinement(scope, location, equivalence->lhs, refis, sense, true, constraints); + computeRefinement(scope, location, equivalence->rhs, refis, sense, true, constraints); + } + else if (auto proposition = get(refinement)) + { + TypeId discriminantTy = proposition->discriminantTy; + + // if we have a negative sense, then we need to negate the discriminant + if (!sense) + discriminantTy = arena->addType(NegationType{discriminantTy}); + + if (eq) + discriminantTy = createTypeFunctionInstance(builtinTypeFunctions().singletonFunc, {discriminantTy}, {}, scope, location); + + for (const RefinementKey* key = proposition->key; key; key = key->parent) + { + refis->insert(key->def, {}); + refis->get(key->def)->discriminantTypes.push_back(discriminantTy); + + // Reached leaf node + if (!key->propName) + break; + + TypeId nextDiscriminantTy = arena->addType(TableType{}); + NotNull table{getMutable(nextDiscriminantTy)}; + // When we fully support read-write properties (i.e. when we allow properties with + // completely disparate read and write types), then the following property can be + // set to read-only since refinements only tell us about what we read. This cannot + // be allowed yet though because it causes read and write types to diverge. + table->props[*key->propName] = Property::rw(discriminantTy); + table->scope = scope.get(); + table->state = TableState::Sealed; + + discriminantTy = nextDiscriminantTy; + } + + // When the top-level expression is `t[x]`, we want to refine it into `nil`, not `never`. + LUAU_ASSERT(refis->get(proposition->key->def)); + refis->get(proposition->key->def)->shouldAppendNilType = (sense || !eq) && containsSubscriptedDefinition(proposition->key->def); + } +} + +namespace +{ + +/* + * Constraint generation may be called upon to simplify an intersection or union + * of types that are not sufficiently solved yet. We use + * FindSimplificationBlockers to recognize these types and defer the + * simplification until constraint solution. + */ +struct FindSimplificationBlockers : TypeOnceVisitor +{ + bool found = false; + + bool visit(TypeId) override + { + return !found; + } + + bool visit(TypeId, const BlockedType&) override + { + found = true; + return false; + } + + bool visit(TypeId, const FreeType&) override + { + found = true; + return false; + } + + bool visit(TypeId, const PendingExpansionType&) override + { + found = true; + return false; + } + + // We do not need to know anything at all about a function's argument or + // return types in order to simplify it in an intersection or union. + bool visit(TypeId, const FunctionType&) override + { + return false; + } + + bool visit(TypeId, const ClassType&) override + { + return false; + } +}; + +bool mustDeferIntersection(TypeId ty) +{ + FindSimplificationBlockers bts; + bts.traverse(ty); + return bts.found; +} +} // namespace + +enum RefinementsOpKind +{ + Intersect, + Refine, + None +}; + +void ConstraintGenerator::applyRefinements(const ScopePtr& scope, Location location, RefinementId refinement) +{ + if (!refinement) + return; + + RefinementContext refinements; + std::vector constraints; + computeRefinement(scope, location, refinement, &refinements, /*sense*/ true, /*eq*/ false, &constraints); + auto flushConstraints = [this, &scope, &location](RefinementsOpKind kind, TypeId ty, std::vector& discriminants) + { + if (discriminants.empty()) + return ty; + if (kind == RefinementsOpKind::None) + { + LUAU_ASSERT(false); + return ty; + } + std::vector args = {ty}; + const TypeFunction& func = kind == RefinementsOpKind::Intersect ? builtinTypeFunctions().intersectFunc : builtinTypeFunctions().refineFunc; + LUAU_ASSERT(!func.name.empty()); + args.insert(args.end(), discriminants.begin(), discriminants.end()); + TypeId resultType = createTypeFunctionInstance(func, args, {}, scope, location); + discriminants.clear(); + return resultType; + }; + + for (auto& [def, partition] : refinements) + { + if (std::optional defTy = lookup(scope, location, def)) + { + TypeId ty = *defTy; + if (partition.shouldAppendNilType) + ty = arena->addType(UnionType{{ty, builtinTypes->nilType}}); + // Intersect ty with every discriminant type. If either type is not + // sufficiently solved, we queue the intersection up via an + // IntersectConstraint. + // For each discriminant ty, we accumulated it onto ty, creating a longer and longer + // sequence of refine constraints. On every loop of this we called mustDeferIntersection. + // For sufficiently large types, we would blow the stack. + // Instead, we record all the discriminant types in sequence + // and then dispatch a single refine constraint with multiple arguments. This helps us avoid + // the potentially expensive check on mustDeferIntersection + std::vector discriminants; + RefinementsOpKind kind = RefinementsOpKind::None; + bool mustDefer = mustDeferIntersection(ty); + for (TypeId dt : partition.discriminantTypes) + { + mustDefer = mustDefer || mustDeferIntersection(dt); + if (mustDefer) + { + if (kind == RefinementsOpKind::Intersect) + ty = flushConstraints(kind, ty, discriminants); + kind = RefinementsOpKind::Refine; + + discriminants.push_back(dt); + } + else + { + ErrorSuppression status = shouldSuppressErrors(normalizer, ty); + if (status == ErrorSuppression::NormalizationFailed) + reportError(location, NormalizationTooComplex{}); + if (kind == RefinementsOpKind::Refine) + ty = flushConstraints(kind, ty, discriminants); + kind = RefinementsOpKind::Intersect; + + discriminants.push_back(dt); + + if (status == ErrorSuppression::Suppress) + { + ty = flushConstraints(kind, ty, discriminants); + ty = makeUnion(scope, location, ty, builtinTypes->errorType); + } + } + } + + // Finalize - if there are any discriminants left, make one big constraint for refining them + if (kind != RefinementsOpKind::None) + ty = flushConstraints(kind, ty, discriminants); + + scope->rvalueRefinements[def] = ty; + } + } + + for (auto& c : constraints) + addConstraint(scope, location, c); +} + +void ConstraintGenerator::checkAliases(const ScopePtr& scope, AstStatBlock* block) +{ + std::unordered_map aliasDefinitionLocations; + + // In order to enable mutually-recursive type aliases, we need to + // populate the type bindings before we actually check any of the + // alias statements. + for (AstStat* stat : block->body) + { + if (auto alias = stat->as()) + { + if (scope->exportedTypeBindings.count(alias->name.value) || scope->privateTypeBindings.count(alias->name.value)) + { + auto it = aliasDefinitionLocations.find(alias->name.value); + LUAU_ASSERT(it != aliasDefinitionLocations.end()); + reportError(alias->location, DuplicateTypeDefinition{alias->name.value, it->second}); + continue; + } + + // A type alias might have no name if the code is syntactically + // illegal. We mustn't prepopulate anything in this case. + if (alias->name == kParseNameError || alias->name == "typeof") + continue; + + ScopePtr defnScope = childScope(alias, scope); + + TypeId initialType = arena->addType(BlockedType{}); + TypeFun initialFun{initialType}; + + for (const auto& [name, gen] : createGenerics(defnScope, alias->generics, /* useCache */ true)) + { + initialFun.typeParams.push_back(gen); + } + + for (const auto& [name, genPack] : createGenericPacks(defnScope, alias->genericPacks, /* useCache */ true)) + { + initialFun.typePackParams.push_back(genPack); + } + + if (alias->exported) + scope->exportedTypeBindings[alias->name.value] = std::move(initialFun); + else + scope->privateTypeBindings[alias->name.value] = std::move(initialFun); + + astTypeAliasDefiningScopes[alias] = defnScope; + aliasDefinitionLocations[alias->name.value] = alias->location; + } + else if (auto function = stat->as()) + { + // If a type function w/ same name has already been defined, error for having duplicates + if (scope->exportedTypeBindings.count(function->name.value) || scope->privateTypeBindings.count(function->name.value)) + { + auto it = aliasDefinitionLocations.find(function->name.value); + LUAU_ASSERT(it != aliasDefinitionLocations.end()); + reportError(function->location, DuplicateTypeDefinition{function->name.value, it->second}); + continue; + } + + if (scope->parent != globalScope) + { + reportError(function->location, GenericError{"Local user-defined functions are not supported yet"}); + continue; + } + + ScopePtr defnScope = childScope(function, scope); + + // Create TypeFunctionInstanceType + + std::vector typeParams; + typeParams.reserve(function->body->args.size); + + std::vector quantifiedTypeParams; + quantifiedTypeParams.reserve(function->body->args.size); + + for (size_t i = 0; i < function->body->args.size; i++) + { + std::string name = format("T%zu", i); + TypeId ty = arena->addType(GenericType{name}); + typeParams.push_back(ty); + + GenericTypeDefinition genericTy{ty}; + quantifiedTypeParams.push_back(genericTy); + } + + if (std::optional error = typeFunctionRuntime->registerFunction(function)) + reportError(function->location, GenericError{*error}); + + TypeId typeFunctionTy = arena->addType(TypeFunctionInstanceType{ + NotNull{&builtinTypeFunctions().userFunc}, + std::move(typeParams), + {}, + function->name, + }); + + TypeFun typeFunction{std::move(quantifiedTypeParams), typeFunctionTy}; + + // Set type bindings and definition locations for this user-defined type function + scope->privateTypeBindings[function->name.value] = std::move(typeFunction); + aliasDefinitionLocations[function->name.value] = function->location; + } + } +} + +ControlFlow ConstraintGenerator::visitBlockWithoutChildScope(const ScopePtr& scope, AstStatBlock* block) +{ + RecursionCounter counter{&recursionCount}; + + if (recursionCount >= FInt::LuauCheckRecursionLimit) + { + reportCodeTooComplex(block->location); + return ControlFlow::None; + } + + checkAliases(scope, block); + + std::optional firstControlFlow; + for (AstStat* stat : block->body) + { + ControlFlow cf = visit(scope, stat); + if (cf != ControlFlow::None && !firstControlFlow) + firstControlFlow = cf; + } + + return firstControlFlow.value_or(ControlFlow::None); +} + +ControlFlow ConstraintGenerator::visitBlockWithoutChildScope_DEPRECATED(const ScopePtr& scope, AstStatBlock* block) +{ + RecursionCounter counter{&recursionCount}; + + if (recursionCount >= FInt::LuauCheckRecursionLimit) + { + reportCodeTooComplex(block->location); + return ControlFlow::None; + } + + std::unordered_map aliasDefinitionLocations; + + // In order to enable mutually-recursive type aliases, we need to + // populate the type bindings before we actually check any of the + // alias statements. + for (AstStat* stat : block->body) + { + if (auto alias = stat->as()) + { + if (scope->exportedTypeBindings.count(alias->name.value) || scope->privateTypeBindings.count(alias->name.value)) + { + auto it = aliasDefinitionLocations.find(alias->name.value); + LUAU_ASSERT(it != aliasDefinitionLocations.end()); + reportError(alias->location, DuplicateTypeDefinition{alias->name.value, it->second}); + continue; + } + + // A type alias might have no name if the code is syntactically + // illegal. We mustn't prepopulate anything in this case. + if (alias->name == kParseNameError || alias->name == "typeof") + continue; + + ScopePtr defnScope = childScope(alias, scope); + + TypeId initialType = arena->addType(BlockedType{}); + TypeFun initialFun{initialType}; + + for (const auto& [name, gen] : createGenerics(defnScope, alias->generics, /* useCache */ true)) + { + initialFun.typeParams.push_back(gen); + } + + for (const auto& [name, genPack] : createGenericPacks(defnScope, alias->genericPacks, /* useCache */ true)) + { + initialFun.typePackParams.push_back(genPack); + } + + if (alias->exported) + scope->exportedTypeBindings[alias->name.value] = std::move(initialFun); + else + scope->privateTypeBindings[alias->name.value] = std::move(initialFun); + + astTypeAliasDefiningScopes[alias] = defnScope; + aliasDefinitionLocations[alias->name.value] = alias->location; + } + else if (auto function = stat->as()) + { + // If a type function w/ same name has already been defined, error for having duplicates + if (scope->exportedTypeBindings.count(function->name.value) || scope->privateTypeBindings.count(function->name.value)) + { + auto it = aliasDefinitionLocations.find(function->name.value); + LUAU_ASSERT(it != aliasDefinitionLocations.end()); + reportError(function->location, DuplicateTypeDefinition{function->name.value, it->second}); + continue; + } + + if (scope->parent != globalScope) + { + reportError(function->location, GenericError{"Local user-defined functions are not supported yet"}); + continue; + } + + ScopePtr defnScope = childScope(function, scope); + + // Create TypeFunctionInstanceType + + std::vector typeParams; + typeParams.reserve(function->body->args.size); + + std::vector quantifiedTypeParams; + quantifiedTypeParams.reserve(function->body->args.size); + + for (size_t i = 0; i < function->body->args.size; i++) + { + std::string name = format("T%zu", i); + TypeId ty = arena->addType(GenericType{name}); + typeParams.push_back(ty); + + GenericTypeDefinition genericTy{ty}; + quantifiedTypeParams.push_back(genericTy); + } + + if (std::optional error = typeFunctionRuntime->registerFunction(function)) + reportError(function->location, GenericError{*error}); + + TypeId typeFunctionTy = arena->addType(TypeFunctionInstanceType{ + NotNull{&builtinTypeFunctions().userFunc}, + std::move(typeParams), + {}, + function->name, + }); + + TypeFun typeFunction{std::move(quantifiedTypeParams), typeFunctionTy}; + + // Set type bindings and definition locations for this user-defined type function + scope->privateTypeBindings[function->name.value] = std::move(typeFunction); + aliasDefinitionLocations[function->name.value] = function->location; + } + } + + std::optional firstControlFlow; + for (AstStat* stat : block->body) + { + ControlFlow cf = visit(scope, stat); + if (cf != ControlFlow::None && !firstControlFlow) + firstControlFlow = cf; + } + + return firstControlFlow.value_or(ControlFlow::None); +} + +ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStat* stat) +{ + RecursionLimiter limiter{&recursionCount, FInt::LuauCheckRecursionLimit}; + + if (auto s = stat->as()) + return visit(scope, s); + else if (auto i = stat->as()) + return visit(scope, i); + else if (auto s = stat->as()) + return visit(scope, s); + else if (auto s = stat->as()) + return visit(scope, s); + else if (stat->is()) + return ControlFlow::Breaks; + else if (stat->is()) + return ControlFlow::Continues; + else if (auto r = stat->as()) + return visit(scope, r); + else if (auto e = stat->as()) + { + checkPack(scope, e->expr); + + if (auto call = e->expr->as(); call && doesCallError(call)) + return ControlFlow::Throws; + + return ControlFlow::None; + } + else if (auto s = stat->as()) + return visit(scope, s); + else if (auto s = stat->as()) + return visit(scope, s); + else if (auto s = stat->as()) + return visit(scope, s); + else if (auto a = stat->as()) + return visit(scope, a); + else if (auto a = stat->as()) + return visit(scope, a); + else if (auto f = stat->as()) + return visit(scope, f); + else if (auto f = stat->as()) + return visit(scope, f); + else if (auto a = stat->as()) + return visit(scope, a); + else if (auto f = stat->as()) + return visit(scope, f); + else if (auto s = stat->as()) + return visit(scope, s); + else if (auto s = stat->as()) + return visit(scope, s); + else if (auto s = stat->as()) + return visit(scope, s); + else if (auto s = stat->as()) + return visit(scope, s); + else + { + LUAU_ASSERT(0 && "Internal error: Unknown AstStat type"); + return ControlFlow::None; + } +} + +ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatLocal* statLocal) +{ + std::vector annotatedTypes; + annotatedTypes.reserve(statLocal->vars.size); + bool hasAnnotation = false; + + std::vector> expectedTypes; + expectedTypes.reserve(statLocal->vars.size); + + std::vector assignees; + assignees.reserve(statLocal->vars.size); + + // Used to name the first value type, even if it's not placed in varTypes, + // for the purpose of synthetic name attribution. + std::optional firstValueType; + + for (AstLocal* local : statLocal->vars) + { + const Location location = local->location; + + TypeId assignee = arena->addType(BlockedType{}); + localTypes.try_insert(assignee, {}); + + assignees.push_back(assignee); + + if (!firstValueType) + firstValueType = assignee; + + if (local->annotation) + { + hasAnnotation = true; + TypeId annotationTy = resolveType(scope, local->annotation, /* inTypeArguments */ false); + annotatedTypes.push_back(annotationTy); + expectedTypes.push_back(annotationTy); + + scope->bindings[local] = Binding{annotationTy, location}; + } + else + { + // annotatedTypes must contain one type per local. If a particular + // local has no annotation at, assume the most conservative thing. + annotatedTypes.push_back(builtinTypes->unknownType); + + expectedTypes.push_back(std::nullopt); + scope->bindings[local] = Binding{builtinTypes->unknownType, location}; + + inferredBindings[local] = {scope.get(), location, {assignee}}; + } + + DefId def = dfg->getDef(local); + scope->lvalueTypes[def] = assignee; + } + + Checkpoint start = checkpoint(this); + TypePackId rvaluePack = checkPack(scope, statLocal->values, expectedTypes).tp; + Checkpoint end = checkpoint(this); + + if (hasAnnotation) + { + for (size_t i = 0; i < statLocal->vars.size; ++i) + { + LUAU_ASSERT(get(assignees[i])); + TypeIds* localDomain = localTypes.find(assignees[i]); + LUAU_ASSERT(localDomain); + localDomain->insert(annotatedTypes[i]); + } + + TypePackId annotatedPack = arena->addTypePack(std::move(annotatedTypes)); + addConstraint(scope, statLocal->location, PackSubtypeConstraint{rvaluePack, annotatedPack}); + } + else + { + std::vector valueTypes; + valueTypes.reserve(statLocal->vars.size); + + auto [head, tail] = flatten(rvaluePack); + + if (head.size() >= statLocal->vars.size) + { + for (size_t i = 0; i < statLocal->vars.size; ++i) + valueTypes.push_back(head[i]); + } + else + { + for (size_t i = 0; i < statLocal->vars.size; ++i) + valueTypes.push_back(arena->addType(BlockedType{})); + + auto uc = addConstraint(scope, statLocal->location, UnpackConstraint{valueTypes, rvaluePack}); + + forEachConstraint( + start, + end, + this, + [&uc](const ConstraintPtr& runBefore) + { + uc->dependencies.push_back(NotNull{runBefore.get()}); + } + ); + + for (TypeId t : valueTypes) + getMutable(t)->setOwner(uc); + } + + for (size_t i = 0; i < statLocal->vars.size; ++i) + { + LUAU_ASSERT(get(assignees[i])); + TypeIds* localDomain = localTypes.find(assignees[i]); + LUAU_ASSERT(localDomain); + localDomain->insert(valueTypes[i]); + } + } + + if (statLocal->vars.size == 1 && statLocal->values.size == 1 && firstValueType && scope.get() == rootScope && !hasAnnotation) + { + AstLocal* var = statLocal->vars.data[0]; + AstExpr* value = statLocal->values.data[0]; + + if (value->is()) + addConstraint(scope, value->location, NameConstraint{*firstValueType, var->name.value, /*synthetic*/ true}); + else if (const AstExprCall* call = value->as()) + { + if (FFlag::LuauTypestateBuiltins) + { + if (matchSetMetatable(*call)) + addConstraint(scope, value->location, NameConstraint{*firstValueType, var->name.value, /*synthetic*/ true}); + } + else + { + if (const AstExprGlobal* global = call->func->as(); global && global->name == "setmetatable") + { + addConstraint(scope, value->location, NameConstraint{*firstValueType, var->name.value, /*synthetic*/ true}); + } + } + } + } + + if (statLocal->values.size > 0) + { + // To correctly handle 'require', we need to import the exported type bindings into the variable 'namespace'. + for (size_t i = 0; i < statLocal->values.size && i < statLocal->vars.size; ++i) + { + const AstExprCall* call = statLocal->values.data[i]->as(); + if (!call) + continue; + + auto maybeRequire = matchRequire(*call); + if (!maybeRequire) + continue; + + AstExpr* require = *maybeRequire; + + auto moduleInfo = moduleResolver->resolveModuleInfo(module->name, *require); + if (!moduleInfo) + continue; + + ModulePtr module = moduleResolver->getModule(moduleInfo->name); + if (!module) + continue; + + const Name name{statLocal->vars.data[i]->name.value}; + scope->importedTypeBindings[name] = module->exportedTypeBindings; + scope->importedModules[name] = moduleInfo->name; + + // Imported types of requires that transitively refer to current module have to be replaced with 'any' + for (const auto& [location, path] : requireCycles) + { + if (path.empty() || path.front() != moduleInfo->name) + continue; + + for (auto& [name, tf] : scope->importedTypeBindings[name]) + tf = TypeFun{{}, {}, builtinTypes->anyType}; + } + } + } + + return ControlFlow::None; +} + +ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatFor* for_) +{ + TypeId annotationTy = builtinTypes->numberType; + if (for_->var->annotation) + annotationTy = resolveType(scope, for_->var->annotation, /* inTypeArguments */ false); + + auto inferNumber = [&](AstExpr* expr) + { + if (!expr) + return; + + TypeId t = check(scope, expr).ty; + addConstraint(scope, expr->location, SubtypeConstraint{t, builtinTypes->numberType}); + }; + + inferNumber(for_->from); + inferNumber(for_->to); + inferNumber(for_->step); + + ScopePtr forScope = childScope(for_, scope); + forScope->bindings[for_->var] = Binding{annotationTy, for_->var->location}; + + DefId def = dfg->getDef(for_->var); + forScope->lvalueTypes[def] = annotationTy; + forScope->rvalueRefinements[def] = annotationTy; + + visit(forScope, for_->body); + + return ControlFlow::None; +} + +ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatForIn* forIn) +{ + ScopePtr loopScope = childScope(forIn, scope); + TypePackId iterator = checkPack(scope, forIn->values).tp; + + std::vector variableTypes; + variableTypes.reserve(forIn->vars.size); + + for (AstLocal* var : forIn->vars) + { + TypeId assignee = arena->addType(BlockedType{}); + variableTypes.push_back(assignee); + + TypeId loopVar = arena->addType(BlockedType{}); + localTypes[loopVar].insert(assignee); + + if (var->annotation) + { + TypeId annotationTy = resolveType(loopScope, var->annotation, /*inTypeArguments*/ false); + loopScope->bindings[var] = Binding{annotationTy, var->location}; + addConstraint(scope, var->location, SubtypeConstraint{loopVar, annotationTy}); + } + else + loopScope->bindings[var] = Binding{loopVar, var->location}; + + DefId def = dfg->getDef(var); + loopScope->lvalueTypes[def] = loopVar; + } + + auto iterable = addConstraint( + loopScope, getLocation(forIn->values), IterableConstraint{iterator, variableTypes, forIn->values.data[0], &module->astForInNextTypes} + ); + + for (TypeId var : variableTypes) + { + auto bt = getMutable(var); + LUAU_ASSERT(bt); + bt->setOwner(iterable); + } + + Checkpoint start = checkpoint(this); + visit(loopScope, forIn->body); + Checkpoint end = checkpoint(this); + + // This iter constraint must dispatch first. + forEachConstraint( + start, + end, + this, + [&iterable](const ConstraintPtr& runLater) + { + runLater->dependencies.push_back(iterable); + } + ); + + return ControlFlow::None; +} + +ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatWhile* while_) +{ + RefinementId refinement = check(scope, while_->condition).refinement; + + ScopePtr whileScope = childScope(while_, scope); + applyRefinements(whileScope, while_->condition->location, refinement); + + visit(whileScope, while_->body); + + return ControlFlow::None; +} + +ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatRepeat* repeat) +{ + ScopePtr repeatScope = childScope(repeat, scope); + + if (DFInt::LuauTypeSolverRelease >= 646) + visitBlockWithoutChildScope(repeatScope, repeat->body); + else + visitBlockWithoutChildScope_DEPRECATED(repeatScope, repeat->body); + + check(repeatScope, repeat->condition); + + return ControlFlow::None; +} + +ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatLocalFunction* function) +{ + // Local + // Global + // Dotted path + // Self? + + TypeId functionType = nullptr; + auto ty = scope->lookup(function->name); + LUAU_ASSERT(!ty.has_value()); // The parser ensures that every local function has a distinct Symbol for its name. + + functionType = arena->addType(BlockedType{}); + scope->bindings[function->name] = Binding{functionType, function->name->location}; + + FunctionSignature sig = checkFunctionSignature(scope, function->func, /* expectedType */ std::nullopt, function->name->location); + sig.bodyScope->bindings[function->name] = Binding{sig.signature, function->name->location}; + + bool sigFullyDefined = !hasFreeType(sig.signature); + if (sigFullyDefined) + emplaceType(asMutable(functionType), sig.signature); + + DefId def = dfg->getDef(function->name); + scope->lvalueTypes[def] = functionType; + scope->rvalueRefinements[def] = functionType; + sig.bodyScope->lvalueTypes[def] = sig.signature; + sig.bodyScope->rvalueRefinements[def] = sig.signature; + + Checkpoint start = checkpoint(this); + checkFunctionBody(sig.bodyScope, function->func); + Checkpoint end = checkpoint(this); + + if (!sigFullyDefined) + { + NotNull constraintScope{sig.signatureScope ? sig.signatureScope.get() : sig.bodyScope.get()}; + std::unique_ptr c = + std::make_unique(constraintScope, function->name->location, GeneralizationConstraint{functionType, sig.signature}); + + Constraint* previous = nullptr; + forEachConstraint( + start, + end, + this, + [&c, &previous](const ConstraintPtr& constraint) + { + c->dependencies.push_back(NotNull{constraint.get()}); + + if (auto psc = get(*constraint); psc && psc->returns) + { + if (previous) + constraint->dependencies.push_back(NotNull{previous}); + + previous = constraint.get(); + } + } + ); + + getMutable(functionType)->setOwner(addConstraint(scope, std::move(c))); + module->astTypes[function->func] = functionType; + } + else + module->astTypes[function->func] = sig.signature; + + return ControlFlow::None; +} + +ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatFunction* function) +{ + // Name could be AstStatLocal, AstStatGlobal, AstStatIndexName. + // With or without self + + Checkpoint start = checkpoint(this); + FunctionSignature sig = checkFunctionSignature(scope, function->func, /* expectedType */ std::nullopt, function->name->location); + bool sigFullyDefined = !hasFreeType(sig.signature); + + checkFunctionBody(sig.bodyScope, function->func); + Checkpoint end = checkpoint(this); + + TypeId generalizedType = arena->addType(BlockedType{}); + if (sigFullyDefined) + emplaceType(asMutable(generalizedType), sig.signature); + else + { + const ScopePtr& constraintScope = sig.signatureScope ? sig.signatureScope : sig.bodyScope; + + NotNull c = addConstraint(constraintScope, function->name->location, GeneralizationConstraint{generalizedType, sig.signature}); + getMutable(generalizedType)->setOwner(c); + + Constraint* previous = nullptr; + forEachConstraint( + start, + end, + this, + [&c, &previous](const ConstraintPtr& constraint) + { + c->dependencies.push_back(NotNull{constraint.get()}); + + if (auto psc = get(*constraint); psc && psc->returns) + { + if (previous) + constraint->dependencies.push_back(NotNull{previous}); + + previous = constraint.get(); + } + } + ); + } + + DefId def = dfg->getDef(function->name); + std::optional existingFunctionTy = follow(lookup(scope, function->name->location, def)); + + if (AstExprLocal* localName = function->name->as()) + { + visitLValue(scope, localName, generalizedType); + + scope->bindings[localName->local] = Binding{sig.signature, localName->location}; + scope->lvalueTypes[def] = sig.signature; + } + else if (AstExprGlobal* globalName = function->name->as()) + { + if (!existingFunctionTy) + ice->ice("prepopulateGlobalScope did not populate a global name", globalName->location); + + // Sketchy: We're specifically looking for BlockedTypes that were + // initially created by ConstraintGenerator::prepopulateGlobalScope. + if (auto bt = get(*existingFunctionTy); bt && nullptr == bt->getOwner()) + emplaceType(asMutable(*existingFunctionTy), generalizedType); + + scope->bindings[globalName->name] = Binding{sig.signature, globalName->location}; + scope->lvalueTypes[def] = sig.signature; + } + else if (AstExprIndexName* indexName = function->name->as()) + { + visitLValue(scope, indexName, generalizedType); + } + else if (AstExprError* err = function->name->as()) + { + generalizedType = builtinTypes->errorRecoveryType(); + } + + if (generalizedType == nullptr) + ice->ice("generalizedType == nullptr", function->location); + + scope->rvalueRefinements[def] = generalizedType; + + return ControlFlow::None; +} + +ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatReturn* ret) +{ + // At this point, the only way scope->returnType should have anything + // interesting in it is if the function has an explicit return annotation. + // If this is the case, then we can expect that the return expression + // conforms to that. + std::vector> expectedTypes; + for (TypeId ty : scope->returnType) + expectedTypes.push_back(ty); + + TypePackId exprTypes = checkPack(scope, ret->list, expectedTypes).tp; + addConstraint(scope, ret->location, PackSubtypeConstraint{exprTypes, scope->returnType, /*returns*/ true}); + + return ControlFlow::Returns; +} + +ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatBlock* block) +{ + ScopePtr innerScope = childScope(block, scope); + + ControlFlow flow = DFInt::LuauTypeSolverRelease >= 646 ? visitBlockWithoutChildScope(innerScope, block) + : visitBlockWithoutChildScope_DEPRECATED(innerScope, block); + + // An AstStatBlock has linear control flow, i.e. one entry and one exit, so we can inherit + // all the changes to the environment occurred by the statements in that block. + scope->inheritRefinements(innerScope); + scope->inheritAssignments(innerScope); + + return flow; +} + +// TODO Clip? +static void bindFreeType(TypeId a, TypeId b) +{ + FreeType* af = getMutable(a); + FreeType* bf = getMutable(b); + + LUAU_ASSERT(af || bf); + + if (!bf) + emplaceType(asMutable(a), b); + else if (!af) + emplaceType(asMutable(b), a); + else if (subsumes(bf->scope, af->scope)) + emplaceType(asMutable(a), b); + else if (subsumes(af->scope, bf->scope)) + emplaceType(asMutable(b), a); +} + +ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatAssign* assign) +{ + TypePackId resultPack = checkPack(scope, assign->values).tp; + + std::vector valueTypes; + valueTypes.reserve(assign->vars.size); + + auto [head, tail] = flatten(resultPack); + if (head.size() >= assign->vars.size) + { + // If the resultPack is definitely long enough for each variable, we can + // skip the UnpackConstraint and use the result types directly. + + for (size_t i = 0; i < assign->vars.size; ++i) + valueTypes.push_back(head[i]); + } + else + { + // We're not sure how many types are produced by the right-side + // expressions. We'll use an UnpackConstraint to defer this until + // later. + for (size_t i = 0; i < assign->vars.size; ++i) + valueTypes.push_back(arena->addType(BlockedType{})); + + auto uc = addConstraint(scope, assign->location, UnpackConstraint{valueTypes, resultPack}); + + for (TypeId t : valueTypes) + getMutable(t)->setOwner(uc); + } + + for (size_t i = 0; i < assign->vars.size; ++i) + { + visitLValue(scope, assign->vars.data[i], valueTypes[i]); + } + + return ControlFlow::None; +} + +ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatCompoundAssign* assign) +{ + AstExprBinary binop = AstExprBinary{assign->location, assign->op, assign->var, assign->value}; + TypeId resultTy = check(scope, &binop).ty; + module->astCompoundAssignResultTypes[assign] = resultTy; + + TypeId lhsType = check(scope, assign->var).ty; + visitLValue(scope, assign->var, lhsType); + + follow(lhsType); + follow(resultTy); + + return ControlFlow::None; +} + +ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatIf* ifStatement) +{ + RefinementId refinement = [&]() + { + InConditionalContext flipper{&typeContext}; + return check(scope, ifStatement->condition, std::nullopt).refinement; + }(); + + ScopePtr thenScope = childScope(ifStatement->thenbody, scope); + applyRefinements(thenScope, ifStatement->condition->location, refinement); + + ScopePtr elseScope = childScope(ifStatement->elsebody ? ifStatement->elsebody : ifStatement, scope); + applyRefinements(elseScope, ifStatement->elseLocation.value_or(ifStatement->condition->location), refinementArena.negation(refinement)); + + ControlFlow thencf = visit(thenScope, ifStatement->thenbody); + ControlFlow elsecf = ControlFlow::None; + if (ifStatement->elsebody) + elsecf = visit(elseScope, ifStatement->elsebody); + + if (thencf != ControlFlow::None && elsecf == ControlFlow::None) + scope->inheritRefinements(elseScope); + else if (thencf == ControlFlow::None && elsecf != ControlFlow::None) + scope->inheritRefinements(thenScope); + + if (thencf == ControlFlow::None) + scope->inheritAssignments(thenScope); + if (elsecf == ControlFlow::None) + scope->inheritAssignments(elseScope); + + if (thencf == elsecf) + return thencf; + else if (matches(thencf, ControlFlow::Returns | ControlFlow::Throws) && matches(elsecf, ControlFlow::Returns | ControlFlow::Throws)) + return ControlFlow::Returns; + else + return ControlFlow::None; +} + +ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatTypeAlias* alias) +{ + if (alias->name == kParseNameError) + return ControlFlow::None; + + if (alias->name == "typeof") + { + reportError(alias->location, GenericError{"Type aliases cannot be named typeof"}); + return ControlFlow::None; + } + + scope->typeAliasLocations[alias->name.value] = alias->location; + scope->typeAliasNameLocations[alias->name.value] = alias->nameLocation; + + ScopePtr* defnScope = astTypeAliasDefiningScopes.find(alias); + + std::unordered_map* typeBindings; + if (alias->exported) + typeBindings = &scope->exportedTypeBindings; + else + typeBindings = &scope->privateTypeBindings; + + // These will be undefined if the alias was a duplicate definition, in which + // case we just skip over it. + auto bindingIt = typeBindings->find(alias->name.value); + if (bindingIt == typeBindings->end() || defnScope == nullptr) + return ControlFlow::None; + + TypeId ty = resolveType(*defnScope, alias->type, /* inTypeArguments */ false, /* replaceErrorWithFresh */ false); + + TypeId aliasTy = bindingIt->second.type; + LUAU_ASSERT(get(aliasTy)); + if (occursCheck(aliasTy, ty)) + { + emplaceType(asMutable(aliasTy), builtinTypes->anyType); + reportError(alias->nameLocation, OccursCheckFailed{}); + } + else + emplaceType(asMutable(aliasTy), ty); + + std::vector typeParams; + for (auto tyParam : createGenerics(*defnScope, alias->generics, /* useCache */ true, /* addTypes */ false)) + typeParams.push_back(tyParam.second.ty); + + std::vector typePackParams; + for (auto tpParam : createGenericPacks(*defnScope, alias->genericPacks, /* useCache */ true, /* addTypes */ false)) + typePackParams.push_back(tpParam.second.tp); + + addConstraint( + scope, + alias->type->location, + NameConstraint{ + ty, + alias->name.value, + /*synthetic=*/false, + std::move(typeParams), + std::move(typePackParams), + } + ); + + return ControlFlow::None; +} + +ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatTypeFunction* function) +{ + // If a type function with the same name was already defined, we skip over + auto bindingIt = scope->privateTypeBindings.find(function->name.value); + if (bindingIt == scope->privateTypeBindings.end()) + return ControlFlow::None; + + TypeFun typeFunction = bindingIt->second; + + // Adding typeAliasExpansionConstraint on user-defined type function for the constraint solver + if (auto typeFunctionTy = get(DFInt::LuauTypeSolverRelease >= 646 ? follow(typeFunction.type) : typeFunction.type)) + { + TypeId expansionTy = arena->addType(PendingExpansionType{{}, function->name, typeFunctionTy->typeArguments, typeFunctionTy->packArguments}); + addConstraint(scope, function->location, TypeAliasExpansionConstraint{/* target */ expansionTy}); + } + + return ControlFlow::None; +} + +ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatDeclareGlobal* global) +{ + LUAU_ASSERT(global->type); + + TypeId globalTy = resolveType(scope, global->type, /* inTypeArguments */ false); + Name globalName(global->name.value); + + module->declaredGlobals[globalName] = globalTy; + rootScope->bindings[global->name] = Binding{globalTy, global->location}; + + DefId def = dfg->getDef(global); + rootScope->lvalueTypes[def] = globalTy; + rootScope->rvalueRefinements[def] = globalTy; + + return ControlFlow::None; +} + +static bool isMetamethod(const Name& name) +{ + return name == "__index" || name == "__newindex" || name == "__call" || name == "__concat" || name == "__unm" || name == "__add" || + name == "__sub" || name == "__mul" || name == "__div" || name == "__mod" || name == "__pow" || name == "__tostring" || + name == "__metatable" || name == "__eq" || name == "__lt" || name == "__le" || name == "__mode" || name == "__iter" || name == "__len" || + name == "__idiv"; +} + +ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatDeclareClass* declaredClass) +{ + std::optional superTy = std::make_optional(builtinTypes->classType); + if (declaredClass->superName) + { + Name superName = Name(declaredClass->superName->value); + std::optional lookupType = scope->lookupType(superName); + + if (!lookupType) + { + reportError(declaredClass->location, UnknownSymbol{superName, UnknownSymbol::Type}); + return ControlFlow::None; + } + + // We don't have generic classes, so this assertion _should_ never be hit. + LUAU_ASSERT(lookupType->typeParams.size() == 0 && lookupType->typePackParams.size() == 0); + superTy = lookupType->type; + + if (!get(follow(*superTy))) + { + reportError( + declaredClass->location, + GenericError{format("Cannot use non-class type '%s' as a superclass of class '%s'", superName.c_str(), declaredClass->name.value)} + ); + + return ControlFlow::None; + } + } + + Name className(declaredClass->name.value); + + TypeId classTy = arena->addType(ClassType(className, {}, superTy, std::nullopt, {}, {}, module->name, declaredClass->location)); + ClassType* ctv = getMutable(classTy); + + TypeId metaTy = arena->addType(TableType{TableState::Sealed, scope->level, scope.get()}); + TableType* metatable = getMutable(metaTy); + + ctv->metatable = metaTy; + + scope->exportedTypeBindings[className] = TypeFun{{}, classTy}; + + if (declaredClass->indexer) + { + RecursionCounter counter{&recursionCount}; + + if (recursionCount >= FInt::LuauCheckRecursionLimit) + { + reportCodeTooComplex(declaredClass->indexer->location); + } + else + { + ctv->indexer = TableIndexer{ + resolveType(scope, declaredClass->indexer->indexType, /* inTypeArguments */ false), + resolveType(scope, declaredClass->indexer->resultType, /* inTypeArguments */ false), + }; + } + } + + for (const AstDeclaredClassProp& prop : declaredClass->props) + { + Name propName(prop.name.value); + TypeId propTy = resolveType(scope, prop.ty, /* inTypeArguments */ false); + + bool assignToMetatable = isMetamethod(propName); + + // Function types always take 'self', but this isn't reflected in the + // parsed annotation. Add it here. + if (prop.isMethod) + { + if (FunctionType* ftv = getMutable(propTy)) + { + ftv->argNames.insert(ftv->argNames.begin(), FunctionArgument{"self", {}}); + ftv->argTypes = addTypePack({classTy}, ftv->argTypes); + + ftv->hasSelf = true; + + FunctionDefinition defn; + + defn.definitionModuleName = module->name; + defn.definitionLocation = prop.location; + // No data is preserved for varargLocation + defn.originalNameLocation = prop.nameLocation; + + ftv->definition = defn; + } + } + + TableType::Props& props = assignToMetatable ? metatable->props : ctv->props; + + if (props.count(propName) == 0) + { + props[propName] = {propTy, /*deprecated*/ false, /*deprecatedSuggestion*/ "", prop.location}; + } + else + { + Luau::Property& prop = props[propName]; + TypeId currentTy = prop.type(); + + // We special-case this logic to keep the intersection flat; otherwise we + // would create a ton of nested intersection types. + if (const IntersectionType* itv = get(currentTy)) + { + std::vector options = itv->parts; + options.push_back(propTy); + TypeId newItv = arena->addType(IntersectionType{std::move(options)}); + + prop.readTy = newItv; + prop.writeTy = newItv; + } + else if (get(currentTy)) + { + TypeId intersection = arena->addType(IntersectionType{{currentTy, propTy}}); + + prop.readTy = intersection; + prop.writeTy = intersection; + } + else + { + reportError(declaredClass->location, GenericError{format("Cannot overload non-function class member '%s'", propName.c_str())}); + } + } + } + + return ControlFlow::None; +} + +ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatDeclareFunction* global) +{ + std::vector> generics = createGenerics(scope, global->generics); + std::vector> genericPacks = createGenericPacks(scope, global->genericPacks); + + std::vector genericTys; + genericTys.reserve(generics.size()); + for (auto& [name, generic] : generics) + { + genericTys.push_back(generic.ty); + } + + std::vector genericTps; + genericTps.reserve(genericPacks.size()); + for (auto& [name, generic] : genericPacks) + { + genericTps.push_back(generic.tp); + } + + ScopePtr funScope = scope; + if (!generics.empty() || !genericPacks.empty()) + funScope = childScope(global, scope); + + TypePackId paramPack = resolveTypePack(funScope, global->params, /* inTypeArguments */ false); + TypePackId retPack = resolveTypePack(funScope, global->retTypes, /* inTypeArguments */ false); + + FunctionDefinition defn; + + defn.definitionModuleName = module->name; + defn.definitionLocation = global->location; + defn.varargLocation = global->vararg ? std::make_optional(global->varargLocation) : std::nullopt; + defn.originalNameLocation = global->nameLocation; + + TypeId fnType = arena->addType(FunctionType{TypeLevel{}, funScope.get(), std::move(genericTys), std::move(genericTps), paramPack, retPack, defn}); + FunctionType* ftv = getMutable(fnType); + ftv->isCheckedFunction = global->isCheckedFunction(); + + ftv->argNames.reserve(global->paramNames.size); + for (const auto& el : global->paramNames) + ftv->argNames.push_back(FunctionArgument{el.first.value, el.second}); + + Name fnName(global->name.value); + + module->declaredGlobals[fnName] = fnType; + scope->bindings[global->name] = Binding{fnType, global->location}; + + DefId def = dfg->getDef(global); + rootScope->lvalueTypes[def] = fnType; + rootScope->rvalueRefinements[def] = fnType; + + return ControlFlow::None; +} + +ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatError* error) +{ + for (AstStat* stat : error->statements) + visit(scope, stat); + for (AstExpr* expr : error->expressions) + check(scope, expr); + + return ControlFlow::None; +} + +InferencePack ConstraintGenerator::checkPack(const ScopePtr& scope, AstArray exprs, const std::vector>& expectedTypes) +{ + std::vector head; + std::optional tail; + + for (size_t i = 0; i < exprs.size; ++i) + { + AstExpr* expr = exprs.data[i]; + if (i < exprs.size - 1) + { + std::optional expectedType; + if (i < expectedTypes.size()) + expectedType = expectedTypes[i]; + head.push_back(check(scope, expr, expectedType).ty); + } + else + { + std::vector> expectedTailTypes; + if (i < expectedTypes.size()) + expectedTailTypes.assign(begin(expectedTypes) + i, end(expectedTypes)); + tail = checkPack(scope, expr, expectedTailTypes).tp; + } + } + + return InferencePack{addTypePack(std::move(head), tail)}; +} + +InferencePack ConstraintGenerator::checkPack( + const ScopePtr& scope, + AstExpr* expr, + const std::vector>& expectedTypes, + bool generalize +) +{ + RecursionCounter counter{&recursionCount}; + + if (recursionCount >= FInt::LuauCheckRecursionLimit) + { + reportCodeTooComplex(expr->location); + return InferencePack{builtinTypes->errorRecoveryTypePack()}; + } + + InferencePack result; + + if (AstExprCall* call = expr->as()) + result = checkPack(scope, call); + else if (AstExprVarargs* varargs = expr->as()) + { + if (scope->varargPack) + result = InferencePack{*scope->varargPack}; + else + result = InferencePack{builtinTypes->errorRecoveryTypePack()}; + } + else + { + std::optional expectedType; + if (!expectedTypes.empty()) + expectedType = expectedTypes[0]; + TypeId t = check(scope, expr, expectedType, /*forceSingletons*/ false, generalize).ty; + result = InferencePack{arena->addTypePack({t})}; + } + + LUAU_ASSERT(result.tp); + module->astTypePacks[expr] = result.tp; + return result; +} + +InferencePack ConstraintGenerator::checkPack(const ScopePtr& scope, AstExprCall* call) +{ + std::vector exprArgs; + + std::vector returnRefinements; + std::vector> discriminantTypes; + + if (call->self) + { + AstExprIndexName* indexExpr = call->func->as(); + if (!indexExpr) + ice->ice("method call expression has no 'self'"); + + exprArgs.push_back(indexExpr->expr); + + if (auto key = dfg->getRefinementKey(indexExpr->expr)) + { + TypeId discriminantTy = arena->addType(BlockedType{}); + returnRefinements.push_back(refinementArena.proposition(key, discriminantTy)); + discriminantTypes.push_back(discriminantTy); + } + else + discriminantTypes.push_back(std::nullopt); + } + + for (AstExpr* arg : call->args) + { + exprArgs.push_back(arg); + + if (auto key = dfg->getRefinementKey(arg)) + { + TypeId discriminantTy = arena->addType(BlockedType{}); + returnRefinements.push_back(refinementArena.proposition(key, discriminantTy)); + discriminantTypes.push_back(discriminantTy); + } + else + discriminantTypes.push_back(std::nullopt); + } + + Checkpoint funcBeginCheckpoint = checkpoint(this); + + TypeId fnType = check(scope, call->func).ty; + + Checkpoint funcEndCheckpoint = checkpoint(this); + + std::vector> expectedTypesForCall = getExpectedCallTypesForFunctionOverloads(fnType); + + module->astOriginalCallTypes[call->func] = fnType; + + Checkpoint argBeginCheckpoint = checkpoint(this); + + std::vector args; + std::optional argTail; + std::vector argumentRefinements; + + for (size_t i = 0; i < exprArgs.size(); ++i) + { + AstExpr* arg = exprArgs[i]; + + if (i == 0 && call->self) + { + // The self type has already been computed as a side effect of + // computing fnType. If computing that did not cause us to exceed a + // recursion limit, we can fetch it from astTypes rather than + // recomputing it. + TypeId* selfTy = module->astTypes.find(exprArgs[0]); + if (selfTy) + args.push_back(*selfTy); + else + args.push_back(freshType(scope)); + } + else if (i < exprArgs.size() - 1 || !(arg->is() || arg->is())) + { + auto [ty, refinement] = check(scope, arg, /*expectedType*/ std::nullopt, /*forceSingleton*/ false, /*generalize*/ false); + args.push_back(ty); + argumentRefinements.push_back(refinement); + } + else + { + auto [tp, refis] = checkPack(scope, arg, {}); + argTail = tp; + argumentRefinements.insert(argumentRefinements.end(), refis.begin(), refis.end()); + } + } + + Checkpoint argEndCheckpoint = checkpoint(this); + + if (matchSetMetatable(*call)) + { + TypePack argTailPack; + if (argTail && args.size() < 2) + argTailPack = extendTypePack(*arena, builtinTypes, *argTail, 2 - args.size()); + + TypeId target = nullptr; + TypeId mt = nullptr; + + if (args.size() + argTailPack.head.size() == 2) + { + target = args.size() > 0 ? args[0] : argTailPack.head[0]; + mt = args.size() > 1 ? args[1] : argTailPack.head[args.size() == 0 ? 1 : 0]; + } + else + { + std::vector unpackedTypes; + if (args.size() > 0) + target = args[0]; + else + { + target = arena->addType(BlockedType{}); + unpackedTypes.emplace_back(target); + } + + mt = arena->addType(BlockedType{}); + unpackedTypes.emplace_back(mt); + + auto c = addConstraint(scope, call->location, UnpackConstraint{unpackedTypes, *argTail}); + getMutable(mt)->setOwner(c); + if (auto b = getMutable(target); b && b->getOwner() == nullptr) + b->setOwner(c); + } + + LUAU_ASSERT(target); + LUAU_ASSERT(mt); + + target = follow(target); + + AstExpr* targetExpr = call->args.data[0]; + + TypeId resultTy = nullptr; + + if (isTableUnion(target)) + { + const UnionType* targetUnion = get(target); + std::vector newParts; + + for (TypeId ty : targetUnion) + newParts.push_back(arena->addType(MetatableType{ty, mt})); + + resultTy = arena->addType(UnionType{std::move(newParts)}); + } + else + resultTy = arena->addType(MetatableType{target, mt}); + + if (AstExprLocal* targetLocal = targetExpr->as()) + { + scope->bindings[targetLocal->local].typeId = resultTy; + + DefId def = dfg->getDef(targetLocal); + scope->lvalueTypes[def] = resultTy; // TODO: typestates: track this as an assignment + scope->rvalueRefinements[def] = resultTy; // TODO: typestates: track this as an assignment + + // HACK: If we have a targetLocal, it has already been added to the + // inferredBindings table. We want to replace it so that we don't + // infer a weird union like tbl | { @metatable something, tbl } + if (InferredBinding* ib = inferredBindings.find(targetLocal->local)) + ib->types.erase(target); + + recordInferredBinding(targetLocal->local, resultTy); + } + + return InferencePack{arena->addTypePack({resultTy}), {refinementArena.variadic(returnRefinements)}}; + } + + if (FFlag::LuauTypestateBuiltins && shouldTypestateForFirstArgument(*call) && call->args.size > 0 && isLValue(call->args.data[0])) + { + AstExpr* targetExpr = call->args.data[0]; + auto resultTy = arena->addType(BlockedType{}); + + if (auto def = dfg->getDefOptional(targetExpr)) + { + scope->lvalueTypes[*def] = resultTy; + scope->rvalueRefinements[*def] = resultTy; + } + } + + if (matchAssert(*call) && !argumentRefinements.empty()) + applyRefinements(scope, call->args.data[0]->location, argumentRefinements[0]); + + // TODO: How do expectedTypes play into this? Do they? + TypePackId rets = arena->addTypePack(BlockedTypePack{}); + TypePackId argPack = addTypePack(std::move(args), argTail); + FunctionType ftv(TypeLevel{}, scope.get(), argPack, rets, std::nullopt, call->self); + + /* + * To make bidirectional type checking work, we need to solve these constraints in a particular order: + * + * 1. Solve the function type + * 2. Propagate type information from the function type to the argument types + * 3. Solve the argument types + * 4. Solve the call + */ + + NotNull checkConstraint = addConstraint( + scope, call->func->location, FunctionCheckConstraint{fnType, argPack, call, NotNull{&module->astTypes}, NotNull{&module->astExpectedTypes}} + ); + + forEachConstraint( + funcBeginCheckpoint, + funcEndCheckpoint, + this, + [checkConstraint](const ConstraintPtr& constraint) + { + checkConstraint->dependencies.emplace_back(constraint.get()); + } + ); + + NotNull callConstraint = addConstraint( + scope, + call->func->location, + FunctionCallConstraint{ + fnType, + argPack, + rets, + call, + std::move(discriminantTypes), + &module->astOverloadResolvedTypes, + } + ); + + getMutable(rets)->owner = callConstraint.get(); + + callConstraint->dependencies.push_back(checkConstraint); + + forEachConstraint( + argBeginCheckpoint, + argEndCheckpoint, + this, + [checkConstraint, callConstraint](const ConstraintPtr& constraint) + { + constraint->dependencies.emplace_back(checkConstraint); + + callConstraint->dependencies.emplace_back(constraint.get()); + } + ); + + return InferencePack{rets, {refinementArena.variadic(returnRefinements)}}; +} + +Inference ConstraintGenerator::check(const ScopePtr& scope, AstExpr* expr, std::optional expectedType, bool forceSingleton, bool generalize) +{ + RecursionCounter counter{&recursionCount}; + + if (recursionCount >= FInt::LuauCheckRecursionLimit) + { + reportCodeTooComplex(expr->location); + return Inference{builtinTypes->errorRecoveryType()}; + } + + Inference result; + + if (auto group = expr->as()) + result = check(scope, group->expr, expectedType, forceSingleton, generalize); + else if (auto stringExpr = expr->as()) + result = check(scope, stringExpr, expectedType, forceSingleton); + else if (expr->is()) + result = Inference{builtinTypes->numberType}; + else if (auto boolExpr = expr->as()) + result = check(scope, boolExpr, expectedType, forceSingleton); + else if (expr->is()) + result = Inference{builtinTypes->nilType}; + else if (auto local = expr->as()) + result = check(scope, local); + else if (auto global = expr->as()) + result = check(scope, global); + else if (expr->is()) + result = flattenPack(scope, expr->location, checkPack(scope, expr)); + else if (auto call = expr->as()) + result = flattenPack(scope, expr->location, checkPack(scope, call)); // TODO: needs predicates too + else if (auto a = expr->as()) + result = check(scope, a, expectedType, generalize); + else if (auto indexName = expr->as()) + result = check(scope, indexName); + else if (auto indexExpr = expr->as()) + result = check(scope, indexExpr); + else if (auto table = expr->as()) + result = check(scope, table, expectedType); + else if (auto unary = expr->as()) + result = check(scope, unary); + else if (auto binary = expr->as()) + result = check(scope, binary, expectedType); + else if (auto ifElse = expr->as()) + result = check(scope, ifElse, expectedType); + else if (auto typeAssert = expr->as()) + result = check(scope, typeAssert); + else if (auto interpString = expr->as()) + result = check(scope, interpString); + else if (auto err = expr->as()) + { + // Open question: Should we traverse into this? + for (AstExpr* subExpr : err->expressions) + check(scope, subExpr); + + result = Inference{builtinTypes->errorRecoveryType()}; + } + else + { + LUAU_ASSERT(0); + result = Inference{freshType(scope)}; + } + + LUAU_ASSERT(result.ty); + module->astTypes[expr] = result.ty; + if (expectedType) + module->astExpectedTypes[expr] = *expectedType; + return result; +} + +Inference ConstraintGenerator::check(const ScopePtr& scope, AstExprConstantString* string, std::optional expectedType, bool forceSingleton) +{ + if (forceSingleton) + return Inference{arena->addType(SingletonType{StringSingleton{std::string{string->value.data, string->value.size}}})}; + + FreeType ft = FreeType{scope.get()}; + ft.lowerBound = arena->addType(SingletonType{StringSingleton{std::string{string->value.data, string->value.size}}}); + ft.upperBound = builtinTypes->stringType; + const TypeId freeTy = arena->addType(ft); + addConstraint(scope, string->location, PrimitiveTypeConstraint{freeTy, expectedType, builtinTypes->stringType}); + return Inference{freeTy}; +} + +Inference ConstraintGenerator::check(const ScopePtr& scope, AstExprConstantBool* boolExpr, std::optional expectedType, bool forceSingleton) +{ + const TypeId singletonType = boolExpr->value ? builtinTypes->trueType : builtinTypes->falseType; + if (forceSingleton) + return Inference{singletonType}; + + FreeType ft = FreeType{scope.get()}; + ft.lowerBound = singletonType; + ft.upperBound = builtinTypes->booleanType; + const TypeId freeTy = arena->addType(ft); + addConstraint(scope, boolExpr->location, PrimitiveTypeConstraint{freeTy, expectedType, builtinTypes->booleanType}); + return Inference{freeTy}; +} + +Inference ConstraintGenerator::check(const ScopePtr& scope, AstExprLocal* local) +{ + const RefinementKey* key = dfg->getRefinementKey(local); + std::optional rvalueDef = dfg->getRValueDefForCompoundAssign(local); + LUAU_ASSERT(key || rvalueDef); + + std::optional maybeTy; + + // if we have a refinement key, we can look up its type. + if (key) + maybeTy = lookup(scope, local->location, key->def); + + // if the current def doesn't have a type, we might be doing a compound assignment + // and therefore might need to look at the rvalue def instead. + if (!maybeTy && rvalueDef) + maybeTy = lookup(scope, local->location, *rvalueDef); + + if (maybeTy) + { + TypeId ty = follow(*maybeTy); + + recordInferredBinding(local->local, ty); + + return Inference{ty, refinementArena.proposition(key, builtinTypes->truthyType)}; + } + else + ice->ice("CG: AstExprLocal came before its declaration?"); +} + +Inference ConstraintGenerator::check(const ScopePtr& scope, AstExprGlobal* global) +{ + const RefinementKey* key = dfg->getRefinementKey(global); + std::optional rvalueDef = dfg->getRValueDefForCompoundAssign(global); + LUAU_ASSERT(key || rvalueDef); + + // we'll use whichever of the two definitions we have here. + DefId def = key ? key->def : *rvalueDef; + + /* prepopulateGlobalScope() has already added all global functions to the environment by this point, so any + * global that is not already in-scope is definitely an unknown symbol. + */ + if (auto ty = lookup(scope, global->location, def, /*prototype=*/false)) + { + rootScope->lvalueTypes[def] = *ty; + return Inference{*ty, refinementArena.proposition(key, builtinTypes->truthyType)}; + } + else + return Inference{builtinTypes->errorRecoveryType()}; +} + +Inference ConstraintGenerator::checkIndexName( + const ScopePtr& scope, + const RefinementKey* key, + AstExpr* indexee, + const std::string& index, + Location indexLocation +) +{ + TypeId obj = check(scope, indexee).ty; + TypeId result = nullptr; + + // We optimize away the HasProp constraint in simple cases so that we can + // reason about updates to unsealed tables more accurately. + + const TableType* tt = getTableType(obj); + + // This is a little bit iffy but I *believe* it is okay because, if the + // local's domain is going to be extended at all, it will be someplace after + // the current lexical position within the script. + if (!tt) + { + if (TypeIds* localDomain = localTypes.find(obj); localDomain && 1 == localDomain->size()) + tt = getTableType(*localDomain->begin()); + } + + if (tt) + { + auto it = tt->props.find(index); + if (it != tt->props.end() && it->second.readTy.has_value()) + result = *it->second.readTy; + } + + if (!result) + { + result = arena->addType(BlockedType{}); + + auto c = addConstraint( + scope, indexee->location, HasPropConstraint{result, obj, std::move(index), ValueContext::RValue, inConditional(typeContext)} + ); + getMutable(result)->setOwner(c); + } + + if (key) + { + if (auto ty = lookup(scope, indexLocation, key->def)) + return Inference{*ty, refinementArena.proposition(key, builtinTypes->truthyType)}; + + scope->rvalueRefinements[key->def] = result; + } + + if (key) + return Inference{result, refinementArena.proposition(key, builtinTypes->truthyType)}; + else + return Inference{result}; +} + +Inference ConstraintGenerator::check(const ScopePtr& scope, AstExprIndexName* indexName) +{ + const RefinementKey* key = dfg->getRefinementKey(indexName); + return checkIndexName(scope, key, indexName->expr, indexName->index.value, indexName->indexLocation); +} + +Inference ConstraintGenerator::check(const ScopePtr& scope, AstExprIndexExpr* indexExpr) +{ + if (auto constantString = indexExpr->index->as()) + { + module->astTypes[indexExpr->index] = builtinTypes->stringType; + const RefinementKey* key = dfg->getRefinementKey(indexExpr); + return checkIndexName(scope, key, indexExpr->expr, constantString->value.data, indexExpr->location); + } + + TypeId obj = check(scope, indexExpr->expr).ty; + TypeId indexType = check(scope, indexExpr->index).ty; + + TypeId result = arena->addType(BlockedType{}); + + const RefinementKey* key = dfg->getRefinementKey(indexExpr); + if (key) + { + if (auto ty = lookup(scope, indexExpr->location, key->def)) + return Inference{*ty, refinementArena.proposition(key, builtinTypes->truthyType)}; + + scope->rvalueRefinements[key->def] = result; + } + + auto c = addConstraint(scope, indexExpr->expr->location, HasIndexerConstraint{result, obj, indexType}); + getMutable(result)->setOwner(c); + + if (key) + return Inference{result, refinementArena.proposition(key, builtinTypes->truthyType)}; + else + return Inference{result}; +} + +Inference ConstraintGenerator::check(const ScopePtr& scope, AstExprFunction* func, std::optional expectedType, bool generalize) +{ + Checkpoint startCheckpoint = checkpoint(this); + FunctionSignature sig = checkFunctionSignature(scope, func, expectedType); + + interiorTypes.push_back(std::vector{}); + checkFunctionBody(sig.bodyScope, func); + Checkpoint endCheckpoint = checkpoint(this); + + TypeId generalizedTy = arena->addType(BlockedType{}); + NotNull gc = + addConstraint(sig.signatureScope, func->location, GeneralizationConstraint{generalizedTy, sig.signature, std::move(interiorTypes.back())}); + getMutable(generalizedTy)->setOwner(gc); + interiorTypes.pop_back(); + + Constraint* previous = nullptr; + forEachConstraint( + startCheckpoint, + endCheckpoint, + this, + [gc, &previous](const ConstraintPtr& constraint) + { + gc->dependencies.emplace_back(constraint.get()); + + if (auto psc = get(*constraint); psc && psc->returns) + { + if (previous) + constraint->dependencies.push_back(NotNull{previous}); + + previous = constraint.get(); + } + } + ); + + if (generalize && hasFreeType(sig.signature)) + { + return Inference{generalizedTy}; + } + else + { + return Inference{sig.signature}; + } +} + +Inference ConstraintGenerator::check(const ScopePtr& scope, AstExprUnary* unary) +{ + auto [operandType, refinement] = check(scope, unary->expr); + + switch (unary->op) + { + case AstExprUnary::Op::Not: + { + TypeId resultType = createTypeFunctionInstance(builtinTypeFunctions().notFunc, {operandType}, {}, scope, unary->location); + return Inference{resultType, refinementArena.negation(refinement)}; + } + case AstExprUnary::Op::Len: + { + TypeId resultType = createTypeFunctionInstance(builtinTypeFunctions().lenFunc, {operandType}, {}, scope, unary->location); + return Inference{resultType, refinementArena.negation(refinement)}; + } + case AstExprUnary::Op::Minus: + { + TypeId resultType = createTypeFunctionInstance(builtinTypeFunctions().unmFunc, {operandType}, {}, scope, unary->location); + return Inference{resultType, refinementArena.negation(refinement)}; + } + default: // msvc can't prove that this is exhaustive. + LUAU_UNREACHABLE(); + } +} + +Inference ConstraintGenerator::check(const ScopePtr& scope, AstExprBinary* binary, std::optional expectedType) +{ + auto [leftType, rightType, refinement] = checkBinary(scope, binary, expectedType); + + switch (binary->op) + { + case AstExprBinary::Op::Add: + { + TypeId resultType = createTypeFunctionInstance(builtinTypeFunctions().addFunc, {leftType, rightType}, {}, scope, binary->location); + return Inference{resultType, std::move(refinement)}; + } + case AstExprBinary::Op::Sub: + { + TypeId resultType = createTypeFunctionInstance(builtinTypeFunctions().subFunc, {leftType, rightType}, {}, scope, binary->location); + return Inference{resultType, std::move(refinement)}; + } + case AstExprBinary::Op::Mul: + { + TypeId resultType = createTypeFunctionInstance(builtinTypeFunctions().mulFunc, {leftType, rightType}, {}, scope, binary->location); + return Inference{resultType, std::move(refinement)}; + } + case AstExprBinary::Op::Div: + { + TypeId resultType = createTypeFunctionInstance(builtinTypeFunctions().divFunc, {leftType, rightType}, {}, scope, binary->location); + return Inference{resultType, std::move(refinement)}; + } + case AstExprBinary::Op::FloorDiv: + { + TypeId resultType = createTypeFunctionInstance(builtinTypeFunctions().idivFunc, {leftType, rightType}, {}, scope, binary->location); + return Inference{resultType, std::move(refinement)}; + } + case AstExprBinary::Op::Pow: + { + TypeId resultType = createTypeFunctionInstance(builtinTypeFunctions().powFunc, {leftType, rightType}, {}, scope, binary->location); + return Inference{resultType, std::move(refinement)}; + } + case AstExprBinary::Op::Mod: + { + TypeId resultType = createTypeFunctionInstance(builtinTypeFunctions().modFunc, {leftType, rightType}, {}, scope, binary->location); + return Inference{resultType, std::move(refinement)}; + } + case AstExprBinary::Op::Concat: + { + TypeId resultType = createTypeFunctionInstance(builtinTypeFunctions().concatFunc, {leftType, rightType}, {}, scope, binary->location); + return Inference{resultType, std::move(refinement)}; + } + case AstExprBinary::Op::And: + { + TypeId resultType = createTypeFunctionInstance(builtinTypeFunctions().andFunc, {leftType, rightType}, {}, scope, binary->location); + return Inference{resultType, std::move(refinement)}; + } + case AstExprBinary::Op::Or: + { + TypeId resultType = createTypeFunctionInstance(builtinTypeFunctions().orFunc, {leftType, rightType}, {}, scope, binary->location); + return Inference{resultType, std::move(refinement)}; + } + case AstExprBinary::Op::CompareLt: + { + TypeId resultType = createTypeFunctionInstance(builtinTypeFunctions().ltFunc, {leftType, rightType}, {}, scope, binary->location); + return Inference{resultType, std::move(refinement)}; + } + case AstExprBinary::Op::CompareGe: + { + TypeId resultType = createTypeFunctionInstance( + builtinTypeFunctions().ltFunc, + {rightType, leftType}, // lua decided that `__ge(a, b)` is instead just `__lt(b, a)` + {}, + scope, + binary->location + ); + return Inference{resultType, std::move(refinement)}; + } + case AstExprBinary::Op::CompareLe: + { + TypeId resultType = createTypeFunctionInstance(builtinTypeFunctions().leFunc, {leftType, rightType}, {}, scope, binary->location); + return Inference{resultType, std::move(refinement)}; + } + case AstExprBinary::Op::CompareGt: + { + TypeId resultType = createTypeFunctionInstance( + builtinTypeFunctions().leFunc, + {rightType, leftType}, // lua decided that `__gt(a, b)` is instead just `__le(b, a)` + {}, + scope, + binary->location + ); + return Inference{resultType, std::move(refinement)}; + } + case AstExprBinary::Op::CompareEq: + case AstExprBinary::Op::CompareNe: + { + DefId leftDef = dfg->getDef(binary->left); + DefId rightDef = dfg->getDef(binary->right); + bool leftSubscripted = containsSubscriptedDefinition(leftDef); + bool rightSubscripted = containsSubscriptedDefinition(rightDef); + + if (leftSubscripted && rightSubscripted) + { + // we cannot add nil in this case because then we will blindly accept comparisons that we should not. + } + else if (leftSubscripted) + leftType = makeUnion(scope, binary->location, leftType, builtinTypes->nilType); + else if (rightSubscripted) + rightType = makeUnion(scope, binary->location, rightType, builtinTypes->nilType); + + TypeId resultType = createTypeFunctionInstance(builtinTypeFunctions().eqFunc, {leftType, rightType}, {}, scope, binary->location); + return Inference{resultType, std::move(refinement)}; + } + case AstExprBinary::Op::Op__Count: + ice->ice("Op__Count should never be generated in an AST."); + default: // msvc can't prove that this is exhaustive. + LUAU_UNREACHABLE(); + } +} + +Inference ConstraintGenerator::check(const ScopePtr& scope, AstExprIfElse* ifElse, std::optional expectedType) +{ + RefinementId refinement = [&]() + { + InConditionalContext flipper{&typeContext}; + ScopePtr condScope = childScope(ifElse->condition, scope); + return check(condScope, ifElse->condition).refinement; + }(); + + ScopePtr thenScope = childScope(ifElse->trueExpr, scope); + applyRefinements(thenScope, ifElse->trueExpr->location, refinement); + TypeId thenType = check(thenScope, ifElse->trueExpr, expectedType).ty; + + ScopePtr elseScope = childScope(ifElse->falseExpr, scope); + applyRefinements(elseScope, ifElse->falseExpr->location, refinementArena.negation(refinement)); + TypeId elseType = check(elseScope, ifElse->falseExpr, expectedType).ty; + + return Inference{expectedType ? *expectedType : makeUnion(scope, ifElse->location, thenType, elseType)}; +} + +Inference ConstraintGenerator::check(const ScopePtr& scope, AstExprTypeAssertion* typeAssert) +{ + check(scope, typeAssert->expr, std::nullopt); + return Inference{resolveType(scope, typeAssert->annotation, /* inTypeArguments */ false)}; +} + +Inference ConstraintGenerator::check(const ScopePtr& scope, AstExprInterpString* interpString) +{ + for (AstExpr* expr : interpString->expressions) + check(scope, expr); + + return Inference{builtinTypes->stringType}; +} + +std::tuple ConstraintGenerator::checkBinary( + const ScopePtr& scope, + AstExprBinary* binary, + std::optional expectedType +) +{ + if (binary->op == AstExprBinary::And) + { + std::optional relaxedExpectedLhs; + + if (expectedType) + relaxedExpectedLhs = arena->addType(UnionType{{builtinTypes->falsyType, *expectedType}}); + + auto [leftType, leftRefinement] = check(scope, binary->left, relaxedExpectedLhs); + + ScopePtr rightScope = childScope(binary->right, scope); + applyRefinements(rightScope, binary->right->location, leftRefinement); + auto [rightType, rightRefinement] = check(rightScope, binary->right, expectedType); + + return {leftType, rightType, refinementArena.conjunction(leftRefinement, rightRefinement)}; + } + else if (binary->op == AstExprBinary::Or) + { + std::optional relaxedExpectedLhs; + + if (expectedType) + relaxedExpectedLhs = arena->addType(UnionType{{builtinTypes->falsyType, *expectedType}}); + + auto [leftType, leftRefinement] = check(scope, binary->left, relaxedExpectedLhs); + + ScopePtr rightScope = childScope(binary->right, scope); + applyRefinements(rightScope, binary->right->location, refinementArena.negation(leftRefinement)); + auto [rightType, rightRefinement] = check(rightScope, binary->right, expectedType); + + return {leftType, rightType, refinementArena.disjunction(leftRefinement, rightRefinement)}; + } + else if (auto typeguard = matchTypeGuard(binary)) + { + TypeId leftType = check(scope, binary->left).ty; + TypeId rightType = check(scope, binary->right).ty; + + const RefinementKey* key = dfg->getRefinementKey(typeguard->target); + if (!key) + return {leftType, rightType, nullptr}; + + TypeId discriminantTy = builtinTypes->neverType; + if (typeguard->type == "nil") + discriminantTy = builtinTypes->nilType; + else if (typeguard->type == "string") + discriminantTy = builtinTypes->stringType; + else if (typeguard->type == "number") + discriminantTy = builtinTypes->numberType; + else if (typeguard->type == "boolean") + discriminantTy = builtinTypes->booleanType; + else if (typeguard->type == "thread") + discriminantTy = builtinTypes->threadType; + else if (typeguard->type == "buffer") + discriminantTy = builtinTypes->bufferType; + else if (typeguard->type == "table") + discriminantTy = builtinTypes->tableType; + else if (typeguard->type == "function") + discriminantTy = builtinTypes->functionType; + else if (typeguard->type == "userdata") + { + // For now, we don't really care about being accurate with userdata if the typeguard was using typeof. + discriminantTy = builtinTypes->classType; + } + else if (!typeguard->isTypeof && typeguard->type == "vector") + discriminantTy = builtinTypes->neverType; // TODO: figure out a way to deal with this quirky type + else if (!typeguard->isTypeof) + discriminantTy = builtinTypes->neverType; + else if (auto typeFun = globalScope->lookupType(typeguard->type); typeFun && typeFun->typeParams.empty() && typeFun->typePackParams.empty()) + { + TypeId ty = follow(typeFun->type); + + // We're only interested in the root class of any classes. + if (auto ctv = get(ty); ctv && (ctv->parent == builtinTypes->classType || hasTag(ty, kTypeofRootTag))) + discriminantTy = ty; + } + + RefinementId proposition = refinementArena.proposition(key, discriminantTy); + if (binary->op == AstExprBinary::CompareEq) + return {leftType, rightType, proposition}; + else if (binary->op == AstExprBinary::CompareNe) + return {leftType, rightType, refinementArena.negation(proposition)}; + else + ice->ice("matchTypeGuard should only return a Some under `==` or `~=`!"); + } + else if (binary->op == AstExprBinary::CompareEq || binary->op == AstExprBinary::CompareNe) + { + // We are checking a binary expression of the form a op b + // Just because a op b is epxected to return a bool, doesn't mean a, b are expected to be bools too + TypeId leftType = check(scope, binary->left, {}, true).ty; + TypeId rightType = check(scope, binary->right, {}, true).ty; + + RefinementId leftRefinement = refinementArena.proposition(dfg->getRefinementKey(binary->left), rightType); + RefinementId rightRefinement = refinementArena.proposition(dfg->getRefinementKey(binary->right), leftType); + + if (binary->op == AstExprBinary::CompareNe) + { + leftRefinement = refinementArena.negation(leftRefinement); + rightRefinement = refinementArena.negation(rightRefinement); + } + + return {leftType, rightType, refinementArena.equivalence(leftRefinement, rightRefinement)}; + } + else + { + TypeId leftType = check(scope, binary->left).ty; + TypeId rightType = check(scope, binary->right).ty; + return {leftType, rightType, nullptr}; + } +} + +void ConstraintGenerator::visitLValue(const ScopePtr& scope, AstExpr* expr, TypeId rhsType) +{ + if (auto e = expr->as()) + visitLValue(scope, e, rhsType); + else if (auto e = expr->as()) + visitLValue(scope, e, rhsType); + else if (auto e = expr->as()) + visitLValue(scope, e, rhsType); + else if (auto e = expr->as()) + visitLValue(scope, e, rhsType); + else if (auto e = expr->as()) + { + if (FFlag::LuauNewSolverVisitErrorExprLvalues) + { + // If we end up with some sort of error expression in an lvalue + // position, at least go and check the expressions so that when + // we visit them later, there aren't any invalid assumptions. + for (auto subExpr : e->expressions) + { + check(scope, subExpr); + } + } + } + else + ice->ice("Unexpected lvalue expression", expr->location); +} + +void ConstraintGenerator::visitLValue(const ScopePtr& scope, AstExprLocal* local, TypeId rhsType) +{ + std::optional annotatedTy = scope->lookup(local->local); + LUAU_ASSERT(annotatedTy); + + const DefId defId = dfg->getDef(local); + std::optional ty = scope->lookupUnrefinedType(defId); + + if (ty) + { + TypeIds* localDomain = localTypes.find(*ty); + if (localDomain) + localDomain->insert(rhsType); + } + else + { + ty = arena->addType(BlockedType{}); + localTypes[*ty].insert(rhsType); + + if (annotatedTy) + { + switch (shouldSuppressErrors(normalizer, *annotatedTy)) + { + case ErrorSuppression::DoNotSuppress: + break; + case ErrorSuppression::Suppress: + ty = simplifyUnion(builtinTypes, arena, *ty, builtinTypes->errorType).result; + break; + case ErrorSuppression::NormalizationFailed: + reportError(local->local->annotation->location, NormalizationTooComplex{}); + break; + } + } + + scope->lvalueTypes[defId] = *ty; + } + + recordInferredBinding(local->local, *ty); + + if (annotatedTy) + addConstraint(scope, local->location, SubtypeConstraint{rhsType, *annotatedTy}); + + if (TypeIds* localDomain = localTypes.find(*ty)) + localDomain->insert(rhsType); +} + +void ConstraintGenerator::visitLValue(const ScopePtr& scope, AstExprGlobal* global, TypeId rhsType) +{ + std::optional annotatedTy = scope->lookup(Symbol{global->name}); + if (annotatedTy) + { + DefId def = dfg->getDef(global); + rootScope->lvalueTypes[def] = rhsType; + + addConstraint(scope, global->location, SubtypeConstraint{rhsType, *annotatedTy}); + } +} + +void ConstraintGenerator::visitLValue(const ScopePtr& scope, AstExprIndexName* expr, TypeId rhsType) +{ + TypeId lhsTy = check(scope, expr->expr).ty; + TypeId propTy = arena->addType(BlockedType{}); + module->astTypes[expr] = propTy; + + bool incremented = recordPropertyAssignment(lhsTy); + + auto apc = + addConstraint(scope, expr->location, AssignPropConstraint{lhsTy, expr->index.value, rhsType, expr->indexLocation, propTy, incremented}); + getMutable(propTy)->setOwner(apc); +} + +void ConstraintGenerator::visitLValue(const ScopePtr& scope, AstExprIndexExpr* expr, TypeId rhsType) +{ + if (auto constantString = expr->index->as()) + { + TypeId lhsTy = check(scope, expr->expr).ty; + TypeId propTy = arena->addType(BlockedType{}); + module->astTypes[expr] = propTy; + module->astTypes[expr->index] = builtinTypes->stringType; // FIXME? Singleton strings exist. + std::string propName{constantString->value.data, constantString->value.size}; + + bool incremented = recordPropertyAssignment(lhsTy); + + auto apc = addConstraint( + scope, expr->location, AssignPropConstraint{lhsTy, std::move(propName), rhsType, expr->index->location, propTy, incremented} + ); + getMutable(propTy)->setOwner(apc); + + return; + } + + TypeId lhsTy = check(scope, expr->expr).ty; + TypeId indexTy = check(scope, expr->index).ty; + TypeId propTy = arena->addType(BlockedType{}); + module->astTypes[expr] = propTy; + auto aic = addConstraint(scope, expr->location, AssignIndexConstraint{lhsTy, indexTy, rhsType, propTy}); + getMutable(propTy)->setOwner(aic); +} + +Inference ConstraintGenerator::check(const ScopePtr& scope, AstExprTable* expr, std::optional expectedType) +{ + TypeId ty = arena->addType(TableType{}); + TableType* ttv = getMutable(ty); + LUAU_ASSERT(ttv); + + ttv->state = TableState::Unsealed; + ttv->definitionModuleName = module->name; + ttv->scope = scope.get(); + + interiorTypes.back().push_back(ty); + + TypeIds indexKeyLowerBound; + TypeIds indexValueLowerBound; + + auto createIndexer = [&indexKeyLowerBound, &indexValueLowerBound](const Location& location, TypeId currentIndexType, TypeId currentResultType) + { + indexKeyLowerBound.insert(follow(currentIndexType)); + indexValueLowerBound.insert(follow(currentResultType)); + }; + + TypeIds valuesLowerBound; + + for (const AstExprTable::Item& item : expr->items) + { + // Expected types are threaded through table literals separately via the + // function matchLiteralType. + + TypeId itemTy = check(scope, item.value).ty; + + if (item.key) + { + // Even though we don't need to use the type of the item's key if + // it's a string constant, we still want to check it to populate + // astTypes. + TypeId keyTy = check(scope, item.key).ty; + + if (AstExprConstantString* key = item.key->as()) + { + std::string propName{key->value.data, key->value.size}; + ttv->props[propName] = {itemTy, /*deprecated*/ false, {}, key->location}; + } + else + { + createIndexer(item.key->location, keyTy, itemTy); + } + } + else + { + TypeId numberType = builtinTypes->numberType; + // FIXME? The location isn't quite right here. Not sure what is + // right. + createIndexer(item.value->location, numberType, itemTy); + } + } + + if (!indexKeyLowerBound.empty()) + { + LUAU_ASSERT(!indexValueLowerBound.empty()); + + TypeId indexKey = indexKeyLowerBound.size() == 1 + ? *indexKeyLowerBound.begin() + : arena->addType(UnionType{std::vector(indexKeyLowerBound.begin(), indexKeyLowerBound.end())}); + + TypeId indexValue = indexValueLowerBound.size() == 1 + ? *indexValueLowerBound.begin() + : arena->addType(UnionType{std::vector(indexValueLowerBound.begin(), indexValueLowerBound.end())}); + + ttv->indexer = TableIndexer{indexKey, indexValue}; + } + + if (expectedType) + { + Unifier2 unifier{arena, builtinTypes, NotNull{scope.get()}, ice}; + std::vector toBlock; + if (DFInt::LuauTypeSolverRelease >= 648) + { + // This logic is incomplete as we want to re-run this + // _after_ blocked types have resolved, but this + // allows us to do some bidirectional inference. + toBlock = findBlockedTypesIn(expr, NotNull{&module->astTypes}); + if (toBlock.empty()) + { + matchLiteralType( + NotNull{&module->astTypes}, + NotNull{&module->astExpectedTypes}, + builtinTypes, + arena, + NotNull{&unifier}, + *expectedType, + ty, + expr, + toBlock + ); + // The visitor we ran prior should ensure that there are no + // blocked types that we would encounter while matching on + // this expression. + LUAU_ASSERT(toBlock.empty()); + } + } + else + { + matchLiteralType( + NotNull{&module->astTypes}, + NotNull{&module->astExpectedTypes}, + builtinTypes, + arena, + NotNull{&unifier}, + *expectedType, + ty, + expr, + toBlock + ); + } + } + + return Inference{ty}; +} + +ConstraintGenerator::FunctionSignature ConstraintGenerator::checkFunctionSignature( + const ScopePtr& parent, + AstExprFunction* fn, + std::optional expectedType, + std::optional originalName +) +{ + ScopePtr signatureScope = nullptr; + ScopePtr bodyScope = nullptr; + TypePackId returnType = nullptr; + + std::vector genericTypes; + std::vector genericTypePacks; + + if (expectedType) + expectedType = follow(*expectedType); + + bool hasGenerics = fn->generics.size > 0 || fn->genericPacks.size > 0; + + signatureScope = childScope(fn, parent); + + // We need to assign returnType before creating bodyScope so that the + // return type gets propogated to bodyScope. + returnType = freshTypePack(signatureScope); + signatureScope->returnType = returnType; + + bodyScope = childScope(fn->body, signatureScope); + + if (hasGenerics) + { + std::vector> genericDefinitions = createGenerics(signatureScope, fn->generics); + std::vector> genericPackDefinitions = createGenericPacks(signatureScope, fn->genericPacks); + + // We do not support default values on function generics, so we only + // care about the types involved. + for (const auto& [name, g] : genericDefinitions) + { + genericTypes.push_back(g.ty); + } + + for (const auto& [name, g] : genericPackDefinitions) + { + genericTypePacks.push_back(g.tp); + } + + // Local variable works around an odd gcc 11.3 warning: may be used uninitialized + std::optional none = std::nullopt; + expectedType = none; + } + + std::vector argTypes; + std::vector> argNames; + TypePack expectedArgPack; + + const FunctionType* expectedFunction = expectedType ? get(*expectedType) : nullptr; + // This check ensures that expectedType is precisely optional and not any (since any is also an optional type) + if (expectedType && isOptional(*expectedType) && !get(*expectedType)) + { + if (auto ut = get(*expectedType)) + { + for (auto u : ut) + { + if (get(u) && !isNil(u)) + { + expectedFunction = get(u); + break; + } + } + } + } + + if (expectedFunction) + { + expectedArgPack = extendTypePack(*arena, builtinTypes, expectedFunction->argTypes, fn->args.size); + + genericTypes = expectedFunction->generics; + genericTypePacks = expectedFunction->genericPacks; + } + + if (fn->self) + { + TypeId selfType = freshType(signatureScope); + argTypes.push_back(selfType); + argNames.emplace_back(FunctionArgument{fn->self->name.value, fn->self->location}); + signatureScope->bindings[fn->self] = Binding{selfType, fn->self->location}; + + DefId def = dfg->getDef(fn->self); + signatureScope->lvalueTypes[def] = selfType; + signatureScope->rvalueRefinements[def] = selfType; + } + + for (size_t i = 0; i < fn->args.size; ++i) + { + AstLocal* local = fn->args.data[i]; + + TypeId argTy = nullptr; + if (local->annotation) + argTy = resolveType(signatureScope, local->annotation, /* inTypeArguments */ false, /* replaceErrorWithFresh*/ true); + else + { + if (i < expectedArgPack.head.size()) + argTy = expectedArgPack.head[i]; + else + argTy = freshType(signatureScope); + } + + argTypes.push_back(argTy); + argNames.emplace_back(FunctionArgument{local->name.value, local->location}); + + signatureScope->bindings[local] = Binding{argTy, local->location}; + + DefId def = dfg->getDef(local); + signatureScope->lvalueTypes[def] = argTy; + signatureScope->rvalueRefinements[def] = argTy; + } + + TypePackId varargPack = nullptr; + + if (fn->vararg) + { + if (fn->varargAnnotation) + { + TypePackId annotationType = + resolveTypePack(signatureScope, fn->varargAnnotation, /* inTypeArguments */ false, /* replaceErrorWithFresh */ true); + varargPack = annotationType; + } + else if (expectedArgPack.tail && get(*expectedArgPack.tail)) + varargPack = *expectedArgPack.tail; + else + varargPack = builtinTypes->anyTypePack; + + signatureScope->varargPack = varargPack; + bodyScope->varargPack = varargPack; + } + else + { + varargPack = arena->addTypePack(VariadicTypePack{builtinTypes->anyType, /*hidden*/ true}); + // We do not add to signatureScope->varargPack because ... is not valid + // in functions without an explicit ellipsis. + + signatureScope->varargPack = std::nullopt; + bodyScope->varargPack = std::nullopt; + } + + LUAU_ASSERT(nullptr != varargPack); + + // If there is both an annotation and an expected type, the annotation wins. + // Type checking will sort out any discrepancies later. + if (fn->returnAnnotation) + { + TypePackId annotatedRetType = + resolveTypePack(signatureScope, *fn->returnAnnotation, /* inTypeArguments */ false, /* replaceErrorWithFresh*/ true); + // We bind the annotated type directly here so that, when we need to + // generate constraints for return types, we have a guarantee that we + // know the annotated return type already, if one was provided. + LUAU_ASSERT(get(returnType)); + emplaceTypePack(asMutable(returnType), annotatedRetType); + } + else if (expectedFunction) + { + emplaceTypePack(asMutable(returnType), expectedFunction->retTypes); + } + + // TODO: Preserve argument names in the function's type. + + FunctionType actualFunction{TypeLevel{}, parent.get(), arena->addTypePack(argTypes, varargPack), returnType}; + actualFunction.generics = std::move(genericTypes); + actualFunction.genericPacks = std::move(genericTypePacks); + actualFunction.argNames = std::move(argNames); + actualFunction.hasSelf = fn->self != nullptr; + + FunctionDefinition defn; + defn.definitionModuleName = module->name; + defn.definitionLocation = fn->location; + defn.varargLocation = fn->vararg ? std::make_optional(fn->varargLocation) : std::nullopt; + defn.originalNameLocation = originalName.value_or(Location(fn->location.begin, 0)); + actualFunction.definition = defn; + + TypeId actualFunctionType = arena->addType(std::move(actualFunction)); + LUAU_ASSERT(actualFunctionType); + module->astTypes[fn] = actualFunctionType; + + if (expectedType && get(*expectedType)) + bindFreeType(*expectedType, actualFunctionType); + + return { + /* signature */ actualFunctionType, + /* signatureScope */ signatureScope, + /* bodyScope */ bodyScope, + }; +} + +void ConstraintGenerator::checkFunctionBody(const ScopePtr& scope, AstExprFunction* fn) +{ + // If it is possible for execution to reach the end of the function, the return type must be compatible with () + ControlFlow cf = + DFInt::LuauTypeSolverRelease >= 646 ? visitBlockWithoutChildScope(scope, fn->body) : visitBlockWithoutChildScope_DEPRECATED(scope, fn->body); + if (cf == ControlFlow::None) + addConstraint(scope, fn->location, PackSubtypeConstraint{builtinTypes->emptyTypePack, scope->returnType}); +} + +TypeId ConstraintGenerator::resolveReferenceType( + const ScopePtr& scope, + AstType* ty, + AstTypeReference* ref, + bool inTypeArguments, + bool replaceErrorWithFresh +) +{ + TypeId result = nullptr; + + if (FFlag::DebugLuauMagicTypes) + { + if (ref->name == "_luau_ice") + ice->ice("_luau_ice encountered", ty->location); + else if (ref->name == "_luau_print") + { + if (ref->parameters.size != 1 || !ref->parameters.data[0].type) + { + reportError(ty->location, GenericError{"_luau_print requires one generic parameter"}); + module->astResolvedTypes[ty] = builtinTypes->errorRecoveryType(); + return builtinTypes->errorRecoveryType(); + } + else + return resolveType(scope, ref->parameters.data[0].type, inTypeArguments); + } + } + + std::optional alias; + + if (ref->prefix.has_value()) + { + alias = scope->lookupImportedType(ref->prefix->value, ref->name.value); + } + else + { + alias = scope->lookupType(ref->name.value); + } + + if (alias.has_value()) + { + // If the alias is not generic, we don't need to set up a blocked + // type and an instantiation constraint. + if (alias.has_value() && alias->typeParams.empty() && alias->typePackParams.empty()) + { + result = alias->type; + } + else + { + std::vector parameters; + std::vector packParameters; + + for (const AstTypeOrPack& p : ref->parameters) + { + // We do not enforce the ordering of types vs. type packs here; + // that is done in the parser. + if (p.type) + { + parameters.push_back(resolveType(scope, p.type, /* inTypeArguments */ true)); + } + else if (p.typePack) + { + TypePackId tp = resolveTypePack(scope, p.typePack, /*inTypeArguments*/ true); + + // If we need more regular types, we can use single element type packs to fill those in + if (parameters.size() < alias->typeParams.size() && size(tp) == 1 && finite(tp) && first(tp)) + parameters.push_back(*first(tp)); + else + packParameters.push_back(tp); + } + else + { + // This indicates a parser bug: one of these two pointers + // should be set. + LUAU_ASSERT(false); + } + } + + result = arena->addType(PendingExpansionType{ref->prefix, ref->name, parameters, packParameters}); + + // If we're not in a type argument context, we need to create a constraint that expands this. + // The dispatching of the above constraint will queue up additional constraints for nested + // type function applications. + if (!inTypeArguments) + addConstraint(scope, ty->location, TypeAliasExpansionConstraint{/* target */ result}); + } + } + else + { + result = builtinTypes->errorRecoveryType(); + if (replaceErrorWithFresh) + result = freshType(scope); + } + + return result; +} + +TypeId ConstraintGenerator::resolveTableType(const ScopePtr& scope, AstType* ty, AstTypeTable* tab, bool inTypeArguments, bool replaceErrorWithFresh) +{ + TableType::Props props; + std::optional indexer; + + for (const AstTableProp& prop : tab->props) + { + TypeId propTy = resolveType(scope, prop.type, inTypeArguments); + + Property& p = props[prop.name.value]; + p.typeLocation = prop.location; + + switch (prop.access) + { + case AstTableAccess::ReadWrite: + p.readTy = propTy; + p.writeTy = propTy; + break; + case AstTableAccess::Read: + p.readTy = propTy; + break; + case AstTableAccess::Write: + reportError(*prop.accessLocation, GenericError{"write keyword is illegal here"}); + p.readTy = propTy; + p.writeTy = propTy; + break; + default: + ice->ice("Unexpected property access " + std::to_string(int(prop.access))); + break; + } + } + + if (AstTableIndexer* astIndexer = tab->indexer) + { + if (astIndexer->access == AstTableAccess::Read) + reportError(astIndexer->accessLocation.value_or(Location{}), GenericError{"read keyword is illegal here"}); + else if (astIndexer->access == AstTableAccess::Write) + reportError(astIndexer->accessLocation.value_or(Location{}), GenericError{"write keyword is illegal here"}); + else if (astIndexer->access == AstTableAccess::ReadWrite) + { + indexer = TableIndexer{ + resolveType(scope, astIndexer->indexType, inTypeArguments), + resolveType(scope, astIndexer->resultType, inTypeArguments), + }; + } + else + ice->ice("Unexpected property access " + std::to_string(int(astIndexer->access))); + } + + return arena->addType(TableType{props, indexer, scope->level, scope.get(), TableState::Sealed}); +} + +TypeId ConstraintGenerator::resolveFunctionType( + const ScopePtr& scope, + AstType* ty, + AstTypeFunction* fn, + bool inTypeArguments, + bool replaceErrorWithFresh +) +{ + bool hasGenerics = fn->generics.size > 0 || fn->genericPacks.size > 0; + ScopePtr signatureScope = nullptr; + + std::vector genericTypes; + std::vector genericTypePacks; + + // If we don't have generics, we do not need to generate a child scope + // for the generic bindings to live on. + if (hasGenerics) + { + signatureScope = childScope(fn, scope); + + std::vector> genericDefinitions = createGenerics(signatureScope, fn->generics); + std::vector> genericPackDefinitions = createGenericPacks(signatureScope, fn->genericPacks); + + for (const auto& [name, g] : genericDefinitions) + { + genericTypes.push_back(g.ty); + } + + for (const auto& [name, g] : genericPackDefinitions) + { + genericTypePacks.push_back(g.tp); + } + } + else + { + // To eliminate the need to branch on hasGenerics below, we say that + // the signature scope is the parent scope if we don't have + // generics. + signatureScope = scope; + } + + TypePackId argTypes = resolveTypePack(signatureScope, fn->argTypes, inTypeArguments, replaceErrorWithFresh); + TypePackId returnTypes = resolveTypePack(signatureScope, fn->returnTypes, inTypeArguments, replaceErrorWithFresh); + + // TODO: FunctionType needs a pointer to the scope so that we know + // how to quantify/instantiate it. + FunctionType ftv{TypeLevel{}, scope.get(), {}, {}, argTypes, returnTypes}; + ftv.isCheckedFunction = fn->isCheckedFunction(); + + // This replicates the behavior of the appropriate FunctionType + // constructors. + ftv.generics = std::move(genericTypes); + ftv.genericPacks = std::move(genericTypePacks); + + ftv.argNames.reserve(fn->argNames.size); + for (const auto& el : fn->argNames) + { + if (el) + { + const auto& [name, location] = *el; + ftv.argNames.push_back(FunctionArgument{name.value, location}); + } + else + { + ftv.argNames.push_back(std::nullopt); + } + } + + return arena->addType(std::move(ftv)); +} + +TypeId ConstraintGenerator::resolveType(const ScopePtr& scope, AstType* ty, bool inTypeArguments, bool replaceErrorWithFresh) +{ + TypeId result = nullptr; + + if (auto ref = ty->as()) + { + result = resolveReferenceType(scope, ty, ref, inTypeArguments, replaceErrorWithFresh); + } + else if (auto tab = ty->as()) + { + result = resolveTableType(scope, ty, tab, inTypeArguments, replaceErrorWithFresh); + } + else if (auto fn = ty->as()) + { + result = resolveFunctionType(scope, ty, fn, inTypeArguments, replaceErrorWithFresh); + } + else if (auto tof = ty->as()) + { + TypeId exprType = check(scope, tof->expr).ty; + result = exprType; + } + else if (auto unionAnnotation = ty->as()) + { + std::vector parts; + for (AstType* part : unionAnnotation->types) + { + parts.push_back(resolveType(scope, part, inTypeArguments)); + } + + result = arena->addType(UnionType{parts}); + } + else if (auto intersectionAnnotation = ty->as()) + { + std::vector parts; + for (AstType* part : intersectionAnnotation->types) + { + parts.push_back(resolveType(scope, part, inTypeArguments)); + } + + result = arena->addType(IntersectionType{parts}); + } + else if (auto boolAnnotation = ty->as()) + { + if (boolAnnotation->value) + result = builtinTypes->trueType; + else + result = builtinTypes->falseType; + } + else if (auto stringAnnotation = ty->as()) + { + result = arena->addType(SingletonType(StringSingleton{std::string(stringAnnotation->value.data, stringAnnotation->value.size)})); + } + else if (ty->is()) + { + result = builtinTypes->errorRecoveryType(); + if (replaceErrorWithFresh) + result = freshType(scope); + } + else + { + LUAU_ASSERT(0); + result = builtinTypes->errorRecoveryType(); + } + + module->astResolvedTypes[ty] = result; + return result; +} + +TypePackId ConstraintGenerator::resolveTypePack(const ScopePtr& scope, AstTypePack* tp, bool inTypeArgument, bool replaceErrorWithFresh) +{ + TypePackId result; + if (auto expl = tp->as()) + { + result = resolveTypePack(scope, expl->typeList, inTypeArgument, replaceErrorWithFresh); + } + else if (auto var = tp->as()) + { + TypeId ty = resolveType(scope, var->variadicType, inTypeArgument, replaceErrorWithFresh); + result = arena->addTypePack(TypePackVar{VariadicTypePack{ty}}); + } + else if (auto gen = tp->as()) + { + if (std::optional lookup = scope->lookupPack(gen->genericName.value)) + { + result = *lookup; + } + else + { + reportError(tp->location, UnknownSymbol{gen->genericName.value, UnknownSymbol::Context::Type}); + result = builtinTypes->errorRecoveryTypePack(); + } + } + else + { + LUAU_ASSERT(0); + result = builtinTypes->errorRecoveryTypePack(); + } + + module->astResolvedTypePacks[tp] = result; + return result; +} + +TypePackId ConstraintGenerator::resolveTypePack(const ScopePtr& scope, const AstTypeList& list, bool inTypeArguments, bool replaceErrorWithFresh) +{ + std::vector head; + + for (AstType* headTy : list.types) + { + head.push_back(resolveType(scope, headTy, inTypeArguments, replaceErrorWithFresh)); + } + + std::optional tail = std::nullopt; + if (list.tailType) + { + tail = resolveTypePack(scope, list.tailType, inTypeArguments, replaceErrorWithFresh); + } + + return addTypePack(std::move(head), tail); +} + +std::vector> ConstraintGenerator::createGenerics( + const ScopePtr& scope, + AstArray generics, + bool useCache, + bool addTypes +) +{ + std::vector> result; + for (const auto& generic : generics) + { + TypeId genericTy = nullptr; + + if (auto it = scope->parent->typeAliasTypeParameters.find(generic.name.value); useCache && it != scope->parent->typeAliasTypeParameters.end()) + genericTy = it->second; + else + { + genericTy = arena->addType(GenericType{scope.get(), generic.name.value}); + scope->parent->typeAliasTypeParameters[generic.name.value] = genericTy; + } + + std::optional defaultTy = std::nullopt; + + if (generic.defaultValue) + defaultTy = resolveType(scope, generic.defaultValue, /* inTypeArguments */ false); + + if (addTypes) + scope->privateTypeBindings[generic.name.value] = TypeFun{genericTy}; + + result.push_back({generic.name.value, GenericTypeDefinition{genericTy, defaultTy}}); + } + + return result; +} + +std::vector> ConstraintGenerator::createGenericPacks( + const ScopePtr& scope, + AstArray generics, + bool useCache, + bool addTypes +) +{ + std::vector> result; + for (const auto& generic : generics) + { + TypePackId genericTy; + + if (auto it = scope->parent->typeAliasTypePackParameters.find(generic.name.value); + useCache && it != scope->parent->typeAliasTypePackParameters.end()) + genericTy = it->second; + else + { + genericTy = arena->addTypePack(TypePackVar{GenericTypePack{scope.get(), generic.name.value}}); + scope->parent->typeAliasTypePackParameters[generic.name.value] = genericTy; + } + + std::optional defaultTy = std::nullopt; + + if (generic.defaultValue) + defaultTy = resolveTypePack(scope, generic.defaultValue, /* inTypeArguments */ false); + + if (addTypes) + scope->privateTypePackBindings[generic.name.value] = genericTy; + + result.push_back({generic.name.value, GenericTypePackDefinition{genericTy, defaultTy}}); + } + + return result; +} + +Inference ConstraintGenerator::flattenPack(const ScopePtr& scope, Location location, InferencePack pack) +{ + const auto& [tp, refinements] = pack; + RefinementId refinement = nullptr; + if (!refinements.empty()) + refinement = refinements[0]; + + if (auto f = first(tp)) + return Inference{*f, refinement}; + + TypeId typeResult = arena->addType(BlockedType{}); + auto c = addConstraint(scope, location, UnpackConstraint{{typeResult}, tp}); + getMutable(typeResult)->setOwner(c); + + return Inference{typeResult, refinement}; +} + +void ConstraintGenerator::reportError(Location location, TypeErrorData err) +{ + errors.push_back(TypeError{location, module->name, std::move(err)}); + + if (logger) + logger->captureGenerationError(errors.back()); +} + +void ConstraintGenerator::reportCodeTooComplex(Location location) +{ + errors.push_back(TypeError{location, module->name, CodeTooComplex{}}); + + if (logger) + logger->captureGenerationError(errors.back()); +} + +TypeId ConstraintGenerator::makeUnion(const ScopePtr& scope, Location location, TypeId lhs, TypeId rhs) +{ + if (get(follow(lhs))) + return rhs; + if (get(follow(rhs))) + return lhs; + + TypeId resultType = createTypeFunctionInstance(builtinTypeFunctions().unionFunc, {lhs, rhs}, {}, scope, location); + + return resultType; +} + +TypeId ConstraintGenerator::makeIntersect(const ScopePtr& scope, Location location, TypeId lhs, TypeId rhs) +{ + TypeId resultType = createTypeFunctionInstance(builtinTypeFunctions().intersectFunc, {lhs, rhs}, {}, scope, location); + + return resultType; +} + +struct GlobalPrepopulator : AstVisitor +{ + const NotNull globalScope; + const NotNull arena; + const NotNull dfg; + + GlobalPrepopulator(NotNull globalScope, NotNull arena, NotNull dfg) + : globalScope(globalScope) + , arena(arena) + , dfg(dfg) + { + } + + bool visit(AstExprGlobal* global) override + { + if (auto ty = globalScope->lookup(global->name)) + { + DefId def = dfg->getDef(global); + globalScope->lvalueTypes[def] = *ty; + } + + return true; + } + + bool visit(AstStatFunction* function) override + { + if (AstExprGlobal* g = function->name->as()) + { + TypeId bt = arena->addType(BlockedType{}); + globalScope->bindings[g->name] = Binding{bt}; + } + + return true; + } + + bool visit(AstType*) override + { + return true; + } + + bool visit(class AstTypePack* node) override + { + return true; + } +}; + +void ConstraintGenerator::prepopulateGlobalScope(const ScopePtr& globalScope, AstStatBlock* program) +{ + GlobalPrepopulator gp{NotNull{globalScope.get()}, arena, dfg}; + + if (prepareModuleScope) + prepareModuleScope(module->name, globalScope); + + program->visit(&gp); +} + +bool ConstraintGenerator::recordPropertyAssignment(TypeId ty) +{ + DenseHashSet seen{nullptr}; + VecDeque queue; + + queue.push_back(ty); + + bool incremented = false; + + while (!queue.empty()) + { + const TypeId t = follow(queue.front()); + queue.pop_front(); + + if (seen.find(t)) + continue; + seen.insert(t); + + if (auto tt = getMutable(t); tt && tt->state == TableState::Unsealed) + { + tt->remainingProps += 1; + incremented = true; + } + else if (auto mt = get(t)) + queue.push_back(mt->table); + else if (TypeIds* localDomain = localTypes.find(t)) + { + for (TypeId domainTy : *localDomain) + queue.push_back(domainTy); + } + else if (auto ut = get(t)) + { + for (TypeId part : ut) + queue.push_back(part); + } + } + + return incremented; +} + +void ConstraintGenerator::recordInferredBinding(AstLocal* local, TypeId ty) +{ + if (InferredBinding* ib = inferredBindings.find(local)) + ib->types.insert(ty); +} + +void ConstraintGenerator::fillInInferredBindings(const ScopePtr& globalScope, AstStatBlock* block) +{ + for (const auto& [symbol, p] : inferredBindings) + { + const auto& [scope, location, types] = p; + + std::vector tys(types.begin(), types.end()); + if (tys.size() == 1) + scope->bindings[symbol] = Binding{tys.front(), location}; + else + { + TypeId ty = createTypeFunctionInstance(builtinTypeFunctions().unionFunc, std::move(tys), {}, globalScope, location); + + scope->bindings[symbol] = Binding{ty, location}; + } + } +} + +std::vector> ConstraintGenerator::getExpectedCallTypesForFunctionOverloads(const TypeId fnType) +{ + std::vector funTys; + if (auto it = get(follow(fnType))) + { + for (TypeId intersectionComponent : it) + { + funTys.push_back(intersectionComponent); + } + } + + std::vector> expectedTypes; + // For a list of functions f_0 : e_0 -> r_0, ... f_n : e_n -> r_n, + // emit a list of arguments that the function could take at each position + // by unioning the arguments at each place + auto assignOption = [this, &expectedTypes](size_t index, TypeId ty) + { + if (index == expectedTypes.size()) + expectedTypes.push_back(ty); + else if (ty) + { + auto& el = expectedTypes[index]; + + if (!el) + el = ty; + else + { + std::vector result = reduceUnion({*el, ty}); + if (result.empty()) + el = builtinTypes->neverType; + else if (result.size() == 1) + el = result[0]; + else + el = module->internalTypes.addType(UnionType{std::move(result)}); + } + } + }; + + for (const TypeId overload : funTys) + { + if (const FunctionType* ftv = get(follow(overload))) + { + auto [argsHead, argsTail] = flatten(ftv->argTypes); + size_t start = ftv->hasSelf ? 1 : 0; + size_t index = 0; + for (size_t i = start; i < argsHead.size(); ++i) + assignOption(index++, argsHead[i]); + if (argsTail) + { + argsTail = follow(*argsTail); + if (const VariadicTypePack* vtp = get(*argsTail)) + { + while (index < funTys.size()) + assignOption(index++, vtp->ty); + } + } + } + } + + // TODO vvijay Feb 24, 2023 apparently we have to demote the types here? + + return expectedTypes; +} + +TypeId ConstraintGenerator::createTypeFunctionInstance( + const TypeFunction& function, + std::vector typeArguments, + std::vector packArguments, + const ScopePtr& scope, + Location location +) +{ + TypeId result = arena->addTypeFunction(function, typeArguments, packArguments); + addConstraint(scope, location, ReduceConstraint{result}); + return result; +} + +std::vector> borrowConstraints(const std::vector& constraints) +{ + std::vector> result; + result.reserve(constraints.size()); + + for (const auto& c : constraints) + result.emplace_back(c.get()); + + return result; +} + +} // namespace Luau diff --git a/Analysis/src/ConstraintGraphBuilder.cpp b/Analysis/src/ConstraintGraphBuilder.cpp deleted file mode 100644 index 9ee2b0882..000000000 --- a/Analysis/src/ConstraintGraphBuilder.cpp +++ /dev/null @@ -1,2630 +0,0 @@ -// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details -#include "Luau/ConstraintGraphBuilder.h" - -#include "Luau/Ast.h" -#include "Luau/Breadcrumb.h" -#include "Luau/Common.h" -#include "Luau/Constraint.h" -#include "Luau/DcrLogger.h" -#include "Luau/ModuleResolver.h" -#include "Luau/RecursionCounter.h" -#include "Luau/Refinement.h" -#include "Luau/Scope.h" -#include "Luau/TypeUtils.h" -#include "Luau/Type.h" - -#include - -LUAU_FASTINT(LuauCheckRecursionLimit); -LUAU_FASTFLAG(DebugLuauMagicTypes); -LUAU_FASTFLAG(LuauNegatedClassTypes); -LUAU_FASTFLAG(SupportTypeAliasGoToDeclaration); - -namespace Luau -{ - -const AstStat* getFallthrough(const AstStat* node); // TypeInfer.cpp - -static std::optional matchRequire(const AstExprCall& call) -{ - const char* require = "require"; - - if (call.args.size != 1) - return std::nullopt; - - const AstExprGlobal* funcAsGlobal = call.func->as(); - if (!funcAsGlobal || funcAsGlobal->name != require) - return std::nullopt; - - if (call.args.size != 1) - return std::nullopt; - - return call.args.data[0]; -} - -static bool matchSetmetatable(const AstExprCall& call) -{ - const char* smt = "setmetatable"; - - if (call.args.size != 2) - return false; - - const AstExprGlobal* funcAsGlobal = call.func->as(); - if (!funcAsGlobal || funcAsGlobal->name != smt) - return false; - - return true; -} - -struct TypeGuard -{ - bool isTypeof; - AstExpr* target; - std::string type; -}; - -static std::optional matchTypeGuard(const AstExprBinary* binary) -{ - if (binary->op != AstExprBinary::CompareEq && binary->op != AstExprBinary::CompareNe) - return std::nullopt; - - AstExpr* left = binary->left; - AstExpr* right = binary->right; - if (right->is()) - std::swap(left, right); - - if (!right->is()) - return std::nullopt; - - AstExprCall* call = left->as(); - AstExprConstantString* string = right->as(); - if (!call || !string) - return std::nullopt; - - AstExprGlobal* callee = call->func->as(); - if (!callee) - return std::nullopt; - - if (callee->name != "type" && callee->name != "typeof") - return std::nullopt; - - if (call->args.size != 1) - return std::nullopt; - - return TypeGuard{ - /*isTypeof*/ callee->name == "typeof", - /*target*/ call->args.data[0], - /*type*/ std::string(string->value.data, string->value.size), - }; -} - -static bool matchAssert(const AstExprCall& call) -{ - if (call.args.size < 1) - return false; - - const AstExprGlobal* funcAsGlobal = call.func->as(); - if (!funcAsGlobal || funcAsGlobal->name != "assert") - return false; - - return true; -} - -namespace -{ - -struct Checkpoint -{ - size_t offset; -}; - -Checkpoint checkpoint(const ConstraintGraphBuilder* cgb) -{ - return Checkpoint{cgb->constraints.size()}; -} - -template -void forEachConstraint(const Checkpoint& start, const Checkpoint& end, const ConstraintGraphBuilder* cgb, F f) -{ - for (size_t i = start.offset; i < end.offset; ++i) - f(cgb->constraints[i]); -} - -} // namespace - -ConstraintGraphBuilder::ConstraintGraphBuilder(const ModuleName& moduleName, ModulePtr module, TypeArena* arena, - NotNull moduleResolver, NotNull builtinTypes, NotNull ice, const ScopePtr& globalScope, - DcrLogger* logger, NotNull dfg) - : moduleName(moduleName) - , module(module) - , builtinTypes(builtinTypes) - , arena(arena) - , rootScope(nullptr) - , dfg(dfg) - , moduleResolver(moduleResolver) - , ice(ice) - , globalScope(globalScope) - , logger(logger) -{ - LUAU_ASSERT(module); -} - -TypeId ConstraintGraphBuilder::freshType(const ScopePtr& scope) -{ - return arena->addType(FreeType{scope.get()}); -} - -TypePackId ConstraintGraphBuilder::freshTypePack(const ScopePtr& scope) -{ - FreeTypePack f{scope.get()}; - return arena->addTypePack(TypePackVar{std::move(f)}); -} - -ScopePtr ConstraintGraphBuilder::childScope(AstNode* node, const ScopePtr& parent) -{ - auto scope = std::make_shared(parent); - scopes.emplace_back(node->location, scope); - - scope->returnType = parent->returnType; - scope->varargPack = parent->varargPack; - - parent->children.push_back(NotNull{scope.get()}); - module->astScopes[node] = scope.get(); - - return scope; -} - -NotNull ConstraintGraphBuilder::addConstraint(const ScopePtr& scope, const Location& location, ConstraintV cv) -{ - return NotNull{constraints.emplace_back(new Constraint{NotNull{scope.get()}, location, std::move(cv)}).get()}; -} - -NotNull ConstraintGraphBuilder::addConstraint(const ScopePtr& scope, std::unique_ptr c) -{ - return NotNull{constraints.emplace_back(std::move(c)).get()}; -} - -struct RefinementPartition -{ - // Types that we want to intersect against the type of the expression. - std::vector discriminantTypes; - - // Sometimes the type we're discriminating against is implicitly nil. - bool shouldAppendNilType = false; -}; - -using RefinementContext = std::unordered_map; - -static void unionRefinements(const RefinementContext& lhs, const RefinementContext& rhs, RefinementContext& dest, NotNull arena) -{ - for (auto& [def, partition] : lhs) - { - auto rhsIt = rhs.find(def); - if (rhsIt == rhs.end()) - continue; - - LUAU_ASSERT(!partition.discriminantTypes.empty()); - LUAU_ASSERT(!rhsIt->second.discriminantTypes.empty()); - - TypeId leftDiscriminantTy = - partition.discriminantTypes.size() == 1 ? partition.discriminantTypes[0] : arena->addType(IntersectionType{partition.discriminantTypes}); - - TypeId rightDiscriminantTy = rhsIt->second.discriminantTypes.size() == 1 ? rhsIt->second.discriminantTypes[0] - : arena->addType(IntersectionType{rhsIt->second.discriminantTypes}); - - dest[def].discriminantTypes.push_back(arena->addType(UnionType{{leftDiscriminantTy, rightDiscriminantTy}})); - dest[def].shouldAppendNilType |= partition.shouldAppendNilType || rhsIt->second.shouldAppendNilType; - } -} - -static void computeRefinement(const ScopePtr& scope, RefinementId refinement, RefinementContext* refis, bool sense, NotNull arena, bool eq, - std::vector* constraints) -{ - if (!refinement) - return; - else if (auto variadic = get(refinement)) - { - for (RefinementId refi : variadic->refinements) - computeRefinement(scope, refi, refis, sense, arena, eq, constraints); - } - else if (auto negation = get(refinement)) - return computeRefinement(scope, negation->refinement, refis, !sense, arena, eq, constraints); - else if (auto conjunction = get(refinement)) - { - RefinementContext lhsRefis; - RefinementContext rhsRefis; - - computeRefinement(scope, conjunction->lhs, sense ? refis : &lhsRefis, sense, arena, eq, constraints); - computeRefinement(scope, conjunction->rhs, sense ? refis : &rhsRefis, sense, arena, eq, constraints); - - if (!sense) - unionRefinements(lhsRefis, rhsRefis, *refis, arena); - } - else if (auto disjunction = get(refinement)) - { - RefinementContext lhsRefis; - RefinementContext rhsRefis; - - computeRefinement(scope, disjunction->lhs, sense ? &lhsRefis : refis, sense, arena, eq, constraints); - computeRefinement(scope, disjunction->rhs, sense ? &rhsRefis : refis, sense, arena, eq, constraints); - - if (sense) - unionRefinements(lhsRefis, rhsRefis, *refis, arena); - } - else if (auto equivalence = get(refinement)) - { - computeRefinement(scope, equivalence->lhs, refis, sense, arena, true, constraints); - computeRefinement(scope, equivalence->rhs, refis, sense, arena, true, constraints); - } - else if (auto proposition = get(refinement)) - { - TypeId discriminantTy = proposition->discriminantTy; - if (!sense && !eq) - discriminantTy = arena->addType(NegationType{proposition->discriminantTy}); - else if (eq) - { - discriminantTy = arena->addType(BlockedType{}); - constraints->push_back(SingletonOrTopTypeConstraint{discriminantTy, proposition->discriminantTy, !sense}); - } - - RefinementContext uncommittedRefis; - uncommittedRefis[proposition->breadcrumb->def].discriminantTypes.push_back(discriminantTy); - - // When the top-level expression is `t[x]`, we want to refine it into `nil`, not `never`. - if ((sense || !eq) && getMetadata(proposition->breadcrumb)) - uncommittedRefis[proposition->breadcrumb->def].shouldAppendNilType = true; - - for (NullableBreadcrumbId current = proposition->breadcrumb; current && current->previous; current = current->previous) - { - LUAU_ASSERT(get(current->def)); - - // If this current breadcrumb has no metadata, it's no-op for the purpose of building a discriminant type. - if (!current->metadata) - continue; - else if (auto field = getMetadata(current)) - { - TableType::Props props{{field->prop, Property{discriminantTy}}}; - discriminantTy = arena->addType(TableType{std::move(props), std::nullopt, TypeLevel{}, scope.get(), TableState::Sealed}); - uncommittedRefis[current->previous->def].discriminantTypes.push_back(discriminantTy); - } - } - - // And now it's time to commit it. - for (auto& [def, partition] : uncommittedRefis) - { - for (TypeId discriminantTy : partition.discriminantTypes) - (*refis)[def].discriminantTypes.push_back(discriminantTy); - - (*refis)[def].shouldAppendNilType |= partition.shouldAppendNilType; - } - } -} - -void ConstraintGraphBuilder::applyRefinements(const ScopePtr& scope, Location location, RefinementId refinement) -{ - if (!refinement) - return; - - RefinementContext refinements; - std::vector constraints; - computeRefinement(scope, refinement, &refinements, /*sense*/ true, arena, /*eq*/ false, &constraints); - - for (auto& [def, partition] : refinements) - { - if (std::optional defTy = scope->lookup(def)) - { - TypeId ty = *defTy; - if (partition.shouldAppendNilType) - ty = arena->addType(UnionType{{ty, builtinTypes->nilType}}); - - partition.discriminantTypes.push_back(ty); - scope->dcrRefinements[def] = arena->addType(IntersectionType{std::move(partition.discriminantTypes)}); - } - } - - for (auto& c : constraints) - addConstraint(scope, location, c); -} - -void ConstraintGraphBuilder::visit(AstStatBlock* block) -{ - LUAU_ASSERT(scopes.empty()); - LUAU_ASSERT(rootScope == nullptr); - ScopePtr scope = std::make_shared(globalScope); - rootScope = scope.get(); - scopes.emplace_back(block->location, scope); - module->astScopes[block] = NotNull{scope.get()}; - - rootScope->returnType = freshTypePack(scope); - - prepopulateGlobalScope(scope, block); - - visitBlockWithoutChildScope(scope, block); - - if (logger) - logger->captureGenerationModule(module); -} - -void ConstraintGraphBuilder::visitBlockWithoutChildScope(const ScopePtr& scope, AstStatBlock* block) -{ - RecursionCounter counter{&recursionCount}; - - if (recursionCount >= FInt::LuauCheckRecursionLimit) - { - reportCodeTooComplex(block->location); - return; - } - - std::unordered_map aliasDefinitionLocations; - - // In order to enable mutually-recursive type aliases, we need to - // populate the type bindings before we actually check any of the - // alias statements. - for (AstStat* stat : block->body) - { - if (auto alias = stat->as()) - { - if (scope->exportedTypeBindings.count(alias->name.value) || scope->privateTypeBindings.count(alias->name.value)) - { - auto it = aliasDefinitionLocations.find(alias->name.value); - LUAU_ASSERT(it != aliasDefinitionLocations.end()); - reportError(alias->location, DuplicateTypeDefinition{alias->name.value, it->second}); - continue; - } - - ScopePtr defnScope = childScope(alias, scope); - - TypeId initialType = arena->addType(BlockedType{}); - TypeFun initialFun{initialType}; - - for (const auto& [name, gen] : createGenerics(defnScope, alias->generics, /* useCache */ true)) - { - initialFun.typeParams.push_back(gen); - } - - for (const auto& [name, genPack] : createGenericPacks(defnScope, alias->genericPacks, /* useCache */ true)) - { - initialFun.typePackParams.push_back(genPack); - } - - if (alias->exported) - scope->exportedTypeBindings[alias->name.value] = std::move(initialFun); - else - scope->privateTypeBindings[alias->name.value] = std::move(initialFun); - - astTypeAliasDefiningScopes[alias] = defnScope; - aliasDefinitionLocations[alias->name.value] = alias->location; - } - } - - for (AstStat* stat : block->body) - visit(scope, stat); -} - -void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStat* stat) -{ - RecursionLimiter limiter{&recursionCount, FInt::LuauCheckRecursionLimit}; - - if (auto s = stat->as()) - visit(scope, s); - else if (auto i = stat->as()) - visit(scope, i); - else if (auto s = stat->as()) - visit(scope, s); - else if (auto s = stat->as()) - visit(scope, s); - else if (stat->is() || stat->is()) - { - // Nothing - } - else if (auto r = stat->as()) - visit(scope, r); - else if (auto e = stat->as()) - checkPack(scope, e->expr); - else if (auto s = stat->as()) - visit(scope, s); - else if (auto s = stat->as()) - visit(scope, s); - else if (auto s = stat->as()) - visit(scope, s); - else if (auto a = stat->as()) - visit(scope, a); - else if (auto a = stat->as()) - visit(scope, a); - else if (auto f = stat->as()) - visit(scope, f); - else if (auto f = stat->as()) - visit(scope, f); - else if (auto a = stat->as()) - visit(scope, a); - else if (auto s = stat->as()) - visit(scope, s); - else if (auto s = stat->as()) - visit(scope, s); - else if (auto s = stat->as()) - visit(scope, s); - else if (auto s = stat->as()) - visit(scope, s); - else - LUAU_ASSERT(0 && "Internal error: Unknown AstStat type"); -} - -void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatLocal* local) -{ - std::vector varTypes; - varTypes.reserve(local->vars.size); - - // Used to name the first value type, even if it's not placed in varTypes, - // for the purpose of synthetic name attribution. - std::optional firstValueType; - - for (AstLocal* local : local->vars) - { - TypeId ty = nullptr; - - if (local->annotation) - ty = resolveType(scope, local->annotation, /* inTypeArguments */ false); - - varTypes.push_back(ty); - } - - for (size_t i = 0; i < local->values.size; ++i) - { - AstExpr* value = local->values.data[i]; - const bool hasAnnotation = i < local->vars.size && nullptr != local->vars.data[i]->annotation; - - if (value->is()) - { - // HACK: we leave nil-initialized things floating under the - // assumption that they will later be populated. - // - // See the test TypeInfer/infer_locals_with_nil_value. Better flow - // awareness should make this obsolete. - - if (!varTypes[i]) - varTypes[i] = freshType(scope); - } - // Only function calls and vararg expressions can produce packs. All - // other expressions produce exactly one value. - else if (i != local->values.size - 1 || (!value->is() && !value->is())) - { - std::optional expectedType; - if (hasAnnotation) - expectedType = varTypes.at(i); - - TypeId exprType = check(scope, value, expectedType).ty; - if (i < varTypes.size()) - { - if (varTypes[i]) - addConstraint(scope, local->location, SubtypeConstraint{exprType, varTypes[i]}); - else - varTypes[i] = exprType; - } - - if (i == 0) - firstValueType = exprType; - } - else - { - std::vector> expectedTypes; - if (hasAnnotation) - expectedTypes.insert(begin(expectedTypes), begin(varTypes) + i, end(varTypes)); - - TypePackId exprPack = checkPack(scope, value, expectedTypes).tp; - - if (i < local->vars.size) - { - TypePack packTypes = extendTypePack(*arena, builtinTypes, exprPack, varTypes.size() - i); - - // fill out missing values in varTypes with values from exprPack - for (size_t j = i; j < varTypes.size(); ++j) - { - if (!varTypes[j]) - { - if (j - i < packTypes.head.size()) - varTypes[j] = packTypes.head[j - i]; - else - varTypes[j] = arena->addType(BlockedType{}); - } - } - - std::vector tailValues{varTypes.begin() + i, varTypes.end()}; - TypePackId tailPack = arena->addTypePack(std::move(tailValues)); - addConstraint(scope, local->location, UnpackConstraint{tailPack, exprPack}); - } - } - } - - if (local->vars.size == 1 && local->values.size == 1 && firstValueType) - { - AstLocal* var = local->vars.data[0]; - AstExpr* value = local->values.data[0]; - - if (value->is()) - addConstraint(scope, value->location, NameConstraint{*firstValueType, var->name.value, /*synthetic*/ true}); - else if (const AstExprCall* call = value->as()) - { - if (const AstExprGlobal* global = call->func->as(); global && global->name == "setmetatable") - { - addConstraint(scope, value->location, NameConstraint{*firstValueType, var->name.value, /*synthetic*/ true}); - } - } - } - - for (size_t i = 0; i < local->vars.size; ++i) - { - AstLocal* l = local->vars.data[i]; - Location location = l->location; - - if (!varTypes[i]) - varTypes[i] = freshType(scope); - - scope->bindings[l] = Binding{varTypes[i], location}; - - // HACK: In the greedy solver, we say the type state of a variable is the type annotation itself, but - // the actual type state is the corresponding initializer expression (if it exists) or nil otherwise. - BreadcrumbId bc = dfg->getBreadcrumb(l); - scope->dcrRefinements[bc->def] = varTypes[i]; - } - - if (local->values.size > 0) - { - // To correctly handle 'require', we need to import the exported type bindings into the variable 'namespace'. - for (size_t i = 0; i < local->values.size && i < local->vars.size; ++i) - { - const AstExprCall* call = local->values.data[i]->as(); - if (!call) - continue; - - if (auto maybeRequire = matchRequire(*call)) - { - AstExpr* require = *maybeRequire; - - if (auto moduleInfo = moduleResolver->resolveModuleInfo(moduleName, *require)) - { - const Name name{local->vars.data[i]->name.value}; - - if (ModulePtr module = moduleResolver->getModule(moduleInfo->name)) - { - scope->importedTypeBindings[name] = module->exportedTypeBindings; - if (FFlag::SupportTypeAliasGoToDeclaration) - scope->importedModules[name] = moduleName; - } - } - } - } - } -} - -void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatFor* for_) -{ - TypeId annotationTy = builtinTypes->numberType; - if (for_->var->annotation) - annotationTy = resolveType(scope, for_->var->annotation, /* inTypeArguments */ false); - - auto inferNumber = [&](AstExpr* expr) { - if (!expr) - return; - - TypeId t = check(scope, expr).ty; - addConstraint(scope, expr->location, SubtypeConstraint{t, builtinTypes->numberType}); - }; - - inferNumber(for_->from); - inferNumber(for_->to); - inferNumber(for_->step); - - ScopePtr forScope = childScope(for_, scope); - forScope->bindings[for_->var] = Binding{annotationTy, for_->var->location}; - - BreadcrumbId bc = dfg->getBreadcrumb(for_->var); - forScope->dcrRefinements[bc->def] = annotationTy; - - visit(forScope, for_->body); -} - -void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatForIn* forIn) -{ - ScopePtr loopScope = childScope(forIn, scope); - - TypePackId iterator = checkPack(scope, forIn->values).tp; - - std::vector variableTypes; - variableTypes.reserve(forIn->vars.size); - for (AstLocal* var : forIn->vars) - { - TypeId ty = freshType(loopScope); - loopScope->bindings[var] = Binding{ty, var->location}; - variableTypes.push_back(ty); - - BreadcrumbId bc = dfg->getBreadcrumb(var); - loopScope->dcrRefinements[bc->def] = ty; - } - - // It is always ok to provide too few variables, so we give this pack a free tail. - TypePackId variablePack = arena->addTypePack(std::move(variableTypes), arena->addTypePack(FreeTypePack{loopScope.get()})); - - addConstraint(loopScope, getLocation(forIn->values), IterableConstraint{iterator, variablePack}); - - visit(loopScope, forIn->body); -} - -void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatWhile* while_) -{ - check(scope, while_->condition); - - ScopePtr whileScope = childScope(while_, scope); - - visit(whileScope, while_->body); -} - -void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatRepeat* repeat) -{ - ScopePtr repeatScope = childScope(repeat, scope); - - visitBlockWithoutChildScope(repeatScope, repeat->body); - - check(repeatScope, repeat->condition); -} - -void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatLocalFunction* function) -{ - // Local - // Global - // Dotted path - // Self? - - TypeId functionType = nullptr; - auto ty = scope->lookup(function->name); - LUAU_ASSERT(!ty.has_value()); // The parser ensures that every local function has a distinct Symbol for its name. - - functionType = arena->addType(BlockedType{}); - scope->bindings[function->name] = Binding{functionType, function->name->location}; - - FunctionSignature sig = checkFunctionSignature(scope, function->func); - sig.bodyScope->bindings[function->name] = Binding{sig.signature, function->func->location}; - - BreadcrumbId bc = dfg->getBreadcrumb(function->name); - scope->dcrRefinements[bc->def] = functionType; - sig.bodyScope->dcrRefinements[bc->def] = sig.signature; - - Checkpoint start = checkpoint(this); - checkFunctionBody(sig.bodyScope, function->func); - Checkpoint end = checkpoint(this); - - NotNull constraintScope{sig.signatureScope ? sig.signatureScope.get() : sig.bodyScope.get()}; - std::unique_ptr c = - std::make_unique(constraintScope, function->name->location, GeneralizationConstraint{functionType, sig.signature}); - - forEachConstraint(start, end, this, [&c](const ConstraintPtr& constraint) { - c->dependencies.push_back(NotNull{constraint.get()}); - }); - - addConstraint(scope, std::move(c)); -} - -void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatFunction* function) -{ - // Name could be AstStatLocal, AstStatGlobal, AstStatIndexName. - // With or without self - - TypeId generalizedType = arena->addType(BlockedType{}); - - Checkpoint start = checkpoint(this); - FunctionSignature sig = checkFunctionSignature(scope, function->func); - - std::unordered_set excludeList; - - if (AstExprLocal* localName = function->name->as()) - { - std::optional existingFunctionTy = scope->lookup(localName->local); - if (existingFunctionTy) - { - addConstraint(scope, function->name->location, SubtypeConstraint{generalizedType, *existingFunctionTy}); - - Symbol sym{localName->local}; - scope->bindings[sym].typeId = generalizedType; - } - else - scope->bindings[localName->local] = Binding{generalizedType, localName->location}; - - sig.bodyScope->bindings[localName->local] = Binding{sig.signature, localName->location}; - } - else if (AstExprGlobal* globalName = function->name->as()) - { - std::optional existingFunctionTy = scope->lookup(globalName->name); - if (!existingFunctionTy) - ice->ice("prepopulateGlobalScope did not populate a global name", globalName->location); - - generalizedType = *existingFunctionTy; - - sig.bodyScope->bindings[globalName->name] = Binding{sig.signature, globalName->location}; - } - else if (AstExprIndexName* indexName = function->name->as()) - { - Checkpoint check1 = checkpoint(this); - TypeId lvalueType = checkLValue(scope, indexName); - Checkpoint check2 = checkpoint(this); - - forEachConstraint(check1, check2, this, [&excludeList](const ConstraintPtr& c) { - excludeList.insert(c.get()); - }); - - // TODO figure out how to populate the location field of the table Property. - - if (get(lvalueType)) - asMutable(lvalueType)->ty.emplace(generalizedType); - else - addConstraint(scope, indexName->location, SubtypeConstraint{lvalueType, generalizedType}); - } - else if (AstExprError* err = function->name->as()) - { - generalizedType = builtinTypes->errorRecoveryType(); - } - - if (generalizedType == nullptr) - ice->ice("generalizedType == nullptr", function->location); - - if (NullableBreadcrumbId bc = dfg->getBreadcrumb(function->name)) - scope->dcrRefinements[bc->def] = generalizedType; - - checkFunctionBody(sig.bodyScope, function->func); - Checkpoint end = checkpoint(this); - - NotNull constraintScope{sig.signatureScope ? sig.signatureScope.get() : sig.bodyScope.get()}; - std::unique_ptr c = - std::make_unique(constraintScope, function->name->location, GeneralizationConstraint{generalizedType, sig.signature}); - - forEachConstraint(start, end, this, [&c, &excludeList](const ConstraintPtr& constraint) { - if (!excludeList.count(constraint.get())) - c->dependencies.push_back(NotNull{constraint.get()}); - }); - - addConstraint(scope, std::move(c)); -} - -void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatReturn* ret) -{ - // At this point, the only way scope->returnType should have anything - // interesting in it is if the function has an explicit return annotation. - // If this is the case, then we can expect that the return expression - // conforms to that. - std::vector> expectedTypes; - for (TypeId ty : scope->returnType) - expectedTypes.push_back(ty); - - TypePackId exprTypes = checkPack(scope, ret->list, expectedTypes).tp; - addConstraint(scope, ret->location, PackSubtypeConstraint{exprTypes, scope->returnType}); -} - -void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatBlock* block) -{ - ScopePtr innerScope = childScope(block, scope); - - visitBlockWithoutChildScope(innerScope, block); -} - -static void bindFreeType(TypeId a, TypeId b) -{ - FreeType* af = getMutable(a); - FreeType* bf = getMutable(b); - - LUAU_ASSERT(af || bf); - - if (!bf) - asMutable(a)->ty.emplace(b); - else if (!af) - asMutable(b)->ty.emplace(a); - else if (subsumes(bf->scope, af->scope)) - asMutable(a)->ty.emplace(b); - else if (subsumes(af->scope, bf->scope)) - asMutable(b)->ty.emplace(a); -} - -void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatAssign* assign) -{ - std::vector varTypes = checkLValues(scope, assign->vars); - - std::vector> expectedTypes; - expectedTypes.reserve(varTypes.size()); - - for (TypeId ty : varTypes) - { - ty = follow(ty); - if (get(ty)) - expectedTypes.push_back(std::nullopt); - else - expectedTypes.push_back(ty); - } - - TypePackId exprPack = checkPack(scope, assign->values, expectedTypes).tp; - TypePackId varPack = arena->addTypePack({varTypes}); - - addConstraint(scope, assign->location, PackSubtypeConstraint{exprPack, varPack}); -} - -void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatCompoundAssign* assign) -{ - // We need to tweak the BinaryConstraint that we emit, so we cannot use the - // strategy of falsifying an AST fragment. - TypeId varTy = checkLValue(scope, assign->var); - TypeId valueTy = check(scope, assign->value).ty; - - TypeId resultType = arena->addType(BlockedType{}); - addConstraint(scope, assign->location, - BinaryConstraint{assign->op, varTy, valueTy, resultType, assign, &module->astOriginalCallTypes, &module->astOverloadResolvedTypes}); - addConstraint(scope, assign->location, SubtypeConstraint{resultType, varTy}); -} - -void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatIf* ifStatement) -{ - ScopePtr condScope = childScope(ifStatement->condition, scope); - RefinementId refinement = check(condScope, ifStatement->condition, std::nullopt).refinement; - - ScopePtr thenScope = childScope(ifStatement->thenbody, scope); - applyRefinements(thenScope, ifStatement->condition->location, refinement); - visit(thenScope, ifStatement->thenbody); - - if (ifStatement->elsebody) - { - ScopePtr elseScope = childScope(ifStatement->elsebody, scope); - applyRefinements(elseScope, ifStatement->elseLocation.value_or(ifStatement->condition->location), refinementArena.negation(refinement)); - visit(elseScope, ifStatement->elsebody); - } -} - -static bool occursCheck(TypeId needle, TypeId haystack) -{ - LUAU_ASSERT(get(needle)); - haystack = follow(haystack); - - auto checkHaystack = [needle](TypeId haystack) { - return occursCheck(needle, haystack); - }; - - if (needle == haystack) - return true; - else if (auto ut = get(haystack)) - return std::any_of(begin(ut), end(ut), checkHaystack); - else if (auto it = get(haystack)) - return std::any_of(begin(it), end(it), checkHaystack); - - return false; -} - -void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatTypeAlias* alias) -{ - ScopePtr* defnScope = astTypeAliasDefiningScopes.find(alias); - - std::unordered_map* typeBindings; - if (alias->exported) - typeBindings = &scope->exportedTypeBindings; - else - typeBindings = &scope->privateTypeBindings; - - // These will be undefined if the alias was a duplicate definition, in which - // case we just skip over it. - auto bindingIt = typeBindings->find(alias->name.value); - if (bindingIt == typeBindings->end() || defnScope == nullptr) - return; - - TypeId ty = resolveType(*defnScope, alias->type, /* inTypeArguments */ false); - - TypeId aliasTy = bindingIt->second.type; - LUAU_ASSERT(get(aliasTy)); - - if (occursCheck(aliasTy, ty)) - { - asMutable(aliasTy)->ty.emplace(builtinTypes->anyType); - reportError(alias->nameLocation, OccursCheckFailed{}); - } - else - asMutable(aliasTy)->ty.emplace(ty); - - std::vector typeParams; - for (auto tyParam : createGenerics(*defnScope, alias->generics, /* useCache */ true, /* addTypes */ false)) - typeParams.push_back(tyParam.second.ty); - - std::vector typePackParams; - for (auto tpParam : createGenericPacks(*defnScope, alias->genericPacks, /* useCache */ true, /* addTypes */ false)) - typePackParams.push_back(tpParam.second.tp); - - addConstraint(scope, alias->type->location, - NameConstraint{ - ty, - alias->name.value, - /*synthetic=*/false, - std::move(typeParams), - std::move(typePackParams), - }); -} - -void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatDeclareGlobal* global) -{ - LUAU_ASSERT(global->type); - - TypeId globalTy = resolveType(scope, global->type, /* inTypeArguments */ false); - Name globalName(global->name.value); - - module->declaredGlobals[globalName] = globalTy; - rootScope->bindings[global->name] = Binding{globalTy, global->location}; - - BreadcrumbId bc = dfg->getBreadcrumb(global); - rootScope->dcrRefinements[bc->def] = globalTy; -} - -static bool isMetamethod(const Name& name) -{ - return name == "__index" || name == "__newindex" || name == "__call" || name == "__concat" || name == "__unm" || name == "__add" || - name == "__sub" || name == "__mul" || name == "__div" || name == "__mod" || name == "__pow" || name == "__tostring" || - name == "__metatable" || name == "__eq" || name == "__lt" || name == "__le" || name == "__mode" || name == "__iter" || name == "__len"; -} - -void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatDeclareClass* declaredClass) -{ - std::optional superTy = FFlag::LuauNegatedClassTypes ? std::make_optional(builtinTypes->classType) : std::nullopt; - if (declaredClass->superName) - { - Name superName = Name(declaredClass->superName->value); - std::optional lookupType = scope->lookupType(superName); - - if (!lookupType) - { - reportError(declaredClass->location, UnknownSymbol{superName, UnknownSymbol::Type}); - return; - } - - // We don't have generic classes, so this assertion _should_ never be hit. - LUAU_ASSERT(lookupType->typeParams.size() == 0 && lookupType->typePackParams.size() == 0); - superTy = lookupType->type; - - if (!get(follow(*superTy))) - { - reportError(declaredClass->location, - GenericError{format("Cannot use non-class type '%s' as a superclass of class '%s'", superName.c_str(), declaredClass->name.value)}); - - return; - } - } - - Name className(declaredClass->name.value); - - TypeId classTy = arena->addType(ClassType(className, {}, superTy, std::nullopt, {}, {}, moduleName)); - ClassType* ctv = getMutable(classTy); - - TypeId metaTy = arena->addType(TableType{TableState::Sealed, scope->level, scope.get()}); - TableType* metatable = getMutable(metaTy); - - ctv->metatable = metaTy; - - scope->exportedTypeBindings[className] = TypeFun{{}, classTy}; - - for (const AstDeclaredClassProp& prop : declaredClass->props) - { - Name propName(prop.name.value); - TypeId propTy = resolveType(scope, prop.ty, /* inTypeArguments */ false); - - bool assignToMetatable = isMetamethod(propName); - - // Function types always take 'self', but this isn't reflected in the - // parsed annotation. Add it here. - if (prop.isMethod) - { - if (FunctionType* ftv = getMutable(propTy)) - { - ftv->argNames.insert(ftv->argNames.begin(), FunctionArgument{"self", {}}); - ftv->argTypes = arena->addTypePack(TypePack{{classTy}, ftv->argTypes}); - - ftv->hasSelf = true; - } - } - - if (ctv->props.count(propName) == 0) - { - if (assignToMetatable) - metatable->props[propName] = {propTy}; - else - ctv->props[propName] = {propTy}; - } - else - { - TypeId currentTy = assignToMetatable ? metatable->props[propName].type : ctv->props[propName].type; - - // We special-case this logic to keep the intersection flat; otherwise we - // would create a ton of nested intersection types. - if (const IntersectionType* itv = get(currentTy)) - { - std::vector options = itv->parts; - options.push_back(propTy); - TypeId newItv = arena->addType(IntersectionType{std::move(options)}); - - if (assignToMetatable) - metatable->props[propName] = {newItv}; - else - ctv->props[propName] = {newItv}; - } - else if (get(currentTy)) - { - TypeId intersection = arena->addType(IntersectionType{{currentTy, propTy}}); - - if (assignToMetatable) - metatable->props[propName] = {intersection}; - else - ctv->props[propName] = {intersection}; - } - else - { - reportError(declaredClass->location, GenericError{format("Cannot overload non-function class member '%s'", propName.c_str())}); - } - } - } -} - -void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatDeclareFunction* global) -{ - std::vector> generics = createGenerics(scope, global->generics); - std::vector> genericPacks = createGenericPacks(scope, global->genericPacks); - - std::vector genericTys; - genericTys.reserve(generics.size()); - for (auto& [name, generic] : generics) - { - genericTys.push_back(generic.ty); - } - - std::vector genericTps; - genericTps.reserve(genericPacks.size()); - for (auto& [name, generic] : genericPacks) - { - genericTps.push_back(generic.tp); - } - - ScopePtr funScope = scope; - if (!generics.empty() || !genericPacks.empty()) - funScope = childScope(global, scope); - - TypePackId paramPack = resolveTypePack(funScope, global->params, /* inTypeArguments */ false); - TypePackId retPack = resolveTypePack(funScope, global->retTypes, /* inTypeArguments */ false); - TypeId fnType = arena->addType(FunctionType{TypeLevel{}, funScope.get(), std::move(genericTys), std::move(genericTps), paramPack, retPack}); - FunctionType* ftv = getMutable(fnType); - - ftv->argNames.reserve(global->paramNames.size); - for (const auto& el : global->paramNames) - ftv->argNames.push_back(FunctionArgument{el.first.value, el.second}); - - Name fnName(global->name.value); - - module->declaredGlobals[fnName] = fnType; - scope->bindings[global->name] = Binding{fnType, global->location}; - - BreadcrumbId bc = dfg->getBreadcrumb(global); - rootScope->dcrRefinements[bc->def] = fnType; -} - -void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatError* error) -{ - for (AstStat* stat : error->statements) - visit(scope, stat); - for (AstExpr* expr : error->expressions) - check(scope, expr); -} - -InferencePack ConstraintGraphBuilder::checkPack( - const ScopePtr& scope, AstArray exprs, const std::vector>& expectedTypes) -{ - std::vector head; - std::optional tail; - - for (size_t i = 0; i < exprs.size; ++i) - { - AstExpr* expr = exprs.data[i]; - if (i < exprs.size - 1) - { - std::optional expectedType; - if (i < expectedTypes.size()) - expectedType = expectedTypes[i]; - head.push_back(check(scope, expr, expectedType).ty); - } - else - { - std::vector> expectedTailTypes; - if (i < expectedTypes.size()) - expectedTailTypes.assign(begin(expectedTypes) + i, end(expectedTypes)); - tail = checkPack(scope, expr, expectedTailTypes).tp; - } - } - - if (head.empty() && tail) - return InferencePack{*tail}; - else - return InferencePack{arena->addTypePack(TypePack{std::move(head), tail})}; -} - -InferencePack ConstraintGraphBuilder::checkPack(const ScopePtr& scope, AstExpr* expr, const std::vector>& expectedTypes) -{ - RecursionCounter counter{&recursionCount}; - - if (recursionCount >= FInt::LuauCheckRecursionLimit) - { - reportCodeTooComplex(expr->location); - return InferencePack{builtinTypes->errorRecoveryTypePack()}; - } - - InferencePack result; - - if (AstExprCall* call = expr->as()) - result = checkPack(scope, call); - else if (AstExprVarargs* varargs = expr->as()) - { - if (scope->varargPack) - result = InferencePack{*scope->varargPack}; - else - result = InferencePack{builtinTypes->errorRecoveryTypePack()}; - } - else - { - std::optional expectedType; - if (!expectedTypes.empty()) - expectedType = expectedTypes[0]; - TypeId t = check(scope, expr, expectedType).ty; - result = InferencePack{arena->addTypePack({t})}; - } - - LUAU_ASSERT(result.tp); - module->astTypePacks[expr] = result.tp; - return result; -} - -InferencePack ConstraintGraphBuilder::checkPack(const ScopePtr& scope, AstExprCall* call) -{ - std::vector exprArgs; - - std::vector returnRefinements; - std::vector> discriminantTypes; - - if (call->self) - { - AstExprIndexName* indexExpr = call->func->as(); - if (!indexExpr) - ice->ice("method call expression has no 'self'"); - - exprArgs.push_back(indexExpr->expr); - - if (auto bc = dfg->getBreadcrumb(indexExpr->expr)) - { - TypeId discriminantTy = arena->addType(BlockedType{}); - returnRefinements.push_back(refinementArena.proposition(NotNull{bc}, discriminantTy)); - discriminantTypes.push_back(discriminantTy); - } - else - discriminantTypes.push_back(std::nullopt); - } - - for (AstExpr* arg : call->args) - { - exprArgs.push_back(arg); - - if (auto bc = dfg->getBreadcrumb(arg)) - { - TypeId discriminantTy = arena->addType(BlockedType{}); - returnRefinements.push_back(refinementArena.proposition(NotNull{bc}, discriminantTy)); - discriminantTypes.push_back(discriminantTy); - } - else - discriminantTypes.push_back(std::nullopt); - } - - Checkpoint startCheckpoint = checkpoint(this); - TypeId fnType = check(scope, call->func).ty; - Checkpoint fnEndCheckpoint = checkpoint(this); - - std::vector> expectedTypesForCall = getExpectedCallTypesForFunctionOverloads(fnType); - - module->astOriginalCallTypes[call->func] = fnType; - - TypePackId expectedArgPack = arena->freshTypePack(scope.get()); - TypePackId expectedRetPack = arena->freshTypePack(scope.get()); - TypeId expectedFunctionType = arena->addType(FunctionType{expectedArgPack, expectedRetPack, std::nullopt, call->self}); - - TypeId instantiatedFnType = arena->addType(BlockedType{}); - addConstraint(scope, call->location, InstantiationConstraint{instantiatedFnType, fnType}); - - NotNull extractArgsConstraint = addConstraint(scope, call->location, SubtypeConstraint{instantiatedFnType, expectedFunctionType}); - - // Fully solve fnType, then extract its argument list as expectedArgPack. - forEachConstraint(startCheckpoint, fnEndCheckpoint, this, [extractArgsConstraint](const ConstraintPtr& constraint) { - extractArgsConstraint->dependencies.emplace_back(constraint.get()); - }); - - const AstExpr* lastArg = exprArgs.size() ? exprArgs[exprArgs.size() - 1] : nullptr; - const bool needTail = lastArg && (lastArg->is() || lastArg->is()); - - TypePack expectedArgs; - - if (!needTail) - expectedArgs = extendTypePack(*arena, builtinTypes, expectedArgPack, exprArgs.size(), expectedTypesForCall); - else - expectedArgs = extendTypePack(*arena, builtinTypes, expectedArgPack, exprArgs.size() - 1, expectedTypesForCall); - - std::vector args; - std::optional argTail; - std::vector argumentRefinements; - - Checkpoint argCheckpoint = checkpoint(this); - - for (size_t i = 0; i < exprArgs.size(); ++i) - { - AstExpr* arg = exprArgs[i]; - std::optional expectedType; - if (i < expectedArgs.head.size()) - expectedType = expectedArgs.head[i]; - - if (i == 0 && call->self) - { - // The self type has already been computed as a side effect of - // computing fnType. If computing that did not cause us to exceed a - // recursion limit, we can fetch it from astTypes rather than - // recomputing it. - TypeId* selfTy = module->astTypes.find(exprArgs[0]); - if (selfTy) - args.push_back(*selfTy); - else - args.push_back(arena->freshType(scope.get())); - } - else if (i < exprArgs.size() - 1 || !(arg->is() || arg->is())) - { - auto [ty, refinement] = check(scope, arg, expectedType); - args.push_back(ty); - argumentRefinements.push_back(refinement); - } - else - { - auto [tp, refis] = checkPack(scope, arg, {}); - argTail = tp; - argumentRefinements.insert(argumentRefinements.end(), refis.begin(), refis.end()); - } - } - - Checkpoint argEndCheckpoint = checkpoint(this); - - // Do not solve argument constraints until after we have extracted the - // expected types from the callable. - forEachConstraint(argCheckpoint, argEndCheckpoint, this, [extractArgsConstraint](const ConstraintPtr& constraint) { - constraint->dependencies.push_back(extractArgsConstraint); - }); - - if (matchSetmetatable(*call)) - { - TypePack argTailPack; - if (argTail && args.size() < 2) - argTailPack = extendTypePack(*arena, builtinTypes, *argTail, 2 - args.size()); - - LUAU_ASSERT(args.size() + argTailPack.head.size() == 2); - - TypeId target = args.size() > 0 ? args[0] : argTailPack.head[0]; - TypeId mt = args.size() > 1 ? args[1] : argTailPack.head[args.size() == 0 ? 1 : 0]; - - AstExpr* targetExpr = call->args.data[0]; - - MetatableType mtv{target, mt}; - TypeId resultTy = arena->addType(mtv); - - if (AstExprLocal* targetLocal = targetExpr->as()) - { - scope->bindings[targetLocal->local].typeId = resultTy; - - BreadcrumbId bc = dfg->getBreadcrumb(targetLocal); - scope->dcrRefinements[bc->def] = resultTy; // TODO: typestates: track this as an assignment - } - - return InferencePack{arena->addTypePack({resultTy}), {refinementArena.variadic(returnRefinements)}}; - } - else - { - if (matchAssert(*call) && !argumentRefinements.empty()) - applyRefinements(scope, call->args.data[0]->location, argumentRefinements[0]); - - // TODO: How do expectedTypes play into this? Do they? - TypePackId rets = arena->addTypePack(BlockedTypePack{}); - TypePackId argPack = arena->addTypePack(TypePack{args, argTail}); - FunctionType ftv(TypeLevel{}, scope.get(), argPack, rets, std::nullopt, call->self); - - NotNull fcc = addConstraint(scope, call->func->location, - FunctionCallConstraint{ - fnType, - argPack, - rets, - call, - std::move(discriminantTypes), - }); - - // We force constraints produced by checking function arguments to wait - // until after we have resolved the constraint on the function itself. - // This ensures, for instance, that we start inferring the contents of - // lambdas under the assumption that their arguments and return types - // will be compatible with the enclosing function call. - forEachConstraint(fnEndCheckpoint, argEndCheckpoint, this, [fcc](const ConstraintPtr& constraint) { - fcc->dependencies.emplace_back(constraint.get()); - }); - - return InferencePack{rets, {refinementArena.variadic(returnRefinements)}}; - } -} - -Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExpr* expr, std::optional expectedType, bool forceSingleton) -{ - RecursionCounter counter{&recursionCount}; - - if (recursionCount >= FInt::LuauCheckRecursionLimit) - { - reportCodeTooComplex(expr->location); - return Inference{builtinTypes->errorRecoveryType()}; - } - - Inference result; - - if (auto group = expr->as()) - result = check(scope, group->expr, expectedType, forceSingleton); - else if (auto stringExpr = expr->as()) - result = check(scope, stringExpr, expectedType, forceSingleton); - else if (expr->is()) - result = Inference{builtinTypes->numberType}; - else if (auto boolExpr = expr->as()) - result = check(scope, boolExpr, expectedType, forceSingleton); - else if (expr->is()) - result = Inference{builtinTypes->nilType}; - else if (auto local = expr->as()) - result = check(scope, local); - else if (auto global = expr->as()) - result = check(scope, global); - else if (expr->is()) - result = flattenPack(scope, expr->location, checkPack(scope, expr)); - else if (auto call = expr->as()) - result = flattenPack(scope, expr->location, checkPack(scope, call)); // TODO: needs predicates too - else if (auto a = expr->as()) - { - Checkpoint startCheckpoint = checkpoint(this); - FunctionSignature sig = checkFunctionSignature(scope, a, expectedType); - checkFunctionBody(sig.bodyScope, a); - Checkpoint endCheckpoint = checkpoint(this); - - TypeId generalizedTy = arena->addType(BlockedType{}); - NotNull gc = addConstraint(scope, expr->location, GeneralizationConstraint{generalizedTy, sig.signature}); - - forEachConstraint(startCheckpoint, endCheckpoint, this, [gc](const ConstraintPtr& constraint) { - gc->dependencies.emplace_back(constraint.get()); - }); - - result = Inference{generalizedTy}; - } - else if (auto indexName = expr->as()) - result = check(scope, indexName); - else if (auto indexExpr = expr->as()) - result = check(scope, indexExpr); - else if (auto table = expr->as()) - result = check(scope, table, expectedType); - else if (auto unary = expr->as()) - result = check(scope, unary); - else if (auto binary = expr->as()) - result = check(scope, binary, expectedType); - else if (auto ifElse = expr->as()) - result = check(scope, ifElse, expectedType); - else if (auto typeAssert = expr->as()) - result = check(scope, typeAssert); - else if (auto interpString = expr->as()) - result = check(scope, interpString); - else if (auto err = expr->as()) - { - // Open question: Should we traverse into this? - for (AstExpr* subExpr : err->expressions) - check(scope, subExpr); - - result = Inference{builtinTypes->errorRecoveryType()}; - } - else - { - LUAU_ASSERT(0); - result = Inference{freshType(scope)}; - } - - LUAU_ASSERT(result.ty); - module->astTypes[expr] = result.ty; - if (expectedType) - module->astExpectedTypes[expr] = *expectedType; - return result; -} - -Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprConstantString* string, std::optional expectedType, bool forceSingleton) -{ - if (forceSingleton) - return Inference{arena->addType(SingletonType{StringSingleton{std::string{string->value.data, string->value.size}}})}; - - if (expectedType) - { - const TypeId expectedTy = follow(*expectedType); - if (get(expectedTy) || get(expectedTy)) - { - TypeId ty = arena->addType(BlockedType{}); - TypeId singletonType = arena->addType(SingletonType(StringSingleton{std::string(string->value.data, string->value.size)})); - addConstraint(scope, string->location, PrimitiveTypeConstraint{ty, expectedTy, singletonType, builtinTypes->stringType}); - return Inference{ty}; - } - else if (maybeSingleton(expectedTy)) - return Inference{arena->addType(SingletonType{StringSingleton{std::string{string->value.data, string->value.size}}})}; - - return Inference{builtinTypes->stringType}; - } - - return Inference{builtinTypes->stringType}; -} - -Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprConstantBool* boolExpr, std::optional expectedType, bool forceSingleton) -{ - const TypeId singletonType = boolExpr->value ? builtinTypes->trueType : builtinTypes->falseType; - if (forceSingleton) - return Inference{singletonType}; - - if (expectedType) - { - const TypeId expectedTy = follow(*expectedType); - - if (get(expectedTy) || get(expectedTy)) - { - TypeId ty = arena->addType(BlockedType{}); - addConstraint(scope, boolExpr->location, PrimitiveTypeConstraint{ty, expectedTy, singletonType, builtinTypes->booleanType}); - return Inference{ty}; - } - else if (maybeSingleton(expectedTy)) - return Inference{singletonType}; - - return Inference{builtinTypes->booleanType}; - } - - return Inference{builtinTypes->booleanType}; -} - -Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprLocal* local) -{ - BreadcrumbId bc = dfg->getBreadcrumb(local); - - if (auto ty = scope->lookup(bc->def)) - return Inference{*ty, refinementArena.proposition(bc, builtinTypes->truthyType)}; - else if (auto ty = scope->lookup(local->local)) - return Inference{*ty, refinementArena.proposition(bc, builtinTypes->truthyType)}; - else - ice->ice("AstExprLocal came before its declaration?"); -} - -Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprGlobal* global) -{ - BreadcrumbId bc = dfg->getBreadcrumb(global); - - /* prepopulateGlobalScope() has already added all global functions to the environment by this point, so any - * global that is not already in-scope is definitely an unknown symbol. - */ - if (auto ty = scope->lookup(bc->def)) - return Inference{*ty, refinementArena.proposition(bc, builtinTypes->truthyType)}; - else if (auto ty = scope->lookup(global->name)) - { - rootScope->dcrRefinements[bc->def] = *ty; - return Inference{*ty, refinementArena.proposition(bc, builtinTypes->truthyType)}; - } - else - { - reportError(global->location, UnknownSymbol{global->name.value}); - return Inference{builtinTypes->errorRecoveryType()}; - } -} - -Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprIndexName* indexName) -{ - TypeId obj = check(scope, indexName->expr).ty; - TypeId result = arena->addType(BlockedType{}); - - NullableBreadcrumbId bc = dfg->getBreadcrumb(indexName); - if (bc) - { - if (auto ty = scope->lookup(bc->def)) - return Inference{*ty, refinementArena.proposition(NotNull{bc}, builtinTypes->truthyType)}; - - scope->dcrRefinements[bc->def] = result; - } - - addConstraint(scope, indexName->expr->location, HasPropConstraint{result, obj, indexName->index.value}); - - if (bc) - return Inference{result, refinementArena.proposition(NotNull{bc}, builtinTypes->truthyType)}; - else - return Inference{result}; -} - -Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprIndexExpr* indexExpr) -{ - TypeId obj = check(scope, indexExpr->expr).ty; - TypeId indexType = check(scope, indexExpr->index).ty; - TypeId result = freshType(scope); - - NullableBreadcrumbId bc = dfg->getBreadcrumb(indexExpr); - if (bc) - { - if (auto ty = scope->lookup(bc->def)) - return Inference{*ty, refinementArena.proposition(NotNull{bc}, builtinTypes->truthyType)}; - - scope->dcrRefinements[bc->def] = result; - } - - TableIndexer indexer{indexType, result}; - TypeId tableType = arena->addType(TableType{TableType::Props{}, TableIndexer{indexType, result}, TypeLevel{}, scope.get(), TableState::Free}); - - addConstraint(scope, indexExpr->expr->location, SubtypeConstraint{obj, tableType}); - - if (bc) - return Inference{result, refinementArena.proposition(NotNull{bc}, builtinTypes->truthyType)}; - else - return Inference{result}; -} - -Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprUnary* unary) -{ - auto [operandType, refinement] = check(scope, unary->expr); - TypeId resultType = arena->addType(BlockedType{}); - addConstraint(scope, unary->location, UnaryConstraint{unary->op, operandType, resultType}); - - if (unary->op == AstExprUnary::Not) - return Inference{resultType, refinementArena.negation(refinement)}; - else - return Inference{resultType}; -} - -Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprBinary* binary, std::optional expectedType) -{ - auto [leftType, rightType, refinement] = checkBinary(scope, binary, expectedType); - - TypeId resultType = arena->addType(BlockedType{}); - addConstraint(scope, binary->location, - BinaryConstraint{binary->op, leftType, rightType, resultType, binary, &module->astOriginalCallTypes, &module->astOverloadResolvedTypes}); - return Inference{resultType, std::move(refinement)}; -} - -Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprIfElse* ifElse, std::optional expectedType) -{ - ScopePtr condScope = childScope(ifElse->condition, scope); - RefinementId refinement = check(condScope, ifElse->condition).refinement; - - ScopePtr thenScope = childScope(ifElse->trueExpr, scope); - applyRefinements(thenScope, ifElse->trueExpr->location, refinement); - TypeId thenType = check(thenScope, ifElse->trueExpr, expectedType).ty; - - ScopePtr elseScope = childScope(ifElse->falseExpr, scope); - applyRefinements(elseScope, ifElse->falseExpr->location, refinementArena.negation(refinement)); - TypeId elseType = check(elseScope, ifElse->falseExpr, expectedType).ty; - - return Inference{expectedType ? *expectedType : arena->addType(UnionType{{thenType, elseType}})}; -} - -Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprTypeAssertion* typeAssert) -{ - check(scope, typeAssert->expr, std::nullopt); - return Inference{resolveType(scope, typeAssert->annotation, /* inTypeArguments */ false)}; -} - -Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprInterpString* interpString) -{ - for (AstExpr* expr : interpString->expressions) - check(scope, expr); - - return Inference{builtinTypes->stringType}; -} - -std::tuple ConstraintGraphBuilder::checkBinary( - const ScopePtr& scope, AstExprBinary* binary, std::optional expectedType) -{ - if (binary->op == AstExprBinary::And) - { - auto [leftType, leftRefinement] = check(scope, binary->left, expectedType); - - ScopePtr rightScope = childScope(binary->right, scope); - applyRefinements(rightScope, binary->right->location, leftRefinement); - auto [rightType, rightRefinement] = check(rightScope, binary->right, expectedType); - - return {leftType, rightType, refinementArena.conjunction(leftRefinement, rightRefinement)}; - } - else if (binary->op == AstExprBinary::Or) - { - auto [leftType, leftRefinement] = check(scope, binary->left, expectedType); - - ScopePtr rightScope = childScope(binary->right, scope); - applyRefinements(rightScope, binary->right->location, refinementArena.negation(leftRefinement)); - auto [rightType, rightRefinement] = check(rightScope, binary->right, expectedType); - - return {leftType, rightType, refinementArena.disjunction(leftRefinement, rightRefinement)}; - } - else if (auto typeguard = matchTypeGuard(binary)) - { - TypeId leftType = check(scope, binary->left).ty; - TypeId rightType = check(scope, binary->right).ty; - - NullableBreadcrumbId bc = dfg->getBreadcrumb(typeguard->target); - if (!bc) - return {leftType, rightType, nullptr}; - - TypeId discriminantTy = builtinTypes->neverType; - if (typeguard->type == "nil") - discriminantTy = builtinTypes->nilType; - else if (typeguard->type == "string") - discriminantTy = builtinTypes->stringType; - else if (typeguard->type == "number") - discriminantTy = builtinTypes->numberType; - else if (typeguard->type == "boolean") - discriminantTy = builtinTypes->booleanType; - else if (typeguard->type == "thread") - discriminantTy = builtinTypes->threadType; - else if (typeguard->type == "table") - discriminantTy = builtinTypes->tableType; - else if (typeguard->type == "function") - discriminantTy = builtinTypes->functionType; - else if (typeguard->type == "userdata") - { - // For now, we don't really care about being accurate with userdata if the typeguard was using typeof. - discriminantTy = builtinTypes->classType; - } - else if (!typeguard->isTypeof && typeguard->type == "vector") - discriminantTy = builtinTypes->neverType; // TODO: figure out a way to deal with this quirky type - else if (!typeguard->isTypeof) - discriminantTy = builtinTypes->neverType; - else if (auto typeFun = globalScope->lookupType(typeguard->type); typeFun && typeFun->typeParams.empty() && typeFun->typePackParams.empty()) - { - TypeId ty = follow(typeFun->type); - - // We're only interested in the root class of any classes. - if (auto ctv = get(ty); !ctv || (FFlag::LuauNegatedClassTypes ? (ctv->parent == builtinTypes->classType) : !ctv->parent)) - discriminantTy = ty; - } - - RefinementId proposition = refinementArena.proposition(NotNull{bc}, discriminantTy); - if (binary->op == AstExprBinary::CompareEq) - return {leftType, rightType, proposition}; - else if (binary->op == AstExprBinary::CompareNe) - return {leftType, rightType, refinementArena.negation(proposition)}; - else - ice->ice("matchTypeGuard should only return a Some under `==` or `~=`!"); - } - else if (binary->op == AstExprBinary::CompareEq || binary->op == AstExprBinary::CompareNe) - { - TypeId leftType = check(scope, binary->left, expectedType, true).ty; - TypeId rightType = check(scope, binary->right, expectedType, true).ty; - - RefinementId leftRefinement = nullptr; - if (auto bc = dfg->getBreadcrumb(binary->left)) - leftRefinement = refinementArena.proposition(NotNull{bc}, rightType); - - RefinementId rightRefinement = nullptr; - if (auto bc = dfg->getBreadcrumb(binary->right)) - rightRefinement = refinementArena.proposition(NotNull{bc}, leftType); - - if (binary->op == AstExprBinary::CompareNe) - { - leftRefinement = refinementArena.negation(leftRefinement); - rightRefinement = refinementArena.negation(rightRefinement); - } - - return {leftType, rightType, refinementArena.equivalence(leftRefinement, rightRefinement)}; - } - else - { - TypeId leftType = check(scope, binary->left, expectedType).ty; - TypeId rightType = check(scope, binary->right, expectedType).ty; - return {leftType, rightType, nullptr}; - } -} - -std::vector ConstraintGraphBuilder::checkLValues(const ScopePtr& scope, AstArray exprs) -{ - std::vector types; - types.reserve(exprs.size); - - for (AstExpr* expr : exprs) - types.push_back(checkLValue(scope, expr)); - - return types; -} - -static bool isIndexNameEquivalent(AstExpr* expr) -{ - if (expr->is()) - return true; - - AstExprIndexExpr* e = expr->as(); - if (e == nullptr) - return false; - - if (!e->index->is()) - return false; - - return true; -} - -/** - * This function is mostly about identifying properties that are being inserted into unsealed tables. - * - * If expr has the form name.a.b.c - */ -TypeId ConstraintGraphBuilder::checkLValue(const ScopePtr& scope, AstExpr* expr) -{ - if (auto indexExpr = expr->as(); indexExpr && !indexExpr->index->is()) - { - // An indexer is only interesting in an lvalue-ey way if it is at the - // tail of an expression. - // - // If the indexer is not at the tail, then we are not interested in - // augmenting the lhs data structure with a new indexer. Constraint - // generation can treat it as an ordinary lvalue. - // - // eg - // - // a.b.c[1] = 44 -- lvalue - // a.b[4].c = 2 -- rvalue - - TypeId resultType = arena->addType(BlockedType{}); - TypeId subjectType = check(scope, indexExpr->expr).ty; - TypeId indexType = check(scope, indexExpr->index).ty; - TypeId propType = arena->addType(BlockedType{}); - addConstraint(scope, expr->location, SetIndexerConstraint{resultType, subjectType, indexType, propType}); - - module->astTypes[expr] = propType; - - return propType; - } - else if (!isIndexNameEquivalent(expr)) - return check(scope, expr).ty; - - Symbol sym; - std::vector segments; - std::vector exprs; - - AstExpr* e = expr; - while (e) - { - if (auto global = e->as()) - { - sym = global->name; - break; - } - else if (auto local = e->as()) - { - sym = local->local; - break; - } - else if (auto indexName = e->as()) - { - segments.push_back(indexName->index.value); - exprs.push_back(e); - e = indexName->expr; - } - else if (auto indexExpr = e->as()) - { - if (auto strIndex = indexExpr->index->as()) - { - segments.push_back(std::string(strIndex->value.data, strIndex->value.size)); - exprs.push_back(e); - e = indexExpr->expr; - } - else - { - return check(scope, expr).ty; - } - } - else - return check(scope, expr).ty; - } - - LUAU_ASSERT(!segments.empty()); - - std::reverse(begin(segments), end(segments)); - std::reverse(begin(exprs), end(exprs)); - - auto lookupResult = scope->lookupEx(sym); - if (!lookupResult) - return check(scope, expr).ty; - const auto [subjectBinding, symbolScope] = std::move(*lookupResult); - TypeId subjectType = subjectBinding->typeId; - - TypeId propTy = freshType(scope); - - std::vector segmentStrings(begin(segments), end(segments)); - - TypeId updatedType = arena->addType(BlockedType{}); - addConstraint(scope, expr->location, SetPropConstraint{updatedType, subjectType, std::move(segmentStrings), propTy}); - - TypeId prevSegmentTy = updatedType; - for (size_t i = 0; i < segments.size(); ++i) - { - TypeId segmentTy = arena->addType(BlockedType{}); - module->astTypes[exprs[i]] = segmentTy; - addConstraint(scope, expr->location, HasPropConstraint{segmentTy, prevSegmentTy, segments[i]}); - prevSegmentTy = segmentTy; - } - - module->astTypes[expr] = prevSegmentTy; - module->astTypes[e] = updatedType; - - if (!subjectType->persistent) - { - symbolScope->bindings[sym].typeId = updatedType; - - // This can fail if the user is erroneously trying to augment a builtin - // table like os or string. - if (auto bc = dfg->getBreadcrumb(e)) - symbolScope->dcrRefinements[bc->def] = updatedType; - } - - return propTy; -} - -Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprTable* expr, std::optional expectedType) -{ - TypeId ty = arena->addType(TableType{}); - TableType* ttv = getMutable(ty); - LUAU_ASSERT(ttv); - - ttv->state = TableState::Unsealed; - ttv->scope = scope.get(); - - auto createIndexer = [this, scope, ttv](const Location& location, TypeId currentIndexType, TypeId currentResultType) { - if (!ttv->indexer) - { - TypeId indexType = this->freshType(scope); - TypeId resultType = this->freshType(scope); - ttv->indexer = TableIndexer{indexType, resultType}; - } - - addConstraint(scope, location, SubtypeConstraint{ttv->indexer->indexType, currentIndexType}); - addConstraint(scope, location, SubtypeConstraint{ttv->indexer->indexResultType, currentResultType}); - }; - - std::optional annotatedKeyType; - std::optional annotatedIndexResultType; - - if (expectedType) - { - if (const TableType* ttv = get(follow(*expectedType))) - { - if (ttv->indexer) - { - annotatedKeyType.emplace(follow(ttv->indexer->indexType)); - annotatedIndexResultType.emplace(ttv->indexer->indexResultType); - } - } - } - - bool isIndexedResultType = false; - std::optional pinnedIndexResultType; - - - for (const AstExprTable::Item& item : expr->items) - { - std::optional expectedValueType; - if (item.kind == AstExprTable::Item::Kind::General || item.kind == AstExprTable::Item::Kind::List) - isIndexedResultType = true; - - if (item.key && expectedType) - { - if (auto stringKey = item.key->as()) - { - ErrorVec errorVec; - std::optional propTy = - findTablePropertyRespectingMeta(builtinTypes, errorVec, follow(*expectedType), stringKey->value.data, item.value->location); - if (propTy) - expectedValueType = propTy; - else - { - expectedValueType = arena->addType(BlockedType{}); - addConstraint(scope, item.value->location, HasPropConstraint{*expectedValueType, *expectedType, stringKey->value.data}); - } - } - } - - - // We'll resolve the expected index result type here with the following priority: - // 1. Record table types - in which key, value pairs must be handled on a k,v pair basis. - // In this case, the above if-statement will populate expectedValueType - // 2. Someone places an annotation on a General or List table - // Trust the annotation and have the solver inform them if they get it wrong - // 3. Someone omits the annotation on a general or List table - // Use the type of the first indexResultType as the expected type - std::optional checkExpectedIndexResultType; - if (expectedValueType) - { - checkExpectedIndexResultType = expectedValueType; - } - else if (annotatedIndexResultType) - { - checkExpectedIndexResultType = annotatedIndexResultType; - } - else if (pinnedIndexResultType) - { - checkExpectedIndexResultType = pinnedIndexResultType; - } - - TypeId itemTy = check(scope, item.value, checkExpectedIndexResultType).ty; - - if (isIndexedResultType && !pinnedIndexResultType) - pinnedIndexResultType = itemTy; - - if (item.key) - { - // Even though we don't need to use the type of the item's key if - // it's a string constant, we still want to check it to populate - // astTypes. - TypeId keyTy = check(scope, item.key, annotatedKeyType).ty; - - if (AstExprConstantString* key = item.key->as()) - { - ttv->props[key->value.begin()] = {itemTy}; - } - else - { - createIndexer(item.key->location, keyTy, itemTy); - } - } - else - { - TypeId numberType = builtinTypes->numberType; - // FIXME? The location isn't quite right here. Not sure what is - // right. - createIndexer(item.value->location, numberType, itemTy); - } - } - - return Inference{ty}; -} - -ConstraintGraphBuilder::FunctionSignature ConstraintGraphBuilder::checkFunctionSignature( - const ScopePtr& parent, AstExprFunction* fn, std::optional expectedType) -{ - ScopePtr signatureScope = nullptr; - ScopePtr bodyScope = nullptr; - TypePackId returnType = nullptr; - - std::vector genericTypes; - std::vector genericTypePacks; - - if (expectedType) - expectedType = follow(*expectedType); - - bool hasGenerics = fn->generics.size > 0 || fn->genericPacks.size > 0; - - signatureScope = childScope(fn, parent); - - // We need to assign returnType before creating bodyScope so that the - // return type gets propogated to bodyScope. - returnType = freshTypePack(signatureScope); - signatureScope->returnType = returnType; - - bodyScope = childScope(fn->body, signatureScope); - - if (hasGenerics) - { - std::vector> genericDefinitions = createGenerics(signatureScope, fn->generics); - std::vector> genericPackDefinitions = createGenericPacks(signatureScope, fn->genericPacks); - - // We do not support default values on function generics, so we only - // care about the types involved. - for (const auto& [name, g] : genericDefinitions) - { - genericTypes.push_back(g.ty); - } - - for (const auto& [name, g] : genericPackDefinitions) - { - genericTypePacks.push_back(g.tp); - } - - // Local variable works around an odd gcc 11.3 warning: may be used uninitialized - std::optional none = std::nullopt; - expectedType = none; - } - - std::vector argTypes; - std::vector> argNames; - TypePack expectedArgPack; - - const FunctionType* expectedFunction = expectedType ? get(*expectedType) : nullptr; - - if (expectedFunction) - { - expectedArgPack = extendTypePack(*arena, builtinTypes, expectedFunction->argTypes, fn->args.size); - - genericTypes = expectedFunction->generics; - genericTypePacks = expectedFunction->genericPacks; - } - - if (fn->self) - { - TypeId selfType = freshType(signatureScope); - argTypes.push_back(selfType); - argNames.emplace_back(FunctionArgument{fn->self->name.value, fn->self->location}); - signatureScope->bindings[fn->self] = Binding{selfType, fn->self->location}; - - BreadcrumbId bc = dfg->getBreadcrumb(fn->self); - signatureScope->dcrRefinements[bc->def] = selfType; - } - - for (size_t i = 0; i < fn->args.size; ++i) - { - AstLocal* local = fn->args.data[i]; - - TypeId argTy = nullptr; - if (local->annotation) - argTy = resolveType(signatureScope, local->annotation, /* inTypeArguments */ false, /* replaceErrorWithFresh*/ true); - else - { - argTy = freshType(signatureScope); - - if (i < expectedArgPack.head.size()) - addConstraint(signatureScope, local->location, SubtypeConstraint{argTy, expectedArgPack.head[i]}); - } - - argTypes.push_back(argTy); - argNames.emplace_back(FunctionArgument{local->name.value, local->location}); - signatureScope->bindings[local] = Binding{argTy, local->location}; - - BreadcrumbId bc = dfg->getBreadcrumb(local); - signatureScope->dcrRefinements[bc->def] = argTy; - } - - TypePackId varargPack = nullptr; - - if (fn->vararg) - { - if (fn->varargAnnotation) - { - TypePackId annotationType = - resolveTypePack(signatureScope, fn->varargAnnotation, /* inTypeArguments */ false, /* replaceErrorWithFresh */ true); - varargPack = annotationType; - } - else if (expectedArgPack.tail && get(*expectedArgPack.tail)) - varargPack = *expectedArgPack.tail; - else - varargPack = builtinTypes->anyTypePack; - - signatureScope->varargPack = varargPack; - bodyScope->varargPack = varargPack; - } - else - { - varargPack = arena->addTypePack(VariadicTypePack{builtinTypes->anyType, /*hidden*/ true}); - // We do not add to signatureScope->varargPack because ... is not valid - // in functions without an explicit ellipsis. - - signatureScope->varargPack = std::nullopt; - bodyScope->varargPack = std::nullopt; - } - - LUAU_ASSERT(nullptr != varargPack); - - // If there is both an annotation and an expected type, the annotation wins. - // Type checking will sort out any discrepancies later. - if (fn->returnAnnotation) - { - TypePackId annotatedRetType = - resolveTypePack(signatureScope, *fn->returnAnnotation, /* inTypeArguments */ false, /* replaceErrorWithFresh*/ true); - // We bind the annotated type directly here so that, when we need to - // generate constraints for return types, we have a guarantee that we - // know the annotated return type already, if one was provided. - LUAU_ASSERT(get(returnType)); - asMutable(returnType)->ty.emplace(annotatedRetType); - } - else if (expectedFunction) - { - asMutable(returnType)->ty.emplace(expectedFunction->retTypes); - } - - // TODO: Preserve argument names in the function's type. - - FunctionType actualFunction{TypeLevel{}, parent.get(), arena->addTypePack(argTypes, varargPack), returnType}; - actualFunction.hasNoGenerics = !hasGenerics; - actualFunction.generics = std::move(genericTypes); - actualFunction.genericPacks = std::move(genericTypePacks); - actualFunction.argNames = std::move(argNames); - actualFunction.hasSelf = fn->self != nullptr; - - TypeId actualFunctionType = arena->addType(std::move(actualFunction)); - LUAU_ASSERT(actualFunctionType); - module->astTypes[fn] = actualFunctionType; - - if (expectedType && get(*expectedType)) - bindFreeType(*expectedType, actualFunctionType); - - return { - /* signature */ actualFunctionType, - /* signatureScope */ signatureScope, - /* bodyScope */ bodyScope, - }; -} - -void ConstraintGraphBuilder::checkFunctionBody(const ScopePtr& scope, AstExprFunction* fn) -{ - visitBlockWithoutChildScope(scope, fn->body); - - // If it is possible for execution to reach the end of the function, the return type must be compatible with () - - if (nullptr != getFallthrough(fn->body)) - { - TypePackId empty = arena->addTypePack({}); // TODO we could have CSG retain one of these forever - addConstraint(scope, fn->location, PackSubtypeConstraint{scope->returnType, empty}); - } -} - -TypeId ConstraintGraphBuilder::resolveType(const ScopePtr& scope, AstType* ty, bool inTypeArguments, bool replaceErrorWithFresh) -{ - TypeId result = nullptr; - - if (auto ref = ty->as()) - { - if (FFlag::DebugLuauMagicTypes) - { - if (ref->name == "_luau_ice") - ice->ice("_luau_ice encountered", ty->location); - else if (ref->name == "_luau_print") - { - if (ref->parameters.size != 1 || !ref->parameters.data[0].type) - { - reportError(ty->location, GenericError{"_luau_print requires one generic parameter"}); - return builtinTypes->errorRecoveryType(); - } - else - return resolveType(scope, ref->parameters.data[0].type, inTypeArguments); - } - } - - std::optional alias; - - if (ref->prefix.has_value()) - { - alias = scope->lookupImportedType(ref->prefix->value, ref->name.value); - } - else - { - alias = scope->lookupType(ref->name.value); - } - - if (alias.has_value()) - { - // If the alias is not generic, we don't need to set up a blocked - // type and an instantiation constraint. - if (alias.has_value() && alias->typeParams.empty() && alias->typePackParams.empty()) - { - result = alias->type; - } - else - { - std::vector parameters; - std::vector packParameters; - - for (const AstTypeOrPack& p : ref->parameters) - { - // We do not enforce the ordering of types vs. type packs here; - // that is done in the parser. - if (p.type) - { - parameters.push_back(resolveType(scope, p.type, /* inTypeArguments */ true)); - } - else if (p.typePack) - { - packParameters.push_back(resolveTypePack(scope, p.typePack, /* inTypeArguments */ true)); - } - else - { - // This indicates a parser bug: one of these two pointers - // should be set. - LUAU_ASSERT(false); - } - } - - result = arena->addType(PendingExpansionType{ref->prefix, ref->name, parameters, packParameters}); - - // If we're not in a type argument context, we need to create a constraint that expands this. - // The dispatching of the above constraint will queue up additional constraints for nested - // type function applications. - if (!inTypeArguments) - addConstraint(scope, ty->location, TypeAliasExpansionConstraint{/* target */ result}); - } - } - else - { - result = builtinTypes->errorRecoveryType(); - if (replaceErrorWithFresh) - result = freshType(scope); - } - } - else if (auto tab = ty->as()) - { - TableType::Props props; - std::optional indexer; - - for (const AstTableProp& prop : tab->props) - { - std::string name = prop.name.value; - // TODO: Recursion limit. - TypeId propTy = resolveType(scope, prop.type, inTypeArguments); - // TODO: Fill in location. - props[name] = {propTy}; - } - - if (tab->indexer) - { - // TODO: Recursion limit. - indexer = TableIndexer{ - resolveType(scope, tab->indexer->indexType, inTypeArguments), - resolveType(scope, tab->indexer->resultType, inTypeArguments), - }; - } - - result = arena->addType(TableType{props, indexer, scope->level, scope.get(), TableState::Sealed}); - } - else if (auto fn = ty->as()) - { - // TODO: Recursion limit. - bool hasGenerics = fn->generics.size > 0 || fn->genericPacks.size > 0; - ScopePtr signatureScope = nullptr; - - std::vector genericTypes; - std::vector genericTypePacks; - - // If we don't have generics, we do not need to generate a child scope - // for the generic bindings to live on. - if (hasGenerics) - { - signatureScope = childScope(fn, scope); - - std::vector> genericDefinitions = createGenerics(signatureScope, fn->generics); - std::vector> genericPackDefinitions = createGenericPacks(signatureScope, fn->genericPacks); - - for (const auto& [name, g] : genericDefinitions) - { - genericTypes.push_back(g.ty); - } - - for (const auto& [name, g] : genericPackDefinitions) - { - genericTypePacks.push_back(g.tp); - } - } - else - { - // To eliminate the need to branch on hasGenerics below, we say that - // the signature scope is the parent scope if we don't have - // generics. - signatureScope = scope; - } - - TypePackId argTypes = resolveTypePack(signatureScope, fn->argTypes, inTypeArguments, replaceErrorWithFresh); - TypePackId returnTypes = resolveTypePack(signatureScope, fn->returnTypes, inTypeArguments, replaceErrorWithFresh); - - // TODO: FunctionType needs a pointer to the scope so that we know - // how to quantify/instantiate it. - FunctionType ftv{TypeLevel{}, scope.get(), {}, {}, argTypes, returnTypes}; - - // This replicates the behavior of the appropriate FunctionType - // constructors. - ftv.hasNoGenerics = !hasGenerics; - ftv.generics = std::move(genericTypes); - ftv.genericPacks = std::move(genericTypePacks); - - ftv.argNames.reserve(fn->argNames.size); - for (const auto& el : fn->argNames) - { - if (el) - { - const auto& [name, location] = *el; - ftv.argNames.push_back(FunctionArgument{name.value, location}); - } - else - { - ftv.argNames.push_back(std::nullopt); - } - } - - result = arena->addType(std::move(ftv)); - } - else if (auto tof = ty->as()) - { - // TODO: Recursion limit. - TypeId exprType = check(scope, tof->expr).ty; - result = exprType; - } - else if (auto unionAnnotation = ty->as()) - { - std::vector parts; - for (AstType* part : unionAnnotation->types) - { - // TODO: Recursion limit. - parts.push_back(resolveType(scope, part, inTypeArguments)); - } - - result = arena->addType(UnionType{parts}); - } - else if (auto intersectionAnnotation = ty->as()) - { - std::vector parts; - for (AstType* part : intersectionAnnotation->types) - { - // TODO: Recursion limit. - parts.push_back(resolveType(scope, part, inTypeArguments)); - } - - result = arena->addType(IntersectionType{parts}); - } - else if (auto boolAnnotation = ty->as()) - { - result = arena->addType(SingletonType(BooleanSingleton{boolAnnotation->value})); - } - else if (auto stringAnnotation = ty->as()) - { - result = arena->addType(SingletonType(StringSingleton{std::string(stringAnnotation->value.data, stringAnnotation->value.size)})); - } - else if (ty->is()) - { - result = builtinTypes->errorRecoveryType(); - if (replaceErrorWithFresh) - result = freshType(scope); - } - else - { - LUAU_ASSERT(0); - result = builtinTypes->errorRecoveryType(); - } - - module->astResolvedTypes[ty] = result; - return result; -} - -TypePackId ConstraintGraphBuilder::resolveTypePack(const ScopePtr& scope, AstTypePack* tp, bool inTypeArgument, bool replaceErrorWithFresh) -{ - TypePackId result; - if (auto expl = tp->as()) - { - result = resolveTypePack(scope, expl->typeList, inTypeArgument, replaceErrorWithFresh); - } - else if (auto var = tp->as()) - { - TypeId ty = resolveType(scope, var->variadicType, inTypeArgument, replaceErrorWithFresh); - result = arena->addTypePack(TypePackVar{VariadicTypePack{ty}}); - } - else if (auto gen = tp->as()) - { - if (std::optional lookup = scope->lookupPack(gen->genericName.value)) - { - result = *lookup; - } - else - { - reportError(tp->location, UnknownSymbol{gen->genericName.value, UnknownSymbol::Context::Type}); - result = builtinTypes->errorRecoveryTypePack(); - } - } - else - { - LUAU_ASSERT(0); - result = builtinTypes->errorRecoveryTypePack(); - } - - module->astResolvedTypePacks[tp] = result; - return result; -} - -TypePackId ConstraintGraphBuilder::resolveTypePack(const ScopePtr& scope, const AstTypeList& list, bool inTypeArguments, bool replaceErrorWithFresh) -{ - std::vector head; - - for (AstType* headTy : list.types) - { - head.push_back(resolveType(scope, headTy, inTypeArguments, replaceErrorWithFresh)); - } - - std::optional tail = std::nullopt; - if (list.tailType) - { - tail = resolveTypePack(scope, list.tailType, inTypeArguments, replaceErrorWithFresh); - } - - return arena->addTypePack(TypePack{head, tail}); -} - -std::vector> ConstraintGraphBuilder::createGenerics( - const ScopePtr& scope, AstArray generics, bool useCache, bool addTypes) -{ - std::vector> result; - for (const auto& generic : generics) - { - TypeId genericTy = nullptr; - - if (auto it = scope->parent->typeAliasTypeParameters.find(generic.name.value); useCache && it != scope->parent->typeAliasTypeParameters.end()) - genericTy = it->second; - else - { - genericTy = arena->addType(GenericType{scope.get(), generic.name.value}); - scope->parent->typeAliasTypeParameters[generic.name.value] = genericTy; - } - - std::optional defaultTy = std::nullopt; - - if (generic.defaultValue) - defaultTy = resolveType(scope, generic.defaultValue, /* inTypeArguments */ false); - - if (addTypes) - scope->privateTypeBindings[generic.name.value] = TypeFun{genericTy}; - - result.push_back({generic.name.value, GenericTypeDefinition{genericTy, defaultTy}}); - } - - return result; -} - -std::vector> ConstraintGraphBuilder::createGenericPacks( - const ScopePtr& scope, AstArray generics, bool useCache, bool addTypes) -{ - std::vector> result; - for (const auto& generic : generics) - { - TypePackId genericTy; - - if (auto it = scope->parent->typeAliasTypePackParameters.find(generic.name.value); - useCache && it != scope->parent->typeAliasTypePackParameters.end()) - genericTy = it->second; - else - { - genericTy = arena->addTypePack(TypePackVar{GenericTypePack{scope.get(), generic.name.value}}); - scope->parent->typeAliasTypePackParameters[generic.name.value] = genericTy; - } - - std::optional defaultTy = std::nullopt; - - if (generic.defaultValue) - defaultTy = resolveTypePack(scope, generic.defaultValue, /* inTypeArguments */ false); - - if (addTypes) - scope->privateTypePackBindings[generic.name.value] = genericTy; - - result.push_back({generic.name.value, GenericTypePackDefinition{genericTy, defaultTy}}); - } - - return result; -} - -Inference ConstraintGraphBuilder::flattenPack(const ScopePtr& scope, Location location, InferencePack pack) -{ - const auto& [tp, refinements] = pack; - RefinementId refinement = nullptr; - if (!refinements.empty()) - refinement = refinements[0]; - - if (auto f = first(tp)) - return Inference{*f, refinement}; - - TypeId typeResult = arena->addType(BlockedType{}); - TypePackId resultPack = arena->addTypePack({typeResult}, arena->freshTypePack(scope.get())); - addConstraint(scope, location, UnpackConstraint{resultPack, tp}); - - return Inference{typeResult, refinement}; -} - -void ConstraintGraphBuilder::reportError(Location location, TypeErrorData err) -{ - errors.push_back(TypeError{location, moduleName, std::move(err)}); - - if (logger) - logger->captureGenerationError(errors.back()); -} - -void ConstraintGraphBuilder::reportCodeTooComplex(Location location) -{ - errors.push_back(TypeError{location, moduleName, CodeTooComplex{}}); - - if (logger) - logger->captureGenerationError(errors.back()); -} - -struct GlobalPrepopulator : AstVisitor -{ - const NotNull globalScope; - const NotNull arena; - - GlobalPrepopulator(NotNull globalScope, NotNull arena) - : globalScope(globalScope) - , arena(arena) - { - } - - bool visit(AstStatFunction* function) override - { - if (AstExprGlobal* g = function->name->as()) - globalScope->bindings[g->name] = Binding{arena->addType(BlockedType{})}; - - return true; - } -}; - -void ConstraintGraphBuilder::prepopulateGlobalScope(const ScopePtr& globalScope, AstStatBlock* program) -{ - GlobalPrepopulator gp{NotNull{globalScope.get()}, arena}; - - program->visit(&gp); -} - -std::vector> ConstraintGraphBuilder::getExpectedCallTypesForFunctionOverloads(const TypeId fnType) -{ - std::vector funTys; - if (auto it = get(follow(fnType))) - { - for (TypeId intersectionComponent : it) - { - funTys.push_back(intersectionComponent); - } - } - - std::vector> expectedTypes; - // For a list of functions f_0 : e_0 -> r_0, ... f_n : e_n -> r_n, - // emit a list of arguments that the function could take at each position - // by unioning the arguments at each place - auto assignOption = [this, &expectedTypes](size_t index, TypeId ty) { - if (index == expectedTypes.size()) - expectedTypes.push_back(ty); - else if (ty) - { - auto& el = expectedTypes[index]; - - if (!el) - el = ty; - else - { - std::vector result = reduceUnion({*el, ty}); - if (result.empty()) - el = builtinTypes->neverType; - else if (result.size() == 1) - el = result[0]; - else - el = module->internalTypes.addType(UnionType{std::move(result)}); - } - } - }; - - for (const TypeId overload : funTys) - { - if (const FunctionType* ftv = get(follow(overload))) - { - auto [argsHead, argsTail] = flatten(ftv->argTypes); - size_t start = ftv->hasSelf ? 1 : 0; - size_t index = 0; - for (size_t i = start; i < argsHead.size(); ++i) - assignOption(index++, argsHead[i]); - if (argsTail) - { - argsTail = follow(*argsTail); - if (const VariadicTypePack* vtp = get(*argsTail)) - { - while (index < funTys.size()) - assignOption(index++, vtp->ty); - } - } - } - } - - // TODO vvijay Feb 24, 2023 apparently we have to demote the types here? - - return expectedTypes; -} - -std::vector> borrowConstraints(const std::vector& constraints) -{ - std::vector> result; - result.reserve(constraints.size()); - - for (const auto& c : constraints) - result.emplace_back(c.get()); - - return result; -} - -} // namespace Luau diff --git a/Analysis/src/ConstraintSolver.cpp b/Analysis/src/ConstraintSolver.cpp index 3cb4e4e7e..31afabb23 100644 --- a/Analysis/src/ConstraintSolver.cpp +++ b/Analysis/src/ConstraintSolver.cpp @@ -1,22 +1,38 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/ConstraintSolver.h" #include "Luau/Anyification.h" #include "Luau/ApplyTypeFunction.h" -#include "Luau/Clone.h" -#include "Luau/ConstraintSolver.h" +#include "Luau/Common.h" #include "Luau/DcrLogger.h" +#include "Luau/Generalization.h" #include "Luau/Instantiation.h" +#include "Luau/Instantiation2.h" #include "Luau/Location.h" -#include "Luau/Metamethods.h" #include "Luau/ModuleResolver.h" +#include "Luau/OverloadResolution.h" #include "Luau/Quantify.h" +#include "Luau/RecursionCounter.h" +#include "Luau/Simplify.h" +#include "Luau/TableLiteralInference.h" +#include "Luau/TimeTrace.h" #include "Luau/ToString.h" -#include "Luau/TypeUtils.h" #include "Luau/Type.h" -#include "Luau/Unifier.h" +#include "Luau/TypeFunction.h" +#include "Luau/TypeFwd.h" +#include "Luau/TypeUtils.h" +#include "Luau/Unifier2.h" #include "Luau/VisitType.h" -LUAU_FASTFLAGVARIABLE(DebugLuauLogSolver, false); +#include +#include + +LUAU_FASTFLAGVARIABLE(DebugLuauLogSolver, false) +LUAU_FASTFLAGVARIABLE(DebugLuauLogSolverIncludeDependencies, false) +LUAU_FASTFLAGVARIABLE(DebugLuauLogBindings, false) +LUAU_FASTINTVARIABLE(LuauSolverRecursionLimit, 500) +LUAU_DYNAMIC_FASTINT(LuauTypeSolverRelease) +LUAU_FASTFLAGVARIABLE(LuauRemoveNotAnyHack, false) namespace Luau { @@ -49,8 +65,39 @@ size_t HashBlockedConstraintId::operator()(const BlockedConstraintId& bci) const dumpBindings(child, opts); } -static std::pair, std::vector> saturateArguments(TypeArena* arena, NotNull builtinTypes, - const TypeFun& fn, const std::vector& rawTypeArguments, const std::vector& rawPackArguments) +// used only in asserts +[[maybe_unused]] static bool canMutate(TypeId ty, NotNull constraint) +{ + if (auto blocked = get(ty)) + { + Constraint* owner = blocked->getOwner(); + LUAU_ASSERT(owner); + return owner == constraint; + } + + return true; +} + +// used only in asserts +[[maybe_unused]] static bool canMutate(TypePackId tp, NotNull constraint) +{ + if (auto blocked = get(tp)) + { + Constraint* owner = blocked->owner; + LUAU_ASSERT(owner); + return owner == constraint; + } + + return true; +} + +static std::pair, std::vector> saturateArguments( + TypeArena* arena, + NotNull builtinTypes, + const TypeFun& fn, + const std::vector& rawTypeArguments, + const std::vector& rawPackArguments +) { std::vector saturatedTypeArguments; std::vector extraTypes; @@ -70,7 +117,7 @@ static std::pair, std::vector> saturateArguments // mutually exclusive with the type pack -> type conversion we do below: // extraTypes will only have elements in it if we have more types than we // have parameter slots for them to go into. - if (!extraTypes.empty()) + if (!extraTypes.empty() && !fn.typePackParams.empty()) { saturatedPackArguments.push_back(arena->addTypePack(extraTypes)); } @@ -86,7 +133,7 @@ static std::pair, std::vector> saturateArguments { saturatedTypeArguments.push_back(*first(tp)); } - else + else if (saturatedPackArguments.size() < fn.typePackParams.size()) { saturatedPackArguments.push_back(tp); } @@ -174,6 +221,12 @@ static std::pair, std::vector> saturateArguments saturatedPackArguments.push_back(builtinTypes->errorRecoveryTypePack()); } + for (TypeId& arg : saturatedTypeArguments) + arg = follow(arg); + + for (TypePackId& pack : saturatedPackArguments) + pack = follow(pack); + // At this point, these two conditions should be true. If they aren't we // will run into access violations. LUAU_ASSERT(saturatedTypeArguments.size() == fn.typeParams.size()); @@ -222,38 +275,86 @@ void dump(ConstraintSolver* cs, ToStringOptions& opts) int blockCount = it == cs->blockedConstraints.end() ? 0 : int(it->second); printf("\t%d\t%s\n", blockCount, toString(*c, opts).c_str()); - for (NotNull dep : c->dependencies) + if (FFlag::DebugLuauLogSolverIncludeDependencies) { - auto unsolvedIter = std::find(begin(cs->unsolvedConstraints), end(cs->unsolvedConstraints), dep); - if (unsolvedIter == cs->unsolvedConstraints.end()) - continue; - - auto it = cs->blockedConstraints.find(dep); - int blockCount = it == cs->blockedConstraints.end() ? 0 : int(it->second); - printf("\t%d\t\t%s\n", blockCount, toString(*dep, opts).c_str()); + for (NotNull dep : c->dependencies) + { + if (std::find(cs->unsolvedConstraints.begin(), cs->unsolvedConstraints.end(), dep) != cs->unsolvedConstraints.end()) + printf("\t\t|\t%s\n", toString(*dep, opts).c_str()); + } } } } -ConstraintSolver::ConstraintSolver(NotNull normalizer, NotNull rootScope, std::vector> constraints, - ModuleName moduleName, NotNull reducer, NotNull moduleResolver, std::vector requireCycles, - DcrLogger* logger) +struct InstantiationQueuer : TypeOnceVisitor +{ + ConstraintSolver* solver; + NotNull scope; + Location location; + + explicit InstantiationQueuer(NotNull scope, const Location& location, ConstraintSolver* solver) + : solver(solver) + , scope(scope) + , location(location) + { + } + + bool visit(TypeId ty, const PendingExpansionType& petv) override + { + solver->pushConstraint(scope, location, TypeAliasExpansionConstraint{ty}); + return false; + } + + bool visit(TypeId ty, const TypeFunctionInstanceType&) override + { + solver->pushConstraint(scope, location, ReduceConstraint{ty}); + return true; + } + + bool visit(TypeId ty, const ClassType& ctv) override + { + return false; + } +}; + +ConstraintSolver::ConstraintSolver( + NotNull normalizer, + NotNull typeFunctionRuntime, + NotNull rootScope, + std::vector> constraints, + ModuleName moduleName, + NotNull moduleResolver, + std::vector requireCycles, + DcrLogger* logger, + NotNull dfg, + TypeCheckLimits limits +) : arena(normalizer->arena) , builtinTypes(normalizer->builtinTypes) , normalizer(normalizer) - , reducer(reducer) + , typeFunctionRuntime(typeFunctionRuntime) , constraints(std::move(constraints)) , rootScope(rootScope) , currentModuleName(std::move(moduleName)) + , dfg(dfg) , moduleResolver(moduleResolver) - , requireCycles(requireCycles) + , requireCycles(std::move(requireCycles)) , logger(logger) + , limits(std::move(limits)) { opts.exhaustive = true; for (NotNull c : this->constraints) { - unsolvedConstraints.push_back(c); + unsolvedConstraints.emplace_back(c); + + // initialize the reference counts for the free types in this constraint. + for (auto ty : c->getMaybeMutatedFreeTypes()) + { + // increment the reference count for `ty` + auto [refCount, _] = unresolvedConstraints.try_insert(ty, 0); + refCount += 1; + } for (NotNull dep : c->dependencies) { @@ -284,12 +385,16 @@ void ConstraintSolver::randomize(unsigned seed) void ConstraintSolver::run() { + LUAU_TIMETRACE_SCOPE("ConstraintSolver::run", "Typechecking"); + if (isDone()) return; if (FFlag::DebugLuauLogSolver) { - printf("Starting solver\n"); + printf( + "Starting solver for module %s (%s)\n", moduleResolver->getHumanReadableModuleName(currentModuleName).c_str(), currentModuleName.c_str() + ); dump(this, opts); printf("Bindings:\n"); dumpBindings(rootScope, opts); @@ -300,7 +405,8 @@ void ConstraintSolver::run() logger->captureInitialSolverState(rootScope, unsolvedConstraints); } - auto runSolverPass = [&](bool force) { + auto runSolverPass = [&](bool force) + { bool progress = false; size_t i = 0; @@ -313,6 +419,11 @@ void ConstraintSolver::run() continue; } + if (limits.finishTime && TimeTrace::getClock() > *limits.finishTime) + throwTimeLimitError(); + if (limits.cancellationToken && limits.cancellationToken->requested()) + throwUserCancelError(); + std::string saveMe = FFlag::DebugLuauLogSolver ? toString(*c, opts) : std::string{}; StepSnapshot snapshot; @@ -330,6 +441,22 @@ void ConstraintSolver::run() unblock(c); unsolvedConstraints.erase(unsolvedConstraints.begin() + i); + // decrement the referenced free types for this constraint if we dispatched successfully! + for (auto ty : c->getMaybeMutatedFreeTypes()) + { + size_t& refCount = unresolvedConstraints[ty]; + if (refCount > 0) + refCount -= 1; + + // We have two constraints that are designed to wait for the + // refCount on a free type to be equal to 1: the + // PrimitiveTypeConstraint and ReduceConstraint. We + // therefore wake any constraint waiting for a free type's + // refcount to be 1 or 0. + if (refCount <= 1) + unblock(ty, Location{}); + } + if (logger) { logger->commitStepSnapshot(snapshot); @@ -382,12 +509,17 @@ void ConstraintSolver::run() progress |= runSolverPass(true); } while (progress); - finalizeModule(); + if (!unsolvedConstraints.empty()) + reportError(ConstraintSolvingIncompleteError{}, Location{}); - if (FFlag::DebugLuauLogSolver) - { + // After we have run all the constraints, type functions should be generalized + // At this point, we can try to perform one final simplification to suss out + // whether type functions are truly uninhabited or if they can reduce + + finalizeTypeFunctions(); + + if (FFlag::DebugLuauLogSolver || FFlag::DebugLuauLogBindings) dumpBindings(rootScope, opts); - } if (logger) { @@ -395,22 +527,89 @@ void ConstraintSolver::run() } } +void ConstraintSolver::finalizeTypeFunctions() +{ + // At this point, we've generalized. Let's try to finish reducing as much as we can, we'll leave warning to the typechecker + for (auto [t, constraint] : typeFunctionsToFinalize) + { + TypeId ty = follow(t); + if (get(ty)) + { + FunctionGraphReductionResult result = + reduceTypeFunctions(t, constraint->location, TypeFunctionContext{NotNull{this}, constraint->scope, NotNull{constraint}}, true); + + for (TypeId r : result.reducedTypes) + unblock(r, constraint->location); + for (TypePackId r : result.reducedPacks) + unblock(r, constraint->location); + } + } +} + bool ConstraintSolver::isDone() { return unsolvedConstraints.empty(); } -void ConstraintSolver::finalizeModule() +namespace { - Anyification a{arena, rootScope, builtinTypes, &iceReporter, builtinTypes->anyType, builtinTypes->anyTypePack}; - std::optional returnType = a.substitute(rootScope->returnType); - if (!returnType) - { - reportError(CodeTooComplex{}, Location{}); - rootScope->returnType = builtinTypes->errorTypePack; - } - else - rootScope->returnType = *returnType; + +struct TypeAndLocation +{ + TypeId typeId; + Location location; +}; + +} // namespace + +void ConstraintSolver::bind(NotNull constraint, TypeId ty, TypeId boundTo) +{ + LUAU_ASSERT(get(ty) || get(ty) || get(ty)); + LUAU_ASSERT(canMutate(ty, constraint)); + + boundTo = follow(boundTo); + if (get(ty) && ty == boundTo) + return emplace(constraint, ty, constraint->scope, builtinTypes->neverType, builtinTypes->unknownType); + + shiftReferences(ty, boundTo); + emplaceType(asMutable(ty), boundTo); + unblock(ty, constraint->location); +} + +void ConstraintSolver::bind(NotNull constraint, TypePackId tp, TypePackId boundTo) +{ + LUAU_ASSERT(get(tp) || get(tp)); + LUAU_ASSERT(canMutate(tp, constraint)); + + boundTo = follow(boundTo); + LUAU_ASSERT(tp != boundTo); + + emplaceTypePack(asMutable(tp), boundTo); + unblock(tp, constraint->location); +} + +template +void ConstraintSolver::emplace(NotNull constraint, TypeId ty, Args&&... args) +{ + static_assert(!std::is_same_v, "cannot use `emplace`! use `bind`"); + + LUAU_ASSERT(get(ty) || get(ty) || get(ty)); + LUAU_ASSERT(canMutate(ty, constraint)); + + emplaceType(asMutable(ty), std::forward(args)...); + unblock(ty, constraint->location); +} + +template +void ConstraintSolver::emplace(NotNull constraint, TypePackId tp, Args&&... args) +{ + static_assert(!std::is_same_v, "cannot use `emplace`! use `bind`"); + + LUAU_ASSERT(get(tp) || get(tp)); + LUAU_ASSERT(canMutate(tp, constraint)); + + emplaceTypePack(asMutable(tp), std::forward(args)...); + unblock(tp, constraint->location); } bool ConstraintSolver::tryDispatch(NotNull constraint, bool force) @@ -421,17 +620,11 @@ bool ConstraintSolver::tryDispatch(NotNull constraint, bool fo bool success = false; if (auto sc = get(*constraint)) - success = tryDispatch(*sc, constraint, force); + success = tryDispatch(*sc, constraint); else if (auto psc = get(*constraint)) - success = tryDispatch(*psc, constraint, force); + success = tryDispatch(*psc, constraint); else if (auto gc = get(*constraint)) - success = tryDispatch(*gc, constraint, force); - else if (auto ic = get(*constraint)) - success = tryDispatch(*ic, constraint, force); - else if (auto uc = get(*constraint)) - success = tryDispatch(*uc, constraint, force); - else if (auto bc = get(*constraint)) - success = tryDispatch(*bc, constraint, force); + success = tryDispatch(*gc, constraint); else if (auto ic = get(*constraint)) success = tryDispatch(*ic, constraint, force); else if (auto nc = get(*constraint)) @@ -440,384 +633,94 @@ bool ConstraintSolver::tryDispatch(NotNull constraint, bool fo success = tryDispatch(*taec, constraint); else if (auto fcc = get(*constraint)) success = tryDispatch(*fcc, constraint); + else if (auto fcc = get(*constraint)) + success = tryDispatch(*fcc, constraint); else if (auto fcc = get(*constraint)) success = tryDispatch(*fcc, constraint); else if (auto hpc = get(*constraint)) success = tryDispatch(*hpc, constraint); - else if (auto spc = get(*constraint)) - success = tryDispatch(*spc, constraint, force); - else if (auto spc = get(*constraint)) - success = tryDispatch(*spc, constraint, force); - else if (auto sottc = get(*constraint)) - success = tryDispatch(*sottc, constraint); + else if (auto spc = get(*constraint)) + success = tryDispatch(*spc, constraint); + else if (auto uc = get(*constraint)) + success = tryDispatch(*uc, constraint); + else if (auto uc = get(*constraint)) + success = tryDispatch(*uc, constraint); else if (auto uc = get(*constraint)) success = tryDispatch(*uc, constraint); + else if (auto rc = get(*constraint)) + success = tryDispatch(*rc, constraint, force); + else if (auto rpc = get(*constraint)) + success = tryDispatch(*rpc, constraint, force); + else if (auto eqc = get(*constraint)) + success = tryDispatch(*eqc, constraint); else LUAU_ASSERT(false); - if (success) - unblock(constraint); - return success; } -bool ConstraintSolver::tryDispatch(const SubtypeConstraint& c, NotNull constraint, bool force) +bool ConstraintSolver::tryDispatch(const SubtypeConstraint& c, NotNull constraint) { if (isBlocked(c.subType)) return block(c.subType, constraint); else if (isBlocked(c.superType)) return block(c.superType, constraint); - Unifier u{normalizer, Mode::Strict, constraint->scope, Location{}, Covariant}; - u.useScopes = true; - - u.tryUnify(c.subType, c.superType); - - if (!u.blockedTypes.empty() || !u.blockedTypePacks.empty()) - { - for (TypeId bt : u.blockedTypes) - block(bt, constraint); - for (TypePackId btp : u.blockedTypePacks) - block(btp, constraint); - return false; - } - - if (const auto& e = hasUnificationTooComplex(u.errors)) - reportError(*e); - - if (!u.errors.empty()) - { - TypeId errorType = errorRecoveryType(); - u.tryUnify(c.subType, errorType); - u.tryUnify(c.superType, errorType); - } - - const auto [changedTypes, changedPacks] = u.log.getChanges(); - - u.log.commit(); - - unblock(changedTypes); - unblock(changedPacks); - - // unify(c.subType, c.superType, constraint->scope); + unify(constraint, c.subType, c.superType); return true; } -bool ConstraintSolver::tryDispatch(const PackSubtypeConstraint& c, NotNull constraint, bool force) +bool ConstraintSolver::tryDispatch(const PackSubtypeConstraint& c, NotNull constraint) { if (isBlocked(c.subPack)) return block(c.subPack, constraint); else if (isBlocked(c.superPack)) return block(c.superPack, constraint); - unify(c.subPack, c.superPack, constraint->scope); + unify(constraint, c.subPack, c.superPack); return true; } -bool ConstraintSolver::tryDispatch(const GeneralizationConstraint& c, NotNull constraint, bool force) +bool ConstraintSolver::tryDispatch(const GeneralizationConstraint& c, NotNull constraint) { + TypeId generalizedType = follow(c.generalizedType); + if (isBlocked(c.sourceType)) return block(c.sourceType, constraint); + else if (get(generalizedType)) + return block(generalizedType, constraint); - TypeId generalized = quantify(arena, c.sourceType, constraint->scope); - - if (isBlocked(c.generalizedType)) - asMutable(c.generalizedType)->ty.emplace(generalized); - else - unify(c.generalizedType, generalized, constraint->scope); - - unblock(c.generalizedType); - unblock(c.sourceType); - - return true; -} - -bool ConstraintSolver::tryDispatch(const InstantiationConstraint& c, NotNull constraint, bool force) -{ - if (isBlocked(c.superType)) - return block(c.superType, constraint); - - Instantiation inst(TxnLog::empty(), arena, TypeLevel{}, constraint->scope); + std::optional generalized; - std::optional instantiated = inst.substitute(c.superType); - LUAU_ASSERT(instantiated); // TODO FIXME HANDLE THIS - - if (isBlocked(c.subType)) - asMutable(c.subType)->ty.emplace(*instantiated); + std::optional generalizedTy = generalize(NotNull{arena}, builtinTypes, constraint->scope, generalizedTypes, c.sourceType); + if (generalizedTy) + generalized = QuantifierResult{*generalizedTy}; // FIXME insertedGenerics and insertedGenericPacks else - unify(c.subType, *instantiated, constraint->scope); - - unblock(c.subType); - - return true; -} - -bool ConstraintSolver::tryDispatch(const UnaryConstraint& c, NotNull constraint, bool force) -{ - TypeId operandType = follow(c.operandType); + reportError(CodeTooComplex{}, constraint->location); - if (isBlocked(operandType)) - return block(operandType, constraint); - - if (get(operandType)) - return block(operandType, constraint); - - LUAU_ASSERT(get(c.resultType)); - - switch (c.op) - { - case AstExprUnary::Not: + if (generalized) { - asMutable(c.resultType)->ty.emplace(builtinTypes->booleanType); - return true; - } - case AstExprUnary::Len: - { - // __len must return a number. - asMutable(c.resultType)->ty.emplace(builtinTypes->numberType); - return true; - } - case AstExprUnary::Minus: - { - if (isNumber(operandType) || get(operandType) || get(operandType) || get(operandType)) - { - asMutable(c.resultType)->ty.emplace(c.operandType); - } - else if (std::optional mm = findMetatableEntry(builtinTypes, errors, operandType, "__unm", constraint->location)) - { - TypeId mmTy = follow(*mm); - - if (get(mmTy) && !force) - return block(mmTy, constraint); - - TypePackId argPack = arena->addTypePack(TypePack{{operandType}, {}}); - TypePackId retPack = arena->addTypePack(BlockedTypePack{}); - - asMutable(c.resultType)->ty.emplace(constraint->scope); - - pushConstraint(constraint->scope, constraint->location, PackSubtypeConstraint{retPack, arena->addTypePack(TypePack{{c.resultType}})}); - - pushConstraint(constraint->scope, constraint->location, FunctionCallConstraint{mmTy, argPack, retPack, nullptr}); - } + if (get(generalizedType)) + bind(constraint, generalizedType, generalized->result); else - { - asMutable(c.resultType)->ty.emplace(builtinTypes->errorRecoveryType()); - } - - return true; - } - } - - LUAU_ASSERT(false); - return false; -} - -bool ConstraintSolver::tryDispatch(const BinaryConstraint& c, NotNull constraint, bool force) -{ - TypeId leftType = follow(c.leftType); - TypeId rightType = follow(c.rightType); - TypeId resultType = follow(c.resultType); - - bool isLogical = c.op == AstExprBinary::Op::And || c.op == AstExprBinary::Op::Or; - - /* Compound assignments create constraints of the form - * - * A <: Binary - * - * This constraint is the one that is meant to unblock A, so it doesn't - * make any sense to stop and wait for someone else to do it. - */ - - if (isBlocked(leftType) && leftType != resultType) - return block(c.leftType, constraint); - - if (isBlocked(rightType) && rightType != resultType) - return block(c.rightType, constraint); - - if (!force) - { - // Logical expressions may proceed if the LHS is free. - if (get(leftType) && !isLogical) - return block(leftType, constraint); - } - - // Logical expressions may proceed if the LHS is free. - if (isBlocked(leftType) || (get(leftType) && !isLogical)) - { - asMutable(resultType)->ty.emplace(errorRecoveryType()); - unblock(resultType); - return true; - } - - // Metatables go first, even if there is primitive behavior. - if (auto it = kBinaryOpMetamethods.find(c.op); it != kBinaryOpMetamethods.end()) - { - // Metatables are not the same. The metamethod will not be invoked. - if ((c.op == AstExprBinary::Op::CompareEq || c.op == AstExprBinary::Op::CompareNe) && - getMetatable(leftType, builtinTypes) != getMetatable(rightType, builtinTypes)) - { - // TODO: Boolean singleton false? The result is _always_ boolean false. - asMutable(resultType)->ty.emplace(builtinTypes->booleanType); - unblock(resultType); - return true; - } - - std::optional mm; - - // The LHS metatable takes priority over the RHS metatable, where - // present. - if (std::optional leftMm = findMetatableEntry(builtinTypes, errors, leftType, it->second, constraint->location)) - mm = leftMm; - else if (std::optional rightMm = findMetatableEntry(builtinTypes, errors, rightType, it->second, constraint->location)) - mm = rightMm; - - if (mm) - { - Instantiation instantiation{TxnLog::empty(), arena, TypeLevel{}, constraint->scope}; - std::optional instantiatedMm = instantiation.substitute(*mm); - if (!instantiatedMm) - { - reportError(CodeTooComplex{}, constraint->location); - return true; - } - - // TODO: Is a table with __call legal here? - // TODO: Overloads - if (const FunctionType* ftv = get(follow(*instantiatedMm))) - { - TypePackId inferredArgs; - // For >= and > we invoke __lt and __le respectively with - // swapped argument ordering. - if (c.op == AstExprBinary::Op::CompareGe || c.op == AstExprBinary::Op::CompareGt) - { - inferredArgs = arena->addTypePack({rightType, leftType}); - } - else - { - inferredArgs = arena->addTypePack({leftType, rightType}); - } - - unify(inferredArgs, ftv->argTypes, constraint->scope); - - TypeId mmResult; - - // Comparison operations always evaluate to a boolean, - // regardless of what the metamethod returns. - switch (c.op) - { - case AstExprBinary::Op::CompareEq: - case AstExprBinary::Op::CompareNe: - case AstExprBinary::Op::CompareGe: - case AstExprBinary::Op::CompareGt: - case AstExprBinary::Op::CompareLe: - case AstExprBinary::Op::CompareLt: - mmResult = builtinTypes->booleanType; - break; - default: - mmResult = first(ftv->retTypes).value_or(errorRecoveryType()); - } - - asMutable(resultType)->ty.emplace(mmResult); - unblock(resultType); - - (*c.astOriginalCallTypes)[c.astFragment] = *mm; - (*c.astOverloadResolvedTypes)[c.astFragment] = *instantiatedMm; - return true; - } - } - - // If there's no metamethod available, fall back to primitive behavior. - } - - // If any is present, the expression must evaluate to any as well. - bool leftAny = get(leftType) || get(leftType); - bool rightAny = get(rightType) || get(rightType); - bool anyPresent = leftAny || rightAny; - - switch (c.op) - { - // For arithmetic operators, if the LHS is a number, the RHS must be a - // number as well. The result will also be a number. - case AstExprBinary::Op::Add: - case AstExprBinary::Op::Sub: - case AstExprBinary::Op::Mul: - case AstExprBinary::Op::Div: - case AstExprBinary::Op::Pow: - case AstExprBinary::Op::Mod: - if (isNumber(leftType)) - { - unify(leftType, rightType, constraint->scope); - asMutable(resultType)->ty.emplace(anyPresent ? builtinTypes->anyType : leftType); - unblock(resultType); - return true; - } - - break; - // For concatenation, if the LHS is a string, the RHS must be a string as - // well. The result will also be a string. - case AstExprBinary::Op::Concat: - if (isString(leftType)) - { - unify(leftType, rightType, constraint->scope); - asMutable(resultType)->ty.emplace(anyPresent ? builtinTypes->anyType : leftType); - unblock(resultType); - return true; - } - - break; - // Inexact comparisons require that the types be both numbers or both - // strings, and evaluate to a boolean. - case AstExprBinary::Op::CompareGe: - case AstExprBinary::Op::CompareGt: - case AstExprBinary::Op::CompareLe: - case AstExprBinary::Op::CompareLt: - if ((isNumber(leftType) && isNumber(rightType)) || (isString(leftType) && isString(rightType))) - { - asMutable(resultType)->ty.emplace(builtinTypes->booleanType); - unblock(resultType); - return true; - } + unify(constraint, generalizedType, generalized->result); - break; - // == and ~= always evaluate to a boolean, and impose no other constraints - // on their parameters. - case AstExprBinary::Op::CompareEq: - case AstExprBinary::Op::CompareNe: - asMutable(resultType)->ty.emplace(builtinTypes->booleanType); - unblock(resultType); - return true; - // And evalutes to a boolean if the LHS is falsey, and the RHS type if LHS is - // truthy. - case AstExprBinary::Op::And: - { - TypeId leftFilteredTy = arena->addType(IntersectionType{{builtinTypes->falsyType, leftType}}); + for (auto [free, gen] : generalized->insertedGenerics.pairings) + unify(constraint, free, gen); - asMutable(resultType)->ty.emplace(arena->addType(UnionType{{leftFilteredTy, rightType}})); - unblock(resultType); - return true; + for (auto [free, gen] : generalized->insertedGenericPacks.pairings) + unify(constraint, free, gen); } - // Or evaluates to the LHS type if the LHS is truthy, and the RHS type if - // LHS is falsey. - case AstExprBinary::Op::Or: + else { - TypeId leftFilteredTy = arena->addType(IntersectionType{{builtinTypes->truthyType, leftType}}); - - asMutable(resultType)->ty.emplace(arena->addType(UnionType{{leftFilteredTy, rightType}})); - unblock(resultType); - return true; - } - default: - iceReporter.ice("Unhandled AstExprBinary::Op for binary operation", constraint->location); - break; + reportError(CodeTooComplex{}, constraint->location); + bind(constraint, c.generalizedType, builtinTypes->errorRecoveryType()); } - // We failed to either evaluate a metamethod or invoke primitive behavior. - unify(leftType, errorRecoveryType(), constraint->scope); - unify(rightType, errorRecoveryType(), constraint->scope); - asMutable(resultType)->ty.emplace(errorRecoveryType()); - unblock(resultType); + for (TypeId ty : c.interiorTypes) + generalize(NotNull{arena}, builtinTypes, constraint->scope, generalizedTypes, ty, /* avoidSealingTables */ false); return true; } @@ -841,14 +744,15 @@ bool ConstraintSolver::tryDispatch(const IterableConstraint& c, NotNullscope, builtinTypes, &iceReporter, errorRecoveryType(), errorRecoveryTypePack()}; - std::optional anyified = anyify.substitute(c.variables); - LUAU_ASSERT(anyified); - unify(*anyified, c.variables, constraint->scope); + for (TypeId ty : c.variables) + unify(constraint, builtinTypes->errorRecoveryType(), ty); return true; } - TypeId nextTy = follow(iteratorTypes[0]); + TypeId nextTy = follow(iterator.head[0]); if (get(nextTy)) - return block_(nextTy); + { + TypeId keyTy = freshType(arena, builtinTypes, constraint->scope); + TypeId valueTy = freshType(arena, builtinTypes, constraint->scope); + TypeId tableTy = + arena->addType(TableType{TableType::Props{}, TableIndexer{keyTy, valueTy}, TypeLevel{}, constraint->scope, TableState::Free}); + + unify(constraint, nextTy, tableTy); + + auto it = begin(c.variables); + auto endIt = end(c.variables); + + if (it != endIt) + { + bind(constraint, *it, keyTy); + ++it; + } + if (it != endIt) + { + bind(constraint, *it, valueTy); + ++it; + } + + while (it != endIt) + { + bind(constraint, *it, builtinTypes->nilType); + ++it; + } + + return true; + } if (get(nextTy)) { TypeId tableTy = builtinTypes->nilType; - if (iteratorTypes.size() >= 2) - tableTy = iteratorTypes[1]; - - TypeId firstIndexTy = builtinTypes->nilType; - if (iteratorTypes.size() >= 3) - firstIndexTy = iteratorTypes[2]; + if (iterator.head.size() >= 2) + tableTy = iterator.head[1]; - return tryDispatchIterableFunction(nextTy, tableTy, firstIndexTy, c, constraint, force); + return tryDispatchIterableFunction(nextTy, tableTy, c, constraint); } else - return tryDispatchIterableTable(iteratorTypes[0], c, constraint, force); + return tryDispatchIterableTable(iterator.head[0], c, constraint, force); return true; } @@ -939,8 +866,6 @@ bool ConstraintSolver::tryDispatch(const NameConstraint& c, NotNull scope; - Location location; - - explicit InstantiationQueuer(NotNull scope, const Location& location, ConstraintSolver* solver) - : solver(solver) - , scope(scope) - , location(location) - { - } - - bool visit(TypeId ty, const PendingExpansionType& petv) override - { - solver->pushConstraint(scope, location, TypeAliasExpansionConstraint{ty}); - return false; - } -}; - bool ConstraintSolver::tryDispatch(const TypeAliasExpansionConstraint& c, NotNull constraint) { const PendingExpansionType* petv = get(follow(c.target)); if (!petv) { - unblock(c.target); + unblock(c.target, constraint->location); // TODO: do we need this? any re-entrancy? return true; } - auto bindResult = [this, &c](TypeId result) { - asMutable(c.target)->ty.emplace(result); - unblock(c.target); + auto bindResult = [this, &c, constraint](TypeId result) + { + if (DFInt::LuauTypeSolverRelease >= 646) + { + auto cTarget = follow(c.target); + LUAU_ASSERT(get(cTarget)); + shiftReferences(cTarget, result); + bind(constraint, cTarget, result); + } + else + { + LUAU_ASSERT(get(c.target)); + shiftReferences(c.target, result); + bind(constraint, c.target, result); + } }; std::optional tf = (petv->prefix) ? constraint->scope->lookupImportedType(petv->prefix->value, petv->name.value) @@ -1023,6 +940,10 @@ bool ConstraintSolver::tryDispatch(const TypeAliasExpansionConstraint& c, NotNul return true; } + // Adding ReduceConstraint on type function for the constraint solver + if (auto typeFn = get(follow(tf->type))) + pushConstraint(NotNull(constraint->scope.get()), constraint->location, ReduceConstraint{tf->type}); + // If there are no parameters to the type function we can just use the type // directly. if (tf->typeParams.empty() && tf->typePackParams.empty()) @@ -1031,16 +952,41 @@ bool ConstraintSolver::tryDispatch(const TypeAliasExpansionConstraint& c, NotNul return true; } - auto [typeArguments, packArguments] = saturateArguments(arena, builtinTypes, *tf, petv->typeArguments, petv->packArguments); + // Due to how pending expansion types and TypeFun's are created + // If this check passes, we have created a cyclic / corecursive type alias + // of size 0 + TypeId lhs = DFInt::LuauTypeSolverRelease >= 646 ? follow(c.target) : c.target; + TypeId rhs = tf->type; + if (occursCheck(lhs, rhs)) + { + reportError(OccursCheckFailed{}, constraint->location); + bindResult(errorRecoveryType()); + return true; + } - bool sameTypes = std::equal(typeArguments.begin(), typeArguments.end(), tf->typeParams.begin(), tf->typeParams.end(), [](auto&& itp, auto&& p) { - return itp == p.ty; - }); + auto [typeArguments, packArguments] = saturateArguments(arena, builtinTypes, *tf, petv->typeArguments, petv->packArguments); - bool samePacks = - std::equal(packArguments.begin(), packArguments.end(), tf->typePackParams.begin(), tf->typePackParams.end(), [](auto&& itp, auto&& p) { + bool sameTypes = std::equal( + typeArguments.begin(), + typeArguments.end(), + tf->typeParams.begin(), + tf->typeParams.end(), + [](auto&& itp, auto&& p) + { + return itp == p.ty; + } + ); + + bool samePacks = std::equal( + packArguments.begin(), + packArguments.end(), + tf->typePackParams.begin(), + tf->typePackParams.end(), + [](auto&& itp, auto&& p) + { return itp == p.tp; - }); + } + ); // If we're instantiating the type with its generic saturatedTypeArguments we are // performing the identity substitution. We can just short-circuit and bind @@ -1069,9 +1015,9 @@ bool ConstraintSolver::tryDispatch(const TypeAliasExpansionConstraint& c, NotNul // In order to prevent infinite types from being expanded and causing us to // cycle infinitely, we need to scan the type function for cases where we // expand the same alias with different type saturatedTypeArguments. See - // https://github.com/Roblox/luau/pull/68 for the RFC responsible for this. - // This is a little nicer than using a recursion limit because we can catch - // the infinite expansion before actually trying to expand it. + // https://github.com/luau-lang/luau/pull/68 for the RFC responsible for + // this. This is a little nicer than using a recursion limit because we can + // catch the infinite expansion before actually trying to expand it. InfiniteTypeFinder itf{this, signature, constraint->scope}; itf.traverse(tf->type); @@ -1115,7 +1061,7 @@ bool ConstraintSolver::tryDispatch(const TypeAliasExpansionConstraint& c, NotNul InstantiationQueuer queuer{constraint->scope, constraint->location, this}; queuer.traverse(target); - if (target->persistent) + if (target->persistent || target->owningArena != arena) { bindResult(target); return true; @@ -1123,7 +1069,18 @@ bool ConstraintSolver::tryDispatch(const TypeAliasExpansionConstraint& c, NotNul // Type function application will happily give us the exact same type if // there are e.g. generic saturatedTypeArguments that go unused. - bool needsClone = follow(tf->type) == target; + const TableType* tfTable = getTableType(tf->type); + + bool needsClone = follow(tf->type) == target || (tfTable != nullptr && tfTable == getTableType(target)) || + std::any_of( + typeArguments.begin(), + typeArguments.end(), + [&](const auto& other) + { + return other == target; + } + ); + // Only tables have the properties we're trying to set. TableType* ttv = getMutableTableType(target); @@ -1170,12 +1127,54 @@ bool ConstraintSolver::tryDispatch(const FunctionCallConstraint& c, NotNull std::optional { + if (get(fn)) + { + emplaceTypePack(asMutable(c.result), builtinTypes->anyTypePack); + unblock(c.result, constraint->location); + return true; + } + + // if we're calling an error type, the result is an error type, and that's that. + if (get(fn)) + { + bind(constraint, c.result, builtinTypes->errorRecoveryTypePack()); + return true; + } + + if (get(fn)) + { + bind(constraint, c.result, builtinTypes->neverTypePack); + return true; + } + + auto [argsHead, argsTail] = flatten(argsPack); + + bool blocked = false; + for (TypeId t : argsHead) + { + if (isBlocked(t)) + { + block(t, constraint); + blocked = true; + } + } + + if (argsTail && isBlocked(*argsTail)) + { + block(*argsTail, constraint); + blocked = true; + } + + if (blocked) + return false; + + auto collapse = [](const auto* t) -> std::optional + { auto it = begin(t); auto endIt = end(t); @@ -1200,14 +1199,17 @@ bool ConstraintSolver::tryDispatch(const FunctionCallConstraint& c, NotNull callMm = findMetatableEntry(builtinTypes, errors, fn, "__call", constraint->location)) { - std::vector args{fn}; + if (isBlocked(*callMm)) + return block(*callMm, constraint); + + argsHead.insert(argsHead.begin(), fn); - for (TypeId arg : c.argsPack) - args.push_back(arg); + if (argsTail && isBlocked(*argsTail)) + return block(*argsTail, constraint); - argsPack = arena->addTypePack(TypePack{args, {}}); - fn = *callMm; - asMutable(c.result)->ty.emplace(constraint->scope); + argsPack = arena->addTypePack(TypePack{std::move(argsHead), argsTail}); + fn = follow(*callMm); + emplace(constraint, c.result, constraint->scope); } else { @@ -1217,104 +1219,309 @@ bool ConstraintSolver::tryDispatch(const FunctionCallConstraint& c, NotNulldcrMagicFunction) - usedMagic = ftv->dcrMagicFunction(MagicFunctionCallContext{NotNull(this), c.callSite, c.argsPack, result}); + usedMagic = ftv->dcrMagicFunction(MagicFunctionCallContext{NotNull{this}, constraint, c.callSite, c.argsPack, result}); if (ftv->dcrMagicRefinement) ftv->dcrMagicRefinement(MagicRefinementContext{constraint->scope, c.callSite, c.discriminantTypes}); } if (!usedMagic) - asMutable(c.result)->ty.emplace(constraint->scope); + emplace(constraint, c.result, constraint->scope); } for (std::optional ty : c.discriminantTypes) { - if (!ty || !isBlocked(*ty)) + if (!ty) continue; - // We use `any` here because the discriminant type may be pointed at by both branches, - // where the discriminant type is not negated, and the other where it is negated, i.e. - // `unknown ~ unknown` and `~unknown ~ never`, so `T & unknown ~ T` and `T & ~unknown ~ never` - // v.s. - // `any ~ any` and `~any ~ any`, so `T & any ~ T` and `T & ~any ~ T` - // - // In practice, users cannot negate `any`, so this is an implementation detail we can always change. - *asMutable(follow(*ty)) = BoundType{builtinTypes->anyType}; + // If the discriminant type has been transmuted, we need to unblock them. + if (!isBlocked(*ty)) + { + unblock(*ty, constraint->location); + continue; + } + + if (FFlag::LuauRemoveNotAnyHack) + { + // We bind any unused discriminants to the `*no-refine*` type indicating that it can be safely ignored. + emplaceType(asMutable(follow(*ty)), builtinTypes->noRefineType); + } + else + { + // We use `any` here because the discriminant type may be pointed at by both branches, + // where the discriminant type is not negated, and the other where it is negated, i.e. + // `unknown ~ unknown` and `~unknown ~ never`, so `T & unknown ~ T` and `T & ~unknown ~ never` + // v.s. + // `any ~ any` and `~any ~ any`, so `T & any ~ T` and `T & ~any ~ T` + // + // In practice, users cannot negate `any`, so this is an implementation detail we can always change. + emplaceType(asMutable(follow(*ty)), builtinTypes->anyType); + } } - TypeId instantiatedTy = arena->addType(BlockedType{}); + OverloadResolver resolver{ + builtinTypes, + NotNull{arena}, + normalizer, + typeFunctionRuntime, + constraint->scope, + NotNull{&iceReporter}, + NotNull{&limits}, + constraint->location + }; + auto [status, overload] = resolver.selectOverload(fn, argsPack); + TypeId overloadToUse = fn; + if (status == OverloadResolver::Analysis::Ok) + overloadToUse = overload; + TypeId inferredTy = arena->addType(FunctionType{TypeLevel{}, constraint->scope.get(), argsPack, c.result}); + Unifier2 u2{NotNull{arena}, builtinTypes, constraint->scope, NotNull{&iceReporter}}; - auto pushConstraintGreedy = [this, constraint](ConstraintV cv) -> Constraint* { - std::unique_ptr c = std::make_unique(constraint->scope, constraint->location, std::move(cv)); - NotNull borrow{c.get()}; + const bool occursCheckPassed = u2.unify(overloadToUse, inferredTy); - bool ok = tryDispatch(borrow, false); - if (ok) - return nullptr; + if (!u2.genericSubstitutions.empty() || !u2.genericPackSubstitutions.empty()) + { + std::optional subst = instantiate2(arena, std::move(u2.genericSubstitutions), std::move(u2.genericPackSubstitutions), result); + if (!subst) + { + reportError(CodeTooComplex{}, constraint->location); + result = builtinTypes->errorTypePack; + } + else + result = *subst; - solverConstraints.push_back(std::move(c)); - unsolvedConstraints.push_back(borrow); + if (c.result != result) + emplaceTypePack(asMutable(c.result), result); + } - return borrow; - }; + for (const auto& [expanded, additions] : u2.expandedFreeTypes) + { + for (TypeId addition : additions) + upperBoundContributors[expanded].emplace_back(constraint->location, addition); + } - // HACK: We don't want other constraints to act on the free type pack - // created above until after these two constraints are solved, so we try to - // dispatch them directly. + if (occursCheckPassed && c.callSite) + (*c.astOverloadResolvedTypes)[c.callSite] = inferredTy; + else if (!occursCheckPassed) + reportError(OccursCheckFailed{}, constraint->location); - auto ic = pushConstraintGreedy(InstantiationConstraint{instantiatedTy, fn}); - auto sc = pushConstraintGreedy(SubtypeConstraint{instantiatedTy, inferredTy}); + InstantiationQueuer queuer{constraint->scope, constraint->location, this}; + queuer.traverse(overloadToUse); + queuer.traverse(inferredTy); - // Anything that is blocked on this constraint must also be blocked on our - // synthesized constraints. - auto blockedIt = blocked.find(constraint.get()); - if (blockedIt != blocked.end()) + unblock(c.result, constraint->location); + + return true; +} + +static AstExpr* unwrapGroup(AstExpr* expr) +{ + while (auto group = expr->as()) + expr = group->expr; + + return expr; +} + +bool ConstraintSolver::tryDispatch(const FunctionCheckConstraint& c, NotNull constraint) +{ + TypeId fn = follow(c.fn); + const TypePackId argsPack = follow(c.argsPack); + + if (isBlocked(fn)) + return block(fn, constraint); + + if (isBlocked(argsPack)) + return true; + + if (DFInt::LuauTypeSolverRelease >= 648) + { + // This is expensive as we need to traverse a (potentially large) + // literal up front in order to determine if there are any blocked + // types, otherwise we may run `matchTypeLiteral` multiple times, + // which right now may fail due to being non-idempotent (it + // destructively updates the underlying literal type). + auto blockedTypes = findBlockedArgTypesIn(c.callSite, c.astTypes); + for (const auto ty : blockedTypes) + { + block(ty, constraint); + } + if (!blockedTypes.empty()) + return false; + } + + // We know the type of the function and the arguments it expects to receive. + // We also know the TypeIds of the actual arguments that will be passed. + // + // Bidirectional type checking: Force those TypeIds to be the expected + // arguments. If something is incoherent, we'll spot it in type checking. + // + // Most important detail: If a function argument is a lambda, we also want + // to force unannotated argument types of that lambda to be the expected + // types. + + // FIXME: Bidirectional type checking of overloaded functions is not yet supported. + const FunctionType* ftv = get(fn); + if (!ftv) + return true; + + DenseHashMap replacements{nullptr}; + DenseHashMap replacementPacks{nullptr}; + + for (auto generic : ftv->generics) + replacements[generic] = builtinTypes->unknownType; + + for (auto genericPack : ftv->genericPacks) + replacementPacks[genericPack] = builtinTypes->unknownTypePack; + + // If the type of the function has generics, we don't actually want to push any of the generics themselves + // into the argument types as expected types because this creates an unnecessary loop. Instead, we want to + // replace these types with `unknown` (and `...unknown`) to keep any structure but not create the cycle. + if (!replacements.empty() || !replacementPacks.empty()) { - for (const auto& blockedConstraint : blockedIt->second) + Replacer replacer{arena, std::move(replacements), std::move(replacementPacks)}; + + std::optional res = replacer.substitute(fn); + if (res) + { + if (*res != fn) + { + FunctionType* ftvMut = getMutable(*res); + LUAU_ASSERT(ftvMut); + ftvMut->generics.clear(); + ftvMut->genericPacks.clear(); + } + + fn = *res; + ftv = get(*res); + LUAU_ASSERT(ftv); + + // we've potentially copied type functions here, so we need to reproduce their reduce constraint. + reproduceConstraints(constraint->scope, constraint->location, replacer); + } + } + + const std::vector expectedArgs = flatten(ftv->argTypes).first; + const std::vector argPackHead = flatten(argsPack).first; + + // If this is a self call, the types will have more elements than the AST call. + // We don't attempt to perform bidirectional inference on the self type. + const size_t typeOffset = c.callSite->self ? 1 : 0; + + for (size_t i = 0; i < c.callSite->args.size && i + typeOffset < expectedArgs.size() && i + typeOffset < argPackHead.size(); ++i) + { + const TypeId expectedArgTy = follow(expectedArgs[i + typeOffset]); + const TypeId actualArgTy = follow(argPackHead[i + typeOffset]); + AstExpr* expr = unwrapGroup(c.callSite->args.data[i]); + + (*c.astExpectedTypes)[expr] = expectedArgTy; + + const FunctionType* expectedLambdaTy = get(expectedArgTy); + const FunctionType* lambdaTy = get(actualArgTy); + const AstExprFunction* lambdaExpr = expr->as(); + + if (expectedLambdaTy && lambdaTy && lambdaExpr) + { + const std::vector expectedLambdaArgTys = flatten(expectedLambdaTy->argTypes).first; + const std::vector lambdaArgTys = flatten(lambdaTy->argTypes).first; + + for (size_t j = 0; j < expectedLambdaArgTys.size() && j < lambdaArgTys.size() && j < lambdaExpr->args.size; ++j) + { + if (!lambdaExpr->args.data[j]->annotation && get(follow(lambdaArgTys[j]))) + { + shiftReferences(lambdaArgTys[j], expectedLambdaArgTys[j]); + bind(constraint, lambdaArgTys[j], expectedLambdaArgTys[j]); + } + } + } + else if (expr->is() || expr->is() || expr->is() || expr->is()) + { + Unifier2 u2{arena, builtinTypes, constraint->scope, NotNull{&iceReporter}}; + u2.unify(actualArgTy, expectedArgTy); + } + else if (expr->is()) { - if (ic) - block(NotNull{ic}, blockedConstraint); - if (sc) - block(NotNull{sc}, blockedConstraint); + Unifier2 u2{arena, builtinTypes, constraint->scope, NotNull{&iceReporter}}; + std::vector toBlock; + (void)matchLiteralType(c.astTypes, c.astExpectedTypes, builtinTypes, arena, NotNull{&u2}, expectedArgTy, actualArgTy, expr, toBlock); + if (DFInt::LuauTypeSolverRelease >= 648) + { + LUAU_ASSERT(toBlock.empty()); + } + else + { + for (auto t : toBlock) + block(t, constraint); + if (!toBlock.empty()) + return false; + } } } - unblock(c.result); return true; } bool ConstraintSolver::tryDispatch(const PrimitiveTypeConstraint& c, NotNull constraint) { - TypeId expectedType = follow(c.expectedType); - if (isBlocked(expectedType) || get(expectedType)) - return block(expectedType, constraint); + std::optional expectedType = c.expectedType ? std::make_optional(follow(*c.expectedType)) : std::nullopt; + if (expectedType && (isBlocked(*expectedType) || get(*expectedType))) + return block(*expectedType, constraint); - TypeId bindTo = maybeSingleton(expectedType) ? c.singletonType : c.multitonType; - asMutable(c.resultType)->ty.emplace(bindTo); + const FreeType* freeType = get(follow(c.freeType)); + + // if this is no longer a free type, then we're done. + if (!freeType) + return true; + + // We will wait if there are any other references to the free type mentioned here. + // This is probably the only thing that makes this not insane to do. + if (auto refCount = unresolvedConstraints.find(c.freeType); refCount && *refCount > 1) + { + block(c.freeType, constraint); + return false; + } + + TypeId bindTo = c.primitiveType; + + if (freeType->upperBound != c.primitiveType && maybeSingleton(freeType->upperBound)) + bindTo = freeType->lowerBound; + else if (expectedType && maybeSingleton(*expectedType)) + bindTo = freeType->lowerBound; + + if (DFInt::LuauTypeSolverRelease >= 645) + { + auto ty = follow(c.freeType); + shiftReferences(ty, bindTo); + bind(constraint, ty, bindTo); + } + else + { + shiftReferences(c.freeType, bindTo); + bind(constraint, c.freeType, bindTo); + } return true; } bool ConstraintSolver::tryDispatch(const HasPropConstraint& c, NotNull constraint) { - TypeId subjectType = follow(c.subjectType); + const TypeId subjectType = follow(c.subjectType); + const TypeId resultType = follow(c.resultType); - if (isBlocked(subjectType) || get(subjectType)) + LUAU_ASSERT(get(resultType)); + LUAU_ASSERT(canMutate(resultType, constraint)); + + if (isBlocked(subjectType) || get(subjectType) || get(subjectType)) return block(subjectType, constraint); - if (get(subjectType)) + if (const TableType* subjectTable = getTableType(subjectType)) { - TableType& ttv = asMutable(subjectType)->ty.emplace(TableState::Free, TypeLevel{}, constraint->scope); - ttv.props[c.prop] = Property{c.resultType}; - asMutable(c.resultType)->ty.emplace(constraint->scope); - unblock(c.resultType); - return true; + if (subjectTable->state == TableState::Unsealed && subjectTable->remainingProps > 0 && subjectTable->props.count(c.prop) == 0) + { + return block(subjectType, constraint); + } } - subjectType = reducer->reduce(subjectType).value_or(subjectType); - - auto [blocked, result] = lookupTableProp(subjectType, c.prop); + auto [blocked, result] = lookupTableProp(constraint, subjectType, c.prop, c.context, c.inConditional, c.suppressSimplification); if (!blocked.empty()) { for (TypeId blocked : blocked) @@ -1323,261 +1530,491 @@ bool ConstraintSolver::tryDispatch(const HasPropConstraint& c, NotNullty.emplace(result.value_or(builtinTypes->errorRecoveryType())); - unblock(c.resultType); + bind(constraint, resultType, result.value_or(builtinTypes->anyType)); return true; } -static bool isUnsealedTable(TypeId ty) +bool ConstraintSolver::tryDispatchHasIndexer( + int& recursionDepth, + NotNull constraint, + TypeId subjectType, + TypeId indexType, + TypeId resultType, + Set& seen +) { - ty = follow(ty); - const TableType* ttv = get(ty); - return ttv && ttv->state == TableState::Unsealed; -} + RecursionLimiter _rl{&recursionDepth, FInt::LuauSolverRecursionLimit}; -/** - * Create a shallow copy of `ty` and its properties along `path`. Insert a new - * property (the last segment of `path`) into the tail table with the value `t`. - * - * On success, returns the new outermost table type. If the root table or any - * of its subkeys are not unsealed tables, the function fails and returns - * std::nullopt. - * - * TODO: Prove that we completely give up in the face of indexers and - * metatables. - */ -static std::optional updateTheTableType(NotNull arena, TypeId ty, const std::vector& path, TypeId replaceTy) -{ - if (path.empty()) - return std::nullopt; + subjectType = follow(subjectType); + indexType = follow(indexType); - // First walk the path and ensure that it's unsealed tables all the way - // to the end. + if (seen.contains(subjectType)) + return false; + seen.insert(subjectType); + + LUAU_ASSERT(get(resultType)); + LUAU_ASSERT(canMutate(resultType, constraint)); + + if (get(subjectType)) { - TypeId t = ty; - for (size_t i = 0; i < path.size() - 1; ++i) + bind(constraint, resultType, builtinTypes->anyType); + return true; + } + + if (auto ft = get(subjectType)) + { + if (auto tbl = get(follow(ft->upperBound)); tbl && tbl->indexer) { - if (!isUnsealedTable(t)) - return std::nullopt; + unify(constraint, indexType, tbl->indexer->indexType); + bind(constraint, resultType, tbl->indexer->indexResultType); + return true; + } + else if (auto mt = get(follow(ft->upperBound))) + return tryDispatchHasIndexer(recursionDepth, constraint, mt->table, indexType, resultType, seen); - const TableType* tbl = get(t); - auto it = tbl->props.find(path[i]); - if (it == tbl->props.end()) - return std::nullopt; + FreeType freeResult{ft->scope, builtinTypes->neverType, builtinTypes->unknownType}; + emplace(constraint, resultType, freeResult); + + TypeId upperBound = + arena->addType(TableType{/* props */ {}, TableIndexer{indexType, resultType}, TypeLevel{}, ft->scope, TableState::Unsealed}); + + unify(constraint, subjectType, upperBound); - t = it->second.type; + return true; + } + else if (auto tt = getMutable(subjectType)) + { + if (auto indexer = tt->indexer) + { + unify(constraint, indexType, indexer->indexType); + bind(constraint, resultType, indexer->indexResultType); + return true; } - // The last path segment should not be a property of the table at all. - // We are not changing property types. We are only admitting this one - // new property to be appended. - if (!isUnsealedTable(t)) - return std::nullopt; - const TableType* tbl = get(t); - if (0 != tbl->props.count(path.back())) - return std::nullopt; + if (tt->state == TableState::Unsealed) + { + // FIXME this is greedy. + + FreeType freeResult{tt->scope, builtinTypes->neverType, builtinTypes->unknownType}; + emplace(constraint, resultType, freeResult); + + tt->indexer = TableIndexer{indexType, resultType}; + return true; + } + } + else if (auto mt = get(subjectType)) + return tryDispatchHasIndexer(recursionDepth, constraint, mt->table, indexType, resultType, seen); + else if (auto ct = get(subjectType)) + { + if (auto indexer = ct->indexer) + { + unify(constraint, indexType, indexer->indexType); + bind(constraint, resultType, indexer->indexResultType); + return true; + } + else if (isString(indexType)) + { + bind(constraint, resultType, builtinTypes->unknownType); + return true; + } } + else if (auto it = get(subjectType)) + { + // subjectType <: {[indexType]: resultType} + // + // 'a & ~(false | nil) <: {[indexType]: resultType} + // + // 'a <: {[indexType]: resultType} + // ~(false | nil) <: {[indexType]: resultType} + + Set parts{nullptr}; + for (TypeId part : it) + parts.insert(follow(part)); + + Set results{nullptr}; + + for (TypeId part : parts) + { + TypeId r = arena->addType(BlockedType{}); + getMutable(r)->setOwner(const_cast(constraint.get())); - const TypeId res = shallowClone(ty, arena); - TypeId t = res; + bool ok = tryDispatchHasIndexer(recursionDepth, constraint, part, indexType, r, seen); + // If we've cut a recursive loop short, skip it. + if (!ok) + continue; + + r = follow(r); + if (!get(r)) + results.insert(r); + } - for (size_t i = 0; i < path.size() - 1; ++i) + if (0 == results.size()) + bind(constraint, resultType, builtinTypes->errorType); + else if (1 == results.size()) + bind(constraint, resultType, *results.begin()); + else + emplace(constraint, resultType, std::vector(results.begin(), results.end())); + + return true; + } + else if (auto ut = get(subjectType)) { - const std::string segment = path[i]; + Set parts{nullptr}; + for (TypeId part : ut) + parts.insert(follow(part)); - TableType* ttv = getMutable(t); - LUAU_ASSERT(ttv); + Set results{nullptr}; - auto propIt = ttv->props.find(segment); - if (propIt != ttv->props.end()) + for (TypeId part : parts) { - LUAU_ASSERT(isUnsealedTable(propIt->second.type)); - t = shallowClone(follow(propIt->second.type), arena); - ttv->props[segment].type = t; + TypeId r = arena->addType(BlockedType{}); + getMutable(r)->setOwner(const_cast(constraint.get())); + + bool ok = tryDispatchHasIndexer(recursionDepth, constraint, part, indexType, r, seen); + // If we've cut a recursive loop short, skip it. + if (!ok) + continue; + + r = follow(r); + if (!get(r)) + results.insert(r); + } + + if (0 == results.size()) + bind(constraint, resultType, builtinTypes->errorType); + else if (1 == results.size()) + { + TypeId firstResult = *results.begin(); + shiftReferences(resultType, firstResult); + bind(constraint, resultType, firstResult); } else - return std::nullopt; + emplace(constraint, resultType, std::vector(results.begin(), results.end())); + + return true; } - TableType* ttv = getMutable(t); - LUAU_ASSERT(ttv); + bind(constraint, resultType, builtinTypes->errorType); - const std::string lastSegment = path.back(); - LUAU_ASSERT(0 == ttv->props.count(lastSegment)); - ttv->props[lastSegment] = Property{replaceTy}; - return res; + return true; } -bool ConstraintSolver::tryDispatch(const SetPropConstraint& c, NotNull constraint, bool force) +namespace { - TypeId subjectType = follow(c.subjectType); + +struct BlockedTypeFinder : TypeOnceVisitor +{ + std::optional blocked; + + bool visit(TypeId ty) override + { + // If we've already found one, stop traversing. + return !blocked.has_value(); + } + + bool visit(TypeId ty, const BlockedType&) override + { + blocked = ty; + return false; + } +}; + +} // namespace + +bool ConstraintSolver::tryDispatch(const HasIndexerConstraint& c, NotNull constraint) +{ + const TypeId subjectType = follow(c.subjectType); + const TypeId indexType = follow(c.indexType); if (isBlocked(subjectType)) return block(subjectType, constraint); - if (!force && get(subjectType)) - return block(subjectType, constraint); + if (isBlocked(indexType)) + return block(indexType, constraint); - std::optional existingPropType = subjectType; - for (const std::string& segment : c.path) - { - if (!existingPropType) - break; + BlockedTypeFinder btf; - auto [blocked, result] = lookupTableProp(*existingPropType, segment); - if (!blocked.empty()) - { - for (TypeId blocked : blocked) - block(blocked, constraint); - return false; - } + btf.visit(subjectType); - existingPropType = result; - } + if (btf.blocked) + return block(*btf.blocked, constraint); + int recursionDepth = 0; - auto bind = [](TypeId a, TypeId b) { - asMutable(a)->ty.emplace(b); - }; + Set seen{nullptr}; - if (existingPropType) - { - if (!isBlocked(c.propType)) - unify(c.propType, *existingPropType, constraint->scope); - bind(c.resultType, c.subjectType); - return true; - } + return tryDispatchHasIndexer(recursionDepth, constraint, subjectType, indexType, c.resultType, seen); +} + +bool ConstraintSolver::tryDispatch(const AssignPropConstraint& c, NotNull constraint) +{ + TypeId lhsType = follow(c.lhsType); + const std::string& propName = c.propName; + const TypeId rhsType = follow(c.rhsType); + + if (isBlocked(lhsType)) + return block(lhsType, constraint); + + // 1. lhsType is a class that already has the prop + // 2. lhsType is a table that already has the prop (or a union or + // intersection that has the prop in aggregate) + // 3. lhsType has a metatable that already has the prop + // 4. lhsType is an unsealed table that does not have the prop, but has a + // string indexer + // 5. lhsType is an unsealed table that does not have the prop or a string + // indexer - if (auto mt = get(subjectType)) - subjectType = follow(mt->table); + // Important: In every codepath through this function, the type `c.propType` + // must be bound to something, even if it's just the errorType. - if (get(subjectType) || get(subjectType) || get(subjectType)) + if (auto lhsClass = get(lhsType)) { - bind(c.resultType, subjectType); + const Property* prop = lookupClassProp(lhsClass, propName); + if (!prop || !prop->writeTy.has_value()) + { + bind(constraint, c.propType, builtinTypes->anyType); + return true; + } + + bind(constraint, c.propType, *prop->writeTy); + unify(constraint, rhsType, *prop->writeTy); return true; } - if (get(subjectType)) + if (auto lhsFree = getMutable(lhsType)) { - TypeId ty = arena->freshType(constraint->scope); - - // Mint a chain of free tables per c.path - for (auto it = rbegin(c.path); it != rend(c.path); ++it) + auto lhsFreeUpperBound = DFInt::LuauTypeSolverRelease >= 648 ? follow(lhsFree->upperBound) : lhsFree->upperBound; + if (get(lhsFreeUpperBound) || get(lhsFreeUpperBound)) + lhsType = lhsFreeUpperBound; + else { - TableType t{TableState::Free, TypeLevel{}, constraint->scope}; - t.props[*it] = {ty}; + TypeId newUpperBound = arena->addType(TableType{TableState::Free, TypeLevel{}, constraint->scope}); + TableType* upperTable = getMutable(newUpperBound); + LUAU_ASSERT(upperTable); - ty = arena->addType(std::move(t)); + upperTable->props[c.propName] = rhsType; + + // Food for thought: Could we block if simplification encounters a blocked type? + lhsFree->upperBound = simplifyIntersection(builtinTypes, arena, lhsFreeUpperBound, newUpperBound).result; + + bind(constraint, c.propType, rhsType); + return true; } + } - LUAU_ASSERT(ty); + // Handle the case that lhsType is a table that already has the property or + // a matching indexer. This also handles unions and intersections. + const auto [blocked, maybeTy] = lookupTableProp(constraint, lhsType, propName, ValueContext::LValue); + if (!blocked.empty()) + { + for (TypeId t : blocked) + block(t, constraint); + return false; + } - bind(subjectType, ty); - if (follow(c.resultType) != follow(ty)) - bind(c.resultType, ty); + if (maybeTy) + { + const TypeId propTy = *maybeTy; + bind(constraint, c.propType, propTy); + unify(constraint, rhsType, propTy); return true; } - else if (auto ttv = getMutable(subjectType)) + + if (auto lhsMeta = get(lhsType)) + lhsType = follow(lhsMeta->table); + + // Handle the case where the lhs type is a table that does not have the + // named property. It could be a table with a string indexer, or an unsealed + // or free table that can grow. + if (auto lhsTable = getMutable(lhsType)) { - if (ttv->state == TableState::Free) + if (auto it = lhsTable->props.find(propName); it != lhsTable->props.end()) { - LUAU_ASSERT(!subjectType->persistent); + Property& prop = it->second; - ttv->props[c.path[0]] = Property{c.propType}; - bind(c.resultType, c.subjectType); - return true; + if (prop.writeTy.has_value()) + { + bind(constraint, c.propType, *prop.writeTy); + unify(constraint, rhsType, *prop.writeTy); + return true; + } + else + { + LUAU_ASSERT(prop.isReadOnly()); + if (lhsTable->state == TableState::Unsealed || lhsTable->state == TableState::Free) + { + prop.writeTy = prop.readTy; + bind(constraint, c.propType, *prop.writeTy); + unify(constraint, rhsType, *prop.writeTy); + return true; + } + else + { + bind(constraint, c.propType, builtinTypes->errorType); + return true; + } + } } - else if (ttv->state == TableState::Unsealed) - { - LUAU_ASSERT(!subjectType->persistent); - std::optional augmented = updateTheTableType(NotNull{arena}, subjectType, c.path, c.propType); - bind(c.resultType, augmented.value_or(subjectType)); - bind(subjectType, c.resultType); + if (lhsTable->indexer && maybeString(lhsTable->indexer->indexType)) + { + bind(constraint, c.propType, rhsType); + unify(constraint, rhsType, lhsTable->indexer->indexResultType); return true; } - else + + if (lhsTable->state == TableState::Unsealed || lhsTable->state == TableState::Free) { - bind(c.resultType, subjectType); + bind(constraint, c.propType, rhsType); + Property& newProp = lhsTable->props[propName]; + newProp.readTy = rhsType; + newProp.writeTy = rhsType; + newProp.location = c.propLocation; + + if (lhsTable->state == TableState::Unsealed && c.decrementPropCount) + { + LUAU_ASSERT(lhsTable->remainingProps > 0); + lhsTable->remainingProps -= 1; + } + return true; } } - else if (get(subjectType)) - { - // Classes and intersections never change shape as a result of property - // assignments. The result is always the subject. - bind(c.resultType, subjectType); - return true; - } - LUAU_ASSERT(0); + bind(constraint, c.propType, builtinTypes->errorType); + return true; } -bool ConstraintSolver::tryDispatch(const SetIndexerConstraint& c, NotNull constraint, bool force) +bool ConstraintSolver::tryDispatch(const AssignIndexConstraint& c, NotNull constraint) { - TypeId subjectType = follow(c.subjectType); - if (isBlocked(subjectType)) - return block(subjectType, constraint); + const TypeId lhsType = follow(c.lhsType); + const TypeId indexType = follow(c.indexType); + const TypeId rhsType = follow(c.rhsType); - if (auto ft = get(subjectType)) - { - Scope* scope = ft->scope; - TableType* tt = &asMutable(subjectType)->ty.emplace(TableState::Free, TypeLevel{}, scope); - tt->indexer = TableIndexer{c.indexType, c.propType}; + if (isBlocked(lhsType)) + return block(lhsType, constraint); - asMutable(c.resultType)->ty.emplace(subjectType); - asMutable(c.propType)->ty.emplace(scope); - unblock(c.propType); - unblock(c.resultType); + // 0. lhsType could be an intersection or union. + // 1. lhsType is a class with an indexer + // 2. lhsType is a table with an indexer, or it has a metatable that has an indexer + // 3. lhsType is a free or unsealed table and can grow an indexer - return true; - } - else if (auto tt = get(subjectType)) + // Important: In every codepath through this function, the type `c.propType` + // must be bound to something, even if it's just the errorType. + + auto tableStuff = [&](TableType* lhsTable) -> std::optional { - if (tt->indexer) + if (lhsTable->indexer) { - // TODO This probably has to be invariant. - unify(c.indexType, tt->indexer->indexType, constraint->scope); - asMutable(c.propType)->ty.emplace(tt->indexer->indexResultType); - asMutable(c.resultType)->ty.emplace(subjectType); - unblock(c.propType); - unblock(c.resultType); + unify(constraint, indexType, lhsTable->indexer->indexType); + unify(constraint, rhsType, lhsTable->indexer->indexResultType); + bind(constraint, c.propType, lhsTable->indexer->indexResultType); return true; } - else if (tt->state == TableState::Free || tt->state == TableState::Unsealed) + + if (lhsTable->state == TableState::Unsealed || lhsTable->state == TableState::Free) { - auto mtt = getMutable(subjectType); - mtt->indexer = TableIndexer{c.indexType, c.propType}; - asMutable(c.propType)->ty.emplace(tt->scope); - asMutable(c.resultType)->ty.emplace(subjectType); - unblock(c.propType); - unblock(c.resultType); + lhsTable->indexer = TableIndexer{indexType, rhsType}; + bind(constraint, c.propType, rhsType); return true; } - // Do not augment sealed or generic tables that lack indexers + + return {}; + }; + + if (auto lhsFree = getMutable(lhsType)) + { + if (auto lhsTable = getMutable(lhsFree->upperBound)) + { + if (auto res = tableStuff(lhsTable)) + return *res; + } + + TypeId newUpperBound = + arena->addType(TableType{/*props*/ {}, TableIndexer{indexType, rhsType}, TypeLevel{}, constraint->scope, TableState::Free}); + const TableType* newTable = get(newUpperBound); + LUAU_ASSERT(newTable); + + unify(constraint, lhsType, newUpperBound); + + LUAU_ASSERT(newTable->indexer); + bind(constraint, c.propType, newTable->indexer->indexResultType); + return true; + } + + if (auto lhsTable = getMutable(lhsType)) + { + std::optional res = tableStuff(lhsTable); + if (res.has_value()) + return *res; + } + + if (auto lhsClass = get(lhsType)) + { + while (true) + { + if (lhsClass->indexer) + { + unify(constraint, indexType, lhsClass->indexer->indexType); + unify(constraint, rhsType, lhsClass->indexer->indexResultType); + bind(constraint, c.propType, lhsClass->indexer->indexResultType); + return true; + } + + if (lhsClass->parent) + lhsClass = get(lhsClass->parent); + else + break; + } + return true; } - asMutable(c.propType)->ty.emplace(builtinTypes->errorRecoveryType()); - asMutable(c.resultType)->ty.emplace(builtinTypes->errorRecoveryType()); - unblock(c.propType); - unblock(c.resultType); - return true; -} + if (auto lhsIntersection = getMutable(lhsType)) + { + std::set parts; -bool ConstraintSolver::tryDispatch(const SingletonOrTopTypeConstraint& c, NotNull constraint) -{ - if (isBlocked(c.discriminantType)) - return false; + for (TypeId t : lhsIntersection) + { + if (auto tbl = getMutable(follow(t))) + { + if (tbl->indexer) + { + unify(constraint, indexType, tbl->indexer->indexType); + parts.insert(tbl->indexer->indexResultType); + } - TypeId followed = follow(c.discriminantType); + if (tbl->state == TableState::Unsealed || tbl->state == TableState::Free) + { + tbl->indexer = TableIndexer{indexType, rhsType}; + parts.insert(rhsType); + } + } + else if (auto cls = get(follow(t))) + { + while (true) + { + if (cls->indexer) + { + unify(constraint, indexType, cls->indexer->indexType); + parts.insert(cls->indexer->indexResultType); + break; + } - // `nil` is a singleton type too! There's only one value of type `nil`. - if (c.negated && (get(followed) || isNil(followed))) - *asMutable(c.resultType) = NegationType{c.discriminantType}; - else if (!c.negated && get(followed)) - *asMutable(c.resultType) = BoundType{c.discriminantType}; - else - *asMutable(c.resultType) = BoundType{builtinTypes->unknownType}; + if (cls->parent) + cls = get(cls->parent); + else + break; + } + } + } + + TypeId res = simplifyIntersection(builtinTypes, arena, std::move(parts)).result; + + unify(constraint, rhsType, res); + } + + // Other types do not support index assignment. + bind(constraint, c.propType, builtinTypes->errorType); return true; } @@ -1585,114 +2022,209 @@ bool ConstraintSolver::tryDispatch(const SingletonOrTopTypeConstraint& c, NotNul bool ConstraintSolver::tryDispatch(const UnpackConstraint& c, NotNull constraint) { TypePackId sourcePack = follow(c.sourcePack); - TypePackId resultPack = follow(c.resultPack); if (isBlocked(sourcePack)) - return block(sourcePack, constraint); - - if (isBlocked(resultPack)) - { - asMutable(resultPack)->ty.emplace(sourcePack); - unblock(resultPack); - return true; - } + return block(sourcePack, constraint); - TypePack srcPack = extendTypePack(*arena, builtinTypes, sourcePack, size(resultPack)); + TypePack srcPack = extendTypePack(*arena, builtinTypes, sourcePack, c.resultPack.size()); - auto destIter = begin(resultPack); - auto destEnd = end(resultPack); + auto resultIter = begin(c.resultPack); + auto resultEnd = end(c.resultPack); size_t i = 0; - while (destIter != destEnd) + while (resultIter != resultEnd) { if (i >= srcPack.head.size()) break; + TypeId srcTy = follow(srcPack.head[i]); + TypeId resultTy = follow(*resultIter); - if (isBlocked(*destIter)) + LUAU_ASSERT(get(resultTy)); + LUAU_ASSERT(canMutate(resultTy, constraint)); + + if (get(resultTy)) { - if (follow(srcTy) == *destIter) + if (follow(srcTy) == resultTy) { - // Cyclic type dependency. (????) - asMutable(*destIter)->ty.emplace(constraint->scope); + // It is sometimes the case that we find that a blocked type + // is only blocked on itself. This doesn't actually + // constitute any meaningful constraint, so we replace it + // with a free type. + TypeId f = freshType(arena, builtinTypes, constraint->scope); + shiftReferences(resultTy, f); + emplaceType(asMutable(resultTy), f); } else - asMutable(*destIter)->ty.emplace(srcTy); - unblock(*destIter); + bind(constraint, resultTy, srcTy); } else - unify(*destIter, srcTy, constraint->scope); + unify(constraint, srcTy, resultTy); + + unblock(resultTy, constraint->location); - ++destIter; + ++resultIter; ++i; } // We know that resultPack does not have a tail, but we don't know if // sourcePack is long enough to fill every value. Replace every remaining - // result TypeId with the error recovery type. + // result TypeId with `nil`. - while (destIter != destEnd) + while (resultIter != resultEnd) { - if (isBlocked(*destIter)) + TypeId resultTy = follow(*resultIter); + LUAU_ASSERT(canMutate(resultTy, constraint)); + if (get(resultTy) || get(resultTy)) { - asMutable(*destIter)->ty.emplace(builtinTypes->errorRecoveryType()); - unblock(*destIter); + bind(constraint, resultTy, builtinTypes->nilType); } - ++destIter; + ++resultIter; } return true; } -bool ConstraintSolver::tryDispatchIterableTable(TypeId iteratorTy, const IterableConstraint& c, NotNull constraint, bool force) +bool ConstraintSolver::tryDispatch(const ReduceConstraint& c, NotNull constraint, bool force) { - auto block_ = [&](auto&& t) { - if (force) + TypeId ty = follow(c.ty); + FunctionGraphReductionResult result = + reduceTypeFunctions(ty, constraint->location, TypeFunctionContext{NotNull{this}, constraint->scope, constraint}, force); + + for (TypeId r : result.reducedTypes) + unblock(r, constraint->location); + + for (TypePackId r : result.reducedPacks) + unblock(r, constraint->location); + + bool reductionFinished = result.blockedTypes.empty() && result.blockedPacks.empty(); + + ty = follow(ty); + // If we couldn't reduce this type function, stick it in the set! + if (get(ty)) + typeFunctionsToFinalize[ty] = constraint; + + if (force || reductionFinished) + { + // if we're completely dispatching this constraint, we want to record any uninhabited type functions to unblock. + for (auto error : result.errors) { - // TODO: I believe it is the case that, if we are asked to force - // this constraint, then we can do nothing but fail. I'd like to - // find a code sample that gets here. - LUAU_ASSERT(false); + if (auto utf = get(error)) + uninhabitedTypeFunctions.insert(utf->ty); + else if (auto utpf = get(error)) + uninhabitedTypeFunctions.insert(utpf->tp); } - else - block(t, constraint); - return false; - }; + } + + if (force) + return true; + + for (TypeId b : result.blockedTypes) + block(b, constraint); + + for (TypePackId b : result.blockedPacks) + block(b, constraint); + + return reductionFinished; +} + +bool ConstraintSolver::tryDispatch(const ReducePackConstraint& c, NotNull constraint, bool force) +{ + TypePackId tp = follow(c.tp); + FunctionGraphReductionResult result = + reduceTypeFunctions(tp, constraint->location, TypeFunctionContext{NotNull{this}, constraint->scope, constraint}, force); + + for (TypeId r : result.reducedTypes) + unblock(r, constraint->location); + + for (TypePackId r : result.reducedPacks) + unblock(r, constraint->location); + + bool reductionFinished = result.blockedTypes.empty() && result.blockedPacks.empty(); + + if (force || reductionFinished) + { + // if we're completely dispatching this constraint, we want to record any uninhabited type functions to unblock. + for (auto error : result.errors) + { + if (auto utf = get(error)) + uninhabitedTypeFunctions.insert(utf->ty); + else if (auto utpf = get(error)) + uninhabitedTypeFunctions.insert(utpf->tp); + } + } + + if (force) + return true; + + for (TypeId b : result.blockedTypes) + block(b, constraint); + + for (TypePackId b : result.blockedPacks) + block(b, constraint); + + return reductionFinished; +} - // We may have to block here if we don't know what the iteratee type is, - // if it's a free table, if we don't know it has a metatable, and so on. +bool ConstraintSolver::tryDispatch(const EqualityConstraint& c, NotNull constraint) +{ + unify(constraint, c.resultType, c.assignmentType); + unify(constraint, c.assignmentType, c.resultType); + return true; +} + +bool ConstraintSolver::tryDispatchIterableTable(TypeId iteratorTy, const IterableConstraint& c, NotNull constraint, bool force) +{ iteratorTy = follow(iteratorTy); + if (get(iteratorTy)) - return block_(iteratorTy); + { + TypeId keyTy = freshType(arena, builtinTypes, constraint->scope); + TypeId valueTy = freshType(arena, builtinTypes, constraint->scope); + TypeId tableTy = arena->addType(TableType{TableState::Sealed, {}, constraint->scope}); + getMutable(tableTy)->indexer = TableIndexer{keyTy, valueTy}; - auto anyify = [&](auto ty) { - Anyification anyify{arena, constraint->scope, builtinTypes, &iceReporter, builtinTypes->anyType, builtinTypes->anyTypePack}; - std::optional anyified = anyify.substitute(ty); - if (!anyified) - reportError(CodeTooComplex{}, constraint->location); - else - unify(*anyified, ty, constraint->scope); - }; + pushConstraint(constraint->scope, constraint->location, SubtypeConstraint{iteratorTy, tableTy}); - auto errorify = [&](auto ty) { - Anyification anyify{arena, constraint->scope, builtinTypes, &iceReporter, errorRecoveryType(), errorRecoveryTypePack()}; - std::optional errorified = anyify.substitute(ty); - if (!errorified) - reportError(CodeTooComplex{}, constraint->location); - else - unify(*errorified, ty, constraint->scope); + auto it = begin(c.variables); + auto endIt = end(c.variables); + if (it != endIt) + { + bind(constraint, *it, keyTy); + ++it; + } + if (it != endIt) + bind(constraint, *it, valueTy); + + return true; + } + + auto unpack = [&](TypeId ty) + { + for (TypeId varTy : c.variables) + { + LUAU_ASSERT(get(varTy)); + LUAU_ASSERT(varTy != ty); + bind(constraint, varTy, ty); + } }; if (get(iteratorTy)) { - anyify(c.variables); + unpack(builtinTypes->anyType); return true; } if (get(iteratorTy)) { - errorify(c.variables); + unpack(builtinTypes->errorType); + return true; + } + + if (get(iteratorTy)) + { + unpack(builtinTypes->neverType); return true; } @@ -1701,16 +2233,33 @@ bool ConstraintSolver::tryDispatchIterableTable(TypeId iteratorTy, const Iterabl if (auto iteratorTable = get(iteratorTy)) { - if (iteratorTable->state == TableState::Free) - return block_(iteratorTy); + /* + * We try not to dispatch IterableConstraints over free tables because + * it's possible that there are other constraints on the table that will + * clarify what we should do. + * + * We should eventually introduce a type function to talk about iteration. + */ + if (iteratorTable->state == TableState::Free && !force) + return block(iteratorTy, constraint); if (iteratorTable->indexer) { - TypePackId expectedVariablePack = arena->addTypePack({iteratorTable->indexer->indexType, iteratorTable->indexer->indexResultType}); - unify(c.variables, expectedVariablePack, constraint->scope); + std::vector expectedVariables{iteratorTable->indexer->indexType, iteratorTable->indexer->indexResultType}; + while (c.variables.size() >= expectedVariables.size()) + expectedVariables.push_back(builtinTypes->errorRecoveryType()); + + for (size_t i = 0; i < c.variables.size(); ++i) + { + LUAU_ASSERT(c.variables[i] != expectedVariables[i]); + + unify(constraint, c.variables[i], expectedVariables[i]); + + bind(constraint, c.variables[i], expectedVariables[i]); + } } else - errorify(c.variables); + unpack(builtinTypes->errorType); } else if (std::optional iterFn = findMetatableEntry(builtinTypes, errors, iteratorTy, "__iter", Location{})) { @@ -1719,14 +2268,12 @@ bool ConstraintSolver::tryDispatchIterableTable(TypeId iteratorTy, const Iterabl return block(*iterFn, constraint); } - Instantiation instantiation(TxnLog::empty(), arena, TypeLevel{}, constraint->scope); - - if (std::optional instantiatedIterFn = instantiation.substitute(*iterFn)) + if (std::optional instantiatedIterFn = instantiate(builtinTypes, arena, NotNull{&limits}, constraint->scope, *iterFn)) { if (auto iterFtv = get(*instantiatedIterFn)) { TypePackId expectedIterArgs = arena->addTypePack({iteratorTy}); - unify(iterFtv->argTypes, expectedIterArgs, constraint->scope); + unify(constraint, iterFtv->argTypes, expectedIterArgs); TypePack iterRets = extendTypePack(*arena, builtinTypes, iterFtv->retTypes, 2); @@ -1738,21 +2285,16 @@ bool ConstraintSolver::tryDispatchIterableTable(TypeId iteratorTy, const Iterabl } TypeId nextFn = iterRets.head[0]; - TypeId table = iterRets.head.size() == 2 ? iterRets.head[1] : arena->freshType(constraint->scope); - if (std::optional instantiatedNextFn = instantiation.substitute(nextFn)) + if (std::optional instantiatedNextFn = instantiate(builtinTypes, arena, NotNull{&limits}, constraint->scope, nextFn)) { - const TypeId firstIndex = arena->freshType(constraint->scope); - - // nextTy : (iteratorTy, indexTy?) -> (indexTy, valueTailTy...) - const TypePackId nextArgPack = arena->addTypePack({table, arena->addType(UnionType{{firstIndex, builtinTypes->nilType}})}); - const TypePackId valueTailTy = arena->addTypePack(FreeTypePack{constraint->scope}); - const TypePackId nextRetPack = arena->addTypePack(TypePack{{firstIndex}, valueTailTy}); + const FunctionType* nextFn = get(*instantiatedNextFn); - const TypeId expectedNextTy = arena->addType(FunctionType{nextArgPack, nextRetPack}); - unify(*instantiatedNextFn, expectedNextTy, constraint->scope); + // If nextFn is nullptr, then the iterator function has an improper signature. + if (nextFn) + unpackAndAssign(c.variables, nextFn->retTypes, constraint); - pushConstraint(constraint->scope, constraint->location, PackSubtypeConstraint{c.variables, nextRetPack}); + return true; } else { @@ -1771,43 +2313,33 @@ bool ConstraintSolver::tryDispatchIterableTable(TypeId iteratorTy, const Iterabl } else if (auto iteratorMetatable = get(iteratorTy)) { - TypeId metaTy = follow(iteratorMetatable->metatable); - if (get(metaTy)) - return block_(metaTy); - - LUAU_ASSERT(false); + // If the metatable does not contain a `__iter` metamethod, then we iterate over the table part of the metatable. + return tryDispatchIterableTable(iteratorMetatable->table, c, constraint, force); } + else if (auto primitiveTy = get(iteratorTy); primitiveTy && primitiveTy->type == PrimitiveType::Type::Table) + unpack(builtinTypes->unknownType); else - errorify(c.variables); + { + unpack(builtinTypes->errorType); + } return true; } bool ConstraintSolver::tryDispatchIterableFunction( - TypeId nextTy, TypeId tableTy, TypeId firstIndexTy, const IterableConstraint& c, NotNull constraint, bool force) + TypeId nextTy, + TypeId tableTy, + const IterableConstraint& c, + NotNull constraint +) { - // We need to know whether or not this type is nil or not. - // If we don't know, block and reschedule ourselves. - firstIndexTy = follow(firstIndexTy); - if (get(firstIndexTy)) - { - if (force) - LUAU_ASSERT(false); - else - block(firstIndexTy, constraint); - return false; - } + const FunctionType* nextFn = get(nextTy); + // If this does not hold, we should've never called `tryDispatchIterableFunction` in the first place. + LUAU_ASSERT(nextFn); + const TypePackId nextRetPack = nextFn->retTypes; - const TypeId firstIndex = isNil(firstIndexTy) ? arena->freshType(constraint->scope) // FIXME: Surely this should be a union (free | nil) - : firstIndexTy; - - // nextTy : (tableTy, indexTy?) -> (indexTy?, valueTailTy...) - const TypePackId nextArgPack = arena->addTypePack({tableTy, arena->addType(UnionType{{firstIndex, builtinTypes->nilType}})}); - const TypePackId valueTailTy = arena->addTypePack(FreeTypePack{constraint->scope}); - const TypePackId nextRetPack = arena->addTypePack(TypePack{{firstIndex}, valueTailTy}); - - const TypeId expectedNextTy = arena->addType(FunctionType{TypeLevel{}, constraint->scope, nextArgPack, nextRetPack}); - unify(nextTy, expectedNextTy, constraint->scope); + // the type of the `nextAstFragment` is the `nextTy`. + (*c.astForInNextTypes)[c.nextAstFragment] = nextTy; auto it = begin(nextRetPack); std::vector modifiedNextRetHead; @@ -1827,21 +2359,58 @@ bool ConstraintSolver::tryDispatchIterableFunction( modifiedNextRetHead.push_back(*it); TypePackId modifiedNextRetPack = arena->addTypePack(std::move(modifiedNextRetHead), it.tail()); - pushConstraint(constraint->scope, constraint->location, PackSubtypeConstraint{c.variables, modifiedNextRetPack}); + + auto unpackConstraint = unpackAndAssign(c.variables, modifiedNextRetPack, constraint); + + inheritBlocks(constraint, unpackConstraint); return true; } -std::pair, std::optional> ConstraintSolver::lookupTableProp(TypeId subjectType, const std::string& propName) +NotNull ConstraintSolver::unpackAndAssign( + const std::vector destTypes, + TypePackId srcTypes, + NotNull constraint +) { - std::unordered_set seen; - return lookupTableProp(subjectType, propName, seen); + auto c = pushConstraint(constraint->scope, constraint->location, UnpackConstraint{destTypes, srcTypes}); + + for (TypeId t : destTypes) + { + BlockedType* bt = getMutable(t); + LUAU_ASSERT(bt); + bt->replaceOwner(c); + } + + return c; } -std::pair, std::optional> ConstraintSolver::lookupTableProp(TypeId subjectType, const std::string& propName, std::unordered_set& seen) +std::pair, std::optional> ConstraintSolver::lookupTableProp( + NotNull constraint, + TypeId subjectType, + const std::string& propName, + ValueContext context, + bool inConditional, + bool suppressSimplification +) { - if (!seen.insert(subjectType).second) + DenseHashSet seen{nullptr}; + return lookupTableProp(constraint, subjectType, propName, context, inConditional, suppressSimplification, seen); +} + +std::pair, std::optional> ConstraintSolver::lookupTableProp( + NotNull constraint, + TypeId subjectType, + const std::string& propName, + ValueContext context, + bool inConditional, + bool suppressSimplification, + DenseHashSet& seen +) +{ + if (seen.contains(subjectType)) return {}; + seen.insert(subjectType); subjectType = follow(subjectType); @@ -1854,19 +2423,64 @@ std::pair, std::optional> ConstraintSolver::lookupTa else if (auto ttv = getMutable(subjectType)) { if (auto prop = ttv->props.find(propName); prop != ttv->props.end()) - return {{}, prop->second.type}; - else if (ttv->indexer && maybeString(ttv->indexer->indexType)) + { + switch (context) + { + case ValueContext::RValue: + if (auto rt = prop->second.readTy) + return {{}, rt}; + break; + case ValueContext::LValue: + if (auto wt = prop->second.writeTy) + return {{}, wt}; + break; + } + } + + if (ttv->indexer && maybeString(ttv->indexer->indexType)) return {{}, ttv->indexer->indexResultType}; - else if (ttv->state == TableState::Free) + + if (ttv->state == TableState::Free) { - TypeId result = arena->freshType(ttv->scope); - ttv->props[propName] = Property{result}; + TypeId result = freshType(arena, builtinTypes, ttv->scope); + switch (context) + { + case ValueContext::RValue: + ttv->props[propName].readTy = result; + break; + case ValueContext::LValue: + if (auto it = ttv->props.find(propName); it != ttv->props.end() && it->second.isReadOnly()) + { + // We do infer read-only properties, but we do not infer + // separate read and write types. + // + // If we encounter a case where a free table has a read-only + // property that we subsequently sense a write to, we make + // the judgement that the property is read-write and that + // both the read and write types are the same. + + Property& prop = it->second; + + prop.writeTy = prop.readTy; + return {{}, *prop.readTy}; + } + else + ttv->props[propName] = Property::rw(result); + + break; + } return {{}, result}; } + + // if we are in a conditional context, we treat the property as present and `unknown` because + // we may be _refining_ a table to include that property. we will want to revisit this a bit + // in the future once luau has support for exact tables since this only applies when inexact. + if (inConditional) + return {{}, builtinTypes->unknownType}; } - else if (auto mt = get(subjectType)) + else if (auto mt = get(subjectType); mt && context == ValueContext::RValue) { - auto [blocked, result] = lookupTableProp(mt->table, propName, seen); + auto [blocked, result] = lookupTableProp(constraint, mt->table, propName, context, inConditional, suppressSimplification, seen); if (!blocked.empty() || result) return {blocked, result}; @@ -1882,18 +2496,34 @@ std::pair, std::optional> ConstraintSolver::lookupTa // TODO: __index can be an overloaded function. - TypeId indexType = follow(indexProp->second.type); + TypeId indexType = follow(indexProp->second.type()); if (auto ft = get(indexType)) - return {{}, first(ft->retTypes)}; + { + TypePack rets = extendTypePack(*arena, builtinTypes, ft->retTypes, 1); + if (1 == rets.head.size()) + return {{}, rets.head[0]}; + else + { + // This should probably be an error: We need the first result of the MT.__index method, + // but it returns 0 values. See CLI-68672 + return {{}, builtinTypes->nilType}; + } + } else - return lookupTableProp(indexType, propName, seen); + return lookupTableProp(constraint, indexType, propName, context, inConditional, suppressSimplification, seen); } + else if (get(mtt)) + return lookupTableProp(constraint, mtt, propName, context, inConditional, suppressSimplification, seen); } else if (auto ct = get(subjectType)) { if (auto p = lookupClassProp(ct, propName)) - return {{}, p->type}; + return {{}, context == ValueContext::RValue ? p->readTy : p->writeTy}; + if (ct->indexer) + { + return {{}, ct->indexer->indexResultType}; + } } else if (auto pt = get(subjectType); pt && pt->metatable) { @@ -1904,31 +2534,49 @@ std::pair, std::optional> ConstraintSolver::lookupTa if (indexProp == metatable->props.end()) return {{}, std::nullopt}; - return lookupTableProp(indexProp->second.type, propName, seen); + return lookupTableProp(constraint, indexProp->second.type(), propName, context, inConditional, suppressSimplification, seen); } else if (auto ft = get(subjectType)) { - Scope* scope = ft->scope; + const TypeId upperBound = follow(ft->upperBound); + + if (get(upperBound) || get(upperBound)) + return lookupTableProp(constraint, upperBound, propName, context, inConditional, suppressSimplification, seen); + + // TODO: The upper bound could be an intersection that contains suitable tables or classes. + + NotNull scope{ft->scope}; + + const TypeId newUpperBound = arena->addType(TableType{TableState::Free, TypeLevel{}, scope}); + TableType* tt = getMutable(newUpperBound); + LUAU_ASSERT(tt); + TypeId propType = freshType(arena, builtinTypes, scope); + + switch (context) + { + case ValueContext::RValue: + tt->props[propName] = Property::readonly(propType); + break; + case ValueContext::LValue: + tt->props[propName] = Property::rw(propType); + break; + } - TableType* tt = &asMutable(subjectType)->ty.emplace(); - tt->state = TableState::Free; - tt->scope = scope; - TypeId propType = arena->freshType(scope); - tt->props[propName] = Property{propType}; + unify(constraint, subjectType, newUpperBound); return {{}, propType}; } else if (auto utv = get(subjectType)) { std::vector blocked; - std::vector options; + std::set options; for (TypeId ty : utv) { - auto [innerBlocked, innerResult] = lookupTableProp(ty, propName, seen); + auto [innerBlocked, innerResult] = lookupTableProp(constraint, ty, propName, context, inConditional, suppressSimplification, seen); blocked.insert(blocked.end(), innerBlocked.begin(), innerBlocked.end()); if (innerResult) - options.push_back(*innerResult); + options.insert(*innerResult); } if (!blocked.empty()) @@ -1937,21 +2585,35 @@ std::pair, std::optional> ConstraintSolver::lookupTa if (options.empty()) return {{}, std::nullopt}; else if (options.size() == 1) - return {{}, options[0]}; + return {{}, *begin(options)}; + else if (options.size() == 2 && !suppressSimplification) + { + TypeId one = *begin(options); + TypeId two = *(++begin(options)); + + // if we're in an lvalue context, we need the _common_ type here. + if (context == ValueContext::LValue) + return {{}, simplifyIntersection(builtinTypes, arena, one, two).result}; + + return {{}, simplifyUnion(builtinTypes, arena, one, two).result}; + } + // if we're in an lvalue context, we need the _common_ type here. + else if (context == ValueContext::LValue) + return {{}, arena->addType(IntersectionType{std::vector(begin(options), end(options))})}; else - return {{}, arena->addType(UnionType{std::move(options)})}; + return {{}, arena->addType(UnionType{std::vector(begin(options), end(options))})}; } else if (auto itv = get(subjectType)) { std::vector blocked; - std::vector options; + std::set options; for (TypeId ty : itv) { - auto [innerBlocked, innerResult] = lookupTableProp(ty, propName, seen); + auto [innerBlocked, innerResult] = lookupTableProp(constraint, ty, propName, context, inConditional, suppressSimplification, seen); blocked.insert(blocked.end(), innerBlocked.begin(), innerBlocked.end()); if (innerResult) - options.push_back(*innerResult); + options.insert(*innerResult); } if (!blocked.empty()) @@ -1960,57 +2622,133 @@ std::pair, std::optional> ConstraintSolver::lookupTa if (options.empty()) return {{}, std::nullopt}; else if (options.size() == 1) - return {{}, options[0]}; + return {{}, *begin(options)}; + else if (options.size() == 2 && !suppressSimplification) + { + TypeId one = *begin(options); + TypeId two = *(++begin(options)); + return {{}, simplifyIntersection(builtinTypes, arena, one, two).result}; + } else - return {{}, arena->addType(IntersectionType{std::move(options)})}; + return {{}, arena->addType(IntersectionType{std::vector(begin(options), end(options))})}; + } + else if (auto pt = get(subjectType)) + { + // if we are in a conditional context, we treat the property as present and `unknown` because + // we may be _refining_ a table to include that property. we will want to revisit this a bit + // in the future once luau has support for exact tables since this only applies when inexact. + if (inConditional && pt->type == PrimitiveType::Table) + return {{}, builtinTypes->unknownType}; } return {{}, std::nullopt}; } -void ConstraintSolver::block_(BlockedConstraintId target, NotNull constraint) +template +bool ConstraintSolver::unify(NotNull constraint, TID subTy, TID superTy) +{ + Unifier2 u2{NotNull{arena}, builtinTypes, constraint->scope, NotNull{&iceReporter}, &uninhabitedTypeFunctions}; + + const bool ok = u2.unify(subTy, superTy); + + for (ConstraintV& c : u2.incompleteSubtypes) + { + NotNull addition = pushConstraint(constraint->scope, constraint->location, std::move(c)); + inheritBlocks(constraint, addition); + } + + if (ok) + { + for (const auto& [expanded, additions] : u2.expandedFreeTypes) + { + for (TypeId addition : additions) + upperBoundContributors[expanded].emplace_back(constraint->location, addition); + } + } + else + { + reportError(OccursCheckFailed{}, constraint->location); + return false; + } + + return true; +} + +bool ConstraintSolver::block_(BlockedConstraintId target, NotNull constraint) { - blocked[target].push_back(constraint); + // If a set is not present for the target, construct a new DenseHashSet for it, + // else grab the address of the existing set. + auto [iter, inserted] = blocked.try_emplace(target, nullptr); + auto& [key, blockVec] = *iter; - auto& count = blockedConstraints[constraint]; + if (blockVec.find(constraint)) + return false; + + blockVec.insert(constraint); + + size_t& count = blockedConstraints[constraint]; count += 1; + + return true; } void ConstraintSolver::block(NotNull target, NotNull constraint) { - if (logger) - logger->pushBlock(constraint, target); - - if (FFlag::DebugLuauLogSolver) - printf("block Constraint %s on\t%s\n", toString(*target, opts).c_str(), toString(*constraint, opts).c_str()); + const bool newBlock = block_(target.get(), constraint); + if (newBlock) + { + if (logger) + logger->pushBlock(constraint, target); - block_(target.get(), constraint); + if (FFlag::DebugLuauLogSolver) + printf("%s depends on constraint %s\n", toString(*constraint, opts).c_str(), toString(*target, opts).c_str()); + } } bool ConstraintSolver::block(TypeId target, NotNull constraint) { - if (logger) - logger->pushBlock(constraint, target); + const bool newBlock = block_(follow(target), constraint); + if (newBlock) + { + if (logger) + logger->pushBlock(constraint, target); - if (FFlag::DebugLuauLogSolver) - printf("block TypeId %s on\t%s\n", toString(target, opts).c_str(), toString(*constraint, opts).c_str()); + if (FFlag::DebugLuauLogSolver) + printf("%s depends on TypeId %s\n", toString(*constraint, opts).c_str(), toString(target, opts).c_str()); + } - block_(follow(target), constraint); return false; } bool ConstraintSolver::block(TypePackId target, NotNull constraint) { - if (logger) - logger->pushBlock(constraint, target); + const bool newBlock = block_(target, constraint); + if (newBlock) + { + if (logger) + logger->pushBlock(constraint, target); - if (FFlag::DebugLuauLogSolver) - printf("block TypeId %s on\t%s\n", toString(target, opts).c_str(), toString(*constraint, opts).c_str()); + if (FFlag::DebugLuauLogSolver) + printf("%s depends on TypePackId %s\n", toString(*constraint, opts).c_str(), toString(target, opts).c_str()); + } - block_(target, constraint); return false; } +void ConstraintSolver::inheritBlocks(NotNull source, NotNull addition) +{ + // Anything that is blocked on this constraint must also be blocked on our + // synthesized constraints. + auto blockedIt = blocked.find(source.get()); + if (blockedIt != blocked.end()) + { + for (const Constraint* blockedConstraint : blockedIt->second) + { + block(addition, NotNull{blockedConstraint}); + } + } +} + struct Blocker : TypeOnceVisitor { NotNull solver; @@ -2024,29 +2762,27 @@ struct Blocker : TypeOnceVisitor { } - bool visit(TypeId ty, const BlockedType&) + bool visit(TypeId ty, const PendingExpansionType&) override { blocked = true; solver->block(ty, constraint); return false; } - bool visit(TypeId ty, const PendingExpansionType&) + bool visit(TypeId ty, const ClassType&) override { - blocked = true; - solver->block(ty, constraint); return false; } }; -bool ConstraintSolver::recursiveBlock(TypeId target, NotNull constraint) +bool ConstraintSolver::blockOnPendingTypes(TypeId target, NotNull constraint) { Blocker blocker{NotNull{this}, constraint}; blocker.traverse(target); return !blocker.blocked; } -bool ConstraintSolver::recursiveBlock(TypePackId pack, NotNull constraint) +bool ConstraintSolver::blockOnPendingTypes(TypePackId pack, NotNull constraint) { Blocker blocker{NotNull{this}, constraint}; blocker.traverse(pack); @@ -2060,9 +2796,9 @@ void ConstraintSolver::unblock_(BlockedConstraintId progressed) return; // unblocked should contain a value always, because of the above check - for (NotNull unblockedConstraint : it->second) + for (const Constraint* unblockedConstraint : it->second) { - auto& count = blockedConstraints[unblockedConstraint]; + auto& count = blockedConstraints[NotNull{unblockedConstraint}]; if (FFlag::DebugLuauLogSolver) printf("Unblocking count=%d\t%s\n", int(count), toString(*unblockedConstraint, opts).c_str()); @@ -2085,18 +2821,30 @@ void ConstraintSolver::unblock(NotNull progressed) return unblock_(progressed.get()); } -void ConstraintSolver::unblock(TypeId progressed) +void ConstraintSolver::unblock(TypeId ty, Location location) { - if (logger) - logger->popBlock(progressed); + DenseHashSet seen{nullptr}; + + TypeId progressed = ty; + while (true) + { + if (seen.find(progressed)) + iceReporter.ice("ConstraintSolver::unblock encountered a self-bound type!", location); + seen.insert(progressed); + + if (logger) + logger->popBlock(progressed); - unblock_(progressed); + unblock_(progressed); - if (auto bt = get(progressed)) - unblock(bt->boundTo); + if (auto bt = get(progressed)) + progressed = bt->boundTo; + else + break; + } } -void ConstraintSolver::unblock(TypePackId progressed) +void ConstraintSolver::unblock(TypePackId progressed, Location) { if (logger) logger->popBlock(progressed); @@ -2104,70 +2852,57 @@ void ConstraintSolver::unblock(TypePackId progressed) return unblock_(progressed); } -void ConstraintSolver::unblock(const std::vector& types) +void ConstraintSolver::unblock(const std::vector& types, Location location) { for (TypeId t : types) - unblock(t); + unblock(t, location); } -void ConstraintSolver::unblock(const std::vector& packs) +void ConstraintSolver::unblock(const std::vector& packs, Location location) { for (TypePackId t : packs) - unblock(t); -} - -bool ConstraintSolver::isBlocked(TypeId ty) -{ - return nullptr != get(follow(ty)) || nullptr != get(follow(ty)); -} - -bool ConstraintSolver::isBlocked(TypePackId tp) -{ - return nullptr != get(follow(tp)); -} - -bool ConstraintSolver::isBlocked(NotNull constraint) -{ - auto blockedIt = blockedConstraints.find(constraint); - return blockedIt != blockedConstraints.end() && blockedIt->second > 0; + unblock(t, location); } -void ConstraintSolver::unify(TypeId subType, TypeId superType, NotNull scope) +void ConstraintSolver::reproduceConstraints(NotNull scope, const Location& location, const Substitution& subst) { - Unifier u{normalizer, Mode::Strict, scope, Location{}, Covariant}; - u.useScopes = true; - - u.tryUnify(subType, superType); + for (auto [_, newTy] : subst.newTypes) + { + if (get(newTy)) + pushConstraint(scope, location, ReduceConstraint{newTy}); + } - if (!u.errors.empty()) + for (auto [_, newPack] : subst.newPacks) { - TypeId errorType = errorRecoveryType(); - u.tryUnify(subType, errorType); - u.tryUnify(superType, errorType); + if (get(newPack)) + pushConstraint(scope, location, ReducePackConstraint{newPack}); } +} - const auto [changedTypes, changedPacks] = u.log.getChanges(); +bool ConstraintSolver::isBlocked(TypeId ty) const +{ + ty = follow(ty); - u.log.commit(); + if (auto tfit = get(ty)) + return uninhabitedTypeFunctions.contains(ty) == false; - unblock(changedTypes); - unblock(changedPacks); + return nullptr != get(ty) || nullptr != get(ty); } -void ConstraintSolver::unify(TypePackId subPack, TypePackId superPack, NotNull scope) +bool ConstraintSolver::isBlocked(TypePackId tp) const { - UnifierSharedState sharedState{&iceReporter}; - Unifier u{normalizer, Mode::Strict, scope, Location{}, Covariant}; - u.useScopes = true; + tp = follow(tp); - u.tryUnify(subPack, superPack); + if (auto tfitp = get(tp)) + return uninhabitedTypeFunctions.contains(tp) == false; - const auto [changedTypes, changedPacks] = u.log.getChanges(); - - u.log.commit(); + return nullptr != get(tp); +} - unblock(changedTypes); - unblock(changedPacks); +bool ConstraintSolver::isBlocked(NotNull constraint) const +{ + auto blockedIt = blockedConstraints.find(constraint); + return blockedIt != blockedConstraints.end() && blockedIt->second > 0; } NotNull ConstraintSolver::pushConstraint(NotNull scope, const Location& location, ConstraintV cv) @@ -2175,7 +2910,7 @@ NotNull ConstraintSolver::pushConstraint(NotNull scope, const std::unique_ptr c = std::make_unique(scope, location, std::move(cv)); NotNull borrow = NotNull(c.get()); solverConstraints.push_back(std::move(c)); - unsolvedConstraints.push_back(borrow); + unsolvedConstraints.emplace_back(borrow); return borrow; } @@ -2188,11 +2923,9 @@ TypeId ConstraintSolver::resolveModule(const ModuleInfo& info, const Location& l return errorRecoveryType(); } - std::string humanReadableName = moduleResolver->getHumanReadableModuleName(info.name); - for (const auto& [location, path] : requireCycles) { - if (!path.empty() && path.front() == humanReadableName) + if (!path.empty() && path.front() == info.name) return builtinTypes->anyType; } @@ -2200,14 +2933,14 @@ TypeId ConstraintSolver::resolveModule(const ModuleInfo& info, const Location& l if (!module) { if (!moduleResolver->moduleExists(info.name) && !info.optional) - reportError(UnknownRequire{humanReadableName}, location); + reportError(UnknownRequire{moduleResolver->getHumanReadableModuleName(info.name)}, location); return errorRecoveryType(); } if (module->type != SourceCode::Type::Module) { - reportError(IllegalRequire{humanReadableName, "Module is not a ModuleScript. It cannot be required."}, location); + reportError(IllegalRequire{module->humanReadableName, "Module is not a ModuleScript. It cannot be required."}, location); return errorRecoveryType(); } @@ -2218,7 +2951,7 @@ TypeId ConstraintSolver::resolveModule(const ModuleInfo& info, const Location& l std::optional moduleType = first(modulePack); if (!moduleType) { - reportError(IllegalRequire{humanReadableName, "Module does not return exactly 1 value. It cannot be required."}, location); + reportError(IllegalRequire{module->humanReadableName, "Module does not return exactly 1 value. It cannot be required."}, location); return errorRecoveryType(); } @@ -2237,6 +2970,52 @@ void ConstraintSolver::reportError(TypeError e) errors.back().moduleName = currentModuleName; } +void ConstraintSolver::shiftReferences(TypeId source, TypeId target) +{ + target = follow(target); + + // if the target isn't a reference counted type, there's nothing to do. + // this stops us from keeping unnecessary counts for e.g. primitive types. + if (!isReferenceCountedType(target)) + return; + + auto sourceRefs = unresolvedConstraints.find(source); + if (!sourceRefs) + return; + + // we read out the count before proceeding to avoid hash invalidation issues. + size_t count = *sourceRefs; + + auto [targetRefs, _] = unresolvedConstraints.try_insert(target, 0); + targetRefs += count; +} + +std::optional ConstraintSolver::generalizeFreeType(NotNull scope, TypeId type, bool avoidSealingTables) +{ + TypeId t = follow(type); + if (get(t)) + { + auto refCount = unresolvedConstraints.find(t); + if (refCount && *refCount > 0) + return {}; + + // if no reference count is present, then that means the only constraints referring to + // this free type need only for it to be generalized. in principle, this means we could + // have actually never generated the free type in the first place, but we couldn't know + // that until all constraint generation is complete. + } + + return generalize(NotNull{arena}, builtinTypes, scope, generalizedTypes, type, avoidSealingTables); +} + +bool ConstraintSolver::hasUnresolvedConstraints(TypeId ty) +{ + if (auto refCount = unresolvedConstraints.find(ty)) + return *refCount > 0; + + return false; +} + TypeId ConstraintSolver::errorRecoveryType() const { return builtinTypes->errorRecoveryType(); @@ -2247,39 +3026,48 @@ TypePackId ConstraintSolver::errorRecoveryTypePack() const return builtinTypes->errorRecoveryTypePack(); } -TypeId ConstraintSolver::unionOfTypes(TypeId a, TypeId b, NotNull scope, bool unifyFreeTypes) +TypePackId ConstraintSolver::anyifyModuleReturnTypePackGenerics(TypePackId tp) { - a = follow(a); - b = follow(b); + tp = follow(tp); - if (unifyFreeTypes && (get(a) || get(b))) + if (const VariadicTypePack* vtp = get(tp)) { - Unifier u{normalizer, Mode::Strict, scope, Location{}, Covariant}; - u.useScopes = true; - u.tryUnify(b, a); + TypeId ty = follow(vtp->ty); + return get(ty) ? builtinTypes->anyTypePack : tp; + } - if (u.errors.empty()) - { - u.log.commit(); - return a; - } - else - { - return builtinTypes->errorRecoveryType(builtinTypes->anyType); - } + if (!get(follow(tp))) + return tp; + + std::vector resultTypes; + std::optional resultTail; + + TypePackIterator it = begin(tp); + + for (TypePackIterator e = end(tp); it != e; ++it) + { + TypeId ty = follow(*it); + resultTypes.push_back(get(ty) ? builtinTypes->anyType : ty); } - if (*a == *b) - return a; + if (std::optional tail = it.tail()) + resultTail = anyifyModuleReturnTypePackGenerics(*tail); - std::vector types = reduceUnion({a, b}); - if (types.empty()) - return builtinTypes->neverType; + return arena->addTypePack(resultTypes, resultTail); +} - if (types.size() == 1) - return types[0]; +LUAU_NOINLINE void ConstraintSolver::throwTimeLimitError() const +{ + throw TimeLimitError(currentModuleName); +} - return arena->addType(UnionType{types}); +LUAU_NOINLINE void ConstraintSolver::throwUserCancelError() const +{ + throw UserCancelError(currentModuleName); } +// Instantiate private template implementations for external callers +template bool ConstraintSolver::unify(NotNull constraint, TypeId subTy, TypeId superTy); +template bool ConstraintSolver::unify(NotNull constraint, TypePackId subTy, TypePackId superTy); + } // namespace Luau diff --git a/Analysis/src/DataFlowGraph.cpp b/Analysis/src/DataFlowGraph.cpp index e73c7e8c9..4225942b9 100644 --- a/Analysis/src/DataFlowGraph.cpp +++ b/Analysis/src/DataFlowGraph.cpp @@ -1,305 +1,743 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/DataFlowGraph.h" -#include "Luau/Breadcrumb.h" +#include "Luau/Ast.h" +#include "Luau/BuiltinDefinitions.h" +#include "Luau/Def.h" +#include "Luau/Common.h" #include "Luau/Error.h" -#include "Luau/Refinement.h" +#include "Luau/TimeTrace.h" + +#include +#include LUAU_FASTFLAG(DebugLuauFreezeArena) -LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) +LUAU_FASTFLAG(LuauSolverV2) +LUAU_FASTFLAG(LuauTypestateBuiltins) namespace Luau { -NullableBreadcrumbId DataFlowGraph::getBreadcrumb(const AstExpr* expr) const +bool doesCallError(const AstExprCall* call); // TypeInfer.cpp + +struct ReferencedDefFinder : public AstVisitor { - // We need to skip through AstExprGroup because DFG doesn't try its best to transitively - while (auto group = expr->as()) - expr = group->expr; - if (auto bc = astBreadcrumbs.find(expr)) - return *bc; - return nullptr; + bool visit(AstExprLocal* local) override + { + referencedLocalDefs.push_back(local->local); + return true; + } + // ast defs is just a mapping from expr -> def in general + // will get built up by the dfg builder + + // localDefs, we need to copy over + std::vector referencedLocalDefs; +}; + +struct PushScope +{ + ScopeStack& stack; + + PushScope(ScopeStack& stack, DfgScope* scope) + : stack(stack) + { + // `scope` should never be `nullptr` here. + LUAU_ASSERT(scope); + stack.push_back(scope); + } + + ~PushScope() + { + stack.pop_back(); + } +}; + +const RefinementKey* RefinementKeyArena::leaf(DefId def) +{ + return allocator.allocate(RefinementKey{nullptr, def, std::nullopt}); } -BreadcrumbId DataFlowGraph::getBreadcrumb(const AstLocal* local) const +const RefinementKey* RefinementKeyArena::node(const RefinementKey* parent, DefId def, const std::string& propName) { - auto bc = localBreadcrumbs.find(local); - LUAU_ASSERT(bc); - return NotNull{*bc}; + return allocator.allocate(RefinementKey{parent, def, propName}); } -BreadcrumbId DataFlowGraph::getBreadcrumb(const AstExprLocal* local) const +DefId DataFlowGraph::getDef(const AstExpr* expr) const { - auto bc = astBreadcrumbs.find(local); - LUAU_ASSERT(bc); - return NotNull{*bc}; + auto def = astDefs.find(expr); + LUAU_ASSERT(def); + return NotNull{*def}; } -BreadcrumbId DataFlowGraph::getBreadcrumb(const AstExprGlobal* global) const +std::optional DataFlowGraph::getDefOptional(const AstExpr* expr) const { - auto bc = astBreadcrumbs.find(global); - LUAU_ASSERT(bc); - return NotNull{*bc}; + auto def = astDefs.find(expr); + if (!def) + return std::nullopt; + return NotNull{*def}; } -BreadcrumbId DataFlowGraph::getBreadcrumb(const AstStatDeclareGlobal* global) const +std::optional DataFlowGraph::getRValueDefForCompoundAssign(const AstExpr* expr) const { - auto bc = declaredBreadcrumbs.find(global); - LUAU_ASSERT(bc); - return NotNull{*bc}; + auto def = compoundAssignDefs.find(expr); + return def ? std::optional(*def) : std::nullopt; } -BreadcrumbId DataFlowGraph::getBreadcrumb(const AstStatDeclareFunction* func) const +DefId DataFlowGraph::getDef(const AstLocal* local) const { - auto bc = declaredBreadcrumbs.find(func); - LUAU_ASSERT(bc); - return NotNull{*bc}; + auto def = localDefs.find(local); + LUAU_ASSERT(def); + return NotNull{*def}; } -NullableBreadcrumbId DfgScope::lookup(Symbol symbol) const +DefId DataFlowGraph::getDef(const AstStatDeclareGlobal* global) const +{ + auto def = declaredDefs.find(global); + LUAU_ASSERT(def); + return NotNull{*def}; +} + +DefId DataFlowGraph::getDef(const AstStatDeclareFunction* func) const +{ + auto def = declaredDefs.find(func); + LUAU_ASSERT(def); + return NotNull{*def}; +} + +const RefinementKey* DataFlowGraph::getRefinementKey(const AstExpr* expr) const +{ + if (auto key = astRefinementKeys.find(expr)) + return *key; + + return nullptr; +} + +std::optional DfgScope::lookup(Symbol symbol) const { for (const DfgScope* current = this; current; current = current->parent) { - if (auto breadcrumb = current->bindings.find(symbol)) - return *breadcrumb; + if (auto def = current->bindings.find(symbol)) + return NotNull{*def}; } - return nullptr; + return std::nullopt; } -NullableBreadcrumbId DfgScope::lookup(DefId def, const std::string& key) const +std::optional DfgScope::lookup(DefId def, const std::string& key) const { for (const DfgScope* current = this; current; current = current->parent) { - if (auto map = props.find(def)) + if (auto props = current->props.find(def)) { - if (auto it = map->find(key); it != map->end()) - return it->second; + if (auto it = props->find(key); it != props->end()) + return NotNull{it->second}; } } - return nullptr; + return std::nullopt; +} + +void DfgScope::inherit(const DfgScope* childScope) +{ + for (const auto& [k, a] : childScope->bindings) + { + if (lookup(k)) + bindings[k] = a; + } + + for (const auto& [k1, a1] : childScope->props) + { + for (const auto& [k2, a2] : a1) + props[k1][k2] = a2; + } +} + +bool DfgScope::canUpdateDefinition(Symbol symbol) const +{ + for (const DfgScope* current = this; current; current = current->parent) + { + if (current->bindings.find(symbol)) + return true; + else if (current->scopeType == DfgScope::Loop) + return false; + } + + return true; +} + +bool DfgScope::canUpdateDefinition(DefId def, const std::string& key) const +{ + for (const DfgScope* current = this; current; current = current->parent) + { + if (auto props = current->props.find(def)) + return true; + else if (current->scopeType == DfgScope::Loop) + return false; + } + + return true; } DataFlowGraph DataFlowGraphBuilder::build(AstStatBlock* block, NotNull handle) { - LUAU_ASSERT(FFlag::DebugLuauDeferredConstraintResolution); + LUAU_TIMETRACE_SCOPE("DataFlowGraphBuilder::build", "Typechecking"); + + LUAU_ASSERT(FFlag::LuauSolverV2); DataFlowGraphBuilder builder; builder.handle = handle; - builder.moduleScope = builder.childScope(nullptr); // nullptr is the root DFG scope. - builder.visitBlockWithoutChildScope(builder.moduleScope, block); + DfgScope* moduleScope = builder.makeChildScope(block->location); + PushScope ps{builder.scopeStack, moduleScope}; + builder.visitBlockWithoutChildScope(block); + builder.resolveCaptures(); if (FFlag::DebugLuauFreezeArena) { - builder.defs->allocator.freeze(); - builder.breadcrumbs->allocator.freeze(); + builder.defArena->allocator.freeze(); + builder.keyArena->allocator.freeze(); } return std::move(builder.graph); } -DfgScope* DataFlowGraphBuilder::childScope(DfgScope* scope) +std::pair, std::vector>> DataFlowGraphBuilder::buildShared( + AstStatBlock* block, + NotNull handle +) +{ + + LUAU_TIMETRACE_SCOPE("DataFlowGraphBuilder::build", "Typechecking"); + + LUAU_ASSERT(FFlag::LuauSolverV2); + + DataFlowGraphBuilder builder; + builder.handle = handle; + DfgScope* moduleScope = builder.makeChildScope(block->location); + PushScope ps{builder.scopeStack, moduleScope}; + builder.visitBlockWithoutChildScope(block); + builder.resolveCaptures(); + + if (FFlag::DebugLuauFreezeArena) + { + builder.defArena->allocator.freeze(); + builder.keyArena->allocator.freeze(); + } + + return {std::make_shared(std::move(builder.graph)), std::move(builder.scopes)}; +} + +DataFlowGraph DataFlowGraphBuilder::updateGraph( + const DataFlowGraph& staleGraph, + const std::vector>& scopes, + AstStatBlock* fragment, + const Position& cursorPos, + NotNull handle +) +{ + LUAU_TIMETRACE_SCOPE("DataFlowGraphBuilder::build", "Typechecking"); + LUAU_ASSERT(FFlag::LuauSolverV2); + + DataFlowGraphBuilder builder; + builder.handle = handle; + // Generate a list of prepopulated locals + ReferencedDefFinder finder; + fragment->visit(&finder); + for (AstLocal* loc : finder.referencedLocalDefs) + { + if (staleGraph.localDefs.contains(loc)) + { + builder.graph.localDefs[loc] = *staleGraph.localDefs.find(loc); + } + } + + // Figure out which scope we should start re-accumulating DFG information from again + DfgScope* nearest = nullptr; + for (auto& sc : scopes) + { + if (nearest == nullptr || (sc->location.begin <= cursorPos && nearest->location.begin < sc->location.begin)) + nearest = sc.get(); + } + + // The scope stack should start with the nearest enclosing scope so we can resume DFG'ing correctly + PushScope ps{builder.scopeStack, nearest}; + // Conspire for the current scope in the scope stack to be a fresh dfg scope, parented to the above nearest enclosing scope, so any insertions are + // isolated there + DfgScope* scope = builder.makeChildScope(fragment->location); + PushScope psAgain{builder.scopeStack, scope}; + + builder.visitBlockWithoutChildScope(fragment); + + if (FFlag::DebugLuauFreezeArena) + { + builder.defArena->allocator.freeze(); + builder.keyArena->allocator.freeze(); + } + + return std::move(builder.graph); +} + +void DataFlowGraphBuilder::resolveCaptures() +{ + for (const auto& [_, capture] : captures) + { + std::vector operands; + for (size_t i = capture.versionOffset; i < capture.allVersions.size(); ++i) + collectOperands(capture.allVersions[i], &operands); + + for (DefId captureDef : capture.captureDefs) + { + Phi* phi = const_cast(get(captureDef)); + LUAU_ASSERT(phi); + LUAU_ASSERT(phi->operands.empty()); + phi->operands = operands; + } + } +} + +DfgScope* DataFlowGraphBuilder::currentScope() +{ + if (scopeStack.empty()) + return nullptr; // nullptr is the root DFG scope. + return scopeStack.back(); +} + +DfgScope* DataFlowGraphBuilder::makeChildScope(Location loc, DfgScope::ScopeType scopeType) +{ + return scopes.emplace_back(new DfgScope{currentScope(), scopeType, loc}).get(); +} + +void DataFlowGraphBuilder::join(DfgScope* p, DfgScope* a, DfgScope* b) +{ + joinBindings(p, *a, *b); + joinProps(p, *a, *b); +} + +void DataFlowGraphBuilder::joinBindings(DfgScope* p, const DfgScope& a, const DfgScope& b) +{ + for (const auto& [sym, def1] : a.bindings) + { + if (auto def2 = b.bindings.find(sym)) + p->bindings[sym] = defArena->phi(NotNull{def1}, NotNull{*def2}); + else if (auto def2 = p->lookup(sym)) + p->bindings[sym] = defArena->phi(NotNull{def1}, NotNull{*def2}); + } + + for (const auto& [sym, def1] : b.bindings) + { + if (auto def2 = p->lookup(sym)) + p->bindings[sym] = defArena->phi(NotNull{def1}, NotNull{*def2}); + } +} + +void DataFlowGraphBuilder::joinProps(DfgScope* result, const DfgScope& a, const DfgScope& b) +{ + auto phinodify = [this](DfgScope* scope, const auto& a, const auto& b, DefId parent) mutable + { + auto& p = scope->props[parent]; + for (const auto& [k, defA] : a) + { + if (auto it = b.find(k); it != b.end()) + p[k] = defArena->phi(NotNull{it->second}, NotNull{defA}); + else if (auto it = p.find(k); it != p.end()) + p[k] = defArena->phi(NotNull{it->second}, NotNull{defA}); + else if (auto def2 = scope->lookup(parent, k)) + p[k] = defArena->phi(*def2, NotNull{defA}); + else + p[k] = defA; + } + + for (const auto& [k, defB] : b) + { + if (auto it = a.find(k); it != a.end()) + continue; + else if (auto it = p.find(k); it != p.end()) + p[k] = defArena->phi(NotNull{it->second}, NotNull{defB}); + else if (auto def2 = scope->lookup(parent, k)) + p[k] = defArena->phi(*def2, NotNull{defB}); + else + p[k] = defB; + } + }; + + for (const auto& [def, a1] : a.props) + { + result->props.try_insert(def, {}); + if (auto a2 = b.props.find(def)) + phinodify(result, a1, *a2, NotNull{def}); + else if (auto a2 = result->props.find(def)) + phinodify(result, a1, *a2, NotNull{def}); + } + + for (const auto& [def, a1] : b.props) + { + result->props.try_insert(def, {}); + if (a.props.find(def)) + continue; + else if (auto a2 = result->props.find(def)) + phinodify(result, a1, *a2, NotNull{def}); + } +} + +DefId DataFlowGraphBuilder::lookup(Symbol symbol) +{ + DfgScope* scope = currentScope(); + + // true if any of the considered scopes are a loop. + bool outsideLoopScope = false; + for (DfgScope* current = scope; current; current = current->parent) + { + outsideLoopScope = outsideLoopScope || current->scopeType == DfgScope::Loop; + + if (auto found = current->bindings.find(symbol)) + return NotNull{*found}; + else if (current->scopeType == DfgScope::Function) + { + FunctionCapture& capture = captures[symbol]; + DefId captureDef = defArena->phi({}); + capture.captureDefs.push_back(captureDef); + + // If we are outside of a loop scope, then we don't want to actually bind + // uses of `symbol` to this new phi node since it will not get populated. + if (!outsideLoopScope) + scope->bindings[symbol] = captureDef; + + return NotNull{captureDef}; + } + } + + DefId result = defArena->freshCell(); + scope->bindings[symbol] = result; + captures[symbol].allVersions.push_back(result); + return result; +} + +DefId DataFlowGraphBuilder::lookup(DefId def, const std::string& key) { - return scopes.emplace_back(new DfgScope{scope}).get(); + DfgScope* scope = currentScope(); + for (DfgScope* current = scope; current; current = current->parent) + { + if (auto props = current->props.find(def)) + { + if (auto it = props->find(key); it != props->end()) + return NotNull{it->second}; + } + else if (auto phi = get(def); phi && phi->operands.empty()) // Unresolved phi nodes + { + DefId result = defArena->freshCell(); + scope->props[def][key] = result; + return result; + } + } + + if (auto phi = get(def)) + { + std::vector defs; + for (DefId operand : phi->operands) + defs.push_back(lookup(operand, key)); + + DefId result = defArena->phi(defs); + scope->props[def][key] = result; + return result; + } + else if (get(def)) + { + DefId result = defArena->freshCell(); + scope->props[def][key] = result; + return result; + } + else + handle->ice("Inexhaustive lookup cases in DataFlowGraphBuilder::lookup"); } -void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatBlock* b) +ControlFlow DataFlowGraphBuilder::visit(AstStatBlock* b) { - DfgScope* child = childScope(scope); - return visitBlockWithoutChildScope(child, b); + DfgScope* child = makeChildScope(b->location); + + ControlFlow cf; + { + PushScope ps{scopeStack, child}; + cf = visitBlockWithoutChildScope(b); + } + + currentScope()->inherit(child); + return cf; } -void DataFlowGraphBuilder::visitBlockWithoutChildScope(DfgScope* scope, AstStatBlock* b) +ControlFlow DataFlowGraphBuilder::visitBlockWithoutChildScope(AstStatBlock* b) { - for (AstStat* s : b->body) - visit(scope, s); + std::optional firstControlFlow; + for (AstStat* stat : b->body) + { + ControlFlow cf = visit(stat); + if (cf != ControlFlow::None && !firstControlFlow) + firstControlFlow = cf; + } + + return firstControlFlow.value_or(ControlFlow::None); } -void DataFlowGraphBuilder::visit(DfgScope* scope, AstStat* s) +ControlFlow DataFlowGraphBuilder::visit(AstStat* s) { if (auto b = s->as()) - return visit(scope, b); + return visit(b); else if (auto i = s->as()) - return visit(scope, i); + return visit(i); else if (auto w = s->as()) - return visit(scope, w); + return visit(w); else if (auto r = s->as()) - return visit(scope, r); + return visit(r); else if (auto b = s->as()) - return visit(scope, b); + return visit(b); else if (auto c = s->as()) - return visit(scope, c); + return visit(c); else if (auto r = s->as()) - return visit(scope, r); + return visit(r); else if (auto e = s->as()) - return visit(scope, e); + return visit(e); else if (auto l = s->as()) - return visit(scope, l); + return visit(l); else if (auto f = s->as()) - return visit(scope, f); + return visit(f); else if (auto f = s->as()) - return visit(scope, f); + return visit(f); else if (auto a = s->as()) - return visit(scope, a); + return visit(a); else if (auto c = s->as()) - return visit(scope, c); + return visit(c); else if (auto f = s->as()) - return visit(scope, f); + return visit(f); else if (auto l = s->as()) - return visit(scope, l); + return visit(l); else if (auto t = s->as()) - return visit(scope, t); + return visit(t); + else if (auto f = s->as()) + return visit(f); else if (auto d = s->as()) - return visit(scope, d); + return visit(d); else if (auto d = s->as()) - return visit(scope, d); + return visit(d); else if (auto d = s->as()) - return visit(scope, d); + return visit(d); else if (auto error = s->as()) - return visit(scope, error); + return visit(error); else handle->ice("Unknown AstStat in DataFlowGraphBuilder::visit"); } -void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatIf* i) +ControlFlow DataFlowGraphBuilder::visit(AstStatIf* i) { - // TODO: type states and control flow analysis - visitExpr(scope, i->condition); - visit(scope, i->thenbody); + visitExpr(i->condition); + + DfgScope* thenScope = makeChildScope(i->thenbody->location); + DfgScope* elseScope = makeChildScope(i->elsebody ? i->elsebody->location : i->location); + + ControlFlow thencf; + { + PushScope ps{scopeStack, thenScope}; + thencf = visit(i->thenbody); + } + + ControlFlow elsecf = ControlFlow::None; if (i->elsebody) - visit(scope, i->elsebody); + { + PushScope ps{scopeStack, elseScope}; + elsecf = visit(i->elsebody); + } + + DfgScope* scope = currentScope(); + if (thencf != ControlFlow::None && elsecf == ControlFlow::None) + join(scope, scope, elseScope); + else if (thencf == ControlFlow::None && elsecf != ControlFlow::None) + join(scope, thenScope, scope); + else if ((thencf | elsecf) == ControlFlow::None) + join(scope, thenScope, elseScope); + + if (thencf == elsecf) + return thencf; + else if (matches(thencf, ControlFlow::Returns | ControlFlow::Throws) && matches(elsecf, ControlFlow::Returns | ControlFlow::Throws)) + return ControlFlow::Returns; + else + return ControlFlow::None; } -void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatWhile* w) +ControlFlow DataFlowGraphBuilder::visit(AstStatWhile* w) { // TODO(controlflow): entry point has a back edge from exit point - DfgScope* whileScope = childScope(scope); - visitExpr(whileScope, w->condition); - visit(whileScope, w->body); + DfgScope* whileScope = makeChildScope(w->location, DfgScope::Loop); + + { + PushScope ps{scopeStack, whileScope}; + visitExpr(w->condition); + visit(w->body); + } + + currentScope()->inherit(whileScope); + + return ControlFlow::None; } -void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatRepeat* r) +ControlFlow DataFlowGraphBuilder::visit(AstStatRepeat* r) { // TODO(controlflow): entry point has a back edge from exit point - DfgScope* repeatScope = childScope(scope); // TODO: loop scope. - visitBlockWithoutChildScope(repeatScope, r->body); - visitExpr(repeatScope, r->condition); + DfgScope* repeatScope = makeChildScope(r->location, DfgScope::Loop); + + { + PushScope ps{scopeStack, repeatScope}; + visitBlockWithoutChildScope(r->body); + visitExpr(r->condition); + } + + currentScope()->inherit(repeatScope); + + return ControlFlow::None; } -void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatBreak* b) +ControlFlow DataFlowGraphBuilder::visit(AstStatBreak* b) { - // TODO: Control flow analysis - return; // ok + return ControlFlow::Breaks; } -void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatContinue* c) +ControlFlow DataFlowGraphBuilder::visit(AstStatContinue* c) { - // TODO: Control flow analysis - return; // ok + return ControlFlow::Continues; } -void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatReturn* r) +ControlFlow DataFlowGraphBuilder::visit(AstStatReturn* r) { - // TODO: Control flow analysis for (AstExpr* e : r->list) - visitExpr(scope, e); + visitExpr(e); + + return ControlFlow::Returns; } -void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatExpr* e) +ControlFlow DataFlowGraphBuilder::visit(AstStatExpr* e) { - visitExpr(scope, e->expr); + visitExpr(e->expr); + if (auto call = e->expr->as(); call && doesCallError(call)) + return ControlFlow::Throws; + else + return ControlFlow::None; } -void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatLocal* l) +ControlFlow DataFlowGraphBuilder::visit(AstStatLocal* l) { // We're gonna need a `visitExprList` and `visitVariadicExpr` (function calls and `...`) - std::vector bcs; - bcs.reserve(l->values.size); + std::vector defs; + defs.reserve(l->values.size); for (AstExpr* e : l->values) - bcs.push_back(visitExpr(scope, e)); + defs.push_back(visitExpr(e).def); for (size_t i = 0; i < l->vars.size; ++i) { AstLocal* local = l->vars.data[i]; if (local->annotation) - visitType(scope, local->annotation); + visitType(local->annotation); - // We need to create a new breadcrumb with new defs to intentionally avoid alias tracking. - BreadcrumbId bc = breadcrumbs->add(nullptr, defs->freshCell(), i < bcs.size() ? bcs[i]->metadata : std::nullopt); - graph.localBreadcrumbs[local] = bc; - scope->bindings[local] = bc; + // We need to create a new def to intentionally avoid alias tracking, but we'd like to + // make sure that the non-aliased defs are also marked as a subscript for refinements. + bool subscripted = i < defs.size() && containsSubscriptedDefinition(defs[i]); + DefId def = defArena->freshCell(subscripted); + if (i < l->values.size) + { + AstExpr* e = l->values.data[i]; + if (const AstExprTable* tbl = e->as()) + { + def = defs[i]; + } + } + graph.localDefs[local] = def; + currentScope()->bindings[local] = def; + captures[local].allVersions.push_back(def); } + + return ControlFlow::None; } -void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatFor* f) +ControlFlow DataFlowGraphBuilder::visit(AstStatFor* f) { - DfgScope* forScope = childScope(scope); // TODO: loop scope. + DfgScope* forScope = makeChildScope(f->location, DfgScope::Loop); - visitExpr(scope, f->from); - visitExpr(scope, f->to); + visitExpr(f->from); + visitExpr(f->to); if (f->step) - visitExpr(scope, f->step); + visitExpr(f->step); + + { + PushScope ps{scopeStack, forScope}; - if (f->var->annotation) - visitType(forScope, f->var->annotation); + if (f->var->annotation) + visitType(f->var->annotation); - // TODO: RangeMetadata. - BreadcrumbId bc = breadcrumbs->add(nullptr, defs->freshCell()); - graph.localBreadcrumbs[f->var] = bc; - scope->bindings[f->var] = bc; + DefId def = defArena->freshCell(); + graph.localDefs[f->var] = def; + currentScope()->bindings[f->var] = def; + captures[f->var].allVersions.push_back(def); - // TODO(controlflow): entry point has a back edge from exit point - visit(forScope, f->body); + // TODO(controlflow): entry point has a back edge from exit point + visit(f->body); + } + + currentScope()->inherit(forScope); + + return ControlFlow::None; } -void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatForIn* f) +ControlFlow DataFlowGraphBuilder::visit(AstStatForIn* f) { - DfgScope* forScope = childScope(scope); // TODO: loop scope. + DfgScope* forScope = makeChildScope(f->location, DfgScope::Loop); - for (AstLocal* local : f->vars) { - if (local->annotation) - visitType(forScope, local->annotation); + PushScope ps{scopeStack, forScope}; - // TODO: IterMetadata (different from RangeMetadata) - BreadcrumbId bc = breadcrumbs->add(nullptr, defs->freshCell()); - graph.localBreadcrumbs[local] = bc; - forScope->bindings[local] = bc; + for (AstLocal* local : f->vars) + { + if (local->annotation) + visitType(local->annotation); + + DefId def = defArena->freshCell(); + graph.localDefs[local] = def; + currentScope()->bindings[local] = def; + captures[local].allVersions.push_back(def); + } + + // TODO(controlflow): entry point has a back edge from exit point + // We're gonna need a `visitExprList` and `visitVariadicExpr` (function calls and `...`) + for (AstExpr* e : f->values) + visitExpr(e); + + visit(f->body); } - // TODO(controlflow): entry point has a back edge from exit point - // We're gonna need a `visitExprList` and `visitVariadicExpr` (function calls and `...`) - for (AstExpr* e : f->values) - visitExpr(forScope, e); + currentScope()->inherit(forScope); - visit(forScope, f->body); + return ControlFlow::None; } -void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatAssign* a) +ControlFlow DataFlowGraphBuilder::visit(AstStatAssign* a) { - for (AstExpr* r : a->values) - visitExpr(scope, r); + std::vector defs; + defs.reserve(a->values.size); + for (AstExpr* e : a->values) + defs.push_back(visitExpr(e).def); - for (AstExpr* l : a->vars) - visitLValue(scope, l); + for (size_t i = 0; i < a->vars.size; ++i) + { + AstExpr* v = a->vars.data[i]; + visitLValue(v, i < defs.size() ? defs[i] : defArena->freshCell()); + } + + return ControlFlow::None; } -void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatCompoundAssign* c) +ControlFlow DataFlowGraphBuilder::visit(AstStatCompoundAssign* c) { - // TODO: This needs revisiting because this is incorrect. The `c->var` part is both being read and written to, - // but the `c->var` only has one pointer address, so we need to come up with a way to store both. - // For now, it's not important because we don't have type states, but it is going to be important, e.g. - // - // local a = 5 -- a[1] - // a += 5 -- a[2] = a[1] + 5 - // - // We can't just visit `c->var` as a rvalue and then separately traverse `c->var` as an lvalue, since that's O(n^2). - visitLValue(scope, c->var); - visitExpr(scope, c->value); + (void)visitExpr(c->value); + (void)visitExpr(c->var); + + return ControlFlow::None; } -void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatFunction* f) +ControlFlow DataFlowGraphBuilder::visit(AstStatFunction* f) { // In the old solver, we assumed that the name of the function is always a function in the body // but this isn't true, e.g. the following example will print `5`, not a function address. @@ -311,210 +749,294 @@ void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatFunction* f) // // which is evidence that references to variables must be a phi node of all possible definitions, // but for bug compatibility, we'll assume the same thing here. - visitLValue(scope, f->name); - visitExpr(scope, f->func); + visitLValue(f->name, defArena->freshCell()); + visitExpr(f->func); + + if (auto local = f->name->as()) + { + // local f + // function f() + // if cond() then + // f() -- should reference only the function version and other future version, and nothing prior + // end + // end + FunctionCapture& capture = captures[local->local]; + capture.versionOffset = capture.allVersions.size() - 1; + } + + return ControlFlow::None; } -void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatLocalFunction* l) +ControlFlow DataFlowGraphBuilder::visit(AstStatLocalFunction* l) { - BreadcrumbId bc = breadcrumbs->add(nullptr, defs->freshCell()); - graph.localBreadcrumbs[l->name] = bc; - scope->bindings[l->name] = bc; + DefId def = defArena->freshCell(); + graph.localDefs[l->name] = def; + currentScope()->bindings[l->name] = def; + captures[l->name].allVersions.push_back(def); + visitExpr(l->func); - visitExpr(scope, l->func); + return ControlFlow::None; } -void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatTypeAlias* t) +ControlFlow DataFlowGraphBuilder::visit(AstStatTypeAlias* t) { - DfgScope* unreachable = childScope(scope); - visitGenerics(unreachable, t->generics); - visitGenericPacks(unreachable, t->genericPacks); - visitType(unreachable, t->type); + DfgScope* unreachable = makeChildScope(t->location); + PushScope ps{scopeStack, unreachable}; + + visitGenerics(t->generics); + visitGenericPacks(t->genericPacks); + visitType(t->type); + + return ControlFlow::None; } -void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatDeclareGlobal* d) +ControlFlow DataFlowGraphBuilder::visit(AstStatTypeFunction* f) { - // TODO: AmbientDeclarationMetadata. - BreadcrumbId bc = breadcrumbs->add(nullptr, defs->freshCell()); - graph.declaredBreadcrumbs[d] = bc; - scope->bindings[d->name] = bc; + DfgScope* unreachable = makeChildScope(f->location); + PushScope ps{scopeStack, unreachable}; - visitType(scope, d->type); + visitExpr(f->body); + + return ControlFlow::None; } -void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatDeclareFunction* d) +ControlFlow DataFlowGraphBuilder::visit(AstStatDeclareGlobal* d) { - // TODO: AmbientDeclarationMetadata. - BreadcrumbId bc = breadcrumbs->add(nullptr, defs->freshCell()); - graph.declaredBreadcrumbs[d] = bc; - scope->bindings[d->name] = bc; + DefId def = defArena->freshCell(); + graph.declaredDefs[d] = def; + currentScope()->bindings[d->name] = def; + captures[d->name].allVersions.push_back(def); + + visitType(d->type); - DfgScope* unreachable = childScope(scope); - visitGenerics(unreachable, d->generics); - visitGenericPacks(unreachable, d->genericPacks); - visitTypeList(unreachable, d->params); - visitTypeList(unreachable, d->retTypes); + return ControlFlow::None; } -void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatDeclareClass* d) +ControlFlow DataFlowGraphBuilder::visit(AstStatDeclareFunction* d) +{ + DefId def = defArena->freshCell(); + graph.declaredDefs[d] = def; + currentScope()->bindings[d->name] = def; + captures[d->name].allVersions.push_back(def); + + DfgScope* unreachable = makeChildScope(d->location); + PushScope ps{scopeStack, unreachable}; + + visitGenerics(d->generics); + visitGenericPacks(d->genericPacks); + visitTypeList(d->params); + visitTypeList(d->retTypes); + + return ControlFlow::None; +} + +ControlFlow DataFlowGraphBuilder::visit(AstStatDeclareClass* d) { // This declaration does not "introduce" any bindings in value namespace, // so there's no symbolic value to begin with. We'll traverse the properties // because their type annotations may depend on something in the value namespace. - DfgScope* unreachable = childScope(scope); + DfgScope* unreachable = makeChildScope(d->location); + PushScope ps{scopeStack, unreachable}; + for (AstDeclaredClassProp prop : d->props) - visitType(unreachable, prop.ty); + visitType(prop.ty); + + return ControlFlow::None; } -void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatError* error) +ControlFlow DataFlowGraphBuilder::visit(AstStatError* error) { - DfgScope* unreachable = childScope(scope); + DfgScope* unreachable = makeChildScope(error->location); + PushScope ps{scopeStack, unreachable}; + for (AstStat* s : error->statements) - visit(unreachable, s); + visit(s); for (AstExpr* e : error->expressions) - visitExpr(unreachable, e); -} - -BreadcrumbId DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExpr* e) -{ - if (auto g = e->as()) - return visitExpr(scope, g->expr); - else if (auto c = e->as()) - return breadcrumbs->add(nullptr, defs->freshCell()); // ok - else if (auto c = e->as()) - return breadcrumbs->add(nullptr, defs->freshCell()); // ok - else if (auto c = e->as()) - return breadcrumbs->add(nullptr, defs->freshCell()); // ok - else if (auto c = e->as()) - return breadcrumbs->add(nullptr, defs->freshCell()); // ok - else if (auto l = e->as()) - return visitExpr(scope, l); - else if (auto g = e->as()) - return visitExpr(scope, g); - else if (auto v = e->as()) - return breadcrumbs->add(nullptr, defs->freshCell()); // ok - else if (auto c = e->as()) - return visitExpr(scope, c); - else if (auto i = e->as()) - return visitExpr(scope, i); - else if (auto i = e->as()) - return visitExpr(scope, i); - else if (auto f = e->as()) - return visitExpr(scope, f); - else if (auto t = e->as()) - return visitExpr(scope, t); - else if (auto u = e->as()) - return visitExpr(scope, u); - else if (auto b = e->as()) - return visitExpr(scope, b); - else if (auto t = e->as()) - return visitExpr(scope, t); - else if (auto i = e->as()) - return visitExpr(scope, i); - else if (auto i = e->as()) - return visitExpr(scope, i); - else if (auto error = e->as()) - return visitExpr(scope, error); - else - handle->ice("Unknown AstExpr in DataFlowGraphBuilder::visitExpr"); + visitExpr(e); + + return ControlFlow::None; } -BreadcrumbId DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprLocal* l) +DataFlowResult DataFlowGraphBuilder::visitExpr(AstExpr* e) { - NullableBreadcrumbId breadcrumb = scope->lookup(l->local); - if (!breadcrumb) - handle->ice("AstExprLocal came before its declaration?"); + // Some subexpressions could be visited two times. If we've already seen it, just extract it. + if (auto def = graph.astDefs.find(e)) + { + auto key = graph.astRefinementKeys.find(e); + return {NotNull{*def}, key ? *key : nullptr}; + } + + auto go = [&]() -> DataFlowResult + { + if (auto g = e->as()) + return visitExpr(g); + else if (auto c = e->as()) + return {defArena->freshCell(), nullptr}; // ok + else if (auto c = e->as()) + return {defArena->freshCell(), nullptr}; // ok + else if (auto c = e->as()) + return {defArena->freshCell(), nullptr}; // ok + else if (auto c = e->as()) + return {defArena->freshCell(), nullptr}; // ok + else if (auto l = e->as()) + return visitExpr(l); + else if (auto g = e->as()) + return visitExpr(g); + else if (auto v = e->as()) + return {defArena->freshCell(), nullptr}; // ok + else if (auto c = e->as()) + return visitExpr(c); + else if (auto i = e->as()) + return visitExpr(i); + else if (auto i = e->as()) + return visitExpr(i); + else if (auto f = e->as()) + return visitExpr(f); + else if (auto t = e->as()) + return visitExpr(t); + else if (auto u = e->as()) + return visitExpr(u); + else if (auto b = e->as()) + return visitExpr(b); + else if (auto t = e->as()) + return visitExpr(t); + else if (auto i = e->as()) + return visitExpr(i); + else if (auto i = e->as()) + return visitExpr(i); + else if (auto error = e->as()) + return visitExpr(error); + else + handle->ice("Unknown AstExpr in DataFlowGraphBuilder::visitExpr"); + }; - graph.astBreadcrumbs[l] = breadcrumb; - return NotNull{breadcrumb}; + auto [def, key] = go(); + graph.astDefs[e] = def; + if (key) + graph.astRefinementKeys[e] = key; + return {def, key}; } -BreadcrumbId DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprGlobal* g) +DataFlowResult DataFlowGraphBuilder::visitExpr(AstExprGroup* group) { - NullableBreadcrumbId bc = scope->lookup(g->name); - if (!bc) - { - bc = breadcrumbs->add(nullptr, defs->freshCell()); - moduleScope->bindings[g->name] = bc; - } + return visitExpr(group->expr); +} - graph.astBreadcrumbs[g] = bc; - return NotNull{bc}; +DataFlowResult DataFlowGraphBuilder::visitExpr(AstExprLocal* l) +{ + DefId def = lookup(l->local); + const RefinementKey* key = keyArena->leaf(def); + return {def, key}; } -BreadcrumbId DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprCall* c) +DataFlowResult DataFlowGraphBuilder::visitExpr(AstExprGlobal* g) { - visitExpr(scope, c->func); + DefId def = lookup(g->name); + return {def, keyArena->leaf(def)}; +} + +DataFlowResult DataFlowGraphBuilder::visitExpr(AstExprCall* c) +{ + visitExpr(c->func); + + if (FFlag::LuauTypestateBuiltins && shouldTypestateForFirstArgument(*c) && c->args.size > 1 && isLValue(*c->args.begin())) + { + AstExpr* firstArg = *c->args.begin(); + + // this logic has to handle the name-like subset of expressions. + std::optional result; + if (auto l = firstArg->as()) + result = visitExpr(l); + else if (auto g = firstArg->as()) + result = visitExpr(g); + else if (auto i = firstArg->as()) + result = visitExpr(i); + else if (auto i = firstArg->as()) + result = visitExpr(i); + else + LUAU_UNREACHABLE(); // This is unreachable because the whole thing is guarded by `isLValue`. + + LUAU_ASSERT(result); + + Location location = currentScope()->location; + // This scope starts at the end of the call site and continues to the end of the original scope. + location.begin = c->location.end; + DfgScope* child = makeChildScope(location); + scopeStack.push_back(child); + + auto [def, key] = *result; + graph.astDefs[firstArg] = def; + if (key) + graph.astRefinementKeys[firstArg] = key; + + visitLValue(firstArg, def); + } for (AstExpr* arg : c->args) - visitExpr(scope, arg); + visitExpr(arg); - return breadcrumbs->add(nullptr, defs->freshCell()); + // calls should be treated as subscripted. + return {defArena->freshCell(/* subscripted */ true), nullptr}; } -BreadcrumbId DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprIndexName* i) +DataFlowResult DataFlowGraphBuilder::visitExpr(AstExprIndexName* i) { - BreadcrumbId parentBreadcrumb = visitExpr(scope, i->expr); + auto [parentDef, parentKey] = visitExpr(i->expr); - std::string key = i->index.value; - NullableBreadcrumbId& propBreadcrumb = moduleScope->props[parentBreadcrumb->def][key]; - if (!propBreadcrumb) - propBreadcrumb = breadcrumbs->emplace(parentBreadcrumb, defs->freshCell(), key); + std::string index = i->index.value; - graph.astBreadcrumbs[i] = propBreadcrumb; - return NotNull{propBreadcrumb}; + DefId def = lookup(parentDef, index); + return {def, keyArena->node(parentKey, def, index)}; } -BreadcrumbId DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprIndexExpr* i) +DataFlowResult DataFlowGraphBuilder::visitExpr(AstExprIndexExpr* i) { - BreadcrumbId parentBreadcrumb = visitExpr(scope, i->expr); - BreadcrumbId key = visitExpr(scope, i->index); + auto [parentDef, parentKey] = visitExpr(i->expr); + visitExpr(i->index); if (auto string = i->index->as()) { - std::string key{string->value.data, string->value.size}; - NullableBreadcrumbId& propBreadcrumb = moduleScope->props[parentBreadcrumb->def][key]; - if (!propBreadcrumb) - propBreadcrumb = breadcrumbs->emplace(parentBreadcrumb, defs->freshCell(), key); + std::string index{string->value.data, string->value.size}; - graph.astBreadcrumbs[i] = NotNull{propBreadcrumb}; - return NotNull{propBreadcrumb}; + DefId def = lookup(parentDef, index); + return {def, keyArena->node(parentKey, def, index)}; } - return breadcrumbs->emplace(nullptr, defs->freshCell(), key); + return {defArena->freshCell(/* subscripted= */ true), nullptr}; } -BreadcrumbId DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprFunction* f) +DataFlowResult DataFlowGraphBuilder::visitExpr(AstExprFunction* f) { - DfgScope* signatureScope = childScope(scope); + DfgScope* signatureScope = makeChildScope(f->location, DfgScope::Function); + PushScope ps{scopeStack, signatureScope}; if (AstLocal* self = f->self) { // There's no syntax for `self` to have an annotation if using `function t:m()` LUAU_ASSERT(!self->annotation); - // TODO: ParameterMetadata. - BreadcrumbId bc = breadcrumbs->add(nullptr, defs->freshCell()); - graph.localBreadcrumbs[self] = bc; - signatureScope->bindings[self] = bc; + DefId def = defArena->freshCell(); + graph.localDefs[self] = def; + signatureScope->bindings[self] = def; + captures[self].allVersions.push_back(def); } for (AstLocal* param : f->args) { if (param->annotation) - visitType(signatureScope, param->annotation); + visitType(param->annotation); - // TODO: ParameterMetadata. - BreadcrumbId bc = breadcrumbs->add(nullptr, defs->freshCell()); - graph.localBreadcrumbs[param] = bc; - signatureScope->bindings[param] = bc; + DefId def = defArena->freshCell(); + graph.localDefs[param] = def; + signatureScope->bindings[param] = def; + captures[param].allVersions.push_back(def); } if (f->varargAnnotation) - visitTypePack(scope, f->varargAnnotation); + visitTypePack(f->varargAnnotation); if (f->returnAnnotation) - visitTypeList(signatureScope, *f->returnAnnotation); + visitTypeList(*f->returnAnnotation); // TODO: function body can be re-entrant, as in mutations that occurs at the end of the function can also be // visible to the beginning of the function, so statically speaking, the body of the function has an exit point @@ -524,164 +1046,190 @@ BreadcrumbId DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprFunction* f // local g = f // g() --> function: address // g() --> 5 - visit(signatureScope, f->body); + visit(f->body); - return breadcrumbs->add(nullptr, defs->freshCell()); + return {defArena->freshCell(), nullptr}; } -BreadcrumbId DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprTable* t) +DataFlowResult DataFlowGraphBuilder::visitExpr(AstExprTable* t) { + DefId tableCell = defArena->freshCell(); + currentScope()->props[tableCell] = {}; for (AstExprTable::Item item : t->items) { + DataFlowResult result = visitExpr(item.value); if (item.key) - visitExpr(scope, item.key); - visitExpr(scope, item.value); + { + visitExpr(item.key); + if (auto string = item.key->as()) + currentScope()->props[tableCell][string->value.data] = result.def; + } } - return breadcrumbs->add(nullptr, defs->freshCell()); + return {tableCell, nullptr}; } -BreadcrumbId DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprUnary* u) +DataFlowResult DataFlowGraphBuilder::visitExpr(AstExprUnary* u) { - visitExpr(scope, u->expr); + visitExpr(u->expr); - return breadcrumbs->add(nullptr, defs->freshCell()); + return {defArena->freshCell(), nullptr}; } -BreadcrumbId DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprBinary* b) +DataFlowResult DataFlowGraphBuilder::visitExpr(AstExprBinary* b) { - visitExpr(scope, b->left); - visitExpr(scope, b->right); + visitExpr(b->left); + visitExpr(b->right); - return breadcrumbs->add(nullptr, defs->freshCell()); + return {defArena->freshCell(), nullptr}; } -BreadcrumbId DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprTypeAssertion* t) +DataFlowResult DataFlowGraphBuilder::visitExpr(AstExprTypeAssertion* t) { - // TODO: TypeAssertionMetadata? - BreadcrumbId bc = visitExpr(scope, t->expr); - visitType(scope, t->annotation); + auto [def, key] = visitExpr(t->expr); + visitType(t->annotation); - return bc; + return {def, key}; } -BreadcrumbId DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprIfElse* i) +DataFlowResult DataFlowGraphBuilder::visitExpr(AstExprIfElse* i) { - visitExpr(scope, i->condition); - visitExpr(scope, i->trueExpr); - visitExpr(scope, i->falseExpr); + visitExpr(i->condition); + visitExpr(i->trueExpr); + visitExpr(i->falseExpr); - return breadcrumbs->add(nullptr, defs->freshCell()); + return {defArena->freshCell(), nullptr}; } -BreadcrumbId DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprInterpString* i) +DataFlowResult DataFlowGraphBuilder::visitExpr(AstExprInterpString* i) { for (AstExpr* e : i->expressions) - visitExpr(scope, e); + visitExpr(e); - return breadcrumbs->add(nullptr, defs->freshCell()); + return {defArena->freshCell(), nullptr}; } -BreadcrumbId DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprError* error) +DataFlowResult DataFlowGraphBuilder::visitExpr(AstExprError* error) { - DfgScope* unreachable = childScope(scope); + DfgScope* unreachable = makeChildScope(error->location); + PushScope ps{scopeStack, unreachable}; + for (AstExpr* e : error->expressions) - visitExpr(unreachable, e); + visitExpr(e); - return breadcrumbs->add(nullptr, defs->freshCell()); + return {defArena->freshCell(), nullptr}; } -void DataFlowGraphBuilder::visitLValue(DfgScope* scope, AstExpr* e) +void DataFlowGraphBuilder::visitLValue(AstExpr* e, DefId incomingDef) { - if (auto l = e->as()) - return visitLValue(scope, l); - else if (auto g = e->as()) - return visitLValue(scope, g); - else if (auto i = e->as()) - return visitLValue(scope, i); - else if (auto i = e->as()) - return visitLValue(scope, i); - else if (auto error = e->as()) + auto go = [&]() { - visitExpr(scope, error); // TODO: is this right? - return; - } - else - handle->ice("Unknown AstExpr in DataFlowGraphBuilder::visitLValue"); + if (auto l = e->as()) + return visitLValue(l, incomingDef); + else if (auto g = e->as()) + return visitLValue(g, incomingDef); + else if (auto i = e->as()) + return visitLValue(i, incomingDef); + else if (auto i = e->as()) + return visitLValue(i, incomingDef); + else if (auto error = e->as()) + return visitLValue(error, incomingDef); + else + handle->ice("Unknown AstExpr in DataFlowGraphBuilder::visitLValue"); + }; + + graph.astDefs[e] = go(); } -void DataFlowGraphBuilder::visitLValue(DfgScope* scope, AstExprLocal* l) +DefId DataFlowGraphBuilder::visitLValue(AstExprLocal* l, DefId incomingDef) { - // Bug compatibility: we don't support type states yet, so we need to do this. - NullableBreadcrumbId bc = scope->lookup(l->local); - LUAU_ASSERT(bc); + DfgScope* scope = currentScope(); - graph.astBreadcrumbs[l] = bc; - scope->bindings[l->local] = bc; + // In order to avoid alias tracking, we need to clip the reference to the parent def. + if (scope->canUpdateDefinition(l->local)) + { + DefId updated = defArena->freshCell(containsSubscriptedDefinition(incomingDef)); + scope->bindings[l->local] = updated; + captures[l->local].allVersions.push_back(updated); + return updated; + } + else + return visitExpr(static_cast(l)).def; } -void DataFlowGraphBuilder::visitLValue(DfgScope* scope, AstExprGlobal* g) +DefId DataFlowGraphBuilder::visitLValue(AstExprGlobal* g, DefId incomingDef) { - // Bug compatibility: we don't support type states yet, so we need to do this. - NullableBreadcrumbId bc = scope->lookup(g->name); - if (!bc) - bc = breadcrumbs->add(nullptr, defs->freshCell()); + DfgScope* scope = currentScope(); - graph.astBreadcrumbs[g] = bc; - scope->bindings[g->name] = bc; + // In order to avoid alias tracking, we need to clip the reference to the parent def. + if (scope->canUpdateDefinition(g->name)) + { + DefId updated = defArena->freshCell(containsSubscriptedDefinition(incomingDef)); + scope->bindings[g->name] = updated; + captures[g->name].allVersions.push_back(updated); + return updated; + } + else + return visitExpr(static_cast(g)).def; } -void DataFlowGraphBuilder::visitLValue(DfgScope* scope, AstExprIndexName* i) +DefId DataFlowGraphBuilder::visitLValue(AstExprIndexName* i, DefId incomingDef) { - // Bug compatibility: we don't support type states yet, so we need to do this. - BreadcrumbId parentBreadcrumb = visitExpr(scope, i->expr); + DefId parentDef = visitExpr(i->expr).def; - std::string key = i->index.value; - NullableBreadcrumbId propBreadcrumb = scope->lookup(parentBreadcrumb->def, key); - if (!propBreadcrumb) + DfgScope* scope = currentScope(); + if (scope->canUpdateDefinition(parentDef, i->index.value)) { - propBreadcrumb = breadcrumbs->emplace(parentBreadcrumb, defs->freshCell(), key); - moduleScope->props[parentBreadcrumb->def][key] = propBreadcrumb; + DefId updated = defArena->freshCell(containsSubscriptedDefinition(incomingDef)); + scope->props[parentDef][i->index.value] = updated; + return updated; } - - graph.astBreadcrumbs[i] = propBreadcrumb; + else + return visitExpr(static_cast(i)).def; } -void DataFlowGraphBuilder::visitLValue(DfgScope* scope, AstExprIndexExpr* i) +DefId DataFlowGraphBuilder::visitLValue(AstExprIndexExpr* i, DefId incomingDef) { - BreadcrumbId parentBreadcrumb = visitExpr(scope, i->expr); - visitExpr(scope, i->index); + DefId parentDef = visitExpr(i->expr).def; + visitExpr(i->index); + DfgScope* scope = currentScope(); if (auto string = i->index->as()) { - std::string key{string->value.data, string->value.size}; - NullableBreadcrumbId propBreadcrumb = scope->lookup(parentBreadcrumb->def, key); - if (!propBreadcrumb) + if (scope->canUpdateDefinition(parentDef, string->value.data)) { - propBreadcrumb = breadcrumbs->add(parentBreadcrumb, parentBreadcrumb->def); - moduleScope->props[parentBreadcrumb->def][key] = propBreadcrumb; + DefId updated = defArena->freshCell(containsSubscriptedDefinition(incomingDef)); + scope->props[parentDef][string->value.data] = updated; + return updated; } - - graph.astBreadcrumbs[i] = propBreadcrumb; + else + return visitExpr(static_cast(i)).def; } + else + return defArena->freshCell(/*subscripted=*/true); +} + +DefId DataFlowGraphBuilder::visitLValue(AstExprError* error, DefId incomingDef) +{ + return visitExpr(error).def; } -void DataFlowGraphBuilder::visitType(DfgScope* scope, AstType* t) +void DataFlowGraphBuilder::visitType(AstType* t) { if (auto r = t->as()) - return visitType(scope, r); + return visitType(r); else if (auto table = t->as()) - return visitType(scope, table); + return visitType(table); else if (auto f = t->as()) - return visitType(scope, f); + return visitType(f); else if (auto tyof = t->as()) - return visitType(scope, tyof); + return visitType(tyof); else if (auto u = t->as()) - return visitType(scope, u); + return visitType(u); else if (auto i = t->as()) - return visitType(scope, i); + return visitType(i); else if (auto e = t->as()) - return visitType(scope, e); + return visitType(e); else if (auto s = t->as()) return; // ok else if (auto s = t->as()) @@ -690,106 +1238,106 @@ void DataFlowGraphBuilder::visitType(DfgScope* scope, AstType* t) handle->ice("Unknown AstType in DataFlowGraphBuilder::visitType"); } -void DataFlowGraphBuilder::visitType(DfgScope* scope, AstTypeReference* r) +void DataFlowGraphBuilder::visitType(AstTypeReference* r) { for (AstTypeOrPack param : r->parameters) { if (param.type) - visitType(scope, param.type); + visitType(param.type); else - visitTypePack(scope, param.typePack); + visitTypePack(param.typePack); } } -void DataFlowGraphBuilder::visitType(DfgScope* scope, AstTypeTable* t) +void DataFlowGraphBuilder::visitType(AstTypeTable* t) { for (AstTableProp p : t->props) - visitType(scope, p.type); + visitType(p.type); if (t->indexer) { - visitType(scope, t->indexer->indexType); - visitType(scope, t->indexer->resultType); + visitType(t->indexer->indexType); + visitType(t->indexer->resultType); } } -void DataFlowGraphBuilder::visitType(DfgScope* scope, AstTypeFunction* f) +void DataFlowGraphBuilder::visitType(AstTypeFunction* f) { - visitGenerics(scope, f->generics); - visitGenericPacks(scope, f->genericPacks); - visitTypeList(scope, f->argTypes); - visitTypeList(scope, f->returnTypes); + visitGenerics(f->generics); + visitGenericPacks(f->genericPacks); + visitTypeList(f->argTypes); + visitTypeList(f->returnTypes); } -void DataFlowGraphBuilder::visitType(DfgScope* scope, AstTypeTypeof* t) +void DataFlowGraphBuilder::visitType(AstTypeTypeof* t) { - visitExpr(scope, t->expr); + visitExpr(t->expr); } -void DataFlowGraphBuilder::visitType(DfgScope* scope, AstTypeUnion* u) +void DataFlowGraphBuilder::visitType(AstTypeUnion* u) { for (AstType* t : u->types) - visitType(scope, t); + visitType(t); } -void DataFlowGraphBuilder::visitType(DfgScope* scope, AstTypeIntersection* i) +void DataFlowGraphBuilder::visitType(AstTypeIntersection* i) { for (AstType* t : i->types) - visitType(scope, t); + visitType(t); } -void DataFlowGraphBuilder::visitType(DfgScope* scope, AstTypeError* error) +void DataFlowGraphBuilder::visitType(AstTypeError* error) { for (AstType* t : error->types) - visitType(scope, t); + visitType(t); } -void DataFlowGraphBuilder::visitTypePack(DfgScope* scope, AstTypePack* p) +void DataFlowGraphBuilder::visitTypePack(AstTypePack* p) { if (auto e = p->as()) - return visitTypePack(scope, e); + return visitTypePack(e); else if (auto v = p->as()) - return visitTypePack(scope, v); + return visitTypePack(v); else if (auto g = p->as()) return; // ok else handle->ice("Unknown AstTypePack in DataFlowGraphBuilder::visitTypePack"); } -void DataFlowGraphBuilder::visitTypePack(DfgScope* scope, AstTypePackExplicit* e) +void DataFlowGraphBuilder::visitTypePack(AstTypePackExplicit* e) { - visitTypeList(scope, e->typeList); + visitTypeList(e->typeList); } -void DataFlowGraphBuilder::visitTypePack(DfgScope* scope, AstTypePackVariadic* v) +void DataFlowGraphBuilder::visitTypePack(AstTypePackVariadic* v) { - visitType(scope, v->variadicType); + visitType(v->variadicType); } -void DataFlowGraphBuilder::visitTypeList(DfgScope* scope, AstTypeList l) +void DataFlowGraphBuilder::visitTypeList(AstTypeList l) { for (AstType* t : l.types) - visitType(scope, t); + visitType(t); if (l.tailType) - visitTypePack(scope, l.tailType); + visitTypePack(l.tailType); } -void DataFlowGraphBuilder::visitGenerics(DfgScope* scope, AstArray g) +void DataFlowGraphBuilder::visitGenerics(AstArray g) { for (AstGenericType generic : g) { if (generic.defaultValue) - visitType(scope, generic.defaultValue); + visitType(generic.defaultValue); } } -void DataFlowGraphBuilder::visitGenericPacks(DfgScope* scope, AstArray g) +void DataFlowGraphBuilder::visitGenericPacks(AstArray g) { for (AstGenericTypePack generic : g) { if (generic.defaultValue) - visitTypePack(scope, generic.defaultValue); + visitTypePack(generic.defaultValue); } } diff --git a/Analysis/src/DcrLogger.cpp b/Analysis/src/DcrLogger.cpp index 9f66b022a..f013b9852 100644 --- a/Analysis/src/DcrLogger.cpp +++ b/Analysis/src/DcrLogger.cpp @@ -124,7 +124,8 @@ void write(JsonEmitter& emitter, const ConstraintBlock& block) ObjectEmitter o = emitter.writeObject(); o.writePair("stringification", block.stringification); - auto go = [&o](auto&& t) { + auto go = [&o](auto&& t) + { using T = std::decay_t; o.writePair("id", toPointerId(t)); @@ -350,8 +351,12 @@ void DcrLogger::popBlock(NotNull block) } } -static void snapshotTypeStrings(const std::vector& interestedExprs, - const std::vector& interestedAnnots, DenseHashMap& map, ToStringOptions& opts) +static void snapshotTypeStrings( + const std::vector& interestedExprs, + const std::vector& interestedAnnots, + DenseHashMap& map, + ToStringOptions& opts +) { for (const ExprTypesAtLocation& tys : interestedExprs) { @@ -368,7 +373,10 @@ static void snapshotTypeStrings(const std::vector& interest } void DcrLogger::captureBoundaryState( - BoundarySnapshot& target, const Scope* rootScope, const std::vector>& unsolvedConstraints) + BoundarySnapshot& target, + const Scope* rootScope, + const std::vector>& unsolvedConstraints +) { target.rootScope = snapshotScope(rootScope, opts); target.unsolvedConstraints.clear(); @@ -391,7 +399,11 @@ void DcrLogger::captureInitialSolverState(const Scope* rootScope, const std::vec } StepSnapshot DcrLogger::prepareStepSnapshot( - const Scope* rootScope, NotNull current, bool force, const std::vector>& unsolvedConstraints) + const Scope* rootScope, + NotNull current, + bool force, + const std::vector>& unsolvedConstraints +) { ScopeSnapshot scopeSnapshot = snapshotScope(rootScope, opts); DenseHashMap constraints{nullptr}; diff --git a/Analysis/src/Def.cpp b/Analysis/src/Def.cpp index 7be075c25..6d58b28fe 100644 --- a/Analysis/src/Def.cpp +++ b/Analysis/src/Def.cpp @@ -1,12 +1,62 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/Def.h" +#include "Luau/Common.h" + +#include + namespace Luau { -DefId DefArena::freshCell() +bool containsSubscriptedDefinition(DefId def) +{ + if (auto cell = get(def)) + return cell->subscripted; + else if (auto phi = get(def)) + return std::any_of(phi->operands.begin(), phi->operands.end(), containsSubscriptedDefinition); + else + return false; +} + +void collectOperands(DefId def, std::vector* operands) +{ + LUAU_ASSERT(operands); + if (std::find(operands->begin(), operands->end(), def) != operands->end()) + return; + else if (get(def)) + operands->push_back(def); + else if (auto phi = get(def)) + { + // A trivial phi node has no operands to populate, so we push this definition in directly. + if (phi->operands.empty()) + return operands->push_back(def); + + for (const Def* operand : phi->operands) + collectOperands(NotNull{operand}, operands); + } +} + +DefId DefArena::freshCell(bool subscripted) +{ + return NotNull{allocator.allocate(Def{Cell{subscripted}})}; +} + +DefId DefArena::phi(DefId a, DefId b) +{ + return phi({a, b}); +} + +DefId DefArena::phi(const std::vector& defs) { - return NotNull{allocator.allocate(Def{Cell{}})}; + std::vector operands; + for (DefId operand : defs) + collectOperands(operand, &operands); + + // There's no need to allocate a Phi node for a singleton set. + if (operands.size() == 1) + return operands[0]; + else + return NotNull{allocator.allocate(Def{Phi{std::move(operands)}})}; } } // namespace Luau diff --git a/Analysis/src/Differ.cpp b/Analysis/src/Differ.cpp new file mode 100644 index 000000000..b2cebc0ba --- /dev/null +++ b/Analysis/src/Differ.cpp @@ -0,0 +1,970 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/Differ.h" +#include "Luau/Common.h" +#include "Luau/Error.h" +#include "Luau/ToString.h" +#include "Luau/Type.h" +#include "Luau/TypePack.h" +#include "Luau/Unifiable.h" +#include +#include +#include +#include + +namespace Luau +{ + +std::string DiffPathNode::toString() const +{ + switch (kind) + { + case DiffPathNode::Kind::TableProperty: + { + if (!tableProperty.has_value()) + throw InternalCompilerError{"DiffPathNode has kind TableProperty but tableProperty is nullopt"}; + return *tableProperty; + break; + } + case DiffPathNode::Kind::FunctionArgument: + { + if (!index.has_value()) + return "Arg[Variadic]"; + // Add 1 because Lua is 1-indexed + return "Arg[" + std::to_string(*index + 1) + "]"; + } + case DiffPathNode::Kind::FunctionReturn: + { + if (!index.has_value()) + return "Ret[Variadic]"; + // Add 1 because Lua is 1-indexed + return "Ret[" + std::to_string(*index + 1) + "]"; + } + case DiffPathNode::Kind::Negation: + { + return "Negation"; + } + default: + { + throw InternalCompilerError{"DiffPathNode::toString is not exhaustive"}; + } + } +} + +DiffPathNode DiffPathNode::constructWithTableProperty(Name tableProperty) +{ + return DiffPathNode{DiffPathNode::Kind::TableProperty, tableProperty, std::nullopt}; +} + +DiffPathNode DiffPathNode::constructWithKindAndIndex(Kind kind, size_t index) +{ + return DiffPathNode{kind, std::nullopt, index}; +} + +DiffPathNode DiffPathNode::constructWithKind(Kind kind) +{ + return DiffPathNode{kind, std::nullopt, std::nullopt}; +} + +DiffPathNodeLeaf DiffPathNodeLeaf::detailsNormal(TypeId ty) +{ + return DiffPathNodeLeaf{ty, std::nullopt, std::nullopt, false, std::nullopt}; +} + +DiffPathNodeLeaf DiffPathNodeLeaf::detailsTableProperty(TypeId ty, Name tableProperty) +{ + return DiffPathNodeLeaf{ty, tableProperty, std::nullopt, false, std::nullopt}; +} + +DiffPathNodeLeaf DiffPathNodeLeaf::detailsUnionIndex(TypeId ty, size_t index) +{ + return DiffPathNodeLeaf{ty, std::nullopt, std::nullopt, false, index}; +} + +DiffPathNodeLeaf DiffPathNodeLeaf::detailsLength(int minLength, bool isVariadic) +{ + return DiffPathNodeLeaf{std::nullopt, std::nullopt, minLength, isVariadic, std::nullopt}; +} + +DiffPathNodeLeaf DiffPathNodeLeaf::nullopts() +{ + return DiffPathNodeLeaf{std::nullopt, std::nullopt, std::nullopt, false, std::nullopt}; +} + +std::string DiffPath::toString(bool prependDot) const +{ + std::string pathStr; + bool isFirstInForLoop = !prependDot; + for (auto node = path.rbegin(); node != path.rend(); node++) + { + if (isFirstInForLoop) + { + isFirstInForLoop = false; + } + else + { + pathStr += "."; + } + pathStr += node->toString(); + } + return pathStr; +} +std::string DiffError::toStringALeaf(std::string rootName, const DiffPathNodeLeaf& leaf, const DiffPathNodeLeaf& otherLeaf, bool multiLine) const +{ + std::string conditionalNewline = multiLine ? "\n" : " "; + std::string conditionalIndent = multiLine ? " " : ""; + std::string pathStr{rootName + diffPath.toString(true)}; + switch (kind) + { + case DiffError::Kind::Normal: + { + checkNonMissingPropertyLeavesHaveNulloptTableProperty(); + return pathStr + conditionalNewline + "has type" + conditionalNewline + conditionalIndent + Luau::toString(*leaf.ty); + } + case DiffError::Kind::MissingTableProperty: + { + if (leaf.ty.has_value()) + { + if (!leaf.tableProperty.has_value()) + throw InternalCompilerError{"leaf.tableProperty is nullopt"}; + return pathStr + "." + *leaf.tableProperty + conditionalNewline + "has type" + conditionalNewline + conditionalIndent + + Luau::toString(*leaf.ty); + } + else if (otherLeaf.ty.has_value()) + { + if (!otherLeaf.tableProperty.has_value()) + throw InternalCompilerError{"otherLeaf.tableProperty is nullopt"}; + return pathStr + conditionalNewline + "is missing the property" + conditionalNewline + conditionalIndent + *otherLeaf.tableProperty; + } + throw InternalCompilerError{"Both leaf.ty and otherLeaf.ty is nullopt"}; + } + case DiffError::Kind::MissingUnionMember: + { + // TODO: do normal case + if (leaf.ty.has_value()) + { + if (!leaf.unionIndex.has_value()) + throw InternalCompilerError{"leaf.unionIndex is nullopt"}; + return pathStr + conditionalNewline + "is a union containing type" + conditionalNewline + conditionalIndent + Luau::toString(*leaf.ty); + } + else if (otherLeaf.ty.has_value()) + { + return pathStr + conditionalNewline + "is a union missing type" + conditionalNewline + conditionalIndent + Luau::toString(*otherLeaf.ty); + } + throw InternalCompilerError{"Both leaf.ty and otherLeaf.ty is nullopt"}; + } + case DiffError::Kind::MissingIntersectionMember: + { + // TODO: better message for intersections + // An intersection of just functions is always an "overloaded function" + // An intersection of just tables is always a "joined table" + if (leaf.ty.has_value()) + { + if (!leaf.unionIndex.has_value()) + throw InternalCompilerError{"leaf.unionIndex is nullopt"}; + return pathStr + conditionalNewline + "is an intersection containing type" + conditionalNewline + conditionalIndent + + Luau::toString(*leaf.ty); + } + else if (otherLeaf.ty.has_value()) + { + return pathStr + conditionalNewline + "is an intersection missing type" + conditionalNewline + conditionalIndent + + Luau::toString(*otherLeaf.ty); + } + throw InternalCompilerError{"Both leaf.ty and otherLeaf.ty is nullopt"}; + } + case DiffError::Kind::LengthMismatchInFnArgs: + { + if (!leaf.minLength.has_value()) + throw InternalCompilerError{"leaf.minLength is nullopt"}; + return pathStr + conditionalNewline + "takes " + std::to_string(*leaf.minLength) + (leaf.isVariadic ? " or more" : "") + " arguments"; + } + case DiffError::Kind::LengthMismatchInFnRets: + { + if (!leaf.minLength.has_value()) + throw InternalCompilerError{"leaf.minLength is nullopt"}; + return pathStr + conditionalNewline + "returns " + std::to_string(*leaf.minLength) + (leaf.isVariadic ? " or more" : "") + " values"; + } + default: + { + throw InternalCompilerError{"DiffPath::toStringALeaf is not exhaustive"}; + } + } +} + +void DiffError::checkNonMissingPropertyLeavesHaveNulloptTableProperty() const +{ + if (left.tableProperty.has_value() || right.tableProperty.has_value()) + throw InternalCompilerError{"Non-MissingProperty DiffError should have nullopt tableProperty in both leaves"}; +} + +std::string getDevFixFriendlyName(const std::optional& maybeSymbol, TypeId ty) +{ + if (maybeSymbol.has_value()) + return *maybeSymbol; + + if (auto table = get(ty)) + { + if (table->name.has_value()) + return *table->name; + else if (table->syntheticName.has_value()) + return *table->syntheticName; + } + if (auto metatable = get(ty)) + { + if (metatable->syntheticName.has_value()) + { + return *metatable->syntheticName; + } + } + return ""; +} + +std::string DifferEnvironment::getDevFixFriendlyNameLeft() const +{ + return getDevFixFriendlyName(externalSymbolLeft, rootLeft); +} + +std::string DifferEnvironment::getDevFixFriendlyNameRight() const +{ + return getDevFixFriendlyName(externalSymbolRight, rootRight); +} + +std::string DiffError::toString(bool multiLine) const +{ + std::string conditionalNewline = multiLine ? "\n" : " "; + std::string conditionalIndent = multiLine ? " " : ""; + switch (kind) + { + case DiffError::Kind::IncompatibleGeneric: + { + std::string diffPathStr{diffPath.toString(true)}; + return "DiffError: these two types are not equal because the left generic at" + conditionalNewline + conditionalIndent + leftRootName + + diffPathStr + conditionalNewline + "cannot be the same type parameter as the right generic at" + conditionalNewline + + conditionalIndent + rightRootName + diffPathStr; + } + default: + { + return "DiffError: these two types are not equal because the left type at" + conditionalNewline + conditionalIndent + + toStringALeaf(leftRootName, left, right, multiLine) + "," + conditionalNewline + "while the right type at" + conditionalNewline + + conditionalIndent + toStringALeaf(rightRootName, right, left, multiLine); + } + } +} + +void DiffError::checkValidInitialization(const DiffPathNodeLeaf& left, const DiffPathNodeLeaf& right) +{ + if (!left.ty.has_value() || !right.ty.has_value()) + { + // TODO: think about whether this should be always thrown! + // For example, Kind::Primitive doesn't make too much sense to have a TypeId + // throw InternalCompilerError{"Left and Right fields are leaf nodes and must have a TypeId"}; + } +} + +void DifferResult::wrapDiffPath(DiffPathNode node) +{ + if (!diffError.has_value()) + { + throw InternalCompilerError{"Cannot wrap diffPath because there is no diffError"}; + } + + diffError->diffPath.path.push_back(node); +} + +static DifferResult diffUsingEnv(DifferEnvironment& env, TypeId left, TypeId right); +static DifferResult diffTable(DifferEnvironment& env, TypeId left, TypeId right); +static DifferResult diffMetatable(DifferEnvironment& env, TypeId left, TypeId right); +static DifferResult diffPrimitive(DifferEnvironment& env, TypeId left, TypeId right); +static DifferResult diffSingleton(DifferEnvironment& env, TypeId left, TypeId right); +static DifferResult diffFunction(DifferEnvironment& env, TypeId left, TypeId right); +static DifferResult diffGeneric(DifferEnvironment& env, TypeId left, TypeId right); +static DifferResult diffNegation(DifferEnvironment& env, TypeId left, TypeId right); +static DifferResult diffClass(DifferEnvironment& env, TypeId left, TypeId right); +struct FindSeteqCounterexampleResult +{ + // nullopt if no counterexample found + std::optional mismatchIdx; + // true if counterexample is in the left, false if cex is in the right + bool inLeft; +}; +static FindSeteqCounterexampleResult findSeteqCounterexample( + DifferEnvironment& env, + const std::vector& left, + const std::vector& right +); +static DifferResult diffUnion(DifferEnvironment& env, TypeId left, TypeId right); +static DifferResult diffIntersection(DifferEnvironment& env, TypeId left, TypeId right); +/** + * The last argument gives context info on which complex type contained the TypePack. + */ +static DifferResult diffTpi(DifferEnvironment& env, DiffError::Kind possibleNonNormalErrorKind, TypePackId left, TypePackId right); +static DifferResult diffCanonicalTpShape( + DifferEnvironment& env, + DiffError::Kind possibleNonNormalErrorKind, + const std::pair, std::optional>& left, + const std::pair, std::optional>& right +); +static DifferResult diffHandleFlattenedTail(DifferEnvironment& env, DiffError::Kind possibleNonNormalErrorKind, TypePackId left, TypePackId right); +static DifferResult diffGenericTp(DifferEnvironment& env, TypePackId left, TypePackId right); + +static DifferResult diffTable(DifferEnvironment& env, TypeId left, TypeId right) +{ + const TableType* leftTable = get(left); + const TableType* rightTable = get(right); + LUAU_ASSERT(leftTable); + LUAU_ASSERT(rightTable); + + for (auto const& [field, value] : leftTable->props) + { + if (rightTable->props.find(field) == rightTable->props.end()) + { + // left has a field the right doesn't + return DifferResult{DiffError{ + DiffError::Kind::MissingTableProperty, + DiffPathNodeLeaf::detailsTableProperty(value.type(), field), + DiffPathNodeLeaf::nullopts(), + env.getDevFixFriendlyNameLeft(), + env.getDevFixFriendlyNameRight(), + }}; + } + } + for (auto const& [field, value] : rightTable->props) + { + if (leftTable->props.find(field) == leftTable->props.end()) + { + // right has a field the left doesn't + return DifferResult{DiffError{ + DiffError::Kind::MissingTableProperty, + DiffPathNodeLeaf::nullopts(), + DiffPathNodeLeaf::detailsTableProperty(value.type(), field), + env.getDevFixFriendlyNameLeft(), + env.getDevFixFriendlyNameRight() + }}; + } + } + // left and right have the same set of keys + for (auto const& [field, leftValue] : leftTable->props) + { + auto const& rightValue = rightTable->props.at(field); + DifferResult differResult = diffUsingEnv(env, leftValue.type(), rightValue.type()); + if (differResult.diffError.has_value()) + { + differResult.wrapDiffPath(DiffPathNode::constructWithTableProperty(field)); + return differResult; + } + } + return DifferResult{}; +} + +static DifferResult diffMetatable(DifferEnvironment& env, TypeId left, TypeId right) +{ + const MetatableType* leftMetatable = get(left); + const MetatableType* rightMetatable = get(right); + LUAU_ASSERT(leftMetatable); + LUAU_ASSERT(rightMetatable); + + DifferResult diffRes = diffUsingEnv(env, leftMetatable->table, rightMetatable->table); + if (diffRes.diffError.has_value()) + { + return diffRes; + } + + diffRes = diffUsingEnv(env, leftMetatable->metatable, rightMetatable->metatable); + if (diffRes.diffError.has_value()) + { + diffRes.wrapDiffPath(DiffPathNode::constructWithTableProperty("__metatable")); + return diffRes; + } + return DifferResult{}; +} + +static DifferResult diffPrimitive(DifferEnvironment& env, TypeId left, TypeId right) +{ + const PrimitiveType* leftPrimitive = get(left); + const PrimitiveType* rightPrimitive = get(right); + LUAU_ASSERT(leftPrimitive); + LUAU_ASSERT(rightPrimitive); + + if (leftPrimitive->type != rightPrimitive->type) + { + return DifferResult{DiffError{ + DiffError::Kind::Normal, + DiffPathNodeLeaf::detailsNormal(left), + DiffPathNodeLeaf::detailsNormal(right), + env.getDevFixFriendlyNameLeft(), + env.getDevFixFriendlyNameRight(), + }}; + } + return DifferResult{}; +} + +static DifferResult diffSingleton(DifferEnvironment& env, TypeId left, TypeId right) +{ + const SingletonType* leftSingleton = get(left); + const SingletonType* rightSingleton = get(right); + LUAU_ASSERT(leftSingleton); + LUAU_ASSERT(rightSingleton); + + if (*leftSingleton != *rightSingleton) + { + return DifferResult{DiffError{ + DiffError::Kind::Normal, + DiffPathNodeLeaf::detailsNormal(left), + DiffPathNodeLeaf::detailsNormal(right), + env.getDevFixFriendlyNameLeft(), + env.getDevFixFriendlyNameRight(), + }}; + } + return DifferResult{}; +} + +static DifferResult diffFunction(DifferEnvironment& env, TypeId left, TypeId right) +{ + const FunctionType* leftFunction = get(left); + const FunctionType* rightFunction = get(right); + LUAU_ASSERT(leftFunction); + LUAU_ASSERT(rightFunction); + + DifferResult differResult = diffTpi(env, DiffError::Kind::LengthMismatchInFnArgs, leftFunction->argTypes, rightFunction->argTypes); + if (differResult.diffError.has_value()) + return differResult; + return diffTpi(env, DiffError::Kind::LengthMismatchInFnRets, leftFunction->retTypes, rightFunction->retTypes); +} + +static DifferResult diffGeneric(DifferEnvironment& env, TypeId left, TypeId right) +{ + LUAU_ASSERT(get(left)); + LUAU_ASSERT(get(right)); + // Try to pair up the generics + bool isLeftFree = !env.genericMatchedPairs.contains(left); + bool isRightFree = !env.genericMatchedPairs.contains(right); + if (isLeftFree && isRightFree) + { + env.genericMatchedPairs[left] = right; + env.genericMatchedPairs[right] = left; + return DifferResult{}; + } + else if (isLeftFree || isRightFree) + { + return DifferResult{DiffError{ + DiffError::Kind::IncompatibleGeneric, + DiffPathNodeLeaf::nullopts(), + DiffPathNodeLeaf::nullopts(), + env.getDevFixFriendlyNameLeft(), + env.getDevFixFriendlyNameRight(), + }}; + } + + // Both generics are already paired up + if (*env.genericMatchedPairs.find(left) == right) + return DifferResult{}; + + return DifferResult{DiffError{ + DiffError::Kind::IncompatibleGeneric, + DiffPathNodeLeaf::nullopts(), + DiffPathNodeLeaf::nullopts(), + env.getDevFixFriendlyNameLeft(), + env.getDevFixFriendlyNameRight(), + }}; +} + +static DifferResult diffNegation(DifferEnvironment& env, TypeId left, TypeId right) +{ + const NegationType* leftNegation = get(left); + const NegationType* rightNegation = get(right); + LUAU_ASSERT(leftNegation); + LUAU_ASSERT(rightNegation); + + DifferResult differResult = diffUsingEnv(env, leftNegation->ty, rightNegation->ty); + if (!differResult.diffError.has_value()) + return DifferResult{}; + + differResult.wrapDiffPath(DiffPathNode::constructWithKind(DiffPathNode::Kind::Negation)); + return differResult; +} + +static DifferResult diffClass(DifferEnvironment& env, TypeId left, TypeId right) +{ + const ClassType* leftClass = get(left); + const ClassType* rightClass = get(right); + LUAU_ASSERT(leftClass); + LUAU_ASSERT(rightClass); + + if (leftClass == rightClass) + { + return DifferResult{}; + } + + return DifferResult{DiffError{ + DiffError::Kind::Normal, + DiffPathNodeLeaf::detailsNormal(left), + DiffPathNodeLeaf::detailsNormal(right), + env.getDevFixFriendlyNameLeft(), + env.getDevFixFriendlyNameRight(), + }}; +} + +static FindSeteqCounterexampleResult findSeteqCounterexample( + DifferEnvironment& env, + const std::vector& left, + const std::vector& right +) +{ + std::unordered_set unmatchedRightIdxes; + for (size_t i = 0; i < right.size(); i++) + unmatchedRightIdxes.insert(i); + for (size_t leftIdx = 0; leftIdx < left.size(); leftIdx++) + { + bool leftIdxIsMatched = false; + auto unmatchedRightIdxIt = unmatchedRightIdxes.begin(); + while (unmatchedRightIdxIt != unmatchedRightIdxes.end()) + { + DifferResult differResult = diffUsingEnv(env, left[leftIdx], right[*unmatchedRightIdxIt]); + if (differResult.diffError.has_value()) + { + unmatchedRightIdxIt++; + continue; + } + // unmatchedRightIdxIt is matched with current leftIdx + env.recordProvenEqual(left[leftIdx], right[*unmatchedRightIdxIt]); + leftIdxIsMatched = true; + unmatchedRightIdxIt = unmatchedRightIdxes.erase(unmatchedRightIdxIt); + } + if (!leftIdxIsMatched) + { + return FindSeteqCounterexampleResult{leftIdx, true}; + } + } + if (unmatchedRightIdxes.empty()) + return FindSeteqCounterexampleResult{std::nullopt, false}; + return FindSeteqCounterexampleResult{*unmatchedRightIdxes.begin(), false}; +} + +static DifferResult diffUnion(DifferEnvironment& env, TypeId left, TypeId right) +{ + const UnionType* leftUnion = get(left); + const UnionType* rightUnion = get(right); + LUAU_ASSERT(leftUnion); + LUAU_ASSERT(rightUnion); + + FindSeteqCounterexampleResult findSeteqCexResult = findSeteqCounterexample(env, leftUnion->options, rightUnion->options); + if (findSeteqCexResult.mismatchIdx.has_value()) + { + if (findSeteqCexResult.inLeft) + return DifferResult{DiffError{ + DiffError::Kind::MissingUnionMember, + DiffPathNodeLeaf::detailsUnionIndex(leftUnion->options[*findSeteqCexResult.mismatchIdx], *findSeteqCexResult.mismatchIdx), + DiffPathNodeLeaf::nullopts(), + env.getDevFixFriendlyNameLeft(), + env.getDevFixFriendlyNameRight(), + }}; + else + return DifferResult{DiffError{ + DiffError::Kind::MissingUnionMember, + DiffPathNodeLeaf::nullopts(), + DiffPathNodeLeaf::detailsUnionIndex(rightUnion->options[*findSeteqCexResult.mismatchIdx], *findSeteqCexResult.mismatchIdx), + env.getDevFixFriendlyNameLeft(), + env.getDevFixFriendlyNameRight(), + }}; + } + + // TODO: somehow detect mismatch index, likely using heuristics + + return DifferResult{}; +} + +static DifferResult diffIntersection(DifferEnvironment& env, TypeId left, TypeId right) +{ + const IntersectionType* leftIntersection = get(left); + const IntersectionType* rightIntersection = get(right); + LUAU_ASSERT(leftIntersection); + LUAU_ASSERT(rightIntersection); + + FindSeteqCounterexampleResult findSeteqCexResult = findSeteqCounterexample(env, leftIntersection->parts, rightIntersection->parts); + if (findSeteqCexResult.mismatchIdx.has_value()) + { + if (findSeteqCexResult.inLeft) + return DifferResult{DiffError{ + DiffError::Kind::MissingIntersectionMember, + DiffPathNodeLeaf::detailsUnionIndex(leftIntersection->parts[*findSeteqCexResult.mismatchIdx], *findSeteqCexResult.mismatchIdx), + DiffPathNodeLeaf::nullopts(), + env.getDevFixFriendlyNameLeft(), + env.getDevFixFriendlyNameRight(), + }}; + else + return DifferResult{DiffError{ + DiffError::Kind::MissingIntersectionMember, + DiffPathNodeLeaf::nullopts(), + DiffPathNodeLeaf::detailsUnionIndex(rightIntersection->parts[*findSeteqCexResult.mismatchIdx], *findSeteqCexResult.mismatchIdx), + env.getDevFixFriendlyNameLeft(), + env.getDevFixFriendlyNameRight(), + }}; + } + + // TODO: somehow detect mismatch index, likely using heuristics + + return DifferResult{}; +} + +static DifferResult diffUsingEnv(DifferEnvironment& env, TypeId left, TypeId right) +{ + left = follow(left); + right = follow(right); + + if (left->ty.index() != right->ty.index()) + { + return DifferResult{DiffError{ + DiffError::Kind::Normal, + DiffPathNodeLeaf::detailsNormal(left), + DiffPathNodeLeaf::detailsNormal(right), + env.getDevFixFriendlyNameLeft(), + env.getDevFixFriendlyNameRight(), + }}; + } + + // Both left and right are the same variant + + // Check cycles & caches + if (env.isAssumedEqual(left, right) || env.isProvenEqual(left, right)) + return DifferResult{}; + + if (isSimple(left)) + { + if (auto lp = get(left)) + return diffPrimitive(env, left, right); + else if (auto ls = get(left)) + { + return diffSingleton(env, left, right); + } + else if (auto la = get(left)) + { + // Both left and right must be Any if either is Any for them to be equal! + return DifferResult{}; + } + else if (auto lu = get(left)) + { + return DifferResult{}; + } + else if (auto ln = get(left)) + { + return DifferResult{}; + } + else if (auto ln = get(left)) + { + return diffNegation(env, left, right); + } + else if (auto lc = get(left)) + { + return diffClass(env, left, right); + } + + throw InternalCompilerError{"Unimplemented Simple TypeId variant for diffing"}; + } + + // Both left and right are the same non-Simple + // Non-simple types must record visits in the DifferEnvironment + env.pushVisiting(left, right); + + if (auto lt = get(left)) + { + DifferResult diffRes = diffTable(env, left, right); + if (!diffRes.diffError.has_value()) + { + env.recordProvenEqual(left, right); + } + env.popVisiting(); + return diffRes; + } + if (auto lm = get(left)) + { + env.popVisiting(); + return diffMetatable(env, left, right); + } + if (auto lf = get(left)) + { + DifferResult diffRes = diffFunction(env, left, right); + if (!diffRes.diffError.has_value()) + { + env.recordProvenEqual(left, right); + } + env.popVisiting(); + return diffRes; + } + if (auto lg = get(left)) + { + DifferResult diffRes = diffGeneric(env, left, right); + if (!diffRes.diffError.has_value()) + { + env.recordProvenEqual(left, right); + } + env.popVisiting(); + return diffRes; + } + if (auto lu = get(left)) + { + DifferResult diffRes = diffUnion(env, left, right); + if (!diffRes.diffError.has_value()) + { + env.recordProvenEqual(left, right); + } + env.popVisiting(); + return diffRes; + } + if (auto li = get(left)) + { + DifferResult diffRes = diffIntersection(env, left, right); + if (!diffRes.diffError.has_value()) + { + env.recordProvenEqual(left, right); + } + env.popVisiting(); + return diffRes; + } + if (auto le = get(left)) + { + // TODO: return debug-friendly result state + env.popVisiting(); + return DifferResult{}; + } + + throw InternalCompilerError{"Unimplemented non-simple TypeId variant for diffing"}; +} + +static DifferResult diffTpi(DifferEnvironment& env, DiffError::Kind possibleNonNormalErrorKind, TypePackId left, TypePackId right) +{ + left = follow(left); + right = follow(right); + + // Canonicalize + std::pair, std::optional> leftFlatTpi = flatten(left); + std::pair, std::optional> rightFlatTpi = flatten(right); + + // Check for shape equality + DifferResult diffResult = diffCanonicalTpShape(env, possibleNonNormalErrorKind, leftFlatTpi, rightFlatTpi); + if (diffResult.diffError.has_value()) + { + return diffResult; + } + + // Left and Right have the same shape + for (size_t i = 0; i < leftFlatTpi.first.size(); i++) + { + DifferResult differResult = diffUsingEnv(env, leftFlatTpi.first[i], rightFlatTpi.first[i]); + if (!differResult.diffError.has_value()) + continue; + + switch (possibleNonNormalErrorKind) + { + case DiffError::Kind::LengthMismatchInFnArgs: + { + differResult.wrapDiffPath(DiffPathNode::constructWithKindAndIndex(DiffPathNode::Kind::FunctionArgument, i)); + return differResult; + } + case DiffError::Kind::LengthMismatchInFnRets: + { + differResult.wrapDiffPath(DiffPathNode::constructWithKindAndIndex(DiffPathNode::Kind::FunctionReturn, i)); + return differResult; + } + default: + { + throw InternalCompilerError{"Unhandled Tpi diffing case with same shape"}; + } + } + } + if (!leftFlatTpi.second.has_value()) + return DifferResult{}; + + return diffHandleFlattenedTail(env, possibleNonNormalErrorKind, *leftFlatTpi.second, *rightFlatTpi.second); +} + +static DifferResult diffCanonicalTpShape( + DifferEnvironment& env, + DiffError::Kind possibleNonNormalErrorKind, + const std::pair, std::optional>& left, + const std::pair, std::optional>& right +) +{ + if (left.first.size() == right.first.size() && left.second.has_value() == right.second.has_value()) + return DifferResult{}; + + return DifferResult{DiffError{ + possibleNonNormalErrorKind, + DiffPathNodeLeaf::detailsLength(int(left.first.size()), left.second.has_value()), + DiffPathNodeLeaf::detailsLength(int(right.first.size()), right.second.has_value()), + env.getDevFixFriendlyNameLeft(), + env.getDevFixFriendlyNameRight(), + }}; +} + +static DifferResult diffHandleFlattenedTail(DifferEnvironment& env, DiffError::Kind possibleNonNormalErrorKind, TypePackId left, TypePackId right) +{ + left = follow(left); + right = follow(right); + + if (left->ty.index() != right->ty.index()) + { + return DifferResult{DiffError{ + DiffError::Kind::Normal, + DiffPathNodeLeaf::detailsNormal(env.visitingBegin()->first), + DiffPathNodeLeaf::detailsNormal(env.visitingBegin()->second), + env.getDevFixFriendlyNameLeft(), + env.getDevFixFriendlyNameRight(), + }}; + } + + // Both left and right are the same variant + + if (auto lv = get(left)) + { + auto rv = get(right); + DifferResult differResult = diffUsingEnv(env, lv->ty, rv->ty); + if (!differResult.diffError.has_value()) + return DifferResult{}; + + switch (possibleNonNormalErrorKind) + { + case DiffError::Kind::LengthMismatchInFnArgs: + { + differResult.wrapDiffPath(DiffPathNode::constructWithKind(DiffPathNode::Kind::FunctionArgument)); + return differResult; + } + case DiffError::Kind::LengthMismatchInFnRets: + { + differResult.wrapDiffPath(DiffPathNode::constructWithKind(DiffPathNode::Kind::FunctionReturn)); + return differResult; + } + default: + { + throw InternalCompilerError{"Unhandled flattened tail case for VariadicTypePack"}; + } + } + } + if (auto lg = get(left)) + { + DifferResult diffRes = diffGenericTp(env, left, right); + if (!diffRes.diffError.has_value()) + return DifferResult{}; + switch (possibleNonNormalErrorKind) + { + case DiffError::Kind::LengthMismatchInFnArgs: + { + diffRes.wrapDiffPath(DiffPathNode::constructWithKind(DiffPathNode::Kind::FunctionArgument)); + return diffRes; + } + case DiffError::Kind::LengthMismatchInFnRets: + { + diffRes.wrapDiffPath(DiffPathNode::constructWithKind(DiffPathNode::Kind::FunctionReturn)); + return diffRes; + } + default: + { + throw InternalCompilerError{"Unhandled flattened tail case for GenericTypePack"}; + } + } + } + + throw InternalCompilerError{"Unhandled tail type pack variant for flattened tails"}; +} + +static DifferResult diffGenericTp(DifferEnvironment& env, TypePackId left, TypePackId right) +{ + LUAU_ASSERT(get(left)); + LUAU_ASSERT(get(right)); + // Try to pair up the generics + bool isLeftFree = !env.genericTpMatchedPairs.contains(left); + bool isRightFree = !env.genericTpMatchedPairs.contains(right); + if (isLeftFree && isRightFree) + { + env.genericTpMatchedPairs[left] = right; + env.genericTpMatchedPairs[right] = left; + return DifferResult{}; + } + else if (isLeftFree || isRightFree) + { + return DifferResult{DiffError{ + DiffError::Kind::IncompatibleGeneric, + DiffPathNodeLeaf::nullopts(), + DiffPathNodeLeaf::nullopts(), + env.getDevFixFriendlyNameLeft(), + env.getDevFixFriendlyNameRight(), + }}; + } + + // Both generics are already paired up + if (*env.genericTpMatchedPairs.find(left) == right) + return DifferResult{}; + + return DifferResult{DiffError{ + DiffError::Kind::IncompatibleGeneric, + DiffPathNodeLeaf::nullopts(), + DiffPathNodeLeaf::nullopts(), + env.getDevFixFriendlyNameLeft(), + env.getDevFixFriendlyNameRight(), + }}; +} + +bool DifferEnvironment::isProvenEqual(TypeId left, TypeId right) const +{ + return provenEqual.find({left, right}) != provenEqual.end(); +} + +bool DifferEnvironment::isAssumedEqual(TypeId left, TypeId right) const +{ + return visiting.find({left, right}) != visiting.end(); +} + +void DifferEnvironment::recordProvenEqual(TypeId left, TypeId right) +{ + provenEqual.insert({left, right}); + provenEqual.insert({right, left}); +} + +void DifferEnvironment::pushVisiting(TypeId left, TypeId right) +{ + LUAU_ASSERT(visiting.find({left, right}) == visiting.end()); + LUAU_ASSERT(visiting.find({right, left}) == visiting.end()); + visitingStack.push_back({left, right}); + visiting.insert({left, right}); + visiting.insert({right, left}); +} + +void DifferEnvironment::popVisiting() +{ + auto tyPair = visitingStack.back(); + visiting.erase({tyPair.first, tyPair.second}); + visiting.erase({tyPair.second, tyPair.first}); + visitingStack.pop_back(); +} + +std::vector>::const_reverse_iterator DifferEnvironment::visitingBegin() const +{ + return visitingStack.crbegin(); +} + +std::vector>::const_reverse_iterator DifferEnvironment::visitingEnd() const +{ + return visitingStack.crend(); +} + + +DifferResult diff(TypeId ty1, TypeId ty2) +{ + DifferEnvironment differEnv{ty1, ty2, std::nullopt, std::nullopt}; + return diffUsingEnv(differEnv, ty1, ty2); +} + + +DifferResult diffWithSymbols(TypeId ty1, TypeId ty2, std::optional symbol1, std::optional symbol2) +{ + DifferEnvironment differEnv{ty1, ty2, symbol1, symbol2}; + return diffUsingEnv(differEnv, ty1, ty2); +} + +bool isSimple(TypeId ty) +{ + ty = follow(ty); + // TODO: think about GenericType, etc. + return get(ty) || get(ty) || get(ty) || get(ty) || get(ty) || + get(ty) || get(ty); +} + +} // namespace Luau diff --git a/Analysis/src/EmbeddedBuiltinDefinitions.cpp b/Analysis/src/EmbeddedBuiltinDefinitions.cpp index 364244ad3..50e090ca7 100644 --- a/Analysis/src/EmbeddedBuiltinDefinitions.cpp +++ b/Analysis/src/EmbeddedBuiltinDefinitions.cpp @@ -1,71 +1,297 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/BuiltinDefinitions.h" +LUAU_FASTFLAG(LuauMathMap) + namespace Luau { -static const std::string kBuiltinDefinitionLuaSrc = R"BUILTIN_SRC( +// TODO: there has to be a better way, like splitting up per library +static const std::string kBuiltinDefinitionLuaSrcChecked_DEPRECATED = R"BUILTIN_SRC( + +declare bit32: { + band: @checked (...number) -> number, + bor: @checked (...number) -> number, + bxor: @checked (...number) -> number, + btest: @checked (number, ...number) -> boolean, + rrotate: @checked (x: number, disp: number) -> number, + lrotate: @checked (x: number, disp: number) -> number, + lshift: @checked (x: number, disp: number) -> number, + arshift: @checked (x: number, disp: number) -> number, + rshift: @checked (x: number, disp: number) -> number, + bnot: @checked (x: number) -> number, + extract: @checked (n: number, field: number, width: number?) -> number, + replace: @checked (n: number, v: number, field: number, width: number?) -> number, + countlz: @checked (n: number) -> number, + countrz: @checked (n: number) -> number, + byteswap: @checked (n: number) -> number, +} + +declare math: { + frexp: @checked (n: number) -> (number, number), + ldexp: @checked (s: number, e: number) -> number, + fmod: @checked (x: number, y: number) -> number, + modf: @checked (n: number) -> (number, number), + pow: @checked (x: number, y: number) -> number, + exp: @checked (n: number) -> number, + + ceil: @checked (n: number) -> number, + floor: @checked (n: number) -> number, + abs: @checked (n: number) -> number, + sqrt: @checked (n: number) -> number, + + log: @checked (n: number, base: number?) -> number, + log10: @checked (n: number) -> number, + + rad: @checked (n: number) -> number, + deg: @checked (n: number) -> number, + + sin: @checked (n: number) -> number, + cos: @checked (n: number) -> number, + tan: @checked (n: number) -> number, + sinh: @checked (n: number) -> number, + cosh: @checked (n: number) -> number, + tanh: @checked (n: number) -> number, + atan: @checked (n: number) -> number, + acos: @checked (n: number) -> number, + asin: @checked (n: number) -> number, + atan2: @checked (y: number, x: number) -> number, + + min: @checked (number, ...number) -> number, + max: @checked (number, ...number) -> number, + + pi: number, + huge: number, + + randomseed: @checked (seed: number) -> (), + random: @checked (number?, number?) -> number, + + sign: @checked (n: number) -> number, + clamp: @checked (n: number, min: number, max: number) -> number, + noise: @checked (x: number, y: number?, z: number?) -> number, + round: @checked (n: number) -> number, +} + +type DateTypeArg = { + year: number, + month: number, + day: number, + hour: number?, + min: number?, + sec: number?, + isdst: boolean?, +} + +type DateTypeResult = { + year: number, + month: number, + wday: number, + yday: number, + day: number, + hour: number, + min: number, + sec: number, + isdst: boolean, +} + +declare os: { + time: (time: DateTypeArg?) -> number, + date: ((formatString: "*t" | "!*t", time: number?) -> DateTypeResult) & ((formatString: string?, time: number?) -> string), + difftime: (t2: DateTypeResult | number, t1: DateTypeResult | number) -> number, + clock: () -> number, +} + +@checked declare function require(target: any): any + +@checked declare function getfenv(target: any): { [string]: any } + +declare _G: any +declare _VERSION: string + +declare function gcinfo(): number + +declare function print(...: T...) + +declare function type(value: T): string +declare function typeof(value: T): string + +-- `assert` has a magic function attached that will give more detailed type information +declare function assert(value: T, errorMessage: string?): T +declare function error(message: T, level: number?): never + +declare function tostring(value: T): string +declare function tonumber(value: T, radix: number?): number? + +declare function rawequal(a: T1, b: T2): boolean +declare function rawget(tab: {[K]: V}, k: K): V +declare function rawset(tab: {[K]: V}, k: K, v: V): {[K]: V} +declare function rawlen(obj: {[K]: V} | string): number + +declare function setfenv(target: number | (T...) -> R..., env: {[string]: any}): ((T...) -> R...)? + +declare function ipairs(tab: {V}): (({V}, number) -> (number?, V), {V}, number) + +declare function pcall(f: (A...) -> R..., ...: A...): (boolean, R...) + +-- FIXME: The actual type of `xpcall` is: +-- (f: (A...) -> R1..., err: (E) -> R2..., A...) -> (true, R1...) | (false, R2...) +-- Since we can't represent the return value, we use (boolean, R1...). +declare function xpcall(f: (A...) -> R1..., err: (E) -> R2..., ...: A...): (boolean, R1...) + +-- `select` has a magic function attached to provide more detailed type information +declare function select(i: string | number, ...: A...): ...any + +-- FIXME: This type is not entirely correct - `loadstring` returns a function or +-- (nil, string). +declare function loadstring(src: string, chunkname: string?): (((A...) -> any)?, string?) + +@checked declare function newproxy(mt: boolean?): any + +declare coroutine: { + create: (f: (A...) -> R...) -> thread, + resume: (co: thread, A...) -> (boolean, R...), + running: () -> thread, + status: @checked (co: thread) -> "dead" | "running" | "normal" | "suspended", + wrap: (f: (A...) -> R...) -> ((A...) -> R...), + yield: (A...) -> R..., + isyieldable: () -> boolean, + close: @checked (co: thread) -> (boolean, any) +} + +declare table: { + concat: (t: {V}, sep: string?, i: number?, j: number?) -> string, + insert: ((t: {V}, value: V) -> ()) & ((t: {V}, pos: number, value: V) -> ()), + maxn: (t: {V}) -> number, + remove: (t: {V}, number?) -> V?, + sort: (t: {V}, comp: ((V, V) -> boolean)?) -> (), + create: (count: number, value: V?) -> {V}, + find: (haystack: {V}, needle: V, init: number?) -> number?, + + unpack: (list: {V}, i: number?, j: number?) -> ...V, + pack: (...V) -> { n: number, [number]: V }, + + getn: (t: {V}) -> number, + foreach: (t: {[K]: V}, f: (K, V) -> ()) -> (), + foreachi: ({V}, (number, V) -> ()) -> (), + + move: (src: {V}, a: number, b: number, t: number, dst: {V}?) -> {V}, + clear: (table: {[K]: V}) -> (), + + isfrozen: (t: {[K]: V}) -> boolean, +} + +declare debug: { + info: ((thread: thread, level: number, options: string) -> R...) & ((level: number, options: string) -> R...) & ((func: (A...) -> R1..., options: string) -> R2...), + traceback: ((message: string?, level: number?) -> string) & ((thread: thread, message: string?, level: number?) -> string), +} + +declare utf8: { + char: @checked (...number) -> string, + charpattern: string, + codes: @checked (str: string) -> ((string, number) -> (number, number), string, number), + codepoint: @checked (str: string, i: number?, j: number?) -> ...number, + len: @checked (s: string, i: number?, j: number?) -> (number?, number?), + offset: @checked (s: string, n: number?, i: number?) -> number, +} + +-- Cannot use `typeof` here because it will produce a polytype when we expect a monotype. +declare function unpack(tab: {V}, i: number?, j: number?): ...V + + +--- Buffer API +declare buffer: { + create: @checked (size: number) -> buffer, + fromstring: @checked (str: string) -> buffer, + tostring: @checked (b: buffer) -> string, + len: @checked (b: buffer) -> number, + copy: @checked (target: buffer, targetOffset: number, source: buffer, sourceOffset: number?, count: number?) -> (), + fill: @checked (b: buffer, offset: number, value: number, count: number?) -> (), + readi8: @checked (b: buffer, offset: number) -> number, + readu8: @checked (b: buffer, offset: number) -> number, + readi16: @checked (b: buffer, offset: number) -> number, + readu16: @checked (b: buffer, offset: number) -> number, + readi32: @checked (b: buffer, offset: number) -> number, + readu32: @checked (b: buffer, offset: number) -> number, + readf32: @checked (b: buffer, offset: number) -> number, + readf64: @checked (b: buffer, offset: number) -> number, + writei8: @checked (b: buffer, offset: number, value: number) -> (), + writeu8: @checked (b: buffer, offset: number, value: number) -> (), + writei16: @checked (b: buffer, offset: number, value: number) -> (), + writeu16: @checked (b: buffer, offset: number, value: number) -> (), + writei32: @checked (b: buffer, offset: number, value: number) -> (), + writeu32: @checked (b: buffer, offset: number, value: number) -> (), + writef32: @checked (b: buffer, offset: number, value: number) -> (), + writef64: @checked (b: buffer, offset: number, value: number) -> (), + readstring: @checked (b: buffer, offset: number, count: number) -> string, + writestring: @checked (b: buffer, offset: number, value: string, count: number?) -> (), +} + +)BUILTIN_SRC"; + +static const std::string kBuiltinDefinitionLuaSrcChecked = R"BUILTIN_SRC( declare bit32: { - band: (...number) -> number, - bor: (...number) -> number, - bxor: (...number) -> number, - btest: (number, ...number) -> boolean, - rrotate: (x: number, disp: number) -> number, - lrotate: (x: number, disp: number) -> number, - lshift: (x: number, disp: number) -> number, - arshift: (x: number, disp: number) -> number, - rshift: (x: number, disp: number) -> number, - bnot: (x: number) -> number, - extract: (n: number, field: number, width: number?) -> number, - replace: (n: number, v: number, field: number, width: number?) -> number, - countlz: (n: number) -> number, - countrz: (n: number) -> number, + band: @checked (...number) -> number, + bor: @checked (...number) -> number, + bxor: @checked (...number) -> number, + btest: @checked (number, ...number) -> boolean, + rrotate: @checked (x: number, disp: number) -> number, + lrotate: @checked (x: number, disp: number) -> number, + lshift: @checked (x: number, disp: number) -> number, + arshift: @checked (x: number, disp: number) -> number, + rshift: @checked (x: number, disp: number) -> number, + bnot: @checked (x: number) -> number, + extract: @checked (n: number, field: number, width: number?) -> number, + replace: @checked (n: number, v: number, field: number, width: number?) -> number, + countlz: @checked (n: number) -> number, + countrz: @checked (n: number) -> number, + byteswap: @checked (n: number) -> number, } declare math: { - frexp: (n: number) -> (number, number), - ldexp: (s: number, e: number) -> number, - fmod: (x: number, y: number) -> number, - modf: (n: number) -> (number, number), - pow: (x: number, y: number) -> number, - exp: (n: number) -> number, - - ceil: (n: number) -> number, - floor: (n: number) -> number, - abs: (n: number) -> number, - sqrt: (n: number) -> number, - - log: (n: number, base: number?) -> number, - log10: (n: number) -> number, - - rad: (n: number) -> number, - deg: (n: number) -> number, - - sin: (n: number) -> number, - cos: (n: number) -> number, - tan: (n: number) -> number, - sinh: (n: number) -> number, - cosh: (n: number) -> number, - tanh: (n: number) -> number, - atan: (n: number) -> number, - acos: (n: number) -> number, - asin: (n: number) -> number, - atan2: (y: number, x: number) -> number, - - min: (number, ...number) -> number, - max: (number, ...number) -> number, + frexp: @checked (n: number) -> (number, number), + ldexp: @checked (s: number, e: number) -> number, + fmod: @checked (x: number, y: number) -> number, + modf: @checked (n: number) -> (number, number), + pow: @checked (x: number, y: number) -> number, + exp: @checked (n: number) -> number, + + ceil: @checked (n: number) -> number, + floor: @checked (n: number) -> number, + abs: @checked (n: number) -> number, + sqrt: @checked (n: number) -> number, + + log: @checked (n: number, base: number?) -> number, + log10: @checked (n: number) -> number, + + rad: @checked (n: number) -> number, + deg: @checked (n: number) -> number, + + sin: @checked (n: number) -> number, + cos: @checked (n: number) -> number, + tan: @checked (n: number) -> number, + sinh: @checked (n: number) -> number, + cosh: @checked (n: number) -> number, + tanh: @checked (n: number) -> number, + atan: @checked (n: number) -> number, + acos: @checked (n: number) -> number, + asin: @checked (n: number) -> number, + atan2: @checked (y: number, x: number) -> number, + + min: @checked (number, ...number) -> number, + max: @checked (number, ...number) -> number, pi: number, huge: number, - randomseed: (seed: number) -> (), - random: (number?, number?) -> number, + randomseed: @checked (seed: number) -> (), + random: @checked (number?, number?) -> number, - sign: (n: number) -> number, - clamp: (n: number, min: number, max: number) -> number, - noise: (x: number, y: number?, z: number?) -> number, - round: (n: number) -> number, + sign: @checked (n: number) -> number, + clamp: @checked (n: number, min: number, max: number) -> number, + noise: @checked (x: number, y: number?, z: number?) -> number, + round: @checked (n: number) -> number, + map: @checked (x: number, inmin: number, inmax: number, outmin: number, outmax: number) -> number, } type DateTypeArg = { @@ -92,14 +318,14 @@ type DateTypeResult = { declare os: { time: (time: DateTypeArg?) -> number, - date: (formatString: string?, time: number?) -> DateTypeResult | string, + date: ((formatString: "*t" | "!*t", time: number?) -> DateTypeResult) & ((formatString: string?, time: number?) -> string), difftime: (t2: DateTypeResult | number, t1: DateTypeResult | number) -> number, clock: () -> number, } -declare function require(target: any): any +@checked declare function require(target: any): any -declare function getfenv(target: any): { [string]: any } +@checked declare function getfenv(target: any): { [string]: any } declare _G: any declare _VERSION: string @@ -141,18 +367,17 @@ declare function select(i: string | number, ...: A...): ...any -- (nil, string). declare function loadstring(src: string, chunkname: string?): (((A...) -> any)?, string?) -declare function newproxy(mt: boolean?): any +@checked declare function newproxy(mt: boolean?): any declare coroutine: { create: (f: (A...) -> R...) -> thread, resume: (co: thread, A...) -> (boolean, R...), running: () -> thread, - status: (co: thread) -> "dead" | "running" | "normal" | "suspended", - -- FIXME: This technically returns a function, but we can't represent this yet. - wrap: (f: (A...) -> R...) -> any, + status: @checked (co: thread) -> "dead" | "running" | "normal" | "suspended", + wrap: (f: (A...) -> R...) -> ((A...) -> R...), yield: (A...) -> R..., isyieldable: () -> boolean, - close: (co: thread) -> (boolean, any) + close: @checked (co: thread) -> (boolean, any) } declare table: { @@ -183,22 +408,51 @@ declare debug: { } declare utf8: { - char: (...number) -> string, + char: @checked (...number) -> string, charpattern: string, - codes: (str: string) -> ((string, number) -> (number, number), string, number), - codepoint: (str: string, i: number?, j: number?) -> ...number, - len: (s: string, i: number?, j: number?) -> (number?, number?), - offset: (s: string, n: number?, i: number?) -> number, + codes: @checked (str: string) -> ((string, number) -> (number, number), string, number), + codepoint: @checked (str: string, i: number?, j: number?) -> ...number, + len: @checked (s: string, i: number?, j: number?) -> (number?, number?), + offset: @checked (s: string, n: number?, i: number?) -> number, } -- Cannot use `typeof` here because it will produce a polytype when we expect a monotype. declare function unpack(tab: {V}, i: number?, j: number?): ...V + +--- Buffer API +declare buffer: { + create: @checked (size: number) -> buffer, + fromstring: @checked (str: string) -> buffer, + tostring: @checked (b: buffer) -> string, + len: @checked (b: buffer) -> number, + copy: @checked (target: buffer, targetOffset: number, source: buffer, sourceOffset: number?, count: number?) -> (), + fill: @checked (b: buffer, offset: number, value: number, count: number?) -> (), + readi8: @checked (b: buffer, offset: number) -> number, + readu8: @checked (b: buffer, offset: number) -> number, + readi16: @checked (b: buffer, offset: number) -> number, + readu16: @checked (b: buffer, offset: number) -> number, + readi32: @checked (b: buffer, offset: number) -> number, + readu32: @checked (b: buffer, offset: number) -> number, + readf32: @checked (b: buffer, offset: number) -> number, + readf64: @checked (b: buffer, offset: number) -> number, + writei8: @checked (b: buffer, offset: number, value: number) -> (), + writeu8: @checked (b: buffer, offset: number, value: number) -> (), + writei16: @checked (b: buffer, offset: number, value: number) -> (), + writeu16: @checked (b: buffer, offset: number, value: number) -> (), + writei32: @checked (b: buffer, offset: number, value: number) -> (), + writeu32: @checked (b: buffer, offset: number, value: number) -> (), + writef32: @checked (b: buffer, offset: number, value: number) -> (), + writef64: @checked (b: buffer, offset: number, value: number) -> (), + readstring: @checked (b: buffer, offset: number, count: number) -> string, + writestring: @checked (b: buffer, offset: number, value: string, count: number?) -> (), +} + )BUILTIN_SRC"; std::string getBuiltinDefinitionSource() { - std::string result = kBuiltinDefinitionLuaSrc; + std::string result = FFlag::LuauMathMap ? kBuiltinDefinitionLuaSrcChecked : kBuiltinDefinitionLuaSrcChecked_DEPRECATED; return result; } diff --git a/Analysis/src/Error.cpp b/Analysis/src/Error.cpp index a527b2440..66b61d6bc 100644 --- a/Analysis/src/Error.cpp +++ b/Analysis/src/Error.cpp @@ -4,16 +4,27 @@ #include "Luau/Clone.h" #include "Luau/Common.h" #include "Luau/FileResolver.h" +#include "Luau/NotNull.h" #include "Luau/StringUtils.h" #include "Luau/ToString.h" +#include "Luau/Type.h" +#include "Luau/TypeFunction.h" +#include #include +#include #include +#include -LUAU_FASTFLAGVARIABLE(LuauTypeMismatchInvarianceInError, false) +LUAU_FASTINTVARIABLE(LuauIndentTypeMismatchMaxTypeLength, 10) static std::string wrongNumberOfArgsString( - size_t expectedCount, std::optional maximumCount, size_t actualCount, const char* argPrefix = nullptr, bool isVariadic = false) + size_t expectedCount, + std::optional maximumCount, + size_t actualCount, + const char* argPrefix = nullptr, + bool isVariadic = false +) { std::string s = "expects "; @@ -56,6 +67,30 @@ static std::string wrongNumberOfArgsString( namespace Luau { +// this list of binary operator type functions is used for better stringification of type functions errors +static const std::unordered_map kBinaryOps{ + {"add", "+"}, + {"sub", "-"}, + {"mul", "*"}, + {"div", "/"}, + {"idiv", "//"}, + {"pow", "^"}, + {"mod", "%"}, + {"concat", ".."}, + {"and", "and"}, + {"or", "or"}, + {"lt", "< or >="}, + {"le", "<= or >"}, + {"eq", "== or ~="} +}; + +// this list of unary operator type functions is used for better stringification of type functions errors +static const std::unordered_map kUnaryOps{{"unm", "-"}, {"len", "#"}, {"not", "not"}}; + +// this list of type functions will receive a special error indicating that the user should file a bug on the GitHub repository +// putting a type function in this list indicates that it is expected to _always_ reduce +static const std::unordered_set kUnreachableTypeFunctions{"refine", "singleton", "union", "intersect"}; + struct ErrorConverter { FileResolver* fileResolver = nullptr; @@ -67,6 +102,23 @@ struct ErrorConverter std::string result; + auto quote = [&](std::string s) + { + return "'" + s + "'"; + }; + + auto constructErrorMessage = + [&](std::string givenType, std::string wantedType, std::optional givenModule, std::optional wantedModule + ) -> std::string + { + std::string given = givenModule ? quote(givenType) + " from " + quote(*givenModule) : quote(givenType); + std::string wanted = wantedModule ? quote(wantedType) + " from " + quote(*wantedModule) : quote(wantedType); + size_t luauIndentTypeMismatchMaxTypeLength = size_t(FInt::LuauIndentTypeMismatchMaxTypeLength); + if (givenType.length() <= luauIndentTypeMismatchMaxTypeLength || wantedType.length() <= luauIndentTypeMismatchMaxTypeLength) + return "Type " + given + " could not be converted into " + wanted; + return "Type\n " + given + "\ncould not be converted into\n " + wanted; + }; + if (givenTypeName == wantedTypeName) { if (auto givenDefinitionModule = getDefinitionModuleName(tm.givenType)) @@ -77,20 +129,18 @@ struct ErrorConverter { std::string givenModuleName = fileResolver->getHumanReadableModuleName(*givenDefinitionModule); std::string wantedModuleName = fileResolver->getHumanReadableModuleName(*wantedDefinitionModule); - result = "Type '" + givenTypeName + "' from '" + givenModuleName + "' could not be converted into '" + wantedTypeName + - "' from '" + wantedModuleName + "'"; + result = constructErrorMessage(givenTypeName, wantedTypeName, givenModuleName, wantedModuleName); } else { - result = "Type '" + givenTypeName + "' from '" + *givenDefinitionModule + "' could not be converted into '" + wantedTypeName + - "' from '" + *wantedDefinitionModule + "'"; + result = constructErrorMessage(givenTypeName, wantedTypeName, *givenDefinitionModule, *wantedDefinitionModule); } } } } if (result.empty()) - result = "Type '" + givenTypeName + "' could not be converted into '" + wantedTypeName + "'"; + result = constructErrorMessage(givenTypeName, wantedTypeName, std::nullopt, std::nullopt); if (tm.error) @@ -98,7 +148,7 @@ struct ErrorConverter result += "\ncaused by:\n "; if (!tm.reason.empty()) - result += tm.reason + " "; + result += tm.reason + "\n"; result += Luau::toString(*tm.error, TypeErrorToStringOptions{fileResolver}); } @@ -106,7 +156,7 @@ struct ErrorConverter { result += "; " + tm.reason; } - else if (FFlag::LuauTypeMismatchInvarianceInError && tm.context == TypeMismatch::InvariantContext) + else if (tm.context == TypeMismatch::InvariantContext) { result += " in an invariant context"; } @@ -320,9 +370,66 @@ struct ErrorConverter return e.message; } + std::string operator()(const Luau::ConstraintSolvingIncompleteError& e) const + { + return "Type inference failed to complete, you may see some confusing types and type errors."; + } + + std::optional findCallMetamethod(TypeId type) const + { + type = follow(type); + + std::optional metatable; + if (const MetatableType* mtType = get(type)) + metatable = mtType->metatable; + else if (const ClassType* classType = get(type)) + metatable = classType->metatable; + + if (!metatable) + return std::nullopt; + + TypeId unwrapped = follow(*metatable); + + if (get(unwrapped)) + return unwrapped; + + const TableType* mtt = getTableType(unwrapped); + if (!mtt) + return std::nullopt; + + auto it = mtt->props.find("__call"); + if (it != mtt->props.end()) + return it->second.type(); + else + return std::nullopt; + } + std::string operator()(const Luau::CannotCallNonFunction& e) const { - return "Cannot call non-function " + toString(e.ty); + if (auto unionTy = get(follow(e.ty))) + { + std::string err = "Cannot call a value of the union type:"; + + for (auto option : unionTy) + { + option = follow(option); + + if (get(option) || findCallMetamethod(option)) + { + err += "\n | " + toString(option); + continue; + } + + // early-exit if we find something that isn't callable in the union. + return "Cannot call a value of type " + toString(option) + " in union:\n " + toString(e.ty); + } + + err += "\nWe are unable to determine the appropriate result type for such a call."; + + return err; + } + + return "Cannot call a value of type " + toString(e.ty); } std::string operator()(const Luau::ExtraInformation& e) const { @@ -349,7 +456,10 @@ struct ErrorConverter else s += " -> "; - s += name; + if (fileResolver != nullptr) + s += fileResolver->getHumanReadableModuleName(name); + else + s += name; } return s; @@ -473,13 +583,232 @@ struct ErrorConverter std::string operator()(const TypePackMismatch& e) const { - return "Type pack '" + toString(e.givenTp) + "' could not be converted into '" + toString(e.wantedTp) + "'"; + std::string ss = "Type pack '" + toString(e.givenTp) + "' could not be converted into '" + toString(e.wantedTp) + "'"; + + if (!e.reason.empty()) + ss += "; " + e.reason; + + return ss; } std::string operator()(const DynamicPropertyLookupOnClassesUnsafe& e) const { return "Attempting a dynamic property access on type '" + Luau::toString(e.ty) + "' is unsafe and may cause exceptions at runtime"; } + + std::string operator()(const UninhabitedTypeFunction& e) const + { + auto tfit = get(e.ty); + LUAU_ASSERT(tfit); // Luau analysis has actually done something wrong if this type is not a type function. + if (!tfit) + return "Unexpected type " + Luau::toString(e.ty) + " flagged as an uninhabited type function."; + + // unary operators + if (auto unaryString = kUnaryOps.find(tfit->function->name); unaryString != kUnaryOps.end()) + { + std::string result = "Operator '" + std::string(unaryString->second) + "' could not be applied to "; + + if (tfit->typeArguments.size() == 1 && tfit->packArguments.empty()) + { + result += "operand of type " + Luau::toString(tfit->typeArguments[0]); + + if (tfit->function->name != "not") + result += "; there is no corresponding overload for __" + tfit->function->name; + } + else + { + // if it's not the expected case, we ought to add a specialization later, but this is a sane default. + result += "operands of types "; + + bool isFirst = true; + for (auto arg : tfit->typeArguments) + { + if (!isFirst) + result += ", "; + + result += Luau::toString(arg); + isFirst = false; + } + + for (auto packArg : tfit->packArguments) + result += ", " + Luau::toString(packArg); + } + + return result; + } + + // binary operators + if (auto binaryString = kBinaryOps.find(tfit->function->name); binaryString != kBinaryOps.end()) + { + std::string result = "Operator '" + std::string(binaryString->second) + "' could not be applied to operands of types "; + + if (tfit->typeArguments.size() == 2 && tfit->packArguments.empty()) + { + // this is the expected case. + result += Luau::toString(tfit->typeArguments[0]) + " and " + Luau::toString(tfit->typeArguments[1]); + } + else + { + // if it's not the expected case, we ought to add a specialization later, but this is a sane default. + + bool isFirst = true; + for (auto arg : tfit->typeArguments) + { + if (!isFirst) + result += ", "; + + result += Luau::toString(arg); + isFirst = false; + } + + for (auto packArg : tfit->packArguments) + result += ", " + Luau::toString(packArg); + } + + result += "; there is no corresponding overload for __" + tfit->function->name; + + return result; + } + + // miscellaneous + + if ("keyof" == tfit->function->name || "rawkeyof" == tfit->function->name) + { + if (tfit->typeArguments.size() == 1 && tfit->packArguments.empty()) + return "Type '" + toString(tfit->typeArguments[0]) + "' does not have keys, so '" + Luau::toString(e.ty) + "' is invalid"; + else + return "Type function instance " + Luau::toString(e.ty) + " is ill-formed, and thus invalid"; + } + + if ("index" == tfit->function->name || "rawget" == tfit->function->name) + { + if (tfit->typeArguments.size() != 2) + return "Type function instance " + Luau::toString(e.ty) + " is ill-formed, and thus invalid"; + + if (auto errType = get(tfit->typeArguments[1])) // Second argument to (index | rawget)<_,_> is not a type + return "Second argument to " + tfit->function->name + "<" + Luau::toString(tfit->typeArguments[0]) + ", _> is not a valid index type"; + else // Property `indexer` does not exist on type `indexee` + return "Property '" + Luau::toString(tfit->typeArguments[1]) + "' does not exist on type '" + Luau::toString(tfit->typeArguments[0]) + + "'"; + } + + if (kUnreachableTypeFunctions.count(tfit->function->name)) + { + return "Type function instance " + Luau::toString(e.ty) + " is uninhabited\n" + + "This is likely to be a bug, please report it at https://github.com/luau-lang/luau/issues"; + } + + // Everything should be specialized above to report a more descriptive error that hopefully does not mention "type functions" explicitly. + // If we produce this message, it's an indication that we've missed a specialization and it should be fixed! + return "Type function instance " + Luau::toString(e.ty) + " is uninhabited"; + } + + std::string operator()(const ExplicitFunctionAnnotationRecommended& r) const + { + std::string toReturn = toString(r.recommendedReturn); + std::string argAnnotations; + for (auto [arg, type] : r.recommendedArgs) + { + argAnnotations += arg + ": " + toString(type) + ", "; + } + if (argAnnotations.length() >= 2) + { + argAnnotations.pop_back(); + argAnnotations.pop_back(); + } + + if (argAnnotations.empty()) + return "Consider annotating the return with " + toReturn; + + return "Consider placing the following annotations on the arguments: " + argAnnotations + " or instead annotating the return as " + toReturn; + } + + std::string operator()(const UninhabitedTypePackFunction& e) const + { + return "Type pack function instance " + Luau::toString(e.tp) + " is uninhabited"; + } + + std::string operator()(const WhereClauseNeeded& e) const + { + return "Type function instance " + Luau::toString(e.ty) + + " depends on generic function parameters but does not appear in the function signature; this construct cannot be type-checked at this " + "time"; + } + + std::string operator()(const PackWhereClauseNeeded& e) const + { + return "Type pack function instance " + Luau::toString(e.tp) + + " depends on generic function parameters but does not appear in the function signature; this construct cannot be type-checked at this " + "time"; + } + + std::string operator()(const CheckedFunctionCallError& e) const + { + // TODO: What happens if checkedFunctionName cannot be found?? + return "Function '" + e.checkedFunctionName + "' expects '" + toString(e.expected) + "' at argument #" + std::to_string(e.argumentIndex) + + ", but got '" + Luau::toString(e.passed) + "'"; + } + + std::string operator()(const NonStrictFunctionDefinitionError& e) const + { + return "Argument " + e.argument + " with type '" + toString(e.argumentType) + "' in function '" + e.functionName + + "' is used in a way that will run time error"; + } + + std::string operator()(const PropertyAccessViolation& e) const + { + const std::string stringKey = isIdentifier(e.key) ? e.key : "\"" + e.key + "\""; + switch (e.context) + { + case PropertyAccessViolation::CannotRead: + return "Property " + stringKey + " of table '" + toString(e.table) + "' is write-only"; + case PropertyAccessViolation::CannotWrite: + return "Property " + stringKey + " of table '" + toString(e.table) + "' is read-only"; + } + + LUAU_UNREACHABLE(); + return ""; + } + + std::string operator()(const CheckedFunctionIncorrectArgs& e) const + { + return "Checked Function " + e.functionName + " expects " + std::to_string(e.expected) + " arguments, but received " + + std::to_string(e.actual); + } + + std::string operator()(const UnexpectedTypeInSubtyping& e) const + { + return "Encountered an unexpected type in subtyping: " + toString(e.ty); + } + + std::string operator()(const UnexpectedTypePackInSubtyping& e) const + { + return "Encountered an unexpected type pack in subtyping: " + toString(e.tp); + } + + std::string operator()(const UserDefinedTypeFunctionError& e) const + { + return e.message; + } + + std::string operator()(const CannotAssignToNever& e) const + { + std::string result = "Cannot assign a value of type " + toString(e.rhsType) + " to a field of type never"; + + switch (e.reason) + { + case CannotAssignToNever::Reason::PropertyNarrowed: + if (!e.cause.empty()) + { + result += "\ncaused by the property being given the following incompatible types:\n"; + for (auto ty : e.cause) + result += " " + toString(ty) + "\n"; + result += "There are no values that could safely satisfy all of these types at once."; + } + } + + return result; + } }; struct InvalidNameChecker @@ -572,6 +901,11 @@ bool UnknownProperty::operator==(const UnknownProperty& rhs) const return *table == *rhs.table && key == rhs.key; } +bool PropertyAccessViolation::operator==(const PropertyAccessViolation& rhs) const +{ + return *table == *rhs.table && key == rhs.key && context == rhs.context; +} + bool NotATable::operator==(const NotATable& rhs) const { return ty == rhs.ty; @@ -677,6 +1011,11 @@ bool InternalError::operator==(const InternalError& rhs) const return message == rhs.message; } +bool ConstraintSolvingIncompleteError::operator==(const ConstraintSolvingIncompleteError& rhs) const +{ + return true; +} + bool CannotCallNonFunction::operator==(const CannotCallNonFunction& rhs) const { return ty == rhs.ty; @@ -782,6 +1121,77 @@ bool DynamicPropertyLookupOnClassesUnsafe::operator==(const DynamicPropertyLooku return ty == rhs.ty; } +bool UninhabitedTypeFunction::operator==(const UninhabitedTypeFunction& rhs) const +{ + return ty == rhs.ty; +} + + +bool ExplicitFunctionAnnotationRecommended::operator==(const ExplicitFunctionAnnotationRecommended& rhs) const +{ + return recommendedReturn == rhs.recommendedReturn && recommendedArgs == rhs.recommendedArgs; +} + +bool UninhabitedTypePackFunction::operator==(const UninhabitedTypePackFunction& rhs) const +{ + return tp == rhs.tp; +} + +bool WhereClauseNeeded::operator==(const WhereClauseNeeded& rhs) const +{ + return ty == rhs.ty; +} + +bool PackWhereClauseNeeded::operator==(const PackWhereClauseNeeded& rhs) const +{ + return tp == rhs.tp; +} + +bool CheckedFunctionCallError::operator==(const CheckedFunctionCallError& rhs) const +{ + return *expected == *rhs.expected && *passed == *rhs.passed && checkedFunctionName == rhs.checkedFunctionName && + argumentIndex == rhs.argumentIndex; +} + +bool NonStrictFunctionDefinitionError::operator==(const NonStrictFunctionDefinitionError& rhs) const +{ + return functionName == rhs.functionName && argument == rhs.argument && argumentType == rhs.argumentType; +} + +bool CheckedFunctionIncorrectArgs::operator==(const CheckedFunctionIncorrectArgs& rhs) const +{ + return functionName == rhs.functionName && expected == rhs.expected && actual == rhs.actual; +} + +bool UnexpectedTypeInSubtyping::operator==(const UnexpectedTypeInSubtyping& rhs) const +{ + return ty == rhs.ty; +} + +bool UnexpectedTypePackInSubtyping::operator==(const UnexpectedTypePackInSubtyping& rhs) const +{ + return tp == rhs.tp; +} + +bool UserDefinedTypeFunctionError::operator==(const UserDefinedTypeFunctionError& rhs) const +{ + return message == rhs.message; +} + +bool CannotAssignToNever::operator==(const CannotAssignToNever& rhs) const +{ + if (cause.size() != rhs.cause.size()) + return false; + + for (size_t i = 0; i < cause.size(); ++i) + { + if (*cause[i] != *rhs.cause[i]) + return false; + } + + return *rhsType == *rhs.rhsType && reason == rhs.reason; +} + std::string toString(const TypeError& error) { return toString(error, TypeErrorToStringOptions{}); @@ -799,13 +1209,15 @@ bool containsParseErrorName(const TypeError& error) } template -void copyError(T& e, TypeArena& destArena, CloneState cloneState) +void copyError(T& e, TypeArena& destArena, CloneState& cloneState) { - auto clone = [&](auto&& ty) { + auto clone = [&](auto&& ty) + { return ::Luau::clone(ty, destArena, cloneState); }; - auto visitErrorData = [&](auto&& e) { + auto visitErrorData = [&](auto&& e) + { copyError(e, destArena, cloneState); }; @@ -880,6 +1292,9 @@ void copyError(T& e, TypeArena& destArena, CloneState cloneState) else if constexpr (std::is_same_v) { } + else if constexpr (std::is_same_v) + { + } else if constexpr (std::is_same_v) { e.ty = clone(e.ty); @@ -940,15 +1355,58 @@ void copyError(T& e, TypeArena& destArena, CloneState cloneState) } else if constexpr (std::is_same_v) e.ty = clone(e.ty); + else if constexpr (std::is_same_v) + e.ty = clone(e.ty); + else if constexpr (std::is_same_v) + { + e.recommendedReturn = clone(e.recommendedReturn); + for (auto& [_, t] : e.recommendedArgs) + t = clone(t); + } + else if constexpr (std::is_same_v) + e.tp = clone(e.tp); + else if constexpr (std::is_same_v) + e.ty = clone(e.ty); + else if constexpr (std::is_same_v) + e.tp = clone(e.tp); + else if constexpr (std::is_same_v) + { + e.expected = clone(e.expected); + e.passed = clone(e.passed); + } + else if constexpr (std::is_same_v) + { + e.argumentType = clone(e.argumentType); + } + else if constexpr (std::is_same_v) + e.table = clone(e.table); + else if constexpr (std::is_same_v) + { + } + else if constexpr (std::is_same_v) + e.ty = clone(e.ty); + else if constexpr (std::is_same_v) + e.tp = clone(e.tp); + else if constexpr (std::is_same_v) + { + } + else if constexpr (std::is_same_v) + { + e.rhsType = clone(e.rhsType); + + for (auto& ty : e.cause) + ty = clone(ty); + } else static_assert(always_false_v, "Non-exhaustive type switch"); } -void copyErrors(ErrorVec& errors, TypeArena& destArena) +void copyErrors(ErrorVec& errors, TypeArena& destArena, NotNull builtinTypes) { - CloneState cloneState; + CloneState cloneState{builtinTypes}; - auto visitErrorData = [&](auto&& e) { + auto visitErrorData = [&](auto&& e) + { copyError(e, destArena, cloneState); }; @@ -959,7 +1417,7 @@ void copyErrors(ErrorVec& errors, TypeArena& destArena) visit(visitErrorData, error.data); } -void InternalErrorReporter::ice(const std::string& message, const Location& location) +void InternalErrorReporter::ice(const std::string& message, const Location& location) const { InternalCompilerError error(message, moduleName, location); @@ -969,7 +1427,7 @@ void InternalErrorReporter::ice(const std::string& message, const Location& loca throw error; } -void InternalErrorReporter::ice(const std::string& message) +void InternalErrorReporter::ice(const std::string& message) const { InternalCompilerError error(message, moduleName); diff --git a/Analysis/src/FragmentAutocomplete.cpp b/Analysis/src/FragmentAutocomplete.cpp new file mode 100644 index 000000000..d4f3ebd99 --- /dev/null +++ b/Analysis/src/FragmentAutocomplete.cpp @@ -0,0 +1,353 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/FragmentAutocomplete.h" + +#include "Luau/Ast.h" +#include "Luau/AstQuery.h" +#include "Luau/Common.h" +#include "Luau/Parser.h" +#include "Luau/ParseOptions.h" +#include "Luau/Module.h" +#include "Luau/TimeTrace.h" +#include "Luau/UnifierSharedState.h" +#include "Luau/TypeFunction.h" +#include "Luau/DataFlowGraph.h" +#include "Luau/ConstraintGenerator.h" +#include "Luau/ConstraintSolver.h" +#include "Luau/Frontend.h" +#include "Luau/Parser.h" +#include "Luau/ParseOptions.h" +#include "Luau/Module.h" + +LUAU_FASTINT(LuauTypeInferRecursionLimit); +LUAU_FASTINT(LuauTypeInferIterationLimit); +LUAU_FASTINT(LuauTarjanChildLimit) +LUAU_FASTFLAG(LuauAllowFragmentParsing); +LUAU_FASTFLAG(LuauStoreDFGOnModule2); + +namespace +{ +template +void copyModuleVec(std::vector& result, const std::vector& input) +{ + result.insert(result.end(), input.begin(), input.end()); +} + +template +void copyModuleMap(Luau::DenseHashMap& result, const Luau::DenseHashMap& input) +{ + for (auto [k, v] : input) + result[k] = v; +} + +} // namespace + + +namespace Luau +{ + +FragmentAutocompleteAncestryResult findAncestryForFragmentParse(AstStatBlock* root, const Position& cursorPos) +{ + std::vector ancestry = findAncestryAtPositionForAutocomplete(root, cursorPos); + // Should always contain the root AstStat + LUAU_ASSERT(ancestry.size() >= 1); + DenseHashMap localMap{AstName()}; + std::vector localStack; + AstStat* nearestStatement = nullptr; + for (AstNode* node : ancestry) + { + if (auto block = node->as()) + { + for (auto stat : block->body) + { + if (stat->location.begin <= cursorPos) + nearestStatement = stat; + if (stat->location.begin < cursorPos && stat->location.begin.line < cursorPos.line) + { + // This statement precedes the current one + if (auto loc = stat->as()) + { + for (auto v : loc->vars) + { + localStack.push_back(v); + localMap[v->name] = v; + } + } + else if (auto locFun = stat->as()) + { + localStack.push_back(locFun->name); + localMap[locFun->name->name] = locFun->name; + } + } + } + } + } + + if (!nearestStatement) + nearestStatement = ancestry[0]->asStat(); + LUAU_ASSERT(nearestStatement); + return {std::move(localMap), std::move(localStack), std::move(ancestry), std::move(nearestStatement)}; +} + +std::pair getDocumentOffsets(const std::string_view& src, const Position& startPos, const Position& endPos) +{ + unsigned int lineCount = 0; + unsigned int colCount = 0; + + unsigned int docOffset = 0; + unsigned int startOffset = 0; + unsigned int endOffset = 0; + bool foundStart = false; + bool foundEnd = false; + for (char c : src) + { + if (foundStart && foundEnd) + break; + + if (startPos.line == lineCount && startPos.column == colCount) + { + foundStart = true; + startOffset = docOffset; + } + + if (endPos.line == lineCount && endPos.column == colCount) + { + endOffset = docOffset; + foundEnd = true; + } + + if (c == '\n') + { + lineCount++; + colCount = 0; + } + else + colCount++; + docOffset++; + } + + + unsigned int min = std::min(startOffset, endOffset); + unsigned int len = std::max(startOffset, endOffset) - min; + return {min, len}; +} + +ScopePtr findClosestScope(const ModulePtr& module, const Position& cursorPos) +{ + LUAU_ASSERT(module->hasModuleScope()); + + ScopePtr closest = module->getModuleScope(); + for (auto [loc, sc] : module->scopes) + { + if (loc.begin <= cursorPos && closest->location.begin <= loc.begin) + closest = sc; + } + + return closest; +} + +FragmentParseResult parseFragment(const SourceModule& srcModule, std::string_view src, const Position& cursorPos) +{ + FragmentAutocompleteAncestryResult result = findAncestryForFragmentParse(srcModule.root, cursorPos); + ParseOptions opts; + opts.allowDeclarationSyntax = false; + opts.captureComments = false; + opts.parseFragment = FragmentParseResumeSettings{std::move(result.localMap), std::move(result.localStack)}; + AstStat* enclosingStatement = result.nearestStatement; + + const Position& endPos = cursorPos; + // If the statement starts on a previous line, grab the statement beginning + // otherwise, grab the statement end to whatever is being typed right now + const Position& startPos = + enclosingStatement->location.begin.line == cursorPos.line ? enclosingStatement->location.begin : enclosingStatement->location.end; + + auto [offsetStart, parseLength] = getDocumentOffsets(src, startPos, endPos); + + const char* srcStart = src.data() + offsetStart; + std::string_view dbg = src.substr(offsetStart, parseLength); + const std::shared_ptr& nameTbl = srcModule.names; + FragmentParseResult fragmentResult; + fragmentResult.fragmentToParse = std::string(dbg.data(), parseLength); + // For the duration of the incremental parse, we want to allow the name table to re-use duplicate names + ParseResult p = Luau::Parser::parse(srcStart, parseLength, *nameTbl, *fragmentResult.alloc.get(), opts); + + std::vector fabricatedAncestry = std::move(result.ancestry); + std::vector fragmentAncestry = findAncestryAtPositionForAutocomplete(p.root, p.root->location.end); + fabricatedAncestry.insert(fabricatedAncestry.end(), fragmentAncestry.begin(), fragmentAncestry.end()); + if (enclosingStatement == nullptr) + enclosingStatement = p.root; + fragmentResult.root = std::move(p.root); + fragmentResult.ancestry = std::move(fabricatedAncestry); + return fragmentResult; +} + +ModulePtr copyModule(const ModulePtr& result, std::unique_ptr alloc) +{ + freeze(result->internalTypes); + freeze(result->interfaceTypes); + ModulePtr incrementalModule = std::make_shared(); + incrementalModule->name = result->name; + incrementalModule->humanReadableName = result->humanReadableName; + incrementalModule->allocator = std::move(alloc); + // Don't need to keep this alive (it's already on the source module) + copyModuleVec(incrementalModule->scopes, result->scopes); + copyModuleMap(incrementalModule->astTypes, result->astTypes); + copyModuleMap(incrementalModule->astTypePacks, result->astTypePacks); + copyModuleMap(incrementalModule->astExpectedTypes, result->astExpectedTypes); + // Don't need to clone astOriginalCallTypes + copyModuleMap(incrementalModule->astOverloadResolvedTypes, result->astOverloadResolvedTypes); + // Don't need to clone astForInNextTypes + copyModuleMap(incrementalModule->astForInNextTypes, result->astForInNextTypes); + // Don't need to clone astResolvedTypes + // Don't need to clone astResolvedTypePacks + // Don't need to clone upperBoundContributors + copyModuleMap(incrementalModule->astScopes, result->astScopes); + // Don't need to clone declared Globals; + return incrementalModule; +} + +FragmentTypeCheckResult typeCheckFragmentHelper( + Frontend& frontend, + AstStatBlock* root, + const ModulePtr& stale, + const ScopePtr& closestScope, + const Position& cursorPos, + std::unique_ptr astAllocator, + const FrontendOptions& opts +) +{ + freeze(stale->internalTypes); + freeze(stale->interfaceTypes); + ModulePtr incrementalModule = copyModule(stale, std::move(astAllocator)); + unfreeze(incrementalModule->internalTypes); + unfreeze(incrementalModule->interfaceTypes); + + /// Setup typecheck limits + TypeCheckLimits limits; + if (opts.moduleTimeLimitSec) + limits.finishTime = TimeTrace::getClock() + *opts.moduleTimeLimitSec; + else + limits.finishTime = std::nullopt; + limits.cancellationToken = opts.cancellationToken; + + /// Icehandler + NotNull iceHandler{&frontend.iceHandler}; + /// Make the shared state for the unifier (recursion + iteration limits) + UnifierSharedState unifierState{iceHandler}; + unifierState.counters.recursionLimit = FInt::LuauTypeInferRecursionLimit; + unifierState.counters.iterationLimit = limits.unifierIterationLimit.value_or(FInt::LuauTypeInferIterationLimit); + + /// Initialize the normalizer + Normalizer normalizer{&incrementalModule->internalTypes, frontend.builtinTypes, NotNull{&unifierState}}; + + /// User defined type functions runtime + TypeFunctionRuntime typeFunctionRuntime(iceHandler, NotNull{&limits}); + + /// Create a DataFlowGraph just for the surrounding context + auto updatedDfg = DataFlowGraphBuilder::updateGraph(*stale->dataFlowGraph.get(), stale->dfgScopes, root, cursorPos, iceHandler); + + /// Contraint Generator + ConstraintGenerator cg{ + incrementalModule, + NotNull{&normalizer}, + NotNull{&typeFunctionRuntime}, + NotNull{&frontend.moduleResolver}, + frontend.builtinTypes, + iceHandler, + frontend.globals.globalScope, + nullptr, + nullptr, + NotNull{&updatedDfg}, + {} + }; + cg.rootScope = stale->getModuleScope().get(); + // Any additions to the scope must occur in a fresh scope + auto freshChildOfNearestScope = std::make_shared(closestScope); + incrementalModule->scopes.push_back({root->location, freshChildOfNearestScope}); + + // closest Scope -> children = { ...., freshChildOfNearestScope} + // We need to trim nearestChild from the scope hierarcy + closestScope->children.push_back(NotNull{freshChildOfNearestScope.get()}); + // Visit just the root - we know the scope it should be in + cg.visitFragmentRoot(freshChildOfNearestScope, root); + // Trim nearestChild from the closestScope + Scope* back = closestScope->children.back().get(); + LUAU_ASSERT(back == freshChildOfNearestScope.get()); + closestScope->children.pop_back(); + + /// Initialize the constraint solver and run it + ConstraintSolver cs{ + NotNull{&normalizer}, + NotNull{&typeFunctionRuntime}, + NotNull(cg.rootScope), + borrowConstraints(cg.constraints), + incrementalModule->name, + NotNull{&frontend.moduleResolver}, + {}, + nullptr, + NotNull{&updatedDfg}, + limits + }; + + try + { + cs.run(); + } + catch (const TimeLimitError&) + { + stale->timeout = true; + } + catch (const UserCancelError&) + { + stale->cancelled = true; + } + + // In frontend we would forbid internal types + // because this is just for autocomplete, we don't actually care + // We also don't even need to typecheck - just synthesize types as best as we can + + freeze(incrementalModule->internalTypes); + freeze(incrementalModule->interfaceTypes); + return {std::move(incrementalModule), freshChildOfNearestScope.get()}; +} + + +FragmentTypeCheckResult typecheckFragment( + Frontend& frontend, + const ModuleName& moduleName, + const Position& cursorPos, + std::optional opts, + std::string_view src +) +{ + const SourceModule* sourceModule = frontend.getSourceModule(moduleName); + if (!sourceModule) + { + LUAU_ASSERT(!"Expected Source Module for fragment typecheck"); + return {}; + } + + ModulePtr module = frontend.moduleResolver.getModule(moduleName); + const ScopePtr& closestScope = findClosestScope(module, cursorPos); + + + FragmentParseResult r = parseFragment(*sourceModule, src, cursorPos); + FrontendOptions frontendOptions = opts.value_or(frontend.options); + return typeCheckFragmentHelper(frontend, r.root, module, closestScope, cursorPos, std::move(r.alloc), frontendOptions); +} + +AutocompleteResult fragmentAutocomplete( + Frontend& frontend, + std::string_view src, + const ModuleName& moduleName, + Position& cursorPosition, + const FrontendOptions& opts, + StringCompletionCallback callback +) +{ + LUAU_ASSERT(FFlag::LuauSolverV2); + LUAU_ASSERT(FFlag::LuauAllowFragmentParsing); + LUAU_ASSERT(FFlag::LuauStoreDFGOnModule2); + return {}; +} + +} // namespace Luau diff --git a/Analysis/src/Frontend.cpp b/Analysis/src/Frontend.cpp index b3e453db0..4072575a5 100644 --- a/Analysis/src/Frontend.cpp +++ b/Analysis/src/Frontend.cpp @@ -1,26 +1,34 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/Frontend.h" +#include "Luau/AnyTypeSummary.h" #include "Luau/BuiltinDefinitions.h" #include "Luau/Clone.h" #include "Luau/Common.h" #include "Luau/Config.h" -#include "Luau/ConstraintGraphBuilder.h" +#include "Luau/ConstraintGenerator.h" #include "Luau/ConstraintSolver.h" #include "Luau/DataFlowGraph.h" #include "Luau/DcrLogger.h" #include "Luau/FileResolver.h" +#include "Luau/NonStrictTypeChecker.h" #include "Luau/Parser.h" #include "Luau/Scope.h" #include "Luau/StringUtils.h" #include "Luau/TimeTrace.h" +#include "Luau/ToString.h" +#include "Luau/Transpiler.h" +#include "Luau/TypeArena.h" #include "Luau/TypeChecker2.h" #include "Luau/TypeInfer.h" -#include "Luau/TypeReduction.h" #include "Luau/Variant.h" +#include "Luau/VisitType.h" #include #include +#include +#include +#include #include #include @@ -29,13 +37,48 @@ LUAU_FASTINT(LuauTypeInferRecursionLimit) LUAU_FASTINT(LuauTarjanChildLimit) LUAU_FASTFLAG(LuauInferInNoCheckMode) LUAU_FASTFLAGVARIABLE(LuauKnowsTheDataModel3, false) -LUAU_FASTINTVARIABLE(LuauAutocompleteCheckTimeoutMs, 100) -LUAU_FASTFLAGVARIABLE(DebugLuauDeferredConstraintResolution, false) -LUAU_FASTFLAGVARIABLE(DebugLuauLogSolverToJson, false); +LUAU_FASTFLAGVARIABLE(LuauStoreCommentsForDefinitionFiles, false) +LUAU_FASTFLAG(LuauSolverV2) +LUAU_FASTFLAGVARIABLE(DebugLuauLogSolverToJson, false) +LUAU_FASTFLAGVARIABLE(DebugLuauLogSolverToJsonFile, false) +LUAU_FASTFLAGVARIABLE(DebugLuauForbidInternalTypes, false) +LUAU_FASTFLAGVARIABLE(DebugLuauForceStrictMode, false) +LUAU_FASTFLAGVARIABLE(DebugLuauForceNonStrictMode, false) +LUAU_FASTFLAGVARIABLE(LuauUserDefinedTypeFunctionNoEvaluation, false) +LUAU_DYNAMIC_FASTFLAGVARIABLE(LuauRunCustomModuleChecks, false) +LUAU_FASTFLAGVARIABLE(LuauMoreThoroughCycleDetection, false) + +LUAU_FASTFLAG(StudioReportLuauAny2) +LUAU_FASTFLAGVARIABLE(LuauStoreDFGOnModule2, false) namespace Luau { +struct BuildQueueItem +{ + ModuleName name; + ModuleName humanReadableName; + + // Parameters + std::shared_ptr sourceNode; + std::shared_ptr sourceModule; + Config config; + ScopePtr environmentScope; + std::vector requireCycles; + FrontendOptions options; + bool recordJsonLog = false; + + // Queue state + std::vector reverseDeps; + int dirtyDependencies = 0; + bool processing = false; + + // Result + std::exception_ptr exception; + ModulePtr module; + Frontend::Stats stats; +}; + std::optional parseMode(const std::vector& hotcomments) { for (const HotComment& hc : hotcomments) @@ -83,108 +126,45 @@ static void generateDocumentationSymbols(TypeId ty, const std::string& rootName) } } -LoadDefinitionFileResult Frontend::loadDefinitionFile(std::string_view source, const std::string& packageName) +static ParseResult parseSourceForModule(std::string_view source, Luau::SourceModule& sourceModule, bool captureComments) { - if (!FFlag::DebugLuauDeferredConstraintResolution) - return Luau::loadDefinitionFile(typeChecker, typeChecker.globalScope, source, packageName); - - LUAU_TIMETRACE_SCOPE("loadDefinitionFile", "Frontend"); - - Luau::Allocator allocator; - Luau::AstNameTable names(allocator); - ParseOptions options; options.allowDeclarationSyntax = true; + options.captureComments = captureComments; - Luau::ParseResult parseResult = Luau::Parser::parse(source.data(), source.size(), names, allocator, options); - - if (parseResult.errors.size() > 0) - return LoadDefinitionFileResult{false, parseResult, nullptr}; - - Luau::SourceModule module; - module.root = parseResult.root; - module.mode = Mode::Definition; - - ModulePtr checkedModule = check(module, Mode::Definition, {}); - - if (checkedModule->errors.size() > 0) - return LoadDefinitionFileResult{false, parseResult, checkedModule}; - - CloneState cloneState; - - std::vector typesToPersist; - typesToPersist.reserve(checkedModule->declaredGlobals.size() + checkedModule->exportedTypeBindings.size()); - - for (const auto& [name, ty] : checkedModule->declaredGlobals) - { - TypeId globalTy = clone(ty, globalTypes, cloneState); - std::string documentationSymbol = packageName + "/global/" + name; - generateDocumentationSymbols(globalTy, documentationSymbol); - globalScope->bindings[typeChecker.globalNames.names->getOrAdd(name.c_str())] = {globalTy, Location(), false, {}, documentationSymbol}; - - typesToPersist.push_back(globalTy); - } - - for (const auto& [name, ty] : checkedModule->exportedTypeBindings) - { - TypeFun globalTy = clone(ty, globalTypes, cloneState); - std::string documentationSymbol = packageName + "/globaltype/" + name; - generateDocumentationSymbols(globalTy.type, documentationSymbol); - globalScope->exportedTypeBindings[name] = globalTy; - - typesToPersist.push_back(globalTy.type); - } + Luau::ParseResult parseResult = Luau::Parser::parse(source.data(), source.size(), *sourceModule.names, *sourceModule.allocator, options); + sourceModule.root = parseResult.root; + sourceModule.mode = Mode::Definition; - for (TypeId ty : typesToPersist) + if (FFlag::LuauStoreCommentsForDefinitionFiles && options.captureComments) { - persist(ty); + sourceModule.hotcomments = parseResult.hotcomments; + sourceModule.commentLocations = parseResult.commentLocations; } - return LoadDefinitionFileResult{true, parseResult, checkedModule}; + return parseResult; } -LoadDefinitionFileResult loadDefinitionFile(TypeChecker& typeChecker, ScopePtr targetScope, std::string_view source, const std::string& packageName) +static void persistCheckedTypes(ModulePtr checkedModule, GlobalTypes& globals, ScopePtr targetScope, const std::string& packageName) { - LUAU_TIMETRACE_SCOPE("loadDefinitionFile", "Frontend"); - - Luau::Allocator allocator; - Luau::AstNameTable names(allocator); - - ParseOptions options; - options.allowDeclarationSyntax = true; - - Luau::ParseResult parseResult = Luau::Parser::parse(source.data(), source.size(), names, allocator, options); - - if (parseResult.errors.size() > 0) - return LoadDefinitionFileResult{false, parseResult, nullptr}; - - Luau::SourceModule module; - module.root = parseResult.root; - module.mode = Mode::Definition; - - ModulePtr checkedModule = typeChecker.check(module, Mode::Definition); - - if (checkedModule->errors.size() > 0) - return LoadDefinitionFileResult{false, parseResult, checkedModule}; - - CloneState cloneState; + CloneState cloneState{globals.builtinTypes}; std::vector typesToPersist; typesToPersist.reserve(checkedModule->declaredGlobals.size() + checkedModule->exportedTypeBindings.size()); for (const auto& [name, ty] : checkedModule->declaredGlobals) { - TypeId globalTy = clone(ty, typeChecker.globalTypes, cloneState); + TypeId globalTy = clone(ty, globals.globalTypes, cloneState); std::string documentationSymbol = packageName + "/global/" + name; generateDocumentationSymbols(globalTy, documentationSymbol); - targetScope->bindings[typeChecker.globalNames.names->getOrAdd(name.c_str())] = {globalTy, Location(), false, {}, documentationSymbol}; + targetScope->bindings[globals.globalNames.names->getOrAdd(name.c_str())] = {globalTy, Location(), false, {}, documentationSymbol}; typesToPersist.push_back(globalTy); } for (const auto& [name, ty] : checkedModule->exportedTypeBindings) { - TypeFun globalTy = clone(ty, typeChecker.globalTypes, cloneState); + TypeFun globalTy = clone(ty, globals.globalTypes, cloneState); std::string documentationSymbol = packageName + "/globaltype/" + name; generateDocumentationSymbols(globalTy.type, documentationSymbol); targetScope->exportedTypeBindings[name] = globalTy; @@ -196,83 +176,47 @@ LoadDefinitionFileResult loadDefinitionFile(TypeChecker& typeChecker, ScopePtr t { persist(ty); } - - return LoadDefinitionFileResult{true, parseResult, checkedModule}; } -std::vector parsePathExpr(const AstExpr& pathExpr) +LoadDefinitionFileResult Frontend::loadDefinitionFile( + GlobalTypes& globals, + ScopePtr targetScope, + std::string_view source, + const std::string& packageName, + bool captureComments, + bool typeCheckForAutocomplete +) { - const AstExprIndexName* indexName = pathExpr.as(); - if (!indexName) - return {}; - - std::vector segments{indexName->index.value}; - - while (true) - { - if (AstExprIndexName* in = indexName->expr->as()) - { - segments.push_back(in->index.value); - indexName = in; - continue; - } - else if (AstExprGlobal* indexNameAsGlobal = indexName->expr->as()) - { - segments.push_back(indexNameAsGlobal->name.value); - break; - } - else if (AstExprLocal* indexNameAsLocal = indexName->expr->as()) - { - segments.push_back(indexNameAsLocal->local->name.value); - break; - } - else - return {}; - } - - std::reverse(segments.begin(), segments.end()); - return segments; -} - -std::optional pathExprToModuleName(const ModuleName& currentModuleName, const std::vector& segments) -{ - if (segments.empty()) - return std::nullopt; + LUAU_TIMETRACE_SCOPE("loadDefinitionFile", "Frontend"); - std::vector result; + Luau::SourceModule sourceModule; + sourceModule.name = packageName; + sourceModule.humanReadableName = packageName; - auto it = segments.begin(); + Luau::ParseResult parseResult = parseSourceForModule(source, sourceModule, captureComments); + if (parseResult.errors.size() > 0) + return LoadDefinitionFileResult{false, parseResult, sourceModule, nullptr}; - if (*it == "script" && !currentModuleName.empty()) - { - result = split(currentModuleName, '/'); - ++it; - } + ModulePtr checkedModule = check(sourceModule, Mode::Definition, {}, std::nullopt, /*forAutocomplete*/ false, /*recordJsonLog*/ false, {}); - for (; it != segments.end(); ++it) - { - if (result.size() > 1 && *it == "Parent") - result.pop_back(); - else - result.push_back(*it); - } + if (checkedModule->errors.size() > 0) + return LoadDefinitionFileResult{false, parseResult, sourceModule, checkedModule}; - return join(result, "/"); -} + persistCheckedTypes(checkedModule, globals, targetScope, packageName); -std::optional pathExprToModuleName(const ModuleName& currentModuleName, const AstExpr& pathExpr) -{ - std::vector segments = parsePathExpr(pathExpr); - return pathExprToModuleName(currentModuleName, segments); + return LoadDefinitionFileResult{true, parseResult, sourceModule, checkedModule}; } namespace { -ErrorVec accumulateErrors( - const std::unordered_map& sourceNodes, const std::unordered_map& modules, const ModuleName& name) +static ErrorVec accumulateErrors( + const std::unordered_map>& sourceNodes, + ModuleResolver& moduleResolver, + const ModuleName& name +) { - std::unordered_set seen; + DenseHashSet seen{{}}; std::vector queue{name}; ErrorVec result; @@ -282,7 +226,7 @@ ErrorVec accumulateErrors( ModuleName next = std::move(queue.back()); queue.pop_back(); - if (seen.count(next)) + if (seen.contains(next)) continue; seen.insert(next); @@ -290,21 +234,26 @@ ErrorVec accumulateErrors( if (it == sourceNodes.end()) continue; - const SourceNode& sourceNode = it->second; + const SourceNode& sourceNode = *it->second; queue.insert(queue.end(), sourceNode.requireSet.begin(), sourceNode.requireSet.end()); // FIXME: If a module has a syntax error, we won't be able to re-report it here. // The solution is probably to move errors from Module to SourceNode - auto it2 = modules.find(next); - if (it2 == modules.end()) + auto modulePtr = moduleResolver.getModule(next); + if (!modulePtr) continue; - Module& module = *it2->second; + Module& module = *modulePtr; - std::sort(module.errors.begin(), module.errors.end(), [](const TypeError& e1, const TypeError& e2) -> bool { - return e1.location.begin > e2.location.begin; - }); + std::sort( + module.errors.begin(), + module.errors.end(), + [](const TypeError& e1, const TypeError& e2) -> bool + { + return e1.location.begin > e2.location.begin; + } + ); result.insert(result.end(), module.errors.begin(), module.errors.end()); } @@ -314,12 +263,33 @@ ErrorVec accumulateErrors( return result; } +static void filterLintOptions(LintOptions& lintOptions, const std::vector& hotcomments, Mode mode) +{ + uint64_t ignoreLints = LintWarning::parseMask(hotcomments); + + lintOptions.warningMask &= ~ignoreLints; + + if (mode != Mode::NoCheck) + { + lintOptions.disableWarning(Luau::LintWarning::Code_UnknownGlobal); + } + + if (mode == Mode::Strict) + { + lintOptions.disableWarning(Luau::LintWarning::Code_ImplicitReturn); + } +} + // Given a source node (start), find all requires that start a transitive dependency path that ends back at start // For each such path, record the full path and the location of the require in the starting module. // Note that this is O(V^2) for a fully connected graph and produces O(V) paths of length O(V) // However, when the graph is acyclic, this is O(V), as well as when only the first cycle is needed (stopAtFirst=true) std::vector getRequireCycles( - const FileResolver* resolver, const std::unordered_map& sourceNodes, const SourceNode* start, bool stopAtFirst = false) + const FileResolver* resolver, + const std::unordered_map>& sourceNodes, + const SourceNode* start, + bool stopAtFirst = false +) { std::vector result; @@ -335,7 +305,7 @@ std::vector getRequireCycles( if (dit == sourceNodes.end()) continue; - stack.push_back(&dit->second); + stack.push_back(dit->second.get()); while (!stack.empty()) { @@ -353,9 +323,9 @@ std::vector getRequireCycles( if (top == start) { for (const SourceNode* node : path) - cycle.push_back(resolver->getHumanReadableModuleName(node->name)); + cycle.push_back(node->name); - cycle.push_back(resolver->getHumanReadableModuleName(top->name)); + cycle.push_back(top->name); break; } } @@ -376,7 +346,7 @@ std::vector getRequireCycles( auto rit = sourceNodes.find(reqName); if (rit != sourceNodes.end()) - stack.push_back(&rit->second); + stack.push_back(rit->second.get()); } } } @@ -414,17 +384,23 @@ Frontend::Frontend(FileResolver* fileResolver, ConfigResolver* configResolver, c , fileResolver(fileResolver) , moduleResolver(this) , moduleResolverForAutocomplete(this) - , typeChecker(&moduleResolver, builtinTypes, &iceHandler) - , typeCheckerForAutocomplete(&moduleResolverForAutocomplete, builtinTypes, &iceHandler) + , globals(builtinTypes) + , globalsForAutocomplete(builtinTypes) , configResolver(configResolver) , options(options) - , globalScope(typeChecker.globalScope) { } -FrontendModuleResolver::FrontendModuleResolver(Frontend* frontend) - : frontend(frontend) +void Frontend::parse(const ModuleName& name) { + LUAU_TIMETRACE_SCOPE("Frontend::parse", "Frontend"); + LUAU_TIMETRACE_ARGUMENT("name", name.c_str()); + + if (getCheckResult(name, false, false)) + return; + + std::vector buildQueue; + parseGraph(buildQueue, name, false); } CheckResult Frontend::check(const ModuleName& name, std::optional optionOverride) @@ -433,180 +409,351 @@ CheckResult Frontend::check(const ModuleName& name, std::optional result = getCheckResult(name, true, frontendOptions.forAutocomplete)) + return std::move(*result); + + std::vector buildQueue; + bool cycleDetected = parseGraph(buildQueue, name, frontendOptions.forAutocomplete); + + DenseHashSet seen{{}}; + std::vector buildQueueItems; + addBuildQueueItems(buildQueueItems, buildQueue, cycleDetected, seen, frontendOptions); + LUAU_ASSERT(!buildQueueItems.empty()); + + if (FFlag::DebugLuauLogSolverToJson) + { + LUAU_ASSERT(buildQueueItems.back().name == name); + buildQueueItems.back().recordJsonLog = true; + } + + checkBuildQueueItems(buildQueueItems); + + // Collect results only for checked modules, 'getCheckResult' produces a different result CheckResult checkResult; - auto it = sourceNodes.find(name); - if (it != sourceNodes.end() && !it->second.hasDirtyModule(frontendOptions.forAutocomplete)) + for (const BuildQueueItem& item : buildQueueItems) { - // No recheck required. - if (frontendOptions.forAutocomplete) - { - auto it2 = moduleResolverForAutocomplete.modules.find(name); - if (it2 == moduleResolverForAutocomplete.modules.end() || it2->second == nullptr) - throw InternalCompilerError("Frontend::modules does not have data for " + name, name); - } - else + if (item.module->timeout) + checkResult.timeoutHits.push_back(item.name); + + // If check was manually cancelled, do not return partial results + if (item.module->cancelled) + return {}; + + checkResult.errors.insert(checkResult.errors.end(), item.module->errors.begin(), item.module->errors.end()); + + if (item.name == name) + checkResult.lintResult = item.module->lintResult; + + if (FFlag::StudioReportLuauAny2 && item.options.retainFullTypeGraphs) { - auto it2 = moduleResolver.modules.find(name); - if (it2 == moduleResolver.modules.end() || it2->second == nullptr) - throw InternalCompilerError("Frontend::modules does not have data for " + name, name); + if (item.module) + { + const SourceModule& sourceModule = *item.sourceModule; + if (sourceModule.mode == Luau::Mode::Strict) + { + item.module->ats.root = toString(sourceModule.root); + } + item.module->ats.rootSrc = sourceModule.root; + item.module->ats.traverse(item.module.get(), sourceModule.root, NotNull{&builtinTypes_}); + } } - - return CheckResult{ - accumulateErrors(sourceNodes, frontendOptions.forAutocomplete ? moduleResolverForAutocomplete.modules : moduleResolver.modules, name)}; } - std::vector buildQueue; - bool cycleDetected = parseGraph(buildQueue, name, frontendOptions.forAutocomplete); + return checkResult; +} - for (const ModuleName& moduleName : buildQueue) +void Frontend::queueModuleCheck(const std::vector& names) +{ + moduleQueue.insert(moduleQueue.end(), names.begin(), names.end()); +} + +void Frontend::queueModuleCheck(const ModuleName& name) +{ + moduleQueue.push_back(name); +} + +std::vector Frontend::checkQueuedModules( + std::optional optionOverride, + std::function task)> executeTask, + std::function progress +) +{ + FrontendOptions frontendOptions = optionOverride.value_or(options); + if (FFlag::LuauSolverV2) + frontendOptions.forAutocomplete = false; + + // By taking data into locals, we make sure queue is cleared at the end, even if an ICE or a different exception is thrown + std::vector currModuleQueue; + std::swap(currModuleQueue, moduleQueue); + + DenseHashSet seen{{}}; + std::vector buildQueueItems; + + for (const ModuleName& name : currModuleQueue) { - LUAU_ASSERT(sourceNodes.count(moduleName)); - SourceNode& sourceNode = sourceNodes[moduleName]; + if (seen.contains(name)) + continue; - if (!sourceNode.hasDirtyModule(frontendOptions.forAutocomplete)) + if (!isDirty(name, frontendOptions.forAutocomplete)) + { + seen.insert(name); continue; + } - LUAU_ASSERT(sourceModules.count(moduleName)); - SourceModule& sourceModule = sourceModules[moduleName]; + std::vector queue; + bool cycleDetected = parseGraph( + queue, + name, + frontendOptions.forAutocomplete, + [&seen](const ModuleName& name) + { + return seen.contains(name); + } + ); - const Config& config = configResolver->getConfig(moduleName); + addBuildQueueItems(buildQueueItems, queue, cycleDetected, seen, frontendOptions); + } - Mode mode = sourceModule.mode.value_or(config.mode); + if (buildQueueItems.empty()) + return {}; - ScopePtr environmentScope = getModuleEnvironment(sourceModule, config, frontendOptions.forAutocomplete); + // We need a mapping from modules to build queue slots + std::unordered_map moduleNameToQueue; - double timestamp = getTimestamp(); + for (size_t i = 0; i < buildQueueItems.size(); i++) + { + BuildQueueItem& item = buildQueueItems[i]; + moduleNameToQueue[item.name] = i; + } - std::vector requireCycles; + // Default task execution is single-threaded and immediate + if (!executeTask) + { + executeTask = [](std::function task) + { + task(); + }; + } - // in NoCheck mode we only need to compute the value of .cyclic for typeck - // in the future we could replace toposort with an algorithm that can flag cyclic nodes by itself - // however, for now getRequireCycles isn't expensive in practice on the cases we care about, and long term - // all correct programs must be acyclic so this code triggers rarely - if (cycleDetected) - requireCycles = getRequireCycles(fileResolver, sourceNodes, &sourceNode, mode == Mode::NoCheck); + std::mutex mtx; + std::condition_variable cv; + std::vector readyQueueItems; - // This is used by the type checker to replace the resulting type of cyclic modules with any - sourceModule.cyclic = !requireCycles.empty(); + size_t processing = 0; + size_t remaining = buildQueueItems.size(); - if (frontendOptions.forAutocomplete) + auto itemTask = [&](size_t i) + { + BuildQueueItem& item = buildQueueItems[i]; + + try { - // The autocomplete typecheck is always in strict mode with DM awareness - // to provide better type information for IDE features - typeCheckerForAutocomplete.requireCycles = requireCycles; + checkBuildQueueItem(item); + } + catch (...) + { + item.exception = std::current_exception(); + } - double autocompleteTimeLimit = FInt::LuauAutocompleteCheckTimeoutMs / 1000.0; + { + std::unique_lock guard(mtx); + readyQueueItems.push_back(i); + } - if (autocompleteTimeLimit != 0.0) - typeCheckerForAutocomplete.finishTime = TimeTrace::getClock() + autocompleteTimeLimit; - else - typeCheckerForAutocomplete.finishTime = std::nullopt; + cv.notify_one(); + }; - // TODO: This is a dirty ad hoc solution for autocomplete timeouts - // We are trying to dynamically adjust our existing limits to lower total typechecking time under the limit - // so that we'll have type information for the whole file at lower quality instead of a full abort in the middle - if (FInt::LuauTarjanChildLimit > 0) - typeCheckerForAutocomplete.instantiationChildLimit = std::max(1, int(FInt::LuauTarjanChildLimit * sourceNode.autocompleteLimitsMult)); - else - typeCheckerForAutocomplete.instantiationChildLimit = std::nullopt; + auto sendItemTask = [&](size_t i) + { + BuildQueueItem& item = buildQueueItems[i]; - if (FInt::LuauTypeInferIterationLimit > 0) - typeCheckerForAutocomplete.unifierIterationLimit = - std::max(1, int(FInt::LuauTypeInferIterationLimit * sourceNode.autocompleteLimitsMult)); - else - typeCheckerForAutocomplete.unifierIterationLimit = std::nullopt; + item.processing = true; + processing++; - ModulePtr moduleForAutocomplete = FFlag::DebugLuauDeferredConstraintResolution - ? check(sourceModule, mode, requireCycles, /*forAutocomplete*/ true, /*recordJsonLog*/ false) - : typeCheckerForAutocomplete.check(sourceModule, Mode::Strict, environmentScope); + executeTask( + [&itemTask, i]() + { + itemTask(i); + } + ); + }; + + auto sendCycleItemTask = [&] + { + for (size_t i = 0; i < buildQueueItems.size(); i++) + { + BuildQueueItem& item = buildQueueItems[i]; - moduleResolverForAutocomplete.modules[moduleName] = moduleForAutocomplete; + if (!item.processing) + { + sendItemTask(i); + break; + } + } + }; - double duration = getTimestamp() - timestamp; + // In a first pass, check modules that have no dependencies and record info of those modules that wait + for (size_t i = 0; i < buildQueueItems.size(); i++) + { + BuildQueueItem& item = buildQueueItems[i]; - if (moduleForAutocomplete->timeout) + for (const ModuleName& dep : item.sourceNode->requireSet) + { + if (auto it = sourceNodes.find(dep); it != sourceNodes.end()) { - checkResult.timeoutHits.push_back(moduleName); + if (it->second->hasDirtyModule(frontendOptions.forAutocomplete)) + { + item.dirtyDependencies++; - sourceNode.autocompleteLimitsMult = sourceNode.autocompleteLimitsMult / 2.0; + buildQueueItems[moduleNameToQueue[dep]].reverseDeps.push_back(i); + } } - else if (duration < autocompleteTimeLimit / 2.0) + } + + if (item.dirtyDependencies == 0) + sendItemTask(i); + } + + // Not a single item was found, a cycle in the graph was hit + if (processing == 0) + sendCycleItemTask(); + + std::vector nextItems; + std::optional itemWithException; + bool cancelled = false; + + while (remaining != 0) + { + { + std::unique_lock guard(mtx); + + // If nothing is ready yet, wait + cv.wait( + guard, + [&readyQueueItems] + { + return !readyQueueItems.empty(); + } + ); + + // Handle checked items + for (size_t i : readyQueueItems) { - sourceNode.autocompleteLimitsMult = std::min(sourceNode.autocompleteLimitsMult * 2.0, 1.0); - } + const BuildQueueItem& item = buildQueueItems[i]; - stats.timeCheck += duration; - stats.filesStrict += 1; + // If exception was thrown, stop adding new items and wait for processing items to complete + if (item.exception) + itemWithException = i; - sourceNode.dirtyModuleForAutocomplete = false; - continue; - } + if (item.module && item.module->cancelled) + cancelled = true; - typeChecker.requireCycles = requireCycles; + if (itemWithException || cancelled) + break; - const bool recordJsonLog = FFlag::DebugLuauLogSolverToJson && moduleName == name; + recordItemResult(item); - ModulePtr module = FFlag::DebugLuauDeferredConstraintResolution ? check(sourceModule, mode, requireCycles, /*forAutocomplete*/ false, recordJsonLog) - : typeChecker.check(sourceModule, mode, environmentScope); + // Notify items that were waiting for this dependency + for (size_t reverseDep : item.reverseDeps) + { + BuildQueueItem& reverseDepItem = buildQueueItems[reverseDep]; - stats.timeCheck += getTimestamp() - timestamp; - stats.filesStrict += mode == Mode::Strict; - stats.filesNonstrict += mode == Mode::Nonstrict; + LUAU_ASSERT(reverseDepItem.dirtyDependencies != 0); + reverseDepItem.dirtyDependencies--; + + // In case of a module cycle earlier, check if unlocked an item that was already processed + if (!reverseDepItem.processing && reverseDepItem.dirtyDependencies == 0) + nextItems.push_back(reverseDep); + } + } - if (module == nullptr) - throw InternalCompilerError("Frontend::check produced a nullptr module for " + moduleName, moduleName); + LUAU_ASSERT(processing >= readyQueueItems.size()); + processing -= readyQueueItems.size(); - if (!frontendOptions.retainFullTypeGraphs) + LUAU_ASSERT(remaining >= readyQueueItems.size()); + remaining -= readyQueueItems.size(); + readyQueueItems.clear(); + } + + if (progress) { - // copyErrors needs to allocate into interfaceTypes as it copies - // types out of internalTypes, so we unfreeze it here. - unfreeze(module->interfaceTypes); - copyErrors(module->errors, module->interfaceTypes); - freeze(module->interfaceTypes); - - module->internalTypes.clear(); - - module->astTypes.clear(); - module->astTypePacks.clear(); - module->astExpectedTypes.clear(); - module->astOriginalCallTypes.clear(); - module->astOverloadResolvedTypes.clear(); - module->astResolvedTypes.clear(); - module->astOriginalResolvedTypes.clear(); - module->astResolvedTypePacks.clear(); - module->astScopes.clear(); - - module->scopes.clear(); + if (!progress(buildQueueItems.size() - remaining, buildQueueItems.size())) + cancelled = true; } - if (mode != Mode::NoCheck) + // Items cannot be submitted while holding the lock + for (size_t i : nextItems) + sendItemTask(i); + nextItems.clear(); + + if (processing == 0) { - for (const RequireCycle& cyc : requireCycles) - { - TypeError te{cyc.location, moduleName, ModuleHasCyclicDependency{cyc.path}}; + // Typechecking might have been cancelled by user, don't return partial results + if (cancelled) + return {}; - module->errors.push_back(te); - } + // We might have stopped because of a pending exception + if (itemWithException) + recordItemResult(buildQueueItems[*itemWithException]); } - ErrorVec parseErrors; + // If we aren't done, but don't have anything processing, we hit a cycle + if (remaining != 0 && processing == 0) + sendCycleItemTask(); + } + + std::vector checkedModules; + checkedModules.reserve(buildQueueItems.size()); + + for (size_t i = 0; i < buildQueueItems.size(); i++) + checkedModules.push_back(std::move(buildQueueItems[i].name)); - for (const ParseError& pe : sourceModule.parseErrors) - parseErrors.push_back(TypeError{pe.getLocation(), moduleName, SyntaxError{pe.what()}}); + return checkedModules; +} + +std::optional Frontend::getCheckResult(const ModuleName& name, bool accumulateNested, bool forAutocomplete) +{ + if (FFlag::LuauSolverV2) + forAutocomplete = false; + + auto it = sourceNodes.find(name); + + if (it == sourceNodes.end() || it->second->hasDirtyModule(forAutocomplete)) + return std::nullopt; - module->errors.insert(module->errors.begin(), parseErrors.begin(), parseErrors.end()); + auto& resolver = forAutocomplete ? moduleResolverForAutocomplete : moduleResolver; + ModulePtr module = resolver.getModule(name); + + if (module == nullptr) + throw InternalCompilerError("Frontend does not have module: " + name, name); + + CheckResult checkResult; + + if (module->timeout) + checkResult.timeoutHits.push_back(name); + + if (accumulateNested) + checkResult.errors = accumulateErrors(sourceNodes, resolver, name); + else checkResult.errors.insert(checkResult.errors.end(), module->errors.begin(), module->errors.end()); - moduleResolver.modules[moduleName] = std::move(module); - sourceNode.dirtyModule = false; - } + // Get lint result only for top checked module + checkResult.lintResult = module->lintResult; return checkResult; } -bool Frontend::parseGraph(std::vector& buildQueue, const ModuleName& root, bool forAutocomplete) +bool Frontend::parseGraph( + std::vector& buildQueue, + const ModuleName& root, + bool forAutocomplete, + std::function canSkip +) { LUAU_TIMETRACE_SCOPE("Frontend::parseGraph", "Frontend"); LUAU_TIMETRACE_ARGUMENT("root", root.c_str()); @@ -677,14 +824,18 @@ bool Frontend::parseGraph(std::vector& buildQueue, const ModuleName& // this relies on the fact that markDirty marks reverse-dependencies dirty as well // thus if a node is not dirty, all its transitive deps aren't dirty, which means that they won't ever need // to be built, *and* can't form a cycle with any nodes we did process. - if (!it->second.hasDirtyModule(forAutocomplete)) + if (!it->second->hasDirtyModule(forAutocomplete)) + continue; + + // This module might already be in the outside build queue + if (canSkip && canSkip(dep)) continue; // note: this check is technically redundant *except* that getSourceNode has somewhat broken memoization // calling getSourceNode twice in succession will reparse the file, since getSourceNode leaves dirty flag set - if (seen.contains(&it->second)) + if (seen.contains(it->second.get())) { - stack.push_back(&it->second); + stack.push_back(it->second.get()); continue; } } @@ -704,86 +855,285 @@ bool Frontend::parseGraph(std::vector& buildQueue, const ModuleName& return cyclic; } -ScopePtr Frontend::getModuleEnvironment(const SourceModule& module, const Config& config, bool forAutocomplete) +void Frontend::addBuildQueueItems( + std::vector& items, + std::vector& buildQueue, + bool cycleDetected, + DenseHashSet& seen, + const FrontendOptions& frontendOptions +) { - ScopePtr result; - if (forAutocomplete) - result = typeCheckerForAutocomplete.globalScope; - else - result = typeChecker.globalScope; + for (const ModuleName& moduleName : buildQueue) + { + if (seen.contains(moduleName)) + continue; + seen.insert(moduleName); - if (module.environmentName) - result = getEnvironmentScope(*module.environmentName); + LUAU_ASSERT(sourceNodes.count(moduleName)); + std::shared_ptr& sourceNode = sourceNodes[moduleName]; - if (!config.globals.empty()) - { - result = std::make_shared(result); + if (!sourceNode->hasDirtyModule(frontendOptions.forAutocomplete)) + continue; - for (const std::string& global : config.globals) - { - AstName name = module.names->get(global.c_str()); + LUAU_ASSERT(sourceModules.count(moduleName)); + std::shared_ptr& sourceModule = sourceModules[moduleName]; - if (name.value) - result->bindings[name].typeId = typeChecker.anyType; + BuildQueueItem data{moduleName, fileResolver->getHumanReadableModuleName(moduleName), sourceNode, sourceModule}; + + data.config = configResolver->getConfig(moduleName); + data.environmentScope = getModuleEnvironment(*sourceModule, data.config, frontendOptions.forAutocomplete); + data.recordJsonLog = FFlag::DebugLuauLogSolverToJson; + + const Mode mode = sourceModule->mode.value_or(data.config.mode); + + // in the future we could replace toposort with an algorithm that can flag cyclic nodes by itself + // however, for now getRequireCycles isn't expensive in practice on the cases we care about, and long term + // all correct programs must be acyclic so this code triggers rarely + if (cycleDetected) + { + if (FFlag::LuauMoreThoroughCycleDetection) + data.requireCycles = getRequireCycles(fileResolver, sourceNodes, sourceNode.get(), false); + else + data.requireCycles = getRequireCycles(fileResolver, sourceNodes, sourceNode.get(), mode == Mode::NoCheck); } + + data.options = frontendOptions; + + // This is used by the type checker to replace the resulting type of cyclic modules with any + sourceModule->cyclic = !data.requireCycles.empty(); + + items.push_back(std::move(data)); } +} - return result; +static void applyInternalLimitScaling(SourceNode& sourceNode, const ModulePtr module, double limit) +{ + if (module->timeout) + sourceNode.autocompleteLimitsMult = sourceNode.autocompleteLimitsMult / 2.0; + else if (module->checkDurationSec < limit / 2.0) + sourceNode.autocompleteLimitsMult = std::min(sourceNode.autocompleteLimitsMult * 2.0, 1.0); } -LintResult Frontend::lint(const ModuleName& name, std::optional enabledLintWarnings) +void Frontend::checkBuildQueueItem(BuildQueueItem& item) { - LUAU_TIMETRACE_SCOPE("Frontend::lint", "Frontend"); - LUAU_TIMETRACE_ARGUMENT("name", name.c_str()); + SourceNode& sourceNode = *item.sourceNode; + const SourceModule& sourceModule = *item.sourceModule; + const Config& config = item.config; + Mode mode; + if (FFlag::DebugLuauForceStrictMode) + mode = Mode::Strict; + else if (FFlag::DebugLuauForceNonStrictMode) + mode = Mode::Nonstrict; + else + mode = sourceModule.mode.value_or(config.mode); - auto [_sourceNode, sourceModule] = getSourceNode(name); + item.sourceModule->mode = {mode}; + ScopePtr environmentScope = item.environmentScope; + double timestamp = getTimestamp(); + const std::vector& requireCycles = item.requireCycles; - if (!sourceModule) - return LintResult{}; // FIXME: We really should do something a bit more obvious when a file is too broken to lint. + TypeCheckLimits typeCheckLimits; - return lint(*sourceModule, enabledLintWarnings); -} + if (item.options.moduleTimeLimitSec) + typeCheckLimits.finishTime = TimeTrace::getClock() + *item.options.moduleTimeLimitSec; + else + typeCheckLimits.finishTime = std::nullopt; -LintResult Frontend::lint(const SourceModule& module, std::optional enabledLintWarnings) -{ - LUAU_TIMETRACE_SCOPE("Frontend::lint", "Frontend"); - LUAU_TIMETRACE_ARGUMENT("module", module.name.c_str()); + // TODO: This is a dirty ad hoc solution for autocomplete timeouts + // We are trying to dynamically adjust our existing limits to lower total typechecking time under the limit + // so that we'll have type information for the whole file at lower quality instead of a full abort in the middle + if (item.options.applyInternalLimitScaling) + { + if (FInt::LuauTarjanChildLimit > 0) + typeCheckLimits.instantiationChildLimit = std::max(1, int(FInt::LuauTarjanChildLimit * sourceNode.autocompleteLimitsMult)); + else + typeCheckLimits.instantiationChildLimit = std::nullopt; + + if (FInt::LuauTypeInferIterationLimit > 0) + typeCheckLimits.unifierIterationLimit = std::max(1, int(FInt::LuauTypeInferIterationLimit * sourceNode.autocompleteLimitsMult)); + else + typeCheckLimits.unifierIterationLimit = std::nullopt; + } + + typeCheckLimits.cancellationToken = item.options.cancellationToken; + + if (item.options.forAutocomplete) + { + // The autocomplete typecheck is always in strict mode with DM awareness to provide better type information for IDE features + ModulePtr moduleForAutocomplete = check( + sourceModule, + Mode::Strict, + requireCycles, + environmentScope, + /*forAutocomplete*/ true, + /*recordJsonLog*/ false, + typeCheckLimits + ); + + double duration = getTimestamp() - timestamp; - const Config& config = configResolver->getConfig(module.name); + moduleForAutocomplete->checkDurationSec = duration; - uint64_t ignoreLints = LintWarning::parseMask(module.hotcomments); + if (item.options.moduleTimeLimitSec && item.options.applyInternalLimitScaling) + applyInternalLimitScaling(sourceNode, moduleForAutocomplete, *item.options.moduleTimeLimitSec); - LintOptions options = enabledLintWarnings.value_or(config.enabledLint); - options.warningMask &= ~ignoreLints; + item.stats.timeCheck += duration; + item.stats.filesStrict += 1; + + if (DFFlag::LuauRunCustomModuleChecks && item.options.customModuleCheck) + item.options.customModuleCheck(sourceModule, *moduleForAutocomplete); + + item.module = moduleForAutocomplete; + return; + } + + ModulePtr module = check(sourceModule, mode, requireCycles, environmentScope, /*forAutocomplete*/ false, item.recordJsonLog, typeCheckLimits); + + double duration = getTimestamp() - timestamp; + + module->checkDurationSec = duration; + + if (item.options.moduleTimeLimitSec && item.options.applyInternalLimitScaling) + applyInternalLimitScaling(sourceNode, module, *item.options.moduleTimeLimitSec); + + item.stats.timeCheck += duration; + item.stats.filesStrict += mode == Mode::Strict; + item.stats.filesNonstrict += mode == Mode::Nonstrict; + + if (DFFlag::LuauRunCustomModuleChecks && item.options.customModuleCheck) + item.options.customModuleCheck(sourceModule, *module); + + if (FFlag::LuauSolverV2 && mode == Mode::NoCheck) + module->errors.clear(); + + if (item.options.runLintChecks) + { + LUAU_TIMETRACE_SCOPE("lint", "Frontend"); + + LintOptions lintOptions = item.options.enabledLintWarnings.value_or(config.enabledLint); + filterLintOptions(lintOptions, sourceModule.hotcomments, mode); + + double timestamp = getTimestamp(); + + std::vector warnings = + Luau::lint(sourceModule.root, *sourceModule.names, environmentScope, module.get(), sourceModule.hotcomments, lintOptions); + + item.stats.timeLint += getTimestamp() - timestamp; + + module->lintResult = classifyLints(warnings, config); + } + + if (!item.options.retainFullTypeGraphs) + { + // copyErrors needs to allocate into interfaceTypes as it copies + // types out of internalTypes, so we unfreeze it here. + unfreeze(module->interfaceTypes); + copyErrors(module->errors, module->interfaceTypes, builtinTypes); + freeze(module->interfaceTypes); + + module->internalTypes.clear(); + + module->astTypes.clear(); + module->astTypePacks.clear(); + module->astExpectedTypes.clear(); + module->astOriginalCallTypes.clear(); + module->astOverloadResolvedTypes.clear(); + module->astForInNextTypes.clear(); + module->astResolvedTypes.clear(); + module->astResolvedTypePacks.clear(); + module->astCompoundAssignResultTypes.clear(); + module->astScopes.clear(); + module->upperBoundContributors.clear(); + module->scopes.clear(); + } - Mode mode = module.mode.value_or(config.mode); if (mode != Mode::NoCheck) { - options.disableWarning(Luau::LintWarning::Code_UnknownGlobal); + for (const RequireCycle& cyc : requireCycles) + { + TypeError te{cyc.location, item.name, ModuleHasCyclicDependency{cyc.path}}; + + module->errors.push_back(te); + } } - if (mode == Mode::Strict) + ErrorVec parseErrors; + + for (const ParseError& pe : sourceModule.parseErrors) + parseErrors.push_back(TypeError{pe.getLocation(), item.name, SyntaxError{pe.what()}}); + + module->errors.insert(module->errors.begin(), parseErrors.begin(), parseErrors.end()); + + item.module = module; +} + +void Frontend::checkBuildQueueItems(std::vector& items) +{ + for (BuildQueueItem& item : items) { - options.disableWarning(Luau::LintWarning::Code_ImplicitReturn); + checkBuildQueueItem(item); + + if (item.module && item.module->cancelled) + break; + + recordItemResult(item); } +} - ScopePtr environmentScope = getModuleEnvironment(module, config, /*forAutocomplete*/ false); +void Frontend::recordItemResult(const BuildQueueItem& item) +{ + if (item.exception) + std::rethrow_exception(item.exception); - ModulePtr modulePtr = moduleResolver.getModule(module.name); + if (item.options.forAutocomplete) + { + moduleResolverForAutocomplete.setModule(item.name, item.module); + item.sourceNode->dirtyModuleForAutocomplete = false; + } + else + { + moduleResolver.setModule(item.name, item.module); + item.sourceNode->dirtyModule = false; + } - double timestamp = getTimestamp(); + stats.timeCheck += item.stats.timeCheck; + stats.timeLint += item.stats.timeLint; - std::vector warnings = Luau::lint(module.root, *module.names, environmentScope, modulePtr.get(), module.hotcomments, options); + stats.filesStrict += item.stats.filesStrict; + stats.filesNonstrict += item.stats.filesNonstrict; +} - stats.timeLint += getTimestamp() - timestamp; +ScopePtr Frontend::getModuleEnvironment(const SourceModule& module, const Config& config, bool forAutocomplete) const +{ + ScopePtr result; + if (forAutocomplete) + result = globalsForAutocomplete.globalScope; + else + result = globals.globalScope; + + if (module.environmentName) + result = getEnvironmentScope(*module.environmentName); + + if (!config.globals.empty()) + { + result = std::make_shared(result); + + for (const std::string& global : config.globals) + { + AstName name = module.names->get(global.c_str()); - return classifyLints(warnings, config); + if (name.value) + result->bindings[name].typeId = builtinTypes->anyType; + } + } + + return result; } bool Frontend::isDirty(const ModuleName& name, bool forAutocomplete) const { auto it = sourceNodes.find(name); - return it == sourceNodes.end() || it->second.hasDirtyModule(forAutocomplete); + return it == sourceNodes.end() || it->second->hasDirtyModule(forAutocomplete); } /* @@ -794,13 +1144,13 @@ bool Frontend::isDirty(const ModuleName& name, bool forAutocomplete) const */ void Frontend::markDirty(const ModuleName& name, std::vector* markedDirty) { - if (!moduleResolver.modules.count(name) && !moduleResolverForAutocomplete.modules.count(name)) + if (sourceNodes.count(name) == 0) return; std::unordered_map> reverseDeps; for (const auto& module : sourceNodes) { - for (const auto& dep : module.second.requireSet) + for (const auto& dep : module.second->requireSet) reverseDeps[dep].push_back(module.first); } @@ -812,7 +1162,7 @@ void Frontend::markDirty(const ModuleName& name, std::vector* marked queue.pop_back(); LUAU_ASSERT(sourceNodes.count(next) > 0); - SourceNode& sourceNode = sourceNodes[next]; + SourceNode& sourceNode = *sourceNodes[next]; if (markedDirty) markedDirty->push_back(next); @@ -838,7 +1188,7 @@ SourceModule* Frontend::getSourceModule(const ModuleName& moduleName) { auto it = sourceModules.find(moduleName); if (it != sourceModules.end()) - return &it->second; + return it->second.get(); else return nullptr; } @@ -848,126 +1198,403 @@ const SourceModule* Frontend::getSourceModule(const ModuleName& moduleName) cons return const_cast(this)->getSourceModule(moduleName); } -ScopePtr Frontend::getGlobalScope() +ModulePtr check( + const SourceModule& sourceModule, + Mode mode, + const std::vector& requireCycles, + NotNull builtinTypes, + NotNull iceHandler, + NotNull moduleResolver, + NotNull fileResolver, + const ScopePtr& parentScope, + std::function prepareModuleScope, + FrontendOptions options, + TypeCheckLimits limits, + std::function writeJsonLog +) +{ + const bool recordJsonLog = FFlag::DebugLuauLogSolverToJson; + return check( + sourceModule, + mode, + requireCycles, + builtinTypes, + iceHandler, + moduleResolver, + fileResolver, + parentScope, + std::move(prepareModuleScope), + options, + limits, + recordJsonLog, + writeJsonLog + ); +} + +struct InternalTypeFinder : TypeOnceVisitor { - if (!globalScope) + bool visit(TypeId, const ClassType&) override { - globalScope = typeChecker.globalScope; + return false; } - return globalScope; -} + bool visit(TypeId, const BlockedType&) override + { + LUAU_ASSERT(false); + return false; + } -ModulePtr check(const SourceModule& sourceModule, const std::vector& requireCycles, NotNull builtinTypes, - NotNull iceHandler, NotNull moduleResolver, NotNull fileResolver, - const ScopePtr& globalScope, FrontendOptions options) -{ - const bool recordJsonLog = FFlag::DebugLuauLogSolverToJson; - return check(sourceModule, requireCycles, builtinTypes, iceHandler, moduleResolver, fileResolver, globalScope, options, recordJsonLog); -} + bool visit(TypeId, const FreeType&) override + { + LUAU_ASSERT(false); + return false; + } + + bool visit(TypeId, const PendingExpansionType&) override + { + LUAU_ASSERT(false); + return false; + } + + bool visit(TypePackId, const BlockedTypePack&) override + { + LUAU_ASSERT(false); + return false; + } + + bool visit(TypePackId, const FreeTypePack&) override + { + LUAU_ASSERT(false); + return false; + } -ModulePtr check(const SourceModule& sourceModule, const std::vector& requireCycles, NotNull builtinTypes, - NotNull iceHandler, NotNull moduleResolver, NotNull fileResolver, - const ScopePtr& globalScope, FrontendOptions options, bool recordJsonLog) + bool visit(TypePackId, const TypeFunctionInstanceTypePack&) override + { + LUAU_ASSERT(false); + return false; + } +}; + +ModulePtr check( + const SourceModule& sourceModule, + Mode mode, + const std::vector& requireCycles, + NotNull builtinTypes, + NotNull iceHandler, + NotNull moduleResolver, + NotNull fileResolver, + const ScopePtr& parentScope, + std::function prepareModuleScope, + FrontendOptions options, + TypeCheckLimits limits, + bool recordJsonLog, + std::function writeJsonLog +) { + LUAU_TIMETRACE_SCOPE("Frontend::check", "Typechecking"); + LUAU_TIMETRACE_ARGUMENT("module", sourceModule.name.c_str()); + LUAU_TIMETRACE_ARGUMENT("name", sourceModule.humanReadableName.c_str()); + ModulePtr result = std::make_shared(); - result->reduction = std::make_unique(NotNull{&result->internalTypes}, builtinTypes, iceHandler); + result->name = sourceModule.name; + result->humanReadableName = sourceModule.humanReadableName; + result->mode = mode; + result->internalTypes.owningModule = result.get(); + result->interfaceTypes.owningModule = result.get(); + + iceHandler->moduleName = sourceModule.name; std::unique_ptr logger; if (recordJsonLog) { logger = std::make_unique(); - std::optional source = fileResolver->readSource(sourceModule.name); + std::optional source = fileResolver->readSource(result->name); if (source) { logger->captureSource(source->source); } } - DataFlowGraph dfg = DataFlowGraphBuilder::build(sourceModule.root, iceHandler); + DataFlowGraph oldDfg = DataFlowGraphBuilder::build(sourceModule.root, iceHandler); + DataFlowGraph* dfgForConstraintGeneration = nullptr; + if (FFlag::LuauStoreDFGOnModule2) + { + auto [dfg, scopes] = DataFlowGraphBuilder::buildShared(sourceModule.root, iceHandler); + result->dataFlowGraph = std::move(dfg); + result->dfgScopes = std::move(scopes); + dfgForConstraintGeneration = result->dataFlowGraph.get(); + } + else + { + dfgForConstraintGeneration = &oldDfg; + } UnifierSharedState unifierState{iceHandler}; unifierState.counters.recursionLimit = FInt::LuauTypeInferRecursionLimit; - unifierState.counters.iterationLimit = FInt::LuauTypeInferIterationLimit; + unifierState.counters.iterationLimit = limits.unifierIterationLimit.value_or(FInt::LuauTypeInferIterationLimit); Normalizer normalizer{&result->internalTypes, builtinTypes, NotNull{&unifierState}}; + TypeFunctionRuntime typeFunctionRuntime{iceHandler, NotNull{&limits}}; + + if (FFlag::LuauUserDefinedTypeFunctionNoEvaluation) + typeFunctionRuntime.allowEvaluation = sourceModule.parseErrors.empty(); - ConstraintGraphBuilder cgb{ - sourceModule.name, + ConstraintGenerator cg{ result, - &result->internalTypes, + NotNull{&normalizer}, + NotNull{&typeFunctionRuntime}, moduleResolver, builtinTypes, iceHandler, - globalScope, + parentScope, + std::move(prepareModuleScope), logger.get(), - NotNull{&dfg}, + NotNull{dfgForConstraintGeneration}, + requireCycles }; - cgb.visit(sourceModule.root); - result->errors = std::move(cgb.errors); + cg.visitModuleRoot(sourceModule.root); + result->errors = std::move(cg.errors); - ConstraintSolver cs{NotNull{&normalizer}, NotNull(cgb.rootScope), borrowConstraints(cgb.constraints), sourceModule.name, - NotNull{result->reduction.get()}, moduleResolver, requireCycles, logger.get()}; + ConstraintSolver cs{ + NotNull{&normalizer}, + NotNull{&typeFunctionRuntime}, + NotNull(cg.rootScope), + borrowConstraints(cg.constraints), + result->name, + moduleResolver, + requireCycles, + logger.get(), + NotNull{dfgForConstraintGeneration}, + limits + }; if (options.randomizeConstraintResolutionSeed) cs.randomize(*options.randomizeConstraintResolutionSeed); - cs.run(); + try + { + cs.run(); + } + catch (const TimeLimitError&) + { + result->timeout = true; + } + catch (const UserCancelError&) + { + result->cancelled = true; + } + + if (recordJsonLog) + { + std::string output = logger->compileOutput(); + if (FFlag::DebugLuauLogSolverToJsonFile && writeJsonLog) + writeJsonLog(sourceModule.name, std::move(output)); + else + printf("%s\n", output.c_str()); + } for (TypeError& e : cs.errors) result->errors.emplace_back(std::move(e)); - result->scopes = std::move(cgb.scopes); + result->scopes = std::move(cg.scopes); result->type = sourceModule.type; + result->upperBoundContributors = std::move(cs.upperBoundContributors); + + if (result->timeout || result->cancelled) + { + // If solver was interrupted, skip typechecking and replace all module results with error-supressing types to avoid leaking blocked/pending + // types + ScopePtr moduleScope = result->getModuleScope(); + moduleScope->returnType = builtinTypes->errorRecoveryTypePack(); + for (auto& [name, ty] : result->declaredGlobals) + ty = builtinTypes->errorRecoveryType(); + + for (auto& [name, tf] : result->exportedTypeBindings) + tf.type = builtinTypes->errorRecoveryType(); + } + else + { + switch (mode) + { + case Mode::Nonstrict: + if (FFlag::LuauStoreDFGOnModule2) + { + Luau::checkNonStrict( + builtinTypes, + NotNull{&typeFunctionRuntime}, + iceHandler, + NotNull{&unifierState}, + NotNull{dfgForConstraintGeneration}, + NotNull{&limits}, + sourceModule, + result.get() + ); + } + else + { + Luau::checkNonStrict( + builtinTypes, + NotNull{&typeFunctionRuntime}, + iceHandler, + NotNull{&unifierState}, + NotNull{&oldDfg}, + NotNull{&limits}, + sourceModule, + result.get() + ); + } + break; + case Mode::Definition: + // fallthrough intentional + case Mode::Strict: + Luau::check( + builtinTypes, NotNull{&typeFunctionRuntime}, NotNull{&unifierState}, NotNull{&limits}, logger.get(), sourceModule, result.get() + ); + break; + case Mode::NoCheck: + break; + }; + } + + unfreeze(result->interfaceTypes); result->clonePublicInterface(builtinTypes, *iceHandler); - Luau::check(builtinTypes, logger.get(), sourceModule, result.get()); + if (FFlag::DebugLuauForbidInternalTypes) + { + InternalTypeFinder finder; + + finder.traverse(result->returnType); - // Ideally we freeze the arenas before the call into Luau::check, but TypeReduction - // needs to allocate new types while Luau::check is in progress, so here we are. + for (const auto& [_, binding] : result->exportedTypeBindings) + finder.traverse(binding.type); + + for (const auto& [_, ty] : result->astTypes) + finder.traverse(ty); + + for (const auto& [_, ty] : result->astExpectedTypes) + finder.traverse(ty); + + for (const auto& [_, tp] : result->astTypePacks) + finder.traverse(tp); + + for (const auto& [_, ty] : result->astResolvedTypes) + finder.traverse(ty); + + for (const auto& [_, ty] : result->astOverloadResolvedTypes) + finder.traverse(ty); + + for (const auto& [_, tp] : result->astResolvedTypePacks) + finder.traverse(tp); + } + + // It would be nice if we could freeze the arenas before doing type + // checking, but we'll have to do some work to get there. // - // It does mean that mutations to the type graph can happen after the constraints - // have been solved, which will cause hard-to-debug problems. We should revisit this. + // TypeChecker2 sometimes needs to allocate TypePacks via extendTypePack() + // in order to do its thing. We can rework that code to instead allocate + // into a temporary arena as long as we can prove that the allocated types + // and packs can never find their way into an error. + // + // Notably, we would first need to get to a place where TypeChecker2 is + // never in the position of dealing with a FreeType. They should all be + // bound to something by the time constraints are solved. freeze(result->internalTypes); freeze(result->interfaceTypes); - if (recordJsonLog) - { - std::string output = logger->compileOutput(); - printf("%s\n", output.c_str()); - } - return result; } -ModulePtr Frontend::check(const SourceModule& sourceModule, Mode mode, std::vector requireCycles, bool forAutocomplete, bool recordJsonLog) +ModulePtr Frontend::check( + const SourceModule& sourceModule, + Mode mode, + std::vector requireCycles, + std::optional environmentScope, + bool forAutocomplete, + bool recordJsonLog, + TypeCheckLimits typeCheckLimits +) { - return Luau::check(sourceModule, requireCycles, builtinTypes, NotNull{&iceHandler}, - NotNull{forAutocomplete ? &moduleResolverForAutocomplete : &moduleResolver}, NotNull{fileResolver}, - forAutocomplete ? typeCheckerForAutocomplete.globalScope : typeChecker.globalScope, options, recordJsonLog); + if (FFlag::LuauSolverV2) + { + auto prepareModuleScopeWrap = [this, forAutocomplete](const ModuleName& name, const ScopePtr& scope) + { + if (prepareModuleScope) + prepareModuleScope(name, scope, forAutocomplete); + }; + + try + { + return Luau::check( + sourceModule, + mode, + requireCycles, + builtinTypes, + NotNull{&iceHandler}, + NotNull{forAutocomplete ? &moduleResolverForAutocomplete : &moduleResolver}, + NotNull{fileResolver}, + environmentScope ? *environmentScope : globals.globalScope, + prepareModuleScopeWrap, + options, + typeCheckLimits, + recordJsonLog, + writeJsonLog + ); + } + catch (const InternalCompilerError& err) + { + InternalCompilerError augmented = err.location.has_value() ? InternalCompilerError{err.message, sourceModule.name, *err.location} + : InternalCompilerError{err.message, sourceModule.name}; + throw augmented; + } + } + else + { + TypeChecker typeChecker( + forAutocomplete ? globalsForAutocomplete.globalScope : globals.globalScope, + forAutocomplete ? &moduleResolverForAutocomplete : &moduleResolver, + builtinTypes, + &iceHandler + ); + + if (prepareModuleScope) + { + typeChecker.prepareModuleScope = [this, forAutocomplete](const ModuleName& name, const ScopePtr& scope) + { + prepareModuleScope(name, scope, forAutocomplete); + }; + } + + typeChecker.requireCycles = requireCycles; + typeChecker.finishTime = typeCheckLimits.finishTime; + typeChecker.instantiationChildLimit = typeCheckLimits.instantiationChildLimit; + typeChecker.unifierIterationLimit = typeCheckLimits.unifierIterationLimit; + typeChecker.cancellationToken = typeCheckLimits.cancellationToken; + + return typeChecker.check(sourceModule, mode, environmentScope); + } } // Read AST into sourceModules if necessary. Trace require()s. Report parse errors. std::pair Frontend::getSourceNode(const ModuleName& name) { - LUAU_TIMETRACE_SCOPE("Frontend::getSourceNode", "Frontend"); - LUAU_TIMETRACE_ARGUMENT("name", name.c_str()); - auto it = sourceNodes.find(name); - if (it != sourceNodes.end() && !it->second.hasDirtySourceModule()) + if (it != sourceNodes.end() && !it->second->hasDirtySourceModule()) { auto moduleIt = sourceModules.find(name); if (moduleIt != sourceModules.end()) - return {&it->second, &moduleIt->second}; + return {it->second.get(), moduleIt->second.get()}; else { LUAU_ASSERT(!"Everything in sourceNodes should also be in sourceModules"); - return {&it->second, nullptr}; + return {it->second.get(), nullptr}; } } + LUAU_TIMETRACE_SCOPE("Frontend::getSourceNode", "Frontend"); + LUAU_TIMETRACE_ARGUMENT("name", name.c_str()); + double timestamp = getTimestamp(); std::optional source = fileResolver->readSource(name); @@ -990,29 +1617,37 @@ std::pair Frontend::getSourceNode(const ModuleName& RequireTraceResult& require = requireTrace[name]; require = traceRequires(fileResolver, result.root, name); - SourceNode& sourceNode = sourceNodes[name]; - SourceModule& sourceModule = sourceModules[name]; + std::shared_ptr& sourceNode = sourceNodes[name]; - sourceModule = std::move(result); - sourceModule.environmentName = environmentName; + if (!sourceNode) + sourceNode = std::make_shared(); - sourceNode.name = name; - sourceNode.requireSet.clear(); - sourceNode.requireLocations.clear(); - sourceNode.dirtySourceModule = false; + std::shared_ptr& sourceModule = sourceModules[name]; + + if (!sourceModule) + sourceModule = std::make_shared(); + + *sourceModule = std::move(result); + sourceModule->environmentName = environmentName; + + sourceNode->name = sourceModule->name; + sourceNode->humanReadableName = sourceModule->humanReadableName; + sourceNode->requireSet.clear(); + sourceNode->requireLocations.clear(); + sourceNode->dirtySourceModule = false; if (it == sourceNodes.end()) { - sourceNode.dirtyModule = true; - sourceNode.dirtyModuleForAutocomplete = true; + sourceNode->dirtyModule = true; + sourceNode->dirtyModuleForAutocomplete = true; } for (const auto& [moduleName, location] : require.requireList) - sourceNode.requireSet.insert(moduleName); + sourceNode->requireSet.insert(moduleName); - sourceNode.requireLocations = require.requireList; + sourceNode->requireLocations = require.requireList; - return {&sourceNode, &sourceModule}; + return {sourceNode.get(), sourceModule.get()}; } /** Try to parse a source file into a SourceModule. @@ -1058,6 +1693,7 @@ SourceModule Frontend::parse(const ModuleName& name, std::string_view src, const } sourceModule.name = name; + sourceModule.humanReadableName = fileResolver->getHumanReadableModuleName(name); if (parseOptions.captureComments) { @@ -1068,6 +1704,12 @@ SourceModule Frontend::parse(const ModuleName& name, std::string_view src, const return sourceModule; } + +FrontendModuleResolver::FrontendModuleResolver(Frontend* frontend) + : frontend(frontend) +{ +} + std::optional FrontendModuleResolver::resolveModuleInfo(const ModuleName& currentModuleName, const AstExpr& pathExpr) { // FIXME I think this can be pushed into the FileResolver. @@ -1092,6 +1734,8 @@ std::optional FrontendModuleResolver::resolveModuleInfo(const Module const ModulePtr FrontendModuleResolver::getModule(const ModuleName& moduleName) const { + std::scoped_lock lock(moduleMutex); + auto it = modules.find(moduleName); if (it != modules.end()) return it->second; @@ -1109,13 +1753,27 @@ std::string FrontendModuleResolver::getHumanReadableModuleName(const ModuleName& return frontend->fileResolver->getHumanReadableModuleName(moduleName); } +void FrontendModuleResolver::setModule(const ModuleName& moduleName, ModulePtr module) +{ + std::scoped_lock lock(moduleMutex); + + modules[moduleName] = std::move(module); +} + +void FrontendModuleResolver::clearModules() +{ + std::scoped_lock lock(moduleMutex); + + modules.clear(); +} + ScopePtr Frontend::addEnvironment(const std::string& environmentName) { LUAU_ASSERT(environments.count(environmentName) == 0); if (environments.count(environmentName) == 0) { - ScopePtr scope = std::make_shared(typeChecker.globalScope); + ScopePtr scope = std::make_shared(globals.globalScope); environments[environmentName] = scope; return scope; } @@ -1123,14 +1781,16 @@ ScopePtr Frontend::addEnvironment(const std::string& environmentName) return environments[environmentName]; } -ScopePtr Frontend::getEnvironmentScope(const std::string& environmentName) +ScopePtr Frontend::getEnvironmentScope(const std::string& environmentName) const { - LUAU_ASSERT(environments.count(environmentName) > 0); + if (auto it = environments.find(environmentName); it != environments.end()) + return it->second; - return environments[environmentName]; + LUAU_ASSERT(!"environment doesn't exist"); + return {}; } -void Frontend::registerBuiltinDefinition(const std::string& name, std::function applicator) +void Frontend::registerBuiltinDefinition(const std::string& name, std::function applicator) { LUAU_ASSERT(builtinDefinitions.count(name) == 0); @@ -1143,7 +1803,7 @@ void Frontend::applyBuiltinDefinitionToEnvironment(const std::string& environmen LUAU_ASSERT(builtinDefinitions.count(definitionName) > 0); if (builtinDefinitions.count(definitionName) > 0) - builtinDefinitions[definitionName](typeChecker, getEnvironmentScope(environmentName)); + builtinDefinitions[definitionName](*this, globals, getEnvironmentScope(environmentName)); } LintResult Frontend::classifyLints(const std::vector& warnings, const Config& config) @@ -1169,8 +1829,8 @@ void Frontend::clear() { sourceNodes.clear(); sourceModules.clear(); - moduleResolver.modules.clear(); - moduleResolverForAutocomplete.modules.clear(); + moduleResolver.clearModules(); + moduleResolverForAutocomplete.clearModules(); requireTrace.clear(); } diff --git a/Analysis/src/Generalization.cpp b/Analysis/src/Generalization.cpp new file mode 100644 index 000000000..8dce95f9e --- /dev/null +++ b/Analysis/src/Generalization.cpp @@ -0,0 +1,1054 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Luau/Generalization.h" + +#include "Luau/Scope.h" +#include "Luau/Type.h" +#include "Luau/ToString.h" +#include "Luau/TypeArena.h" +#include "Luau/TypePack.h" +#include "Luau/VisitType.h" + +LUAU_DYNAMIC_FASTINT(LuauTypeSolverRelease) + +namespace Luau +{ + +struct MutatingGeneralizer : TypeOnceVisitor +{ + NotNull builtinTypes; + + NotNull scope; + NotNull> cachedTypes; + DenseHashMap positiveTypes; + DenseHashMap negativeTypes; + std::vector generics; + std::vector genericPacks; + + bool isWithinFunction = false; + bool avoidSealingTables = false; + + MutatingGeneralizer( + NotNull builtinTypes, + NotNull scope, + NotNull> cachedTypes, + DenseHashMap positiveTypes, + DenseHashMap negativeTypes, + bool avoidSealingTables + ) + : TypeOnceVisitor(/* skipBoundTypes */ true) + , builtinTypes(builtinTypes) + , scope(scope) + , cachedTypes(cachedTypes) + , positiveTypes(std::move(positiveTypes)) + , negativeTypes(std::move(negativeTypes)) + , avoidSealingTables(avoidSealingTables) + { + } + + static void replace(DenseHashSet& seen, TypeId haystack, TypeId needle, TypeId replacement) + { + haystack = follow(haystack); + + if (seen.find(haystack)) + return; + seen.insert(haystack); + + if (UnionType* ut = getMutable(haystack)) + { + for (auto iter = ut->options.begin(); iter != ut->options.end();) + { + // FIXME: I bet this function has reentrancy problems + TypeId option = follow(*iter); + + if (option == needle && get(replacement)) + { + iter = ut->options.erase(iter); + continue; + } + + if (option == needle) + { + *iter = replacement; + iter++; + continue; + } + + // advance the iterator, nothing after this can use it. + iter++; + + if (seen.find(option)) + continue; + seen.insert(option); + + if (get(option)) + replace(seen, option, needle, haystack); + else if (get(option)) + replace(seen, option, needle, haystack); + } + + if (ut->options.size() == 1) + { + TypeId onlyType = ut->options[0]; + LUAU_ASSERT(onlyType != haystack); + emplaceType(asMutable(haystack), onlyType); + } + + return; + } + + if (IntersectionType* it = getMutable(needle)) + { + for (auto iter = it->parts.begin(); iter != it->parts.end();) + { + // FIXME: I bet this function has reentrancy problems + TypeId part = follow(*iter); + + if (part == needle && get(replacement)) + { + iter = it->parts.erase(iter); + continue; + } + + if (part == needle) + { + *iter = replacement; + iter++; + continue; + } + + // advance the iterator, nothing after this can use it. + iter++; + + if (seen.find(part)) + continue; + seen.insert(part); + + if (get(part)) + replace(seen, part, needle, haystack); + else if (get(part)) + replace(seen, part, needle, haystack); + } + + if (it->parts.size() == 1) + { + TypeId onlyType = it->parts[0]; + LUAU_ASSERT(onlyType != needle); + emplaceType(asMutable(needle), onlyType); + } + + return; + } + } + + bool visit(TypeId ty, const FunctionType& ft) override + { + if (cachedTypes->contains(ty)) + return false; + + const bool oldValue = isWithinFunction; + + isWithinFunction = true; + + traverse(ft.argTypes); + traverse(ft.retTypes); + + isWithinFunction = oldValue; + + return false; + } + + bool visit(TypeId ty, const FreeType&) override + { + LUAU_ASSERT(!cachedTypes->contains(ty)); + + const FreeType* ft = get(ty); + LUAU_ASSERT(ft); + + traverse(ft->lowerBound); + traverse(ft->upperBound); + + // It is possible for the above traverse() calls to cause ty to be + // transmuted. We must reacquire ft if this happens. + ty = follow(ty); + ft = get(ty); + if (!ft) + return false; + + const size_t positiveCount = getCount(positiveTypes, ty); + const size_t negativeCount = getCount(negativeTypes, ty); + + if (!positiveCount && !negativeCount) + return false; + + const bool hasLowerBound = !get(follow(ft->lowerBound)); + const bool hasUpperBound = !get(follow(ft->upperBound)); + + DenseHashSet seen{nullptr}; + seen.insert(ty); + + if (!hasLowerBound && !hasUpperBound) + { + if (!isWithinFunction || (positiveCount + negativeCount == 1)) + emplaceType(asMutable(ty), builtinTypes->unknownType); + else + { + emplaceType(asMutable(ty), scope); + generics.push_back(ty); + } + } + + // It is possible that this free type has other free types in its upper + // or lower bounds. If this is the case, we must replace those + // references with never (for the lower bound) or unknown (for the upper + // bound). + // + // If we do not do this, we get tautological bounds like a <: a <: unknown. + else if (positiveCount && !hasUpperBound) + { + TypeId lb = follow(ft->lowerBound); + if (FreeType* lowerFree = getMutable(lb); lowerFree && lowerFree->upperBound == ty) + lowerFree->upperBound = builtinTypes->unknownType; + else + { + DenseHashSet replaceSeen{nullptr}; + replace(replaceSeen, lb, ty, builtinTypes->unknownType); + } + + if (lb != ty) + emplaceType(asMutable(ty), lb); + else if (!isWithinFunction || (positiveCount + negativeCount == 1)) + emplaceType(asMutable(ty), builtinTypes->unknownType); + else + { + // if the lower bound is the type in question, we don't actually have a lower bound. + emplaceType(asMutable(ty), scope); + generics.push_back(ty); + } + } + else + { + TypeId ub = follow(ft->upperBound); + if (FreeType* upperFree = getMutable(ub); upperFree && upperFree->lowerBound == ty) + upperFree->lowerBound = builtinTypes->neverType; + else + { + DenseHashSet replaceSeen{nullptr}; + replace(replaceSeen, ub, ty, builtinTypes->neverType); + } + + if (ub != ty) + emplaceType(asMutable(ty), ub); + else if (!isWithinFunction || (positiveCount + negativeCount == 1)) + emplaceType(asMutable(ty), builtinTypes->unknownType); + else + { + // if the upper bound is the type in question, we don't actually have an upper bound. + emplaceType(asMutable(ty), scope); + generics.push_back(ty); + } + } + + return false; + } + + size_t getCount(const DenseHashMap& map, const void* ty) + { + if (const size_t* count = map.find(ty)) + return *count; + else + return 0; + } + + bool visit(TypeId ty, const TableType&) override + { + if (cachedTypes->contains(ty)) + return false; + + const size_t positiveCount = getCount(positiveTypes, ty); + const size_t negativeCount = getCount(negativeTypes, ty); + + // FIXME: Free tables should probably just be replaced by upper bounds on free types. + // + // eg never <: 'a <: {x: number} & {z: boolean} + + if (!positiveCount && !negativeCount) + return true; + + TableType* tt = getMutable(ty); + LUAU_ASSERT(tt); + + if (!avoidSealingTables) + tt->state = TableState::Sealed; + + return true; + } + + bool visit(TypePackId tp, const FreeTypePack& ftp) override + { + if (!subsumes(scope, ftp.scope)) + return true; + + tp = follow(tp); + + const size_t positiveCount = getCount(positiveTypes, tp); + const size_t negativeCount = getCount(negativeTypes, tp); + + if (1 == positiveCount + negativeCount) + emplaceTypePack(asMutable(tp), builtinTypes->unknownTypePack); + else + { + emplaceTypePack(asMutable(tp), scope); + genericPacks.push_back(tp); + } + + return true; + } +}; + +struct FreeTypeSearcher : TypeVisitor +{ + NotNull scope; + NotNull> cachedTypes; + + explicit FreeTypeSearcher(NotNull scope, NotNull> cachedTypes) + : TypeVisitor(/*skipBoundTypes*/ true) + , scope(scope) + , cachedTypes(cachedTypes) + { + } + + enum Polarity + { + Positive, + Negative, + Both, + }; + + Polarity polarity = Positive; + + void flip() + { + switch (polarity) + { + case Positive: + polarity = Negative; + break; + case Negative: + polarity = Positive; + break; + case Both: + break; + } + } + + DenseHashSet seenPositive{nullptr}; + DenseHashSet seenNegative{nullptr}; + + bool seenWithPolarity(const void* ty) + { + switch (polarity) + { + case Positive: + { + if (seenPositive.contains(ty)) + return true; + + seenPositive.insert(ty); + return false; + } + case Negative: + { + if (seenNegative.contains(ty)) + return true; + + seenNegative.insert(ty); + return false; + } + case Both: + { + if (seenPositive.contains(ty) && seenNegative.contains(ty)) + return true; + + seenPositive.insert(ty); + seenNegative.insert(ty); + return false; + } + } + + return false; + } + + // The keys in these maps are either TypeIds or TypePackIds. It's safe to + // mix them because we only use these pointers as unique keys. We never + // indirect them. + DenseHashMap negativeTypes{0}; + DenseHashMap positiveTypes{0}; + + bool visit(TypeId ty) override + { + if (cachedTypes->contains(ty) || seenWithPolarity(ty)) + return false; + + LUAU_ASSERT(ty); + return true; + } + + bool visit(TypeId ty, const FreeType& ft) override + { + if (cachedTypes->contains(ty) || seenWithPolarity(ty)) + return false; + + if (!subsumes(scope, ft.scope)) + return true; + + switch (polarity) + { + case Positive: + positiveTypes[ty]++; + break; + case Negative: + negativeTypes[ty]++; + break; + case Both: + positiveTypes[ty]++; + negativeTypes[ty]++; + break; + } + + return true; + } + + bool visit(TypeId ty, const TableType& tt) override + { + if (cachedTypes->contains(ty) || seenWithPolarity(ty)) + return false; + + if ((tt.state == TableState::Free || tt.state == TableState::Unsealed) && subsumes(scope, tt.scope)) + { + switch (polarity) + { + case Positive: + positiveTypes[ty]++; + break; + case Negative: + negativeTypes[ty]++; + break; + case Both: + positiveTypes[ty]++; + negativeTypes[ty]++; + break; + } + } + + for (const auto& [_name, prop] : tt.props) + { + if (prop.isReadOnly()) + traverse(*prop.readTy); + else + { + LUAU_ASSERT(prop.isShared()); + + Polarity p = polarity; + polarity = Both; + traverse(prop.type()); + polarity = p; + } + } + + if (tt.indexer) + { + traverse(tt.indexer->indexType); + traverse(tt.indexer->indexResultType); + } + + return false; + } + + bool visit(TypeId ty, const FunctionType& ft) override + { + if (cachedTypes->contains(ty) || seenWithPolarity(ty)) + return false; + + flip(); + traverse(ft.argTypes); + flip(); + + traverse(ft.retTypes); + + return false; + } + + bool visit(TypeId, const ClassType&) override + { + return false; + } + + bool visit(TypePackId tp, const FreeTypePack& ftp) override + { + if (seenWithPolarity(tp)) + return false; + + if (!subsumes(scope, ftp.scope)) + return true; + + switch (polarity) + { + case Positive: + positiveTypes[tp]++; + break; + case Negative: + negativeTypes[tp]++; + break; + case Both: + positiveTypes[tp]++; + negativeTypes[tp]++; + break; + } + + return true; + } +}; + +// We keep a running set of types that will not change under generalization and +// only have outgoing references to types that are the same. We use this to +// short circuit generalization. It improves performance quite a lot. +// +// We do this by tracing through the type and searching for types that are +// uncacheable. If a type has a reference to an uncacheable type, it is itself +// uncacheable. +// +// If a type has no outbound references to uncacheable types, we add it to the +// cache. +struct TypeCacher : TypeOnceVisitor +{ + NotNull> cachedTypes; + + DenseHashSet uncacheable{nullptr}; + DenseHashSet uncacheablePacks{nullptr}; + + explicit TypeCacher(NotNull> cachedTypes) + // CLI-120975: once we roll out release 646, we _want_ to visit bound + // types to ensure they're marked as uncacheable if the types they are + // bound to are also uncacheable. Hence: if LuauTypeSolverRelease is + // less than 646, skip bound types (the prior behavior). Otherwise, + // do not skip bound types. + : TypeOnceVisitor(/* skipBoundTypes */ DFInt::LuauTypeSolverRelease < 646) + , cachedTypes(cachedTypes) + { + } + + void cache(TypeId ty) + { + cachedTypes->insert(ty); + } + + bool isCached(TypeId ty) const + { + return cachedTypes->contains(ty); + } + + void markUncacheable(TypeId ty) + { + uncacheable.insert(ty); + } + + void markUncacheable(TypePackId tp) + { + uncacheablePacks.insert(tp); + } + + bool isUncacheable(TypeId ty) const + { + return uncacheable.contains(ty); + } + + bool isUncacheable(TypePackId tp) const + { + return uncacheablePacks.contains(tp); + } + + bool visit(TypeId ty) override + { + if (DFInt::LuauTypeSolverRelease >= 646) + { + // NOTE: `TypeCacher` should explicitly visit _all_ types and type packs, + // otherwise it's prone to marking types that cannot be cached as + // cacheable. + LUAU_ASSERT(false); + LUAU_UNREACHABLE(); + } + else + { + return true; + } + } + + bool visit(TypeId ty, const BoundType& btv) override + { + if (DFInt::LuauTypeSolverRelease >= 646) + { + traverse(btv.boundTo); + if (isUncacheable(btv.boundTo)) + markUncacheable(ty); + return false; + } + else + { + return true; + } + } + + bool visit(TypeId ty, const FreeType& ft) override + { + // Free types are never cacheable. + LUAU_ASSERT(!isCached(ty)); + + if (!isUncacheable(ty)) + { + traverse(ft.lowerBound); + traverse(ft.upperBound); + + markUncacheable(ty); + } + + return false; + } + + bool visit(TypeId ty, const GenericType&) override + { + cache(ty); + return false; + } + + bool visit(TypeId ty, const ErrorType&) override + { + if (DFInt::LuauTypeSolverRelease >= 646) + { + cache(ty); + return false; + } + else + { + return true; + } + } + + bool visit(TypeId ty, const PrimitiveType&) override + { + cache(ty); + return false; + } + + bool visit(TypeId ty, const SingletonType&) override + { + cache(ty); + return false; + } + + bool visit(TypeId ty, const BlockedType&) override + { + markUncacheable(ty); + return false; + } + + bool visit(TypeId ty, const PendingExpansionType&) override + { + markUncacheable(ty); + return false; + } + + bool visit(TypeId ty, const FunctionType& ft) override + { + if (isCached(ty) || isUncacheable(ty)) + return false; + + traverse(ft.argTypes); + traverse(ft.retTypes); + for (TypeId gen : ft.generics) + traverse(gen); + + bool uncacheable = false; + + if (isUncacheable(ft.argTypes)) + uncacheable = true; + + else if (isUncacheable(ft.retTypes)) + uncacheable = true; + + for (TypeId argTy : ft.argTypes) + { + if (isUncacheable(argTy)) + { + uncacheable = true; + break; + } + } + + for (TypeId retTy : ft.retTypes) + { + if (isUncacheable(retTy)) + { + uncacheable = true; + break; + } + } + + for (TypeId g : ft.generics) + { + if (isUncacheable(g)) + { + uncacheable = true; + break; + } + } + + if (uncacheable) + markUncacheable(ty); + else + cache(ty); + + return false; + } + + bool visit(TypeId ty, const TableType& tt) override + { + if (isCached(ty) || isUncacheable(ty)) + return false; + + if (tt.boundTo) + { + traverse(*tt.boundTo); + if (isUncacheable(*tt.boundTo)) + { + markUncacheable(ty); + return false; + } + } + + bool uncacheable = false; + + // This logic runs immediately after generalization, so any remaining + // unsealed tables are assuredly not cacheable. They may yet have + // properties added to them. + if (tt.state == TableState::Free || tt.state == TableState::Unsealed) + uncacheable = true; + + for (const auto& [_name, prop] : tt.props) + { + if (prop.readTy) + { + traverse(*prop.readTy); + + if (isUncacheable(*prop.readTy)) + uncacheable = true; + } + if (prop.writeTy && prop.writeTy != prop.readTy) + { + traverse(*prop.writeTy); + + if (isUncacheable(*prop.writeTy)) + uncacheable = true; + } + } + + if (tt.indexer) + { + traverse(tt.indexer->indexType); + if (isUncacheable(tt.indexer->indexType)) + uncacheable = true; + + traverse(tt.indexer->indexResultType); + if (isUncacheable(tt.indexer->indexResultType)) + uncacheable = true; + } + + if (uncacheable) + markUncacheable(ty); + else + cache(ty); + + return false; + } + + bool visit(TypeId ty, const MetatableType& mtv) override + { + if (DFInt::LuauTypeSolverRelease >= 646) + { + traverse(mtv.table); + traverse(mtv.metatable); + if (isUncacheable(mtv.table) || isUncacheable(mtv.metatable)) + markUncacheable(ty); + else + cache(ty); + return false; + } + else + { + return true; + } + } + + bool visit(TypeId ty, const ClassType&) override + { + cache(ty); + return false; + } + + bool visit(TypeId ty, const AnyType&) override + { + cache(ty); + return false; + } + + bool visit(TypeId ty, const NoRefineType&) override + { + cache(ty); + return false; + } + + bool visit(TypeId ty, const UnionType& ut) override + { + if (isUncacheable(ty) || isCached(ty)) + return false; + + bool uncacheable = false; + + for (TypeId partTy : ut.options) + { + traverse(partTy); + + uncacheable |= isUncacheable(partTy); + } + + if (uncacheable) + markUncacheable(ty); + else + cache(ty); + + return false; + } + + bool visit(TypeId ty, const IntersectionType& it) override + { + if (isUncacheable(ty) || isCached(ty)) + return false; + + bool uncacheable = false; + + for (TypeId partTy : it.parts) + { + traverse(partTy); + + uncacheable |= isUncacheable(partTy); + } + + if (uncacheable) + markUncacheable(ty); + else + cache(ty); + + return false; + } + + bool visit(TypeId ty, const UnknownType&) override + { + cache(ty); + return false; + } + + bool visit(TypeId ty, const NeverType&) override + { + cache(ty); + return false; + } + + bool visit(TypeId ty, const NegationType& nt) override + { + if (!isCached(ty) && !isUncacheable(ty)) + { + traverse(nt.ty); + + if (isUncacheable(nt.ty)) + markUncacheable(ty); + else + cache(ty); + } + + return false; + } + + bool visit(TypeId ty, const TypeFunctionInstanceType& tfit) override + { + if (isCached(ty) || isUncacheable(ty)) + return false; + + bool uncacheable = false; + + for (TypeId argTy : tfit.typeArguments) + { + traverse(argTy); + + if (isUncacheable(argTy)) + uncacheable = true; + } + + for (TypePackId argPack : tfit.packArguments) + { + traverse(argPack); + + if (isUncacheable(argPack)) + uncacheable = true; + } + + if (uncacheable) + markUncacheable(ty); + else + cache(ty); + + return false; + } + + bool visit(TypePackId tp) override + { + if (DFInt::LuauTypeSolverRelease >= 646) + { + // NOTE: `TypeCacher` should explicitly visit _all_ types and type packs, + // otherwise it's prone to marking types that cannot be cached as + // cacheable, which will segfault down the line. + LUAU_ASSERT(false); + LUAU_UNREACHABLE(); + } + else + { + return true; + } + } + + bool visit(TypePackId tp, const FreeTypePack&) override + { + markUncacheable(tp); + return false; + } + + bool visit(TypePackId tp, const GenericTypePack& gtp) override + { + return true; + } + + bool visit(TypePackId tp, const Unifiable::Error& etp) override + { + return true; + } + + bool visit(TypePackId tp, const VariadicTypePack& vtp) override + { + if (isUncacheable(tp)) + return false; + + traverse(vtp.ty); + + if (isUncacheable(vtp.ty)) + markUncacheable(tp); + + return false; + } + + bool visit(TypePackId tp, const BlockedTypePack&) override + { + markUncacheable(tp); + return false; + } + + bool visit(TypePackId tp, const TypeFunctionInstanceTypePack&) override + { + markUncacheable(tp); + return false; + } + + bool visit(TypePackId tp, const BoundTypePack& btp) override { + if (DFInt::LuauTypeSolverRelease >= 645) { + traverse(btp.boundTo); + if (isUncacheable(btp.boundTo)) + markUncacheable(tp); + return false; + } + return true; + } + + bool visit(TypePackId tp, const TypePack& typ) override + { + if (DFInt::LuauTypeSolverRelease >= 646) + { + bool uncacheable = false; + for (TypeId ty : typ.head) + { + traverse(ty); + uncacheable |= isUncacheable(ty); + } + if (typ.tail) + { + traverse(*typ.tail); + uncacheable |= isUncacheable(*typ.tail); + } + if (uncacheable) + markUncacheable(tp); + return false; + } + return true; + } +}; + +std::optional generalize( + NotNull arena, + NotNull builtinTypes, + NotNull scope, + NotNull> cachedTypes, + TypeId ty, + bool avoidSealingTables +) +{ + ty = follow(ty); + + if (ty->owningArena != arena || ty->persistent) + return ty; + + FreeTypeSearcher fts{scope, cachedTypes}; + fts.traverse(ty); + + MutatingGeneralizer gen{builtinTypes, scope, cachedTypes, std::move(fts.positiveTypes), std::move(fts.negativeTypes), avoidSealingTables}; + + gen.traverse(ty); + + /* MutatingGeneralizer mutates types in place, so it is possible that ty has + * been transmuted to a BoundType. We must follow it again and verify that + * we are allowed to mutate it before we attach generics to it. + */ + ty = follow(ty); + + if (ty->owningArena != arena || ty->persistent) + return ty; + + TypeCacher cacher{cachedTypes}; + cacher.traverse(ty); + + FunctionType* ftv = getMutable(ty); + if (ftv) + { + // If we're generalizing a function type, add any of the newly inferred + // generics to the list of existing generic types. + for (const auto g : std::move(gen.generics)) + { + ftv->generics.push_back(g); + } + // Ditto for generic packs. + for (const auto gp : std::move(gen.genericPacks)) + { + ftv->genericPacks.push_back(gp); + } + } + + return ty; +} + +} // namespace Luau diff --git a/Analysis/src/GlobalTypes.cpp b/Analysis/src/GlobalTypes.cpp new file mode 100644 index 000000000..9dd60caad --- /dev/null +++ b/Analysis/src/GlobalTypes.cpp @@ -0,0 +1,30 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Luau/GlobalTypes.h" + +namespace Luau +{ + +GlobalTypes::GlobalTypes(NotNull builtinTypes) + : builtinTypes(builtinTypes) +{ + globalScope = std::make_shared(globalTypes.addTypePack(TypePackVar{FreeTypePack{TypeLevel{}}})); + + globalScope->addBuiltinTypeBinding("any", TypeFun{{}, builtinTypes->anyType}); + globalScope->addBuiltinTypeBinding("nil", TypeFun{{}, builtinTypes->nilType}); + globalScope->addBuiltinTypeBinding("number", TypeFun{{}, builtinTypes->numberType}); + globalScope->addBuiltinTypeBinding("string", TypeFun{{}, builtinTypes->stringType}); + globalScope->addBuiltinTypeBinding("boolean", TypeFun{{}, builtinTypes->booleanType}); + globalScope->addBuiltinTypeBinding("thread", TypeFun{{}, builtinTypes->threadType}); + globalScope->addBuiltinTypeBinding("buffer", TypeFun{{}, builtinTypes->bufferType}); + globalScope->addBuiltinTypeBinding("unknown", TypeFun{{}, builtinTypes->unknownType}); + globalScope->addBuiltinTypeBinding("never", TypeFun{{}, builtinTypes->neverType}); + + unfreeze(*builtinTypes->arena); + TypeId stringMetatableTy = makeStringMetatable(builtinTypes); + asMutable(builtinTypes->stringType)->ty.emplace(PrimitiveType::String, stringMetatableTy); + persist(stringMetatableTy); + freeze(*builtinTypes->arena); +} + +} // namespace Luau diff --git a/Analysis/src/Instantiation.cpp b/Analysis/src/Instantiation.cpp index 9c3ae0771..4b6d11154 100644 --- a/Analysis/src/Instantiation.cpp +++ b/Analysis/src/Instantiation.cpp @@ -1,19 +1,35 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details -#include "Luau/Common.h" #include "Luau/Instantiation.h" + +#include "Luau/Common.h" +#include "Luau/Instantiation2.h" // including for `Replacer` which was stolen since it will be kept in the new solver +#include "Luau/ToString.h" #include "Luau/TxnLog.h" #include "Luau/TypeArena.h" +#include "Luau/TypeCheckLimits.h" -LUAU_FASTFLAG(LuauClassTypeVarsInSubstitution) +#include + +LUAU_FASTFLAG(LuauSolverV2) namespace Luau { +void Instantiation::resetState(const TxnLog* log, TypeArena* arena, NotNull builtinTypes, TypeLevel level, Scope* scope) +{ + Substitution::resetState(log, arena); + + this->builtinTypes = builtinTypes; + + this->level = level; + this->scope = scope; +} + bool Instantiation::isDirty(TypeId ty) { if (const FunctionType* ftv = log->getMutable(ty)) { - if (ftv->hasNoGenerics) + if (ftv->hasNoFreeOrGenericTypes) return false; return true; @@ -33,7 +49,7 @@ bool Instantiation::ignoreChildren(TypeId ty) { if (log->getMutable(ty)) return true; - else if (FFlag::LuauClassTypeVarsInSubstitution && get(ty)) + else if (get(ty)) return true; else return false; @@ -54,11 +70,11 @@ TypeId Instantiation::clean(TypeId ty) // Annoyingly, we have to do this even if there are no generics, // to replace any generic tables. - ReplaceGenerics replaceGenerics{log, arena, level, scope, ftv->generics, ftv->genericPacks}; + reusableReplaceGenerics.resetState(log, arena, builtinTypes, level, scope, ftv->generics, ftv->genericPacks); // TODO: What to do if this returns nullopt? // We don't have access to the error-reporting machinery - result = replaceGenerics.substitute(result).value_or(result); + result = reusableReplaceGenerics.substitute(result).value_or(result); asMutable(result)->documentationSymbol = ty->documentationSymbol; return result; @@ -70,11 +86,32 @@ TypePackId Instantiation::clean(TypePackId tp) return tp; } +void ReplaceGenerics::resetState( + const TxnLog* log, + TypeArena* arena, + NotNull builtinTypes, + TypeLevel level, + Scope* scope, + const std::vector& generics, + const std::vector& genericPacks +) +{ + Substitution::resetState(log, arena); + + this->builtinTypes = builtinTypes; + + this->level = level; + this->scope = scope; + + this->generics = generics; + this->genericPacks = genericPacks; +} + bool ReplaceGenerics::ignoreChildren(TypeId ty) { if (const FunctionType* ftv = log->getMutable(ty)) { - if (ftv->hasNoGenerics) + if (ftv->hasNoFreeOrGenericTypes) return true; // We aren't recursing in the case of a generic function which @@ -84,7 +121,7 @@ bool ReplaceGenerics::ignoreChildren(TypeId ty) // whenever we quantify, so the vectors overlap if and only if they are equal. return (!generics.empty() || !genericPacks.empty()) && (ftv->generics == generics) && (ftv->genericPacks == genericPacks); } - else if (FFlag::LuauClassTypeVarsInSubstitution && get(ty)) + else if (get(ty)) return true; else { @@ -120,14 +157,66 @@ TypeId ReplaceGenerics::clean(TypeId ty) clone.definitionLocation = ttv->definitionLocation; return addType(std::move(clone)); } + else if (FFlag::LuauSolverV2) + { + TypeId res = freshType(NotNull{arena}, builtinTypes, scope); + getMutable(res)->level = level; + return res; + } else + { return addType(FreeType{scope, level}); + } } TypePackId ReplaceGenerics::clean(TypePackId tp) { LUAU_ASSERT(isDirty(tp)); - return addTypePack(TypePackVar(FreeTypePack{level})); + return addTypePack(TypePackVar(FreeTypePack{scope, level})); +} + +std::optional instantiate( + NotNull builtinTypes, + NotNull arena, + NotNull limits, + NotNull scope, + TypeId ty +) +{ + ty = follow(ty); + + const FunctionType* ft = get(ty); + if (!ft) + return ty; + + if (ft->generics.empty() && ft->genericPacks.empty()) + return ty; + + DenseHashMap replacements{nullptr}; + DenseHashMap replacementPacks{nullptr}; + + for (TypeId g : ft->generics) + replacements[g] = freshType(arena, builtinTypes, scope); + + for (TypePackId g : ft->genericPacks) + replacementPacks[g] = arena->freshTypePack(scope); + + Replacer r{arena, std::move(replacements), std::move(replacementPacks)}; + + if (limits->instantiationChildLimit) + r.childLimit = *limits->instantiationChildLimit; + + std::optional res = r.substitute(ty); + if (!res) + return res; + + FunctionType* ft2 = getMutable(*res); + LUAU_ASSERT(ft != ft2); + + ft2->generics.clear(); + ft2->genericPacks.clear(); + + return res; } } // namespace Luau diff --git a/Analysis/src/Instantiation2.cpp b/Analysis/src/Instantiation2.cpp new file mode 100644 index 000000000..106ad8700 --- /dev/null +++ b/Analysis/src/Instantiation2.cpp @@ -0,0 +1,88 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/Instantiation2.h" + +namespace Luau +{ + +bool Instantiation2::ignoreChildren(TypeId ty) +{ + if (get(ty)) + return true; + + if (auto ftv = get(ty)) + { + if (ftv->hasNoFreeOrGenericTypes) + return false; + + // If this function type quantifies over these generics, we don't want substitution to + // go any further into them because it's being shadowed in this case. + for (auto generic : ftv->generics) + if (genericSubstitutions.contains(generic)) + return true; + + for (auto generic : ftv->genericPacks) + if (genericPackSubstitutions.contains(generic)) + return true; + } + + return false; +} + +bool Instantiation2::isDirty(TypeId ty) +{ + return get(ty) && genericSubstitutions.contains(ty); +} + +bool Instantiation2::isDirty(TypePackId tp) +{ + return get(tp) && genericPackSubstitutions.contains(tp); +} + +TypeId Instantiation2::clean(TypeId ty) +{ + TypeId substTy = follow(genericSubstitutions[ty]); + const FreeType* ft = get(substTy); + + // violation of the substitution invariant if this is not a free type. + LUAU_ASSERT(ft); + + // if we didn't learn anything about the lower bound, we pick the upper bound instead. + // we default to the lower bound which represents the most specific type for the free type. + TypeId res = get(ft->lowerBound) ? ft->upperBound : ft->lowerBound; + + // Instantiation should not traverse into the type that we are substituting for. + dontTraverseInto(res); + + return res; +} + +TypePackId Instantiation2::clean(TypePackId tp) +{ + TypePackId res = genericPackSubstitutions[tp]; + dontTraverseInto(res); + return res; +} + +std::optional instantiate2( + TypeArena* arena, + DenseHashMap genericSubstitutions, + DenseHashMap genericPackSubstitutions, + TypeId ty +) +{ + Instantiation2 instantiation{arena, std::move(genericSubstitutions), std::move(genericPackSubstitutions)}; + return instantiation.substitute(ty); +} + +std::optional instantiate2( + TypeArena* arena, + DenseHashMap genericSubstitutions, + DenseHashMap genericPackSubstitutions, + TypePackId tp +) +{ + Instantiation2 instantiation{arena, std::move(genericSubstitutions), std::move(genericPackSubstitutions)}; + return instantiation.substitute(tp); +} + +} // namespace Luau diff --git a/Analysis/src/IostreamHelpers.cpp b/Analysis/src/IostreamHelpers.cpp index 43580da4d..64e059933 100644 --- a/Analysis/src/IostreamHelpers.cpp +++ b/Analysis/src/IostreamHelpers.cpp @@ -1,6 +1,7 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/IostreamHelpers.h" #include "Luau/ToString.h" +#include "Luau/TypePath.h" namespace Luau { @@ -113,6 +114,8 @@ static void errorToString(std::ostream& stream, const T& err) stream << "GenericError { " << err.message << " }"; else if constexpr (std::is_same_v) stream << "InternalError { " << err.message << " }"; + else if constexpr (std::is_same_v) + stream << "ConstraintSolvingIncompleteError {}"; else if constexpr (std::is_same_v) stream << "CannotCallNonFunction { " << toString(err.ty) << " }"; else if constexpr (std::is_same_v) @@ -192,13 +195,76 @@ static void errorToString(std::ostream& stream, const T& err) stream << "TypePackMismatch { wanted = '" + toString(err.wantedTp) + "', given = '" + toString(err.givenTp) + "' }"; else if constexpr (std::is_same_v) stream << "DynamicPropertyLookupOnClassesUnsafe { " << toString(err.ty) << " }"; + else if constexpr (std::is_same_v) + stream << "UninhabitedTypeFunction { " << toString(err.ty) << " }"; + else if constexpr (std::is_same_v) + { + std::string recArgs = "["; + for (auto [s, t] : err.recommendedArgs) + recArgs += " " + s + ": " + toString(t); + recArgs += " ]"; + stream << "ExplicitFunctionAnnotationRecommended { recommmendedReturn = '" + toString(err.recommendedReturn) + + "', recommmendedArgs = " + recArgs + "}"; + } + else if constexpr (std::is_same_v) + stream << "UninhabitedTypePackFunction { " << toString(err.tp) << " }"; + else if constexpr (std::is_same_v) + stream << "WhereClauseNeeded { " << toString(err.ty) << " }"; + else if constexpr (std::is_same_v) + stream << "PackWhereClauseNeeded { " << toString(err.tp) << " }"; + else if constexpr (std::is_same_v) + stream << "CheckedFunctionCallError { expected = '" << toString(err.expected) << "', passed = '" << toString(err.passed) + << "', checkedFunctionName = " << err.checkedFunctionName << ", argumentIndex = " << std::to_string(err.argumentIndex) << " }"; + else if constexpr (std::is_same_v) + stream << "NonStrictFunctionDefinitionError { functionName = '" + err.functionName + "', argument = '" + err.argument + + "', argumentType = '" + toString(err.argumentType) + "' }"; + else if constexpr (std::is_same_v) + stream << "PropertyAccessViolation { table = " << toString(err.table) << ", prop = '" << err.key << "', context = " << err.context << " }"; + else if constexpr (std::is_same_v) + stream << "CheckedFunction { functionName = '" + err.functionName + ", expected = " + std::to_string(err.expected) + + ", actual = " + std::to_string(err.actual) + "}"; + else if constexpr (std::is_same_v) + stream << "UnexpectedTypeInSubtyping { ty = '" + toString(err.ty) + "' }"; + else if constexpr (std::is_same_v) + stream << "UnexpectedTypePackInSubtyping { tp = '" + toString(err.tp) + "' }"; + else if constexpr (std::is_same_v) + stream << "UserDefinedTypeFunctionError { " << err.message << " }"; + else if constexpr (std::is_same_v) + { + stream << "CannotAssignToNever { rvalueType = '" << toString(err.rhsType) << "', reason = '" << err.reason << "', cause = { "; + + bool first = true; + for (TypeId ty : err.cause) + { + if (first) + first = false; + else + stream << ", "; + + stream << "'" << toString(ty) << "'"; + } + + stream << " } } "; + } else static_assert(always_false_v, "Non-exhaustive type switch"); } +std::ostream& operator<<(std::ostream& stream, const CannotAssignToNever::Reason& reason) +{ + switch (reason) + { + case CannotAssignToNever::Reason::PropertyNarrowed: + return stream << "PropertyNarrowed"; + default: + return stream << "UnknownReason"; + } +} + std::ostream& operator<<(std::ostream& stream, const TypeErrorData& data) { - auto cb = [&](const auto& e) { + auto cb = [&](const auto& e) + { return errorToString(stream, e); }; visit(cb, data); @@ -225,4 +291,34 @@ std::ostream& operator<<(std::ostream& stream, const TypePackVar& tv) return stream << toString(tv); } +std::ostream& operator<<(std::ostream& stream, TypeId ty) +{ + // we commonly use a null pointer when a type may not be present; we need to + // account for that here. + if (!ty) + return stream << ""; + + return stream << toString(ty); +} + +std::ostream& operator<<(std::ostream& stream, TypePackId tp) +{ + // we commonly use a null pointer when a type may not be present; we need to + // account for that here. + if (!tp) + return stream << ""; + + return stream << toString(tp); +} + +namespace TypePath +{ + +std::ostream& operator<<(std::ostream& stream, const Path& path) +{ + return stream << toString(path); +} + +} // namespace TypePath + } // namespace Luau diff --git a/Analysis/src/Linter.cpp b/Analysis/src/Linter.cpp index f850bd3d1..c4f46c84f 100644 --- a/Analysis/src/Linter.cpp +++ b/Analysis/src/Linter.cpp @@ -14,48 +14,15 @@ LUAU_FASTINTVARIABLE(LuauSuggestionDistance, 4) -LUAU_FASTFLAGVARIABLE(LuauImproveDeprecatedApiLint, false) +LUAU_FASTFLAG(LuauSolverV2) + +LUAU_FASTFLAG(LuauAttribute) +LUAU_FASTFLAG(LuauNativeAttribute) +LUAU_FASTFLAGVARIABLE(LintRedundantNativeAttribute, false) namespace Luau { -// clang-format off -static const char* kWarningNames[] = { - "Unknown", - - "UnknownGlobal", - "DeprecatedGlobal", - "GlobalUsedAsLocal", - "LocalShadow", - "SameLineStatement", - "MultiLineStatement", - "LocalUnused", - "FunctionUnused", - "ImportUnused", - "BuiltinGlobalWrite", - "PlaceholderRead", - "UnreachableCode", - "UnknownType", - "ForRange", - "UnbalancedAssignment", - "ImplicitReturn", - "DuplicateLocal", - "FormatString", - "TableLiteral", - "UninitializedLocal", - "DuplicateFunction", - "DeprecatedApi", - "TableOperations", - "DuplicateCondition", - "MisleadingAndOr", - "CommentDirective", - "IntegerParsing", - "ComparisonPrecedence", -}; -// clang-format on - -static_assert(std::size(kWarningNames) == unsigned(LintWarning::Code__Count), "did you forget to add warning to the list?"); - struct LintContext { struct Global @@ -308,8 +275,14 @@ class LintGlobalLocal : AstVisitor else if (g->deprecated) { if (const char* replacement = *g->deprecated; replacement && strlen(replacement)) - emitWarning(*context, LintWarning::Code_DeprecatedGlobal, gv->location, "Global '%s' is deprecated, use '%s' instead", - gv->name.value, replacement); + emitWarning( + *context, + LintWarning::Code_DeprecatedGlobal, + gv->location, + "Global '%s' is deprecated, use '%s' instead", + gv->name.value, + replacement + ); else emitWarning(*context, LintWarning::Code_DeprecatedGlobal, gv->location, "Global '%s' is deprecated", gv->name.value); } @@ -324,18 +297,33 @@ class LintGlobalLocal : AstVisitor AstExprFunction* top = g.functionRef.back(); if (top->debugname.value) - emitWarning(*context, LintWarning::Code_GlobalUsedAsLocal, g.firstRef->location, - "Global '%s' is only used in the enclosing function '%s'; consider changing it to local", g.firstRef->name.value, - top->debugname.value); + emitWarning( + *context, + LintWarning::Code_GlobalUsedAsLocal, + g.firstRef->location, + "Global '%s' is only used in the enclosing function '%s'; consider changing it to local", + g.firstRef->name.value, + top->debugname.value + ); else - emitWarning(*context, LintWarning::Code_GlobalUsedAsLocal, g.firstRef->location, + emitWarning( + *context, + LintWarning::Code_GlobalUsedAsLocal, + g.firstRef->location, "Global '%s' is only used in the enclosing function defined at line %d; consider changing it to local", - g.firstRef->name.value, top->location.begin.line + 1); + g.firstRef->name.value, + top->location.begin.line + 1 + ); } else if (g.assigned && !g.readBeforeWritten && !g.definedInModuleScope && g.firstRef->name != context->placeholder) { - emitWarning(*context, LintWarning::Code_GlobalUsedAsLocal, g.firstRef->location, - "Global '%s' is never read before being written. Consider changing it to local", g.firstRef->name.value); + emitWarning( + *context, + LintWarning::Code_GlobalUsedAsLocal, + g.firstRef->location, + "Global '%s' is never read before being written. Consider changing it to local", + g.firstRef->name.value + ); } } } @@ -362,7 +350,8 @@ class LintGlobalLocal : AstVisitor if (node->name == context->placeholder) emitWarning( - *context, LintWarning::Code_PlaceholderRead, node->location, "Placeholder value '_' is read here; consider using a named variable"); + *context, LintWarning::Code_PlaceholderRead, node->location, "Placeholder value '_' is read here; consider using a named variable" + ); return true; } @@ -371,7 +360,8 @@ class LintGlobalLocal : AstVisitor { if (node->local->name == context->placeholder) emitWarning( - *context, LintWarning::Code_PlaceholderRead, node->location, "Placeholder value '_' is read here; consider using a named variable"); + *context, LintWarning::Code_PlaceholderRead, node->location, "Placeholder value '_' is read here; consider using a named variable" + ); return true; } @@ -399,8 +389,13 @@ class LintGlobalLocal : AstVisitor } if (g.builtin) - emitWarning(*context, LintWarning::Code_BuiltinGlobalWrite, gv->location, - "Built-in global '%s' is overwritten here; consider using a local or changing the name", gv->name.value); + emitWarning( + *context, + LintWarning::Code_BuiltinGlobalWrite, + gv->location, + "Built-in global '%s' is overwritten here; consider using a local or changing the name", + gv->name.value + ); else g.assigned = true; @@ -429,8 +424,13 @@ class LintGlobalLocal : AstVisitor Global& g = globals[gv->name]; if (g.builtin) - emitWarning(*context, LintWarning::Code_BuiltinGlobalWrite, gv->location, - "Built-in global '%s' is overwritten here; consider using a local or changing the name", gv->name.value); + emitWarning( + *context, + LintWarning::Code_BuiltinGlobalWrite, + gv->location, + "Built-in global '%s' is overwritten here; consider using a local or changing the name", + gv->name.value + ); else { g.assigned = true; @@ -598,8 +598,12 @@ class LintSameLineStatement : AstVisitor if (node->body.data[i - 1]->hasSemicolon) continue; - emitWarning(*context, LintWarning::Code_SameLineStatement, location, - "A new statement is on the same line; add semi-colon on previous statement to silence"); + emitWarning( + *context, + LintWarning::Code_SameLineStatement, + location, + "A new statement is on the same line; add semi-colon on previous statement to silence" + ); lastLine = location.begin.line; } @@ -646,7 +650,8 @@ class LintMultiLineStatement : AstVisitor if (location.begin.column <= top.start.begin.column) { emitWarning( - *context, LintWarning::Code_MultiLineStatement, location, "Statement spans multiple lines; use indentation to silence"); + *context, LintWarning::Code_MultiLineStatement, location, "Statement spans multiple lines; use indentation to silence" + ); top.flagged = true; } @@ -760,8 +765,14 @@ class LintLocalHygiene : AstVisitor // don't warn on inter-function shadowing since it is much more fragile wrt refactoring if (shadow->functionDepth == local->functionDepth) - emitWarning(*context, LintWarning::Code_LocalShadow, local->location, "Variable '%s' shadows previous declaration at line %d", - local->name.value, shadow->location.begin.line + 1); + emitWarning( + *context, + LintWarning::Code_LocalShadow, + local->location, + "Variable '%s' shadows previous declaration at line %d", + local->name.value, + shadow->location.begin.line + 1 + ); } else if (Global* global = globals.find(local->name)) { @@ -769,8 +780,14 @@ class LintLocalHygiene : AstVisitor ; // there are many builtins with common names like 'table'; some of them are deprecated as well else if (global->firstRef) { - emitWarning(*context, LintWarning::Code_LocalShadow, local->location, "Variable '%s' shadows a global variable used at line %d", - local->name.value, global->firstRef->location.begin.line + 1); + emitWarning( + *context, + LintWarning::Code_LocalShadow, + local->location, + "Variable '%s' shadows a global variable used at line %d", + local->name.value, + global->firstRef->location.begin.line + 1 + ); } else { @@ -785,14 +802,21 @@ class LintLocalHygiene : AstVisitor return; if (info.function) - emitWarning(*context, LintWarning::Code_FunctionUnused, local->location, "Function '%s' is never used; prefix with '_' to silence", - local->name.value); + emitWarning( + *context, + LintWarning::Code_FunctionUnused, + local->location, + "Function '%s' is never used; prefix with '_' to silence", + local->name.value + ); else if (info.import) - emitWarning(*context, LintWarning::Code_ImportUnused, local->location, "Import '%s' is never used; prefix with '_' to silence", - local->name.value); + emitWarning( + *context, LintWarning::Code_ImportUnused, local->location, "Import '%s' is never used; prefix with '_' to silence", local->name.value + ); else - emitWarning(*context, LintWarning::Code_LocalUnused, local->location, "Variable '%s' is never used; prefix with '_' to silence", - local->name.value); + emitWarning( + *context, LintWarning::Code_LocalUnused, local->location, "Variable '%s' is never used; prefix with '_' to silence", local->name.value + ); } bool isRequireCall(AstExpr* expr) @@ -946,8 +970,13 @@ class LintUnusedFunction : AstVisitor for (auto& g : globals) { if (g.second.function && !g.second.used && g.first.value[0] != '_') - emitWarning(*context, LintWarning::Code_FunctionUnused, g.second.location, "Function '%s' is never used; prefix with '_' to silence", - g.first.value); + emitWarning( + *context, + LintWarning::Code_FunctionUnused, + g.second.location, + "Function '%s' is never used; prefix with '_' to silence", + g.first.value + ); } } @@ -1046,8 +1075,13 @@ class LintUnreachableCode : AstVisitor if (step == Error && si->is() && next->is() && i + 2 == stat->body.size) return Error; - emitWarning(*context, LintWarning::Code_UnreachableCode, next->location, "Unreachable code (previous statement always %ss)", - getReason(step)); + emitWarning( + *context, + LintWarning::Code_UnreachableCode, + next->location, + "Unreachable code (previous statement always %ss)", + getReason(step) + ); return step; } } @@ -1144,7 +1178,7 @@ class LintUnknownType : AstVisitor TypeKind getTypeKind(const std::string& name) { if (name == "nil" || name == "boolean" || name == "userdata" || name == "number" || name == "string" || name == "table" || - name == "function" || name == "thread") + name == "function" || name == "thread" || name == "buffer") return Kind_Primitive; if (name == "vector") @@ -1242,22 +1276,34 @@ class LintForRange : AstVisitor // for i=#t,1 do if (fu && fu->op == AstExprUnary::Len && tc && tc->value == 1.0) emitWarning( - *context, LintWarning::Code_ForRange, rangeLocation, "For loop should iterate backwards; did you forget to specify -1 as step?"); + *context, LintWarning::Code_ForRange, rangeLocation, "For loop should iterate backwards; did you forget to specify -1 as step?" + ); // for i=8,1 do else if (fc && tc && fc->value > tc->value) emitWarning( - *context, LintWarning::Code_ForRange, rangeLocation, "For loop should iterate backwards; did you forget to specify -1 as step?"); + *context, LintWarning::Code_ForRange, rangeLocation, "For loop should iterate backwards; did you forget to specify -1 as step?" + ); // for i=1,8.75 do else if (fc && tc && getLoopEnd(fc->value, tc->value) != tc->value) - emitWarning(*context, LintWarning::Code_ForRange, rangeLocation, "For loop ends at %g instead of %g; did you forget to specify step?", - getLoopEnd(fc->value, tc->value), tc->value); + emitWarning( + *context, + LintWarning::Code_ForRange, + rangeLocation, + "For loop ends at %g instead of %g; did you forget to specify step?", + getLoopEnd(fc->value, tc->value), + tc->value + ); // for i=0,#t do else if (fc && tu && fc->value == 0.0 && tu->op == AstExprUnary::Len) emitWarning(*context, LintWarning::Code_ForRange, rangeLocation, "For loop starts at 0, but arrays start at 1"); // for i=#t,0 do else if (fu && fu->op == AstExprUnary::Len && tc && tc->value == 0.0) - emitWarning(*context, LintWarning::Code_ForRange, rangeLocation, - "For loop should iterate backwards; did you forget to specify -1 as step? Also consider changing 0 to 1 since arrays start at 1"); + emitWarning( + *context, + LintWarning::Code_ForRange, + rangeLocation, + "For loop should iterate backwards; did you forget to specify -1 as step? Also consider changing 0 to 1 since arrays start at 1" + ); } return true; @@ -1285,16 +1331,27 @@ class LintUnbalancedAssignment : AstVisitor AstExpr* last = values.data[values.size - 1]; if (vars < values.size) - emitWarning(*context, LintWarning::Code_UnbalancedAssignment, location, - "Assigning %d values to %d variables leaves some values unused", int(values.size), int(vars)); + emitWarning( + *context, + LintWarning::Code_UnbalancedAssignment, + location, + "Assigning %d values to %d variables leaves some values unused", + int(values.size), + int(vars) + ); else if (last->is() || last->is()) ; // we don't know how many values the last expression returns else if (last->is()) ; // last expression is nil which explicitly silences the nil-init warning else - emitWarning(*context, LintWarning::Code_UnbalancedAssignment, location, - "Assigning %d values to %d variables initializes extra variables with nil; add 'nil' to value list to silence", int(values.size), - int(vars)); + emitWarning( + *context, + LintWarning::Code_UnbalancedAssignment, + location, + "Assigning %d values to %d variables initializes extra variables with nil; add 'nil' to value list to silence", + int(values.size), + int(vars) + ); } } @@ -1377,13 +1434,22 @@ class LintImplicitReturn : AstVisitor Location location = getEndLocation(bodyf); if (node->debugname.value) - emitWarning(*context, LintWarning::Code_ImplicitReturn, location, + emitWarning( + *context, + LintWarning::Code_ImplicitReturn, + location, "Function '%s' can implicitly return no values even though there's an explicit return at line %d; add explicit return to silence", - node->debugname.value, vret->location.begin.line + 1); + node->debugname.value, + vret->location.begin.line + 1 + ); else - emitWarning(*context, LintWarning::Code_ImplicitReturn, location, + emitWarning( + *context, + LintWarning::Code_ImplicitReturn, + location, "Function can implicitly return no values even though there's an explicit return at line %d; add explicit return to silence", - vret->location.begin.line + 1); + vret->location.begin.line + 1 + ); } return true; @@ -1854,23 +1920,41 @@ class LintTableLiteral : AstVisitor int& line = names[&expr->value]; if (line) - emitWarning(*context, LintWarning::Code_TableLiteral, expr->location, - "Table field '%.*s' is a duplicate; previously defined at line %d", int(expr->value.size), expr->value.data, line); + emitWarning( + *context, + LintWarning::Code_TableLiteral, + expr->location, + "Table field '%.*s' is a duplicate; previously defined at line %d", + int(expr->value.size), + expr->value.data, + line + ); else line = expr->location.begin.line + 1; } else if (AstExprConstantNumber* expr = item.key->as()) { if (expr->value >= 1 && expr->value <= double(count) && double(int(expr->value)) == expr->value) - emitWarning(*context, LintWarning::Code_TableLiteral, expr->location, - "Table index %d is a duplicate; previously defined as a list entry", int(expr->value)); + emitWarning( + *context, + LintWarning::Code_TableLiteral, + expr->location, + "Table index %d is a duplicate; previously defined as a list entry", + int(expr->value) + ); else if (expr->value >= 0 && expr->value <= double(INT_MAX) && double(int(expr->value)) == expr->value) { int& line = indices[int(expr->value)]; if (line) - emitWarning(*context, LintWarning::Code_TableLiteral, expr->location, - "Table index %d is a duplicate; previously defined at line %d", int(expr->value), line); + emitWarning( + *context, + LintWarning::Code_TableLiteral, + expr->location, + "Table index %d is a duplicate; previously defined at line %d", + int(expr->value), + line + ); else line = expr->location.begin.line + 1; } @@ -1887,6 +1971,72 @@ class LintTableLiteral : AstVisitor bool visit(AstTypeTable* node) override { + if (FFlag::LuauSolverV2) + { + struct Rec + { + AstTableAccess access; + Location location; + }; + DenseHashMap names(AstName{}); + + for (const AstTableProp& item : node->props) + { + Rec* rec = names.find(item.name); + if (!rec) + { + names[item.name] = Rec{item.access, item.location}; + continue; + } + + if (int(rec->access) & int(item.access)) + { + if (rec->access == item.access) + emitWarning( + *context, + LintWarning::Code_TableLiteral, + item.location, + "Table type field '%s' is a duplicate; previously defined at line %d", + item.name.value, + rec->location.begin.line + 1 + ); + else if (rec->access == AstTableAccess::ReadWrite) + emitWarning( + *context, + LintWarning::Code_TableLiteral, + item.location, + "Table type field '%s' is already read-write; previously defined at line %d", + item.name.value, + rec->location.begin.line + 1 + ); + else if (rec->access == AstTableAccess::Read) + emitWarning( + *context, + LintWarning::Code_TableLiteral, + rec->location, + "Table type field '%s' already has a read type defined at line %d", + item.name.value, + rec->location.begin.line + 1 + ); + else if (rec->access == AstTableAccess::Write) + emitWarning( + *context, + LintWarning::Code_TableLiteral, + rec->location, + "Table type field '%s' already has a write type defined at line %d", + item.name.value, + rec->location.begin.line + 1 + ); + else + LUAU_ASSERT(!"Unreachable"); + } + else + rec->access = AstTableAccess(int(rec->access) | int(item.access)); + } + + return true; + } + DenseHashMap names(AstName{}); for (const AstTableProp& item : node->props) @@ -1894,8 +2044,14 @@ class LintTableLiteral : AstVisitor int& line = names[item.name]; if (line) - emitWarning(*context, LintWarning::Code_TableLiteral, item.location, - "Table type field '%s' is a duplicate; previously defined at line %d", item.name.value, line); + emitWarning( + *context, + LintWarning::Code_TableLiteral, + item.location, + "Table type field '%s' is a duplicate; previously defined at line %d", + item.name.value, + line + ); else line = item.location.begin.line + 1; } @@ -1956,9 +2112,14 @@ class LintUninitializedLocal : AstVisitor if (l.defined && !l.initialized && !l.assigned && l.firstUse) { - emitWarning(*context, LintWarning::Code_UninitializedLocal, l.firstUse->location, - "Variable '%s' defined at line %d is never initialized or assigned; initialize with 'nil' to silence", local->name.value, - local->location.begin.line + 1); + emitWarning( + *context, + LintWarning::Code_UninitializedLocal, + l.firstUse->location, + "Variable '%s' defined at line %d is never initialized or assigned; initialize with 'nil' to silence", + local->name.value, + local->location.begin.line + 1 + ); } } } @@ -2092,8 +2253,14 @@ class LintDuplicateFunction : AstVisitor void report(const std::string& name, Location location, Location otherLocation) { - emitWarning(*context, LintWarning::Code_DuplicateFunction, location, "Duplicate function definition: '%s' also defined on line %d", - name.c_str(), otherLocation.begin.line + 1); + emitWarning( + *context, + LintWarning::Code_DuplicateFunction, + location, + "Duplicate function definition: '%s' also defined on line %d", + name.c_str(), + otherLocation.begin.line + 1 + ); } }; @@ -2102,9 +2269,6 @@ class LintDeprecatedApi : AstVisitor public: LUAU_NOINLINE static void process(LintContext& context) { - if (!FFlag::LuauImproveDeprecatedApiLint && !context.module) - return; - LintDeprecatedApi pass{&context}; context.root->visit(&pass); } @@ -2122,8 +2286,34 @@ class LintDeprecatedApi : AstVisitor if (std::optional ty = context->getType(node->expr)) check(node, follow(*ty)); else if (AstExprGlobal* global = node->expr->as()) - if (FFlag::LuauImproveDeprecatedApiLint) - check(node->location, global->name, node->index); + check(node->location, global->name, node->index); + + return true; + } + + bool visit(AstExprCall* node) override + { + // getfenv/setfenv are deprecated, however they are still used in some test frameworks and don't have a great general replacement + // for now we warn about the deprecation only when they are used with a numeric first argument; this produces fewer warnings and makes use + // of getfenv/setfenv a little more localized + if (!node->self && node->args.size >= 1) + { + if (AstExprGlobal* fenv = node->func->as(); fenv && (fenv->name == "getfenv" || fenv->name == "setfenv")) + { + AstExpr* level = node->args.data[0]; + std::optional ty = context->getType(level); + + if ((ty && isNumber(*ty)) || level->is()) + { + // some common uses of getfenv(n) can be replaced by debug.info if the goal is to get the caller's identity + const char* suggestion = (fenv->name == "getfenv") ? "; consider using 'debug.info' instead" : ""; + + emitWarning( + *context, LintWarning::Code_DeprecatedApi, node->location, "Function '%s' is deprecated%s", fenv->name.value, suggestion + ); + } + } + } return true; } @@ -2144,7 +2334,7 @@ class LintDeprecatedApi : AstVisitor if (prop != tty->props.end() && prop->second.deprecated) { // strip synthetic typeof() for builtin tables - if (FFlag::LuauImproveDeprecatedApiLint && tty->name && tty->name->compare(0, 7, "typeof(") == 0 && tty->name->back() == ')') + if (tty->name && tty->name->compare(0, 7, "typeof(") == 0 && tty->name->back() == ')') report(node->location, prop->second, tty->name->substr(7, tty->name->length() - 8).c_str(), node->index.value); else report(node->location, prop->second, tty->name ? tty->name->c_str() : nullptr, node->index.value); @@ -2197,16 +2387,50 @@ class LintTableOperations : AstVisitor { } + bool visit(AstExprUnary* node) override + { + if (node->op == AstExprUnary::Len) + checkIndexer(node, node->expr, "#"); + + return true; + } + bool visit(AstExprCall* node) override { - AstExprIndexName* func = node->func->as(); - if (!func) - return true; + if (AstExprGlobal* func = node->func->as()) + { + if (func->name == "ipairs" && node->args.size == 1) + checkIndexer(node, node->args.data[0], "ipairs"); + } + else if (AstExprIndexName* func = node->func->as()) + { + if (AstExprGlobal* tablib = func->expr->as(); tablib && tablib->name == "table") + checkTableCall(node, func); + } - AstExprGlobal* tablib = func->expr->as(); - if (!tablib || tablib->name != "table") - return true; + return true; + } + + void checkIndexer(AstExpr* node, AstExpr* expr, const char* op) + { + std::optional ty = context->getType(expr); + if (!ty) + return; + + const TableType* tty = get(follow(*ty)); + if (!tty) + return; + if (!tty->indexer && !tty->props.empty() && tty->state != TableState::Generic) + emitWarning( + *context, LintWarning::Code_TableOperations, node->location, "Using '%s' on a table without an array part is likely a bug", op + ); + else if (tty->indexer && isString(tty->indexer->indexType)) // note: to avoid complexity of subtype tests we just check if the key is a string + emitWarning(*context, LintWarning::Code_TableOperations, node->location, "Using '%s' on a table with string keys is likely a bug", op); + } + + void checkTableCall(AstExprCall* node, AstExprIndexName* func) + { AstExpr** args = node->args.data; if (func->index == "insert" && node->args.size == 2) @@ -2218,9 +2442,13 @@ class LintTableOperations : AstVisitor size_t ret = getReturnCount(follow(*funty)); if (ret > 1) - emitWarning(*context, LintWarning::Code_TableOperations, tail->location, + emitWarning( + *context, + LintWarning::Code_TableOperations, + tail->location, "table.insert may change behavior if the call returns more than one result; consider adding parentheses around second " - "argument"); + "argument" + ); } } } @@ -2229,28 +2457,44 @@ class LintTableOperations : AstVisitor { // table.insert(t, 0, ?) if (isConstant(args[1], 0.0)) - emitWarning(*context, LintWarning::Code_TableOperations, args[1]->location, - "table.insert uses index 0 but arrays are 1-based; did you mean 1 instead?"); + emitWarning( + *context, + LintWarning::Code_TableOperations, + args[1]->location, + "table.insert uses index 0 but arrays are 1-based; did you mean 1 instead?" + ); // table.insert(t, #t, ?) if (isLength(args[1], args[0])) - emitWarning(*context, LintWarning::Code_TableOperations, args[1]->location, + emitWarning( + *context, + LintWarning::Code_TableOperations, + args[1]->location, "table.insert will insert the value before the last element, which is likely a bug; consider removing the second argument or " - "wrap it in parentheses to silence"); + "wrap it in parentheses to silence" + ); // table.insert(t, #t+1, ?) if (AstExprBinary* add = args[1]->as(); add && add->op == AstExprBinary::Add && isLength(add->left, args[0]) && isConstant(add->right, 1.0)) - emitWarning(*context, LintWarning::Code_TableOperations, args[1]->location, - "table.insert will append the value to the table; consider removing the second argument for efficiency"); + emitWarning( + *context, + LintWarning::Code_TableOperations, + args[1]->location, + "table.insert will append the value to the table; consider removing the second argument for efficiency" + ); } if (func->index == "remove" && node->args.size >= 2) { // table.remove(t, 0) if (isConstant(args[1], 0.0)) - emitWarning(*context, LintWarning::Code_TableOperations, args[1]->location, - "table.remove uses index 0 but arrays are 1-based; did you mean 1 instead?"); + emitWarning( + *context, + LintWarning::Code_TableOperations, + args[1]->location, + "table.remove uses index 0 but arrays are 1-based; did you mean 1 instead?" + ); // note: it's tempting to check for table.remove(t, #t), which is equivalent to table.remove(t), but it's correct, occurs frequently, // and also reads better. @@ -2258,38 +2502,56 @@ class LintTableOperations : AstVisitor // table.remove(t, #t-1) if (AstExprBinary* sub = args[1]->as(); sub && sub->op == AstExprBinary::Sub && isLength(sub->left, args[0]) && isConstant(sub->right, 1.0)) - emitWarning(*context, LintWarning::Code_TableOperations, args[1]->location, + emitWarning( + *context, + LintWarning::Code_TableOperations, + args[1]->location, "table.remove will remove the value before the last element, which is likely a bug; consider removing the second argument or " - "wrap it in parentheses to silence"); + "wrap it in parentheses to silence" + ); } if (func->index == "move" && node->args.size >= 4) { // table.move(t, 0, _, _) if (isConstant(args[1], 0.0)) - emitWarning(*context, LintWarning::Code_TableOperations, args[1]->location, - "table.move uses index 0 but arrays are 1-based; did you mean 1 instead?"); + emitWarning( + *context, + LintWarning::Code_TableOperations, + args[1]->location, + "table.move uses index 0 but arrays are 1-based; did you mean 1 instead?" + ); // table.move(t, _, _, 0) else if (isConstant(args[3], 0.0)) - emitWarning(*context, LintWarning::Code_TableOperations, args[3]->location, - "table.move uses index 0 but arrays are 1-based; did you mean 1 instead?"); + emitWarning( + *context, + LintWarning::Code_TableOperations, + args[3]->location, + "table.move uses index 0 but arrays are 1-based; did you mean 1 instead?" + ); } if (func->index == "create" && node->args.size == 2) { // table.create(n, {...}) if (args[1]->is()) - emitWarning(*context, LintWarning::Code_TableOperations, args[1]->location, - "table.create with a table literal will reuse the same object for all elements; consider using a for loop instead"); + emitWarning( + *context, + LintWarning::Code_TableOperations, + args[1]->location, + "table.create with a table literal will reuse the same object for all elements; consider using a for loop instead" + ); // table.create(n, {...} :: ?) if (AstExprTypeAssertion* as = args[1]->as(); as && as->expr->is()) - emitWarning(*context, LintWarning::Code_TableOperations, as->expr->location, - "table.create with a table literal will reuse the same object for all elements; consider using a for loop instead"); + emitWarning( + *context, + LintWarning::Code_TableOperations, + as->expr->location, + "table.create with a table literal will reuse the same object for all elements; consider using a for loop instead" + ); } - - return true; } bool isConstant(AstExpr* expr, double value) @@ -2480,11 +2742,21 @@ class LintDuplicateCondition : AstVisitor if (similar(conditions[j], conditions[i])) { if (conditions[i]->location.begin.line == conditions[j]->location.begin.line) - emitWarning(*context, LintWarning::Code_DuplicateCondition, conditions[i]->location, - "Condition has already been checked on column %d", conditions[j]->location.begin.column + 1); + emitWarning( + *context, + LintWarning::Code_DuplicateCondition, + conditions[i]->location, + "Condition has already been checked on column %d", + conditions[j]->location.begin.column + 1 + ); else - emitWarning(*context, LintWarning::Code_DuplicateCondition, conditions[i]->location, - "Condition has already been checked on line %d", conditions[j]->location.begin.line + 1); + emitWarning( + *context, + LintWarning::Code_DuplicateCondition, + conditions[i]->location, + "Condition has already been checked on line %d", + conditions[j]->location.begin.line + 1 + ); break; } } @@ -2529,11 +2801,23 @@ class LintDuplicateLocal : AstVisitor if (local->shadow && locals[local->shadow] == node && !ignoreDuplicate(local)) { if (local->shadow->location.begin.line == local->location.begin.line) - emitWarning(*context, LintWarning::Code_DuplicateLocal, local->location, "Variable '%s' already defined on column %d", - local->name.value, local->shadow->location.begin.column + 1); + emitWarning( + *context, + LintWarning::Code_DuplicateLocal, + local->location, + "Variable '%s' already defined on column %d", + local->name.value, + local->shadow->location.begin.column + 1 + ); else - emitWarning(*context, LintWarning::Code_DuplicateLocal, local->location, "Variable '%s' already defined on line %d", - local->name.value, local->shadow->location.begin.line + 1); + emitWarning( + *context, + LintWarning::Code_DuplicateLocal, + local->location, + "Variable '%s' already defined on line %d", + local->name.value, + local->shadow->location.begin.line + 1 + ); } } @@ -2557,11 +2841,23 @@ class LintDuplicateLocal : AstVisitor if (local->shadow == node->self) emitWarning(*context, LintWarning::Code_DuplicateLocal, local->location, "Function parameter 'self' already defined implicitly"); else if (local->shadow->location.begin.line == local->location.begin.line) - emitWarning(*context, LintWarning::Code_DuplicateLocal, local->location, "Function parameter '%s' already defined on column %d", - local->name.value, local->shadow->location.begin.column + 1); + emitWarning( + *context, + LintWarning::Code_DuplicateLocal, + local->location, + "Function parameter '%s' already defined on column %d", + local->name.value, + local->shadow->location.begin.column + 1 + ); else - emitWarning(*context, LintWarning::Code_DuplicateLocal, local->location, "Function parameter '%s' already defined on line %d", - local->name.value, local->shadow->location.begin.line + 1); + emitWarning( + *context, + LintWarning::Code_DuplicateLocal, + local->location, + "Function parameter '%s' already defined on line %d", + local->name.value, + local->shadow->location.begin.line + 1 + ); } } @@ -2605,10 +2901,14 @@ class LintMisleadingAndOr : AstVisitor alt = "false"; if (alt) - emitWarning(*context, LintWarning::Code_MisleadingAndOr, node->location, + emitWarning( + *context, + LintWarning::Code_MisleadingAndOr, + node->location, "The and-or expression always evaluates to the second alternative because the first alternative is %s; consider using if-then-else " "expression instead", - alt); + alt + ); return true; } @@ -2635,13 +2935,29 @@ class LintIntegerParsing : AstVisitor case ConstantNumberParseResult::Ok: case ConstantNumberParseResult::Malformed: break; + case ConstantNumberParseResult::Imprecise: + emitWarning( + *context, + LintWarning::Code_IntegerParsing, + node->location, + "Number literal exceeded available precision and was truncated to closest representable number" + ); + break; case ConstantNumberParseResult::BinOverflow: - emitWarning(*context, LintWarning::Code_IntegerParsing, node->location, - "Binary number literal exceeded available precision and has been truncated to 2^64"); + emitWarning( + *context, + LintWarning::Code_IntegerParsing, + node->location, + "Binary number literal exceeded available precision and was truncated to 2^64" + ); break; case ConstantNumberParseResult::HexOverflow: - emitWarning(*context, LintWarning::Code_IntegerParsing, node->location, - "Hexadecimal number literal exceeded available precision and has been truncated to 2^64"); + emitWarning( + *context, + LintWarning::Code_IntegerParsing, + node->location, + "Hexadecimal number literal exceeded available precision and was truncated to 2^64" + ); break; } @@ -2692,12 +3008,24 @@ class LintComparisonPrecedence : AstVisitor std::string op = toString(node->op); if (isEquality(node->op)) - emitWarning(*context, LintWarning::Code_ComparisonPrecedence, node->location, - "not X %s Y is equivalent to (not X) %s Y; consider using X %s Y, or add parentheses to silence", op.c_str(), op.c_str(), - node->op == AstExprBinary::CompareEq ? "~=" : "=="); + emitWarning( + *context, + LintWarning::Code_ComparisonPrecedence, + node->location, + "not X %s Y is equivalent to (not X) %s Y; consider using X %s Y, or add parentheses to silence", + op.c_str(), + op.c_str(), + node->op == AstExprBinary::CompareEq ? "~=" : "==" + ); else - emitWarning(*context, LintWarning::Code_ComparisonPrecedence, node->location, - "not X %s Y is equivalent to (not X) %s Y; add parentheses to silence", op.c_str(), op.c_str()); + emitWarning( + *context, + LintWarning::Code_ComparisonPrecedence, + node->location, + "not X %s Y is equivalent to (not X) %s Y; add parentheses to silence", + op.c_str(), + op.c_str() + ); } else if (AstExprBinary* left = node->left->as(); left && isComparison(left->op)) { @@ -2705,12 +3033,29 @@ class LintComparisonPrecedence : AstVisitor std::string rop = toString(node->op); if (isEquality(left->op) || isEquality(node->op)) - emitWarning(*context, LintWarning::Code_ComparisonPrecedence, node->location, - "X %s Y %s Z is equivalent to (X %s Y) %s Z; add parentheses to silence", lop.c_str(), rop.c_str(), lop.c_str(), rop.c_str()); + emitWarning( + *context, + LintWarning::Code_ComparisonPrecedence, + node->location, + "X %s Y %s Z is equivalent to (X %s Y) %s Z; add parentheses to silence", + lop.c_str(), + rop.c_str(), + lop.c_str(), + rop.c_str() + ); else - emitWarning(*context, LintWarning::Code_ComparisonPrecedence, node->location, - "X %s Y %s Z is equivalent to (X %s Y) %s Z; did you mean X %s Y and Y %s Z?", lop.c_str(), rop.c_str(), lop.c_str(), rop.c_str(), - lop.c_str(), rop.c_str()); + emitWarning( + *context, + LintWarning::Code_ComparisonPrecedence, + node->location, + "X %s Y %s Z is equivalent to (X %s Y) %s Z; did you mean X %s Y and Y %s Z?", + lop.c_str(), + rop.c_str(), + lop.c_str(), + rop.c_str(), + lop.c_str(), + rop.c_str() + ); } return true; @@ -2776,8 +3121,12 @@ static void lintComments(LintContext& context, const std::vector& ho if (!hc.header) { - emitWarning(context, LintWarning::Code_CommentDirective, hc.location, - "Comment directive is ignored because it is placed after the first non-comment token"); + emitWarning( + context, + LintWarning::Code_CommentDirective, + hc.location, + "Comment directive is ignored because it is placed after the first non-comment token" + ); } else { @@ -2798,21 +3147,36 @@ static void lintComments(LintContext& context, const std::vector& ho // skip Unknown if (const char* suggestion = fuzzyMatch(rule, kWarningNames + 1, LintWarning::Code__Count - 1)) - emitWarning(context, LintWarning::Code_CommentDirective, hc.location, - "nolint directive refers to unknown lint rule '%s'; did you mean '%s'?", rule, suggestion); + emitWarning( + context, + LintWarning::Code_CommentDirective, + hc.location, + "nolint directive refers to unknown lint rule '%s'; did you mean '%s'?", + rule, + suggestion + ); else emitWarning( - context, LintWarning::Code_CommentDirective, hc.location, "nolint directive refers to unknown lint rule '%s'", rule); + context, LintWarning::Code_CommentDirective, hc.location, "nolint directive refers to unknown lint rule '%s'", rule + ); } } else if (first == "nocheck" || first == "nonstrict" || first == "strict") { if (space != std::string::npos) - emitWarning(context, LintWarning::Code_CommentDirective, hc.location, - "Comment directive with the type checking mode has extra symbols at the end of the line"); + emitWarning( + context, + LintWarning::Code_CommentDirective, + hc.location, + "Comment directive with the type checking mode has extra symbols at the end of the line" + ); else if (seenMode) - emitWarning(context, LintWarning::Code_CommentDirective, hc.location, - "Comment directive with the type checking mode has already been used"); + emitWarning( + context, + LintWarning::Code_CommentDirective, + hc.location, + "Comment directive with the type checking mode has already been used" + ); else seenMode = true; } @@ -2827,10 +3191,22 @@ static void lintComments(LintContext& context, const std::vector& ho const char* level = hc.content.c_str() + notspace; if (strcmp(level, "0") && strcmp(level, "1") && strcmp(level, "2")) - emitWarning(context, LintWarning::Code_CommentDirective, hc.location, - "optimize directive uses unknown optimization level '%s', 0..2 expected", level); + emitWarning( + context, + LintWarning::Code_CommentDirective, + hc.location, + "optimize directive uses unknown optimization level '%s', 0..2 expected", + level + ); } } + else if (first == "native") + { + if (space != std::string::npos) + emitWarning( + context, LintWarning::Code_CommentDirective, hc.location, "native directive has extra symbols at the end of the line" + ); + } else { static const char* kHotComments[] = { @@ -2839,27 +3215,96 @@ static void lintComments(LintContext& context, const std::vector& ho "nonstrict", "strict", "optimize", + "native", }; if (const char* suggestion = fuzzyMatch(first, kHotComments, std::size(kHotComments))) - emitWarning(context, LintWarning::Code_CommentDirective, hc.location, "Unknown comment directive '%.*s'; did you mean '%s'?", - int(first.size()), first.data(), suggestion); + emitWarning( + context, + LintWarning::Code_CommentDirective, + hc.location, + "Unknown comment directive '%.*s'; did you mean '%s'?", + int(first.size()), + first.data(), + suggestion + ); else - emitWarning(context, LintWarning::Code_CommentDirective, hc.location, "Unknown comment directive '%.*s'", int(first.size()), - first.data()); + emitWarning( + context, LintWarning::Code_CommentDirective, hc.location, "Unknown comment directive '%.*s'", int(first.size()), first.data() + ); } } } } -void LintOptions::setDefaults() +static bool hasNativeCommentDirective(const std::vector& hotcomments) { - // By default, we enable all warnings - warningMask = ~0ull; + LUAU_ASSERT(FFlag::LuauNativeAttribute); + LUAU_ASSERT(FFlag::LintRedundantNativeAttribute); + + for (const HotComment& hc : hotcomments) + { + if (hc.content.empty() || hc.content[0] == ' ' || hc.content[0] == '\t') + continue; + + if (hc.header) + { + size_t space = hc.content.find_first_of(" \t"); + std::string_view first = std::string_view(hc.content).substr(0, space); + + if (first == "native") + return true; + } + } + + return false; } -std::vector lint(AstStat* root, const AstNameTable& names, const ScopePtr& env, const Module* module, - const std::vector& hotcomments, const LintOptions& options) +struct LintRedundantNativeAttribute : AstVisitor +{ +public: + LUAU_NOINLINE static void process(LintContext& context) + { + LUAU_ASSERT(FFlag::LuauNativeAttribute); + LUAU_ASSERT(FFlag::LintRedundantNativeAttribute); + + LintRedundantNativeAttribute pass; + pass.context = &context; + context.root->visit(&pass); + } + +private: + LintContext* context; + + bool visit(AstExprFunction* node) override + { + node->body->visit(this); + + for (const auto attribute : node->attributes) + { + if (attribute->type == AstAttr::Type::Native) + { + emitWarning( + *context, + LintWarning::Code_RedundantNativeAttribute, + attribute->location, + "native attribute on a function is redundant in a native module; consider removing it" + ); + } + } + + return false; + } +}; + +std::vector lint( + AstStat* root, + const AstNameTable& names, + const ScopePtr& env, + const Module* module, + const std::vector& hotcomments, + const LintOptions& options +) { LintContext context; @@ -2944,57 +3389,15 @@ std::vector lint(AstStat* root, const AstNameTable& names, const Sc if (context.warningEnabled(LintWarning::Code_ComparisonPrecedence)) LintComparisonPrecedence::process(context); - std::sort(context.result.begin(), context.result.end(), WarningComparator()); - - return context.result; -} - -const char* LintWarning::getName(Code code) -{ - LUAU_ASSERT(unsigned(code) < Code__Count); - - return kWarningNames[code]; -} - -LintWarning::Code LintWarning::parseName(const char* name) -{ - for (int code = Code_Unknown; code < Code__Count; ++code) - if (strcmp(name, getName(Code(code))) == 0) - return Code(code); - - return Code_Unknown; -} - -uint64_t LintWarning::parseMask(const std::vector& hotcomments) -{ - uint64_t result = 0; - - for (const HotComment& hc : hotcomments) + if (FFlag::LuauNativeAttribute && FFlag::LintRedundantNativeAttribute && context.warningEnabled(LintWarning::Code_RedundantNativeAttribute)) { - if (!hc.header) - continue; - - if (hc.content.compare(0, 6, "nolint") != 0) - continue; - - size_t name = hc.content.find_first_not_of(" \t", 6); - - // --!nolint disables everything - if (name == std::string::npos) - return ~0ull; - - // --!nolint needs to be followed by a whitespace character - if (name == 6) - continue; - - // --!nolint name disables the specific lint - LintWarning::Code code = LintWarning::parseName(hc.content.c_str() + name); - - if (code != LintWarning::Code_Unknown) - result |= 1ull << int(code); + if (hasNativeCommentDirective(hotcomments)) + LintRedundantNativeAttribute::process(context); } - return result; + std::sort(context.result.begin(), context.result.end(), WarningComparator()); + + return context.result; } std::vector getDeprecatedGlobals(const AstNameTable& names) diff --git a/Analysis/src/Module.cpp b/Analysis/src/Module.cpp index b51b7c9a6..0c5b361cf 100644 --- a/Analysis/src/Module.cpp +++ b/Analysis/src/Module.cpp @@ -3,23 +3,19 @@ #include "Luau/Clone.h" #include "Luau/Common.h" -#include "Luau/ConstraintGraphBuilder.h" +#include "Luau/ConstraintGenerator.h" #include "Luau/Normalize.h" #include "Luau/RecursionCounter.h" #include "Luau/Scope.h" #include "Luau/Type.h" #include "Luau/TypeInfer.h" #include "Luau/TypePack.h" -#include "Luau/TypeReduction.h" #include "Luau/VisitType.h" #include -LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); -LUAU_FASTFLAGVARIABLE(LuauClonePublicInterfaceLess, false); -LUAU_FASTFLAG(LuauSubstitutionReentrant); -LUAU_FASTFLAG(LuauClassTypeVarsInSubstitution); -LUAU_FASTFLAG(LuauSubstitutionFixMissingFields); +LUAU_FASTFLAG(LuauSolverV2); +LUAU_DYNAMIC_FASTINT(LuauTypeSolverRelease) namespace Luau { @@ -28,8 +24,8 @@ static bool contains(Position pos, Comment comment) { if (comment.location.contains(pos)) return true; - else if (comment.type == Lexeme::BrokenComment && - comment.location.begin <= pos) // Broken comments are broken specifically because they don't have an end + else if (comment.type == Lexeme::BrokenComment && comment.location.begin <= pos) // Broken comments are broken specifically because they don't + // have an end return true; else if (comment.type == Lexeme::Comment && comment.location.end == pos) return true; @@ -37,14 +33,19 @@ static bool contains(Position pos, Comment comment) return false; } -bool isWithinComment(const SourceModule& sourceModule, Position pos) +static bool isWithinComment(const std::vector& commentLocations, Position pos) { - auto iter = std::lower_bound(sourceModule.commentLocations.begin(), sourceModule.commentLocations.end(), - Comment{Lexeme::Comment, Location{pos, pos}}, [](const Comment& a, const Comment& b) { + auto iter = std::lower_bound( + commentLocations.begin(), + commentLocations.end(), + Comment{Lexeme::Comment, Location{pos, pos}}, + [](const Comment& a, const Comment& b) + { return a.location.end < b.location.end; - }); + } + ); - if (iter == sourceModule.commentLocations.end()) + if (iter == commentLocations.end()) return false; if (contains(pos, *iter)) @@ -53,12 +54,22 @@ bool isWithinComment(const SourceModule& sourceModule, Position pos) // Due to the nature of std::lower_bound, it is possible that iter points at a comment that ends // at pos. We'll try the next comment, if it exists. ++iter; - if (iter == sourceModule.commentLocations.end()) + if (iter == commentLocations.end()) return false; return contains(pos, *iter); } +bool isWithinComment(const SourceModule& sourceModule, Position pos) +{ + return isWithinComment(sourceModule.commentLocations, pos); +} + +bool isWithinComment(const ParseResult& result, Position pos) +{ + return isWithinComment(result.commentLocations, pos); +} + struct ClonePublicInterface : Substitution { NotNull builtinTypes; @@ -89,27 +100,111 @@ struct ClonePublicInterface : Substitution return tp->owningArena == &module->internalTypes; } + bool ignoreChildrenVisit(TypeId ty) override + { + if (ty->owningArena != &module->internalTypes) + return true; + + return false; + } + + bool ignoreChildrenVisit(TypePackId tp) override + { + if (tp->owningArena != &module->internalTypes) + return true; + + return false; + } + TypeId clean(TypeId ty) override { TypeId result = clone(ty); if (FunctionType* ftv = getMutable(result)) + { + if (ftv->generics.empty() && ftv->genericPacks.empty()) + { + GenericTypeFinder marker; + marker.traverse(result); + + if (!marker.found) + ftv->hasNoFreeOrGenericTypes = true; + } + ftv->level = TypeLevel{0, 0}; + if (FFlag::LuauSolverV2 && DFInt::LuauTypeSolverRelease >= 645) + ftv->scope = nullptr; + } else if (TableType* ttv = getMutable(result)) + { ttv->level = TypeLevel{0, 0}; + if (FFlag::LuauSolverV2 && DFInt::LuauTypeSolverRelease >= 645) + ttv->scope = nullptr; + } + + if (FFlag::LuauSolverV2 && DFInt::LuauTypeSolverRelease >= 645) + { + if (auto freety = getMutable(result)) + { + if (DFInt::LuauTypeSolverRelease >= 646) + { + module->errors.emplace_back( + freety->scope->location, + module->name, + InternalError{"Free type is escaping its module; please report this bug at " + "https://github.com/luau-lang/luau/issues"} + ); + result = builtinTypes->errorRecoveryType(); + } + else + { + freety->scope = nullptr; + } + } + else if (auto genericty = getMutable(result)) + { + genericty->scope = nullptr; + } + } return result; } TypePackId clean(TypePackId tp) override { - return clone(tp); + if (FFlag::LuauSolverV2 && DFInt::LuauTypeSolverRelease >= 645) + { + auto clonedTp = clone(tp); + if (auto ftp = getMutable(clonedTp)) + { + + if (DFInt::LuauTypeSolverRelease >= 646) + { + module->errors.emplace_back( + ftp->scope->location, + module->name, + InternalError{"Free type pack is escaping its module; please report this bug at " + "https://github.com/luau-lang/luau/issues"} + ); + clonedTp = builtinTypes->errorRecoveryTypePack(); + } + else + { + ftp->scope = nullptr; + } + } + else if (auto gtp = getMutable(clonedTp)) + gtp->scope = nullptr; + return clonedTp; + } + else + { + return clone(tp); + } } TypeId cloneType(TypeId ty) { - LUAU_ASSERT(FFlag::LuauSubstitutionReentrant && FFlag::LuauSubstitutionFixMissingFields); - std::optional result = substitute(ty); if (result) { @@ -124,8 +219,6 @@ struct ClonePublicInterface : Substitution TypePackId cloneTypePack(TypePackId tp) { - LUAU_ASSERT(FFlag::LuauSubstitutionReentrant && FFlag::LuauSubstitutionFixMissingFields); - std::optional result = substitute(tp); if (result) { @@ -140,8 +233,6 @@ struct ClonePublicInterface : Substitution TypeFun cloneTypeFun(const TypeFun& tf) { - LUAU_ASSERT(FFlag::LuauSubstitutionReentrant && FFlag::LuauSubstitutionFixMissingFields); - std::vector typeParams; std::vector typePackParams; @@ -181,56 +272,38 @@ Module::~Module() void Module::clonePublicInterface(NotNull builtinTypes, InternalErrorReporter& ice) { - LUAU_ASSERT(interfaceTypes.types.empty()); - LUAU_ASSERT(interfaceTypes.typePacks.empty()); - - CloneState cloneState; + CloneState cloneState{builtinTypes}; ScopePtr moduleScope = getModuleScope(); TypePackId returnType = moduleScope->returnType; - std::optional varargPack = FFlag::DebugLuauDeferredConstraintResolution ? std::nullopt : moduleScope->varargPack; + std::optional varargPack = FFlag::LuauSolverV2 ? std::nullopt : moduleScope->varargPack; TxnLog log; ClonePublicInterface clonePublicInterface{&log, builtinTypes, this}; - if (FFlag::LuauClonePublicInterfaceLess) - returnType = clonePublicInterface.cloneTypePack(returnType); - else - returnType = clone(returnType, interfaceTypes, cloneState); + returnType = clonePublicInterface.cloneTypePack(returnType); moduleScope->returnType = returnType; if (varargPack) { - if (FFlag::LuauClonePublicInterfaceLess) - varargPack = clonePublicInterface.cloneTypePack(*varargPack); - else - varargPack = clone(*varargPack, interfaceTypes, cloneState); + varargPack = clonePublicInterface.cloneTypePack(*varargPack); moduleScope->varargPack = varargPack; } for (auto& [name, tf] : moduleScope->exportedTypeBindings) { - if (FFlag::LuauClonePublicInterfaceLess) - tf = clonePublicInterface.cloneTypeFun(tf); - else - tf = clone(tf, interfaceTypes, cloneState); + tf = clonePublicInterface.cloneTypeFun(tf); } for (auto& [name, ty] : declaredGlobals) { - if (FFlag::LuauClonePublicInterfaceLess) - ty = clonePublicInterface.cloneType(ty); - else - ty = clone(ty, interfaceTypes, cloneState); + ty = clonePublicInterface.cloneType(ty); } // Copy external stuff over to Module itself this->returnType = moduleScope->returnType; - if (FFlag::DebugLuauDeferredConstraintResolution) - this->exportedTypeBindings = moduleScope->exportedTypeBindings; - else - this->exportedTypeBindings = std::move(moduleScope->exportedTypeBindings); + this->exportedTypeBindings = moduleScope->exportedTypeBindings; } bool Module::hasModuleScope() const diff --git a/Analysis/src/NonStrictTypeChecker.cpp b/Analysis/src/NonStrictTypeChecker.cpp new file mode 100644 index 000000000..0ebc573d1 --- /dev/null +++ b/Analysis/src/NonStrictTypeChecker.cpp @@ -0,0 +1,781 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/NonStrictTypeChecker.h" + +#include "Luau/Ast.h" +#include "Luau/Common.h" +#include "Luau/Simplify.h" +#include "Luau/Type.h" +#include "Luau/Simplify.h" +#include "Luau/Subtyping.h" +#include "Luau/Normalize.h" +#include "Luau/Error.h" +#include "Luau/TimeTrace.h" +#include "Luau/TypeArena.h" +#include "Luau/TypeFunction.h" +#include "Luau/Def.h" +#include "Luau/ToString.h" +#include "Luau/TypeFwd.h" + +#include +#include + +namespace Luau +{ + +/* Push a scope onto the end of a stack for the lifetime of the StackPusher instance. + * NonStrictTypeChecker uses this to maintain knowledge about which scope encloses every + * given AstNode. + */ +struct StackPusher +{ + std::vector>* stack; + NotNull scope; + + explicit StackPusher(std::vector>& stack, Scope* scope) + : stack(&stack) + , scope(scope) + { + stack.push_back(NotNull{scope}); + } + + ~StackPusher() + { + if (stack) + { + LUAU_ASSERT(stack->back() == scope); + stack->pop_back(); + } + } + + StackPusher(const StackPusher&) = delete; + StackPusher&& operator=(const StackPusher&) = delete; + + StackPusher(StackPusher&& other) + : stack(std::exchange(other.stack, nullptr)) + , scope(other.scope) + { + } +}; + + +struct NonStrictContext +{ + NonStrictContext() = default; + + NonStrictContext(const NonStrictContext&) = delete; + NonStrictContext& operator=(const NonStrictContext&) = delete; + + NonStrictContext(NonStrictContext&&) = default; + NonStrictContext& operator=(NonStrictContext&&) = default; + + static NonStrictContext disjunction( + NotNull builtinTypes, + NotNull arena, + const NonStrictContext& left, + const NonStrictContext& right + ) + { + // disjunction implements union over the domain of keys + // if the default value for a defId not in the map is `never` + // then never | T is T + NonStrictContext disj{}; + + for (auto [def, leftTy] : left.context) + { + if (std::optional rightTy = right.find(def)) + disj.context[def] = simplifyUnion(builtinTypes, arena, leftTy, *rightTy).result; + else + disj.context[def] = leftTy; + } + + for (auto [def, rightTy] : right.context) + { + if (!left.find(def).has_value()) + disj.context[def] = rightTy; + } + + return disj; + } + + static NonStrictContext conjunction( + NotNull builtins, + NotNull arena, + const NonStrictContext& left, + const NonStrictContext& right + ) + { + NonStrictContext conj{}; + + for (auto [def, leftTy] : left.context) + { + if (std::optional rightTy = right.find(def)) + conj.context[def] = simplifyIntersection(builtins, arena, leftTy, *rightTy).result; + } + + return conj; + } + + // Returns true if the removal was successful + bool remove(const DefId& def) + { + std::vector defs; + collectOperands(def, &defs); + bool result = true; + for (DefId def : defs) + result = result && context.erase(def.get()) == 1; + return result; + } + + std::optional find(const DefId& def) const + { + const Def* d = def.get(); + return find(d); + } + + void addContext(const DefId& def, TypeId ty) + { + std::vector defs; + collectOperands(def, &defs); + for (DefId def : defs) + context[def.get()] = ty; + } + +private: + std::optional find(const Def* d) const + { + auto it = context.find(d); + if (it != context.end()) + return {it->second}; + return {}; + } + + std::unordered_map context; +}; + +struct NonStrictTypeChecker +{ + NotNull builtinTypes; + NotNull typeFunctionRuntime; + const NotNull ice; + NotNull arena; + Module* module; + Normalizer normalizer; + Subtyping subtyping; + NotNull dfg; + DenseHashSet noTypeFunctionErrors{nullptr}; + std::vector> stack; + DenseHashMap cachedNegations{nullptr}; + + const NotNull limits; + + NonStrictTypeChecker( + NotNull arena, + NotNull builtinTypes, + NotNull typeFunctionRuntime, + const NotNull ice, + NotNull unifierState, + NotNull dfg, + NotNull limits, + Module* module + ) + : builtinTypes(builtinTypes) + , typeFunctionRuntime(typeFunctionRuntime) + , ice(ice) + , arena(arena) + , module(module) + , normalizer{arena, builtinTypes, unifierState, /* cache inhabitance */ true} + , subtyping{builtinTypes, arena, NotNull(&normalizer), typeFunctionRuntime, ice} + , dfg(dfg) + , limits(limits) + { + } + + std::optional pushStack(AstNode* node) + { + if (Scope** scope = module->astScopes.find(node)) + return StackPusher{stack, *scope}; + else + return std::nullopt; + } + + TypeId flattenPack(TypePackId pack) + { + pack = follow(pack); + + if (auto fst = first(pack, /*ignoreHiddenVariadics*/ false)) + return *fst; + else if (auto ftp = get(pack)) + { + TypeId result = arena->addType(FreeType{ftp->scope}); + TypePackId freeTail = arena->addTypePack(FreeTypePack{ftp->scope}); + + TypePack* resultPack = emplaceTypePack(asMutable(pack)); + resultPack->head.assign(1, result); + resultPack->tail = freeTail; + + return result; + } + else if (get(pack)) + return builtinTypes->errorRecoveryType(); + else if (finite(pack) && size(pack) == 0) + return builtinTypes->nilType; // `(f())` where `f()` returns no values is coerced into `nil` + else + ice->ice("flattenPack got a weird pack!"); + } + + + TypeId checkForTypeFunctionInhabitance(TypeId instance, Location location) + { + if (noTypeFunctionErrors.find(instance)) + return instance; + + ErrorVec errors = reduceTypeFunctions( + instance, + location, + TypeFunctionContext{arena, builtinTypes, stack.back(), NotNull{&normalizer}, typeFunctionRuntime, ice, limits}, + true + ) + .errors; + + if (errors.empty()) + noTypeFunctionErrors.insert(instance); + // TODO?? + // if (!isErrorSuppressing(location, instance)) + // reportErrors(std::move(errors)); + return instance; + } + + + TypeId lookupType(AstExpr* expr) + { + TypeId* ty = module->astTypes.find(expr); + if (ty) + return checkForTypeFunctionInhabitance(follow(*ty), expr->location); + + TypePackId* tp = module->astTypePacks.find(expr); + if (tp) + return checkForTypeFunctionInhabitance(flattenPack(*tp), expr->location); + return builtinTypes->anyType; + } + + NonStrictContext visit(AstStat* stat) + { + auto pusher = pushStack(stat); + if (auto s = stat->as()) + return visit(s); + else if (auto s = stat->as()) + return visit(s); + else if (auto s = stat->as()) + return visit(s); + else if (auto s = stat->as()) + return visit(s); + else if (auto s = stat->as()) + return visit(s); + else if (auto s = stat->as()) + return visit(s); + else if (auto s = stat->as()) + return visit(s); + else if (auto s = stat->as()) + return visit(s); + else if (auto s = stat->as()) + return visit(s); + else if (auto s = stat->as()) + return visit(s); + else if (auto s = stat->as()) + return visit(s); + else if (auto s = stat->as()) + return visit(s); + else if (auto s = stat->as()) + return visit(s); + else if (auto s = stat->as()) + return visit(s); + else if (auto s = stat->as()) + return visit(s); + else if (auto s = stat->as()) + return visit(s); + else if (auto f = stat->as()) + return visit(f); + else if (auto s = stat->as()) + return visit(s); + else if (auto s = stat->as()) + return visit(s); + else if (auto s = stat->as()) + return visit(s); + else if (auto s = stat->as()) + return visit(s); + else + { + LUAU_ASSERT(!"NonStrictTypeChecker encountered an unknown statement type"); + ice->ice("NonStrictTypeChecker encountered an unknown statement type"); + } + } + + NonStrictContext visit(AstStatBlock* block) + { + auto StackPusher = pushStack(block); + NonStrictContext ctx; + + + for (auto it = block->body.rbegin(); it != block->body.rend(); it++) + { + AstStat* stat = *it; + if (AstStatLocal* local = stat->as()) + { + // Iterating in reverse order + // local x ; B generates the context of B without x + visit(local); + for (auto local : local->vars) + ctx.remove(dfg->getDef(local)); + } + else + ctx = NonStrictContext::disjunction(builtinTypes, arena, visit(stat), ctx); + } + return ctx; + } + + NonStrictContext visit(AstStatIf* ifStatement) + { + NonStrictContext condB = visit(ifStatement->condition); + NonStrictContext branchContext; + // If there is no else branch, don't bother generating warnings for the then branch - we can't prove there is an error + if (ifStatement->elsebody) + { + NonStrictContext thenBody = visit(ifStatement->thenbody); + NonStrictContext elseBody = visit(ifStatement->elsebody); + branchContext = NonStrictContext::conjunction(builtinTypes, arena, thenBody, elseBody); + } + return NonStrictContext::disjunction(builtinTypes, arena, condB, branchContext); + } + + NonStrictContext visit(AstStatWhile* whileStatement) + { + return {}; + } + + NonStrictContext visit(AstStatRepeat* repeatStatement) + { + return {}; + } + + NonStrictContext visit(AstStatBreak* breakStatement) + { + return {}; + } + + NonStrictContext visit(AstStatContinue* continueStatement) + { + return {}; + } + + NonStrictContext visit(AstStatReturn* returnStatement) + { + return {}; + } + + NonStrictContext visit(AstStatExpr* expr) + { + return visit(expr->expr); + } + + NonStrictContext visit(AstStatLocal* local) + { + for (AstExpr* rhs : local->values) + visit(rhs); + return {}; + } + + NonStrictContext visit(AstStatFor* forStatement) + { + return {}; + } + + NonStrictContext visit(AstStatForIn* forInStatement) + { + return {}; + } + + NonStrictContext visit(AstStatAssign* assign) + { + return {}; + } + + NonStrictContext visit(AstStatCompoundAssign* compoundAssign) + { + return {}; + } + + NonStrictContext visit(AstStatFunction* statFn) + { + return visit(statFn->func); + } + + NonStrictContext visit(AstStatLocalFunction* localFn) + { + return visit(localFn->func); + } + + NonStrictContext visit(AstStatTypeAlias* typeAlias) + { + return {}; + } + + NonStrictContext visit(AstStatTypeFunction* typeFunc) + { + reportError(GenericError{"This syntax is not supported"}, typeFunc->location); + return {}; + } + + NonStrictContext visit(AstStatDeclareFunction* declFn) + { + return {}; + } + + NonStrictContext visit(AstStatDeclareGlobal* declGlobal) + { + return {}; + } + + NonStrictContext visit(AstStatDeclareClass* declClass) + { + return {}; + } + + NonStrictContext visit(AstStatError* error) + { + return {}; + } + + NonStrictContext visit(AstExpr* expr) + { + auto pusher = pushStack(expr); + if (auto e = expr->as()) + return visit(e); + else if (auto e = expr->as()) + return visit(e); + else if (auto e = expr->as()) + return visit(e); + else if (auto e = expr->as()) + return visit(e); + else if (auto e = expr->as()) + return visit(e); + else if (auto e = expr->as()) + return visit(e); + else if (auto e = expr->as()) + return visit(e); + else if (auto e = expr->as()) + return visit(e); + else if (auto e = expr->as()) + return visit(e); + else if (auto e = expr->as()) + return visit(e); + else if (auto e = expr->as()) + return visit(e); + else if (auto e = expr->as()) + return visit(e); + else if (auto e = expr->as()) + return visit(e); + else if (auto e = expr->as()) + return visit(e); + else if (auto e = expr->as()) + return visit(e); + else if (auto e = expr->as()) + return visit(e); + else if (auto e = expr->as()) + return visit(e); + else if (auto e = expr->as()) + return visit(e); + else if (auto e = expr->as()) + return visit(e); + else + { + LUAU_ASSERT(!"NonStrictTypeChecker encountered an unknown expression type"); + ice->ice("NonStrictTypeChecker encountered an unknown expression type"); + } + } + + NonStrictContext visit(AstExprGroup* group) + { + return {}; + } + + NonStrictContext visit(AstExprConstantNil* expr) + { + return {}; + } + + NonStrictContext visit(AstExprConstantBool* expr) + { + return {}; + } + + NonStrictContext visit(AstExprConstantNumber* expr) + { + return {}; + } + + NonStrictContext visit(AstExprConstantString* expr) + { + return {}; + } + + NonStrictContext visit(AstExprLocal* local) + { + return {}; + } + + NonStrictContext visit(AstExprGlobal* global) + { + return {}; + } + + NonStrictContext visit(AstExprVarargs* global) + { + return {}; + } + + + NonStrictContext visit(AstExprCall* call) + { + NonStrictContext fresh{}; + TypeId* originalCallTy = module->astOriginalCallTypes.find(call->func); + if (!originalCallTy) + return fresh; + + TypeId fnTy = *originalCallTy; + if (auto fn = get(follow(fnTy))) + { + if (fn->isCheckedFunction) + { + // We know fn is a checked function, which means it looks like: + // (S1, ... SN) -> T & + // (~S1, unknown^N-1) -> error & + // (unknown, ~S2, unknown^N-2) -> error + // ... + // ... + // (unknown^N-1, ~S_N) -> error + std::vector argTypes; + argTypes.reserve(call->args.size); + // Pad out the arg types array with the types you would expect to see + TypePackIterator curr = begin(fn->argTypes); + TypePackIterator fin = end(fn->argTypes); + while (curr != fin) + { + argTypes.push_back(*curr); + ++curr; + } + if (auto argTail = curr.tail()) + { + if (const VariadicTypePack* vtp = get(follow(*argTail))) + { + while (argTypes.size() < call->args.size) + { + argTypes.push_back(vtp->ty); + } + } + } + + std::string functionName = getFunctionNameAsString(*call->func).value_or(""); + if (call->args.size > argTypes.size()) + { + // We are passing more arguments than we expect, so we should error + reportError(CheckedFunctionIncorrectArgs{functionName, argTypes.size(), call->args.size}, call->location); + return fresh; + } + + for (size_t i = 0; i < call->args.size; i++) + { + // For example, if the arg is "hi" + // The actual arg type is string + // The expected arg type is number + // The type of the argument in the overload is ~number + // We will compare arg and ~number + AstExpr* arg = call->args.data[i]; + TypeId expectedArgType = argTypes[i]; + std::shared_ptr norm = normalizer.normalize(expectedArgType); + DefId def = dfg->getDef(arg); + TypeId runTimeErrorTy; + // If we're dealing with any, negating any will cause all subtype tests to fail + // However, when someone calls this function, they're going to want to be able to pass it anything, + // for that reason, we manually inject never into the context so that the runtime test will always pass. + if (!norm) + reportError(NormalizationTooComplex{}, arg->location); + + if (norm && get(norm->tops)) + runTimeErrorTy = builtinTypes->neverType; + else + runTimeErrorTy = getOrCreateNegation(expectedArgType); + fresh.addContext(def, runTimeErrorTy); + } + + // Populate the context and now iterate through each of the arguments to the call to find out if we satisfy the types + for (size_t i = 0; i < call->args.size; i++) + { + AstExpr* arg = call->args.data[i]; + if (auto runTimeFailureType = willRunTimeError(arg, fresh)) + reportError(CheckedFunctionCallError{argTypes[i], *runTimeFailureType, functionName, i}, arg->location); + } + + if (call->args.size < argTypes.size()) + { + // We are passing fewer arguments than we expect + // so we need to ensure that the rest of the args are optional. + bool remainingArgsOptional = true; + for (size_t i = call->args.size; i < argTypes.size(); i++) + remainingArgsOptional = remainingArgsOptional && isOptional(argTypes[i]); + if (!remainingArgsOptional) + { + reportError(CheckedFunctionIncorrectArgs{functionName, argTypes.size(), call->args.size}, call->location); + return fresh; + } + } + } + } + + return fresh; + } + + NonStrictContext visit(AstExprIndexName* indexName) + { + return {}; + } + + NonStrictContext visit(AstExprIndexExpr* indexExpr) + { + return {}; + } + + NonStrictContext visit(AstExprFunction* exprFn) + { + // TODO: should a function being used as an expression generate a context without the arguments? + auto pusher = pushStack(exprFn); + NonStrictContext remainder = visit(exprFn->body); + for (AstLocal* local : exprFn->args) + { + if (std::optional ty = willRunTimeErrorFunctionDefinition(local, remainder)) + reportError(NonStrictFunctionDefinitionError{exprFn->debugname.value, local->name.value, *ty}, local->location); + remainder.remove(dfg->getDef(local)); + } + return remainder; + } + + NonStrictContext visit(AstExprTable* table) + { + return {}; + } + + NonStrictContext visit(AstExprUnary* unary) + { + return {}; + } + + NonStrictContext visit(AstExprBinary* binary) + { + return {}; + } + + NonStrictContext visit(AstExprTypeAssertion* typeAssertion) + { + return {}; + } + + NonStrictContext visit(AstExprIfElse* ifElse) + { + NonStrictContext condB = visit(ifElse->condition); + NonStrictContext thenB = visit(ifElse->trueExpr); + NonStrictContext elseB = visit(ifElse->falseExpr); + return NonStrictContext::disjunction(builtinTypes, arena, condB, NonStrictContext::conjunction(builtinTypes, arena, thenB, elseB)); + } + + NonStrictContext visit(AstExprInterpString* interpString) + { + return {}; + } + + NonStrictContext visit(AstExprError* error) + { + return {}; + } + + void reportError(TypeErrorData data, const Location& location) + { + module->errors.emplace_back(location, module->name, std::move(data)); + // TODO: weave in logger here? + } + + // If this fragment of the ast will run time error, return the type that causes this + std::optional willRunTimeError(AstExpr* fragment, const NonStrictContext& context) + { + NotNull scope{Luau::findScopeAtPosition(*module, fragment->location.end).get()}; + DefId def = dfg->getDef(fragment); + std::vector defs; + collectOperands(def, &defs); + for (DefId def : defs) + { + if (std::optional contextTy = context.find(def)) + { + + TypeId actualType = lookupType(fragment); + SubtypingResult r = subtyping.isSubtype(actualType, *contextTy, scope); + if (r.normalizationTooComplex) + reportError(NormalizationTooComplex{}, fragment->location); + if (r.isSubtype) + return {actualType}; + } + } + + return {}; + } + + std::optional willRunTimeErrorFunctionDefinition(AstLocal* fragment, const NonStrictContext& context) + { + NotNull scope{Luau::findScopeAtPosition(*module, fragment->location.end).get()}; + DefId def = dfg->getDef(fragment); + std::vector defs; + collectOperands(def, &defs); + for (DefId def : defs) + { + if (std::optional contextTy = context.find(def)) + { + SubtypingResult r1 = subtyping.isSubtype(builtinTypes->unknownType, *contextTy, scope); + SubtypingResult r2 = subtyping.isSubtype(*contextTy, builtinTypes->unknownType, scope); + if (r1.normalizationTooComplex || r2.normalizationTooComplex) + reportError(NormalizationTooComplex{}, fragment->location); + bool isUnknown = r1.isSubtype && r2.isSubtype; + if (isUnknown) + return {builtinTypes->unknownType}; + } + } + return {}; + } + +private: + TypeId getOrCreateNegation(TypeId baseType) + { + TypeId& cachedResult = cachedNegations[baseType]; + if (!cachedResult) + cachedResult = arena->addType(NegationType{baseType}); + return cachedResult; + } +}; + +void checkNonStrict( + NotNull builtinTypes, + NotNull typeFunctionRuntime, + NotNull ice, + NotNull unifierState, + NotNull dfg, + NotNull limits, + const SourceModule& sourceModule, + Module* module +) +{ + LUAU_TIMETRACE_SCOPE("checkNonStrict", "Typechecking"); + + NonStrictTypeChecker typeChecker{NotNull{&module->internalTypes}, builtinTypes, typeFunctionRuntime, ice, unifierState, dfg, limits, module}; + typeChecker.visit(sourceModule.root); + unfreeze(module->interfaceTypes); + copyErrors(module->errors, module->interfaceTypes, builtinTypes); + freeze(module->interfaceTypes); +} + +} // namespace Luau diff --git a/Analysis/src/Normalize.cpp b/Analysis/src/Normalize.cpp index 0552bec03..1480d2635 100644 --- a/Analysis/src/Normalize.cpp +++ b/Analysis/src/Normalize.cpp @@ -8,29 +8,50 @@ #include "Luau/Clone.h" #include "Luau/Common.h" #include "Luau/RecursionCounter.h" +#include "Luau/Set.h" +#include "Luau/Simplify.h" +#include "Luau/Subtyping.h" #include "Luau/Type.h" +#include "Luau/TypeFwd.h" #include "Luau/Unifier.h" -LUAU_FASTFLAGVARIABLE(DebugLuauCopyBeforeNormalizing, false) LUAU_FASTFLAGVARIABLE(DebugLuauCheckNormalizeInvariant, false) -// This could theoretically be 2000 on amd64, but x86 requires this. -LUAU_FASTINTVARIABLE(LuauNormalizeIterationLimit, 1200); LUAU_FASTINTVARIABLE(LuauNormalizeCacheLimit, 100000); -LUAU_FASTFLAGVARIABLE(LuauNegatedClassTypes, false); -LUAU_FASTFLAGVARIABLE(LuauNegatedFunctionTypes, false); -LUAU_FASTFLAGVARIABLE(LuauNegatedTableTypes, false); -LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) -LUAU_FASTFLAG(LuauUninhabitedSubAnything2) +LUAU_FASTFLAG(LuauSolverV2); + +LUAU_FASTINTVARIABLE(LuauNormalizeIntersectionLimit, 200) +LUAU_FASTFLAGVARIABLE(LuauNormalizationTracksCyclicPairsThroughInhabitance, false); +LUAU_FASTFLAGVARIABLE(LuauIntersectNormalsNeedsToTrackResourceLimits, false); namespace Luau { + +static bool shouldEarlyExit(NormalizationResult res) +{ + // if res is hit limits, return control flow + if (res == NormalizationResult::HitLimits || res == NormalizationResult::False) + return true; + return false; +} + +TypeIds::TypeIds(std::initializer_list tys) +{ + for (TypeId ty : tys) + insert(ty); +} + void TypeIds::insert(TypeId ty) { ty = follow(ty); - auto [_, fresh] = types.insert(ty); - if (fresh) + + // get a reference to the slot for `ty` in `types` + bool& entry = types[ty]; + + // if `ty` is fresh, we can set it to `true`, add it to the order and hash and be done. + if (!entry) { + entry = true; order.push_back(ty); hash ^= std::hash{}(ty); } @@ -71,25 +92,35 @@ TypeIds::const_iterator TypeIds::end() const TypeIds::iterator TypeIds::erase(TypeIds::const_iterator it) { TypeId ty = *it; - types.erase(ty); + types[ty] = false; hash ^= std::hash{}(ty); return order.erase(it); } +void TypeIds::erase(TypeId ty) +{ + const_iterator it = std::find(order.begin(), order.end(), ty); + if (it == order.end()) + return; + + erase(it); +} + size_t TypeIds::size() const { - return types.size(); + return order.size(); } bool TypeIds::empty() const { - return types.empty(); + return order.empty(); } size_t TypeIds::count(TypeId ty) const { ty = follow(ty); - return types.count(ty); + const bool* val = types.find(ty); + return (val && *val) ? 1 : 0; } void TypeIds::retain(const TypeIds& there) @@ -108,9 +139,44 @@ size_t TypeIds::getHash() const return hash; } +bool TypeIds::isNever() const +{ + return std::all_of( + begin(), + end(), + [&](TypeId i) + { + // If each typeid is never, then I guess typeid's is also never? + return get(i) != nullptr; + } + ); +} + bool TypeIds::operator==(const TypeIds& there) const { - return hash == there.hash && types == there.types; + // we can early return if the hashes don't match. + if (hash != there.hash) + return false; + + // we have to check equality of the sets themselves if not. + + // if the sets are unequal sizes, then they cannot possibly be equal. + // it is important to use `order` here and not `types` since the mappings + // may have different sizes since removal is not possible, and so erase + // simply writes `false` into the map. + if (order.size() != there.order.size()) + return false; + + // otherwise, we'll need to check that every element we have here is in `there`. + for (auto ty : order) + { + // if it's not, we'll return `false` + if (there.count(ty) == 0) + return false; + } + + // otherwise, we've proven the two equal! + return true; } NormalizedStringType::NormalizedStringType() {} @@ -169,7 +235,7 @@ const NormalizedStringType NormalizedStringType::never; bool isSubtype(const NormalizedStringType& subStr, const NormalizedStringType& superStr) { - if (subStr.isUnion() && superStr.isUnion()) + if (subStr.isUnion() && (superStr.isUnion() && !superStr.isNever())) { for (auto [name, ty] : subStr.singletons) { @@ -185,8 +251,10 @@ bool isSubtype(const NormalizedStringType& subStr, const NormalizedStringType& s void NormalizedClassType::pushPair(TypeId ty, TypeIds negations) { - ordering.push_back(ty); - classes.insert(std::make_pair(ty, std::move(negations))); + auto result = classes.insert(std::make_pair(ty, std::move(negations))); + if (result.second) + ordering.push_back(ty); + LUAU_ASSERT(ordering.size() == classes.size()); } void NormalizedClassType::resetToNever() @@ -200,26 +268,21 @@ bool NormalizedClassType::isNever() const return classes.empty(); } -NormalizedFunctionType::NormalizedFunctionType() - : parts(FFlag::LuauNegatedFunctionTypes ? std::optional{TypeIds{}} : std::nullopt) -{ -} - void NormalizedFunctionType::resetToTop() { isTop = true; - parts.emplace(); + parts.clear(); } void NormalizedFunctionType::resetToNever() { isTop = false; - parts.emplace(); + parts.clear(); } bool NormalizedFunctionType::isNever() const { - return !isTop && (!parts || parts->empty()); + return !isTop && parts.empty(); } NormalizedType::NormalizedType(NotNull builtinTypes) @@ -230,75 +293,241 @@ NormalizedType::NormalizedType(NotNull builtinTypes) , numbers(builtinTypes->neverType) , strings{NormalizedStringType::never} , threads(builtinTypes->neverType) + , buffers(builtinTypes->neverType) { } -static bool isShallowInhabited(const NormalizedType& norm) +bool NormalizedType::isUnknown() const { - bool inhabitedClasses; + if (get(tops)) + return true; - if (FFlag::LuauNegatedClassTypes) - inhabitedClasses = !norm.classes.isNever(); - else - inhabitedClasses = !norm.DEPRECATED_classes.empty(); + // Otherwise, we can still be unknown! + bool hasAllPrimitives = isPrim(booleans, PrimitiveType::Boolean) && isPrim(nils, PrimitiveType::NilType) && isNumber(numbers) && + strings.isString() && isPrim(threads, PrimitiveType::Thread) && isThread(threads); + + // Check is class + bool isTopClass = false; + for (auto [t, disj] : classes.classes) + { + if (auto ct = get(t)) + { + if (ct->name == "class" && disj.empty()) + { + isTopClass = true; + break; + } + } + } + // Check is table + bool isTopTable = false; + for (auto t : tables) + { + if (isPrim(t, PrimitiveType::Table)) + { + isTopTable = true; + break; + } + } + // any = unknown or error ==> we need to make sure we have all the unknown components, but not errors + return get(errors) && hasAllPrimitives && isTopClass && isTopTable && functions.isTop; +} + +bool NormalizedType::isExactlyNumber() const +{ + return hasNumbers() && !hasTops() && !hasBooleans() && !hasClasses() && !hasErrors() && !hasNils() && !hasStrings() && !hasThreads() && + !hasBuffers() && !hasTables() && !hasFunctions() && !hasTyvars(); +} + +bool NormalizedType::isSubtypeOfString() const +{ + return hasStrings() && !hasTops() && !hasBooleans() && !hasClasses() && !hasErrors() && !hasNils() && !hasNumbers() && !hasThreads() && + !hasBuffers() && !hasTables() && !hasFunctions() && !hasTyvars(); +} + +bool NormalizedType::isSubtypeOfBooleans() const +{ + return hasBooleans() && !hasTops() && !hasClasses() && !hasErrors() && !hasNils() && !hasNumbers() && !hasStrings() && !hasThreads() && + !hasBuffers() && !hasTables() && !hasFunctions() && !hasTyvars(); +} + +bool NormalizedType::shouldSuppressErrors() const +{ + return hasErrors() || get(tops); +} + +bool NormalizedType::hasTopTable() const +{ + return hasTables() && std::any_of( + tables.begin(), + tables.end(), + [&](TypeId ty) + { + auto primTy = get(ty); + return primTy && primTy->type == PrimitiveType::Type::Table; + } + ); +} + +bool NormalizedType::hasTops() const +{ + return !get(tops); +} + + +bool NormalizedType::hasBooleans() const +{ + return !get(booleans); +} + +bool NormalizedType::hasClasses() const +{ + return !classes.isNever(); +} + +bool NormalizedType::hasErrors() const +{ + return !get(errors); +} + +bool NormalizedType::hasNils() const +{ + return !get(nils); +} + +bool NormalizedType::hasNumbers() const +{ + return !get(numbers); +} + +bool NormalizedType::hasStrings() const +{ + return !strings.isNever(); +} + +bool NormalizedType::hasThreads() const +{ + return !get(threads); +} + +bool NormalizedType::hasBuffers() const +{ + return !get(buffers); +} +bool NormalizedType::hasTables() const +{ + return !tables.isNever(); +} + +bool NormalizedType::hasFunctions() const +{ + return !functions.isNever(); +} + +bool NormalizedType::hasTyvars() const +{ + return !tyvars.empty(); +} + +bool NormalizedType::isFalsy() const +{ + + bool hasAFalse = false; + if (auto singleton = get(booleans)) + { + if (auto bs = singleton->variant.get_if()) + hasAFalse = !bs->value; + } + + return (hasAFalse || hasNils()) && (!hasTops() && !hasClasses() && !hasErrors() && !hasNumbers() && !hasStrings() && !hasThreads() && + !hasBuffers() && !hasTables() && !hasFunctions() && !hasTyvars()); +} + +bool NormalizedType::isTruthy() const +{ + return !isFalsy(); +} + +static bool isShallowInhabited(const NormalizedType& norm) +{ // This test is just a shallow check, for example it returns `true` for `{ p : never }` - return !get(norm.tops) || !get(norm.booleans) || inhabitedClasses || !get(norm.errors) || + return !get(norm.tops) || !get(norm.booleans) || !norm.classes.isNever() || !get(norm.errors) || !get(norm.nils) || !get(norm.numbers) || !norm.strings.isNever() || !get(norm.threads) || - !norm.functions.isNever() || !norm.tables.empty() || !norm.tyvars.empty(); + !get(norm.buffers) || !norm.functions.isNever() || !norm.tables.empty() || !norm.tyvars.empty(); } -bool isInhabited_DEPRECATED(const NormalizedType& norm) +NormalizationResult Normalizer::isInhabited(const NormalizedType* norm) { - LUAU_ASSERT(!FFlag::LuauUninhabitedSubAnything2); - return isShallowInhabited(norm); + Set seen{nullptr}; + + return isInhabited(norm, seen); } -bool Normalizer::isInhabited(const NormalizedType* norm, std::unordered_set seen) +NormalizationResult Normalizer::isInhabited(const NormalizedType* norm, Set& seen) { - // If normalization failed, the type is complex, and so is more likely than not to be inhabited. - if (!norm) - return true; - - bool inhabitedClasses; - if (FFlag::LuauNegatedClassTypes) - inhabitedClasses = !norm->classes.isNever(); - else - inhabitedClasses = !norm->DEPRECATED_classes.empty(); + RecursionCounter _rc(&sharedState->counters.recursionCount); + if (!withinResourceLimits() || !norm) + return NormalizationResult::HitLimits; if (!get(norm->tops) || !get(norm->booleans) || !get(norm->errors) || !get(norm->nils) || - !get(norm->numbers) || !get(norm->threads) || inhabitedClasses || !norm->strings.isNever() || - !norm->functions.isNever()) - return true; + !get(norm->numbers) || !get(norm->threads) || !get(norm->buffers) || !norm->classes.isNever() || + !norm->strings.isNever() || !norm->functions.isNever()) + return NormalizationResult::True; for (const auto& [_, intersect] : norm->tyvars) { - if (isInhabited(intersect.get(), seen)) - return true; + NormalizationResult res = isInhabited(intersect.get(), seen); + if (res != NormalizationResult::False) + return res; } for (TypeId table : norm->tables) { - if (isInhabited(table, seen)) - return true; + NormalizationResult res = isInhabited(table, seen); + if (res != NormalizationResult::False) + return res; } - return false; + return NormalizationResult::False; +} + +NormalizationResult Normalizer::isInhabited(TypeId ty) +{ + if (cacheInhabitance) + { + if (bool* result = cachedIsInhabited.find(ty)) + return *result ? NormalizationResult::True : NormalizationResult::False; + } + + Set seen{nullptr}; + NormalizationResult result = isInhabited(ty, seen); + + if (cacheInhabitance && result == NormalizationResult::True) + cachedIsInhabited[ty] = true; + else if (cacheInhabitance && result == NormalizationResult::False) + cachedIsInhabited[ty] = false; + + return result; } -bool Normalizer::isInhabited(TypeId ty, std::unordered_set seen) +NormalizationResult Normalizer::isInhabited(TypeId ty, Set& seen) { + RecursionCounter _rc(&sharedState->counters.recursionCount); + if (!withinResourceLimits()) + return NormalizationResult::HitLimits; + // TODO: use log.follow(ty), CLI-64291 ty = follow(ty); if (get(ty)) - return false; + return NormalizationResult::False; if (!get(ty) && !get(ty) && !get(ty) && !get(ty)) - return true; + return NormalizationResult::True; if (seen.count(ty)) - return true; + return NormalizationResult::True; seen.insert(ty); @@ -306,17 +535,76 @@ bool Normalizer::isInhabited(TypeId ty, std::unordered_set seen) { for (const auto& [_, prop] : ttv->props) { - if (!isInhabited(prop.type, seen)) - return false; + if (FFlag::LuauSolverV2) + { + // A table enclosing a read property whose type is uninhabitable is also itself uninhabitable, + // but not its write property. That just means the write property doesn't exist, and so is readonly. + if (auto ty = prop.readTy) + { + NormalizationResult res = isInhabited(*ty, seen); + if (res != NormalizationResult::True) + return res; + } + } + else + { + NormalizationResult res = isInhabited(prop.type(), seen); + if (res != NormalizationResult::True) + return res; + } } - return true; + return NormalizationResult::True; } if (const MetatableType* mtv = get(ty)) - return isInhabited(mtv->table, seen) && isInhabited(mtv->metatable, seen); + { + NormalizationResult res = isInhabited(mtv->table, seen); + if (res != NormalizationResult::True) + return res; + return isInhabited(mtv->metatable, seen); + } - const NormalizedType* norm = normalize(ty); - return isInhabited(norm, seen); + std::shared_ptr norm = normalize(ty); + return isInhabited(norm.get(), seen); +} + +NormalizationResult Normalizer::isIntersectionInhabited(TypeId left, TypeId right) +{ + Set seen{nullptr}; + SeenTablePropPairs seenTablePropPairs{{nullptr, nullptr}}; + return isIntersectionInhabited(left, right, seenTablePropPairs, seen); +} + +NormalizationResult Normalizer::isIntersectionInhabited(TypeId left, TypeId right, SeenTablePropPairs& seenTablePropPairs, Set& seenSet) +{ + left = follow(left); + right = follow(right); + // We're asking if intersection is inahbited between left and right but we've already seen them .... + + if (cacheInhabitance) + { + if (bool* result = cachedIsInhabitedIntersection.find({left, right})) + return *result ? NormalizationResult::True : NormalizationResult::False; + } + + NormalizedType norm{builtinTypes}; + NormalizationResult res = normalizeIntersections({left, right}, norm, seenTablePropPairs, seenSet); + if (res != NormalizationResult::True) + { + if (cacheInhabitance && res == NormalizationResult::False) + cachedIsInhabitedIntersection[{left, right}] = false; + + return res; + } + + NormalizationResult result = isInhabited(&norm, seenSet); + + if (cacheInhabitance && result == NormalizationResult::True) + cachedIsInhabitedIntersection[{left, right}] = true; + else if (cacheInhabitance && result == NormalizationResult::False) + cachedIsInhabitedIntersection[{left, right}] = false; + + return result; } static int tyvarIndex(TypeId ty) @@ -325,6 +613,8 @@ static int tyvarIndex(TypeId ty) return gtv->index; else if (const FreeType* ftv = get(ty)) return ftv->index; + else if (const BlockedType* btv = get(ty)) + return btv->index; else return 0; } @@ -432,15 +722,22 @@ static bool isNormalizedThread(TypeId ty) return false; } +static bool isNormalizedBuffer(TypeId ty) +{ + if (get(ty)) + return true; + else if (const PrimitiveType* ptv = get(ty)) + return ptv->type == PrimitiveType::Buffer; + else + return false; +} + static bool areNormalizedFunctions(const NormalizedFunctionType& tys) { - if (tys.parts) + for (TypeId ty : tys.parts) { - for (TypeId ty : *tys.parts) - { - if (!get(ty) && !get(ty)) - return false; - } + if (!get(ty) && !get(ty)) + return false; } return true; } @@ -456,7 +753,7 @@ static bool areNormalizedTables(const TypeIds& tys) if (!pt) return false; - if (pt->type == PrimitiveType::Table && FFlag::LuauNegatedTableTypes) + if (pt->type == PrimitiveType::Table) continue; return false; @@ -465,14 +762,6 @@ static bool areNormalizedTables(const TypeIds& tys) return true; } -static bool areNormalizedClasses(const TypeIds& tys) -{ - for (TypeId ty : tys) - if (!get(ty)) - return false; - return true; -} - static bool areNormalizedClasses(const NormalizedClassType& tys) { for (const auto& [ty, negations] : tys.classes) @@ -510,7 +799,8 @@ static bool areNormalizedClasses(const NormalizedClassType& tys) if (isSubclass(ctv, octv)) { - auto iss = [ctv](TypeId t) { + auto iss = [ctv](TypeId t) + { const ClassType* c = get(t); if (!c) return false; @@ -529,7 +819,7 @@ static bool areNormalizedClasses(const NormalizedClassType& tys) static bool isPlainTyvar(TypeId ty) { - return (get(ty) || get(ty)); + return (get(ty) || get(ty) || get(ty) || get(ty) || get(ty)); } static bool isNormalizedTyvar(const NormalizedTyvars& tyvars) @@ -557,13 +847,13 @@ static void assertInvariant(const NormalizedType& norm) LUAU_ASSERT(isNormalizedTop(norm.tops)); LUAU_ASSERT(isNormalizedBoolean(norm.booleans)); - LUAU_ASSERT(areNormalizedClasses(norm.DEPRECATED_classes)); LUAU_ASSERT(areNormalizedClasses(norm.classes)); LUAU_ASSERT(isNormalizedError(norm.errors)); LUAU_ASSERT(isNormalizedNil(norm.nils)); LUAU_ASSERT(isNormalizedNumber(norm.numbers)); LUAU_ASSERT(isNormalizedString(norm.strings)); LUAU_ASSERT(isNormalizedThread(norm.threads)); + LUAU_ASSERT(isNormalizedBuffer(norm.buffers)); LUAU_ASSERT(areNormalizedFunctions(norm.functions)); LUAU_ASSERT(areNormalizedTables(norm.tables)); LUAU_ASSERT(isNormalizedTyvar(norm.tyvars)); @@ -572,29 +862,126 @@ static void assertInvariant(const NormalizedType& norm) #endif } -Normalizer::Normalizer(TypeArena* arena, NotNull builtinTypes, NotNull sharedState) +Normalizer::Normalizer(TypeArena* arena, NotNull builtinTypes, NotNull sharedState, bool cacheInhabitance) : arena(arena) , builtinTypes(builtinTypes) , sharedState(sharedState) + , cacheInhabitance(cacheInhabitance) +{ +} + +static bool isCacheable(TypeId ty, Set& seen); + +static bool isCacheable(TypePackId tp, Set& seen) +{ + tp = follow(tp); + + auto it = begin(tp); + auto endIt = end(tp); + for (; it != endIt; ++it) + { + if (!isCacheable(*it, seen)) + return false; + } + + if (auto tail = it.tail()) + { + if (get(*tail) || get(*tail) || get(*tail)) + return false; + } + + return true; +} + +static bool isCacheable(TypeId ty, Set& seen) +{ + if (seen.contains(ty)) + return true; + seen.insert(ty); + + ty = follow(ty); + + if (get(ty) || get(ty) || get(ty)) + return false; + + if (auto tfi = get(ty)) + { + for (TypeId t : tfi->typeArguments) + { + if (!isCacheable(t, seen)) + return false; + } + + for (TypePackId tp : tfi->packArguments) + { + if (!isCacheable(tp, seen)) + return false; + } + } + + return true; +} + +static bool isCacheable(TypeId ty) { + Set seen{nullptr}; + return isCacheable(ty, seen); } -const NormalizedType* Normalizer::normalize(TypeId ty) +std::shared_ptr Normalizer::normalize(TypeId ty) { if (!arena) sharedState->iceHandler->ice("Normalizing types outside a module"); auto found = cachedNormals.find(ty); if (found != cachedNormals.end()) - return found->second.get(); + return found->second; NormalizedType norm{builtinTypes}; - if (!unionNormalWithTy(norm, ty)) + Set seenSetTypes{nullptr}; + SeenTablePropPairs seenTablePropPairs{{nullptr, nullptr}}; + NormalizationResult res = unionNormalWithTy(norm, ty, seenTablePropPairs, seenSetTypes); + if (res != NormalizationResult::True) return nullptr; - std::unique_ptr uniq = std::make_unique(std::move(norm)); - const NormalizedType* result = uniq.get(); - cachedNormals[ty] = std::move(uniq); - return result; + + if (norm.isUnknown()) + { + clearNormal(norm); + norm.tops = builtinTypes->unknownType; + } + + std::shared_ptr shared = std::make_shared(std::move(norm)); + + if (shared->isCacheable) + cachedNormals[ty] = shared; + + return shared; +} + +NormalizationResult Normalizer::normalizeIntersections( + const std::vector& intersections, + NormalizedType& outType, + SeenTablePropPairs& seenTablePropPairs, + Set& seenSet +) +{ + if (!arena) + sharedState->iceHandler->ice("Normalizing types outside a module"); + NormalizedType norm{builtinTypes}; + norm.tops = builtinTypes->anyType; + // Now we need to intersect the two types + for (auto ty : intersections) + { + NormalizationResult res = intersectNormalWithTy(norm, ty, seenTablePropPairs, seenSet); + if (res != NormalizationResult::True) + return res; + } + + NormalizationResult res = unionNormals(outType, norm); + if (res != NormalizationResult::True) + return res; + + return NormalizationResult::True; } void Normalizer::clearNormal(NormalizedType& norm) @@ -602,12 +989,12 @@ void Normalizer::clearNormal(NormalizedType& norm) norm.tops = builtinTypes->neverType; norm.booleans = builtinTypes->neverType; norm.classes.resetToNever(); - norm.DEPRECATED_classes.clear(); norm.errors = builtinTypes->neverType; norm.nils = builtinTypes->neverType; norm.numbers = builtinTypes->neverType; norm.strings.resetToNever(); norm.threads = builtinTypes->neverType; + norm.buffers = builtinTypes->neverType; norm.tables.clear(); norm.functions.resetToNever(); norm.tyvars.clear(); @@ -1029,8 +1416,9 @@ std::optional Normalizer::unionOfTypePacks(TypePackId here, TypePack itt++; } - auto dealWithDifferentArities = [&](TypePackIterator& ith, TypePackIterator itt, TypePackId here, TypePackId there, bool& hereSubThere, - bool& thereSubHere) { + auto dealWithDifferentArities = + [&](TypePackIterator& ith, TypePackIterator itt, TypePackId here, TypePackId there, bool& hereSubThere, bool& thereSubHere) + { if (ith != end(here)) { TypeId tty = builtinTypes->nilType; @@ -1166,13 +1554,10 @@ std::optional Normalizer::unionOfFunctions(TypeId here, TypeId there) void Normalizer::unionFunctions(NormalizedFunctionType& heres, const NormalizedFunctionType& theres) { - if (FFlag::LuauNegatedFunctionTypes) - { - if (heres.isTop) - return; - if (theres.isTop) - heres.resetToTop(); - } + if (heres.isTop) + return; + if (theres.isTop) + heres.resetToTop(); if (theres.isNever()) return; @@ -1181,13 +1566,13 @@ void Normalizer::unionFunctions(NormalizedFunctionType& heres, const NormalizedF if (heres.isNever()) { - tmps.insert(theres.parts->begin(), theres.parts->end()); + tmps.insert(theres.parts.begin(), theres.parts.end()); heres.parts = std::move(tmps); return; } - for (TypeId here : *heres.parts) - for (TypeId there : *theres.parts) + for (TypeId here : heres.parts) + for (TypeId there : theres.parts) { if (std::optional fun = unionOfFunctions(here, there)) tmps.insert(*fun); @@ -1209,7 +1594,7 @@ void Normalizer::unionFunctionsWithFunction(NormalizedFunctionType& heres, TypeI } TypeIds tmps; - for (TypeId here : *heres.parts) + for (TypeId here : heres.parts) { if (std::optional fun = unionOfFunctions(here, there)) tmps.insert(*fun); @@ -1222,6 +1607,11 @@ void Normalizer::unionFunctionsWithFunction(NormalizedFunctionType& heres, TypeI void Normalizer::unionTablesWithTable(TypeIds& heres, TypeId there) { // TODO: remove unions of tables where possible + + // we can always skip `never` + if (get(there)) + return; + heres.insert(there); } @@ -1229,18 +1619,11 @@ void Normalizer::unionTables(TypeIds& heres, const TypeIds& theres) { for (TypeId there : theres) { - if (FFlag::LuauNegatedTableTypes) + if (there == builtinTypes->tableType) { - if (there == builtinTypes->tableType) - { - heres.clear(); - heres.insert(there); - return; - } - else - { - unionTablesWithTable(heres, there); - } + heres.clear(); + heres.insert(there); + return; } else { @@ -1268,14 +1651,18 @@ void Normalizer::unionTables(TypeIds& heres, const TypeIds& theres) // // And yes, this is essentially a SAT solver hidden inside a typechecker. // That's what you get for having a type system with generics, intersection and union types. -bool Normalizer::unionNormals(NormalizedType& here, const NormalizedType& there, int ignoreSmallerTyvars) +NormalizationResult Normalizer::unionNormals(NormalizedType& here, const NormalizedType& there, int ignoreSmallerTyvars) { + here.isCacheable &= there.isCacheable; + TypeId tops = unionOfTops(here.tops, there.tops); + if (get(tops) && (get(here.errors) || get(there.errors))) + tops = builtinTypes->anyType; if (!get(tops)) { clearNormal(here); here.tops = tops; - return true; + return NormalizationResult::True; } for (auto it = there.tyvars.begin(); it != there.tyvars.end(); it++) @@ -1287,26 +1674,29 @@ bool Normalizer::unionNormals(NormalizedType& here, const NormalizedType& there, continue; auto [emplaced, fresh] = here.tyvars.emplace(tyvar, std::make_unique(NormalizedType{builtinTypes})); if (fresh) - if (!unionNormals(*emplaced->second, here, index)) - return false; - if (!unionNormals(*emplaced->second, inter, index)) - return false; + { + NormalizationResult res = unionNormals(*emplaced->second, here, index); + if (res != NormalizationResult::True) + return res; + } + + NormalizationResult res = unionNormals(*emplaced->second, inter, index); + if (res != NormalizationResult::True) + return res; } here.booleans = unionOfBools(here.booleans, there.booleans); - if (FFlag::LuauNegatedClassTypes) - unionClasses(here.classes, there.classes); - else - unionClasses(here.DEPRECATED_classes, there.DEPRECATED_classes); + unionClasses(here.classes, there.classes); here.errors = (get(there.errors) ? here.errors : there.errors); here.nils = (get(there.nils) ? here.nils : there.nils); here.numbers = (get(there.numbers) ? here.numbers : there.numbers); unionStrings(here.strings, there.strings); here.threads = (get(there.threads) ? here.threads : there.threads); + here.buffers = (get(there.buffers) ? here.buffers : there.buffers); unionFunctions(here.functions, there.functions); unionTables(here.tables, there.tables); - return true; + return NormalizationResult::True; } bool Normalizer::withinResourceLimits() @@ -1314,7 +1704,8 @@ bool Normalizer::withinResourceLimits() // If cache is too large, clear it if (FInt::LuauNormalizeCacheLimit > 0) { - size_t cacheUsage = cachedNormals.size() + cachedIntersections.size() + cachedUnions.size() + cachedTypeIds.size(); + size_t cacheUsage = cachedNormals.size() + cachedIntersections.size() + cachedUnions.size() + cachedTypeIds.size() + + cachedIsInhabited.size() + cachedIsInhabitedIntersection.size(); if (cacheUsage > size_t(FInt::LuauNormalizeCacheLimit)) { clearCaches(); @@ -1330,62 +1721,111 @@ bool Normalizer::withinResourceLimits() return true; } +NormalizationResult Normalizer::intersectNormalWithNegationTy(TypeId toNegate, NormalizedType& intersect) +{ + + std::optional negated; + + std::shared_ptr normal = normalize(toNegate); + negated = negateNormal(*normal); + + if (!negated) + return NormalizationResult::False; + intersectNormals(intersect, *negated); + return NormalizationResult::True; +} + // See above for an explaination of `ignoreSmallerTyvars`. -bool Normalizer::unionNormalWithTy(NormalizedType& here, TypeId there, int ignoreSmallerTyvars) +NormalizationResult Normalizer::unionNormalWithTy( + NormalizedType& here, + TypeId there, + SeenTablePropPairs& seenTablePropPairs, + Set& seenSetTypes, + int ignoreSmallerTyvars +) { RecursionCounter _rc(&sharedState->counters.recursionCount); if (!withinResourceLimits()) - return false; + return NormalizationResult::HitLimits; there = follow(there); + if (get(there) || get(there)) { TypeId tops = unionOfTops(here.tops, there); + if (get(tops) && get(here.errors)) + tops = builtinTypes->anyType; clearNormal(here); here.tops = tops; - return true; + return NormalizationResult::True; + } + else if (get(there) || get(here.tops)) + return NormalizationResult::True; + else if (get(there) && get(here.tops)) + { + here.tops = builtinTypes->anyType; + return NormalizationResult::True; } - else if (get(there) || !get(here.tops)) - return true; else if (const UnionType* utv = get(there)) { + if (seenSetTypes.count(there)) + return NormalizationResult::True; + seenSetTypes.insert(there); + for (UnionTypeIterator it = begin(utv); it != end(utv); ++it) - if (!unionNormalWithTy(here, *it)) - return false; - return true; + { + NormalizationResult res = unionNormalWithTy(here, *it, seenTablePropPairs, seenSetTypes); + if (res != NormalizationResult::True) + { + seenSetTypes.erase(there); + return res; + } + } + + seenSetTypes.erase(there); + return NormalizationResult::True; } else if (const IntersectionType* itv = get(there)) { + if (seenSetTypes.count(there)) + return NormalizationResult::True; + seenSetTypes.insert(there); + NormalizedType norm{builtinTypes}; norm.tops = builtinTypes->anyType; for (IntersectionTypeIterator it = begin(itv); it != end(itv); ++it) - if (!intersectNormalWithTy(norm, *it)) - return false; + { + NormalizationResult res = intersectNormalWithTy(norm, *it, seenTablePropPairs, seenSetTypes); + if (res != NormalizationResult::True) + { + seenSetTypes.erase(there); + return res; + } + } + + seenSetTypes.erase(there); + return unionNormals(here, norm); } - else if (get(there) || get(there)) + else if (get(here.tops)) + return NormalizationResult::True; + else if (get(there) || get(there) || get(there) || get(there) || get(there)) { if (tyvarIndex(there) <= ignoreSmallerTyvars) - return true; + return NormalizationResult::True; NormalizedType inter{builtinTypes}; inter.tops = builtinTypes->unknownType; here.tyvars.insert_or_assign(there, std::make_unique(std::move(inter))); + + if (!isCacheable(there)) + here.isCacheable = false; } else if (get(there)) unionFunctionsWithFunction(here.functions, there); else if (get(there) || get(there)) unionTablesWithTable(here.tables, there); else if (get(there)) - { - if (FFlag::LuauNegatedClassTypes) - { - unionClassesWithClass(here.classes, there); - } - else - { - unionClassesWithClass(here.DEPRECATED_classes, there); - } - } + unionClassesWithClass(here.classes, there); else if (get(there)) here.errors = there; else if (const PrimitiveType* ptv = get(there)) @@ -1400,12 +1840,13 @@ bool Normalizer::unionNormalWithTy(NormalizedType& here, TypeId there, int ignor here.strings.resetToString(); else if (ptv->type == PrimitiveType::Thread) here.threads = there; + else if (ptv->type == PrimitiveType::Buffer) + here.buffers = there; else if (ptv->type == PrimitiveType::Function) { - LUAU_ASSERT(FFlag::LuauNegatedFunctionTypes); here.functions.resetToTop(); } - else if (ptv->type == PrimitiveType::Table && FFlag::LuauNegatedTableTypes) + else if (ptv->type == PrimitiveType::Table) { here.tables.clear(); here.tables.insert(there); @@ -1433,25 +1874,34 @@ bool Normalizer::unionNormalWithTy(NormalizedType& here, TypeId there, int ignor } else if (const NegationType* ntv = get(there)) { - const NormalizedType* thereNormal = normalize(ntv->ty); - std::optional tn = negateNormal(*thereNormal); + std::optional tn; + + std::shared_ptr thereNormal = normalize(ntv->ty); + tn = negateNormal(*thereNormal); + if (!tn) - return false; + return NormalizationResult::False; - if (!unionNormals(here, *tn)) - return false; + NormalizationResult res = unionNormals(here, *tn); + if (res != NormalizationResult::True) + return res; + } + else if (get(there) || get(there) || get(there)) + { + // nothing } - else if (get(there)) - LUAU_ASSERT(!"Internal error: Trying to normalize a BlockedType"); else LUAU_ASSERT(!"Unreachable"); for (auto& [tyvar, intersect] : here.tyvars) - if (!unionNormalWithTy(*intersect, there, tyvarIndex(tyvar))) - return false; + { + NormalizationResult res = unionNormalWithTy(*intersect, there, seenTablePropPairs, seenSetTypes, tyvarIndex(tyvar)); + if (res != NormalizationResult::True) + return res; + } assertInvariant(here); - return true; + return NormalizationResult::True; } // ------- Negations @@ -1459,6 +1909,8 @@ bool Normalizer::unionNormalWithTy(NormalizedType& here, TypeId there, int ignor std::optional Normalizer::negateNormal(const NormalizedType& here) { NormalizedType result{builtinTypes}; + result.isCacheable = here.isCacheable; + if (!get(here.tops)) { // The negation of unknown or any is never. Easy. @@ -1486,36 +1938,29 @@ std::optional Normalizer::negateNormal(const NormalizedType& her result.booleans = builtinTypes->trueType; } - if (FFlag::LuauNegatedClassTypes) + if (here.classes.isNever()) { - if (here.classes.isNever()) - { - resetToTop(builtinTypes, result.classes); - } - else if (isTop(builtinTypes, result.classes)) - { - result.classes.resetToNever(); - } - else - { - TypeIds rootNegations{}; - - for (const auto& [hereParent, hereNegations] : here.classes.classes) - { - if (hereParent != builtinTypes->classType) - rootNegations.insert(hereParent); - - for (TypeId hereNegation : hereNegations) - unionClassesWithClass(result.classes, hereNegation); - } - - if (!rootNegations.empty()) - result.classes.pushPair(builtinTypes->classType, rootNegations); - } + resetToTop(builtinTypes, result.classes); + } + else if (isTop(builtinTypes, result.classes)) + { + result.classes.resetToNever(); } else { - result.DEPRECATED_classes = negateAll(here.DEPRECATED_classes); + TypeIds rootNegations{}; + + for (const auto& [hereParent, hereNegations] : here.classes.classes) + { + if (hereParent != builtinTypes->classType) + rootNegations.insert(hereParent); + + for (TypeId hereNegation : hereNegations) + unionClassesWithClass(result.classes, hereNegation); + } + + if (!rootNegations.empty()) + result.classes.pushPair(builtinTypes->classType, rootNegations); } result.nils = get(here.nils) ? builtinTypes->nilType : builtinTypes->neverType; @@ -1525,36 +1970,31 @@ std::optional Normalizer::negateNormal(const NormalizedType& her result.strings.isCofinite = !result.strings.isCofinite; result.threads = get(here.threads) ? builtinTypes->threadType : builtinTypes->neverType; + result.buffers = get(here.buffers) ? builtinTypes->bufferType : builtinTypes->neverType; /* * Things get weird and so, so complicated if we allow negations of * arbitrary function types. Ordinary code can never form these kinds of * types, so we decline to negate them. */ - if (FFlag::LuauNegatedFunctionTypes) - { - if (here.functions.isNever()) - result.functions.resetToTop(); - else if (here.functions.isTop) - result.functions.resetToNever(); - else - return std::nullopt; - } + if (here.functions.isNever()) + result.functions.resetToTop(); + else if (here.functions.isTop) + result.functions.resetToNever(); + else + return std::nullopt; /* * It is not possible to negate an arbitrary table type, because function * types are not runtime-testable. Thus, we prohibit negation of anything * other than `table` and `never`. */ - if (FFlag::LuauNegatedTableTypes) - { - if (here.tables.empty()) - result.tables.insert(builtinTypes->tableType); - else if (here.tables.size() == 1 && here.tables.front() == builtinTypes->tableType) - result.tables.clear(); - else - return std::nullopt; - } + if (here.tables.empty()) + result.tables.insert(builtinTypes->tableType); + else if (here.tables.size() == 1 && here.tables.front() == builtinTypes->tableType) + result.tables.clear(); + else + return std::nullopt; // TODO: negating tables // TODO: negating tyvars? @@ -1620,11 +2060,13 @@ void Normalizer::subtractPrimitive(NormalizedType& here, TypeId ty) case PrimitiveType::Thread: here.threads = builtinTypes->neverType; break; + case PrimitiveType::Buffer: + here.buffers = builtinTypes->neverType; + break; case PrimitiveType::Function: here.functions.resetToNever(); break; case PrimitiveType::Table: - LUAU_ASSERT(FFlag::LuauNegatedTableTypes); here.tables.clear(); break; } @@ -1696,64 +2138,6 @@ TypeId Normalizer::intersectionOfBools(TypeId here, TypeId there) return there; } -void Normalizer::DEPRECATED_intersectClasses(TypeIds& heres, const TypeIds& theres) -{ - TypeIds tmp; - for (auto it = heres.begin(); it != heres.end();) - { - const ClassType* hctv = get(*it); - LUAU_ASSERT(hctv); - bool keep = false; - for (TypeId there : theres) - { - const ClassType* tctv = get(there); - LUAU_ASSERT(tctv); - if (isSubclass(hctv, tctv)) - { - keep = true; - break; - } - else if (isSubclass(tctv, hctv)) - { - keep = false; - tmp.insert(there); - break; - } - } - if (keep) - it++; - else - it = heres.erase(it); - } - heres.insert(tmp.begin(), tmp.end()); -} - -void Normalizer::DEPRECATED_intersectClassesWithClass(TypeIds& heres, TypeId there) -{ - bool foundSuper = false; - const ClassType* tctv = get(there); - LUAU_ASSERT(tctv); - for (auto it = heres.begin(); it != heres.end();) - { - const ClassType* hctv = get(*it); - LUAU_ASSERT(hctv); - if (isSubclass(hctv, tctv)) - it++; - else if (isSubclass(tctv, hctv)) - { - foundSuper = true; - break; - } - else - it = heres.erase(it); - } - if (foundSuper) - { - heres.clear(); - heres.insert(there); - } -} - void Normalizer::intersectClasses(NormalizedClassType& heres, const NormalizedClassType& theres) { if (theres.isNever()) @@ -1797,6 +2181,11 @@ void Normalizer::intersectClasses(NormalizedClassType& heres, const NormalizedCl if (isSubclass(thereTy, hereTy)) { + // If thereTy is a subtype of hereTy, we need to replace hereTy + // by thereTy and combine their negation lists. + // + // If any types in the negation list are not subtypes of + // thereTy, they need to be removed from the negation list. TypeIds negations = std::move(hereNegations); for (auto nIt = negations.begin(); nIt != negations.end();) @@ -1820,22 +2209,45 @@ void Normalizer::intersectClasses(NormalizedClassType& heres, const NormalizedCl } else if (isSubclass(hereTy, thereTy)) { + // If thereTy is a supertype of hereTy, we need to extend the + // negation list of hereTy by that of thereTy. + // + // If any of the types of thereTy's negations are not subtypes + // of hereTy, they must not be added to hereTy's negation list. + // + // If any of the types of thereTy's negations are supertypes of + // hereTy, then hereTy must be removed entirely. + // + // If any of the types of thereTy's negations are supertypes of + // the negations of herety, the former must supplant the latter. TypeIds negations = thereNegations; + bool erasedHere = false; + for (auto nIt = negations.begin(); nIt != negations.end();) { - if (!isSubclass(*nIt, hereTy)) + if (isSubclass(hereTy, *nIt)) { - nIt = negations.erase(nIt); + // eg SomeClass & (class & ~SomeClass) + // or SomeClass & (class & ~ParentClass) + heres.classes.erase(hereTy); + it = heres.ordering.erase(it); + erasedHere = true; + break; } + + // eg SomeClass & (class & ~Unrelated) + if (!isSubclass(*nIt, hereTy)) + nIt = negations.erase(nIt); else - { ++nIt; - } } - unionClasses(hereNegations, negations); - break; + if (!erasedHere) + { + unionClasses(hereNegations, negations); + ++it; + } } else if (hereTy == thereTy) { @@ -1908,18 +2320,68 @@ void Normalizer::intersectClassesWithClass(NormalizedClassType& heres, TypeId th void Normalizer::intersectStrings(NormalizedStringType& here, const NormalizedStringType& there) { + /* There are 9 cases to worry about here + Normalized Left | Normalized Right + C1 string | string ===> trivial + C2 string - {u_1,..} | string ===> trivial + C3 {u_1, ..} | string ===> trivial + C4 string | string - {v_1, ..} ===> string - {v_1, ..} + C5 string - {u_1,..} | string - {v_1, ..} ===> string - ({u_s} U {v_s}) + C6 {u_1, ..} | string - {v_1, ..} ===> {u_s} - {v_s} + C7 string | {v_1, ..} ===> {v_s} + C8 string - {u_1,..} | {v_1, ..} ===> {v_s} - {u_s} + C9 {u_1, ..} | {v_1, ..} ===> {u_s} ∩ {v_s} + */ + // Case 1,2,3 if (there.isString()) return; - if (here.isString()) - here.resetToNever(); - - for (auto it = here.singletons.begin(); it != here.singletons.end();) + // Case 4, Case 7 + else if (here.isString()) { - if (there.singletons.count(it->first)) - it++; - else - it = here.singletons.erase(it); + here.singletons.clear(); + for (const auto& [key, type] : there.singletons) + here.singletons[key] = type; + here.isCofinite = here.isCofinite && there.isCofinite; + } + // Case 5 + else if (here.isIntersection() && there.isIntersection()) + { + here.isCofinite = true; + for (const auto& [key, type] : there.singletons) + here.singletons[key] = type; + } + // Case 6 + else if (here.isUnion() && there.isIntersection()) + { + here.isCofinite = false; + for (const auto& [key, _] : there.singletons) + here.singletons.erase(key); + } + // Case 8 + else if (here.isIntersection() && there.isUnion()) + { + here.isCofinite = false; + std::map result(there.singletons); + for (const auto& [key, _] : here.singletons) + result.erase(key); + here.singletons = result; + } + // Case 9 + else if (here.isUnion() && there.isUnion()) + { + here.isCofinite = false; + std::map result; + result.insert(here.singletons.begin(), here.singletons.end()); + result.insert(there.singletons.begin(), there.singletons.end()); + for (auto it = result.begin(); it != result.end();) + if (!here.singletons.count(it->first) || !there.singletons.count(it->first)) + it = result.erase(it); + else + ++it; + here.singletons = result; } + else + LUAU_ASSERT(0 && "Internal Error - unrecognized case"); } std::optional Normalizer::intersectionOfTypePacks(TypePackId here, TypePackId there) @@ -1950,8 +2412,9 @@ std::optional Normalizer::intersectionOfTypePacks(TypePackId here, T itt++; } - auto dealWithDifferentArities = [&](TypePackIterator& ith, TypePackIterator itt, TypePackId here, TypePackId there, bool& hereSubThere, - bool& thereSubHere) { + auto dealWithDifferentArities = + [&](TypePackIterator& ith, TypePackIterator itt, TypePackId here, TypePackId there, bool& hereSubThere, bool& thereSubHere) + { if (ith != end(here)) { TypeId tty = builtinTypes->nilType; @@ -2042,7 +2505,7 @@ std::optional Normalizer::intersectionOfTypePacks(TypePackId here, T return arena->addTypePack({}); } -std::optional Normalizer::intersectionOfTables(TypeId here, TypeId there) +std::optional Normalizer::intersectionOfTables(TypeId here, TypeId there, SeenTablePropPairs& seenTablePropPairs, Set& seenSet) { if (here == there) return here; @@ -2056,25 +2519,38 @@ std::optional Normalizer::intersectionOfTables(TypeId here, TypeId there else if (isPrim(there, PrimitiveType::Table)) return here; + if (get(here)) + return there; + else if (get(there)) + return here; + else if (get(here)) + return there; + else if (get(there)) + return here; + TypeId htable = here; TypeId hmtable = nullptr; if (const MetatableType* hmtv = get(here)) { - htable = hmtv->table; - hmtable = hmtv->metatable; + htable = follow(hmtv->table); + hmtable = follow(hmtv->metatable); } TypeId ttable = there; TypeId tmtable = nullptr; if (const MetatableType* tmtv = get(there)) { - ttable = tmtv->table; - tmtable = tmtv->metatable; + ttable = follow(tmtv->table); + tmtable = follow(tmtv->metatable); } const TableType* httv = get(htable); - LUAU_ASSERT(httv); + if (!httv) + return std::nullopt; + const TableType* tttv = get(ttable); - LUAU_ASSERT(tttv); + if (!tttv) + return std::nullopt; + if (httv->state == TableState::Free || tttv->state == TableState::Free) return std::nullopt; @@ -2086,8 +2562,9 @@ std::optional Normalizer::intersectionOfTables(TypeId here, TypeId there state = tttv->state; TypeLevel level = max(httv->level, tttv->level); - TableType result{state, level}; + Scope* scope = max(httv->scope, tttv->scope); + std::unique_ptr result = nullptr; bool hereSubThere = true; bool thereSubHere = true; @@ -2101,19 +2578,131 @@ std::optional Normalizer::intersectionOfTables(TypeId here, TypeId there { const auto& [_name, tprop] = *tfound; // TODO: variance issues here, which can't be fixed until we have read/write property types - prop.type = intersectionType(hprop.type, tprop.type); - hereSubThere &= (prop.type == hprop.type); - thereSubHere &= (prop.type == tprop.type); + if (FFlag::LuauSolverV2) + { + if (hprop.readTy.has_value()) + { + if (tprop.readTy.has_value()) + { + // if the intersection of the read types of a property is uninhabited, the whole table is `never`. + // We've seen these table prop elements before and we're about to ask if their intersection + // is inhabited + if (FFlag::LuauNormalizationTracksCyclicPairsThroughInhabitance) + { + auto pair1 = std::pair{*hprop.readTy, *tprop.readTy}; + auto pair2 = std::pair{*tprop.readTy, *hprop.readTy}; + if (seenTablePropPairs.contains(pair1) || seenTablePropPairs.contains(pair2)) + { + seenTablePropPairs.erase(pair1); + seenTablePropPairs.erase(pair2); + return {builtinTypes->neverType}; + } + else + { + seenTablePropPairs.insert(pair1); + seenTablePropPairs.insert(pair2); + } + + Set seenSet{nullptr}; + NormalizationResult res = isIntersectionInhabited(*hprop.readTy, *tprop.readTy, seenTablePropPairs, seenSet); + + seenTablePropPairs.erase(pair1); + seenTablePropPairs.erase(pair2); + if (NormalizationResult::True != res) + return {builtinTypes->neverType}; + + TypeId ty = simplifyIntersection(builtinTypes, NotNull{arena}, *hprop.readTy, *tprop.readTy).result; + prop.readTy = ty; + hereSubThere &= (ty == hprop.readTy); + thereSubHere &= (ty == tprop.readTy); + } + else + { + + if (seenSet.contains(*hprop.readTy) && seenSet.contains(*tprop.readTy)) + { + seenSet.erase(*hprop.readTy); + seenSet.erase(*tprop.readTy); + return {builtinTypes->neverType}; + } + else + { + seenSet.insert(*hprop.readTy); + seenSet.insert(*tprop.readTy); + } + + NormalizationResult res = isIntersectionInhabited(*hprop.readTy, *tprop.readTy); + + seenSet.erase(*hprop.readTy); + seenSet.erase(*tprop.readTy); + + if (NormalizationResult::True != res) + return {builtinTypes->neverType}; + + TypeId ty = simplifyIntersection(builtinTypes, NotNull{arena}, *hprop.readTy, *tprop.readTy).result; + prop.readTy = ty; + hereSubThere &= (ty == hprop.readTy); + thereSubHere &= (ty == tprop.readTy); + } + } + else + { + prop.readTy = *hprop.readTy; + thereSubHere = false; + } + } + else if (tprop.readTy.has_value()) + { + prop.readTy = *tprop.readTy; + hereSubThere = false; + } + + if (hprop.writeTy.has_value()) + { + if (tprop.writeTy.has_value()) + { + prop.writeTy = simplifyIntersection(builtinTypes, NotNull{arena}, *hprop.writeTy, *tprop.writeTy).result; + hereSubThere &= (prop.writeTy == hprop.writeTy); + thereSubHere &= (prop.writeTy == tprop.writeTy); + } + else + { + prop.writeTy = *hprop.writeTy; + thereSubHere = false; + } + } + else if (tprop.writeTy.has_value()) + { + prop.writeTy = *tprop.writeTy; + hereSubThere = false; + } + } + else + { + prop.setType(intersectionType(hprop.type(), tprop.type())); + hereSubThere &= (prop.type() == hprop.type()); + thereSubHere &= (prop.type() == tprop.type()); + } } + // TODO: string indexers - result.props[name] = prop; + + if (prop.readTy || prop.writeTy) + { + if (!result.get()) + result = std::make_unique(TableType{state, level, scope}); + result->props[name] = prop; + } } for (const auto& [name, tprop] : tttv->props) { if (httv->props.count(name) == 0) { - result.props[name] = tprop; + if (!result.get()) + result = std::make_unique(TableType{state, level, scope}); + + result->props[name] = tprop; hereSubThere = false; } } @@ -2123,18 +2712,24 @@ std::optional Normalizer::intersectionOfTables(TypeId here, TypeId there // TODO: What should intersection of indexes be? TypeId index = unionType(httv->indexer->indexType, tttv->indexer->indexType); TypeId indexResult = intersectionType(httv->indexer->indexResultType, tttv->indexer->indexResultType); - result.indexer = {index, indexResult}; + if (!result.get()) + result = std::make_unique(TableType{state, level, scope}); + result->indexer = {index, indexResult}; hereSubThere &= (httv->indexer->indexType == index) && (httv->indexer->indexResultType == indexResult); thereSubHere &= (tttv->indexer->indexType == index) && (tttv->indexer->indexResultType == indexResult); } else if (httv->indexer) { - result.indexer = httv->indexer; + if (!result.get()) + result = std::make_unique(TableType{state, level, scope}); + result->indexer = httv->indexer; thereSubHere = false; } else if (tttv->indexer) { - result.indexer = tttv->indexer; + if (!result.get()) + result = std::make_unique(TableType{state, level, scope}); + result->indexer = tttv->indexer; hereSubThere = false; } @@ -2144,12 +2739,17 @@ std::optional Normalizer::intersectionOfTables(TypeId here, TypeId there else if (thereSubHere) table = ttable; else - table = arena->addType(std::move(result)); + { + if (result.get()) + table = arena->addType(std::move(*result)); + else + table = arena->addType(TableType{state, level, scope}); + } if (tmtable && hmtable) { // NOTE: this assumes metatables are ivariant - if (std::optional mtable = intersectionOfTables(hmtable, tmtable)) + if (std::optional mtable = intersectionOfTables(hmtable, tmtable, seenTablePropPairs, seenSet)) { if (table == htable && *mtable == hmtable) return here; @@ -2179,12 +2779,14 @@ std::optional Normalizer::intersectionOfTables(TypeId here, TypeId there return table; } -void Normalizer::intersectTablesWithTable(TypeIds& heres, TypeId there) +void Normalizer::intersectTablesWithTable(TypeIds& heres, TypeId there, SeenTablePropPairs& seenTablePropPairs, Set& seenSetTypes) { TypeIds tmp; for (TypeId here : heres) - if (std::optional inter = intersectionOfTables(here, there)) + { + if (std::optional inter = intersectionOfTables(here, there, seenTablePropPairs, seenSetTypes)) tmp.insert(*inter); + } heres.retain(tmp); heres.insert(tmp.begin(), tmp.end()); } @@ -2193,9 +2795,16 @@ void Normalizer::intersectTables(TypeIds& heres, const TypeIds& theres) { TypeIds tmp; for (TypeId here : heres) + { for (TypeId there : theres) - if (std::optional inter = intersectionOfTables(here, there)) + { + Set seenSetTypes{nullptr}; + SeenTablePropPairs seenTablePropPairs{{nullptr, nullptr}}; + if (std::optional inter = intersectionOfTables(here, there, seenTablePropPairs, seenSetTypes)) tmp.insert(*inter); + } + } + heres.retain(tmp); heres.insert(tmp.begin(), tmp.end()); } @@ -2368,15 +2977,15 @@ void Normalizer::intersectFunctionsWithFunction(NormalizedFunctionType& heres, T heres.isTop = false; - for (auto it = heres.parts->begin(); it != heres.parts->end();) + for (auto it = heres.parts.begin(); it != heres.parts.end();) { TypeId here = *it; if (get(here)) it++; else if (std::optional tmp = intersectionOfFunctions(here, there)) { - heres.parts->erase(it); - heres.parts->insert(*tmp); + heres.parts.erase(it); + heres.parts.insert(*tmp); return; } else @@ -2384,13 +2993,13 @@ void Normalizer::intersectFunctionsWithFunction(NormalizedFunctionType& heres, T } TypeIds tmps; - for (TypeId here : *heres.parts) + for (TypeId here : heres.parts) { if (std::optional tmp = unionSaturatedFunctions(here, there)) tmps.insert(*tmp); } - heres.parts->insert(there); - heres.parts->insert(tmps.begin(), tmps.end()); + heres.parts.insert(there); + heres.parts.insert(tmps.begin(), tmps.end()); } void Normalizer::intersectFunctions(NormalizedFunctionType& heres, const NormalizedFunctionType& theres) @@ -2404,33 +3013,46 @@ void Normalizer::intersectFunctions(NormalizedFunctionType& heres, const Normali } else { - for (TypeId there : *theres.parts) + for (TypeId there : theres.parts) intersectFunctionsWithFunction(heres, there); } } -bool Normalizer::intersectTyvarsWithTy(NormalizedTyvars& here, TypeId there) +NormalizationResult Normalizer::intersectTyvarsWithTy( + NormalizedTyvars& here, + TypeId there, + SeenTablePropPairs& seenTablePropPairs, + Set& seenSetTypes +) { for (auto it = here.begin(); it != here.end();) { NormalizedType& inter = *it->second; - if (!intersectNormalWithTy(inter, there)) - return false; + NormalizationResult res = intersectNormalWithTy(inter, there, seenTablePropPairs, seenSetTypes); + if (res != NormalizationResult::True) + return res; if (isShallowInhabited(inter)) ++it; else it = here.erase(it); } - return true; + return NormalizationResult::True; } // See above for an explaination of `ignoreSmallerTyvars`. -bool Normalizer::intersectNormals(NormalizedType& here, const NormalizedType& there, int ignoreSmallerTyvars) +NormalizationResult Normalizer::intersectNormals(NormalizedType& here, const NormalizedType& there, int ignoreSmallerTyvars) { + if (FFlag::LuauIntersectNormalsNeedsToTrackResourceLimits) + { + RecursionCounter _rc(&sharedState->counters.recursionCount); + if (!withinResourceLimits()) + return NormalizationResult::HitLimits; + } + if (!get(there.tops)) { here.tops = intersectionOfTops(here.tops, there.tops); - return true; + return NormalizationResult::True; } else if (!get(here.tops)) { @@ -2438,22 +3060,20 @@ bool Normalizer::intersectNormals(NormalizedType& here, const NormalizedType& th return unionNormals(here, there, ignoreSmallerTyvars); } - here.booleans = intersectionOfBools(here.booleans, there.booleans); + // Limit based on worst-case expansion of the table intersection + // This restriction can be relaxed when table intersection simplification is improved + if (here.tables.size() * there.tables.size() >= size_t(FInt::LuauNormalizeIntersectionLimit)) + return NormalizationResult::HitLimits; - if (FFlag::LuauNegatedClassTypes) - { - intersectClasses(here.classes, there.classes); - } - else - { - DEPRECATED_intersectClasses(here.DEPRECATED_classes, there.DEPRECATED_classes); - } + here.booleans = intersectionOfBools(here.booleans, there.booleans); + intersectClasses(here.classes, there.classes); here.errors = (get(there.errors) ? there.errors : here.errors); here.nils = (get(there.nils) ? there.nils : here.nils); here.numbers = (get(there.numbers) ? there.numbers : here.numbers); intersectStrings(here.strings, there.strings); here.threads = (get(there.threads) ? there.threads : here.threads); + here.buffers = (get(there.buffers) ? there.buffers : here.buffers); intersectFunctions(here.functions, there.functions); intersectTables(here.tables, there.tables); @@ -2465,8 +3085,9 @@ bool Normalizer::intersectNormals(NormalizedType& here, const NormalizedType& th auto [found, fresh] = here.tyvars.emplace(tyvar, std::make_unique(NormalizedType{builtinTypes})); if (fresh) { - if (!unionNormals(*found->second, here, index)) - return false; + NormalizationResult res = unionNormals(*found->second, here, index); + if (res != NormalizationResult::True) + return res; } } } @@ -2479,60 +3100,75 @@ bool Normalizer::intersectNormals(NormalizedType& here, const NormalizedType& th auto found = there.tyvars.find(tyvar); if (found == there.tyvars.end()) { - if (!intersectNormals(inter, there, index)) - return false; + NormalizationResult res = intersectNormals(inter, there, index); + if (res != NormalizationResult::True) + return res; } else { - if (!intersectNormals(inter, *found->second, index)) - return false; + NormalizationResult res = intersectNormals(inter, *found->second, index); + if (res != NormalizationResult::True) + return res; } if (isShallowInhabited(inter)) it++; else it = here.tyvars.erase(it); } - return true; + return NormalizationResult::True; } -bool Normalizer::intersectNormalWithTy(NormalizedType& here, TypeId there) +NormalizationResult Normalizer::intersectNormalWithTy( + NormalizedType& here, + TypeId there, + SeenTablePropPairs& seenTablePropPairs, + Set& seenSetTypes +) { RecursionCounter _rc(&sharedState->counters.recursionCount); if (!withinResourceLimits()) - return false; + return NormalizationResult::HitLimits; there = follow(there); + if (get(there) || get(there)) { here.tops = intersectionOfTops(here.tops, there); - return true; + return NormalizationResult::True; } else if (!get(here.tops)) { clearNormal(here); - return unionNormalWithTy(here, there); + return unionNormalWithTy(here, there, seenTablePropPairs, seenSetTypes); } else if (const UnionType* utv = get(there)) { NormalizedType norm{builtinTypes}; for (UnionTypeIterator it = begin(utv); it != end(utv); ++it) - if (!unionNormalWithTy(norm, *it)) - return false; + { + NormalizationResult res = unionNormalWithTy(norm, *it, seenTablePropPairs, seenSetTypes); + if (res != NormalizationResult::True) + return res; + } return intersectNormals(here, norm); } else if (const IntersectionType* itv = get(there)) { for (IntersectionTypeIterator it = begin(itv); it != end(itv); ++it) - if (!intersectNormalWithTy(here, *it)) - return false; - return true; + { + NormalizationResult res = intersectNormalWithTy(here, *it, seenTablePropPairs, seenSetTypes); + if (res != NormalizationResult::True) + return res; + } + return NormalizationResult::True; } - else if (get(there) || get(there)) + else if (get(there) || get(there) || get(there) || get(there) || get(there)) { NormalizedType thereNorm{builtinTypes}; NormalizedType topNorm{builtinTypes}; topNorm.tops = builtinTypes->unknownType; thereNorm.tyvars.insert_or_assign(there, std::make_unique(std::move(topNorm))); + here.isCacheable = false; return intersectNormals(here, thereNorm); } @@ -2549,25 +3185,15 @@ bool Normalizer::intersectNormalWithTy(NormalizedType& here, TypeId there) { TypeIds tables = std::move(here.tables); clearNormal(here); - intersectTablesWithTable(tables, there); + intersectTablesWithTable(tables, there, seenTablePropPairs, seenSetTypes); here.tables = std::move(tables); } else if (get(there)) { - if (FFlag::LuauNegatedClassTypes) - { - NormalizedClassType nct = std::move(here.classes); - clearNormal(here); - intersectClassesWithClass(nct, there); - here.classes = std::move(nct); - } - else - { - TypeIds classes = std::move(here.DEPRECATED_classes); - clearNormal(here); - DEPRECATED_intersectClassesWithClass(classes, there); - here.DEPRECATED_classes = std::move(classes); - } + NormalizedClassType nct = std::move(here.classes); + clearNormal(here); + intersectClassesWithClass(nct, there); + here.classes = std::move(nct); } else if (get(there)) { @@ -2583,6 +3209,7 @@ bool Normalizer::intersectNormalWithTy(NormalizedType& here, TypeId there) NormalizedStringType strings = std::move(here.strings); NormalizedFunctionType functions = std::move(here.functions); TypeId threads = here.threads; + TypeId buffers = here.buffers; TypeIds tables = std::move(here.tables); clearNormal(here); @@ -2597,16 +3224,12 @@ bool Normalizer::intersectNormalWithTy(NormalizedType& here, TypeId there) here.strings = std::move(strings); else if (ptv->type == PrimitiveType::Thread) here.threads = threads; + else if (ptv->type == PrimitiveType::Buffer) + here.buffers = buffers; else if (ptv->type == PrimitiveType::Function) - { - LUAU_ASSERT(FFlag::LuauNegatedFunctionTypes); here.functions = std::move(functions); - } else if (ptv->type == PrimitiveType::Table) - { - LUAU_ASSERT(FFlag::LuauNegatedTableTypes); here.tables = std::move(tables); - } else LUAU_ASSERT(!"Unreachable"); } @@ -2634,33 +3257,47 @@ bool Normalizer::intersectNormalWithTy(NormalizedType& here, TypeId there) subtractPrimitive(here, ntv->ty); else if (const SingletonType* stv = get(t)) subtractSingleton(here, follow(ntv->ty)); - else if (get(t) && FFlag::LuauNegatedClassTypes) + else if (get(t)) { - const NormalizedType* normal = normalize(t); - std::optional negated = negateNormal(*normal); - if (!negated) - return false; - intersectNormals(here, *negated); + NormalizationResult res = intersectNormalWithNegationTy(t, here); + if (shouldEarlyExit(res)) + return res; } else if (const UnionType* itv = get(t)) { for (TypeId part : itv->options) { - const NormalizedType* normalPart = normalize(part); - std::optional negated = negateNormal(*normalPart); - if (!negated) - return false; - intersectNormals(here, *negated); + NormalizationResult res = intersectNormalWithNegationTy(part, here); + if (shouldEarlyExit(res)) + return res; } } else if (get(t)) { // HACK: Refinements sometimes intersect with ~any under the // assumption that it is the same as any. - return true; + return NormalizationResult::True; + } + else if (get(t)) + { + // `*no-refine*` means we will never do anything to affect the intersection. + return NormalizationResult::True; + } + else if (get(t)) + { + // if we're intersecting with `~never`, this is equivalent to intersecting with `unknown` + // this is a noop since an intersection with `unknown` is trivial. + return NormalizationResult::True; + } + else if (get(t)) + { + // if we're intersecting with `~unknown`, this is equivalent to intersecting with `never` + // this means we should clear the type entirely. + clearNormal(here); + return NormalizationResult::True; } else if (auto nt = get(t)) - return intersectNormalWithTy(here, nt->ty); + return intersectNormalWithTy(here, nt->ty, seenTablePropPairs, seenSetTypes); else { // TODO negated unions, intersections, table, and function. @@ -2668,18 +3305,39 @@ bool Normalizer::intersectNormalWithTy(NormalizedType& here, TypeId there) LUAU_ASSERT(!"Unimplemented"); } } - else if (get(there) && FFlag::LuauNegatedClassTypes) + else if (get(there)) { here.classes.resetToNever(); } + else if (get(there)) + { + // `*no-refine*` means we will never do anything to affect the intersection. + return NormalizationResult::True; + } else LUAU_ASSERT(!"Unreachable"); - if (!intersectTyvarsWithTy(tyvars, there)) - return false; + NormalizationResult res = intersectTyvarsWithTy(tyvars, there, seenTablePropPairs, seenSetTypes); + if (res != NormalizationResult::True) + return res; here.tyvars = std::move(tyvars); - return true; + return NormalizationResult::True; +} + +void makeTableShared(TypeId ty) +{ + ty = follow(ty); + if (auto tableTy = getMutable(ty)) + { + for (auto& [_, prop] : tableTy->props) + prop.makeShared(); + } + else if (auto metatableTy = get(ty)) + { + makeTableShared(metatableTy->metatable); + makeTableShared(metatableTy->table); + } } // -------- Convert back from a normalized type to a type @@ -2694,67 +3352,60 @@ TypeId Normalizer::typeFromNormal(const NormalizedType& norm) if (!get(norm.booleans)) result.push_back(norm.booleans); - if (FFlag::LuauNegatedClassTypes) + if (isTop(builtinTypes, norm.classes)) { - if (isTop(builtinTypes, norm.classes)) - { - result.push_back(builtinTypes->classType); - } - else if (!norm.classes.isNever()) + result.push_back(builtinTypes->classType); + } + else if (!norm.classes.isNever()) + { + std::vector parts; + parts.reserve(norm.classes.classes.size()); + + for (const TypeId normTy : norm.classes.ordering) { - std::vector parts; - parts.reserve(norm.classes.classes.size()); + const TypeIds& normNegations = norm.classes.classes.at(normTy); - for (const TypeId normTy : norm.classes.ordering) + if (normNegations.empty()) { - const TypeIds& normNegations = norm.classes.classes.at(normTy); + parts.push_back(normTy); + } + else + { + std::vector intersection; + intersection.reserve(normNegations.size() + 1); - if (normNegations.empty()) + intersection.push_back(normTy); + for (TypeId negation : normNegations) { - parts.push_back(normTy); + intersection.push_back(arena->addType(NegationType{negation})); } - else - { - std::vector intersection; - intersection.reserve(normNegations.size() + 1); - intersection.push_back(normTy); - for (TypeId negation : normNegations) - { - intersection.push_back(arena->addType(NegationType{negation})); - } - - parts.push_back(arena->addType(IntersectionType{std::move(intersection)})); - } + parts.push_back(arena->addType(IntersectionType{std::move(intersection)})); } + } - if (parts.size() == 1) - { - result.push_back(parts.at(0)); - } - else if (parts.size() > 1) - { - result.push_back(arena->addType(UnionType{std::move(parts)})); - } + if (parts.size() == 1) + { + result.push_back(parts.at(0)); + } + else if (parts.size() > 1) + { + result.push_back(arena->addType(UnionType{std::move(parts)})); } - } - else - { - result.insert(result.end(), norm.DEPRECATED_classes.begin(), norm.DEPRECATED_classes.end()); } if (!get(norm.errors)) result.push_back(norm.errors); - if (FFlag::LuauNegatedFunctionTypes && norm.functions.isTop) + if (norm.functions.isTop) result.push_back(builtinTypes->functionType); else if (!norm.functions.isNever()) { - if (norm.functions.parts->size() == 1) - result.push_back(*norm.functions.parts->begin()); + if (norm.functions.parts.size() == 1) + result.push_back(*norm.functions.parts.begin()); else { std::vector parts; - parts.insert(parts.end(), norm.functions.parts->begin(), norm.functions.parts->end()); + parts.insert(parts.end(), norm.functions.parts.begin(), norm.functions.parts.end()); result.push_back(arena->addType(IntersectionType{std::move(parts)})); } } @@ -2780,8 +3431,21 @@ TypeId Normalizer::typeFromNormal(const NormalizedType& norm) } if (!get(norm.threads)) result.push_back(builtinTypes->threadType); + if (!get(norm.buffers)) + result.push_back(builtinTypes->bufferType); + + if (FFlag::LuauSolverV2) + { + result.reserve(result.size() + norm.tables.size()); + for (auto table : norm.tables) + { + makeTableShared(table); + result.push_back(table); + } + } + else + result.insert(result.end(), norm.tables.begin(), norm.tables.end()); - result.insert(result.end(), norm.tables.begin(), norm.tables.end()); for (auto& [tyvar, intersect] : norm.tyvars) { if (get(intersect->tops)) @@ -2806,11 +3470,25 @@ bool isSubtype(TypeId subTy, TypeId superTy, NotNull scope, NotNull scope, NotNull builtinTypes, InternalErrorReporter& ice) @@ -2818,11 +3496,25 @@ bool isSubtype(TypePackId subPack, TypePackId superPack, NotNull scope, N UnifierSharedState sharedState{&ice}; TypeArena arena; Normalizer normalizer{&arena, builtinTypes, NotNull{&sharedState}}; - Unifier u{NotNull{&normalizer}, Mode::Strict, scope, Location{}, Covariant}; + TypeCheckLimits limits; + TypeFunctionRuntime typeFunctionRuntime{ + NotNull{&ice}, NotNull{&limits} + }; // TODO: maybe subtyping checks should not invoke user-defined type function runtime + + // Subtyping under DCR is not implemented using unification! + if (FFlag::LuauSolverV2) + { + Subtyping subtyping{builtinTypes, NotNull{&arena}, NotNull{&normalizer}, NotNull{&typeFunctionRuntime}, NotNull{&ice}}; - u.tryUnify(subPack, superPack); - const bool ok = u.errors.empty() && u.log.empty(); - return ok; + return subtyping.isSubtype(subPack, superPack, scope).isSubtype; + } + else + { + Unifier u{NotNull{&normalizer}, scope, Location{}, Covariant}; + + u.tryUnify(subPack, superPack); + return !u.failure; + } } } // namespace Luau diff --git a/Analysis/src/OverloadResolution.cpp b/Analysis/src/OverloadResolution.cpp new file mode 100644 index 000000000..fbcce2b7c --- /dev/null +++ b/Analysis/src/OverloadResolution.cpp @@ -0,0 +1,484 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/OverloadResolution.h" + +#include "Luau/Instantiation2.h" +#include "Luau/Subtyping.h" +#include "Luau/TxnLog.h" +#include "Luau/Type.h" +#include "Luau/TypeFunction.h" +#include "Luau/TypePack.h" +#include "Luau/TypeUtils.h" +#include "Luau/Unifier2.h" + +namespace Luau +{ + +OverloadResolver::OverloadResolver( + NotNull builtinTypes, + NotNull arena, + NotNull normalizer, + NotNull typeFunctionRuntime, + NotNull scope, + NotNull reporter, + NotNull limits, + Location callLocation +) + : builtinTypes(builtinTypes) + , arena(arena) + , normalizer(normalizer) + , typeFunctionRuntime(typeFunctionRuntime) + , scope(scope) + , ice(reporter) + , limits(limits) + , subtyping({builtinTypes, arena, normalizer, typeFunctionRuntime, ice}) + , callLoc(callLocation) +{ +} + +std::pair OverloadResolver::selectOverload(TypeId ty, TypePackId argsPack) +{ + auto tryOne = [&](TypeId f) + { + if (auto ftv = get(f)) + { + Subtyping::Variance variance = subtyping.variance; + subtyping.variance = Subtyping::Variance::Contravariant; + SubtypingResult r = subtyping.isSubtype(argsPack, ftv->argTypes, scope); + subtyping.variance = variance; + + if (r.isSubtype) + return true; + } + + return false; + }; + + TypeId t = follow(ty); + + if (tryOne(ty)) + return {Analysis::Ok, ty}; + + if (auto it = get(t)) + { + for (TypeId component : it) + { + if (tryOne(component)) + return {Analysis::Ok, component}; + } + } + + return {Analysis::OverloadIsNonviable, ty}; +} + +void OverloadResolver::resolve(TypeId fnTy, const TypePack* args, AstExpr* selfExpr, const std::vector* argExprs) +{ + fnTy = follow(fnTy); + + auto it = get(fnTy); + if (!it) + { + auto [analysis, errors] = checkOverload(fnTy, args, selfExpr, argExprs); + add(analysis, fnTy, std::move(errors)); + return; + } + + for (TypeId ty : it) + { + if (resolution.find(ty) != resolution.end()) + continue; + + auto [analysis, errors] = checkOverload(ty, args, selfExpr, argExprs); + add(analysis, ty, std::move(errors)); + } +} + +std::optional OverloadResolver::testIsSubtype(const Location& location, TypeId subTy, TypeId superTy) +{ + auto r = subtyping.isSubtype(subTy, superTy, scope); + ErrorVec errors; + + if (r.normalizationTooComplex) + errors.emplace_back(location, NormalizationTooComplex{}); + + if (!r.isSubtype) + { + switch (shouldSuppressErrors(normalizer, subTy).orElse(shouldSuppressErrors(normalizer, superTy))) + { + case ErrorSuppression::Suppress: + break; + case ErrorSuppression::NormalizationFailed: + errors.emplace_back(location, NormalizationTooComplex{}); + // intentionally fallthrough here since we couldn't prove this was error-suppressing + [[fallthrough]]; + case ErrorSuppression::DoNotSuppress: + errors.emplace_back(location, TypeMismatch{superTy, subTy}); + break; + } + } + + if (errors.empty()) + return std::nullopt; + + return errors; +} + +std::optional OverloadResolver::testIsSubtype(const Location& location, TypePackId subTy, TypePackId superTy) +{ + auto r = subtyping.isSubtype(subTy, superTy, scope); + ErrorVec errors; + + if (r.normalizationTooComplex) + errors.emplace_back(location, NormalizationTooComplex{}); + + if (!r.isSubtype) + { + switch (shouldSuppressErrors(normalizer, subTy).orElse(shouldSuppressErrors(normalizer, superTy))) + { + case ErrorSuppression::Suppress: + break; + case ErrorSuppression::NormalizationFailed: + errors.emplace_back(location, NormalizationTooComplex{}); + // intentionally fallthrough here since we couldn't prove this was error-suppressing + [[fallthrough]]; + case ErrorSuppression::DoNotSuppress: + errors.emplace_back(location, TypePackMismatch{superTy, subTy}); + break; + } + } + + if (errors.empty()) + return std::nullopt; + + return errors; +} + +std::pair OverloadResolver::checkOverload( + TypeId fnTy, + const TypePack* args, + AstExpr* fnLoc, + const std::vector* argExprs, + bool callMetamethodOk +) +{ + fnTy = follow(fnTy); + + ErrorVec discard; + if (get(fnTy) || get(fnTy) || get(fnTy)) + return {Ok, {}}; + else if (auto fn = get(fnTy)) + return checkOverload_(fnTy, fn, args, fnLoc, argExprs); // Intentionally split to reduce the stack pressure of this function. + else if (auto callMm = findMetatableEntry(builtinTypes, discard, fnTy, "__call", callLoc); callMm && callMetamethodOk) + { + // Calling a metamethod forwards the `fnTy` as self. + TypePack withSelf = *args; + withSelf.head.insert(withSelf.head.begin(), fnTy); + + std::vector withSelfExprs = *argExprs; + withSelfExprs.insert(withSelfExprs.begin(), fnLoc); + + return checkOverload(*callMm, &withSelf, fnLoc, &withSelfExprs, /*callMetamethodOk=*/false); + } + else + return {TypeIsNotAFunction, {}}; // Intentionally empty. We can just fabricate the type error later on. +} + +bool OverloadResolver::isLiteral(AstExpr* expr) +{ + if (auto group = expr->as()) + return isLiteral(group->expr); + else if (auto assertion = expr->as()) + return isLiteral(assertion->expr); + + return expr->is() || expr->is() || expr->is() || + expr->is() || expr->is() || expr->is(); +} + +std::pair OverloadResolver::checkOverload_( + TypeId fnTy, + const FunctionType* fn, + const TypePack* args, + AstExpr* fnExpr, + const std::vector* argExprs +) +{ + FunctionGraphReductionResult result = reduceTypeFunctions( + fnTy, callLoc, TypeFunctionContext{arena, builtinTypes, scope, normalizer, typeFunctionRuntime, ice, limits}, /*force=*/true + ); + if (!result.errors.empty()) + return {OverloadIsNonviable, result.errors}; + + ErrorVec argumentErrors; + TypePackId typ = arena->addTypePack(*args); + + TypeId prospectiveFunction = arena->addType(FunctionType{typ, builtinTypes->anyTypePack}); + SubtypingResult sr = subtyping.isSubtype(fnTy, prospectiveFunction, scope); + + if (sr.isSubtype) + return {Analysis::Ok, {}}; + + if (1 == sr.reasoning.size()) + { + const SubtypingReasoning& reason = *sr.reasoning.begin(); + + const TypePath::Path justArguments{TypePath::PackField::Arguments}; + + if (reason.subPath == justArguments && reason.superPath == justArguments) + { + // If the subtype test failed only due to an arity mismatch, + // it is still possible that this function call is okay. + // Subtype testing does not know anything about optional + // function arguments. + // + // This can only happen if the actual function call has a + // finite set of arguments which is too short for the + // function being called. If all of those unsatisfied + // function arguments are options, then this function call + // is ok. + + const size_t firstUnsatisfiedArgument = args->head.size(); + const auto [requiredHead, requiredTail] = flatten(fn->argTypes); + + bool isVariadic = requiredTail && Luau::isVariadic(*requiredTail); + + // If too many arguments were supplied, this overload + // definitely does not match. + if (args->head.size() > requiredHead.size()) + { + auto [minParams, optMaxParams] = getParameterExtents(TxnLog::empty(), fn->argTypes); + + TypeError error{fnExpr->location, CountMismatch{minParams, optMaxParams, args->head.size(), CountMismatch::Arg, isVariadic}}; + + return {Analysis::ArityMismatch, {error}}; + } + + // If any of the unsatisfied arguments are not supertypes of + // nil, then this overload does not match. + for (size_t i = firstUnsatisfiedArgument; i < requiredHead.size(); ++i) + { + if (!subtyping.isSubtype(builtinTypes->nilType, requiredHead[i], scope).isSubtype) + { + auto [minParams, optMaxParams] = getParameterExtents(TxnLog::empty(), fn->argTypes); + TypeError error{fnExpr->location, CountMismatch{minParams, optMaxParams, args->head.size(), CountMismatch::Arg, isVariadic}}; + + return {Analysis::ArityMismatch, {error}}; + } + } + + return {Analysis::Ok, {}}; + } + } + + ErrorVec errors; + + for (const SubtypingReasoning& reason : sr.reasoning) + { + /* The return type of our prospective function is always + * any... so any subtype failures here can only arise from + * argument type mismatches. + */ + + Location argLocation; + if (reason.superPath.components.size() <= 1) + break; + + if (const Luau::TypePath::Index* pathIndexComponent = get_if(&reason.superPath.components.at(1))) + { + size_t nthArgument = pathIndexComponent->index; + // if the nth type argument to the function is less than the number of ast expressions we passed to the function + // we should be able to pull out the location of the argument + // If the nth type argument to the function is out of range of the ast expressions we passed to the function + // e.g. table.pack(functionThatReturnsMultipleArguments(arg1, arg2, ....)), default to the location of the last passed expression + // If we passed no expression arguments to the call, default to the location of the function expression. + argLocation = nthArgument < argExprs->size() ? argExprs->at(nthArgument)->location + : argExprs->size() != 0 ? argExprs->back()->location + : fnExpr->location; + + std::optional failedSubTy = traverseForType(fnTy, reason.subPath, builtinTypes); + std::optional failedSuperTy = traverseForType(prospectiveFunction, reason.superPath, builtinTypes); + + if (failedSubTy && failedSuperTy) + { + + switch (shouldSuppressErrors(normalizer, *failedSubTy).orElse(shouldSuppressErrors(normalizer, *failedSuperTy))) + { + case ErrorSuppression::Suppress: + break; + case ErrorSuppression::NormalizationFailed: + errors.emplace_back(argLocation, NormalizationTooComplex{}); + // intentionally fallthrough here since we couldn't prove this was error-suppressing + [[fallthrough]]; + case ErrorSuppression::DoNotSuppress: + // TODO extract location from the SubtypingResult path and argExprs + switch (reason.variance) + { + case SubtypingVariance::Covariant: + case SubtypingVariance::Contravariant: + errors.emplace_back(argLocation, TypeMismatch{*failedSubTy, *failedSuperTy, TypeMismatch::CovariantContext}); + break; + case SubtypingVariance::Invariant: + errors.emplace_back(argLocation, TypeMismatch{*failedSubTy, *failedSuperTy, TypeMismatch::InvariantContext}); + break; + default: + LUAU_ASSERT(0); + break; + } + } + } + } + + std::optional failedSubPack = traverseForPack(fnTy, reason.subPath, builtinTypes); + std::optional failedSuperPack = traverseForPack(prospectiveFunction, reason.superPath, builtinTypes); + + if (failedSubPack && failedSuperPack) + { + // If a bug in type inference occurs, we may have a mismatch in the return packs. + // This happens when inference incorrectly leaves the result type of a function free. + // If this happens, we don't want to explode, so we'll use the function's location. + if (argExprs->empty()) + argLocation = fnExpr->location; + else + argLocation = argExprs->at(argExprs->size() - 1)->location; + + // TODO extract location from the SubtypingResult path and argExprs + switch (reason.variance) + { + case SubtypingVariance::Covariant: + errors.emplace_back(argLocation, TypePackMismatch{*failedSubPack, *failedSuperPack}); + break; + case SubtypingVariance::Contravariant: + errors.emplace_back(argLocation, TypePackMismatch{*failedSuperPack, *failedSubPack}); + break; + case SubtypingVariance::Invariant: + errors.emplace_back(argLocation, TypePackMismatch{*failedSubPack, *failedSuperPack}); + break; + default: + LUAU_ASSERT(0); + break; + } + } + } + + return {Analysis::OverloadIsNonviable, std::move(errors)}; +} + +size_t OverloadResolver::indexof(Analysis analysis) +{ + switch (analysis) + { + case Ok: + return ok.size(); + case TypeIsNotAFunction: + return nonFunctions.size(); + case ArityMismatch: + return arityMismatches.size(); + case OverloadIsNonviable: + return nonviableOverloads.size(); + } + + ice->ice("Inexhaustive switch in FunctionCallResolver::indexof"); +} + +void OverloadResolver::add(Analysis analysis, TypeId ty, ErrorVec&& errors) +{ + resolution.insert(ty, {analysis, indexof(analysis)}); + + switch (analysis) + { + case Ok: + LUAU_ASSERT(errors.empty()); + ok.push_back(ty); + break; + case TypeIsNotAFunction: + LUAU_ASSERT(errors.empty()); + nonFunctions.push_back(ty); + break; + case ArityMismatch: + LUAU_ASSERT(!errors.empty()); + arityMismatches.emplace_back(ty, std::move(errors)); + break; + case OverloadIsNonviable: + nonviableOverloads.emplace_back(ty, std::move(errors)); + break; + } +} + +// we wrap calling the overload resolver in a separate function to reduce overall stack pressure in `solveFunctionCall`. +// this limits the lifetime of `OverloadResolver`, a large type, to only as long as it is actually needed. +std::optional selectOverload( + NotNull builtinTypes, + NotNull arena, + NotNull normalizer, + NotNull typeFunctionRuntime, + NotNull scope, + NotNull iceReporter, + NotNull limits, + const Location& location, + TypeId fn, + TypePackId argsPack +) +{ + OverloadResolver resolver{builtinTypes, arena, normalizer, typeFunctionRuntime, scope, iceReporter, limits, location}; + auto [status, overload] = resolver.selectOverload(fn, argsPack); + + if (status == OverloadResolver::Analysis::Ok) + return overload; + + if (get(fn) || get(fn)) + return fn; + + return {}; +} + +SolveResult solveFunctionCall( + NotNull arena, + NotNull builtinTypes, + NotNull normalizer, + NotNull typeFunctionRuntime, + NotNull iceReporter, + NotNull limits, + NotNull scope, + const Location& location, + TypeId fn, + TypePackId argsPack +) +{ + std::optional overloadToUse = + selectOverload(builtinTypes, arena, normalizer, typeFunctionRuntime, scope, iceReporter, limits, location, fn, argsPack); + if (!overloadToUse) + return {SolveResult::NoMatchingOverload}; + + TypePackId resultPack = arena->freshTypePack(scope); + + TypeId inferredTy = arena->addType(FunctionType{TypeLevel{}, scope.get(), argsPack, resultPack}); + Unifier2 u2{NotNull{arena}, builtinTypes, scope, iceReporter}; + + const bool occursCheckPassed = u2.unify(*overloadToUse, inferredTy); + + if (!u2.genericSubstitutions.empty() || !u2.genericPackSubstitutions.empty()) + { + Instantiation2 instantiation{arena, std::move(u2.genericSubstitutions), std::move(u2.genericPackSubstitutions)}; + + std::optional subst = instantiation.substitute(resultPack); + + if (!subst) + return {SolveResult::CodeTooComplex}; + else + resultPack = *subst; + } + + if (!occursCheckPassed) + return {SolveResult::OccursCheckFailed}; + + SolveResult result; + result.result = SolveResult::Ok; + result.typePackId = resultPack; + + LUAU_ASSERT(overloadToUse); + result.overloadToUse = *overloadToUse; + result.inferredTy = inferredTy; + result.expandedFreeTypes = std::move(u2.expandedFreeTypes); + + return result; +} + +} // namespace Luau diff --git a/Analysis/src/Quantify.cpp b/Analysis/src/Quantify.cpp index 845ae3a36..daa61fd5f 100644 --- a/Analysis/src/Quantify.cpp +++ b/Analysis/src/Quantify.cpp @@ -8,10 +8,6 @@ #include "Luau/Type.h" #include "Luau/VisitType.h" -LUAU_FASTFLAG(DebugLuauSharedSelf) -LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); -LUAU_FASTFLAG(LuauClassTypeVarsInSubstitution) - namespace Luau { @@ -27,7 +23,6 @@ struct Quantifier final : TypeOnceVisitor explicit Quantifier(TypeLevel level) : level(level) { - LUAU_ASSERT(!FFlag::DebugLuauDeferredConstraintResolution); } /// @return true if outer encloses inner @@ -103,60 +98,20 @@ struct Quantifier final : TypeOnceVisitor void quantify(TypeId ty, TypeLevel level) { - if (FFlag::DebugLuauSharedSelf) - { - ty = follow(ty); - - if (auto ttv = getTableType(ty); ttv && ttv->selfTy) - { - Quantifier selfQ{level}; - selfQ.traverse(*ttv->selfTy); - - Quantifier q{level}; - q.traverse(ty); - - for (const auto& [_, prop] : ttv->props) - { - auto ftv = getMutable(follow(prop.type)); - if (!ftv || !ftv->hasSelf) - continue; - - if (Luau::first(ftv->argTypes) == ttv->selfTy) - { - ftv->generics.insert(ftv->generics.end(), selfQ.generics.begin(), selfQ.generics.end()); - ftv->genericPacks.insert(ftv->genericPacks.end(), selfQ.genericPacks.begin(), selfQ.genericPacks.end()); - } - } - } - else if (auto ftv = getMutable(ty)) - { - Quantifier q{level}; - q.traverse(ty); - - ftv->generics.insert(ftv->generics.end(), q.generics.begin(), q.generics.end()); - ftv->genericPacks.insert(ftv->genericPacks.end(), q.genericPacks.begin(), q.genericPacks.end()); - - if (ftv->generics.empty() && ftv->genericPacks.empty() && !q.seenMutableType && !q.seenGenericType) - ftv->hasNoGenerics = true; - } - } - else - { - Quantifier q{level}; - q.traverse(ty); + Quantifier q{level}; + q.traverse(ty); - FunctionType* ftv = getMutable(ty); - LUAU_ASSERT(ftv); - ftv->generics.insert(ftv->generics.end(), q.generics.begin(), q.generics.end()); - ftv->genericPacks.insert(ftv->genericPacks.end(), q.genericPacks.begin(), q.genericPacks.end()); - } + FunctionType* ftv = getMutable(ty); + LUAU_ASSERT(ftv); + ftv->generics.insert(ftv->generics.end(), q.generics.begin(), q.generics.end()); + ftv->genericPacks.insert(ftv->genericPacks.end(), q.genericPacks.begin(), q.genericPacks.end()); } struct PureQuantifier : Substitution { Scope* scope; - std::vector insertedGenerics; - std::vector insertedGenericPacks; + OrderedMap insertedGenerics; + OrderedMap insertedGenericPacks; bool seenMutableType = false; bool seenGenericType = false; @@ -204,7 +159,7 @@ struct PureQuantifier : Substitution if (auto ftv = get(ty)) { TypeId result = arena->addType(GenericType{scope}); - insertedGenerics.push_back(result); + insertedGenerics.push(ty, result); return result; } else if (auto ttv = get(ty)) @@ -218,7 +173,10 @@ struct PureQuantifier : Substitution resultTable->scope = scope; if (ttv->state == TableState::Free) + { resultTable->state = TableState::Generic; + insertedGenerics.push(ty, result); + } else if (ttv->state == TableState::Unsealed) resultTable->state = TableState::Sealed; @@ -232,8 +190,8 @@ struct PureQuantifier : Substitution { if (auto ftp = get(tp)) { - TypePackId result = arena->addTypePack(TypePackVar{GenericTypePack{}}); - insertedGenericPacks.push_back(result); + TypePackId result = arena->addTypePack(TypePackVar{GenericTypePack{scope}}); + insertedGenericPacks.push(tp, result); return result; } @@ -242,7 +200,7 @@ struct PureQuantifier : Substitution bool ignoreChildren(TypeId ty) override { - if (FFlag::LuauClassTypeVarsInSubstitution && get(ty)) + if (get(ty)) return true; return ty->persistent; @@ -253,20 +211,30 @@ struct PureQuantifier : Substitution } }; -TypeId quantify(TypeArena* arena, TypeId ty, Scope* scope) +std::optional quantify(TypeArena* arena, TypeId ty, Scope* scope) { PureQuantifier quantifier{arena, scope}; std::optional result = quantifier.substitute(ty); - LUAU_ASSERT(result); + if (!result) + return std::nullopt; FunctionType* ftv = getMutable(*result); LUAU_ASSERT(ftv); ftv->scope = scope; - ftv->generics.insert(ftv->generics.end(), quantifier.insertedGenerics.begin(), quantifier.insertedGenerics.end()); - ftv->genericPacks.insert(ftv->genericPacks.end(), quantifier.insertedGenericPacks.begin(), quantifier.insertedGenericPacks.end()); - ftv->hasNoGenerics = ftv->generics.empty() && ftv->genericPacks.empty() && !quantifier.seenGenericType && !quantifier.seenMutableType; - return *result; + for (auto k : quantifier.insertedGenerics.keys) + { + TypeId g = quantifier.insertedGenerics.pairings[k]; + if (get(g)) + ftv->generics.push_back(g); + } + + for (auto k : quantifier.insertedGenericPacks.keys) + ftv->genericPacks.push_back(quantifier.insertedGenericPacks.pairings[k]); + + ftv->hasNoFreeOrGenericTypes = ftv->generics.empty() && ftv->genericPacks.empty() && !quantifier.seenGenericType && !quantifier.seenMutableType; + + return std::optional({*result, std::move(quantifier.insertedGenerics), std::move(quantifier.insertedGenericPacks)}); } } // namespace Luau diff --git a/Analysis/src/Refinement.cpp b/Analysis/src/Refinement.cpp index a81063c7b..e98b6e5a9 100644 --- a/Analysis/src/Refinement.cpp +++ b/Analysis/src/Refinement.cpp @@ -1,37 +1,60 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/Refinement.h" +#include namespace Luau { RefinementId RefinementArena::variadic(const std::vector& refis) { + bool hasRefinements = false; + for (RefinementId r : refis) + hasRefinements |= bool(r); + + if (!hasRefinements) + return nullptr; + return NotNull{allocator.allocate(Variadic{refis})}; } RefinementId RefinementArena::negation(RefinementId refinement) { + if (!refinement) + return nullptr; + return NotNull{allocator.allocate(Negation{refinement})}; } RefinementId RefinementArena::conjunction(RefinementId lhs, RefinementId rhs) { + if (!lhs && !rhs) + return nullptr; + return NotNull{allocator.allocate(Conjunction{lhs, rhs})}; } RefinementId RefinementArena::disjunction(RefinementId lhs, RefinementId rhs) { + if (!lhs && !rhs) + return nullptr; + return NotNull{allocator.allocate(Disjunction{lhs, rhs})}; } RefinementId RefinementArena::equivalence(RefinementId lhs, RefinementId rhs) { + if (!lhs && !rhs) + return nullptr; + return NotNull{allocator.allocate(Equivalence{lhs, rhs})}; } -RefinementId RefinementArena::proposition(BreadcrumbId breadcrumb, TypeId discriminantTy) +RefinementId RefinementArena::proposition(const RefinementKey* key, TypeId discriminantTy) { - return NotNull{allocator.allocate(Proposition{breadcrumb, discriminantTy})}; + if (!key) + return nullptr; + + return NotNull{allocator.allocate(Proposition{key, discriminantTy})}; } } // namespace Luau diff --git a/Analysis/src/Scope.cpp b/Analysis/src/Scope.cpp index cac72124e..27894505f 100644 --- a/Analysis/src/Scope.cpp +++ b/Analysis/src/Scope.cpp @@ -2,6 +2,8 @@ #include "Luau/Scope.h" +LUAU_FASTFLAG(LuauSolverV2); + namespace Luau { @@ -36,6 +38,24 @@ std::optional Scope::lookup(Symbol sym) const return std::nullopt; } +std::optional> Scope::lookupEx(DefId def) +{ + Scope* s = this; + + while (true) + { + if (TypeId* it = s->lvalueTypes.find(def)) + return std::pair{*it, s}; + else if (TypeId* it = s->rvalueRefinements.find(def)) + return std::pair{*it, s}; + + if (s->parent) + s = s->parent.get(); + else + return std::nullopt; + } +} + std::optional> Scope::lookupEx(Symbol sym) { Scope* s = this; @@ -53,19 +73,31 @@ std::optional> Scope::lookupEx(Symbol sym) } } -// TODO: We might kill Scope::lookup(Symbol) once data flow is fully fleshed out with type states and control flow analysis. +std::optional Scope::lookupUnrefinedType(DefId def) const +{ + for (const Scope* current = this; current; current = current->parent.get()) + { + if (auto ty = current->lvalueTypes.find(def)) + return *ty; + } + + return std::nullopt; +} + std::optional Scope::lookup(DefId def) const { for (const Scope* current = this; current; current = current->parent.get()) { - if (auto ty = current->dcrRefinements.find(def)) + if (auto ty = current->rvalueRefinements.find(def)) + return *ty; + if (auto ty = current->lvalueTypes.find(def)) return *ty; } return std::nullopt; } -std::optional Scope::lookupType(const Name& name) +std::optional Scope::lookupType(const Name& name) const { const Scope* scope = this; while (true) @@ -85,7 +117,7 @@ std::optional Scope::lookupType(const Name& name) } } -std::optional Scope::lookupImportedType(const Name& moduleAlias, const Name& name) +std::optional Scope::lookupImportedType(const Name& moduleAlias, const Name& name) const { const Scope* scope = this; while (scope) @@ -110,7 +142,7 @@ std::optional Scope::lookupImportedType(const Name& moduleAlias, const return std::nullopt; } -std::optional Scope::lookupPack(const Name& name) +std::optional Scope::lookupPack(const Name& name) const { const Scope* scope = this; while (true) @@ -149,6 +181,36 @@ std::optional Scope::linearSearchForBinding(const std::string& name, bo return std::nullopt; } +// Updates the `this` scope with the assignments from the `childScope` including ones that doesn't exist in `this`. +void Scope::inheritAssignments(const ScopePtr& childScope) +{ + if (!FFlag::LuauSolverV2) + return; + + for (const auto& [k, a] : childScope->lvalueTypes) + lvalueTypes[k] = a; +} + +// Updates the `this` scope with the refinements from the `childScope` excluding ones that doesn't exist in `this`. +void Scope::inheritRefinements(const ScopePtr& childScope) +{ + if (FFlag::LuauSolverV2) + { + for (const auto& [k, a] : childScope->rvalueRefinements) + { + if (lookup(NotNull{k})) + rvalueRefinements[k] = a; + } + } + + for (const auto& [k, a] : childScope->refinements) + { + Symbol symbol = getBaseSymbol(k); + if (lookup(symbol)) + refinements[k] = a; + } +} + bool subsumesStrict(Scope* left, Scope* right) { while (right) diff --git a/Analysis/src/Simplify.cpp b/Analysis/src/Simplify.cpp new file mode 100644 index 000000000..3a1e3bd10 --- /dev/null +++ b/Analysis/src/Simplify.cpp @@ -0,0 +1,1451 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Luau/Simplify.h" + +#include "Luau/Common.h" +#include "Luau/DenseHash.h" +#include "Luau/RecursionCounter.h" +#include "Luau/Set.h" +#include "Luau/TypeArena.h" +#include "Luau/TypePairHash.h" +#include "Luau/TypeUtils.h" + +#include + +LUAU_FASTINT(LuauTypeReductionRecursionLimit) +LUAU_FASTFLAG(LuauSolverV2) +LUAU_DYNAMIC_FASTINTVARIABLE(LuauSimplificationComplexityLimit, 8); +LUAU_FASTFLAGVARIABLE(LuauFlagBasicIntersectFollows, false); + +namespace Luau +{ + +using SimplifierSeenSet = Set, TypePairHash>; + +struct TypeSimplifier +{ + NotNull builtinTypes; + NotNull arena; + + DenseHashSet blockedTypes{nullptr}; + + int recursionDepth = 0; + + TypeId mkNegation(TypeId ty); + + TypeId intersectFromParts(std::set parts); + + TypeId intersectUnionWithType(TypeId unionTy, TypeId right); + TypeId intersectUnions(TypeId left, TypeId right); + TypeId intersectNegatedUnion(TypeId unionTy, TypeId right); + + TypeId intersectTypeWithNegation(TypeId a, TypeId b); + TypeId intersectNegations(TypeId a, TypeId b); + + TypeId intersectIntersectionWithType(TypeId left, TypeId right); + + // Attempt to intersect the two types. Does not recurse. Does not handle + // unions, intersections, or negations. + std::optional basicIntersect(TypeId left, TypeId right); + + TypeId intersect(TypeId ty, TypeId discriminant); + TypeId union_(TypeId ty, TypeId discriminant); + + TypeId simplify(TypeId ty); + TypeId simplify(TypeId ty, DenseHashSet& seen); +}; + +// Match the exact type false|nil +static bool isFalsyType(TypeId ty) +{ + ty = follow(ty); + const UnionType* ut = get(ty); + if (!ut) + return false; + + bool hasFalse = false; + bool hasNil = false; + + auto it = begin(ut); + if (it == end(ut)) + return false; + + TypeId t = follow(*it); + + if (auto pt = get(t); pt && pt->type == PrimitiveType::NilType) + hasNil = true; + else if (auto st = get(t); st && st->variant == BooleanSingleton{false}) + hasFalse = true; + else + return false; + + ++it; + if (it == end(ut)) + return false; + + t = follow(*it); + + if (auto pt = get(t); pt && pt->type == PrimitiveType::NilType) + hasNil = true; + else if (auto st = get(t); st && st->variant == BooleanSingleton{false}) + hasFalse = true; + else + return false; + + ++it; + if (it != end(ut)) + return false; + + return hasFalse && hasNil; +} + +// Match the exact type ~(false|nil) +bool isTruthyType(TypeId ty) +{ + ty = follow(ty); + + const NegationType* nt = get(ty); + if (!nt) + return false; + + return isFalsyType(nt->ty); +} + +Relation flip(Relation rel) +{ + switch (rel) + { + case Relation::Subset: + return Relation::Superset; + case Relation::Superset: + return Relation::Subset; + default: + return rel; + } +} + +// FIXME: I'm not completely certain that this function is theoretically reasonable. +Relation combine(Relation a, Relation b) +{ + switch (a) + { + case Relation::Disjoint: + switch (b) + { + case Relation::Disjoint: + return Relation::Disjoint; + case Relation::Coincident: + return Relation::Superset; + case Relation::Intersects: + return Relation::Intersects; + case Relation::Subset: + return Relation::Intersects; + case Relation::Superset: + return Relation::Intersects; + } + break; + case Relation::Coincident: + switch (b) + { + case Relation::Disjoint: + return Relation::Coincident; + case Relation::Coincident: + return Relation::Coincident; + case Relation::Intersects: + return Relation::Superset; + case Relation::Subset: + return Relation::Coincident; + case Relation::Superset: + return Relation::Intersects; + } + break; + case Relation::Superset: + switch (b) + { + case Relation::Disjoint: + return Relation::Superset; + case Relation::Coincident: + return Relation::Superset; + case Relation::Intersects: + return Relation::Intersects; + case Relation::Subset: + return Relation::Intersects; + case Relation::Superset: + return Relation::Superset; + } + break; + case Relation::Subset: + switch (b) + { + case Relation::Disjoint: + return Relation::Subset; + case Relation::Coincident: + return Relation::Coincident; + case Relation::Intersects: + return Relation::Intersects; + case Relation::Subset: + return Relation::Subset; + case Relation::Superset: + return Relation::Intersects; + } + break; + case Relation::Intersects: + switch (b) + { + case Relation::Disjoint: + return Relation::Intersects; + case Relation::Coincident: + return Relation::Superset; + case Relation::Intersects: + return Relation::Intersects; + case Relation::Subset: + return Relation::Intersects; + case Relation::Superset: + return Relation::Intersects; + } + break; + } + + LUAU_UNREACHABLE(); + return Relation::Intersects; +} + +// Given A & B, what is A & ~B? +Relation invert(Relation r) +{ + switch (r) + { + case Relation::Disjoint: + return Relation::Subset; + case Relation::Coincident: + return Relation::Disjoint; + case Relation::Intersects: + return Relation::Intersects; + case Relation::Subset: + return Relation::Disjoint; + case Relation::Superset: + return Relation::Intersects; + } + + LUAU_UNREACHABLE(); + return Relation::Intersects; +} + +static bool isTypeVariable(TypeId ty) +{ + return get(ty) || get(ty) || get(ty) || get(ty); +} + +Relation relate(TypeId left, TypeId right, SimplifierSeenSet& seen); + +Relation relateTables(TypeId left, TypeId right, SimplifierSeenSet& seen) +{ + NotNull leftTable{get(left)}; + NotNull rightTable{get(right)}; + LUAU_ASSERT(1 == rightTable->props.size()); + // Disjoint props have nothing in common + // t1 with props p1's cannot appear in t2 and t2 with props p2's cannot appear in t1 + bool foundPropFromLeftInRight = std::any_of( + begin(leftTable->props), + end(leftTable->props), + [&](auto prop) + { + return rightTable->props.count(prop.first) > 0; + } + ); + bool foundPropFromRightInLeft = std::any_of( + begin(rightTable->props), + end(rightTable->props), + [&](auto prop) + { + return leftTable->props.count(prop.first) > 0; + } + ); + + if (!foundPropFromLeftInRight && !foundPropFromRightInLeft && leftTable->props.size() >= 1 && rightTable->props.size() >= 1) + return Relation::Disjoint; + + const auto [propName, rightProp] = *begin(rightTable->props); + + auto it = leftTable->props.find(propName); + if (it == leftTable->props.end()) + { + // Every table lacking a property is a supertype of a table having that + // property but the reverse is not true. + return Relation::Superset; + } + + const Property leftProp = it->second; + + if (!leftProp.isShared() || !rightProp.isShared()) + return Relation::Intersects; + + Relation r = relate(leftProp.type(), rightProp.type(), seen); + if (r == Relation::Coincident && 1 != leftTable->props.size()) + { + // eg {tag: "cat", prop: string} & {tag: "cat"} + return Relation::Subset; + } + else + return r; +} + +// A cheap and approximate subtype test +Relation relate(TypeId left, TypeId right, SimplifierSeenSet& seen) +{ + // TODO nice to have: Relate functions of equal argument and return arity + + left = follow(left); + right = follow(right); + + if (left == right) + return Relation::Coincident; + + std::pair typePair{left, right}; + if (!seen.insert(typePair)) + { + // TODO: is this right at all? + // The thinking here is that this is a cycle if we get here, and therefore its coincident. + return Relation::Coincident; + } + + if (get(left)) + { + if (get(right)) + return Relation::Subset; + else if (get(right)) + return Relation::Coincident; + else if (get(right)) + return Relation::Disjoint; + else + return Relation::Superset; + } + + if (get(right)) + return flip(relate(right, left, seen)); + + if (get(left)) + { + if (get(right)) + return Relation::Coincident; + else + return Relation::Superset; + } + + if (get(right)) + return flip(relate(right, left, seen)); + + // Type variables + // * FreeType + // * GenericType + // * BlockedType + // * PendingExpansionType + + // Tops and bottoms + // * ErrorType + // * AnyType + // * NeverType + // * UnknownType + + // Concrete + // * PrimitiveType + // * SingletonType + // * FunctionType + // * TableType + // * MetatableType + // * ClassType + // * UnionType + // * IntersectionType + // * NegationType + + if (isTypeVariable(left) || isTypeVariable(right)) + return Relation::Intersects; + + if (get(left)) + { + if (get(right)) + return Relation::Coincident; + else if (get(right)) + return Relation::Subset; + else + return Relation::Disjoint; + } + if (get(right)) + return flip(relate(right, left, seen)); + + if (get(left)) + { + if (get(right)) + return Relation::Coincident; + else + return Relation::Subset; + } + if (get(right)) + return flip(relate(right, left, seen)); + + if (auto ut = get(left)) + return Relation::Intersects; + else if (auto ut = get(right)) + return Relation::Intersects; + + if (auto ut = get(left)) + return Relation::Intersects; + else if (auto ut = get(right)) + { + std::vector opts; + for (TypeId part : ut) + { + Relation r = relate(left, part, seen); + + if (r == Relation::Subset || r == Relation::Coincident) + return Relation::Subset; + } + return Relation::Intersects; + } + + if (auto rnt = get(right)) + { + Relation a = relate(left, rnt->ty, seen); + switch (a) + { + case Relation::Coincident: + // number & ~number + return Relation::Disjoint; + case Relation::Disjoint: + if (get(left)) + { + // ~number & ~string + return Relation::Intersects; + } + else + { + // number & ~string + return Relation::Subset; + } + case Relation::Intersects: + // ~(false?) & ~boolean + return Relation::Intersects; + case Relation::Subset: + // "hello" & ~string + return Relation::Disjoint; + case Relation::Superset: + // ~function & ~(false?) -> ~function + // boolean & ~(false?) -> true + // string & ~"hello" -> string & ~"hello" + return Relation::Intersects; + } + } + else if (get(left)) + return flip(relate(right, left, seen)); + + if (auto lp = get(left)) + { + if (auto rp = get(right)) + { + if (lp->type == rp->type) + return Relation::Coincident; + else + return Relation::Disjoint; + } + + if (auto rs = get(right)) + { + if (lp->type == PrimitiveType::String && rs->variant.get_if()) + return Relation::Superset; + else if (lp->type == PrimitiveType::Boolean && rs->variant.get_if()) + return Relation::Superset; + else + return Relation::Disjoint; + } + + if (lp->type == PrimitiveType::Function) + { + if (get(right)) + return Relation::Superset; + else + return Relation::Disjoint; + } + if (lp->type == PrimitiveType::Table) + { + if (get(right)) + return Relation::Superset; + else + return Relation::Disjoint; + } + + if (get(right) || get(right) || get(right) || get(right)) + return Relation::Disjoint; + } + + if (auto ls = get(left)) + { + if (get(right) || get(right) || get(right) || get(right)) + return Relation::Disjoint; + + if (get(right)) + return flip(relate(right, left, seen)); + if (auto rs = get(right)) + { + if (ls->variant == rs->variant) + return Relation::Coincident; + else + return Relation::Disjoint; + } + } + + if (get(left)) + { + if (auto rp = get(right)) + { + if (rp->type == PrimitiveType::Function) + return Relation::Subset; + else + return Relation::Disjoint; + } + else + return Relation::Intersects; + } + + if (auto lt = get(left)) + { + if (auto rp = get(right)) + { + if (rp->type == PrimitiveType::Table) + return Relation::Subset; + else + return Relation::Disjoint; + } + else if (auto rt = get(right)) + { + // TODO PROBABLY indexers and metatables. + if (1 == rt->props.size()) + { + Relation r = relateTables(left, right, seen); + /* + * A reduction of these intersections is certainly possible, but + * it would require minting new table types. Also, I don't think + * it's super likely for this to arise from a refinement. + * + * Time will tell! + * + * ex we simplify this + * {tag: string} & {tag: "cat"} + * but not this + * {tag: string, prop: number} & {tag: "cat"} + */ + if (lt->props.size() > 1 && r == Relation::Superset) + return Relation::Intersects; + else + return r; + } + else if (1 == lt->props.size()) + return flip(relate(right, left, seen)); + else + return Relation::Intersects; + } + // TODO metatables + + return Relation::Disjoint; + } + + if (auto ct = get(left)) + { + if (auto rct = get(right)) + { + if (isSubclass(ct, rct)) + return Relation::Subset; + else if (isSubclass(rct, ct)) + return Relation::Superset; + else + return Relation::Disjoint; + } + + return Relation::Disjoint; + } + + return Relation::Intersects; +} + +// A cheap and approximate subtype test +Relation relate(TypeId left, TypeId right) +{ + SimplifierSeenSet seen{{}}; + return relate(left, right, seen); +} + +TypeId TypeSimplifier::mkNegation(TypeId ty) +{ + TypeId result = nullptr; + + if (ty == builtinTypes->truthyType) + result = builtinTypes->falsyType; + else if (ty == builtinTypes->falsyType) + result = builtinTypes->truthyType; + else if (auto ntv = get(ty)) + result = follow(ntv->ty); + else + result = arena->addType(NegationType{ty}); + + return result; +} + +TypeId TypeSimplifier::intersectFromParts(std::set parts) +{ + if (0 == parts.size()) + return builtinTypes->neverType; + else if (1 == parts.size()) + return *begin(parts); + + { + auto it = begin(parts); + while (it != end(parts)) + { + TypeId t = follow(*it); + + auto copy = it; + ++it; + + if (auto ut = get(t)) + { + for (TypeId part : ut) + parts.insert(part); + parts.erase(copy); + } + } + } + + std::set newParts; + + /* + * It is possible that the parts of the passed intersection are themselves + * reducable. + * + * eg false & boolean + * + * We do a comparison between each pair of types and look for things that we + * can elide. + */ + for (TypeId part : parts) + { + if (newParts.empty()) + { + newParts.insert(part); + continue; + } + + auto it = begin(newParts); + while (it != end(newParts)) + { + TypeId p = *it; + + switch (relate(part, p)) + { + case Relation::Disjoint: + // eg boolean & string + return builtinTypes->neverType; + case Relation::Subset: + { + /* part is a subset of p. Remove p from the set and replace it + * with part. + * + * eg boolean & true + */ + auto saveIt = it; + ++it; + newParts.erase(saveIt); + continue; + } + case Relation::Coincident: + case Relation::Superset: + { + /* part is coincident or a superset of p. We do not need to + * include part in the final intersection. + * + * ex true & boolean + */ + ++it; + continue; + } + case Relation::Intersects: + { + /* It's complicated! A simplification may still be possible, + * but we have to pull the types apart to figure it out. + * + * ex boolean & ~false + */ + std::optional simplified = basicIntersect(part, p); + + auto saveIt = it; + ++it; + + if (simplified) + { + newParts.erase(saveIt); + newParts.insert(*simplified); + } + else + newParts.insert(part); + continue; + } + } + } + } + + if (0 == newParts.size()) + return builtinTypes->neverType; + else if (1 == newParts.size()) + return *begin(newParts); + else + return arena->addType(IntersectionType{std::vector{begin(newParts), end(newParts)}}); +} + +TypeId TypeSimplifier::intersectUnionWithType(TypeId left, TypeId right) +{ + const UnionType* leftUnion = get(left); + LUAU_ASSERT(leftUnion); + + bool changed = false; + std::set newParts; + + if (leftUnion->options.size() > (size_t)DFInt::LuauSimplificationComplexityLimit) + return arena->addType(IntersectionType{{left, right}}); + + for (TypeId part : leftUnion) + { + TypeId simplified = intersect(right, part); + changed |= simplified != part; + + if (get(simplified)) + { + changed = true; + continue; + } + + newParts.insert(simplified); + } + + if (!changed) + return left; + else if (newParts.empty()) + return builtinTypes->neverType; + else if (newParts.size() == 1) + return *begin(newParts); + else + return arena->addType(UnionType{std::vector(begin(newParts), end(newParts))}); +} + +TypeId TypeSimplifier::intersectUnions(TypeId left, TypeId right) +{ + const UnionType* leftUnion = get(left); + LUAU_ASSERT(leftUnion); + + const UnionType* rightUnion = get(right); + LUAU_ASSERT(rightUnion); + + std::set newParts; + + // Combinatorial blowup moment!! + + // combination size + size_t optionSize = (int)leftUnion->options.size() * rightUnion->options.size(); + size_t maxSize = DFInt::LuauSimplificationComplexityLimit; + + if (optionSize > maxSize) + return arena->addType(IntersectionType{{left, right}}); + + for (TypeId leftPart : leftUnion) + { + for (TypeId rightPart : rightUnion) + { + TypeId simplified = intersect(leftPart, rightPart); + if (get(simplified)) + continue; + + newParts.insert(simplified); + } + } + + if (newParts.empty()) + return builtinTypes->neverType; + else if (newParts.size() == 1) + return *begin(newParts); + else + return arena->addType(UnionType{std::vector(begin(newParts), end(newParts))}); +} + +TypeId TypeSimplifier::intersectNegatedUnion(TypeId left, TypeId right) +{ + // ~(A | B) & C + // (~A & C) & (~B & C) + + const NegationType* leftNegation = get(left); + LUAU_ASSERT(leftNegation); + + TypeId negatedTy = follow(leftNegation->ty); + + const UnionType* negatedUnion = get(negatedTy); + LUAU_ASSERT(negatedUnion); + + bool changed = false; + std::set newParts; + + for (TypeId part : negatedUnion) + { + Relation r = relate(part, right); + switch (r) + { + case Relation::Disjoint: + // If A is disjoint from B, then ~A & B is just B. + // + // ~(false?) & true + // (~false & true) & (~nil & true) + // true & true + newParts.insert(right); + break; + case Relation::Coincident: + // If A is coincident with or a superset of B, then ~A & B is never. + // + // ~(false?) & false + // (~false & false) & (~nil & false) + // never & false + // + // fallthrough + case Relation::Superset: + // If A is a superset of B, then ~A & B is never. + // + // ~(boolean | nil) & true + // (~boolean & true) & (~boolean & nil) + // never & nil + return builtinTypes->neverType; + case Relation::Subset: + case Relation::Intersects: + // If A is a subset of B, then ~A & B is a bit more complicated. We need to think harder. + // + // ~(false?) & boolean + // (~false & boolean) & (~nil & boolean) + // true & boolean + TypeId simplified = intersectTypeWithNegation(mkNegation(part), right); + changed |= simplified != right; + if (get(simplified)) + changed = true; + else + newParts.insert(simplified); + break; + } + } + + if (!changed) + return right; + else + return intersectFromParts(std::move(newParts)); +} + +TypeId TypeSimplifier::intersectTypeWithNegation(TypeId left, TypeId right) +{ + const NegationType* leftNegation = get(left); + LUAU_ASSERT(leftNegation); + + TypeId negatedTy = follow(leftNegation->ty); + + if (negatedTy == right) + return builtinTypes->neverType; + + if (auto ut = get(negatedTy)) + { + // ~(A | B) & C + // (~A & C) & (~B & C) + + bool changed = false; + std::set newParts; + + for (TypeId part : ut) + { + Relation r = relate(part, right); + switch (r) + { + case Relation::Coincident: + // ~(false?) & nil + // (~false & nil) & (~nil & nil) + // nil & never + // + // fallthrough + case Relation::Superset: + // ~(boolean | string) & true + // (~boolean & true) & (~boolean & string) + // never & string + + return builtinTypes->neverType; + + case Relation::Disjoint: + // ~nil & boolean + newParts.insert(right); + break; + + case Relation::Subset: + // ~false & boolean + // fallthrough + case Relation::Intersects: + // FIXME: The mkNegation here is pretty unfortunate. + // Memoizing this will probably be important. + changed = true; + newParts.insert(right); + newParts.insert(mkNegation(part)); + } + } + + if (!changed) + return right; + else + return intersectFromParts(std::move(newParts)); + } + + if (auto rightUnion = get(right)) + { + // ~A & (B | C) + bool changed = false; + std::set newParts; + + for (TypeId part : rightUnion) + { + Relation r = relate(negatedTy, part); + switch (r) + { + case Relation::Coincident: + changed = true; + continue; + case Relation::Disjoint: + newParts.insert(part); + break; + case Relation::Superset: + changed = true; + continue; + case Relation::Subset: + // fallthrough + case Relation::Intersects: + changed = true; + newParts.insert(arena->addType(IntersectionType{{left, part}})); + } + } + + if (!changed) + return right; + else if (0 == newParts.size()) + return builtinTypes->neverType; + else if (1 == newParts.size()) + return *begin(newParts); + else + return arena->addType(UnionType{std::vector{begin(newParts), end(newParts)}}); + } + + if (auto pt = get(right); pt && pt->type == PrimitiveType::Boolean) + { + if (auto st = get(negatedTy)) + { + if (st->variant == BooleanSingleton{true}) + return builtinTypes->falseType; + else if (st->variant == BooleanSingleton{false}) + return builtinTypes->trueType; + else + // boolean & ~"hello" + return builtinTypes->booleanType; + } + } + + Relation r = relate(negatedTy, right); + + switch (r) + { + case Relation::Disjoint: + // ~boolean & string + return right; + case Relation::Coincident: + // ~string & string + // fallthrough + case Relation::Superset: + // ~string & "hello" + return builtinTypes->neverType; + case Relation::Subset: + // ~string & unknown + // ~"hello" & string + // fallthrough + case Relation::Intersects: + // ~("hello" | boolean) & string + // fallthrough + default: + return arena->addType(IntersectionType{{left, right}}); + } +} + +TypeId TypeSimplifier::intersectNegations(TypeId left, TypeId right) +{ + const NegationType* leftNegation = get(left); + LUAU_ASSERT(leftNegation); + + if (get(follow(leftNegation->ty))) + return intersectNegatedUnion(left, right); + + const NegationType* rightNegation = get(right); + LUAU_ASSERT(rightNegation); + + if (get(follow(rightNegation->ty))) + return intersectNegatedUnion(right, left); + + Relation r = relate(leftNegation->ty, rightNegation->ty); + + switch (r) + { + case Relation::Coincident: + // ~true & ~true + return left; + case Relation::Subset: + // ~true & ~boolean + return right; + case Relation::Superset: + // ~boolean & ~true + return left; + case Relation::Intersects: + case Relation::Disjoint: + default: + // ~boolean & ~string + return arena->addType(IntersectionType{{left, right}}); + } +} + +TypeId TypeSimplifier::intersectIntersectionWithType(TypeId left, TypeId right) +{ + const IntersectionType* leftIntersection = get(left); + LUAU_ASSERT(leftIntersection); + + if (leftIntersection->parts.size() > (size_t)DFInt::LuauSimplificationComplexityLimit) + return arena->addType(IntersectionType{{left, right}}); + + bool changed = false; + std::set newParts; + + for (TypeId part : leftIntersection) + { + Relation r = relate(part, right); + switch (r) + { + case Relation::Disjoint: + return builtinTypes->neverType; + case Relation::Coincident: + newParts.insert(part); + continue; + case Relation::Subset: + newParts.insert(part); + continue; + case Relation::Superset: + newParts.insert(right); + changed = true; + continue; + default: + newParts.insert(part); + newParts.insert(right); + changed = true; + continue; + } + } + + // It is sometimes the case that an intersection operation will result in + // clipping a free type from the result. + // + // eg (number & 'a) & string --> never + // + // We want to only report the free types that are part of the result. + for (TypeId part : newParts) + { + if (isTypeVariable(part)) + blockedTypes.insert(part); + } + + if (!changed) + return left; + return intersectFromParts(std::move(newParts)); +} + +std::optional TypeSimplifier::basicIntersect(TypeId left, TypeId right) +{ + if (FFlag::LuauFlagBasicIntersectFollows) + { + left = follow(left); + right = follow(right); + } + + if (get(left) && get(right)) + return right; + if (get(right) && get(left)) + return left; + if (get(left)) + return arena->addType(UnionType{{right, builtinTypes->errorType}}); + if (get(right)) + return arena->addType(UnionType{{left, builtinTypes->errorType}}); + if (get(left)) + return right; + if (get(right)) + return left; + if (get(left)) + return left; + if (get(right)) + return right; + + if (auto pt = get(left); pt && pt->type == PrimitiveType::Boolean) + { + if (auto st = get(right); st && st->variant.get_if()) + return right; + if (auto nt = get(right)) + { + if (auto st = get(follow(nt->ty)); st && st->variant.get_if()) + { + if (st->variant == BooleanSingleton{true}) + return builtinTypes->falseType; + else + return builtinTypes->trueType; + } + } + } + else if (auto pt = get(right); pt && pt->type == PrimitiveType::Boolean) + { + if (auto st = get(left); st && st->variant.get_if()) + return left; + if (auto nt = get(left)) + { + if (auto st = get(follow(nt->ty)); st && st->variant.get_if()) + { + if (st->variant == BooleanSingleton{true}) + return builtinTypes->falseType; + else + return builtinTypes->trueType; + } + } + } + + if (const auto [lt, rt] = get2(left, right); lt && rt) + { + if (1 == lt->props.size()) + { + const auto [propName, leftProp] = *begin(lt->props); + + auto it = rt->props.find(propName); + if (it != rt->props.end() && leftProp.isShared() && it->second.isShared()) + { + Relation r = relate(leftProp.type(), it->second.type()); + + switch (r) + { + case Relation::Disjoint: + return builtinTypes->neverType; + case Relation::Superset: + case Relation::Coincident: + return right; + case Relation::Subset: + if (1 == rt->props.size()) + return left; + break; + default: + break; + } + } + } + else if (1 == rt->props.size()) + return basicIntersect(right, left); + + // If two tables have disjoint properties and indexers, we can combine them. + if (!lt->indexer && !rt->indexer && lt->state == TableState::Sealed && rt->state == TableState::Sealed) + { + if (rt->props.empty()) + return left; + + bool areDisjoint = true; + for (const auto& [name, leftProp] : lt->props) + { + if (rt->props.count(name)) + { + areDisjoint = false; + break; + } + } + + if (areDisjoint) + { + TableType::Props mergedProps = lt->props; + for (const auto& [name, rightProp] : rt->props) + mergedProps[name] = rightProp; + + return arena->addType(TableType{mergedProps, std::nullopt, TypeLevel{}, lt->scope, TableState::Sealed}); + } + } + + return std::nullopt; + } + + Relation relation = relate(left, right); + if (left == right || Relation::Coincident == relation) + return left; + + if (relation == Relation::Disjoint) + return builtinTypes->neverType; + else if (relation == Relation::Subset) + return left; + else if (relation == Relation::Superset) + return right; + + return std::nullopt; +} + +TypeId TypeSimplifier::intersect(TypeId left, TypeId right) +{ + RecursionLimiter rl(&recursionDepth, 15); + + left = simplify(left); + right = simplify(right); + + if (left == right) + return left; + + if (get(left) && get(right)) + return right; + if (get(right) && get(left)) + return left; + if (get(left) && !get(right)) + return right; + if (get(right) && !get(left)) + return left; + if (get(left)) + return arena->addType(UnionType{{right, builtinTypes->errorType}}); + if (get(right)) + return arena->addType(UnionType{{left, builtinTypes->errorType}}); + if (get(left)) + return right; + if (get(right)) + return left; + if (get(left)) + return left; + if (get(right)) + return right; + + if (auto lf = get(left)) + { + Relation r = relate(lf->upperBound, right); + if (r == Relation::Subset || r == Relation::Coincident) + return left; + } + else if (auto rf = get(right)) + { + Relation r = relate(left, rf->upperBound); + if (r == Relation::Superset || r == Relation::Coincident) + return right; + } + + if (isTypeVariable(left)) + { + blockedTypes.insert(left); + return arena->addType(IntersectionType{{left, right}}); + } + + if (isTypeVariable(right)) + { + blockedTypes.insert(right); + return arena->addType(IntersectionType{{left, right}}); + } + + if (auto ut = get(left)) + { + if (get(right)) + return intersectUnions(left, right); + else + return intersectUnionWithType(left, right); + } + else if (auto ut = get(right)) + return intersectUnionWithType(right, left); + + if (auto it = get(left)) + return intersectIntersectionWithType(left, right); + else if (auto it = get(right)) + return intersectIntersectionWithType(right, left); + + if (get(left)) + { + if (get(right)) + return intersectNegations(left, right); + else + return intersectTypeWithNegation(left, right); + } + else if (get(right)) + return intersectTypeWithNegation(right, left); + + std::optional res = basicIntersect(left, right); + if (res) + return *res; + else + return arena->addType(IntersectionType{{left, right}}); +} + +TypeId TypeSimplifier::union_(TypeId left, TypeId right) +{ + RecursionLimiter rl(&recursionDepth, 15); + + left = simplify(left); + right = simplify(right); + + if (get(left)) + return right; + if (get(right)) + return left; + + if (auto leftUnion = get(left)) + { + bool changed = false; + std::set newParts; + for (TypeId part : leftUnion) + { + if (get(part)) + { + changed = true; + continue; + } + + Relation r = relate(part, right); + switch (r) + { + case Relation::Coincident: + case Relation::Superset: + return left; + case Relation::Subset: + newParts.insert(right); + changed = true; + break; + default: + newParts.insert(part); + newParts.insert(right); + changed = true; + break; + } + } + + if (!changed) + return left; + if (0 == newParts.size()) + { + // If the left-side is changed but has no parts, then the left-side union is uninhabited. + return right; + } + else if (1 == newParts.size()) + return *begin(newParts); + else + return arena->addType(UnionType{std::vector{begin(newParts), end(newParts)}}); + } + else if (get(right)) + return union_(right, left); + + Relation r = relate(left, right); + if (left == right || r == Relation::Coincident || r == Relation::Superset) + return left; + + if (r == Relation::Subset) + return right; + + if (auto as = get(left)) + { + if (auto abs = as->variant.get_if()) + { + if (auto bs = get(right)) + { + if (auto bbs = bs->variant.get_if()) + { + if (abs->value != bbs->value) + return builtinTypes->booleanType; + } + } + } + } + + return arena->addType(UnionType{{left, right}}); +} + +TypeId TypeSimplifier::simplify(TypeId ty) +{ + DenseHashSet seen{nullptr}; + return simplify(ty, seen); +} + +TypeId TypeSimplifier::simplify(TypeId ty, DenseHashSet& seen) +{ + RecursionLimiter limiter(&recursionDepth, 60); + + ty = follow(ty); + + if (seen.find(ty)) + return ty; + seen.insert(ty); + + if (auto nt = get(ty)) + { + TypeId negatedTy = follow(nt->ty); + if (get(negatedTy)) + return arena->addType(UnionType{{builtinTypes->neverType, builtinTypes->errorType}}); + else if (get(negatedTy)) + return builtinTypes->neverType; + else if (get(negatedTy)) + return builtinTypes->unknownType; + if (auto nnt = get(negatedTy)) + return simplify(nnt->ty, seen); + } + + // Promote {x: never} to never + if (auto tt = get(ty)) + { + if (1 == tt->props.size()) + { + if (std::optional readTy = begin(tt->props)->second.readTy) + { + TypeId propTy = simplify(*readTy, seen); + if (get(propTy)) + return builtinTypes->neverType; + } + } + } + + return ty; +} + +SimplifyResult simplifyIntersection(NotNull builtinTypes, NotNull arena, TypeId left, TypeId right) +{ + LUAU_ASSERT(FFlag::LuauSolverV2); + + TypeSimplifier s{builtinTypes, arena}; + + // fprintf(stderr, "Intersect %s and %s ...\n", toString(left).c_str(), toString(right).c_str()); + + TypeId res = s.intersect(left, right); + + // fprintf(stderr, "Intersect %s and %s -> %s\n", toString(left).c_str(), toString(right).c_str(), toString(res).c_str()); + + return SimplifyResult{res, std::move(s.blockedTypes)}; +} + +SimplifyResult simplifyIntersection(NotNull builtinTypes, NotNull arena, std::set parts) +{ + LUAU_ASSERT(FFlag::LuauSolverV2); + + TypeSimplifier s{builtinTypes, arena}; + + TypeId res = s.intersectFromParts(std::move(parts)); + + return SimplifyResult{res, std::move(s.blockedTypes)}; +} + +SimplifyResult simplifyUnion(NotNull builtinTypes, NotNull arena, TypeId left, TypeId right) +{ + LUAU_ASSERT(FFlag::LuauSolverV2); + + TypeSimplifier s{builtinTypes, arena}; + + TypeId res = s.union_(left, right); + + // fprintf(stderr, "Union %s and %s -> %s\n", toString(left).c_str(), toString(right).c_str(), toString(res).c_str()); + + return SimplifyResult{res, std::move(s.blockedTypes)}; +} + +} // namespace Luau diff --git a/Analysis/src/Substitution.cpp b/Analysis/src/Substitution.cpp index 160647a05..dd5a2f85f 100644 --- a/Analysis/src/Substitution.cpp +++ b/Analysis/src/Substitution.cpp @@ -8,20 +8,165 @@ #include #include -LUAU_FASTFLAGVARIABLE(LuauSubstitutionFixMissingFields, false) -LUAU_FASTFLAG(LuauClonePublicInterfaceLess) LUAU_FASTINTVARIABLE(LuauTarjanChildLimit, 10000) -LUAU_FASTFLAGVARIABLE(LuauClassTypeVarsInSubstitution, false) -LUAU_FASTFLAGVARIABLE(LuauSubstitutionReentrant, false) +LUAU_FASTFLAG(LuauSolverV2); +LUAU_FASTINTVARIABLE(LuauTarjanPreallocationSize, 256); namespace Luau { +static TypeId shallowClone(TypeId ty, TypeArena& dest, const TxnLog* log, bool alwaysClone) +{ + auto go = [ty, &dest, alwaysClone](auto&& a) + { + using T = std::decay_t; + + // The pointer identities of free and local types is very important. + // We decline to copy them. + if constexpr (std::is_same_v) + return ty; + else if constexpr (std::is_same_v) + { + // This should never happen, but visit() cannot see it. + LUAU_ASSERT(!"shallowClone didn't follow its argument!"); + return dest.addType(BoundType{a.boundTo}); + } + else if constexpr (std::is_same_v) + return dest.addType(a); + else if constexpr (std::is_same_v) + return dest.addType(a); + else if constexpr (std::is_same_v) + { + LUAU_ASSERT(ty->persistent); + return ty; + } + else if constexpr (std::is_same_v) + { + PendingExpansionType clone = PendingExpansionType{a.prefix, a.name, a.typeArguments, a.packArguments}; + return dest.addType(std::move(clone)); + } + else if constexpr (std::is_same_v) + { + LUAU_ASSERT(ty->persistent); + return ty; + } + else if constexpr (std::is_same_v) + { + LUAU_ASSERT(ty->persistent); + return ty; + } + else if constexpr (std::is_same_v) + { + LUAU_ASSERT(ty->persistent); + return ty; + } + else if constexpr (std::is_same_v) + { + LUAU_ASSERT(ty->persistent); + return ty; + } + else if constexpr (std::is_same_v) + { + LUAU_ASSERT(ty->persistent); + return ty; + } + else if constexpr (std::is_same_v) + return ty; + else if constexpr (std::is_same_v) + return dest.addType(a); + else if constexpr (std::is_same_v) + { + FunctionType clone = FunctionType{a.level, a.scope, a.argTypes, a.retTypes, a.definition, a.hasSelf}; + clone.generics = a.generics; + clone.genericPacks = a.genericPacks; + clone.magicFunction = a.magicFunction; + clone.dcrMagicFunction = a.dcrMagicFunction; + clone.dcrMagicRefinement = a.dcrMagicRefinement; + clone.tags = a.tags; + clone.argNames = a.argNames; + clone.isCheckedFunction = a.isCheckedFunction; + return dest.addType(std::move(clone)); + } + else if constexpr (std::is_same_v) + { + LUAU_ASSERT(!a.boundTo); + TableType clone = TableType{a.props, a.indexer, a.level, a.scope, a.state}; + clone.definitionModuleName = a.definitionModuleName; + clone.definitionLocation = a.definitionLocation; + clone.name = a.name; + clone.syntheticName = a.syntheticName; + clone.instantiatedTypeParams = a.instantiatedTypeParams; + clone.instantiatedTypePackParams = a.instantiatedTypePackParams; + clone.tags = a.tags; + return dest.addType(std::move(clone)); + } + else if constexpr (std::is_same_v) + { + MetatableType clone = MetatableType{a.table, a.metatable}; + clone.syntheticName = a.syntheticName; + return dest.addType(std::move(clone)); + } + else if constexpr (std::is_same_v) + { + UnionType clone; + clone.options = a.options; + return dest.addType(std::move(clone)); + } + else if constexpr (std::is_same_v) + { + IntersectionType clone; + clone.parts = a.parts; + return dest.addType(std::move(clone)); + } + else if constexpr (std::is_same_v) + { + if (alwaysClone) + { + ClassType clone{a.name, a.props, a.parent, a.metatable, a.tags, a.userData, a.definitionModuleName, a.definitionLocation, a.indexer}; + return dest.addType(std::move(clone)); + } + else + return ty; + } + else if constexpr (std::is_same_v) + return dest.addType(NegationType{a.ty}); + else if constexpr (std::is_same_v) + { + TypeFunctionInstanceType clone{a.function, a.typeArguments, a.packArguments, a.userFuncName}; + return dest.addType(std::move(clone)); + } + else + static_assert(always_false_v, "Non-exhaustive shallowClone switch"); + }; + + ty = log->follow(ty); + + if (auto pty = log->pending(ty)) + ty = &pty->pending; + + TypeId resTy = visit(go, ty->ty); + if (resTy != ty) + asMutable(resTy)->documentationSymbol = ty->documentationSymbol; + + return resTy; +} + +Tarjan::Tarjan() + : typeToIndex(nullptr, FInt::LuauTarjanPreallocationSize) + , packToIndex(nullptr, FInt::LuauTarjanPreallocationSize) +{ + nodes.reserve(FInt::LuauTarjanPreallocationSize); + stack.reserve(FInt::LuauTarjanPreallocationSize); + edgesTy.reserve(FInt::LuauTarjanPreallocationSize); + edgesTp.reserve(FInt::LuauTarjanPreallocationSize); + worklist.reserve(FInt::LuauTarjanPreallocationSize); +} + void Tarjan::visitChildren(TypeId ty, int index) { LUAU_ASSERT(ty == log->follow(ty)); - if (ignoreChildren(ty)) + if (ignoreChildrenVisit(ty)) return; if (auto pty = log->pending(ty)) @@ -29,13 +174,10 @@ void Tarjan::visitChildren(TypeId ty, int index) if (const FunctionType* ftv = get(ty)) { - if (FFlag::LuauSubstitutionFixMissingFields) - { - for (TypeId generic : ftv->generics) - visitChild(generic); - for (TypePackId genericPack : ftv->genericPacks) - visitChild(genericPack); - } + for (TypeId generic : ftv->generics) + visitChild(generic); + for (TypePackId genericPack : ftv->genericPacks) + visitChild(genericPack); visitChild(ftv->argTypes); visitChild(ftv->retTypes); @@ -44,7 +186,16 @@ void Tarjan::visitChildren(TypeId ty, int index) { LUAU_ASSERT(!ttv->boundTo); for (const auto& [name, prop] : ttv->props) - visitChild(prop.type); + { + if (FFlag::LuauSolverV2) + { + visitChild(prop.readTy); + visitChild(prop.writeTy); + } + else + visitChild(prop.type()); + } + if (ttv->indexer) { visitChild(ttv->indexer->indexType); @@ -80,16 +231,30 @@ void Tarjan::visitChildren(TypeId ty, int index) for (TypePackId a : petv->packArguments) visitChild(a); } - else if (const ClassType* ctv = get(ty); FFlag::LuauClassTypeVarsInSubstitution && ctv) + else if (const TypeFunctionInstanceType* tfit = get(ty)) { - for (auto [name, prop] : ctv->props) - visitChild(prop.type); + for (TypeId a : tfit->typeArguments) + visitChild(a); + + for (TypePackId a : tfit->packArguments) + visitChild(a); + } + else if (const ClassType* ctv = get(ty)) + { + for (const auto& [name, prop] : ctv->props) + visitChild(prop.type()); if (ctv->parent) visitChild(*ctv->parent); if (ctv->metatable) visitChild(*ctv->metatable); + + if (ctv->indexer) + { + visitChild(ctv->indexer->indexType); + visitChild(ctv->indexer->indexResultType); + } } else if (const NegationType* ntv = get(ty)) { @@ -101,7 +266,7 @@ void Tarjan::visitChildren(TypePackId tp, int index) { LUAU_ASSERT(tp == log->follow(tp)); - if (ignoreChildren(tp)) + if (ignoreChildrenVisit(tp)) return; if (auto ptp = log->pending(tp)) @@ -124,17 +289,14 @@ std::pair Tarjan::indexify(TypeId ty) { ty = log->follow(ty); - bool fresh = !typeToIndex.contains(ty); - int& index = typeToIndex[ty]; + auto [index, fresh] = typeToIndex.try_insert(ty, false); if (fresh) { - index = int(indexToType.size()); - indexToType.push_back(ty); - indexToPack.push_back(nullptr); - onStack.push_back(false); - lowlink.push_back(index); + index = int(nodes.size()); + nodes.push_back({ty, nullptr, false, false, index}); } + return {index, fresh}; } @@ -142,17 +304,14 @@ std::pair Tarjan::indexify(TypePackId tp) { tp = log->follow(tp); - bool fresh = !packToIndex.contains(tp); - int& index = packToIndex[tp]; + auto [index, fresh] = packToIndex.try_insert(tp, false); if (fresh) { - index = int(indexToPack.size()); - indexToType.push_back(nullptr); - indexToPack.push_back(tp); - onStack.push_back(false); - lowlink.push_back(index); + index = int(nodes.size()); + nodes.push_back({nullptr, tp, false, false, index}); } + return {index, fresh}; } @@ -187,14 +346,15 @@ TarjanResult Tarjan::loop() return TarjanResult::TooManyChildren; stack.push_back(index); - onStack[index] = true; + + nodes[index].onStack = true; currEdge = int(edgesTy.size()); // Fill in edge list of this vertex - if (TypeId ty = indexToType[index]) + if (TypeId ty = nodes[index].ty) visitChildren(ty, index); - else if (TypePackId tp = indexToPack[index]) + else if (TypePackId tp = nodes[index].tp) visitChildren(tp, index); lastEdge = int(edgesTy.size()); @@ -225,9 +385,9 @@ TarjanResult Tarjan::loop() foundFresh = true; break; } - else if (onStack[childIndex]) + else if (nodes[childIndex].onStack) { - lowlink[index] = std::min(lowlink[index], childIndex); + nodes[index].lowlink = std::min(nodes[index].lowlink, childIndex); } visitEdge(childIndex, index); @@ -236,14 +396,14 @@ TarjanResult Tarjan::loop() if (foundFresh) continue; - if (lowlink[index] == index) + if (nodes[index].lowlink == index) { visitSCC(index); while (!stack.empty()) { int popped = stack.back(); stack.pop_back(); - onStack[popped] = false; + nodes[popped].onStack = false; if (popped == index) break; } @@ -260,7 +420,7 @@ TarjanResult Tarjan::loop() edgesTy.resize(parentEndEdge); edgesTp.resize(parentEndEdge); - lowlink[parentIndex] = std::min(lowlink[parentIndex], lowlink[index]); + nodes[parentIndex].lowlink = std::min(nodes[parentIndex].lowlink, nodes[index].lowlink); visitEdge(index, parentIndex); } } @@ -294,54 +454,56 @@ TarjanResult Tarjan::visitRoot(TypePackId tp) return loop(); } -void FindDirty::clearTarjan() +void Tarjan::clearTarjan(const TxnLog* log) { - dirty.clear(); + typeToIndex.clear(~0u); + packToIndex.clear(~0u); - typeToIndex.clear(); - packToIndex.clear(); - indexToType.clear(); - indexToPack.clear(); + nodes.clear(); stack.clear(); - onStack.clear(); - lowlink.clear(); + + childCount = 0; + // childLimit setting stays the same + + this->log = log; edgesTy.clear(); edgesTp.clear(); worklist.clear(); } -bool FindDirty::getDirty(int index) +bool Tarjan::getDirty(int index) { - if (dirty.size() <= size_t(index)) - dirty.resize(index + 1, false); - return dirty[index]; + LUAU_ASSERT(size_t(index) < nodes.size()); + return nodes[index].dirty; } -void FindDirty::setDirty(int index, bool d) +void Tarjan::setDirty(int index, bool d) { - if (dirty.size() <= size_t(index)) - dirty.resize(index + 1, false); - dirty[index] = d; + LUAU_ASSERT(size_t(index) < nodes.size()); + nodes[index].dirty = d; } -void FindDirty::visitEdge(int index, int parentIndex) +void Tarjan::visitEdge(int index, int parentIndex) { if (getDirty(index)) setDirty(parentIndex, true); } -void FindDirty::visitSCC(int index) +void Tarjan::visitSCC(int index) { bool d = getDirty(index); for (auto it = stack.rbegin(); !d && it != stack.rend(); it++) { - if (TypeId ty = indexToType[*it]) + TarjanNode& node = nodes[*it]; + + if (TypeId ty = node.ty) d = isDirty(ty); - else if (TypePackId tp = indexToPack[*it]) + else if (TypePackId tp = node.tp) d = isDirty(tp); + if (*it == index) break; } @@ -352,32 +514,52 @@ void FindDirty::visitSCC(int index) for (auto it = stack.rbegin(); it != stack.rend(); it++) { setDirty(*it, true); - if (TypeId ty = indexToType[*it]) + + TarjanNode& node = nodes[*it]; + + if (TypeId ty = node.ty) foundDirty(ty); - else if (TypePackId tp = indexToPack[*it]) + else if (TypePackId tp = node.tp) foundDirty(tp); + if (*it == index) return; } } -TarjanResult FindDirty::findDirty(TypeId ty) +TarjanResult Tarjan::findDirty(TypeId ty) { return visitRoot(ty); } -TarjanResult FindDirty::findDirty(TypePackId tp) +TarjanResult Tarjan::findDirty(TypePackId tp) { return visitRoot(tp); } +Substitution::Substitution(const TxnLog* log_, TypeArena* arena) + : arena(arena) +{ + log = log_; + LUAU_ASSERT(log); +} + +void Substitution::dontTraverseInto(TypeId ty) +{ + noTraverseTypes.insert(ty); +} + +void Substitution::dontTraverseInto(TypePackId tp) +{ + noTraverseTypePacks.insert(tp); +} + std::optional Substitution::substitute(TypeId ty) { ty = log->follow(ty); // clear algorithm state for reentrancy - if (FFlag::LuauSubstitutionReentrant) - clearTarjan(); + clearTarjan(log); auto result = findDirty(ty); if (result != TarjanResult::Ok) @@ -385,34 +567,20 @@ std::optional Substitution::substitute(TypeId ty) for (auto [oldTy, newTy] : newTypes) { - if (FFlag::LuauSubstitutionReentrant) + if (!ignoreChildren(oldTy) && !replacedTypes.contains(newTy)) { - if (!ignoreChildren(oldTy) && !replacedTypes.contains(newTy)) - { - replaceChildren(newTy); - replacedTypes.insert(newTy); - } - } - else - { - if (!ignoreChildren(oldTy)) + if (!noTraverseTypes.contains(newTy)) replaceChildren(newTy); + replacedTypes.insert(newTy); } } for (auto [oldTp, newTp] : newPacks) { - if (FFlag::LuauSubstitutionReentrant) + if (!ignoreChildren(oldTp) && !replacedTypePacks.contains(newTp)) { - if (!ignoreChildren(oldTp) && !replacedTypePacks.contains(newTp)) - { - replaceChildren(newTp); - replacedTypePacks.insert(newTp); - } - } - else - { - if (!ignoreChildren(oldTp)) + if (!noTraverseTypePacks.contains(newTp)) replaceChildren(newTp); + replacedTypePacks.insert(newTp); } } TypeId newTy = replace(ty); @@ -424,8 +592,7 @@ std::optional Substitution::substitute(TypePackId tp) tp = log->follow(tp); // clear algorithm state for reentrancy - if (FFlag::LuauSubstitutionReentrant) - clearTarjan(); + clearTarjan(log); auto result = findDirty(tp); if (result != TarjanResult::Ok) @@ -433,43 +600,44 @@ std::optional Substitution::substitute(TypePackId tp) for (auto [oldTy, newTy] : newTypes) { - if (FFlag::LuauSubstitutionReentrant) - { - if (!ignoreChildren(oldTy) && !replacedTypes.contains(newTy)) - { - replaceChildren(newTy); - replacedTypes.insert(newTy); - } - } - else + if (!ignoreChildren(oldTy) && !replacedTypes.contains(newTy)) { - if (!ignoreChildren(oldTy)) + if (!noTraverseTypes.contains(newTy)) replaceChildren(newTy); + replacedTypes.insert(newTy); } } for (auto [oldTp, newTp] : newPacks) { - if (FFlag::LuauSubstitutionReentrant) + if (!ignoreChildren(oldTp) && !replacedTypePacks.contains(newTp)) { - if (!ignoreChildren(oldTp) && !replacedTypePacks.contains(newTp)) - { - replaceChildren(newTp); - replacedTypePacks.insert(newTp); - } - } - else - { - if (!ignoreChildren(oldTp)) + if (!noTraverseTypePacks.contains(newTp)) replaceChildren(newTp); + replacedTypePacks.insert(newTp); } } TypePackId newTp = replace(tp); return newTp; } +void Substitution::resetState(const TxnLog* log, TypeArena* arena) +{ + clearTarjan(log); + + this->arena = arena; + + newTypes.clear(); + newPacks.clear(); + replacedTypes.clear(); + replacedTypePacks.clear(); + + noTraverseTypes.clear(); + noTraverseTypePacks.clear(); +} + TypeId Substitution::clone(TypeId ty) { - return shallowClone(ty, *arena, log, /* alwaysClone */ FFlag::LuauClonePublicInterfaceLess); + return shallowClone(ty, *arena, log, /* alwaysClone */ true); } TypePackId Substitution::clone(TypePackId tp) @@ -490,23 +658,27 @@ TypePackId Substitution::clone(TypePackId tp) { VariadicTypePack clone; clone.ty = vtp->ty; - if (FFlag::LuauSubstitutionFixMissingFields) - clone.hidden = vtp->hidden; + clone.hidden = vtp->hidden; return addTypePack(std::move(clone)); } - else if (FFlag::LuauClonePublicInterfaceLess) + else if (const TypeFunctionInstanceTypePack* tfitp = get(tp)) { - return addTypePack(*tp); + TypeFunctionInstanceTypePack clone{ + tfitp->function, std::vector(tfitp->typeArguments.size()), std::vector(tfitp->packArguments.size()) + }; + clone.typeArguments.assign(tfitp->typeArguments.begin(), tfitp->typeArguments.end()); + clone.packArguments.assign(tfitp->packArguments.begin(), tfitp->packArguments.end()); + return addTypePack(std::move(clone)); } else - return tp; + return addTypePack(*tp); } void Substitution::foundDirty(TypeId ty) { ty = log->follow(ty); - if (FFlag::LuauSubstitutionReentrant && newTypes.contains(ty)) + if (newTypes.contains(ty)) return; if (isDirty(ty)) @@ -519,7 +691,7 @@ void Substitution::foundDirty(TypePackId tp) { tp = log->follow(tp); - if (FFlag::LuauSubstitutionReentrant && newPacks.contains(tp)) + if (newPacks.contains(tp)) return; if (isDirty(tp)) @@ -560,13 +732,10 @@ void Substitution::replaceChildren(TypeId ty) if (FunctionType* ftv = getMutable(ty)) { - if (FFlag::LuauSubstitutionFixMissingFields) - { - for (TypeId& generic : ftv->generics) - generic = replace(generic); - for (TypePackId& genericPack : ftv->genericPacks) - genericPack = replace(genericPack); - } + for (TypeId& generic : ftv->generics) + generic = replace(generic); + for (TypePackId& genericPack : ftv->genericPacks) + genericPack = replace(genericPack); ftv->argTypes = replace(ftv->argTypes); ftv->retTypes = replace(ftv->retTypes); @@ -575,7 +744,18 @@ void Substitution::replaceChildren(TypeId ty) { LUAU_ASSERT(!ttv->boundTo); for (auto& [name, prop] : ttv->props) - prop.type = replace(prop.type); + { + if (FFlag::LuauSolverV2) + { + if (prop.readTy) + prop.readTy = replace(prop.readTy); + if (prop.writeTy) + prop.writeTy = replace(prop.writeTy); + } + else + prop.setType(replace(prop.type())); + } + if (ttv->indexer) { ttv->indexer->indexType = replace(ttv->indexer->indexType); @@ -611,16 +791,30 @@ void Substitution::replaceChildren(TypeId ty) for (TypePackId& a : petv->packArguments) a = replace(a); } - else if (ClassType* ctv = getMutable(ty); FFlag::LuauClassTypeVarsInSubstitution && ctv) + else if (TypeFunctionInstanceType* tfit = getMutable(ty)) + { + for (TypeId& a : tfit->typeArguments) + a = replace(a); + + for (TypePackId& a : tfit->packArguments) + a = replace(a); + } + else if (ClassType* ctv = getMutable(ty)) { for (auto& [name, prop] : ctv->props) - prop.type = replace(prop.type); + prop.setType(replace(prop.type())); if (ctv->parent) ctv->parent = replace(*ctv->parent); if (ctv->metatable) ctv->metatable = replace(*ctv->metatable); + + if (ctv->indexer) + { + ctv->indexer->indexType = replace(ctv->indexer->indexType); + ctv->indexer->indexResultType = replace(ctv->indexer->indexResultType); + } } else if (NegationType* ntv = getMutable(ty)) { @@ -649,6 +843,14 @@ void Substitution::replaceChildren(TypePackId tp) { vtp->ty = replace(vtp->ty); } + else if (TypeFunctionInstanceTypePack* tfitp = getMutable(tp)) + { + for (TypeId& t : tfitp->typeArguments) + t = replace(t); + + for (TypePackId& t : tfitp->packArguments) + t = replace(t); + } } } // namespace Luau diff --git a/Analysis/src/Subtyping.cpp b/Analysis/src/Subtyping.cpp new file mode 100644 index 000000000..6c84c9333 --- /dev/null +++ b/Analysis/src/Subtyping.cpp @@ -0,0 +1,1872 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Luau/Subtyping.h" + +#include "Luau/Common.h" +#include "Luau/Error.h" +#include "Luau/Normalize.h" +#include "Luau/RecursionCounter.h" +#include "Luau/Scope.h" +#include "Luau/StringUtils.h" +#include "Luau/Substitution.h" +#include "Luau/ToString.h" +#include "Luau/TxnLog.h" +#include "Luau/Type.h" +#include "Luau/TypeArena.h" +#include "Luau/TypeCheckLimits.h" +#include "Luau/TypeFunction.h" +#include "Luau/TypePack.h" +#include "Luau/TypePath.h" +#include "Luau/TypeUtils.h" + +#include + +LUAU_FASTFLAGVARIABLE(DebugLuauSubtypingCheckPathValidity, false); +LUAU_DYNAMIC_FASTINT(LuauTypeSolverRelease) + +namespace Luau +{ + +struct VarianceFlipper +{ + Subtyping::Variance* variance; + Subtyping::Variance oldValue; + + VarianceFlipper(Subtyping::Variance* v) + : variance(v) + , oldValue(*v) + { + switch (oldValue) + { + case Subtyping::Variance::Covariant: + *variance = Subtyping::Variance::Contravariant; + break; + case Subtyping::Variance::Contravariant: + *variance = Subtyping::Variance::Covariant; + break; + } + } + + ~VarianceFlipper() + { + *variance = oldValue; + } +}; + +bool SubtypingReasoning::operator==(const SubtypingReasoning& other) const +{ + return subPath == other.subPath && superPath == other.superPath && variance == other.variance; +} + +size_t SubtypingReasoningHash::operator()(const SubtypingReasoning& r) const +{ + return TypePath::PathHash()(r.subPath) ^ (TypePath::PathHash()(r.superPath) << 1) ^ (static_cast(r.variance) << 1); +} + +template +static void assertReasoningValid(TID subTy, TID superTy, const SubtypingResult& result, NotNull builtinTypes) +{ + if (!FFlag::DebugLuauSubtypingCheckPathValidity) + return; + + for (const SubtypingReasoning& reasoning : result.reasoning) + { + LUAU_ASSERT(traverse(subTy, reasoning.subPath, builtinTypes)); + LUAU_ASSERT(traverse(superTy, reasoning.superPath, builtinTypes)); + } +} + +template<> +void assertReasoningValid(TableIndexer subIdx, TableIndexer superIdx, const SubtypingResult& result, NotNull builtinTypes) +{ + // Empty method to satisfy the compiler. +} + +static SubtypingReasonings mergeReasonings(const SubtypingReasonings& a, const SubtypingReasonings& b) +{ + SubtypingReasonings result{kEmptyReasoning}; + + for (const SubtypingReasoning& r : a) + { + if (r.variance == SubtypingVariance::Invariant) + result.insert(r); + else if (r.variance == SubtypingVariance::Covariant || r.variance == SubtypingVariance::Contravariant) + { + SubtypingReasoning inverseReasoning = SubtypingReasoning{ + r.subPath, r.superPath, r.variance == SubtypingVariance::Covariant ? SubtypingVariance::Contravariant : SubtypingVariance::Covariant + }; + if (b.contains(inverseReasoning)) + result.insert(SubtypingReasoning{r.subPath, r.superPath, SubtypingVariance::Invariant}); + else + result.insert(r); + } + } + + for (const SubtypingReasoning& r : b) + { + if (r.variance == SubtypingVariance::Invariant) + result.insert(r); + else if (r.variance == SubtypingVariance::Covariant || r.variance == SubtypingVariance::Contravariant) + { + SubtypingReasoning inverseReasoning = SubtypingReasoning{ + r.subPath, r.superPath, r.variance == SubtypingVariance::Covariant ? SubtypingVariance::Contravariant : SubtypingVariance::Covariant + }; + if (a.contains(inverseReasoning)) + result.insert(SubtypingReasoning{r.subPath, r.superPath, SubtypingVariance::Invariant}); + else + result.insert(r); + } + } + + return result; +} + +SubtypingResult& SubtypingResult::andAlso(const SubtypingResult& other) +{ + // If the other result is not a subtype, we want to join all of its + // reasonings to this one. If this result already has reasonings of its own, + // those need to be attributed here whenever this _also_ failed. + if (!other.isSubtype) + reasoning = isSubtype ? std::move(other.reasoning) : mergeReasonings(reasoning, other.reasoning); + + isSubtype &= other.isSubtype; + normalizationTooComplex |= other.normalizationTooComplex; + isCacheable &= other.isCacheable; + errors.insert(errors.end(), other.errors.begin(), other.errors.end()); + + return *this; +} + +SubtypingResult& SubtypingResult::orElse(const SubtypingResult& other) +{ + // If this result is a subtype, we do not join the reasoning lists. If this + // result is not a subtype, but the other is a subtype, we want to _clear_ + // our reasoning list. If both results are not subtypes, we join the + // reasoning lists. + if (!isSubtype) + { + if (other.isSubtype) + reasoning.clear(); + else + reasoning = mergeReasonings(reasoning, other.reasoning); + } + + isSubtype |= other.isSubtype; + normalizationTooComplex |= other.normalizationTooComplex; + isCacheable &= other.isCacheable; + errors.insert(errors.end(), other.errors.begin(), other.errors.end()); + + return *this; +} + +SubtypingResult& SubtypingResult::withBothComponent(TypePath::Component component) +{ + return withSubComponent(component).withSuperComponent(component); +} + +SubtypingResult& SubtypingResult::withSubComponent(TypePath::Component component) +{ + if (reasoning.empty()) + reasoning.insert(SubtypingReasoning{Path(component), TypePath::kEmpty}); + else + { + for (auto& r : reasoning) + r.subPath = r.subPath.push_front(component); + } + + return *this; +} + +SubtypingResult& SubtypingResult::withSuperComponent(TypePath::Component component) +{ + if (reasoning.empty()) + reasoning.insert(SubtypingReasoning{TypePath::kEmpty, Path(component)}); + else + { + for (auto& r : reasoning) + r.superPath = r.superPath.push_front(component); + } + + return *this; +} + +SubtypingResult& SubtypingResult::withBothPath(TypePath::Path path) +{ + return withSubPath(path).withSuperPath(path); +} + +SubtypingResult& SubtypingResult::withSubPath(TypePath::Path path) +{ + if (reasoning.empty()) + reasoning.insert(SubtypingReasoning{path, TypePath::kEmpty}); + else + { + for (auto& r : reasoning) + r.subPath = path.append(r.subPath); + } + + return *this; +} + +SubtypingResult& SubtypingResult::withSuperPath(TypePath::Path path) +{ + if (reasoning.empty()) + reasoning.insert(SubtypingReasoning{TypePath::kEmpty, path}); + else + { + for (auto& r : reasoning) + r.superPath = path.append(r.superPath); + } + + return *this; +} + +SubtypingResult& SubtypingResult::withErrors(ErrorVec& err) +{ + for (TypeError& e : err) + errors.emplace_back(e); + return *this; +} + +SubtypingResult& SubtypingResult::withError(TypeError err) +{ + errors.push_back(std::move(err)); + return *this; +} + +SubtypingResult SubtypingResult::negate(const SubtypingResult& result) +{ + return SubtypingResult{ + !result.isSubtype, + result.normalizationTooComplex, + }; +} + +SubtypingResult SubtypingResult::all(const std::vector& results) +{ + SubtypingResult acc{true}; + for (const SubtypingResult& current : results) + acc.andAlso(current); + return acc; +} + +SubtypingResult SubtypingResult::any(const std::vector& results) +{ + SubtypingResult acc{false}; + for (const SubtypingResult& current : results) + acc.orElse(current); + return acc; +} + +struct ApplyMappedGenerics : Substitution +{ + NotNull builtinTypes; + NotNull arena; + + SubtypingEnvironment& env; + + ApplyMappedGenerics(NotNull builtinTypes, NotNull arena, SubtypingEnvironment& env) + : Substitution(TxnLog::empty(), arena) + , builtinTypes(builtinTypes) + , arena(arena) + , env(env) + { + } + + bool isDirty(TypeId ty) override + { + return env.containsMappedType(ty); + } + + bool isDirty(TypePackId tp) override + { + return env.containsMappedPack(tp); + } + + TypeId clean(TypeId ty) override + { + const auto& bounds = env.getMappedTypeBounds(ty); + + if (bounds.upperBound.empty()) + return builtinTypes->unknownType; + + if (bounds.upperBound.size() == 1) + return *begin(bounds.upperBound); + + return arena->addType(IntersectionType{std::vector(begin(bounds.upperBound), end(bounds.upperBound))}); + } + + TypePackId clean(TypePackId tp) override + { + if (auto it = env.getMappedPackBounds(tp)) + return *it; + + // Clean is only called when isDirty found a pack bound + LUAU_ASSERT(!"Unreachable"); + return nullptr; + } + + bool ignoreChildren(TypeId ty) override + { + if (get(ty)) + return true; + + return ty->persistent; + } + bool ignoreChildren(TypePackId ty) override + { + return ty->persistent; + } +}; + +std::optional SubtypingEnvironment::applyMappedGenerics(NotNull builtinTypes, NotNull arena, TypeId ty) +{ + ApplyMappedGenerics amg{builtinTypes, arena, *this}; + return amg.substitute(ty); +} + +const TypeId* SubtypingEnvironment::tryFindSubstitution(TypeId ty) const +{ + if (auto it = substitutions.find(ty)) + return it; + + if (parent) + return parent->tryFindSubstitution(ty); + + return nullptr; +} + +const SubtypingResult* SubtypingEnvironment::tryFindSubtypingResult(std::pair subAndSuper) const +{ + if (auto it = ephemeralCache.find(subAndSuper)) + return it; + + if (parent) + return parent->tryFindSubtypingResult(subAndSuper); + + return nullptr; +} + +bool SubtypingEnvironment::containsMappedType(TypeId ty) const +{ + if (mappedGenerics.contains(ty)) + return true; + + if (parent) + return parent->containsMappedType(ty); + + return false; +} + +bool SubtypingEnvironment::containsMappedPack(TypePackId tp) const +{ + if (mappedGenericPacks.contains(tp)) + return true; + + if (parent) + return parent->containsMappedPack(tp); + + return false; +} + +SubtypingEnvironment::GenericBounds& SubtypingEnvironment::getMappedTypeBounds(TypeId ty) +{ + if (auto it = mappedGenerics.find(ty)) + return *it; + + if (parent) + return parent->getMappedTypeBounds(ty); + + LUAU_ASSERT(!"Use containsMappedType before asking for bounds!"); + return mappedGenerics[ty]; +} + +TypePackId* SubtypingEnvironment::getMappedPackBounds(TypePackId tp) +{ + if (auto it = mappedGenericPacks.find(tp)) + return it; + + if (parent) + return parent->getMappedPackBounds(tp); + + // This fallback is reachable in valid cases, unlike the final part of getMappedTypeBounds + return nullptr; +} + +Subtyping::Subtyping( + NotNull builtinTypes, + NotNull typeArena, + NotNull normalizer, + NotNull typeFunctionRuntime, + NotNull iceReporter +) + : builtinTypes(builtinTypes) + , arena(typeArena) + , normalizer(normalizer) + , typeFunctionRuntime(typeFunctionRuntime) + , iceReporter(iceReporter) +{ +} + +SubtypingResult Subtyping::isSubtype(TypeId subTy, TypeId superTy, NotNull scope) +{ + SubtypingEnvironment env; + + SubtypingResult result = isCovariantWith(env, subTy, superTy, scope); + + for (const auto& [subTy, bounds] : env.mappedGenerics) + { + const auto& lb = bounds.lowerBound; + const auto& ub = bounds.upperBound; + TypeId lowerBound = makeAggregateType(lb, builtinTypes->neverType); + TypeId upperBound = makeAggregateType(ub, builtinTypes->unknownType); + + std::shared_ptr nt = normalizer->normalize(upperBound); + // we say that the result is true if normalization failed because complex types are likely to be inhabited. + NormalizationResult res = nt ? normalizer->isInhabited(nt.get()) : NormalizationResult::True; + + if (!nt || res == NormalizationResult::HitLimits) + result.normalizationTooComplex = true; + else if (res == NormalizationResult::False) + { + /* If the normalized upper bound we're mapping to a generic is + * uninhabited, then we must consider the subtyping relation not to + * hold. + * + * This happens eg in () -> (T, T) <: () -> (string, number) + * + * T appears in covariant position and would have to be both string + * and number at once. + * + * No actual value is both a string and a number, so the test fails. + * + * TODO: We'll need to add explanitory context here. + */ + result.isSubtype = false; + } + + + SubtypingEnvironment boundsEnv; + boundsEnv.parent = &env; + SubtypingResult boundsResult = isCovariantWith(boundsEnv, lowerBound, upperBound, scope); + boundsResult.reasoning.clear(); + + result.andAlso(boundsResult); + } + + /* TODO: We presently don't store subtype test results in the persistent + * cache if the left-side type is a generic function. + * + * The implementation would be a bit tricky and we haven't seen any material + * impact on benchmarks. + * + * What we would want to do is to remember points within the type where + * mapped generics are introduced. When all the contingent generics are + * introduced at which we're doing the test, we can mark the result as + * cacheable. + */ + + if (result.isCacheable) + resultCache[{subTy, superTy}] = result; + + return result; +} + +SubtypingResult Subtyping::isSubtype(TypePackId subTp, TypePackId superTp, NotNull scope) +{ + SubtypingEnvironment env; + return isCovariantWith(env, subTp, superTp, scope); +} + +SubtypingResult Subtyping::cache(SubtypingEnvironment& env, SubtypingResult result, TypeId subTy, TypeId superTy) +{ + const std::pair p{subTy, superTy}; + if (result.isCacheable) + resultCache[p] = result; + else + env.ephemeralCache[p] = result; + + return result; +} + +namespace +{ +struct SeenSetPopper +{ + Subtyping::SeenSet* seenTypes; + std::pair pair; + + SeenSetPopper(Subtyping::SeenSet* seenTypes, std::pair pair) + : seenTypes(seenTypes) + , pair(pair) + { + } + + ~SeenSetPopper() + { + seenTypes->erase(pair); + } +}; +} // namespace + +SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, TypeId subTy, TypeId superTy, NotNull scope) +{ + UnifierCounters& counters = normalizer->sharedState->counters; + RecursionCounter rc(&counters.recursionCount); + + if (counters.recursionLimit > 0 && counters.recursionLimit < counters.recursionCount) + { + SubtypingResult result; + result.normalizationTooComplex = true; + return result; + } + + subTy = follow(subTy); + superTy = follow(superTy); + + if (const TypeId* subIt = env.tryFindSubstitution(subTy); subIt && *subIt) + subTy = *subIt; + + if (const TypeId* superIt = env.tryFindSubstitution(superTy); superIt && *superIt) + superTy = *superIt; + + const SubtypingResult* cachedResult = resultCache.find({subTy, superTy}); + if (cachedResult) + return *cachedResult; + + cachedResult = env.tryFindSubtypingResult({subTy, superTy}); + if (cachedResult) + return *cachedResult; + + // TODO: Do we care about returning a proof that this is error-suppressing? + // e.g. given `a | error <: a | error` where both operands are pointer equal, + // then should it also carry the information that it's error-suppressing? + // If it should, then `error <: error` should also do the same. + if (subTy == superTy) + return {true}; + + std::pair typePair{subTy, superTy}; + if (!seenTypes.insert(typePair)) + { + /* TODO: Caching results for recursive types is really tricky to think + * about. + * + * We'd like to cache at the outermost level where we encounter the + * recursive type, but we do not want to cache interior results that + * involve the cycle. + * + * Presently, we stop at cycles and assume that the subtype check will + * succeed because we'll eventually get there if it won't. However, if + * that cyclic type turns out not to have the asked-for subtyping + * relation, then all the intermediate cached results that were + * contingent on that assumption need to be evicted from the cache, or + * not entered into the cache, or something. + * + * For now, we do the conservative thing and refuse to cache anything + * that touches a cycle. + */ + SubtypingResult res; + res.isSubtype = true; + res.isCacheable = false; + return res; + } + + SeenSetPopper ssp{&seenTypes, typePair}; + + // Within the scope to which a generic belongs, that generic should be + // tested as though it were its upper bounds. We do not yet support bounded + // generics, so the upper bound is always unknown. + if (auto subGeneric = get(subTy); subGeneric && subsumes(subGeneric->scope, scope)) + return isCovariantWith(env, builtinTypes->neverType, superTy, scope); + if (auto superGeneric = get(superTy); superGeneric && subsumes(superGeneric->scope, scope)) + return isCovariantWith(env, subTy, builtinTypes->unknownType, scope); + + SubtypingResult result; + + if (auto subUnion = get(subTy)) + result = isCovariantWith(env, subUnion, superTy, scope); + else if (auto superUnion = get(superTy)) + { + result = isCovariantWith(env, subTy, superUnion, scope); + if (!result.isSubtype && !result.normalizationTooComplex) + { + SubtypingResult semantic = isCovariantWith(env, normalizer->normalize(subTy), normalizer->normalize(superTy), scope); + if (semantic.isSubtype) + { + semantic.reasoning.clear(); + result = semantic; + } + } + } + else if (auto superIntersection = get(superTy)) + result = isCovariantWith(env, subTy, superIntersection, scope); + else if (auto subIntersection = get(subTy)) + { + result = isCovariantWith(env, subIntersection, superTy, scope); + if (!result.isSubtype && !result.normalizationTooComplex) + { + SubtypingResult semantic = isCovariantWith(env, normalizer->normalize(subTy), normalizer->normalize(superTy), scope); + if (semantic.isSubtype) + { + // Clear the semantic reasoning, as any reasonings within + // potentially contain invalid paths. + semantic.reasoning.clear(); + result = semantic; + } + } + } + else if (get(superTy)) + result = {true}; + + // We have added this as an exception - the set of inhabitants of any is exactly the set of inhabitants of unknown (since error has no + // inhabitants). any = err | unknown, so under semantic subtyping, {} U unknown = unknown + else if (get(subTy) && get(superTy)) + result = {true}; + else if (get(subTy)) + { + // any = unknown | error, so we rewrite this to match. + // As per TAPL: A | B <: T iff A <: T && B <: T + result = + isCovariantWith(env, builtinTypes->unknownType, superTy, scope).andAlso(isCovariantWith(env, builtinTypes->errorType, superTy, scope)); + } + else if (get(superTy)) + { + LUAU_ASSERT(!get(subTy)); // TODO: replace with ice. + LUAU_ASSERT(!get(subTy)); // TODO: replace with ice. + LUAU_ASSERT(!get(subTy)); // TODO: replace with ice. + + bool errorSuppressing = get(subTy); + result = {!errorSuppressing}; + } + else if (get(subTy)) + result = {true}; + else if (get(superTy)) + result = {false}; + else if (get(subTy)) + result = {true}; + else if (auto p = get2(subTy, superTy)) + result = isCovariantWith(env, p.first->ty, p.second->ty, scope).withBothComponent(TypePath::TypeField::Negated); + else if (auto subNegation = get(subTy)) + { + result = isCovariantWith(env, subNegation, superTy, scope); + if (!result.isSubtype && !result.normalizationTooComplex) + { + SubtypingResult semantic = isCovariantWith(env, normalizer->normalize(subTy), normalizer->normalize(superTy), scope); + if (semantic.isSubtype) + { + semantic.reasoning.clear(); + result = semantic; + } + } + } + else if (auto superNegation = get(superTy)) + { + result = isCovariantWith(env, subTy, superNegation, scope); + if (!result.isSubtype && !result.normalizationTooComplex) + { + SubtypingResult semantic = isCovariantWith(env, normalizer->normalize(subTy), normalizer->normalize(superTy), scope); + if (semantic.isSubtype) + { + semantic.reasoning.clear(); + result = semantic; + } + } + } + else if (auto subTypeFunctionInstance = get(subTy)) + { + if (auto substSubTy = env.applyMappedGenerics(builtinTypes, arena, subTy)) + subTypeFunctionInstance = get(*substSubTy); + + result = isCovariantWith(env, subTypeFunctionInstance, superTy, scope); + } + else if (auto superTypeFunctionInstance = get(superTy)) + { + if (auto substSuperTy = env.applyMappedGenerics(builtinTypes, arena, superTy)) + superTypeFunctionInstance = get(*substSuperTy); + + result = isCovariantWith(env, subTy, superTypeFunctionInstance, scope); + } + else if (auto subGeneric = get(subTy); subGeneric && variance == Variance::Covariant) + { + bool ok = bindGeneric(env, subTy, superTy); + result.isSubtype = ok; + result.isCacheable = false; + } + else if (auto superGeneric = get(superTy); superGeneric && variance == Variance::Contravariant) + { + bool ok = bindGeneric(env, subTy, superTy); + result.isSubtype = ok; + result.isCacheable = false; + } + else if (auto p = get2(subTy, superTy)) + result = isCovariantWith(env, p, scope); + else if (auto p = get2(subTy, superTy)) + result = isCovariantWith(env, p, scope); + else if (auto p = get2(subTy, superTy)) + result = isCovariantWith(env, p, scope); + else if (auto p = get2(subTy, superTy)) + { + auto [subFunction, superPrimitive] = p; + result.isSubtype = superPrimitive->type == PrimitiveType::Function; + } + else if (auto p = get2(subTy, superTy)) + result = isCovariantWith(env, p, scope); + else if (auto p = get2(subTy, superTy)) + result = isCovariantWith(env, p, scope); + else if (auto p = get2(subTy, superTy)) + result = isCovariantWith(env, p, scope); + else if (auto p = get2(subTy, superTy)) + result = isCovariantWith(env, p, scope); + else if (auto p = get2(subTy, superTy)) + result = isCovariantWith(env, p, scope); + else if (auto p = get2(subTy, superTy)) + result = isCovariantWith(env, subTy, p.first, superTy, p.second, scope); + else if (auto p = get2(subTy, superTy)) + result = isCovariantWith(env, p, scope); + else if (auto p = get2(subTy, superTy)) + result = isCovariantWith(env, p, scope); + else if (auto p = get2(subTy, superTy)) + result = isCovariantWith(env, p, scope); + + assertReasoningValid(subTy, superTy, result, builtinTypes); + + return cache(env, result, subTy, superTy); +} + +SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, TypePackId subTp, TypePackId superTp, NotNull scope) +{ + subTp = follow(subTp); + superTp = follow(superTp); + + auto [subHead, subTail] = flatten(subTp); + auto [superHead, superTail] = flatten(superTp); + + const size_t headSize = std::min(subHead.size(), superHead.size()); + + std::vector results; + results.reserve(std::max(subHead.size(), superHead.size()) + 1); + + if (subTp == superTp) + return {true}; + + // Match head types pairwise + + for (size_t i = 0; i < headSize; ++i) + results.push_back(isCovariantWith(env, subHead[i], superHead[i], scope).withBothComponent(TypePath::Index{i})); + + // Handle mismatched head sizes + + if (subHead.size() < superHead.size()) + { + if (subTail) + { + if (auto vt = get(*subTail)) + { + for (size_t i = headSize; i < superHead.size(); ++i) + results.push_back(isCovariantWith(env, vt->ty, superHead[i], scope) + .withSubPath(TypePath::PathBuilder().tail().variadic().build()) + .withSuperComponent(TypePath::Index{i})); + } + else if (auto gt = get(*subTail)) + { + if (variance == Variance::Covariant) + { + // For any non-generic type T: + // + // (X) -> () <: (T) -> () + + // Possible optimization: If headSize == 0 then we can just use subTp as-is. + std::vector headSlice(begin(superHead), begin(superHead) + headSize); + TypePackId superTailPack = arena->addTypePack(std::move(headSlice), superTail); + + if (TypePackId* other = env.getMappedPackBounds(*subTail)) + // TODO: TypePath can't express "slice of a pack + its tail". + results.push_back(isCovariantWith(env, *other, superTailPack, scope).withSubComponent(TypePath::PackField::Tail)); + else + env.mappedGenericPacks.try_insert(*subTail, superTailPack); + + // FIXME? Not a fan of the early return here. It makes the + // control flow harder to reason about. + return SubtypingResult::all(results); + } + else + { + // For any non-generic type T: + // + // (T) -> () (X) -> () + // + return SubtypingResult{false}.withSubComponent(TypePath::PackField::Tail); + } + } + else if (get(*subTail)) + return SubtypingResult{true}.withSubComponent(TypePath::PackField::Tail); + else + return SubtypingResult{false} + .withSubComponent(TypePath::PackField::Tail) + .withError({scope->location, UnexpectedTypePackInSubtyping{*subTail}}); + } + else + { + results.push_back({false}); + return SubtypingResult::all(results); + } + } + else if (subHead.size() > superHead.size()) + { + if (superTail) + { + if (auto vt = get(*superTail)) + { + for (size_t i = headSize; i < subHead.size(); ++i) + results.push_back(isCovariantWith(env, subHead[i], vt->ty, scope) + .withSubComponent(TypePath::Index{i}) + .withSuperPath(TypePath::PathBuilder().tail().variadic().build())); + } + else if (auto gt = get(*superTail)) + { + if (variance == Variance::Contravariant) + { + // For any non-generic type T: + // + // (X...) -> () <: (T) -> () + + // Possible optimization: If headSize == 0 then we can just use subTp as-is. + std::vector headSlice(begin(subHead), begin(subHead) + headSize); + TypePackId subTailPack = arena->addTypePack(std::move(headSlice), subTail); + + if (TypePackId* other = env.getMappedPackBounds(*superTail)) + // TODO: TypePath can't express "slice of a pack + its tail". + results.push_back(isContravariantWith(env, subTailPack, *other, scope).withSuperComponent(TypePath::PackField::Tail)); + else + env.mappedGenericPacks.try_insert(*superTail, subTailPack); + + // FIXME? Not a fan of the early return here. It makes the + // control flow harder to reason about. + return SubtypingResult::all(results); + } + else + { + // For any non-generic type T: + // + // () -> T () -> X... + return SubtypingResult{false}.withSuperComponent(TypePath::PackField::Tail); + } + } + else if (get(*superTail)) + return SubtypingResult{true}.withSuperComponent(TypePath::PackField::Tail); + else + return SubtypingResult{false} + .withSuperComponent(TypePath::PackField::Tail) + .withError({scope->location, UnexpectedTypePackInSubtyping{*subTail}}); + } + else + return {false}; + } + + // Handle tails + + if (subTail && superTail) + { + if (auto p = get2(*subTail, *superTail)) + { + // Variadic component is added by the isCovariantWith + // implementation; no need to add it here. + results.push_back(isCovariantWith(env, p, scope).withBothComponent(TypePath::PackField::Tail)); + } + else if (auto p = get2(*subTail, *superTail)) + { + bool ok = bindGeneric(env, *subTail, *superTail); + results.push_back(SubtypingResult{ok}.withBothComponent(TypePath::PackField::Tail)); + } + else if (auto p = get2(*subTail, *superTail)) + { + if (variance == Variance::Contravariant) + { + // (A...) -> number <: (...number) -> number + bool ok = bindGeneric(env, *subTail, *superTail); + results.push_back(SubtypingResult{ok}.withBothComponent(TypePath::PackField::Tail)); + } + else + { + // (number) -> ...number (number) -> A... + results.push_back(SubtypingResult{false}.withBothComponent(TypePath::PackField::Tail)); + } + } + else if (auto p = get2(*subTail, *superTail)) + { + if (TypeId t = follow(p.second->ty); get(t) || get(t)) + { + // Extra magic rule: + // T... <: ...any + // T... <: ...unknown + // + // See https://github.com/luau-lang/luau/issues/767 + } + else if (variance == Variance::Contravariant) + { + // (...number) -> number (A...) -> number + results.push_back(SubtypingResult{false}.withBothComponent(TypePath::PackField::Tail)); + } + else + { + // () -> A... <: () -> ...number + bool ok = bindGeneric(env, *subTail, *superTail); + results.push_back(SubtypingResult{ok}.withBothComponent(TypePath::PackField::Tail)); + } + } + else if (get(*subTail) || get(*superTail)) + // error type is fine on either side + results.push_back(SubtypingResult{true}.withBothComponent(TypePath::PackField::Tail)); + else + return SubtypingResult{false} + .withBothComponent(TypePath::PackField::Tail) + .withError({scope->location, UnexpectedTypePackInSubtyping{*subTail}}) + .withError({scope->location, UnexpectedTypePackInSubtyping{*superTail}}); + } + else if (subTail) + { + if (get(*subTail)) + { + return SubtypingResult{false}.withSubComponent(TypePath::PackField::Tail); + } + else if (get(*subTail)) + { + bool ok = bindGeneric(env, *subTail, builtinTypes->emptyTypePack); + return SubtypingResult{ok}.withSubComponent(TypePath::PackField::Tail); + } + else + return SubtypingResult{false} + .withSubComponent(TypePath::PackField::Tail) + .withError({scope->location, UnexpectedTypePackInSubtyping{*subTail}}); + } + else if (superTail) + { + if (get(*superTail)) + { + /* + * A variadic type pack ...T can be thought of as an infinite union of finite type packs. + * () | (T) | (T, T) | (T, T, T) | ... + * + * And, per TAPL: + * T <: A | B iff T <: A or T <: B + * + * All variadic type packs are therefore supertypes of the empty type pack. + */ + } + else if (get(*superTail)) + { + if (variance == Variance::Contravariant) + { + bool ok = bindGeneric(env, builtinTypes->emptyTypePack, *superTail); + results.push_back(SubtypingResult{ok}.withSuperComponent(TypePath::PackField::Tail)); + } + else + results.push_back(SubtypingResult{false}.withSuperComponent(TypePath::PackField::Tail)); + } + else + return SubtypingResult{false} + .withSuperComponent(TypePath::PackField::Tail) + .withError({scope->location, UnexpectedTypePackInSubtyping{*superTail}}); + } + + SubtypingResult result = SubtypingResult::all(results); + assertReasoningValid(subTp, superTp, result, builtinTypes); + + return result; +} + +template +SubtypingResult Subtyping::isContravariantWith(SubtypingEnvironment& env, SubTy&& subTy, SuperTy&& superTy, NotNull scope) +{ + VarianceFlipper vf{&variance}; + + SubtypingResult result = isCovariantWith(env, superTy, subTy, scope); + if (result.reasoning.empty()) + result.reasoning.insert(SubtypingReasoning{TypePath::kEmpty, TypePath::kEmpty, SubtypingVariance::Contravariant}); + else + { + // If we don't swap the paths here, we will end up producing an invalid path + // whenever we involve contravariance. We'll end up appending path + // components that should belong to the supertype to the subtype, and vice + // versa. + for (auto& reasoning : result.reasoning) + { + std::swap(reasoning.subPath, reasoning.superPath); + + // Also swap covariant/contravariant, since those are also the other way + // around. + if (reasoning.variance == SubtypingVariance::Covariant) + reasoning.variance = SubtypingVariance::Contravariant; + else if (reasoning.variance == SubtypingVariance::Contravariant) + reasoning.variance = SubtypingVariance::Covariant; + } + } + + assertReasoningValid(subTy, superTy, result, builtinTypes); + + return result; +} + +template +SubtypingResult Subtyping::isInvariantWith(SubtypingEnvironment& env, SubTy&& subTy, SuperTy&& superTy, NotNull scope) +{ + SubtypingResult result = isCovariantWith(env, subTy, superTy, scope).andAlso(isContravariantWith(env, subTy, superTy, scope)); + + if (result.reasoning.empty()) + result.reasoning.insert(SubtypingReasoning{TypePath::kEmpty, TypePath::kEmpty, SubtypingVariance::Invariant}); + else + { + for (auto& reasoning : result.reasoning) + reasoning.variance = SubtypingVariance::Invariant; + } + + assertReasoningValid(subTy, superTy, result, builtinTypes); + return result; +} + +template +SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, const TryPair& pair, NotNull scope) +{ + return isCovariantWith(env, pair.first, pair.second, scope); +} + +template +SubtypingResult Subtyping::isContravariantWith(SubtypingEnvironment& env, const TryPair& pair, NotNull scope) +{ + return isContravariantWith(env, pair.first, pair.second, scope); +} + +template +SubtypingResult Subtyping::isInvariantWith(SubtypingEnvironment& env, const TryPair& pair, NotNull scope) +{ + return isInvariantWith(env, pair.first, pair.second); +} + +/* + * This is much simpler than the Unifier implementation because we don't + * actually care about potential "cross-talk" between union parts that match the + * left side. + * + * In fact, we're very limited in what we can do: If multiple choices match, but + * all of them have non-overlapping constraints, then we're stuck with an "or" + * conjunction of constraints. Solving this in the general case is quite + * difficult. + * + * For example, we cannot dispatch anything from this constraint: + * + * {x: number, y: string} <: {x: number, y: 'a} | {x: 'b, y: string} + * + * From this constraint, we can know that either string <: 'a or number <: 'b, + * but we don't know which! + * + * However: + * + * {x: number, y: string} <: {x: number, y: 'a} | {x: number, y: string} + * + * We can dispatch this constraint because there is no 'or' conjunction. One of + * the arms requires 0 matches. + * + * {x: number, y: string, z: boolean} | {x: number, y: 'a, z: 'b} | {x: number, + * y: string, z: 'b} + * + * Here, we have two matches. One asks for string ~ 'a and boolean ~ 'b. The + * other just asks for boolean ~ 'b. We can dispatch this and only commit + * boolean ~ 'b. This constraint does not teach us anything about 'a. + */ +SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, TypeId subTy, const UnionType* superUnion, NotNull scope) +{ + // As per TAPL: T <: A | B iff T <: A || T <: B + + for (TypeId ty : superUnion) + { + SubtypingResult next = isCovariantWith(env, subTy, ty, scope); + if (next.isSubtype) + return SubtypingResult{true}; + } + + /* + * TODO: Is it possible here to use the context produced by the above + * isCovariantWith() calls to produce a richer, more helpful result in the + * case that the subtyping relation does not hold? + */ + return SubtypingResult{false}; +} + +SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, const UnionType* subUnion, TypeId superTy, NotNull scope) +{ + // As per TAPL: A | B <: T iff A <: T && B <: T + std::vector subtypings; + size_t i = 0; + for (TypeId ty : subUnion) + subtypings.push_back(isCovariantWith(env, ty, superTy, scope).withSubComponent(TypePath::Index{i++})); + return SubtypingResult::all(subtypings); +} + +SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, TypeId subTy, const IntersectionType* superIntersection, NotNull scope) +{ + // As per TAPL: T <: A & B iff T <: A && T <: B + std::vector subtypings; + size_t i = 0; + for (TypeId ty : superIntersection) + subtypings.push_back(isCovariantWith(env, subTy, ty, scope).withSuperComponent(TypePath::Index{i++})); + return SubtypingResult::all(subtypings); +} + +SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, const IntersectionType* subIntersection, TypeId superTy, NotNull scope) +{ + // As per TAPL: A & B <: T iff A <: T || B <: T + std::vector subtypings; + size_t i = 0; + for (TypeId ty : subIntersection) + subtypings.push_back(isCovariantWith(env, ty, superTy, scope).withSubComponent(TypePath::Index{i++})); + return SubtypingResult::any(subtypings); +} + +SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, const NegationType* subNegation, TypeId superTy, NotNull scope) +{ + TypeId negatedTy = follow(subNegation->ty); + + SubtypingResult result; + + // In order to follow a consistent codepath, rather than folding the + // isCovariantWith test down to its conclusion here, we test the subtyping test + // of the result of negating the type for never, unknown, any, and error. + if (is(negatedTy)) + { + // ¬never ~ unknown + result = isCovariantWith(env, builtinTypes->unknownType, superTy, scope).withSubComponent(TypePath::TypeField::Negated); + } + else if (is(negatedTy)) + { + // ¬unknown ~ never + result = isCovariantWith(env, builtinTypes->neverType, superTy, scope).withSubComponent(TypePath::TypeField::Negated); + } + else if (is(negatedTy)) + { + // ¬any ~ any + result = isCovariantWith(env, negatedTy, superTy, scope).withSubComponent(TypePath::TypeField::Negated); + } + else if (auto u = get(negatedTy)) + { + // ¬(A ∪ B) ~ ¬A ∩ ¬B + // follow intersection rules: A & B <: T iff A <: T && B <: T + std::vector subtypings; + + for (TypeId ty : u) + { + if (auto negatedPart = get(follow(ty))) + subtypings.push_back(isCovariantWith(env, negatedPart->ty, superTy, scope).withSubComponent(TypePath::TypeField::Negated)); + else + { + NegationType negatedTmp{ty}; + subtypings.push_back(isCovariantWith(env, &negatedTmp, superTy, scope)); + } + } + + result = SubtypingResult::all(subtypings); + } + else if (auto i = get(negatedTy)) + { + // ¬(A ∩ B) ~ ¬A ∪ ¬B + // follow union rules: A | B <: T iff A <: T || B <: T + std::vector subtypings; + + for (TypeId ty : i) + { + if (auto negatedPart = get(follow(ty))) + subtypings.push_back(isCovariantWith(env, negatedPart->ty, superTy, scope).withSubComponent(TypePath::TypeField::Negated)); + else + { + NegationType negatedTmp{ty}; + subtypings.push_back(isCovariantWith(env, &negatedTmp, superTy, scope)); + } + } + + result = SubtypingResult::any(subtypings); + } + else if (is(negatedTy)) + { + iceReporter->ice("attempting to negate a non-testable type"); + } + // negating a different subtype will get you a very wide type that's not a + // subtype of other stuff. + else + { + result = SubtypingResult{false}.withSubComponent(TypePath::TypeField::Negated); + } + + return result; +} + +SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, const TypeId subTy, const NegationType* superNegation, NotNull scope) +{ + TypeId negatedTy = follow(superNegation->ty); + + SubtypingResult result; + + if (is(negatedTy)) + { + // ¬never ~ unknown + result = isCovariantWith(env, subTy, builtinTypes->unknownType, scope); + } + else if (is(negatedTy)) + { + // ¬unknown ~ never + result = isCovariantWith(env, subTy, builtinTypes->neverType, scope); + } + else if (is(negatedTy)) + { + // ¬any ~ any + result = isSubtype(subTy, negatedTy, scope); + } + else if (auto u = get(negatedTy)) + { + // ¬(A ∪ B) ~ ¬A ∩ ¬B + // follow intersection rules: A & B <: T iff A <: T && B <: T + std::vector subtypings; + + for (TypeId ty : u) + { + if (auto negatedPart = get(follow(ty))) + subtypings.push_back(isCovariantWith(env, subTy, negatedPart->ty, scope)); + else + { + NegationType negatedTmp{ty}; + subtypings.push_back(isCovariantWith(env, subTy, &negatedTmp, scope)); + } + } + + return SubtypingResult::all(subtypings); + } + else if (auto i = get(negatedTy)) + { + // ¬(A ∩ B) ~ ¬A ∪ ¬B + // follow union rules: A | B <: T iff A <: T || B <: T + std::vector subtypings; + + for (TypeId ty : i) + { + if (auto negatedPart = get(follow(ty))) + subtypings.push_back(isCovariantWith(env, subTy, negatedPart->ty, scope)); + else + { + NegationType negatedTmp{ty}; + subtypings.push_back(isCovariantWith(env, subTy, &negatedTmp, scope)); + } + } + + return SubtypingResult::any(subtypings); + } + else if (auto p = get2(subTy, negatedTy)) + { + // number <: ¬boolean + // number type != p.second->type}; + } + else if (auto p = get2(subTy, negatedTy)) + { + // "foo" (p.first) && p.second->type == PrimitiveType::String) + result = {false}; + // false (p.first) && p.second->type == PrimitiveType::Boolean) + result = {false}; + // other cases are true + else + result = {true}; + } + else if (auto p = get2(subTy, negatedTy)) + { + if (p.first->type == PrimitiveType::String && get(p.second)) + result = {false}; + else if (p.first->type == PrimitiveType::Boolean && get(p.second)) + result = {false}; + else + result = {true}; + } + // the top class type is not actually a primitive type, so the negation of + // any one of them includes the top class type. + else if (auto p = get2(subTy, negatedTy)) + result = {true}; + else if (auto p = get(negatedTy); p && is(subTy)) + result = {p->type != PrimitiveType::Table}; + else if (auto p = get2(subTy, negatedTy)) + result = {p.second->type != PrimitiveType::Function}; + else if (auto p = get2(subTy, negatedTy)) + result = {*p.first != *p.second}; + else if (auto p = get2(subTy, negatedTy)) + result = SubtypingResult::negate(isCovariantWith(env, p.first, p.second, scope)); + else if (get2(subTy, negatedTy)) + result = {true}; + else if (is(negatedTy)) + iceReporter->ice("attempting to negate a non-testable type"); + else + result = {false}; + + return result.withSuperComponent(TypePath::TypeField::Negated); +} + +SubtypingResult Subtyping::isCovariantWith( + SubtypingEnvironment& env, + const PrimitiveType* subPrim, + const PrimitiveType* superPrim, + NotNull scope +) +{ + return {subPrim->type == superPrim->type}; +} + +SubtypingResult Subtyping::isCovariantWith( + SubtypingEnvironment& env, + const SingletonType* subSingleton, + const PrimitiveType* superPrim, + NotNull scope +) +{ + if (get(subSingleton) && superPrim->type == PrimitiveType::String) + return {true}; + else if (get(subSingleton) && superPrim->type == PrimitiveType::Boolean) + return {true}; + else + return {false}; +} + +SubtypingResult Subtyping::isCovariantWith( + SubtypingEnvironment& env, + const SingletonType* subSingleton, + const SingletonType* superSingleton, + NotNull scope +) +{ + return {*subSingleton == *superSingleton}; +} + +SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, const TableType* subTable, const TableType* superTable, NotNull scope) +{ + SubtypingResult result{true}; + + if (subTable->props.empty() && !subTable->indexer && superTable->indexer) + return {false}; + + for (const auto& [name, superProp] : superTable->props) + { + std::vector results; + if (auto subIter = subTable->props.find(name); subIter != subTable->props.end()) + results.push_back(isCovariantWith(env, subIter->second, superProp, name, scope)); + else if (subTable->indexer) + { + if (isCovariantWith(env, builtinTypes->stringType, subTable->indexer->indexType, scope).isSubtype) + { + if (superProp.isShared()) + results.push_back(isInvariantWith(env, subTable->indexer->indexResultType, superProp.type(), scope) + .withSubComponent(TypePath::TypeField::IndexResult) + .withSuperComponent(TypePath::Property::read(name))); + else + { + if (superProp.readTy) + results.push_back(isCovariantWith(env, subTable->indexer->indexResultType, *superProp.readTy, scope) + .withSubComponent(TypePath::TypeField::IndexResult) + .withSuperComponent(TypePath::Property::read(name))); + if (superProp.writeTy) + results.push_back(isContravariantWith(env, subTable->indexer->indexResultType, *superProp.writeTy, scope) + .withSubComponent(TypePath::TypeField::IndexResult) + .withSuperComponent(TypePath::Property::write(name))); + } + } + } + + if (results.empty()) + return SubtypingResult{false}; + + result.andAlso(SubtypingResult::all(results)); + } + + if (superTable->indexer) + { + if (subTable->indexer) + result.andAlso(isInvariantWith(env, *subTable->indexer, *superTable->indexer, scope)); + else + return {false}; + } + + return result; +} + +SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, const MetatableType* subMt, const MetatableType* superMt, NotNull scope) +{ + if (DFInt::LuauTypeSolverRelease >= 646) + { + return isCovariantWith(env, subMt->table, superMt->table, scope) + .withBothComponent(TypePath::TypeField::Table) + .andAlso(isCovariantWith(env, subMt->metatable, superMt->metatable, scope).withBothComponent(TypePath::TypeField::Metatable)); + } + else + { + return isCovariantWith(env, subMt->table, superMt->table, scope) + .andAlso(isCovariantWith(env, subMt->metatable, superMt->metatable, scope).withBothComponent(TypePath::TypeField::Metatable)); + } +} + +SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, const MetatableType* subMt, const TableType* superTable, NotNull scope) +{ + if (auto subTable = get(follow(subMt->table))) + { + // Metatables cannot erase properties from the table they're attached to, so + // the subtyping rule for this is just if the table component is a subtype + // of the supertype table. + // + // There's a flaw here in that if the __index metamethod contributes a new + // field that would satisfy the subtyping relationship, we'll erronously say + // that the metatable isn't a subtype of the table, even though they have + // compatible properties/shapes. We'll revisit this later when we have a + // better understanding of how important this is. + return isCovariantWith(env, subTable, superTable, scope); + } + else + { + // TODO: This may be a case we actually hit? + return {false}; + } +} + +SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, const ClassType* subClass, const ClassType* superClass, NotNull scope) +{ + return {isSubclass(subClass, superClass)}; +} + +SubtypingResult Subtyping::isCovariantWith( + SubtypingEnvironment& env, + TypeId subTy, + const ClassType* subClass, + TypeId superTy, + const TableType* superTable, + NotNull scope +) +{ + SubtypingResult result{true}; + + env.substitutions[superTy] = subTy; + + for (const auto& [name, prop] : superTable->props) + { + if (auto classProp = lookupClassProp(subClass, name)) + { + result.andAlso(isCovariantWith(env, *classProp, prop, name, scope)); + } + else + { + result = {false}; + break; + } + } + + env.substitutions[superTy] = nullptr; + + return result; +} + +SubtypingResult Subtyping::isCovariantWith( + SubtypingEnvironment& env, + const FunctionType* subFunction, + const FunctionType* superFunction, + NotNull scope +) +{ + SubtypingResult result; + { + result.orElse( + isContravariantWith(env, subFunction->argTypes, superFunction->argTypes, scope).withBothComponent(TypePath::PackField::Arguments) + ); + } + + result.andAlso(isCovariantWith(env, subFunction->retTypes, superFunction->retTypes, scope).withBothComponent(TypePath::PackField::Returns)); + + return result; +} + +SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, const TableType* subTable, const PrimitiveType* superPrim, NotNull scope) +{ + SubtypingResult result{false}; + if (superPrim->type == PrimitiveType::Table) + result.isSubtype = true; + + return result; +} + +SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, const PrimitiveType* subPrim, const TableType* superTable, NotNull scope) +{ + SubtypingResult result{false}; + if (subPrim->type == PrimitiveType::String) + { + if (auto metatable = getMetatable(builtinTypes->stringType, builtinTypes)) + { + if (auto mttv = get(follow(metatable))) + { + if (auto it = mttv->props.find("__index"); it != mttv->props.end()) + { + if (auto stringTable = get(it->second.type())) + result.orElse( + isCovariantWith(env, stringTable, superTable, scope).withSubPath(TypePath::PathBuilder().mt().readProp("__index").build()) + ); + } + } + } + } + else if (subPrim->type == PrimitiveType::Table) + { + const bool isSubtype = superTable->props.empty() && !superTable->indexer.has_value(); + return {isSubtype}; + } + + return result; +} + +SubtypingResult Subtyping::isCovariantWith( + SubtypingEnvironment& env, + const SingletonType* subSingleton, + const TableType* superTable, + NotNull scope +) +{ + SubtypingResult result{false}; + if (auto stringleton = get(subSingleton)) + { + if (auto metatable = getMetatable(builtinTypes->stringType, builtinTypes)) + { + if (auto mttv = get(follow(metatable))) + { + if (auto it = mttv->props.find("__index"); it != mttv->props.end()) + { + if (auto stringTable = get(it->second.type())) + result.orElse( + isCovariantWith(env, stringTable, superTable, scope).withSubPath(TypePath::PathBuilder().mt().readProp("__index").build()) + ); + } + } + } + } + return result; +} + +SubtypingResult Subtyping::isCovariantWith( + SubtypingEnvironment& env, + const TableIndexer& subIndexer, + const TableIndexer& superIndexer, + NotNull scope +) +{ + return isInvariantWith(env, subIndexer.indexType, superIndexer.indexType, scope) + .withBothComponent(TypePath::TypeField::IndexLookup) + .andAlso( + isInvariantWith(env, subIndexer.indexResultType, superIndexer.indexResultType, scope).withBothComponent(TypePath::TypeField::IndexResult) + ); +} + +SubtypingResult Subtyping::isCovariantWith( + SubtypingEnvironment& env, + const Property& subProp, + const Property& superProp, + const std::string& name, + NotNull scope +) +{ + SubtypingResult res{true}; + + if (superProp.isShared() && subProp.isShared()) + res.andAlso(isInvariantWith(env, subProp.type(), superProp.type(), scope).withBothComponent(TypePath::Property::read(name))); + else + { + if (superProp.readTy.has_value() && subProp.readTy.has_value()) + res.andAlso(isCovariantWith(env, *subProp.readTy, *superProp.readTy, scope).withBothComponent(TypePath::Property::read(name))); + if (superProp.writeTy.has_value() && subProp.writeTy.has_value()) + res.andAlso(isContravariantWith(env, *subProp.writeTy, *superProp.writeTy, scope).withBothComponent(TypePath::Property::write(name))); + + if (superProp.isReadWrite()) + { + if (subProp.isReadOnly()) + res.andAlso(SubtypingResult{false}.withBothComponent(TypePath::Property::read(name))); + else if (subProp.isWriteOnly()) + res.andAlso(SubtypingResult{false}.withBothComponent(TypePath::Property::write(name))); + } + } + + return res; +} + +SubtypingResult Subtyping::isCovariantWith( + SubtypingEnvironment& env, + const std::shared_ptr& subNorm, + const std::shared_ptr& superNorm, + NotNull scope +) +{ + if (!subNorm || !superNorm) + return {false, true}; + + SubtypingResult result = isCovariantWith(env, subNorm->tops, superNorm->tops, scope); + result.andAlso(isCovariantWith(env, subNorm->booleans, superNorm->booleans, scope)); + result.andAlso( + isCovariantWith(env, subNorm->classes, superNorm->classes, scope).orElse(isCovariantWith(env, subNorm->classes, superNorm->tables, scope)) + ); + result.andAlso(isCovariantWith(env, subNorm->errors, superNorm->errors, scope)); + result.andAlso(isCovariantWith(env, subNorm->nils, superNorm->nils, scope)); + result.andAlso(isCovariantWith(env, subNorm->numbers, superNorm->numbers, scope)); + result.andAlso(isCovariantWith(env, subNorm->strings, superNorm->strings, scope)); + result.andAlso(isCovariantWith(env, subNorm->strings, superNorm->tables, scope)); + result.andAlso(isCovariantWith(env, subNorm->threads, superNorm->threads, scope)); + result.andAlso(isCovariantWith(env, subNorm->buffers, superNorm->buffers, scope)); + result.andAlso(isCovariantWith(env, subNorm->tables, superNorm->tables, scope)); + result.andAlso(isCovariantWith(env, subNorm->functions, superNorm->functions, scope)); + // isCovariantWith(subNorm->tyvars, superNorm->tyvars); + return result; +} + +SubtypingResult Subtyping::isCovariantWith( + SubtypingEnvironment& env, + const NormalizedClassType& subClass, + const NormalizedClassType& superClass, + NotNull scope +) +{ + for (const auto& [subClassTy, _] : subClass.classes) + { + SubtypingResult result; + + for (const auto& [superClassTy, superNegations] : superClass.classes) + { + result.orElse(isCovariantWith(env, subClassTy, superClassTy, scope)); + if (!result.isSubtype) + continue; + + for (TypeId negation : superNegations) + { + result.andAlso(SubtypingResult::negate(isCovariantWith(env, subClassTy, negation, scope))); + if (result.isSubtype) + break; + } + } + + if (!result.isSubtype) + return result; + } + + return {true}; +} + +SubtypingResult Subtyping::isCovariantWith( + SubtypingEnvironment& env, + const NormalizedClassType& subClass, + const TypeIds& superTables, + NotNull scope +) +{ + for (const auto& [subClassTy, _] : subClass.classes) + { + SubtypingResult result; + + for (TypeId superTableTy : superTables) + result.orElse(isCovariantWith(env, subClassTy, superTableTy, scope)); + + if (!result.isSubtype) + return result; + } + + return {true}; +} + +SubtypingResult Subtyping::isCovariantWith( + SubtypingEnvironment& env, + const NormalizedStringType& subString, + const NormalizedStringType& superString, + NotNull scope +) +{ + bool isSubtype = Luau::isSubtype(subString, superString); + return {isSubtype}; +} + +SubtypingResult Subtyping::isCovariantWith( + SubtypingEnvironment& env, + const NormalizedStringType& subString, + const TypeIds& superTables, + NotNull scope +) +{ + if (subString.isNever()) + return {true}; + + if (subString.isCofinite) + { + SubtypingResult result; + for (const auto& superTable : superTables) + { + result.orElse(isCovariantWith(env, builtinTypes->stringType, superTable, scope)); + if (result.isSubtype) + return result; + } + return result; + } + + // Finite case + // S = s1 | s2 | s3 ... sn <: t1 | t2 | ... | tn + // iff for some ti, S <: ti + // iff for all sj, sj <: ti + for (const auto& superTable : superTables) + { + SubtypingResult result{true}; + for (const auto& [_, subString] : subString.singletons) + { + result.andAlso(isCovariantWith(env, subString, superTable, scope)); + if (!result.isSubtype) + break; + } + + if (!result.isSubtype) + continue; + else + return result; + } + + return {false}; +} + +SubtypingResult Subtyping::isCovariantWith( + SubtypingEnvironment& env, + const NormalizedFunctionType& subFunction, + const NormalizedFunctionType& superFunction, + NotNull scope +) +{ + if (subFunction.isNever()) + return {true}; + else if (superFunction.isTop) + return {true}; + else + return isCovariantWith(env, subFunction.parts, superFunction.parts, scope); +} + +SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, const TypeIds& subTypes, const TypeIds& superTypes, NotNull scope) +{ + std::vector results; + + for (TypeId subTy : subTypes) + { + results.emplace_back(); + for (TypeId superTy : superTypes) + results.back().orElse(isCovariantWith(env, subTy, superTy, scope)); + } + + return SubtypingResult::all(results); +} + +SubtypingResult Subtyping::isCovariantWith( + SubtypingEnvironment& env, + const VariadicTypePack* subVariadic, + const VariadicTypePack* superVariadic, + NotNull scope +) +{ + return isCovariantWith(env, subVariadic->ty, superVariadic->ty, scope).withBothComponent(TypePath::TypeField::Variadic); +} + +bool Subtyping::bindGeneric(SubtypingEnvironment& env, TypeId subTy, TypeId superTy) +{ + if (variance == Variance::Covariant) + { + if (!get(subTy)) + return false; + + if (!env.mappedGenerics.find(subTy) && env.containsMappedType(subTy)) + iceReporter->ice("attempting to modify bounds of a potentially visited generic"); + + env.mappedGenerics[subTy].upperBound.insert(superTy); + } + else + { + if (!get(superTy)) + return false; + + if (!env.mappedGenerics.find(superTy) && env.containsMappedType(superTy)) + iceReporter->ice("attempting to modify bounds of a potentially visited generic"); + + env.mappedGenerics[superTy].lowerBound.insert(subTy); + } + + return true; +} + +SubtypingResult Subtyping::isCovariantWith( + SubtypingEnvironment& env, + const TypeFunctionInstanceType* subFunctionInstance, + const TypeId superTy, + NotNull scope +) +{ + // Reduce the type function instance + auto [ty, errors] = handleTypeFunctionReductionResult(subFunctionInstance, scope); + + // If we return optional, that means the type function was irreducible - we can reduce that to never + return isCovariantWith(env, ty, superTy, scope).withErrors(errors).withSubComponent(TypePath::Reduction{ty}); +} + +SubtypingResult Subtyping::isCovariantWith( + SubtypingEnvironment& env, + const TypeId subTy, + const TypeFunctionInstanceType* superFunctionInstance, + NotNull scope +) +{ + // Reduce the type function instance + auto [ty, errors] = handleTypeFunctionReductionResult(superFunctionInstance, scope); + return isCovariantWith(env, subTy, ty, scope).withErrors(errors).withSuperComponent(TypePath::Reduction{ty}); +} + +/* + * If, when performing a subtyping test, we encounter a generic on the left + * side, it is permissible to tentatively bind that generic to the right side + * type. + */ +bool Subtyping::bindGeneric(SubtypingEnvironment& env, TypePackId subTp, TypePackId superTp) +{ + if (variance == Variance::Contravariant) + std::swap(superTp, subTp); + + if (!get(subTp)) + return false; + + if (TypePackId* m = env.getMappedPackBounds(subTp)) + return *m == superTp; + + env.mappedGenericPacks[subTp] = superTp; + + return true; +} + +template +TypeId Subtyping::makeAggregateType(const Container& container, TypeId orElse) +{ + if (container.empty()) + return orElse; + else if (container.size() == 1) + return *begin(container); + else + return arena->addType(T{std::vector(begin(container), end(container))}); +} + +std::pair Subtyping::handleTypeFunctionReductionResult(const TypeFunctionInstanceType* functionInstance, NotNull scope) +{ + TypeFunctionContext context{arena, builtinTypes, scope, normalizer, typeFunctionRuntime, iceReporter, NotNull{&limits}}; + TypeId function = arena->addType(*functionInstance); + FunctionGraphReductionResult result = reduceTypeFunctions(function, {}, context, true); + ErrorVec errors; + if (result.blockedTypes.size() != 0 || result.blockedPacks.size() != 0) + { + errors.push_back(TypeError{{}, UninhabitedTypeFunction{function}}); + return {builtinTypes->neverType, errors}; + } + if (result.reducedTypes.contains(function)) + return {function, errors}; + return {builtinTypes->neverType, errors}; +} + +} // namespace Luau diff --git a/Analysis/src/Symbol.cpp b/Analysis/src/Symbol.cpp index 5922bb50e..5e5b9d8cc 100644 --- a/Analysis/src/Symbol.cpp +++ b/Analysis/src/Symbol.cpp @@ -3,9 +3,23 @@ #include "Luau/Common.h" +LUAU_FASTFLAG(LuauSolverV2) + namespace Luau { +bool Symbol::operator==(const Symbol& rhs) const +{ + if (local) + return local == rhs.local; + else if (global.value) + return rhs.global.value && global == rhs.global.value; // Subtlety: AstName::operator==(const char*) uses strcmp, not pointer identity. + else if (FFlag::LuauSolverV2) + return !rhs.local && !rhs.global.value; // Reflexivity: we already know `this` Symbol is empty, so check that rhs is. + else + return false; +} + std::string toString(const Symbol& name) { if (name.local) diff --git a/Analysis/src/TableLiteralInference.cpp b/Analysis/src/TableLiteralInference.cpp new file mode 100644 index 000000000..bcd5c1d8c --- /dev/null +++ b/Analysis/src/TableLiteralInference.cpp @@ -0,0 +1,460 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Luau/Ast.h" +#include "Luau/Normalize.h" +#include "Luau/Simplify.h" +#include "Luau/Type.h" +#include "Luau/ToString.h" +#include "Luau/TypeArena.h" +#include "Luau/TypeUtils.h" +#include "Luau/Unifier2.h" + +LUAU_DYNAMIC_FASTINT(LuauTypeSolverRelease) + +namespace Luau +{ + +// A fast approximation of subTy <: superTy +static bool fastIsSubtype(TypeId subTy, TypeId superTy) +{ + Relation r = relate(superTy, subTy); + return r == Relation::Coincident || r == Relation::Superset; +} + +static bool isRecord(const AstExprTable::Item& item) +{ + if (item.kind == AstExprTable::Item::Record) + return true; + else if (item.kind == AstExprTable::Item::General && item.key->is()) + return true; + else + return false; +} + +static std::optional extractMatchingTableType(std::vector& tables, TypeId exprType, NotNull builtinTypes) +{ + if (tables.empty()) + return std::nullopt; + + const TableType* exprTable = get(follow(exprType)); + if (!exprTable) + return std::nullopt; + + size_t tableCount = 0; + std::optional firstTable; + + for (TypeId ty : tables) + { + ty = follow(ty); + if (auto tt = get(ty)) + { + // If the expected table has a key whose type is a string or boolean + // singleton and the corresponding exprType property does not match, + // then skip this table. + + if (!firstTable) + firstTable = ty; + ++tableCount; + + for (const auto& [name, expectedProp] : tt->props) + { + if (!expectedProp.readTy) + continue; + + const TypeId expectedType = follow(*expectedProp.readTy); + + auto st = get(expectedType); + if (!st) + continue; + + auto it = exprTable->props.find(name); + if (it == exprTable->props.end()) + continue; + + const auto& [_name, exprProp] = *it; + + if (!exprProp.readTy) + continue; + + const TypeId propType = follow(*exprProp.readTy); + + const FreeType* ft = get(propType); + + if (ft && get(ft->lowerBound)) + { + if (fastIsSubtype(builtinTypes->booleanType, ft->upperBound) && fastIsSubtype(expectedType, builtinTypes->booleanType)) + { + return ty; + } + + if (fastIsSubtype(builtinTypes->stringType, ft->upperBound) && fastIsSubtype(expectedType, ft->lowerBound)) + { + return ty; + } + } + } + } + } + + if (tableCount == 1) + { + LUAU_ASSERT(firstTable); + return firstTable; + } + + return std::nullopt; +} + +TypeId matchLiteralType( + NotNull> astTypes, + NotNull> astExpectedTypes, + NotNull builtinTypes, + NotNull arena, + NotNull unifier, + TypeId expectedType, + TypeId exprType, + const AstExpr* expr, + std::vector& toBlock +) +{ + /* + * Table types that arise from literal table expressions have some + * properties that make this algorithm much simpler. + * + * Most importantly, the parts of the type that arise directly from the + * table expression are guaranteed to be acyclic. This means we can do all + * kinds of naive depth first traversal shenanigans and not worry about + * nasty details like aliasing or reentrancy. + * + * We are therefore completely free to mutate these portions of the + * TableType however we choose! We'll take advantage of this property to do + * things like replace explicit named properties with indexers as required + * by the expected type. + */ + if (!isLiteral(expr)) + return exprType; + + expectedType = follow(expectedType); + exprType = follow(exprType); + + if (get(expectedType) || get(expectedType)) + { + // "Narrowing" to unknown or any is not going to do anything useful. + return exprType; + } + + if (expr->is()) + { + auto ft = get(exprType); + if (ft && get(ft->lowerBound) && fastIsSubtype(builtinTypes->stringType, ft->upperBound) && + fastIsSubtype(ft->lowerBound, builtinTypes->stringType)) + { + // if the upper bound is a subtype of the expected type, we can push the expected type in + Relation upperBoundRelation = relate(ft->upperBound, expectedType); + if (upperBoundRelation == Relation::Subset || upperBoundRelation == Relation::Coincident) + { + emplaceType(asMutable(exprType), expectedType); + return exprType; + } + + // likewise, if the lower bound is a subtype, we can force the expected type in + // if this is the case and the previous relation failed, it means that the primitive type + // constraint was going to have to select the lower bound for this type anyway. + Relation lowerBoundRelation = relate(ft->lowerBound, expectedType); + if (lowerBoundRelation == Relation::Subset || lowerBoundRelation == Relation::Coincident) + { + emplaceType(asMutable(exprType), expectedType); + return exprType; + } + } + } + else if (expr->is()) + { + auto ft = get(exprType); + if (ft && get(ft->lowerBound) && fastIsSubtype(builtinTypes->booleanType, ft->upperBound) && + fastIsSubtype(ft->lowerBound, builtinTypes->booleanType)) + { + // if the upper bound is a subtype of the expected type, we can push the expected type in + Relation upperBoundRelation = relate(ft->upperBound, expectedType); + if (upperBoundRelation == Relation::Subset || upperBoundRelation == Relation::Coincident) + { + emplaceType(asMutable(exprType), expectedType); + return exprType; + } + + // likewise, if the lower bound is a subtype, we can force the expected type in + // if this is the case and the previous relation failed, it means that the primitive type + // constraint was going to have to select the lower bound for this type anyway. + Relation lowerBoundRelation = relate(ft->lowerBound, expectedType); + if (lowerBoundRelation == Relation::Subset || lowerBoundRelation == Relation::Coincident) + { + emplaceType(asMutable(exprType), expectedType); + return exprType; + } + } + } + + if (expr->is() || expr->is() || expr->is() || expr->is()) + { + if (auto ft = get(exprType); ft && fastIsSubtype(ft->upperBound, expectedType)) + { + emplaceType(asMutable(exprType), expectedType); + return exprType; + } + + Relation r = relate(exprType, expectedType); + if (r == Relation::Coincident || r == Relation::Subset) + return expectedType; + + return exprType; + } + + // TODO: lambdas + + if (auto exprTable = expr->as()) + { + TableType* const tableTy = getMutable(exprType); + LUAU_ASSERT(tableTy); + + const TableType* expectedTableTy = get(expectedType); + + if (!expectedTableTy) + { + if (auto utv = get(expectedType)) + { + std::vector parts{begin(utv), end(utv)}; + + std::optional tt = extractMatchingTableType(parts, exprType, builtinTypes); + + if (tt) + { + TypeId res = matchLiteralType(astTypes, astExpectedTypes, builtinTypes, arena, unifier, *tt, exprType, expr, toBlock); + + parts.push_back(res); + return arena->addType(UnionType{std::move(parts)}); + } + } + + return exprType; + } + + for (const AstExprTable::Item& item : exprTable->items) + { + if (isRecord(item)) + { + const AstArray& s = item.key->as()->value; + std::string keyStr{s.data, s.data + s.size}; + auto it = tableTy->props.find(keyStr); + LUAU_ASSERT(it != tableTy->props.end()); + + Property& prop = it->second; + + // Table literals always initially result in shared read-write types + LUAU_ASSERT(prop.isShared()); + TypeId propTy = *prop.readTy; + + auto it2 = expectedTableTy->props.find(keyStr); + + if (it2 == expectedTableTy->props.end()) + { + // expectedType may instead have an indexer. This is + // kind of interesting because it means we clip the prop + // from the exprType and fold it into the indexer. + if (expectedTableTy->indexer && isString(expectedTableTy->indexer->indexType)) + { + (*astExpectedTypes)[item.key] = expectedTableTy->indexer->indexType; + (*astExpectedTypes)[item.value] = expectedTableTy->indexer->indexResultType; + + TypeId matchedType = matchLiteralType( + astTypes, + astExpectedTypes, + builtinTypes, + arena, + unifier, + expectedTableTy->indexer->indexResultType, + propTy, + item.value, + toBlock + ); + + if (tableTy->indexer) + unifier->unify(matchedType, tableTy->indexer->indexResultType); + else + tableTy->indexer = TableIndexer{expectedTableTy->indexer->indexType, matchedType}; + + tableTy->props.erase(keyStr); + } + + // If it's just an extra property and the expected type + // has no indexer, there's no work to do here. + + continue; + } + + LUAU_ASSERT(it2 != expectedTableTy->props.end()); + + const Property& expectedProp = it2->second; + + std::optional expectedReadTy = expectedProp.readTy; + std::optional expectedWriteTy = expectedProp.writeTy; + + TypeId matchedType = nullptr; + + // Important optimization: If we traverse into the read and + // write types separately even when they are shared, we go + // quadratic in a hurry. + if (expectedProp.isShared()) + { + matchedType = + matchLiteralType(astTypes, astExpectedTypes, builtinTypes, arena, unifier, *expectedReadTy, propTy, item.value, toBlock); + prop.readTy = matchedType; + prop.writeTy = matchedType; + } + else if (expectedReadTy) + { + matchedType = + matchLiteralType(astTypes, astExpectedTypes, builtinTypes, arena, unifier, *expectedReadTy, propTy, item.value, toBlock); + prop.readTy = matchedType; + prop.writeTy.reset(); + } + else if (expectedWriteTy) + { + matchedType = + matchLiteralType(astTypes, astExpectedTypes, builtinTypes, arena, unifier, *expectedWriteTy, propTy, item.value, toBlock); + prop.readTy.reset(); + prop.writeTy = matchedType; + } + else + { + // Also important: It is presently the case that all + // table properties are either read-only, or have the + // same read and write types. + LUAU_ASSERT(!"Should be unreachable"); + } + + LUAU_ASSERT(prop.readTy || prop.writeTy); + + LUAU_ASSERT(matchedType); + + (*astExpectedTypes)[item.value] = matchedType; + } + else if (item.kind == AstExprTable::Item::List) + { + LUAU_ASSERT(tableTy->indexer); + + if (expectedTableTy->indexer) + { + const TypeId* propTy = astTypes->find(item.value); + LUAU_ASSERT(propTy); + + unifier->unify(expectedTableTy->indexer->indexType, builtinTypes->numberType); + TypeId matchedType = matchLiteralType( + astTypes, + astExpectedTypes, + builtinTypes, + arena, + unifier, + expectedTableTy->indexer->indexResultType, + *propTy, + item.value, + toBlock + ); + + // if the index result type is the prop type, we can replace it with the matched type here. + if (tableTy->indexer->indexResultType == *propTy) + tableTy->indexer->indexResultType = matchedType; + } + } + else if (item.kind == AstExprTable::Item::General) + { + + // We have { ..., [blocked] : somePropExpr, ...} + // If blocked resolves to a string, we will then take care of this above + // If it resolves to some other kind of expression, we don't have a way of folding this information into indexer + // because there is no named prop to remove + // We should just block here + const TypeId* keyTy = astTypes->find(item.key); + LUAU_ASSERT(keyTy); + TypeId tKey = follow(*keyTy); + if (DFInt::LuauTypeSolverRelease >= 648) + { + LUAU_ASSERT(!is(tKey)); + } + else if (get(tKey)) + toBlock.push_back(tKey); + const TypeId* propTy = astTypes->find(item.value); + LUAU_ASSERT(propTy); + TypeId tProp = follow(*propTy); + if (DFInt::LuauTypeSolverRelease >= 648) + { + LUAU_ASSERT(!is(tKey)); + } + else if (get(tProp)) + toBlock.push_back(tProp); + // Populate expected types for non-string keys declared with [] (the code below will handle the case where they are strings) + if (!item.key->as() && expectedTableTy->indexer) + (*astExpectedTypes)[item.key] = expectedTableTy->indexer->indexType; + } + else + LUAU_ASSERT(!"Unexpected"); + } + + // Keys that the expectedType says we should have, but that aren't + // specified by the AST fragment. + // + // If any such keys are options, then we'll add them to the expression + // type. + // + // We use std::optional here because the empty string is a + // perfectly reasonable value to insert into the set. We'll use + // std::nullopt as our sentinel value. + Set> missingKeys{{}}; + for (const auto& [name, _] : expectedTableTy->props) + missingKeys.insert(name); + + for (const AstExprTable::Item& item : exprTable->items) + { + if (item.key) + { + if (const auto str = item.key->as()) + { + missingKeys.erase(std::string(str->value.data, str->value.size)); + } + } + } + + for (const auto& key : missingKeys) + { + LUAU_ASSERT(key.has_value()); + + auto it = expectedTableTy->props.find(*key); + LUAU_ASSERT(it != expectedTableTy->props.end()); + + const Property& expectedProp = it->second; + + Property exprProp; + + if (expectedProp.readTy && isOptional(*expectedProp.readTy)) + exprProp.readTy = *expectedProp.readTy; + if (expectedProp.writeTy && isOptional(*expectedProp.writeTy)) + exprProp.writeTy = *expectedProp.writeTy; + + // If the property isn't actually optional, do nothing. + if (exprProp.readTy || exprProp.writeTy) + tableTy->props[*key] = std::move(exprProp); + } + + // If the expected table has an indexer, then the provided table can + // have one too. + // TODO: If the expected table also has an indexer, we might want to + // push the expected indexer's types into it. + if (expectedTableTy->indexer && !tableTy->indexer) + { + tableTy->indexer = expectedTableTy->indexer; + } + } + + return exprType; +} + +} // namespace Luau diff --git a/Analysis/src/ToDot.cpp b/Analysis/src/ToDot.cpp index 117d39d20..e3f4fd3b0 100644 --- a/Analysis/src/ToDot.cpp +++ b/Analysis/src/ToDot.cpp @@ -4,11 +4,14 @@ #include "Luau/ToString.h" #include "Luau/TypePack.h" #include "Luau/Type.h" +#include "Luau/TypeFunction.h" #include "Luau/StringUtils.h" #include #include +LUAU_FASTFLAG(LuauSolverV2); + namespace Luau { @@ -52,7 +55,7 @@ bool StateDot::canDuplicatePrimitive(TypeId ty) if (get(ty)) return false; - return get(ty) || get(ty); + return get(ty) || get(ty) || get(ty) || get(ty); } void StateDot::visitChild(TypeId ty, int parentIndex, const char* linkName) @@ -76,6 +79,10 @@ void StateDot::visitChild(TypeId ty, int parentIndex, const char* linkName) formatAppend(result, "n%d [label=\"%s\"];\n", index, toString(ty).c_str()); else if (get(ty)) formatAppend(result, "n%d [label=\"any\"];\n", index); + else if (get(ty)) + formatAppend(result, "n%d [label=\"unknown\"];\n", index); + else if (get(ty)) + formatAppend(result, "n%d [label=\"never\"];\n", index); } else { @@ -139,153 +146,227 @@ void StateDot::visitChildren(TypeId ty, int index) startNode(index); startNodeLabel(); - if (const BoundType* btv = get(ty)) - { - formatAppend(result, "BoundType %d", index); - finishNodeLabel(ty); - finishNode(); - - visitChild(btv->boundTo, index); - } - else if (const FunctionType* ftv = get(ty)) + auto go = [&](auto&& t) { - formatAppend(result, "FunctionType %d", index); - finishNodeLabel(ty); - finishNode(); - - visitChild(ftv->argTypes, index, "arg"); - visitChild(ftv->retTypes, index, "ret"); - } - else if (const TableType* ttv = get(ty)) - { - if (ttv->name) - formatAppend(result, "TableType %s", ttv->name->c_str()); - else if (ttv->syntheticName) - formatAppend(result, "TableType %s", ttv->syntheticName->c_str()); - else - formatAppend(result, "TableType %d", index); - finishNodeLabel(ty); - finishNode(); + using T = std::decay_t; - if (ttv->boundTo) - return visitChild(*ttv->boundTo, index, "boundTo"); + if constexpr (std::is_same_v) + { + formatAppend(result, "BoundType %d", index); + finishNodeLabel(ty); + finishNode(); - for (const auto& [name, prop] : ttv->props) - visitChild(prop.type, index, name.c_str()); - if (ttv->indexer) + visitChild(t.boundTo, index); + } + else if constexpr (std::is_same_v) { - visitChild(ttv->indexer->indexType, index, "[index]"); - visitChild(ttv->indexer->indexResultType, index, "[value]"); + formatAppend(result, "BlockedType %d", index); + finishNodeLabel(ty); + finishNode(); } - for (TypeId itp : ttv->instantiatedTypeParams) - visitChild(itp, index, "typeParam"); + else if constexpr (std::is_same_v) + { + formatAppend(result, "FunctionType %d", index); + finishNodeLabel(ty); + finishNode(); - for (TypePackId itp : ttv->instantiatedTypePackParams) - visitChild(itp, index, "typePackParam"); - } - else if (const MetatableType* mtv = get(ty)) - { - formatAppend(result, "MetatableType %d", index); - finishNodeLabel(ty); - finishNode(); + visitChild(t.argTypes, index, "arg"); + visitChild(t.retTypes, index, "ret"); + } + else if constexpr (std::is_same_v) + { + if (t.name) + formatAppend(result, "TableType %s", t.name->c_str()); + else if (t.syntheticName) + formatAppend(result, "TableType %s", t.syntheticName->c_str()); + else + formatAppend(result, "TableType %d", index); + finishNodeLabel(ty); + finishNode(); + + if (t.boundTo) + return visitChild(*t.boundTo, index, "boundTo"); + + for (const auto& [name, prop] : t.props) + visitChild(prop.type(), index, name.c_str()); + if (t.indexer) + { + visitChild(t.indexer->indexType, index, "[index]"); + visitChild(t.indexer->indexResultType, index, "[value]"); + } + for (TypeId itp : t.instantiatedTypeParams) + visitChild(itp, index, "typeParam"); + + for (TypePackId itp : t.instantiatedTypePackParams) + visitChild(itp, index, "typePackParam"); + } + else if constexpr (std::is_same_v) + { + formatAppend(result, "MetatableType %d", index); + finishNodeLabel(ty); + finishNode(); - visitChild(mtv->table, index, "table"); - visitChild(mtv->metatable, index, "metatable"); - } - else if (const UnionType* utv = get(ty)) - { - formatAppend(result, "UnionType %d", index); - finishNodeLabel(ty); - finishNode(); + visitChild(t.table, index, "table"); + visitChild(t.metatable, index, "metatable"); + } + else if constexpr (std::is_same_v) + { + formatAppend(result, "UnionType %d", index); + finishNodeLabel(ty); + finishNode(); - for (TypeId opt : utv->options) - visitChild(opt, index); - } - else if (const IntersectionType* itv = get(ty)) - { - formatAppend(result, "IntersectionType %d", index); - finishNodeLabel(ty); - finishNode(); + for (TypeId opt : t.options) + visitChild(opt, index); + } + else if constexpr (std::is_same_v) + { + formatAppend(result, "IntersectionType %d", index); + finishNodeLabel(ty); + finishNode(); - for (TypeId part : itv->parts) - visitChild(part, index); - } - else if (const GenericType* gtv = get(ty)) - { - if (gtv->explicitName) - formatAppend(result, "GenericType %s", gtv->name.c_str()); - else - formatAppend(result, "GenericType %d", index); - finishNodeLabel(ty); - finishNode(); - } - else if (const FreeType* ftv = get(ty)) - { - formatAppend(result, "FreeType %d", index); - finishNodeLabel(ty); - finishNode(); - } - else if (get(ty)) - { - formatAppend(result, "AnyType %d", index); - finishNodeLabel(ty); - finishNode(); - } - else if (get(ty)) - { - formatAppend(result, "PrimitiveType %s", toString(ty).c_str()); - finishNodeLabel(ty); - finishNode(); - } - else if (get(ty)) - { - formatAppend(result, "ErrorType %d", index); - finishNodeLabel(ty); - finishNode(); - } - else if (const ClassType* ctv = get(ty)) - { - formatAppend(result, "ClassType %s", ctv->name.c_str()); - finishNodeLabel(ty); - finishNode(); + for (TypeId part : t.parts) + visitChild(part, index); + } + else if constexpr (std::is_same_v) + { + formatAppend(result, "LazyType %d", index); + finishNodeLabel(ty); + finishNode(); + } + else if constexpr (std::is_same_v) + { + formatAppend(result, "PendingExpansionType %d", index); + finishNodeLabel(ty); + finishNode(); + } + else if constexpr (std::is_same_v) + { + if (t.explicitName) + formatAppend(result, "GenericType %s", t.name.c_str()); + else + formatAppend(result, "GenericType %d", index); + finishNodeLabel(ty); + finishNode(); + } + else if constexpr (std::is_same_v) + { + formatAppend(result, "FreeType %d", index); + finishNodeLabel(ty); + finishNode(); + + if (FFlag::LuauSolverV2) + { + if (!get(t.lowerBound)) + visitChild(t.lowerBound, index, "[lowerBound]"); + + if (!get(t.upperBound)) + visitChild(t.upperBound, index, "[upperBound]"); + } + } + else if constexpr (std::is_same_v) + { + formatAppend(result, "AnyType %d", index); + finishNodeLabel(ty); + finishNode(); + } + else if constexpr (std::is_same_v) + { + formatAppend(result, "NoRefineType %d", index); + finishNodeLabel(ty); + finishNode(); + } + else if constexpr (std::is_same_v) + { + formatAppend(result, "UnknownType %d", index); + finishNodeLabel(ty); + finishNode(); + } + else if constexpr (std::is_same_v) + { + formatAppend(result, "NeverType %d", index); + finishNodeLabel(ty); + finishNode(); + } + else if constexpr (std::is_same_v) + { + formatAppend(result, "PrimitiveType %s", toString(ty).c_str()); + finishNodeLabel(ty); + finishNode(); + } + else if constexpr (std::is_same_v) + { + formatAppend(result, "ErrorType %d", index); + finishNodeLabel(ty); + finishNode(); + } + else if constexpr (std::is_same_v) + { + formatAppend(result, "ClassType %s", t.name.c_str()); + finishNodeLabel(ty); + finishNode(); - for (const auto& [name, prop] : ctv->props) - visitChild(prop.type, index, name.c_str()); + for (const auto& [name, prop] : t.props) + visitChild(prop.type(), index, name.c_str()); - if (ctv->parent) - visitChild(*ctv->parent, index, "[parent]"); + if (t.parent) + visitChild(*t.parent, index, "[parent]"); - if (ctv->metatable) - visitChild(*ctv->metatable, index, "[metatable]"); - } - else if (const SingletonType* stv = get(ty)) - { - std::string res; + if (t.metatable) + visitChild(*t.metatable, index, "[metatable]"); - if (const StringSingleton* ss = get(stv)) + if (t.indexer) + { + visitChild(t.indexer->indexType, index, "[index]"); + visitChild(t.indexer->indexResultType, index, "[value]"); + } + } + else if constexpr (std::is_same_v) { - // Don't put in quotes anywhere. If it's outside of the call to escape, - // then it's invalid syntax. If it's inside, then escaping is super noisy. - res = "string: " + escape(ss->value); + std::string res; + + if (const StringSingleton* ss = get(&t)) + { + // Don't put in quotes anywhere. If it's outside of the call to escape, + // then it's invalid syntax. If it's inside, then escaping is super noisy. + res = "string: " + escape(ss->value); + } + else if (const BooleanSingleton* bs = get(&t)) + { + res = "boolean: "; + res += bs->value ? "true" : "false"; + } + else + LUAU_ASSERT(!"unknown singleton type"); + + formatAppend(result, "SingletonType %s", res.c_str()); + finishNodeLabel(ty); + finishNode(); } - else if (const BooleanSingleton* bs = get(stv)) + else if constexpr (std::is_same_v) { - res = "boolean: "; - res += bs->value ? "true" : "false"; + formatAppend(result, "NegationType %d", index); + finishNodeLabel(ty); + finishNode(); + + visitChild(t.ty, index, "[negated]"); + } + else if constexpr (std::is_same_v) + { + formatAppend(result, "TypeFunctionInstanceType %s %d", t.function->name.c_str(), index); + finishNodeLabel(ty); + finishNode(); + + for (TypeId tyParam : t.typeArguments) + visitChild(tyParam, index); + + for (TypePackId tpParam : t.packArguments) + visitChild(tpParam, index); } else - LUAU_ASSERT(!"unknown singleton type"); + static_assert(always_false_v, "unknown type kind"); + }; - formatAppend(result, "SingletonType %s", res.c_str()); - finishNodeLabel(ty); - finishNode(); - } - else - { - LUAU_ASSERT(!"unknown type kind"); - finishNodeLabel(ty); - finishNode(); - } + visit(go, ty->ty); } void StateDot::visitChildren(TypePackId tp, int index) diff --git a/Analysis/src/ToString.cpp b/Analysis/src/ToString.cpp index d0c539845..5b191d30d 100644 --- a/Analysis/src/ToString.cpp +++ b/Analysis/src/ToString.cpp @@ -1,27 +1,43 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/ToString.h" +#include "Luau/Common.h" #include "Luau/Constraint.h" +#include "Luau/DenseHash.h" #include "Luau/Location.h" #include "Luau/Scope.h" +#include "Luau/Set.h" #include "Luau/TxnLog.h" #include "Luau/TypeInfer.h" #include "Luau/TypePack.h" #include "Luau/Type.h" +#include "Luau/TypeFunction.h" #include "Luau/VisitType.h" +#include "Luau/TypeOrPack.h" #include #include +#include -LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) -LUAU_FASTFLAGVARIABLE(LuauFunctionReturnStringificationFixup, false) +LUAU_FASTFLAG(LuauSolverV2) /* - * Prefix generic typenames with gen- - * Additionally, free types will be prefixed with free- and suffixed with their level. eg free-a-4 - * Fair warning: Setting this will break a lot of Luau unit tests. + * Enables increasing levels of verbosity for Luau type names when stringifying. + * After level 2, test cases will break unpredictably because a pointer to their + * scope will be included in the stringification of generic and free types. + * + * Supported values: + * + * 0: Disabled, no changes. + * + * 1: Prefix free/generic types with free- and gen-, respectively. Also reveal + * hidden variadic tails. Display block count for local types. + * + * 2: Suffix free/generic types with their scope depth. + * + * 3: Suffix free/generic types with their scope pointer, if present. */ -LUAU_FASTFLAGVARIABLE(DebugLuauVerboseTypeNames, false) +LUAU_FASTINTVARIABLE(DebugLuauVerboseTypeNames, 0) LUAU_FASTFLAGVARIABLE(DebugLuauToStringNoLexicalSort, false) namespace Luau @@ -37,8 +53,8 @@ struct FindCyclicTypes final : TypeVisitor FindCyclicTypes& operator=(const FindCyclicTypes&) = delete; bool exhaustive = false; - std::unordered_set visited; - std::unordered_set visitedPacks; + Luau::Set visited{{}}; + Luau::Set visitedPacks{{}}; std::set cycles; std::set cycleTPs; @@ -54,17 +70,39 @@ struct FindCyclicTypes final : TypeVisitor bool visit(TypeId ty) override { - return visited.insert(ty).second; + return visited.insert(ty); } bool visit(TypePackId tp) override { - return visitedPacks.insert(tp).second; + return visitedPacks.insert(tp); + } + + bool visit(TypeId ty, const FreeType& ft) override + { + if (!visited.insert(ty)) + return false; + + if (FFlag::LuauSolverV2) + { + // TODO: Replace these if statements with assert()s when we + // delete FFlag::LuauSolverV2. + // + // When the old solver is used, these pointers are always + // unused. When the new solver is used, they are never null. + + if (ft.lowerBound) + traverse(ft.lowerBound); + if (ft.upperBound) + traverse(ft.upperBound); + } + + return false; } bool visit(TypeId ty, const TableType& ttv) override { - if (!visited.insert(ty).second) + if (!visited.insert(ty)) return false; if (ttv.name || ttv.syntheticName) @@ -85,6 +123,11 @@ struct FindCyclicTypes final : TypeVisitor { return false; } + + bool visit(TypeId, const PendingExpansionType&) override + { + return false; + } }; template @@ -122,10 +165,12 @@ struct StringifierState ToStringOptions& opts; ToStringResult& result; - std::unordered_map cycleNames; - std::unordered_map cycleTpNames; - std::unordered_set seen; - std::unordered_set usedNames; + DenseHashMap cycleNames{{}}; + DenseHashMap cycleTpNames{{}}; + Set seen{{}}; + // `$$$` was chosen as the tombstone for `usedNames` since it is not a valid name syntactically and is relatively short for string comparison + // reasons. + DenseHashSet usedNames{"$$$"}; size_t indentation = 0; bool exhaustive; @@ -144,7 +189,7 @@ struct StringifierState bool hasSeen(const void* tv) { void* ttv = const_cast(tv); - if (seen.find(ttv) != seen.end()) + if (seen.contains(ttv)) return true; seen.insert(ttv); @@ -154,9 +199,9 @@ struct StringifierState void unsee(const void* tv) { void* ttv = const_cast(tv); - auto iter = seen.find(ttv); - if (iter != seen.end()) - seen.erase(iter); + + if (seen.contains(ttv)) + seen.erase(ttv); } std::string getName(TypeId ty) @@ -169,7 +214,7 @@ struct StringifierState for (int count = 0; count < 256; ++count) { std::string candidate = generateName(usedNames.size() + count); - if (!usedNames.count(candidate)) + if (!usedNames.contains(candidate)) { usedNames.insert(candidate); n = candidate; @@ -192,7 +237,7 @@ struct StringifierState for (int count = 0; count < 256; ++count) { std::string candidate = generateName(previousNameIndex + count); - if (!usedNames.count(candidate)) + if (!usedNames.contains(candidate)) { previousNameIndex += count; usedNames.insert(candidate); @@ -219,11 +264,15 @@ struct StringifierState ++count; emit(count); - emit("-"); - char buffer[16]; - uint32_t s = uint32_t(intptr_t(scope) & 0xFFFFFF); - snprintf(buffer, sizeof(buffer), "0x%x", s); - emit(buffer); + + if (FInt::DebugLuauVerboseTypeNames >= 3) + { + emit("-"); + char buffer[16]; + uint32_t s = uint32_t(intptr_t(scope) & 0xFFFFFF); + snprintf(buffer, sizeof(buffer), "0x%x", s); + emit(buffer); + } } void emit(TypeLevel level) @@ -301,18 +350,72 @@ struct TypeStringifier return; } - auto it = state.cycleNames.find(tv); - if (it != state.cycleNames.end()) + if (auto p = state.cycleNames.find(tv)) { - state.emit(it->second); + state.emit(*p); return; } Luau::visit( - [this, tv](auto&& t) { + [this, tv](auto&& t) + { return (*this)(tv, t); }, - tv->ty); + tv->ty + ); + } + + void emitKey(const std::string& name) + { + if (isIdentifier(name)) + state.emit(name); + else + { + state.emit("[\""); + state.emit(escape(name)); + state.emit("\"]"); + } + state.emit(": "); + } + + void _newStringify(const std::string& name, const Property& prop) + { + bool comma = false; + if (prop.isShared()) + { + emitKey(name); + stringify(prop.type()); + return; + } + + if (prop.readTy) + { + state.emit("read "); + emitKey(name); + stringify(*prop.readTy); + comma = true; + } + if (prop.writeTy) + { + if (comma) + { + state.emit(","); + state.newline(); + } + + state.emit("write "); + emitKey(name); + stringify(*prop.writeTy); + } + } + + void stringify(const std::string& name, const Property& prop) + { + if (FFlag::LuauSolverV2) + return _newStringify(name, prop); + + emitKey(name); + stringify(prop.type()); } void stringify(TypePackId tp); @@ -364,17 +467,51 @@ struct TypeStringifier state.emit(">"); } - void operator()(TypeId ty, const Unifiable::Free& ftv) + void operator()(TypeId ty, const FreeType& ftv) { state.result.invalid = true; - if (FFlag::DebugLuauVerboseTypeNames) + + // TODO: ftv.lowerBound and ftv.upperBound should always be non-nil when + // the new solver is used. This can be replaced with an assert. + if (FFlag::LuauSolverV2 && ftv.lowerBound && ftv.upperBound) + { + const TypeId lowerBound = follow(ftv.lowerBound); + const TypeId upperBound = follow(ftv.upperBound); + if (get(lowerBound) && get(upperBound)) + { + state.emit("'"); + state.emit(state.getName(ty)); + } + else + { + state.emit("("); + if (!get(lowerBound)) + { + stringify(lowerBound); + state.emit(" <: "); + } + state.emit("'"); + state.emit(state.getName(ty)); + + if (!get(upperBound)) + { + state.emit(" <: "); + stringify(upperBound); + } + state.emit(")"); + } + return; + } + + if (FInt::DebugLuauVerboseTypeNames >= 1) state.emit("free-"); + state.emit(state.getName(ty)); - if (FFlag::DebugLuauVerboseTypeNames) + if (FInt::DebugLuauVerboseTypeNames >= 2) { state.emit("-"); - if (FFlag::DebugLuauDeferredConstraintResolution) + if (FFlag::LuauSolverV2) state.emitLevel(ftv.scope); else state.emit(ftv.level); @@ -388,6 +525,9 @@ struct TypeStringifier void operator()(TypeId ty, const GenericType& gtv) { + if (FInt::DebugLuauVerboseTypeNames >= 1) + state.emit("gen-"); + if (gtv.explicitName) { state.usedNames.insert(gtv.name); @@ -397,10 +537,10 @@ struct TypeStringifier else state.emit(state.getName(ty)); - if (FFlag::DebugLuauVerboseTypeNames) + if (FInt::DebugLuauVerboseTypeNames >= 2) { state.emit("-"); - if (FFlag::DebugLuauDeferredConstraintResolution) + if (FFlag::LuauSolverV2) state.emitLevel(gtv.scope); else state.emit(gtv.level); @@ -440,6 +580,9 @@ struct TypeStringifier case PrimitiveType::Thread: state.emit("thread"); return; + case PrimitiveType::Buffer: + state.emit("buffer"); + return; case PrimitiveType::Function: state.emit("function"); return; @@ -500,6 +643,12 @@ struct TypeStringifier state.emit(">"); } + if (FFlag::LuauSolverV2) + { + if (ftv.isCheckedFunction) + state.emit("@checked "); + } + state.emit("("); if (state.opts.functionTypeArguments) @@ -577,16 +726,33 @@ struct TypeStringifier std::string openbrace = "@@@"; std::string closedbrace = "@@@?!"; - switch (state.opts.hideTableKind ? TableState::Unsealed : ttv.state) + switch (state.opts.hideTableKind ? (FFlag::LuauSolverV2 ? TableState::Sealed : TableState::Unsealed) : ttv.state) { case TableState::Sealed: - state.result.invalid = true; - openbrace = "{|"; - closedbrace = "|}"; + if (FFlag::LuauSolverV2) + { + openbrace = "{"; + closedbrace = "}"; + } + else + { + state.result.invalid = true; + openbrace = "{|"; + closedbrace = "|}"; + } break; case TableState::Unsealed: - openbrace = "{"; - closedbrace = "}"; + if (FFlag::LuauSolverV2) + { + state.result.invalid = true; + openbrace = "{|"; + closedbrace = "|}"; + } + else + { + openbrace = "{"; + closedbrace = "}"; + } break; case TableState::Free: state.result.invalid = true; @@ -647,16 +813,8 @@ struct TypeStringifier break; } - if (isIdentifier(name)) - state.emit(name); - else - { - state.emit("[\""); - state.emit(escape(name)); - state.emit("\"]"); - } - state.emit(": "); - stringify(prop.type); + stringify(name, prop); + comma = true; ++index; } @@ -698,6 +856,11 @@ struct TypeStringifier state.emit("any"); } + void operator()(TypeId, const NoRefineType&) + { + state.emit("*no-refine*"); + } + void operator()(TypeId, const UnionType& uv) { if (state.hasSeen(&uv)) @@ -727,7 +890,7 @@ struct TypeStringifier std::string saved = std::move(state.result.name); - bool needParens = !state.cycleNames.count(el) && (get(el) || get(el)); + bool needParens = !state.cycleNames.contains(el) && (get(el) || get(el)); if (needParens) state.emit("("); @@ -750,11 +913,15 @@ struct TypeStringifier state.emit("("); bool first = true; + bool shouldPlaceOnNewlines = results.size() > state.opts.compositeTypesSingleLineLimit; for (std::string& ss : results) { if (!first) { - state.newline(); + if (shouldPlaceOnNewlines) + state.newline(); + else + state.emit(" "); state.emit("| "); } state.emit(ss); @@ -774,7 +941,7 @@ struct TypeStringifier } } - void operator()(TypeId, const IntersectionType& uv) + void operator()(TypeId ty, const IntersectionType& uv) { if (state.hasSeen(&uv)) { @@ -790,7 +957,7 @@ struct TypeStringifier std::string saved = std::move(state.result.name); - bool needParens = !state.cycleNames.count(el) && (get(el) || get(el)); + bool needParens = !state.cycleNames.contains(el) && (get(el) || get(el)); if (needParens) state.emit("("); @@ -810,11 +977,15 @@ struct TypeStringifier std::sort(results.begin(), results.end()); bool first = true; + bool shouldPlaceOnNewlines = results.size() > state.opts.compositeTypesSingleLineLimit || isOverloadedFunction(ty); for (std::string& ss : results) { if (!first) { - state.newline(); + if (shouldPlaceOnNewlines) + state.newline(); + else + state.emit(" "); state.emit("& "); } state.emit(ss); @@ -830,8 +1001,15 @@ struct TypeStringifier void operator()(TypeId, const LazyType& ltv) { - state.result.invalid = true; - state.emit("lazy?"); + if (TypeId unwrapped = ltv.unwrapped.load()) + { + stringify(unwrapped); + } + else + { + state.result.invalid = true; + state.emit("lazy?"); + } } void operator()(TypeId, const UnknownType& ttv) @@ -860,6 +1038,37 @@ struct TypeStringifier if (parens) state.emit(")"); } + + void operator()(TypeId, const TypeFunctionInstanceType& tfitv) + { + if (tfitv.userFuncName) // Special stringification for user-defined type functions + state.emit(tfitv.userFuncName->value); + else + state.emit(tfitv.function->name); + + state.emit("<"); + + bool comma = false; + for (TypeId ty : tfitv.typeArguments) + { + if (comma) + state.emit(", "); + + comma = true; + stringify(ty); + } + + for (TypePackId tp : tfitv.packArguments) + { + if (comma) + state.emit(", "); + + comma = true; + stringify(tp); + } + + state.emit(">"); + } }; struct TypePackStringifier @@ -900,18 +1109,19 @@ struct TypePackStringifier return; } - auto it = state.cycleTpNames.find(tp); - if (it != state.cycleTpNames.end()) + if (auto p = state.cycleTpNames.find(tp)) { - state.emit(it->second); + state.emit(*p); return; } Luau::visit( - [this, tp](auto&& t) { + [this, tp](auto&& t) + { return (*this)(tp, t); }, - tp->ty); + tp->ty + ); } void operator()(TypePackId, const TypePack& tp) @@ -947,7 +1157,7 @@ struct TypePackStringifier if (tp.tail && !isEmpty(*tp.tail)) { TypePackId tail = follow(*tp.tail); - if (auto vtp = get(tail); !vtp || (!FFlag::DebugLuauVerboseTypeNames && !vtp->hidden)) + if (auto vtp = get(tail); !vtp || (FInt::DebugLuauVerboseTypeNames < 1 && !vtp->hidden)) { if (first) first = false; @@ -970,7 +1180,7 @@ struct TypePackStringifier void operator()(TypePackId, const VariadicTypePack& pack) { state.emit("..."); - if (FFlag::DebugLuauVerboseTypeNames && pack.hidden) + if (FInt::DebugLuauVerboseTypeNames >= 1 && pack.hidden) { state.emit("*hidden*"); } @@ -979,6 +1189,9 @@ struct TypePackStringifier void operator()(TypePackId tp, const GenericTypePack& pack) { + if (FInt::DebugLuauVerboseTypeNames >= 1) + state.emit("gen-"); + if (pack.explicitName) { state.usedNames.insert(pack.name); @@ -990,28 +1203,29 @@ struct TypePackStringifier state.emit(state.getName(tp)); } - if (FFlag::DebugLuauVerboseTypeNames) + if (FInt::DebugLuauVerboseTypeNames >= 2) { state.emit("-"); - if (FFlag::DebugLuauDeferredConstraintResolution) + if (FFlag::LuauSolverV2) state.emitLevel(pack.scope); else state.emit(pack.level); } + state.emit("..."); } void operator()(TypePackId tp, const FreeTypePack& pack) { state.result.invalid = true; - if (FFlag::DebugLuauVerboseTypeNames) + if (FInt::DebugLuauVerboseTypeNames >= 1) state.emit("free-"); state.emit(state.getName(tp)); - if (FFlag::DebugLuauVerboseTypeNames) + if (FInt::DebugLuauVerboseTypeNames >= 2) { state.emit("-"); - if (FFlag::DebugLuauDeferredConstraintResolution) + if (FFlag::LuauSolverV2) state.emitLevel(pack.scope); else state.emit(pack.level); @@ -1031,6 +1245,33 @@ struct TypePackStringifier state.emit(btp.index); state.emit("*"); } + + void operator()(TypePackId, const TypeFunctionInstanceTypePack& tfitp) + { + state.emit(tfitp.function->name); + state.emit("<"); + + bool comma = false; + for (TypeId p : tfitp.typeArguments) + { + if (comma) + state.emit(", "); + + comma = true; + stringify(p); + } + + for (TypePackId p : tfitp.packArguments) + { + if (comma) + state.emit(", "); + + comma = true; + stringify(p); + } + + state.emit(">"); + } }; void TypeStringifier::stringify(TypePackId tp) @@ -1045,8 +1286,13 @@ void TypeStringifier::stringify(TypePackId tpid, const std::vector& cycles, const std::set& cycleTPs, - std::unordered_map& cycleNames, std::unordered_map& cycleTpNames, bool exhaustive) +static void assignCycleNames( + const std::set& cycles, + const std::set& cycleTPs, + DenseHashMap& cycleNames, + DenseHashMap& cycleTpNames, + bool exhaustive +) { int nextIndex = 1; @@ -1058,9 +1304,14 @@ static void assignCycleNames(const std::set& cycles, const std::set(follow(cycleTy)); !exhaustive && ttv && (ttv->syntheticName || ttv->name)) { // If we have a cycle type in type parameters, assign a cycle name for this named table - if (std::find_if(ttv->instantiatedTypeParams.begin(), ttv->instantiatedTypeParams.end(), [&](auto&& el) { - return cycles.count(follow(el)); - }) != ttv->instantiatedTypeParams.end()) + if (std::find_if( + ttv->instantiatedTypeParams.begin(), + ttv->instantiatedTypeParams.end(), + [&](auto&& el) + { + return cycles.count(follow(el)); + } + ) != ttv->instantiatedTypeParams.end()) cycleNames[cycleTy] = ttv->name ? *ttv->name : *ttv->syntheticName; continue; @@ -1140,9 +1391,8 @@ ToStringResult toStringDetailed(TypeId ty, ToStringOptions& opts) * * t1 where t1 = the_whole_root_type */ - auto it = state.cycleNames.find(ty); - if (it != state.cycleNames.end()) - state.emit(it->second); + if (auto p = state.cycleNames.find(ty)) + state.emit(*p); else tvs.stringify(ty); @@ -1155,9 +1405,14 @@ ToStringResult toStringDetailed(TypeId ty, ToStringOptions& opts) state.exhaustive = true; std::vector> sortedCycleNames{state.cycleNames.begin(), state.cycleNames.end()}; - std::sort(sortedCycleNames.begin(), sortedCycleNames.end(), [](const auto& a, const auto& b) { - return a.second < b.second; - }); + std::sort( + sortedCycleNames.begin(), + sortedCycleNames.end(), + [](const auto& a, const auto& b) + { + return a.second < b.second; + } + ); bool semi = false; for (const auto& [cycleTy, name] : sortedCycleNames) @@ -1168,18 +1423,25 @@ ToStringResult toStringDetailed(TypeId ty, ToStringOptions& opts) state.emit(name); state.emit(" = "); Luau::visit( - [&tvs, cycleTy = cycleTy](auto&& t) { + [&tvs, cycleTy = cycleTy](auto&& t) + { return tvs(cycleTy, t); }, - cycleTy->ty); + cycleTy->ty + ); semi = true; } std::vector> sortedCycleTpNames(state.cycleTpNames.begin(), state.cycleTpNames.end()); - std::sort(sortedCycleTpNames.begin(), sortedCycleTpNames.end(), [](const auto& a, const auto& b) { - return a.second < b.second; - }); + std::sort( + sortedCycleTpNames.begin(), + sortedCycleTpNames.end(), + [](const auto& a, const auto& b) + { + return a.second < b.second; + } + ); TypePackStringifier tps{state}; @@ -1191,10 +1453,12 @@ ToStringResult toStringDetailed(TypeId ty, ToStringOptions& opts) state.emit(name); state.emit(" = "); Luau::visit( - [&tps, cycleTy = cycleTp](auto&& t) { + [&tps, cycleTy = cycleTp](auto&& t) + { return tps(cycleTy, t); }, - cycleTp->ty); + cycleTp->ty + ); semi = true; } @@ -1234,13 +1498,12 @@ ToStringResult toStringDetailed(TypePackId tp, ToStringOptions& opts) * * t1 where t1 = the_whole_root_type */ - auto it = state.cycleTpNames.find(tp); - if (it != state.cycleTpNames.end()) - state.emit(it->second); + if (auto p = state.cycleTpNames.find(tp)) + state.emit(*p); else tvs.stringify(tp); - if (!cycles.empty()) + if (!cycles.empty() || !cycleTPs.empty()) { result.cycle = true; state.emit(" where "); @@ -1249,9 +1512,14 @@ ToStringResult toStringDetailed(TypePackId tp, ToStringOptions& opts) state.exhaustive = true; std::vector> sortedCycleNames{state.cycleNames.begin(), state.cycleNames.end()}; - std::sort(sortedCycleNames.begin(), sortedCycleNames.end(), [](const auto& a, const auto& b) { - return a.second < b.second; - }); + std::sort( + sortedCycleNames.begin(), + sortedCycleNames.end(), + [](const auto& a, const auto& b) + { + return a.second < b.second; + } + ); bool semi = false; for (const auto& [cycleTy, name] : sortedCycleNames) @@ -1262,10 +1530,42 @@ ToStringResult toStringDetailed(TypePackId tp, ToStringOptions& opts) state.emit(name); state.emit(" = "); Luau::visit( - [&tvs, cycleTy = cycleTy](auto t) { + [&tvs, cycleTy = cycleTy](auto t) + { return tvs(cycleTy, t); }, - cycleTy->ty); + cycleTy->ty + ); + + semi = true; + } + + std::vector> sortedCycleTpNames{state.cycleTpNames.begin(), state.cycleTpNames.end()}; + std::sort( + sortedCycleTpNames.begin(), + sortedCycleTpNames.end(), + [](const auto& a, const auto& b) + { + return a.second < b.second; + } + ); + + TypePackStringifier tps{tvs.state}; + + for (const auto& [cycleTp, name] : sortedCycleTpNames) + { + if (semi) + state.emit(" ; "); + + state.emit(name); + state.emit(" = "); + Luau::visit( + [&tps, cycleTp = cycleTp](auto t) + { + return tps(cycleTp, t); + }, + cycleTp->ty + ); semi = true; } @@ -1451,12 +1751,26 @@ std::string generateName(size_t i) return n; } +std::string toStringVector(const std::vector& types, ToStringOptions& opts) +{ + std::string s; + for (TypeId ty : types) + { + if (!s.empty()) + s += ", "; + s += toString(ty, opts); + } + return s; +} + std::string toString(const Constraint& constraint, ToStringOptions& opts) { - auto go = [&opts](auto&& c) -> std::string { + auto go = [&opts](auto&& c) -> std::string + { using T = std::decay_t; - auto tos = [&opts](auto&& a) { + auto tos = [&opts](auto&& a) + { return toString(a, opts); }; @@ -1470,7 +1784,7 @@ std::string toString(const Constraint& constraint, ToStringOptions& opts) { std::string subStr = tos(c.subPack); std::string superStr = tos(c.superPack); - return subStr + " <: " + superStr; + return subStr + " <...: " + superStr; } else if constexpr (std::is_same_v) { @@ -1478,33 +1792,12 @@ std::string toString(const Constraint& constraint, ToStringOptions& opts) std::string superStr = tos(c.sourceType); return subStr + " ~ gen " + superStr; } - else if constexpr (std::is_same_v) - { - std::string subStr = tos(c.subType); - std::string superStr = tos(c.superType); - return subStr + " ~ inst " + superStr; - } - else if constexpr (std::is_same_v) - { - std::string resultStr = tos(c.resultType); - std::string operandStr = tos(c.operandType); - - return resultStr + " ~ Unary<" + toString(c.op) + ", " + operandStr + ">"; - } - else if constexpr (std::is_same_v) - { - std::string resultStr = tos(c.resultType); - std::string leftStr = tos(c.leftType); - std::string rightStr = tos(c.rightType); - - return resultStr + " ~ Binary<" + toString(c.op) + ", " + leftStr + ", " + rightStr + ">"; - } else if constexpr (std::is_same_v) { std::string iteratorStr = tos(c.iterator); - std::string variableStr = tos(c.variables); + std::string variableStr = toStringVector(c.variables, opts); - return variableStr + " ~ Iterate<" + iteratorStr + ">"; + return variableStr + " ~ iterate " + iteratorStr; } else if constexpr (std::is_same_v) { @@ -1518,37 +1811,41 @@ std::string toString(const Constraint& constraint, ToStringOptions& opts) } else if constexpr (std::is_same_v) { - return "call " + tos(c.fn) + " with { result = " + tos(c.result) + " }"; + return "call " + tos(c.fn) + "( " + tos(c.argsPack) + " )" + " with { result = " + tos(c.result) + " }"; } - else if constexpr (std::is_same_v) + else if constexpr (std::is_same_v) { - return tos(c.resultType) + " ~ prim " + tos(c.expectedType) + ", " + tos(c.singletonType) + ", " + tos(c.multitonType); + return "function_check " + tos(c.fn) + " " + tos(c.argsPack); } - else if constexpr (std::is_same_v) + else if constexpr (std::is_same_v) { - return tos(c.resultType) + " ~ hasProp " + tos(c.subjectType) + ", \"" + c.prop + "\""; + if (c.expectedType) + return "prim " + tos(c.freeType) + "[expected: " + tos(*c.expectedType) + "] as " + tos(c.primitiveType); + else + return "prim " + tos(c.freeType) + " as " + tos(c.primitiveType); } - else if constexpr (std::is_same_v) + else if constexpr (std::is_same_v) { - const std::string pathStr = c.path.size() == 1 ? "\"" + c.path[0] + "\"" : "[\"" + join(c.path, "\", \"") + "\"]"; - return tos(c.resultType) + " ~ setProp " + tos(c.subjectType) + ", " + pathStr + " " + tos(c.propType); + return tos(c.resultType) + " ~ hasProp " + tos(c.subjectType) + ", \"" + c.prop + "\" ctx=" + std::to_string(int(c.context)); } - else if constexpr (std::is_same_v) + else if constexpr (std::is_same_v) { - return tos(c.resultType) + " ~ setIndexer " + tos(c.subjectType) + " [ " + tos(c.indexType) + " ] " + tos(c.propType); + return tos(c.resultType) + " ~ hasIndexer " + tos(c.subjectType) + " " + tos(c.indexType); } - else if constexpr (std::is_same_v) + else if constexpr (std::is_same_v) + return "assignProp " + tos(c.lhsType) + " " + c.propName + " " + tos(c.rhsType); + else if constexpr (std::is_same_v) + return "assignIndex " + tos(c.lhsType) + " " + tos(c.indexType) + " " + tos(c.rhsType); + else if constexpr (std::is_same_v) + return toStringVector(c.resultPack, opts) + " ~ ...unpack " + tos(c.sourcePack); + else if constexpr (std::is_same_v) + return "reduce " + tos(c.ty); + else if constexpr (std::is_same_v) { - std::string result = tos(c.resultType); - std::string discriminant = tos(c.discriminantType); - - if (c.negated) - return result + " ~ if isSingleton D then ~D else unknown where D = " + discriminant; - else - return result + " ~ if isSingleton D then D else unknown where D = " + discriminant; + return "reduce " + tos(c.tp); } - else if constexpr (std::is_same_v) - return tos(c.resultPack) + " ~ unpack " + tos(c.sourcePack); + else if constexpr (std::is_same_v) + return "equality: " + tos(c.resultType) + " ~ " + tos(c.assignmentType); else static_assert(always_false_v, "Non-exhaustive constraint switch"); }; @@ -1556,6 +1853,11 @@ std::string toString(const Constraint& constraint, ToStringOptions& opts) return visit(go, constraint.c); } +std::string toString(const Constraint& constraint) +{ + return toString(constraint, ToStringOptions{}); +} + std::string dump(const Constraint& c) { ToStringOptions opts; @@ -1603,9 +1905,30 @@ std::string toString(const Position& position) return "{ line = " + std::to_string(position.line) + ", col = " + std::to_string(position.column) + " }"; } -std::string toString(const Location& location) +std::string toString(const Location& location, int offset, bool useBegin) +{ + return "(" + std::to_string(location.begin.line + offset) + ", " + std::to_string(location.begin.column + offset) + ") - (" + + std::to_string(location.end.line + offset) + ", " + std::to_string(location.end.column + offset) + ")"; +} + +std::string toString(const TypeOrPack& tyOrTp, ToStringOptions& opts) { - return "Location { " + toString(location.begin) + ", " + toString(location.end) + " }"; + if (const TypeId* ty = get(tyOrTp)) + return toString(*ty, opts); + else if (const TypePackId* tp = get(tyOrTp)) + return toString(*tp, opts); + else + LUAU_UNREACHABLE(); +} + +std::string dump(const TypeOrPack& tyOrTp) +{ + ToStringOptions opts; + opts.exhaustive = true; + opts.functionTypeArguments = true; + std::string s = toString(tyOrTp, opts); + printf("%s\n", s.c_str()); + return s; } } // namespace Luau diff --git a/Analysis/src/Transpiler.cpp b/Analysis/src/Transpiler.cpp index 41442f5bc..e29cc4682 100644 --- a/Analysis/src/Transpiler.cpp +++ b/Analysis/src/Transpiler.cpp @@ -10,6 +10,7 @@ #include #include + namespace { bool isIdentifierStartChar(char c) @@ -27,8 +28,8 @@ bool isIdentifierChar(char c) return isIdentifierStartChar(c) || isDigit(c); } -const std::vector keywords = {"and", "break", "do", "else", "elseif", "end", "false", "for", "function", "if", "in", "local", "nil", - "not", "or", "repeat", "return", "then", "true", "until", "while"}; +const std::vector keywords = {"and", "break", "do", "else", "elseif", "end", "false", "for", "function", "if", "in", + "local", "nil", "not", "or", "repeat", "return", "then", "true", "until", "while"}; } // namespace @@ -476,11 +477,11 @@ struct Printer case AstExprBinary::Sub: case AstExprBinary::Mul: case AstExprBinary::Div: + case AstExprBinary::FloorDiv: case AstExprBinary::Mod: case AstExprBinary::Pow: case AstExprBinary::CompareLt: case AstExprBinary::CompareGt: - case AstExprBinary::DivInt: case AstExprBinary::MaxOf: case AstExprBinary::MinOf: case AstExprBinary::BinAnd: @@ -504,6 +505,8 @@ struct Printer writer.maybeSpace(a->right->location.begin, 4); writer.keyword(toString(a->op)); break; + default: + LUAU_ASSERT(!"Unknown Op"); } visualize(*a->right); @@ -770,6 +773,10 @@ struct Printer writer.maybeSpace(a->value->location.begin, 2); writer.symbol("/="); break; + case AstExprBinary::FloorDiv: + writer.maybeSpace(a->value->location.begin, 2); + writer.symbol("//="); + break; case AstExprBinary::Mod: writer.maybeSpace(a->value->location.begin, 2); writer.symbol("%="); @@ -853,6 +860,15 @@ struct Printer visualizeTypeAnnotation(*a->type); } } + else if (const auto& t = program.as()) + { + if (writeTypes) + { + writer.keyword("type function"); + writer.identifier(t->name.value); + visualizeFunctionBody(*t->body); + } + } else if (const auto& a = program.as()) { writer.symbol("(error-stat"); @@ -1191,11 +1207,11 @@ std::string toString(AstNode* node) Printer printer(writer); printer.writeTypes = true; - if (auto statNode = dynamic_cast(node)) + if (auto statNode = node->asStat()) printer.visualize(*statNode); - else if (auto exprNode = dynamic_cast(node)) + else if (auto exprNode = node->asExpr()) printer.visualize(*exprNode); - else if (auto typeNode = dynamic_cast(node)) + else if (auto typeNode = node->asType()) printer.visualizeTypeAnnotation(*typeNode); return writer.str(); diff --git a/Analysis/src/TxnLog.cpp b/Analysis/src/TxnLog.cpp index 5040952e8..e272c6610 100644 --- a/Analysis/src/TxnLog.cpp +++ b/Analysis/src/TxnLog.cpp @@ -1,6 +1,7 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/TxnLog.h" +#include "Luau/Scope.h" #include "Luau/ToString.h" #include "Luau/TypeArena.h" #include "Luau/TypePack.h" @@ -71,20 +72,29 @@ const TxnLog* TxnLog::empty() void TxnLog::concat(TxnLog rhs) { for (auto& [ty, rep] : rhs.typeVarChanges) + { + if (rep->dead) + continue; typeVarChanges[ty] = std::move(rep); + } for (auto& [tp, rep] : rhs.typePackChanges) typePackChanges[tp] = std::move(rep); + + radioactive |= rhs.radioactive; } void TxnLog::concatAsIntersections(TxnLog rhs, NotNull arena) { for (auto& [ty, rightRep] : rhs.typeVarChanges) { - if (auto leftRep = typeVarChanges.find(ty)) + if (rightRep->dead) + continue; + + if (auto leftRep = typeVarChanges.find(ty); leftRep && !(*leftRep)->dead) { - TypeId leftTy = arena->addType((*leftRep)->pending); - TypeId rightTy = arena->addType(rightRep->pending); + TypeId leftTy = arena->addType((*leftRep)->pending.clone()); + TypeId rightTy = arena->addType(rightRep->pending.clone()); typeVarChanges[ty]->pending.ty = IntersectionType{{leftTy, rightTy}}; } else @@ -93,17 +103,80 @@ void TxnLog::concatAsIntersections(TxnLog rhs, NotNull arena) for (auto& [tp, rep] : rhs.typePackChanges) typePackChanges[tp] = std::move(rep); + + radioactive |= rhs.radioactive; } void TxnLog::concatAsUnion(TxnLog rhs, NotNull arena) { + /* + * Check for cycles. + * + * We must not combine a log entry that binds 'a to 'b with a log that + * binds 'b to 'a. + * + * Of the two, identify the one with the 'bigger' scope and eliminate the + * entry that rebinds it. + */ + for (const auto& [rightTy, rightRep] : rhs.typeVarChanges) + { + if (rightRep->dead) + continue; + + // We explicitly use get_if here because we do not wish to do anything + // if the uncommitted type is already bound to something else. + const FreeType* rf = get_if(&rightTy->ty); + if (!rf) + continue; + + const BoundType* rb = Luau::get(&rightRep->pending); + if (!rb) + continue; + + const TypeId leftTy = rb->boundTo; + const FreeType* lf = get_if(&leftTy->ty); + if (!lf) + continue; + + auto leftRep = typeVarChanges.find(leftTy); + if (!leftRep) + continue; + + if ((*leftRep)->dead) + continue; + + const BoundType* lb = Luau::get(&(*leftRep)->pending); + if (!lb) + continue; + + if (lb->boundTo == rightTy) + { + // leftTy has been bound to rightTy, but rightTy has also been bound + // to leftTy. We find the one that belongs to the more deeply nested + // scope and remove it from the log. + const bool discardLeft = useScopes ? subsumes(lf->scope, rf->scope) : lf->level.subsumes(rf->level); + + if (discardLeft) + (*leftRep)->dead = true; + else + rightRep->dead = true; + } + } + for (auto& [ty, rightRep] : rhs.typeVarChanges) { - if (auto leftRep = typeVarChanges.find(ty)) + if (rightRep->dead) + continue; + + if (auto leftRep = typeVarChanges.find(ty); leftRep && !(*leftRep)->dead) { - TypeId leftTy = arena->addType((*leftRep)->pending); - TypeId rightTy = arena->addType(rightRep->pending); - typeVarChanges[ty]->pending.ty = UnionType{{leftTy, rightTy}}; + TypeId leftTy = arena->addType((*leftRep)->pending.clone()); + TypeId rightTy = arena->addType(rightRep->pending.clone()); + + if (follow(leftTy) == follow(rightTy)) + typeVarChanges[ty] = std::move(rightRep); + else + typeVarChanges[ty]->pending.ty = UnionType{{leftTy, rightTy}}; } else typeVarChanges[ty] = std::move(rightRep); @@ -111,12 +184,19 @@ void TxnLog::concatAsUnion(TxnLog rhs, NotNull arena) for (auto& [tp, rep] : rhs.typePackChanges) typePackChanges[tp] = std::move(rep); + + radioactive |= rhs.radioactive; } void TxnLog::commit() { + LUAU_ASSERT(!radioactive); + for (auto& [ty, rep] : typeVarChanges) - asMutable(ty)->reassign(rep.get()->pending); + { + if (!rep->dead) + asMutable(ty)->reassign(rep.get()->pending); + } for (auto& [tp, rep] : typePackChanges) asMutable(tp)->reassign(rep.get()->pending); @@ -135,11 +215,16 @@ TxnLog TxnLog::inverse() TxnLog inversed(sharedSeen); for (auto& [ty, _rep] : typeVarChanges) - inversed.typeVarChanges[ty] = std::make_unique(*ty); + { + if (!_rep->dead) + inversed.typeVarChanges[ty] = std::make_unique(ty->clone()); + } for (auto& [tp, _rep] : typePackChanges) inversed.typePackChanges[tp] = std::make_unique(*tp); + inversed.radioactive = radioactive; + return inversed; } @@ -199,14 +284,15 @@ void TxnLog::popSeen(TypeOrPackId lhs, TypeOrPackId rhs) PendingType* TxnLog::queue(TypeId ty) { - LUAU_ASSERT(!ty->persistent); + if (ty->persistent) + radioactive = true; // Explicitly don't look in ancestors. If we have discovered something new // about this type, we don't want to mutate the parent's state. auto& pending = typeVarChanges[ty]; - if (!pending) + if (!pending || (*pending).dead) { - pending = std::make_unique(*ty); + pending = std::make_unique(ty->clone()); pending->pending.owningArena = nullptr; } @@ -215,7 +301,8 @@ PendingType* TxnLog::queue(TypeId ty) PendingTypePack* TxnLog::queue(TypePackId tp) { - LUAU_ASSERT(!tp->persistent); + if (tp->persistent) + radioactive = true; // Explicitly don't look in ancestors. If we have discovered something new // about this type, we don't want to mutate the parent's state. @@ -237,7 +324,7 @@ PendingType* TxnLog::pending(TypeId ty) const for (const TxnLog* current = this; current; current = current->parent) { - if (auto it = current->typeVarChanges.find(ty)) + if (auto it = current->typeVarChanges.find(ty); it && !(*it)->dead) return it->get(); } @@ -276,6 +363,7 @@ PendingTypePack* TxnLog::replace(TypePackId tp, TypePackVar replacement) PendingType* TxnLog::bindTable(TypeId ty, std::optional newBoundTo) { LUAU_ASSERT(get(ty)); + LUAU_ASSERT(ty != newBoundTo); PendingType* newTy = queue(ty); if (TableType* ttv = Luau::getMutable(newTy)) @@ -381,32 +469,44 @@ std::optional TxnLog::getLevel(TypeId ty) const TypeId TxnLog::follow(TypeId ty) const { - return Luau::follow(ty, [this](TypeId ty) { - PendingType* state = this->pending(ty); + return Luau::follow( + ty, + this, + [](const void* ctx, TypeId ty) -> TypeId + { + const TxnLog* self = static_cast(ctx); + PendingType* state = self->pending(ty); - if (state == nullptr) - return ty; + if (state == nullptr) + return ty; - // Ugly: Fabricate a TypeId that doesn't adhere to most of the invariants - // that normally apply. This is safe because follow will only call get<> - // on the returned pointer. - return const_cast(&state->pending); - }); + // Ugly: Fabricate a TypeId that doesn't adhere to most of the invariants + // that normally apply. This is safe because follow will only call get<> + // on the returned pointer. + return const_cast(&state->pending); + } + ); } TypePackId TxnLog::follow(TypePackId tp) const { - return Luau::follow(tp, [this](TypePackId tp) { - PendingTypePack* state = this->pending(tp); + return Luau::follow( + tp, + this, + [](const void* ctx, TypePackId tp) -> TypePackId + { + const TxnLog* self = static_cast(ctx); + PendingTypePack* state = self->pending(tp); - if (state == nullptr) - return tp; + if (state == nullptr) + return tp; - // Ugly: Fabricate a TypePackId that doesn't adhere to most of the - // invariants that normally apply. This is safe because follow will - // only call get<> on the returned pointer. - return const_cast(&state->pending); - }); + // Ugly: Fabricate a TypePackId that doesn't adhere to most of the + // invariants that normally apply. This is safe because follow will + // only call get<> on the returned pointer. + return const_cast(&state->pending); + } + ); } std::pair, std::vector> TxnLog::getChanges() const diff --git a/Analysis/src/Type.cpp b/Analysis/src/Type.cpp index 4e2e19bd7..b87de3713 100644 --- a/Analysis/src/Type.cpp +++ b/Analysis/src/Type.cpp @@ -9,8 +9,10 @@ #include "Luau/RecursionCounter.h" #include "Luau/StringUtils.h" #include "Luau/ToString.h" +#include "Luau/TypeFunction.h" #include "Luau/TypeInfer.h" #include "Luau/TypePack.h" +#include "Luau/VecDeque.h" #include "Luau/VisitType.h" #include @@ -25,67 +27,82 @@ LUAU_FASTINTVARIABLE(LuauTypeMaximumStringifierLength, 500) LUAU_FASTINTVARIABLE(LuauTableTypeMaximumStringifierLength, 0) LUAU_FASTINT(LuauTypeInferRecursionLimit) LUAU_FASTFLAG(LuauInstantiateInSubtyping) -LUAU_FASTFLAGVARIABLE(LuauMatchReturnsOptionalString, false); namespace Luau { -std::optional> magicFunctionFormat( - TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate); -static bool dcrMagicFunctionFormat(MagicFunctionCallContext context); +// LUAU_NOINLINE prevents unwrapLazy from being inlined into advance below; advance is important to keep inlineable +static LUAU_NOINLINE TypeId unwrapLazy(LazyType* ltv) +{ + TypeId unwrapped = ltv->unwrapped.load(); + + if (unwrapped) + return unwrapped; + + ltv->unwrap(*ltv); + unwrapped = ltv->unwrapped.load(); -static std::optional> magicFunctionGmatch( - TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate); -static bool dcrMagicFunctionGmatch(MagicFunctionCallContext context); + if (!unwrapped) + throw InternalCompilerError("Lazy Type didn't fill in unwrapped type field"); -static std::optional> magicFunctionMatch( - TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate); -static bool dcrMagicFunctionMatch(MagicFunctionCallContext context); + if (get(unwrapped)) + throw InternalCompilerError("Lazy Type cannot resolve to another Lazy Type"); -static std::optional> magicFunctionFind( - TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate); -static bool dcrMagicFunctionFind(MagicFunctionCallContext context); + return unwrapped; +} TypeId follow(TypeId t) { - return follow(t, [](TypeId t) { - return t; - }); + return follow(t, FollowOption::Normal); +} + +TypeId follow(TypeId t, FollowOption followOption) +{ + return follow( + t, + followOption, + nullptr, + [](const void*, TypeId t) -> TypeId + { + return t; + } + ); } -TypeId follow(TypeId t, std::function mapper) +TypeId follow(TypeId t, const void* context, TypeId (*mapper)(const void*, TypeId)) { - auto advance = [&mapper](TypeId ty) -> std::optional { - if (auto btv = get>(mapper(ty))) + return follow(t, FollowOption::Normal, context, mapper); +} + +TypeId follow(TypeId t, FollowOption followOption, const void* context, TypeId (*mapper)(const void*, TypeId)) +{ + auto advance = [followOption, context, mapper](TypeId ty) -> std::optional + { + TypeId mapped = mapper(context, ty); + + if (auto btv = get>(mapped)) return btv->boundTo; - else if (auto ttv = get(mapper(ty))) + + if (auto ttv = get(mapped)) return ttv->boundTo; - else - return std::nullopt; - }; - auto force = [&mapper](TypeId ty) { - if (auto ltv = get_if(&mapper(ty)->ty)) - { - TypeId res = ltv->thunk(); - if (get(res)) - throw InternalCompilerError("Lazy Type cannot resolve to another Lazy Type"); + if (auto ltv = getMutable(mapped); ltv && followOption != FollowOption::DisableLazyTypeThunks) + return unwrapLazy(ltv); - *asMutable(ty) = BoundType(res); - } + return std::nullopt; }; - force(t); - TypeId cycleTester = t; // Null once we've determined that there is no cycle if (auto a = advance(cycleTester)) cycleTester = *a; else return t; + if (!advance(cycleTester)) // Short circuit traversal for the rather common case when advance(advance(t)) == null + return cycleTester; + while (true) { - force(t); auto a1 = advance(t); if (a1) t = *a1; @@ -118,7 +135,7 @@ std::vector flattenIntersection(TypeId ty) return {ty}; std::unordered_set seen; - std::deque queue{ty}; + VecDeque queue{ty}; std::vector result; @@ -204,6 +221,11 @@ bool isThread(TypeId ty) return isPrim(ty, PrimitiveType::Thread); } +bool isBuffer(TypeId ty) +{ + return isPrim(ty, PrimitiveType::Buffer); +} + bool isOptional(TypeId ty) { if (isNil(ty)) @@ -230,12 +252,22 @@ bool isTableIntersection(TypeId ty) return std::all_of(parts.begin(), parts.end(), getTableType); } +bool isTableUnion(TypeId ty) +{ + const UnionType* ut = get(follow(ty)); + if (!ut) + return false; + + return std::all_of(begin(ut), end(ut), getTableType); +} + bool isOverloadedFunction(TypeId ty) { if (!get(follow(ty))) return false; - auto isFunction = [](TypeId part) -> bool { + auto isFunction = [](TypeId part) -> bool + { return get(part); }; @@ -337,7 +369,16 @@ bool isSubset(const UnionType& super, const UnionType& sub) return true; } +bool hasPrimitiveTypeInIntersection(TypeId ty, PrimitiveType::Type primTy) +{ + TypeId tf = follow(ty); + if (isPrim(tf, primTy)) + return true; + for (auto t : flattenIntersection(tf)) + return isPrim(follow(t), primTy); + return false; +} // When typechecking an assignment `x = e`, we typecheck `x:T` and `e:U`, // then instantiate U if `isGeneric(U)` is true, and `maybeGeneric(T)` is false. bool isGeneric(TypeId ty) @@ -386,6 +427,13 @@ bool maybeSingleton(TypeId ty) for (TypeId option : utv) if (get(follow(option))) return true; + if (const IntersectionType* itv = get(ty)) + for (TypeId part : itv) + if (maybeSingleton(part)) // will i regret this? + return true; + if (const TypeFunctionInstanceType* tfit = get(ty)) + if (tfit->function->name == "keyof" || tfit->function->name == "rawkeyof") + return true; return false; } @@ -430,15 +478,108 @@ bool hasLength(TypeId ty, DenseHashSet& seen, int* recursionCount) return false; } +FreeType::FreeType(TypeLevel level) + : index(Unifiable::freshIndex()) + , level(level) + , scope(nullptr) +{ +} + +FreeType::FreeType(Scope* scope) + : index(Unifiable::freshIndex()) + , level{} + , scope(scope) +{ +} + +FreeType::FreeType(Scope* scope, TypeLevel level) + : index(Unifiable::freshIndex()) + , level(level) + , scope(scope) +{ +} + +FreeType::FreeType(Scope* scope, TypeId lowerBound, TypeId upperBound) + : index(Unifiable::freshIndex()) + , scope(scope) + , lowerBound(lowerBound) + , upperBound(upperBound) +{ +} + +GenericType::GenericType() + : index(Unifiable::freshIndex()) + , name("g" + std::to_string(index)) +{ +} + +GenericType::GenericType(TypeLevel level) + : index(Unifiable::freshIndex()) + , level(level) + , name("g" + std::to_string(index)) +{ +} + +GenericType::GenericType(const Name& name) + : index(Unifiable::freshIndex()) + , name(name) + , explicitName(true) +{ +} + +GenericType::GenericType(Scope* scope) + : index(Unifiable::freshIndex()) + , scope(scope) +{ +} + +GenericType::GenericType(TypeLevel level, const Name& name) + : index(Unifiable::freshIndex()) + , level(level) + , name(name) + , explicitName(true) +{ +} + +GenericType::GenericType(Scope* scope, const Name& name) + : index(Unifiable::freshIndex()) + , scope(scope) + , name(name) + , explicitName(true) +{ +} + BlockedType::BlockedType() - : index(++nextIndex) + : index(Unifiable::freshIndex()) +{ +} + +Constraint* BlockedType::getOwner() const { + return owner; } -int BlockedType::nextIndex = 0; +void BlockedType::setOwner(Constraint* newOwner) +{ + LUAU_ASSERT(owner == nullptr); + + if (owner != nullptr) + return; + + owner = newOwner; +} + +void BlockedType::replaceOwner(Constraint* newOwner) +{ + owner = newOwner; +} PendingExpansionType::PendingExpansionType( - std::optional prefix, AstName name, std::vector typeArguments, std::vector packArguments) + std::optional prefix, + AstName name, + std::vector typeArguments, + std::vector packArguments +) : prefix(prefix) , name(name) , typeArguments(typeArguments) @@ -467,7 +608,13 @@ FunctionType::FunctionType(TypeLevel level, TypePackId argTypes, TypePackId retT } FunctionType::FunctionType( - TypeLevel level, Scope* scope, TypePackId argTypes, TypePackId retTypes, std::optional defn, bool hasSelf) + TypeLevel level, + Scope* scope, + TypePackId argTypes, + TypePackId retTypes, + std::optional defn, + bool hasSelf +) : definition(std::move(defn)) , level(level) , scope(scope) @@ -477,8 +624,14 @@ FunctionType::FunctionType( { } -FunctionType::FunctionType(std::vector generics, std::vector genericPacks, TypePackId argTypes, TypePackId retTypes, - std::optional defn, bool hasSelf) +FunctionType::FunctionType( + std::vector generics, + std::vector genericPacks, + TypePackId argTypes, + TypePackId retTypes, + std::optional defn, + bool hasSelf +) : definition(std::move(defn)) , generics(generics) , genericPacks(genericPacks) @@ -488,8 +641,15 @@ FunctionType::FunctionType(std::vector generics, std::vector { } -FunctionType::FunctionType(TypeLevel level, std::vector generics, std::vector genericPacks, TypePackId argTypes, - TypePackId retTypes, std::optional defn, bool hasSelf) +FunctionType::FunctionType( + TypeLevel level, + std::vector generics, + std::vector genericPacks, + TypePackId argTypes, + TypePackId retTypes, + std::optional defn, + bool hasSelf +) : definition(std::move(defn)) , generics(generics) , genericPacks(genericPacks) @@ -500,8 +660,16 @@ FunctionType::FunctionType(TypeLevel level, std::vector generics, std::v { } -FunctionType::FunctionType(TypeLevel level, Scope* scope, std::vector generics, std::vector genericPacks, TypePackId argTypes, - TypePackId retTypes, std::optional defn, bool hasSelf) +FunctionType::FunctionType( + TypeLevel level, + Scope* scope, + std::vector generics, + std::vector genericPacks, + TypePackId argTypes, + TypePackId retTypes, + std::optional defn, + bool hasSelf +) : definition(std::move(defn)) , generics(generics) , genericPacks(genericPacks) @@ -513,6 +681,107 @@ FunctionType::FunctionType(TypeLevel level, Scope* scope, std::vector ge { } +Property::Property() {} + +Property::Property( + TypeId readTy, + bool deprecated, + const std::string& deprecatedSuggestion, + std::optional location, + const Tags& tags, + const std::optional& documentationSymbol, + std::optional typeLocation +) + : deprecated(deprecated) + , deprecatedSuggestion(deprecatedSuggestion) + , location(location) + , typeLocation(typeLocation) + , tags(tags) + , documentationSymbol(documentationSymbol) + , readTy(readTy) + , writeTy(readTy) +{ +} + +Property Property::readonly(TypeId ty) +{ + Property p; + p.readTy = ty; + return p; +} + +Property Property::writeonly(TypeId ty) +{ + Property p; + p.writeTy = ty; + return p; +} + +Property Property::rw(TypeId ty) +{ + return Property::rw(ty, ty); +} + +Property Property::rw(TypeId read, TypeId write) +{ + Property p; + p.readTy = read; + p.writeTy = write; + return p; +} + +Property Property::create(std::optional read, std::optional write) +{ + if (read && !write) + return Property::readonly(*read); + else if (!read && write) + return Property::writeonly(*write); + else + { + LUAU_ASSERT(read && write); + return Property::rw(*read, *write); + } +} + +TypeId Property::type() const +{ + LUAU_ASSERT(readTy); + return *readTy; +} + +void Property::setType(TypeId ty) +{ + readTy = ty; + if (FFlag::LuauSolverV2) + writeTy = ty; +} + +void Property::makeShared() +{ + if (writeTy) + writeTy = readTy; +} + +bool Property::isShared() const +{ + return readTy && writeTy && readTy == writeTy; +} + +bool Property::isReadOnly() const +{ + return readTy && !writeTy; +} + +bool Property::isWriteOnly() const +{ + return !readTy && writeTy; +} + +bool Property::isReadWrite() const +{ + return readTy && writeTy; +} + TableType::TableType(TableState state, TypeLevel level, Scope* scope) : state(state) , level(level) @@ -600,7 +869,7 @@ bool areEqual(SeenSet& seen, const TableType& lhs, const TableType& rhs) if (l->first != r->first) return false; - if (!areEqual(seen, *l->second.type, *r->second.type)) + if (!areEqual(seen, *l->second.type(), *r->second.type())) return false; ++l; ++r; @@ -730,9 +999,22 @@ Type& Type::operator=(const Type& rhs) return *this; } -TypeId makeFunction(TypeArena& arena, std::optional selfType, std::initializer_list generics, - std::initializer_list genericPacks, std::initializer_list paramTypes, std::initializer_list paramNames, - std::initializer_list retTypes); +Type Type::clone() const +{ + return *this; +} + +TypeId makeFunction( + TypeArena& arena, + std::optional selfType, + std::initializer_list generics, + std::initializer_list genericPacks, + std::initializer_list paramTypes, + std::initializer_list paramNames, + std::initializer_list retTypes +); + +TypeId makeStringMetatable(NotNull builtinTypes); // BuiltinDefinitions.cpp BuiltinTypes::BuiltinTypes() : arena(new TypeArena) @@ -742,8 +1024,9 @@ BuiltinTypes::BuiltinTypes() , stringType(arena->addType(Type{PrimitiveType{PrimitiveType::String}, /*persistent*/ true})) , booleanType(arena->addType(Type{PrimitiveType{PrimitiveType::Boolean}, /*persistent*/ true})) , threadType(arena->addType(Type{PrimitiveType{PrimitiveType::Thread}, /*persistent*/ true})) + , bufferType(arena->addType(Type{PrimitiveType{PrimitiveType::Buffer}, /*persistent*/ true})) , functionType(arena->addType(Type{PrimitiveType{PrimitiveType::Function}, /*persistent*/ true})) - , classType(arena->addType(Type{ClassType{"class", {}, std::nullopt, std::nullopt, {}, {}, {}}, /*persistent*/ true})) + , classType(arena->addType(Type{ClassType{"class", {}, std::nullopt, std::nullopt, {}, {}, {}, {}}, /*persistent*/ true})) , tableType(arena->addType(Type{PrimitiveType{PrimitiveType::Table}, /*persistent*/ true})) , emptyTableType(arena->addType(Type{TableType{TableState::Sealed, TypeLevel{}, nullptr}, /*persistent*/ true})) , trueType(arena->addType(Type{SingletonType{BooleanSingleton{true}}, /*persistent*/ true})) @@ -752,19 +1035,18 @@ BuiltinTypes::BuiltinTypes() , unknownType(arena->addType(Type{UnknownType{}, /*persistent*/ true})) , neverType(arena->addType(Type{NeverType{}, /*persistent*/ true})) , errorType(arena->addType(Type{ErrorType{}, /*persistent*/ true})) + , noRefineType(arena->addType(Type{NoRefineType{}, /*persistent*/ true})) , falsyType(arena->addType(Type{UnionType{{falseType, nilType}}, /*persistent*/ true})) , truthyType(arena->addType(Type{NegationType{falsyType}, /*persistent*/ true})) , optionalNumberType(arena->addType(Type{UnionType{{numberType, nilType}}, /*persistent*/ true})) , optionalStringType(arena->addType(Type{UnionType{{stringType, nilType}}, /*persistent*/ true})) + , emptyTypePack(arena->addTypePack(TypePackVar{TypePack{{}}, /*persistent*/ true})) , anyTypePack(arena->addTypePack(TypePackVar{VariadicTypePack{anyType}, /*persistent*/ true})) + , unknownTypePack(arena->addTypePack(TypePackVar{VariadicTypePack{unknownType}, /*persistent*/ true})) , neverTypePack(arena->addTypePack(TypePackVar{VariadicTypePack{neverType}, /*persistent*/ true})) , uninhabitableTypePack(arena->addTypePack(TypePackVar{TypePack{{neverType}, neverTypePack}, /*persistent*/ true})) , errorTypePack(arena->addTypePack(TypePackVar{Unifiable::Error{}, /*persistent*/ true})) { - TypeId stringMetatable = makeStringMetatable(); - asMutable(stringType)->ty = PrimitiveType{PrimitiveType::String, stringMetatable}; - persist(stringMetatable); - freeze(*arena); } @@ -780,105 +1062,29 @@ BuiltinTypes::~BuiltinTypes() FFlag::DebugLuauFreezeArena.value = prevFlag; } -TypeId BuiltinTypes::makeStringMetatable() -{ - const TypeId optionalNumber = arena->addType(UnionType{{nilType, numberType}}); - const TypeId optionalString = arena->addType(UnionType{{nilType, stringType}}); - const TypeId optionalBoolean = arena->addType(UnionType{{nilType, booleanType}}); - - const TypePackId oneStringPack = arena->addTypePack({stringType}); - const TypePackId anyTypePack = arena->addTypePack(TypePackVar{VariadicTypePack{anyType}, true}); - - FunctionType formatFTV{arena->addTypePack(TypePack{{stringType}, anyTypePack}), oneStringPack}; - formatFTV.magicFunction = &magicFunctionFormat; - const TypeId formatFn = arena->addType(formatFTV); - attachDcrMagicFunction(formatFn, dcrMagicFunctionFormat); - - const TypePackId emptyPack = arena->addTypePack({}); - const TypePackId stringVariadicList = arena->addTypePack(TypePackVar{VariadicTypePack{stringType}}); - const TypePackId numberVariadicList = arena->addTypePack(TypePackVar{VariadicTypePack{numberType}}); - - const TypeId stringToStringType = makeFunction(*arena, std::nullopt, {}, {}, {stringType}, {}, {stringType}); - - const TypeId replArgType = - arena->addType(UnionType{{stringType, arena->addType(TableType({}, TableIndexer(stringType, stringType), TypeLevel{}, TableState::Generic)), - makeFunction(*arena, std::nullopt, {}, {}, {stringType}, {}, {stringType})}}); - const TypeId gsubFunc = makeFunction(*arena, stringType, {}, {}, {stringType, replArgType, optionalNumber}, {}, {stringType, numberType}); - const TypeId gmatchFunc = - makeFunction(*arena, stringType, {}, {}, {stringType}, {}, {arena->addType(FunctionType{emptyPack, stringVariadicList})}); - attachMagicFunction(gmatchFunc, magicFunctionGmatch); - attachDcrMagicFunction(gmatchFunc, dcrMagicFunctionGmatch); - - const TypeId matchFunc = arena->addType( - FunctionType{arena->addTypePack({stringType, stringType, optionalNumber}), arena->addTypePack(TypePackVar{VariadicTypePack{stringType}})}); - attachMagicFunction(matchFunc, magicFunctionMatch); - attachDcrMagicFunction(matchFunc, dcrMagicFunctionMatch); - - const TypeId findFunc = arena->addType(FunctionType{arena->addTypePack({stringType, stringType, optionalNumber, optionalBoolean}), - arena->addTypePack(TypePack{{optionalNumber, optionalNumber}, stringVariadicList})}); - attachMagicFunction(findFunc, magicFunctionFind); - attachDcrMagicFunction(findFunc, dcrMagicFunctionFind); - - TableType::Props stringLib = { - {"byte", {arena->addType(FunctionType{arena->addTypePack({stringType, optionalNumber, optionalNumber}), numberVariadicList})}}, - {"char", {arena->addType(FunctionType{numberVariadicList, arena->addTypePack({stringType})})}}, - {"find", {findFunc}}, - {"format", {formatFn}}, // FIXME - {"gmatch", {gmatchFunc}}, - {"gsub", {gsubFunc}}, - {"len", {makeFunction(*arena, stringType, {}, {}, {}, {}, {numberType})}}, - {"lower", {stringToStringType}}, - {"match", {matchFunc}}, - {"rep", {makeFunction(*arena, stringType, {}, {}, {numberType}, {}, {stringType})}}, - {"reverse", {stringToStringType}}, - {"sub", {makeFunction(*arena, stringType, {}, {}, {numberType, optionalNumber}, {}, {stringType})}}, - {"upper", {stringToStringType}}, - {"split", {makeFunction(*arena, stringType, {}, {}, {optionalString}, {}, - {arena->addType(TableType{{}, TableIndexer{numberType, stringType}, TypeLevel{}, TableState::Sealed})})}}, - {"pack", {arena->addType(FunctionType{ - arena->addTypePack(TypePack{{stringType}, anyTypePack}), - oneStringPack, - })}}, - {"packsize", {makeFunction(*arena, stringType, {}, {}, {}, {}, {numberType})}}, - {"unpack", {arena->addType(FunctionType{ - arena->addTypePack(TypePack{{stringType, stringType, optionalNumber}}), - anyTypePack, - })}}, - }; - - assignPropDocumentationSymbols(stringLib, "@luau/global/string"); - - TypeId tableType = arena->addType(TableType{std::move(stringLib), std::nullopt, TypeLevel{}, TableState::Sealed}); - - if (TableType* ttv = getMutable(tableType)) - ttv->name = "typeof(string)"; - - return arena->addType(TableType{{{{"__index", {tableType}}}}, std::nullopt, TypeLevel{}, TableState::Sealed}); -} - -TypeId BuiltinTypes::errorRecoveryType() +TypeId BuiltinTypes::errorRecoveryType() const { return errorType; } -TypePackId BuiltinTypes::errorRecoveryTypePack() +TypePackId BuiltinTypes::errorRecoveryTypePack() const { return errorTypePack; } -TypeId BuiltinTypes::errorRecoveryType(TypeId guess) +TypeId BuiltinTypes::errorRecoveryType(TypeId guess) const { return guess; } -TypePackId BuiltinTypes::errorRecoveryTypePack(TypePackId guess) +TypePackId BuiltinTypes::errorRecoveryTypePack(TypePackId guess) const { return guess; } void persist(TypeId ty) { - std::deque queue{ty}; + VecDeque queue{ty}; while (!queue.empty()) { @@ -902,7 +1108,7 @@ void persist(TypeId ty) LUAU_ASSERT(ttv->state != TableState::Free);//nico@gideros && ttv->state != TableState::Unsealed); for (const auto& [_name, prop] : ttv->props) - queue.push_back(prop.type); + queue.push_back(prop.type()); if (ttv->indexer) { @@ -913,7 +1119,7 @@ void persist(TypeId ty) else if (auto ctv = get(t)) { for (const auto& [_name, prop] : ctv->props) - queue.push_back(prop.type); + queue.push_back(prop.type()); } else if (auto utv = get(t)) { @@ -933,6 +1139,14 @@ void persist(TypeId ty) else if (get(t) || get(t) || get(t) || get(t) || get(t) || get(t)) { } + else if (auto tfit = get(t)) + { + for (auto ty : tfit->typeArguments) + queue.push_back(ty); + + for (auto tp : tfit->packArguments) + persist(tp); + } else { LUAU_ASSERT(!"TypeId is not supported in a persist call"); @@ -961,6 +1175,14 @@ void persist(TypePackId tp) else if (get(tp)) { } + else if (auto tfitp = get(tp)) + { + for (auto ty : tfitp->typeArguments) + persist(ty); + + for (auto tp : tfitp->packArguments) + persist(tp); + } else { LUAU_ASSERT(!"TypePackId is not supported in a persist call"); @@ -971,7 +1193,7 @@ const TypeLevel* getLevel(TypeId ty) { ty = follow(ty); - if (auto ftv = get(ty)) + if (auto ftv = get(ty)) return &ftv->level; else if (auto ttv = get(ty)) return &ttv->level; @@ -990,7 +1212,7 @@ std::optional getLevel(TypePackId tp) { tp = follow(tp); - if (auto ftv = get(tp)) + if (auto ftv = get(tp)) return ftv->level; else return std::nullopt; @@ -1061,434 +1283,9 @@ IntersectionTypeIterator end(const IntersectionType* itv) return IntersectionTypeIterator{}; } -static std::vector parseFormatString(NotNull builtinTypes, const char* data, size_t size) +TypeId freshType(NotNull arena, NotNull builtinTypes, Scope* scope) { - const char* options = "cdiouxXeEfgGqs*"; - - std::vector result; - - for (size_t i = 0; i < size; ++i) - { - if (data[i] == '%') - { - i++; - - if (i < size && data[i] == '%') - continue; - - // we just ignore all characters (including flags/precision) up until first alphabetic character - while (i < size && !(data[i] > 0 && (isalpha(data[i]) || data[i] == '*'))) - i++; - - if (i == size) - break; - - if (data[i] == 'q' || data[i] == 's') - result.push_back(builtinTypes->stringType); - else if (data[i] == '*') - result.push_back(builtinTypes->unknownType); - else if (strchr(options, data[i])) - result.push_back(builtinTypes->numberType); - else - result.push_back(builtinTypes->errorRecoveryType(builtinTypes->anyType)); - } - } - - return result; -} - -std::optional> magicFunctionFormat( - TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate) -{ - auto [paramPack, _predicates] = withPredicate; - - TypeArena& arena = typechecker.currentModule->internalTypes; - - AstExprConstantString* fmt = nullptr; - if (auto index = expr.func->as(); index && expr.self) - { - if (auto group = index->expr->as()) - fmt = group->expr->as(); - else - fmt = index->expr->as(); - } - - if (!expr.self && expr.args.size > 0) - fmt = expr.args.data[0]->as(); - - if (!fmt) - return std::nullopt; - - std::vector expected = parseFormatString(typechecker.builtinTypes, fmt->value.data, fmt->value.size); - const auto& [params, tail] = flatten(paramPack); - - size_t paramOffset = 1; - size_t dataOffset = expr.self ? 0 : 1; - - // unify the prefix one argument at a time - for (size_t i = 0; i < expected.size() && i + paramOffset < params.size(); ++i) - { - Location location = expr.args.data[std::min(i + dataOffset, expr.args.size - 1)]->location; - - typechecker.unify(params[i + paramOffset], expected[i], scope, location); - } - - // if we know the argument count or if we have too many arguments for sure, we can issue an error - size_t numActualParams = params.size(); - size_t numExpectedParams = expected.size() + 1; // + 1 for the format string - - if (numExpectedParams != numActualParams && (!tail || numExpectedParams < numActualParams)) - typechecker.reportError(TypeError{expr.location, CountMismatch{numExpectedParams, std::nullopt, numActualParams}}); - - return WithPredicate{arena.addTypePack({typechecker.stringType})}; -} - -static bool dcrMagicFunctionFormat(MagicFunctionCallContext context) -{ - TypeArena* arena = context.solver->arena; - - AstExprConstantString* fmt = nullptr; - if (auto index = context.callSite->func->as(); index && context.callSite->self) - { - if (auto group = index->expr->as()) - fmt = group->expr->as(); - else - fmt = index->expr->as(); - } - - if (!context.callSite->self && context.callSite->args.size > 0) - fmt = context.callSite->args.data[0]->as(); - - if (!fmt) - return false; - - std::vector expected = parseFormatString(context.solver->builtinTypes, fmt->value.data, fmt->value.size); - const auto& [params, tail] = flatten(context.arguments); - - size_t paramOffset = 1; - - // unify the prefix one argument at a time - for (size_t i = 0; i < expected.size() && i + paramOffset < params.size(); ++i) - { - context.solver->unify(params[i + paramOffset], expected[i], context.solver->rootScope); - } - - // if we know the argument count or if we have too many arguments for sure, we can issue an error - size_t numActualParams = params.size(); - size_t numExpectedParams = expected.size() + 1; // + 1 for the format string - - if (numExpectedParams != numActualParams && (!tail || numExpectedParams < numActualParams)) - context.solver->reportError(TypeError{context.callSite->location, CountMismatch{numExpectedParams, std::nullopt, numActualParams}}); - - TypePackId resultPack = arena->addTypePack({context.solver->builtinTypes->stringType}); - asMutable(context.result)->ty.emplace(resultPack); - - return true; -} - -static std::vector parsePatternString(NotNull builtinTypes, const char* data, size_t size) -{ - std::vector result; - int depth = 0; - bool parsingSet = false; - - for (size_t i = 0; i < size; ++i) - { - if (data[i] == '%') - { - ++i; - if (!parsingSet && i < size && data[i] == 'b') - i += 2; - } - else if (!parsingSet && data[i] == '[') - { - parsingSet = true; - if (i + 1 < size && data[i + 1] == ']') - i += 1; - } - else if (parsingSet && data[i] == ']') - { - parsingSet = false; - } - else if (data[i] == '(') - { - if (parsingSet) - continue; - - if (i + 1 < size && data[i + 1] == ')') - { - i++; - result.push_back(FFlag::LuauMatchReturnsOptionalString ? builtinTypes->optionalNumberType : builtinTypes->numberType); - continue; - } - - ++depth; - result.push_back(FFlag::LuauMatchReturnsOptionalString ? builtinTypes->optionalStringType : builtinTypes->stringType); - } - else if (data[i] == ')') - { - if (parsingSet) - continue; - - --depth; - - if (depth < 0) - break; - } - } - - if (depth != 0 || parsingSet) - return std::vector(); - - if (result.empty()) - result.push_back(FFlag::LuauMatchReturnsOptionalString ? builtinTypes->optionalStringType : builtinTypes->stringType); - - return result; -} - -static std::optional> magicFunctionGmatch( - TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate) -{ - auto [paramPack, _predicates] = withPredicate; - const auto& [params, tail] = flatten(paramPack); - - if (params.size() != 2) - return std::nullopt; - - TypeArena& arena = typechecker.currentModule->internalTypes; - - AstExprConstantString* pattern = nullptr; - size_t index = expr.self ? 0 : 1; - if (expr.args.size > index) - pattern = expr.args.data[index]->as(); - - if (!pattern) - return std::nullopt; - - std::vector returnTypes = parsePatternString(typechecker.builtinTypes, pattern->value.data, pattern->value.size); - - if (returnTypes.empty()) - return std::nullopt; - - typechecker.unify(params[0], typechecker.stringType, scope, expr.args.data[0]->location); - - const TypePackId emptyPack = arena.addTypePack({}); - const TypePackId returnList = arena.addTypePack(returnTypes); - const TypeId iteratorType = arena.addType(FunctionType{emptyPack, returnList}); - return WithPredicate{arena.addTypePack({iteratorType})}; -} - -static bool dcrMagicFunctionGmatch(MagicFunctionCallContext context) -{ - const auto& [params, tail] = flatten(context.arguments); - - if (params.size() != 2) - return false; - - TypeArena* arena = context.solver->arena; - - AstExprConstantString* pattern = nullptr; - size_t index = context.callSite->self ? 0 : 1; - if (context.callSite->args.size > index) - pattern = context.callSite->args.data[index]->as(); - - if (!pattern) - return false; - - std::vector returnTypes = parsePatternString(context.solver->builtinTypes, pattern->value.data, pattern->value.size); - - if (returnTypes.empty()) - return false; - - context.solver->unify(params[0], context.solver->builtinTypes->stringType, context.solver->rootScope); - - const TypePackId emptyPack = arena->addTypePack({}); - const TypePackId returnList = arena->addTypePack(returnTypes); - const TypeId iteratorType = arena->addType(FunctionType{emptyPack, returnList}); - const TypePackId resTypePack = arena->addTypePack({iteratorType}); - asMutable(context.result)->ty.emplace(resTypePack); - - return true; -} - -static std::optional> magicFunctionMatch( - TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate) -{ - auto [paramPack, _predicates] = withPredicate; - const auto& [params, tail] = flatten(paramPack); - - if (params.size() < 2 || params.size() > 3) - return std::nullopt; - - TypeArena& arena = typechecker.currentModule->internalTypes; - - AstExprConstantString* pattern = nullptr; - size_t patternIndex = expr.self ? 0 : 1; - if (expr.args.size > patternIndex) - pattern = expr.args.data[patternIndex]->as(); - - if (!pattern) - return std::nullopt; - - std::vector returnTypes = parsePatternString(typechecker.builtinTypes, pattern->value.data, pattern->value.size); - - if (returnTypes.empty()) - return std::nullopt; - - typechecker.unify(params[0], typechecker.stringType, scope, expr.args.data[0]->location); - - const TypeId optionalNumber = arena.addType(UnionType{{typechecker.nilType, typechecker.numberType}}); - - size_t initIndex = expr.self ? 1 : 2; - if (params.size() == 3 && expr.args.size > initIndex) - typechecker.unify(params[2], optionalNumber, scope, expr.args.data[initIndex]->location); - - const TypePackId returnList = arena.addTypePack(returnTypes); - return WithPredicate{returnList}; -} - -static bool dcrMagicFunctionMatch(MagicFunctionCallContext context) -{ - const auto& [params, tail] = flatten(context.arguments); - - if (params.size() < 2 || params.size() > 3) - return false; - - TypeArena* arena = context.solver->arena; - - AstExprConstantString* pattern = nullptr; - size_t patternIndex = context.callSite->self ? 0 : 1; - if (context.callSite->args.size > patternIndex) - pattern = context.callSite->args.data[patternIndex]->as(); - - if (!pattern) - return false; - - std::vector returnTypes = parsePatternString(context.solver->builtinTypes, pattern->value.data, pattern->value.size); - - if (returnTypes.empty()) - return false; - - context.solver->unify(params[0], context.solver->builtinTypes->stringType, context.solver->rootScope); - - const TypeId optionalNumber = arena->addType(UnionType{{context.solver->builtinTypes->nilType, context.solver->builtinTypes->numberType}}); - - size_t initIndex = context.callSite->self ? 1 : 2; - if (params.size() == 3 && context.callSite->args.size > initIndex) - context.solver->unify(params[2], optionalNumber, context.solver->rootScope); - - const TypePackId returnList = arena->addTypePack(returnTypes); - asMutable(context.result)->ty.emplace(returnList); - - return true; -} - -static std::optional> magicFunctionFind( - TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate) -{ - auto [paramPack, _predicates] = withPredicate; - const auto& [params, tail] = flatten(paramPack); - - if (params.size() < 2 || params.size() > 4) - return std::nullopt; - - TypeArena& arena = typechecker.currentModule->internalTypes; - - AstExprConstantString* pattern = nullptr; - size_t patternIndex = expr.self ? 0 : 1; - if (expr.args.size > patternIndex) - pattern = expr.args.data[patternIndex]->as(); - - if (!pattern) - return std::nullopt; - - bool plain = false; - size_t plainIndex = expr.self ? 2 : 3; - if (expr.args.size > plainIndex) - { - AstExprConstantBool* p = expr.args.data[plainIndex]->as(); - plain = p && p->value; - } - - std::vector returnTypes; - if (!plain) - { - returnTypes = parsePatternString(typechecker.builtinTypes, pattern->value.data, pattern->value.size); - - if (returnTypes.empty()) - return std::nullopt; - } - - typechecker.unify(params[0], typechecker.stringType, scope, expr.args.data[0]->location); - - const TypeId optionalNumber = arena.addType(UnionType{{typechecker.nilType, typechecker.numberType}}); - const TypeId optionalBoolean = arena.addType(UnionType{{typechecker.nilType, typechecker.booleanType}}); - - size_t initIndex = expr.self ? 1 : 2; - if (params.size() >= 3 && expr.args.size > initIndex) - typechecker.unify(params[2], optionalNumber, scope, expr.args.data[initIndex]->location); - - if (params.size() == 4 && expr.args.size > plainIndex) - typechecker.unify(params[3], optionalBoolean, scope, expr.args.data[plainIndex]->location); - - returnTypes.insert(returnTypes.begin(), {optionalNumber, optionalNumber}); - - const TypePackId returnList = arena.addTypePack(returnTypes); - return WithPredicate{returnList}; -} - -static bool dcrMagicFunctionFind(MagicFunctionCallContext context) -{ - const auto& [params, tail] = flatten(context.arguments); - - if (params.size() < 2 || params.size() > 4) - return false; - - TypeArena* arena = context.solver->arena; - NotNull builtinTypes = context.solver->builtinTypes; - - AstExprConstantString* pattern = nullptr; - size_t patternIndex = context.callSite->self ? 0 : 1; - if (context.callSite->args.size > patternIndex) - pattern = context.callSite->args.data[patternIndex]->as(); - - if (!pattern) - return false; - - bool plain = false; - size_t plainIndex = context.callSite->self ? 2 : 3; - if (context.callSite->args.size > plainIndex) - { - AstExprConstantBool* p = context.callSite->args.data[plainIndex]->as(); - plain = p && p->value; - } - - std::vector returnTypes; - if (!plain) - { - returnTypes = parsePatternString(builtinTypes, pattern->value.data, pattern->value.size); - - if (returnTypes.empty()) - return false; - } - - context.solver->unify(params[0], builtinTypes->stringType, context.solver->rootScope); - - const TypeId optionalNumber = arena->addType(UnionType{{builtinTypes->nilType, builtinTypes->numberType}}); - const TypeId optionalBoolean = arena->addType(UnionType{{builtinTypes->nilType, builtinTypes->booleanType}}); - - size_t initIndex = context.callSite->self ? 1 : 2; - if (params.size() >= 3 && context.callSite->args.size > initIndex) - context.solver->unify(params[2], optionalNumber, context.solver->rootScope); - - if (params.size() == 4 && context.callSite->args.size > plainIndex) - context.solver->unify(params[3], optionalBoolean, context.solver->rootScope); - - returnTypes.insert(returnTypes.begin(), {optionalNumber, optionalNumber}); - - const TypePackId returnList = arena->addTypePack(returnTypes); - asMutable(context.result)->ty.emplace(returnList); - return true; + return arena->addType(FreeType{scope, builtinTypes->neverType, builtinTypes->unknownType}); } std::vector filterMap(TypeId type, TypeIdPredicate predicate) @@ -1590,4 +1387,11 @@ bool GenericTypePackDefinition::operator==(const GenericTypePackDefinition& rhs) return tp == rhs.tp && defaultValue == rhs.defaultValue; } +template<> +LUAU_NOINLINE Unifiable::Bound* emplaceType(Type* ty, TypeId& tyArg) +{ + LUAU_ASSERT(ty != follow(tyArg)); + return &ty->ty.emplace(tyArg); +} + } // namespace Luau diff --git a/Analysis/src/TypeArena.cpp b/Analysis/src/TypeArena.cpp index ed51517ea..6cf81471d 100644 --- a/Analysis/src/TypeArena.cpp +++ b/Analysis/src/TypeArena.cpp @@ -94,6 +94,26 @@ TypePackId TypeArena::addTypePack(TypePackVar tp) return allocated; } +TypeId TypeArena::addTypeFunction(const TypeFunction& function, std::initializer_list types) +{ + return addType(TypeFunctionInstanceType{function, std::move(types)}); +} + +TypeId TypeArena::addTypeFunction(const TypeFunction& function, std::vector typeArguments, std::vector packArguments) +{ + return addType(TypeFunctionInstanceType{function, std::move(typeArguments), std::move(packArguments)}); +} + +TypePackId TypeArena::addTypePackFunction(const TypePackFunction& function, std::initializer_list types) +{ + return addTypePack(TypeFunctionInstanceTypePack{NotNull{&function}, std::move(types)}); +} + +TypePackId TypeArena::addTypePackFunction(const TypePackFunction& function, std::vector typeArguments, std::vector packArguments) +{ + return addTypePack(TypeFunctionInstanceTypePack{NotNull{&function}, std::move(typeArguments), std::move(packArguments)}); +} + void freeze(TypeArena& arena) { if (!FFlag::DebugLuauFreezeArena) diff --git a/Analysis/src/TypeAttach.cpp b/Analysis/src/TypeAttach.cpp index f9a162056..a28ff9871 100644 --- a/Analysis/src/TypeAttach.cpp +++ b/Analysis/src/TypeAttach.cpp @@ -9,6 +9,7 @@ #include "Luau/TypeInfer.h" #include "Luau/TypePack.h" #include "Luau/Type.h" +#include "Luau/TypeFunction.h" #include @@ -35,7 +36,21 @@ using SyntheticNames = std::unordered_map; namespace Luau { -static const char* getName(Allocator* allocator, SyntheticNames* syntheticNames, const Unifiable::Generic& gen) +static const char* getName(Allocator* allocator, SyntheticNames* syntheticNames, const GenericType& gen) +{ + size_t s = syntheticNames->size(); + char*& n = (*syntheticNames)[&gen]; + if (!n) + { + std::string str = gen.explicitName ? gen.name : generateName(s); + n = static_cast(allocator->allocate(str.size() + 1)); + strcpy(n, str.c_str()); + } + + return n; +} + +static const char* getName(Allocator* allocator, SyntheticNames* syntheticNames, const GenericTypePack& gen) { size_t s = syntheticNames->size(); char*& n = (*syntheticNames)[&gen]; @@ -80,28 +95,35 @@ class TypeRehydrationVisitor switch (ptv.type) { case PrimitiveType::NilType: - return allocator->alloc(Location(), std::nullopt, AstName("nil")); + return allocator->alloc(Location(), std::nullopt, AstName("nil"), std::nullopt, Location()); case PrimitiveType::Boolean: - return allocator->alloc(Location(), std::nullopt, AstName("boolean")); + return allocator->alloc(Location(), std::nullopt, AstName("boolean"), std::nullopt, Location()); case PrimitiveType::Number: - return allocator->alloc(Location(), std::nullopt, AstName("number")); + return allocator->alloc(Location(), std::nullopt, AstName("number"), std::nullopt, Location()); case PrimitiveType::String: - return allocator->alloc(Location(), std::nullopt, AstName("string")); + return allocator->alloc(Location(), std::nullopt, AstName("string"), std::nullopt, Location()); case PrimitiveType::Thread: - return allocator->alloc(Location(), std::nullopt, AstName("thread")); + return allocator->alloc(Location(), std::nullopt, AstName("thread"), std::nullopt, Location()); + case PrimitiveType::Buffer: + return allocator->alloc(Location(), std::nullopt, AstName("buffer"), std::nullopt, Location()); + case PrimitiveType::Function: + return allocator->alloc(Location(), std::nullopt, AstName("function"), std::nullopt, Location()); + case PrimitiveType::Table: + return allocator->alloc(Location(), std::nullopt, AstName("table"), std::nullopt, Location()); default: + LUAU_ASSERT(false); // this should be unreachable. return nullptr; } } AstType* operator()(const BlockedType& btv) { - return allocator->alloc(Location(), std::nullopt, AstName("*blocked*")); + return allocator->alloc(Location(), std::nullopt, AstName("*blocked*"), std::nullopt, Location()); } AstType* operator()(const PendingExpansionType& petv) { - return allocator->alloc(Location(), std::nullopt, AstName("*pending-expansion*")); + return allocator->alloc(Location(), std::nullopt, AstName("*pending-expansion*"), std::nullopt, Location()); } AstType* operator()(const SingletonType& stv) @@ -121,8 +143,14 @@ class TypeRehydrationVisitor AstType* operator()(const AnyType&) { - return allocator->alloc(Location(), std::nullopt, AstName("any")); + return allocator->alloc(Location(), std::nullopt, AstName("any"), std::nullopt, Location()); + } + + AstType* operator()(const NoRefineType&) + { + return allocator->alloc(Location(), std::nullopt, AstName("*no-refine*"), std::nullopt, Location()); } + AstType* operator()(const TableType& ttv) { RecursionCounter counter(&count); @@ -143,15 +171,17 @@ class TypeRehydrationVisitor parameters.data[i] = {{}, rehydrate(ttv.instantiatedTypePackParams[i])}; } - return allocator->alloc(Location(), std::nullopt, AstName(ttv.name->c_str()), parameters.size != 0, parameters); + return allocator->alloc( + Location(), std::nullopt, AstName(ttv.name->c_str()), std::nullopt, Location(), parameters.size != 0, parameters + ); } if (hasSeen(&ttv)) { if (ttv.name) - return allocator->alloc(Location(), std::nullopt, AstName(ttv.name->c_str())); + return allocator->alloc(Location(), std::nullopt, AstName(ttv.name->c_str()), std::nullopt, Location()); else - return allocator->alloc(Location(), std::nullopt, AstName("")); + return allocator->alloc(Location(), std::nullopt, AstName(""), std::nullopt, Location()); } AstArray props; @@ -165,7 +195,7 @@ class TypeRehydrationVisitor char* name = allocateString(*allocator, propName); props.data[idx].name = AstName(name); - props.data[idx].type = Luau::visit(*this, prop.type->ty); + props.data[idx].type = Luau::visit(*this, prop.type()->ty); props.data[idx].location = Location(); idx++; } @@ -194,7 +224,7 @@ class TypeRehydrationVisitor char* name = allocateString(*allocator, ctv.name); if (!options.expandClassProps || hasSeen(&ctv) || count > 1) - return allocator->alloc(Location(), std::nullopt, AstName{name}); + return allocator->alloc(Location(), std::nullopt, AstName{name}, std::nullopt, Location()); AstArray props; props.size = ctv.props.size(); @@ -206,12 +236,22 @@ class TypeRehydrationVisitor char* name = allocateString(*allocator, propName); props.data[idx].name = AstName{name}; - props.data[idx].type = Luau::visit(*this, prop.type->ty); + props.data[idx].type = Luau::visit(*this, prop.type()->ty); props.data[idx].location = Location(); idx++; } - return allocator->alloc(Location(), props); + AstTableIndexer* indexer = nullptr; + if (ctv.indexer) + { + RecursionCounter counter(&count); + + indexer = allocator->alloc(); + indexer->indexType = Luau::visit(*this, ctv.indexer->indexType->ty); + indexer->resultType = Luau::visit(*this, ctv.indexer->indexResultType->ty); + } + + return allocator->alloc(Location(), props, indexer); } AstType* operator()(const FunctionType& ftv) @@ -219,7 +259,7 @@ class TypeRehydrationVisitor RecursionCounter counter(&count); if (hasSeen(&ftv)) - return allocator->alloc(Location(), std::nullopt, AstName("")); + return allocator->alloc(Location(), std::nullopt, AstName(""), std::nullopt, Location()); AstArray generics; generics.size = ftv.generics.size(); @@ -237,7 +277,7 @@ class TypeRehydrationVisitor size_t numGenericPacks = 0; for (auto it = ftv.genericPacks.begin(); it != ftv.genericPacks.end(); ++it) { - if (auto gtv = get(*it)) + if (auto gtv = get(*it)) genericPacks.data[numGenericPacks++] = {AstName(gtv->name.c_str()), Location(), nullptr}; } @@ -286,23 +326,26 @@ class TypeRehydrationVisitor retTailAnnotation = rehydrate(*retTail); return allocator->alloc( - Location(), generics, genericPacks, AstTypeList{argTypes, argTailAnnotation}, argNames, AstTypeList{returnTypes, retTailAnnotation}); + Location(), generics, genericPacks, AstTypeList{argTypes, argTailAnnotation}, argNames, AstTypeList{returnTypes, retTailAnnotation} + ); } AstType* operator()(const Unifiable::Error&) { - return allocator->alloc(Location(), std::nullopt, AstName("Unifiable")); + return allocator->alloc(Location(), std::nullopt, AstName("Unifiable"), std::nullopt, Location()); } AstType* operator()(const GenericType& gtv) { - return allocator->alloc(Location(), std::nullopt, AstName(getName(allocator, syntheticNames, gtv))); + return allocator->alloc( + Location(), std::nullopt, AstName(getName(allocator, syntheticNames, gtv)), std::nullopt, Location() + ); } AstType* operator()(const Unifiable::Bound& bound) { return Luau::visit(*this, bound.boundTo->ty); } - AstType* operator()(const FreeType& ftv) + AstType* operator()(const FreeType& ft) { - return allocator->alloc(Location(), std::nullopt, AstName("free")); + return allocator->alloc(Location(), std::nullopt, AstName("free"), std::nullopt, Location()); } AstType* operator()(const UnionType& uv) { @@ -328,21 +371,28 @@ class TypeRehydrationVisitor } AstType* operator()(const LazyType& ltv) { - return allocator->alloc(Location(), std::nullopt, AstName("")); + if (TypeId unwrapped = ltv.unwrapped.load()) + return Luau::visit(*this, unwrapped->ty); + + return allocator->alloc(Location(), std::nullopt, AstName(""), std::nullopt, Location()); } AstType* operator()(const UnknownType& ttv) { - return allocator->alloc(Location(), std::nullopt, AstName{"unknown"}); + return allocator->alloc(Location(), std::nullopt, AstName{"unknown"}, std::nullopt, Location()); } AstType* operator()(const NeverType& ttv) { - return allocator->alloc(Location(), std::nullopt, AstName{"never"}); + return allocator->alloc(Location(), std::nullopt, AstName{"never"}, std::nullopt, Location()); } AstType* operator()(const NegationType& ntv) { // FIXME: do the same thing we do with ErrorType throw InternalCompilerError("Cannot convert NegationType into AstNode"); } + AstType* operator()(const TypeFunctionInstanceType& tfit) + { + return allocator->alloc(Location(), std::nullopt, AstName{tfit.function->name.c_str()}, std::nullopt, Location()); + } private: Allocator* allocator; @@ -413,6 +463,11 @@ class TypePackRehydrationVisitor return allocator->alloc(Location(), AstName("Unifiable")); } + AstTypePack* operator()(const TypeFunctionInstanceTypePack& tfitp) const + { + return allocator->alloc(Location(), AstName(tfitp.function->name.c_str())); + } + private: Allocator* allocator; SyntheticNames* syntheticNames; diff --git a/Analysis/src/TypeChecker2.cpp b/Analysis/src/TypeChecker2.cpp index aacfd7295..2634b89ed 100644 --- a/Analysis/src/TypeChecker2.cpp +++ b/Analysis/src/TypeChecker2.cpp @@ -3,26 +3,36 @@ #include "Luau/Ast.h" #include "Luau/AstQuery.h" -#include "Luau/Clone.h" #include "Luau/Common.h" #include "Luau/DcrLogger.h" +#include "Luau/DenseHash.h" #include "Luau/Error.h" +#include "Luau/InsertionOrderedMap.h" #include "Luau/Instantiation.h" #include "Luau/Metamethods.h" #include "Luau/Normalize.h" +#include "Luau/OverloadResolution.h" +#include "Luau/Subtyping.h" +#include "Luau/TimeTrace.h" #include "Luau/ToString.h" #include "Luau/TxnLog.h" #include "Luau/Type.h" -#include "Luau/TypeReduction.h" +#include "Luau/TypeFunction.h" +#include "Luau/TypeFunctionReductionGuesser.h" +#include "Luau/TypeFwd.h" +#include "Luau/TypePack.h" +#include "Luau/TypePath.h" #include "Luau/TypeUtils.h" -#include "Luau/Unifier.h" +#include "Luau/TypeOrPack.h" +#include "Luau/VisitType.h" #include +#include +#include LUAU_FASTFLAG(DebugLuauMagicTypes) -LUAU_FASTFLAG(DebugLuauDontReduceTypes) - -LUAU_FASTFLAG(LuauNegatedClassTypes) +LUAU_FASTFLAG(LuauUserDefinedTypeFunctions2) +LUAU_DYNAMIC_FASTINT(LuauTypeSolverRelease) namespace Luau { @@ -32,6 +42,7 @@ namespace Luau using PrintLineProc = void (*)(const std::string&); extern PrintLineProc luauPrintLine; + /* Push a scope onto the end of a stack for the lifetime of the StackPusher instance. * TypeChecker2 uses this to maintain knowledge about which scope encloses every * given AstNode. @@ -67,6 +78,37 @@ struct StackPusher } }; +struct PropertyTypes +{ + // a vector of all the types assigned to the given property. + std::vector typesOfProp; + + // a vector of all the types that are missing the given property. + std::vector missingProp; + + bool foundOneProp() const + { + return !typesOfProp.empty(); + } + + bool noneMissingProp() const + { + return missingProp.empty(); + } + + bool foundMissingProp() const + { + return !missingProp.empty(); + } +}; + +struct PropertyType +{ + NormalizationResult present; + std::optional result; +}; + + static std::optional getIdentifierOfBaseVar(AstExpr* node) { if (AstExprGlobal* expr = node->as()) @@ -84,2006 +126,3053 @@ static std::optional getIdentifierOfBaseVar(AstExpr* node) return std::nullopt; } -struct TypeChecker2 +template +bool areEquivalent(const T& a, const T& b) { - NotNull builtinTypes; - DcrLogger* logger; - InternalErrorReporter ice; // FIXME accept a pointer from Frontend - const SourceModule* sourceModule; - Module* module; - TypeArena testArena; - - std::vector> stack; + if (a.function != b.function) + return false; - UnifierSharedState sharedState{&ice}; - Normalizer normalizer{&testArena, builtinTypes, NotNull{&sharedState}}; + if (a.typeArguments.size() != b.typeArguments.size() || a.packArguments.size() != b.packArguments.size()) + return false; - TypeChecker2(NotNull builtinTypes, DcrLogger* logger, const SourceModule* sourceModule, Module* module) - : builtinTypes(builtinTypes) - , logger(logger) - , sourceModule(sourceModule) - , module(module) + for (size_t i = 0; i < a.typeArguments.size(); ++i) { + if (follow(a.typeArguments[i]) != follow(b.typeArguments[i])) + return false; } - std::optional pushStack(AstNode* node) + for (size_t i = 0; i < a.packArguments.size(); ++i) { - if (Scope** scope = module->astScopes.find(node)) - return StackPusher{stack, *scope}; - else - return std::nullopt; + if (follow(a.packArguments[i]) != follow(b.packArguments[i])) + return false; } - TypePackId lookupPack(AstExpr* expr) + return true; +} + +struct TypeFunctionFinder : TypeOnceVisitor +{ + DenseHashSet mentionedFunctions{nullptr}; + DenseHashSet mentionedFunctionPacks{nullptr}; + + bool visit(TypeId ty, const TypeFunctionInstanceType&) override { - // If a type isn't in the type graph, it probably means that a recursion limit was exceeded. - // We'll just return anyType in these cases. Typechecking against any is very fast and this - // allows us not to think about this very much in the actual typechecking logic. - TypePackId* tp = module->astTypePacks.find(expr); - if (tp) - return follow(*tp); - else - return builtinTypes->anyTypePack; + mentionedFunctions.insert(ty); + return true; } - TypeId lookupType(AstExpr* expr) + bool visit(TypePackId tp, const TypeFunctionInstanceTypePack&) override { - // If a type isn't in the type graph, it probably means that a recursion limit was exceeded. - // We'll just return anyType in these cases. Typechecking against any is very fast and this - // allows us not to think about this very much in the actual typechecking logic. - TypeId* ty = module->astTypes.find(expr); - if (ty) - return follow(*ty); + mentionedFunctionPacks.insert(tp); + return true; + } +}; + +struct InternalTypeFunctionFinder : TypeOnceVisitor +{ + DenseHashSet internalFunctions{nullptr}; + DenseHashSet internalPackFunctions{nullptr}; + DenseHashSet mentionedFunctions{nullptr}; + DenseHashSet mentionedFunctionPacks{nullptr}; - TypePackId* tp = module->astTypePacks.find(expr); - if (tp) - return flattenPack(*tp); + InternalTypeFunctionFinder(std::vector& declStack) + { + TypeFunctionFinder f; + for (TypeId fn : declStack) + f.traverse(fn); - return builtinTypes->anyType; + mentionedFunctions = std::move(f.mentionedFunctions); + mentionedFunctionPacks = std::move(f.mentionedFunctionPacks); } - TypeId lookupAnnotation(AstType* annotation) + bool visit(TypeId ty, const TypeFunctionInstanceType& tfit) override { - if (FFlag::DebugLuauMagicTypes) + bool hasGeneric = false; + + for (TypeId p : tfit.typeArguments) + { + if (get(follow(p))) + { + hasGeneric = true; + break; + } + } + + for (TypePackId p : tfit.packArguments) + { + if (get(follow(p))) + { + hasGeneric = true; + break; + } + } + + if (hasGeneric) { - if (auto ref = annotation->as(); ref && ref->name == "_luau_print" && ref->parameters.size > 0) + for (TypeId mentioned : mentionedFunctions) { - if (auto ann = ref->parameters.data[0].type) + const TypeFunctionInstanceType* mentionedTfit = get(mentioned); + LUAU_ASSERT(mentionedTfit); + if (areEquivalent(tfit, *mentionedTfit)) { - TypeId argTy = lookupAnnotation(ref->parameters.data[0].type); - luauPrintLine(format( - "_luau_print (%d, %d): %s\n", annotation->location.begin.line, annotation->location.begin.column, toString(argTy).c_str())); - return follow(argTy); + return true; } } - } - TypeId* ty = module->astResolvedTypes.find(annotation); - LUAU_ASSERT(ty); - return follow(*ty); - } + internalFunctions.insert(ty); + } - TypePackId lookupPackAnnotation(AstTypePack* annotation) - { - TypePackId* tp = module->astResolvedTypePacks.find(annotation); - LUAU_ASSERT(tp); - return follow(*tp); + return true; } - TypePackId reconstructPack(AstArray exprs, TypeArena& arena) + bool visit(TypePackId tp, const TypeFunctionInstanceTypePack& tfitp) override { - if (exprs.size == 0) - return arena.addTypePack(TypePack{{}, std::nullopt}); - - std::vector head; + bool hasGeneric = false; - for (size_t i = 0; i < exprs.size - 1; ++i) + for (TypeId p : tfitp.typeArguments) { - head.push_back(lookupType(exprs.data[i])); + if (get(follow(p))) + { + hasGeneric = true; + break; + } } - TypePackId tail = lookupPack(exprs.data[exprs.size - 1]); - return arena.addTypePack(TypePack{head, tail}); - } - - Scope* findInnermostScope(Location location) - { - Scope* bestScope = module->getModuleScope().get(); - Location bestLocation = module->scopes[0].first; + for (TypePackId p : tfitp.packArguments) + { + if (get(follow(p))) + { + hasGeneric = true; + break; + } + } - for (size_t i = 0; i < module->scopes.size(); ++i) + if (hasGeneric) { - auto& [scopeBounds, scope] = module->scopes[i]; - if (scopeBounds.encloses(location)) + for (TypePackId mentioned : mentionedFunctionPacks) { - if (scopeBounds.begin > bestLocation.begin || scopeBounds.end < bestLocation.end) + const TypeFunctionInstanceTypePack* mentionedTfitp = get(mentioned); + LUAU_ASSERT(mentionedTfitp); + if (areEquivalent(tfitp, *mentionedTfitp)) { - bestScope = scope.get(); - bestLocation = scopeBounds; + return true; } } + + internalPackFunctions.insert(tp); } - return bestScope; + return true; } +}; - enum ValueContext - { - LValue, - RValue - }; +void check( + NotNull builtinTypes, + NotNull typeFunctionRuntime, + NotNull unifierState, + NotNull limits, + DcrLogger* logger, + const SourceModule& sourceModule, + Module* module +) +{ + LUAU_TIMETRACE_SCOPE("check", "Typechecking"); - void visit(AstStat* stat) - { - auto pusher = pushStack(stat); - - if (auto s = stat->as()) - return visit(s); - else if (auto s = stat->as()) - return visit(s); - else if (auto s = stat->as()) - return visit(s); - else if (auto s = stat->as()) - return visit(s); - else if (auto s = stat->as()) - return visit(s); - else if (auto s = stat->as()) - return visit(s); - else if (auto s = stat->as()) - return visit(s); - else if (auto s = stat->as()) - return visit(s); - else if (auto s = stat->as()) - return visit(s); - else if (auto s = stat->as()) - return visit(s); - else if (auto s = stat->as()) - return visit(s); - else if (auto s = stat->as()) - return visit(s); - else if (auto s = stat->as()) - return visit(s); - else if (auto s = stat->as()) - return visit(s); - else if (auto s = stat->as()) - return visit(s); - else if (auto s = stat->as()) - return visit(s); - else if (auto s = stat->as()) - return visit(s); - else if (auto s = stat->as()) - return visit(s); - else if (auto s = stat->as()) - return visit(s); - else if (auto s = stat->as()) - return visit(s); - else - LUAU_ASSERT(!"TypeChecker2 encountered an unknown node type"); - } + TypeChecker2 typeChecker{builtinTypes, typeFunctionRuntime, unifierState, limits, logger, &sourceModule, module}; - void visit(AstStatBlock* block) - { - auto StackPusher = pushStack(block); + typeChecker.visit(sourceModule.root); - for (AstStat* statement : block->body) - visit(statement); - } + // if the only error we're producing is one about constraint solving being incomplete, we can silence it. + // this means we won't give this warning if types seem totally nonsensical, but there are no other errors. + // this is probably, on the whole, a good decision to not annoy users though. + if (module->errors.size() == 1 && get(module->errors[0])) + module->errors.clear(); - void visit(AstStatIf* ifStatement) - { - visit(ifStatement->condition, RValue); - visit(ifStatement->thenbody); - if (ifStatement->elsebody) - visit(ifStatement->elsebody); - } + unfreeze(module->interfaceTypes); + copyErrors(module->errors, module->interfaceTypes, builtinTypes); + freeze(module->interfaceTypes); +} + +TypeChecker2::TypeChecker2( + NotNull builtinTypes, + NotNull typeFunctionRuntime, + NotNull unifierState, + NotNull limits, + DcrLogger* logger, + const SourceModule* sourceModule, + Module* module +) + : builtinTypes(builtinTypes) + , typeFunctionRuntime(typeFunctionRuntime) + , logger(logger) + , limits(limits) + , ice(unifierState->iceHandler) + , sourceModule(sourceModule) + , module(module) + , normalizer{&module->internalTypes, builtinTypes, unifierState, /* cacheInhabitance */ true} + , _subtyping{builtinTypes, NotNull{&module->internalTypes}, NotNull{&normalizer}, typeFunctionRuntime, NotNull{unifierState->iceHandler}} + , subtyping(&_subtyping) +{ +} - void visit(AstStatWhile* whileStatement) +bool TypeChecker2::allowsNoReturnValues(const TypePackId tp) +{ + for (TypeId ty : tp) { - visit(whileStatement->condition, RValue); - visit(whileStatement->body); + if (!get(follow(ty))) + return false; } - void visit(AstStatRepeat* repeatStatement) + return true; +} + +Location TypeChecker2::getEndLocation(const AstExprFunction* function) +{ + Location loc = function->location; + if (loc.begin.line != loc.end.line) { - visit(repeatStatement->body); - visit(repeatStatement->condition, RValue); + Position begin = loc.end; + begin.column = std::max(0u, begin.column - 3); + loc = Location(begin, 3); } - void visit(AstStatBreak*) {} + return loc; +} - void visit(AstStatContinue*) {} +bool TypeChecker2::isErrorCall(const AstExprCall* call) +{ + const AstExprGlobal* global = call->func->as(); + if (!global) + return false; - void visit(AstStatReturn* ret) + if (global->name == "error") + return true; + else if (global->name == "assert") { - Scope* scope = findInnermostScope(ret->location); - TypePackId expectedRetType = scope->returnType; - - TypeArena* arena = &testArena; - TypePackId actualRetType = reconstructPack(ret->list, *arena); + // assert() will error because it is missing the first argument + if (call->args.size == 0) + return true; - Unifier u{NotNull{&normalizer}, Mode::Strict, stack.back(), ret->location, Covariant}; + if (AstExprConstantBool* expr = call->args.data[0]->as()) + if (!expr->value) + return true; + } - u.tryUnify(actualRetType, expectedRetType); - const bool ok = u.errors.empty() && u.log.empty(); + return false; +} - if (!ok) +bool TypeChecker2::hasBreak(AstStat* node) +{ + if (AstStatBlock* stat = node->as()) + { + for (size_t i = 0; i < stat->body.size; ++i) { - for (const TypeError& e : u.errors) - reportError(e); + if (hasBreak(stat->body.data[i])) + return true; } - for (AstExpr* expr : ret->list) - visit(expr, RValue); + return false; } - void visit(AstStatExpr* expr) - { - visit(expr->expr, RValue); - } + if (node->is()) + return true; - void visit(AstStatLocal* local) + if (AstStatIf* stat = node->as()) { - size_t count = std::max(local->values.size, local->vars.size); - for (size_t i = 0; i < count; ++i) - { - AstExpr* value = i < local->values.size ? local->values.data[i] : nullptr; - const bool isPack = value && (value->is() || value->is()); - - if (value) - visit(value, RValue); - - if (i != local->values.size - 1 || !isPack) - { - AstLocal* var = i < local->vars.size ? local->vars.data[i] : nullptr; - - if (var && var->annotation) - { - TypeId annotationType = lookupAnnotation(var->annotation); - TypeId valueType = value ? lookupType(value) : nullptr; - if (valueType) - { - ErrorVec errors = tryUnify(stack.back(), value->location, valueType, annotationType); - if (!errors.empty()) - reportErrors(std::move(errors)); - } - - visit(var->annotation); - } - } - else if (value) - { - TypePackId valuePack = lookupPack(value); - TypePack valueTypes; - if (i < local->vars.size) - valueTypes = extendTypePack(module->internalTypes, builtinTypes, valuePack, local->vars.size - i); - - Location errorLocation; - for (size_t j = i; j < local->vars.size; ++j) - { - if (j - i >= valueTypes.head.size()) - { - errorLocation = local->vars.data[j]->location; - break; - } - - AstLocal* var = local->vars.data[j]; - if (var->annotation) - { - TypeId varType = lookupAnnotation(var->annotation); - ErrorVec errors = tryUnify(stack.back(), value->location, valueTypes.head[j - i], varType); - if (!errors.empty()) - reportErrors(std::move(errors)); + if (hasBreak(stat->thenbody)) + return true; - visit(var->annotation); - } - } + if (stat->elsebody && hasBreak(stat->elsebody)) + return true; - if (valueTypes.head.size() < local->vars.size - i) - { - reportError( - CountMismatch{ - // We subtract 1 here because the final AST - // expression is not worth one value. It is worth 0 - // or more depending on valueTypes.head - local->values.size - 1 + valueTypes.head.size(), - std::nullopt, - local->vars.size, - local->values.data[local->values.size - 1]->is() ? CountMismatch::FunctionResult - : CountMismatch::ExprListResult, - }, - errorLocation); - } - } - } + return false; } - void visit(AstStatFor* forStatement) + return false; +} + +const AstStat* TypeChecker2::getFallthrough(const AstStat* node) +{ + if (const AstStatBlock* stat = node->as()) { - NotNull scope = stack.back(); + if (stat->body.size == 0) + return stat; - if (forStatement->var->annotation) + for (size_t i = 0; i < stat->body.size - 1; ++i) { - visit(forStatement->var->annotation); - reportErrors(tryUnify(scope, forStatement->var->location, builtinTypes->numberType, lookupAnnotation(forStatement->var->annotation))); + if (getFallthrough(stat->body.data[i]) == nullptr) + return nullptr; } - auto checkNumber = [this, scope](AstExpr* expr) { - if (!expr) - return; - - visit(expr, RValue); - reportErrors(tryUnify(scope, expr->location, lookupType(expr), builtinTypes->numberType)); - }; - - checkNumber(forStatement->from); - checkNumber(forStatement->to); - checkNumber(forStatement->step); - - visit(forStatement->body); + return getFallthrough(stat->body.data[stat->body.size - 1]); } - void visit(AstStatForIn* forInStatement) + if (const AstStatIf* stat = node->as()) { - for (AstLocal* local : forInStatement->vars) + if (const AstStat* thenf = getFallthrough(stat->thenbody)) + return thenf; + + if (stat->elsebody) { - if (local->annotation) - visit(local->annotation); - } + if (const AstStat* elsef = getFallthrough(stat->elsebody)) + return elsef; - for (AstExpr* expr : forInStatement->values) - visit(expr, RValue); + return nullptr; + } + else + return stat; + } - visit(forInStatement->body); + if (node->is()) + return nullptr; - // Rule out crazy stuff. Maybe possible if the file is not syntactically valid. - if (!forInStatement->vars.size || !forInStatement->values.size) - return; + if (const AstStatExpr* stat = node->as()) + { + if (AstExprCall* call = stat->expr->as(); call && isErrorCall(call)) + return nullptr; - NotNull scope = stack.back(); - TypeArena& arena = testArena; + return stat; + } - std::vector variableTypes; - for (AstLocal* var : forInStatement->vars) + if (const AstStatWhile* stat = node->as()) + { + if (AstExprConstantBool* expr = stat->condition->as()) { - std::optional ty = scope->lookup(var); - LUAU_ASSERT(ty); - variableTypes.emplace_back(*ty); + if (expr->value && !hasBreak(stat->body)) + return nullptr; } - // ugh. There's nothing in the AST to hang a whole type pack on for the - // set of iteratees, so we have to piece it back together by hand. - std::vector valueTypes; - for (size_t i = 0; i < forInStatement->values.size - 1; ++i) - valueTypes.emplace_back(lookupType(forInStatement->values.data[i])); - TypePackId iteratorTail = lookupPack(forInStatement->values.data[forInStatement->values.size - 1]); - TypePackId iteratorPack = arena.addTypePack(valueTypes, iteratorTail); + return node; + } - // ... and then expand it out to 3 values (if possible) - TypePack iteratorTypes = extendTypePack(arena, builtinTypes, iteratorPack, 3); - if (iteratorTypes.head.empty()) + if (const AstStatRepeat* stat = node->as()) + { + if (AstExprConstantBool* expr = stat->condition->as()) { - reportError(GenericError{"for..in loops require at least one value to iterate over. Got zero"}, getLocation(forInStatement->values)); - return; + if (!expr->value && !hasBreak(stat->body)) + return nullptr; } - TypeId iteratorTy = follow(iteratorTypes.head[0]); - - auto checkFunction = [this, &arena, &scope, &forInStatement, &variableTypes]( - const FunctionType* iterFtv, std::vector iterTys, bool isMm) { - if (iterTys.size() < 1 || iterTys.size() > 3) - { - if (isMm) - reportError(GenericError{"__iter metamethod must return (next[, table[, state]])"}, getLocation(forInStatement->values)); - else - reportError(GenericError{"for..in loops must be passed (next[, table[, state]])"}, getLocation(forInStatement->values)); - return; - } - - // It is okay if there aren't enough iterators, but the iteratee must provide enough. - TypePack expectedVariableTypes = extendTypePack(arena, builtinTypes, iterFtv->retTypes, variableTypes.size()); - if (expectedVariableTypes.head.size() < variableTypes.size()) - { - if (isMm) - reportError( - GenericError{"__iter metamethod's next() function does not return enough values"}, getLocation(forInStatement->values)); - else - reportError(GenericError{"next() does not return enough values"}, forInStatement->values.data[0]->location); - } + if (getFallthrough(stat->body) == nullptr) + return nullptr; - for (size_t i = 0; i < std::min(expectedVariableTypes.head.size(), variableTypes.size()); ++i) - reportErrors(tryUnify(scope, forInStatement->vars.data[i]->location, variableTypes[i], expectedVariableTypes.head[i])); + return node; + } - // nextFn is going to be invoked with (arrayTy, startIndexTy) + return node; +} - // It will be passed two arguments on every iteration save the - // first. +std::optional TypeChecker2::pushStack(AstNode* node) +{ + if (Scope** scope = module->astScopes.find(node)) + return StackPusher{stack, *scope}; + else + return std::nullopt; +} - // It may be invoked with 0 or 1 argument on the first iteration. - // This depends on the types in iterateePack and therefore - // iteratorTypes. +void TypeChecker2::checkForInternalTypeFunction(TypeId ty, Location location) +{ + InternalTypeFunctionFinder finder(functionDeclStack); + finder.traverse(ty); - // If iteratorTypes is too short to be a valid call to nextFn, we have to report a count mismatch error. - // If 2 is too short to be a valid call to nextFn, we have to report a count mismatch error. - // If 2 is too long to be a valid call to nextFn, we have to report a count mismatch error. - auto [minCount, maxCount] = getParameterExtents(TxnLog::empty(), iterFtv->argTypes, /*includeHiddenVariadics*/ true); + for (TypeId internal : finder.internalFunctions) + reportError(WhereClauseNeeded{internal}, location); - if (minCount > 2) - reportError(CountMismatch{2, std::nullopt, minCount, CountMismatch::Arg}, forInStatement->vars.data[0]->location); - if (maxCount && *maxCount < 2) - reportError(CountMismatch{2, std::nullopt, *maxCount, CountMismatch::Arg}, forInStatement->vars.data[0]->location); + for (TypePackId internal : finder.internalPackFunctions) + reportError(PackWhereClauseNeeded{internal}, location); +} - TypePack flattenedArgTypes = extendTypePack(arena, builtinTypes, iterFtv->argTypes, 2); - size_t firstIterationArgCount = iterTys.empty() ? 0 : iterTys.size() - 1; - size_t actualArgCount = expectedVariableTypes.head.size(); +TypeId TypeChecker2::checkForTypeFunctionInhabitance(TypeId instance, Location location) +{ + if (seenTypeFunctionInstances.find(instance)) + return instance; + seenTypeFunctionInstances.insert(instance); + + ErrorVec errors = + reduceTypeFunctions( + instance, + location, + TypeFunctionContext{NotNull{&module->internalTypes}, builtinTypes, stack.back(), NotNull{&normalizer}, typeFunctionRuntime, ice, limits}, + true + ) + .errors; + if (!isErrorSuppressing(location, instance)) + reportErrors(std::move(errors)); + return instance; +} - if (firstIterationArgCount < minCount) - reportError(CountMismatch{2, std::nullopt, firstIterationArgCount, CountMismatch::Arg}, forInStatement->vars.data[0]->location); - else if (actualArgCount < minCount) - reportError(CountMismatch{2, std::nullopt, actualArgCount, CountMismatch::Arg}, forInStatement->vars.data[0]->location); +TypePackId TypeChecker2::lookupPack(AstExpr* expr) +{ + // If a type isn't in the type graph, it probably means that a recursion limit was exceeded. + // We'll just return anyType in these cases. Typechecking against any is very fast and this + // allows us not to think about this very much in the actual typechecking logic. + TypePackId* tp = module->astTypePacks.find(expr); + if (tp) + return follow(*tp); + else + return builtinTypes->anyTypePack; +} - if (iterTys.size() >= 2 && flattenedArgTypes.head.size() > 0) - { - size_t valueIndex = forInStatement->values.size > 1 ? 1 : 0; - reportErrors(tryUnify(scope, forInStatement->values.data[valueIndex]->location, iterTys[1], flattenedArgTypes.head[0])); - } +TypeId TypeChecker2::lookupType(AstExpr* expr) +{ + // If a type isn't in the type graph, it probably means that a recursion limit was exceeded. + // We'll just return anyType in these cases. Typechecking against any is very fast and this + // allows us not to think about this very much in the actual typechecking logic. + TypeId* ty = module->astTypes.find(expr); + if (ty) + return checkForTypeFunctionInhabitance(follow(*ty), expr->location); + + TypePackId* tp = module->astTypePacks.find(expr); + if (tp) + return checkForTypeFunctionInhabitance(flattenPack(*tp), expr->location); + + return builtinTypes->anyType; +} - if (iterTys.size() == 3 && flattenedArgTypes.head.size() > 1) - { - size_t valueIndex = forInStatement->values.size > 2 ? 2 : 0; - reportErrors(tryUnify(scope, forInStatement->values.data[valueIndex]->location, iterTys[2], flattenedArgTypes.head[1])); - } - }; - - /* - * If the first iterator argument is a function - * * There must be 1 to 3 iterator arguments. Name them (nextTy, - * arrayTy, startIndexTy) - * * The return type of nextTy() must correspond to the variables' - * types and counts. HOWEVER the first iterator will never be nil. - * * The first return value of nextTy must be compatible with - * startIndexTy. - * * The first argument to nextTy() must be compatible with arrayTy if - * present. nil if not. - * * The second argument to nextTy() must be compatible with - * startIndexTy if it is present. Else, it must be compatible with - * nil. - * * nextTy() must be callable with only 2 arguments. - */ - if (const FunctionType* nextFn = get(iteratorTy)) - { - checkFunction(nextFn, iteratorTypes.head, false); - } - else if (const TableType* ttv = get(iteratorTy)) - { - if ((forInStatement->vars.size == 1 || forInStatement->vars.size == 2) && ttv->indexer) +TypeId TypeChecker2::lookupAnnotation(AstType* annotation) +{ + if (FFlag::DebugLuauMagicTypes) + { + if (auto ref = annotation->as(); ref && ref->name == "_luau_print" && ref->parameters.size > 0) + { + if (auto ann = ref->parameters.data[0].type) { - reportErrors(tryUnify(scope, forInStatement->vars.data[0]->location, variableTypes[0], ttv->indexer->indexType)); - if (variableTypes.size() == 2) - reportErrors(tryUnify(scope, forInStatement->vars.data[1]->location, variableTypes[1], ttv->indexer->indexResultType)); + TypeId argTy = lookupAnnotation(ref->parameters.data[0].type); + luauPrintLine( + format("_luau_print (%d, %d): %s\n", annotation->location.begin.line, annotation->location.begin.column, toString(argTy).c_str()) + ); + return follow(argTy); } - else - reportError(GenericError{"Cannot iterate over a table without indexer"}, forInStatement->values.data[0]->location); - } - else if (get(iteratorTy) || get(iteratorTy) || get(iteratorTy)) - { - // nothing } - else if (std::optional iterMmTy = - findMetatableEntry(builtinTypes, module->errors, iteratorTy, "__iter", forInStatement->values.data[0]->location)) - { - Instantiation instantiation{TxnLog::empty(), &arena, TypeLevel{}, scope}; + } - if (std::optional instantiatedIterMmTy = instantiation.substitute(*iterMmTy)) - { - if (const FunctionType* iterMmFtv = get(*instantiatedIterMmTy)) - { - TypePackId argPack = arena.addTypePack({iteratorTy}); - reportErrors(tryUnify(scope, forInStatement->values.data[0]->location, argPack, iterMmFtv->argTypes)); + TypeId* ty = module->astResolvedTypes.find(annotation); + LUAU_ASSERT(ty); + return checkForTypeFunctionInhabitance(follow(*ty), annotation->location); +} - TypePack mmIteratorTypes = extendTypePack(arena, builtinTypes, iterMmFtv->retTypes, 3); +std::optional TypeChecker2::lookupPackAnnotation(AstTypePack* annotation) +{ + TypePackId* tp = module->astResolvedTypePacks.find(annotation); + if (tp != nullptr) + return {follow(*tp)}; + return {}; +} - if (mmIteratorTypes.head.size() == 0) - { - reportError(GenericError{"__iter must return at least one value"}, forInStatement->values.data[0]->location); - return; - } +TypeId TypeChecker2::lookupExpectedType(AstExpr* expr) +{ + if (TypeId* ty = module->astExpectedTypes.find(expr)) + return follow(*ty); - TypeId nextFn = follow(mmIteratorTypes.head[0]); + return builtinTypes->anyType; +} - if (std::optional instantiatedNextFn = instantiation.substitute(nextFn)) - { - std::vector instantiatedIteratorTypes = mmIteratorTypes.head; - instantiatedIteratorTypes[0] = *instantiatedNextFn; +TypePackId TypeChecker2::lookupExpectedPack(AstExpr* expr, TypeArena& arena) +{ + if (TypeId* ty = module->astExpectedTypes.find(expr)) + return arena.addTypePack(TypePack{{follow(*ty)}, std::nullopt}); - if (const FunctionType* nextFtv = get(*instantiatedNextFn)) - { - checkFunction(nextFtv, instantiatedIteratorTypes, true); - } - else - { - reportError(CannotCallNonFunction{*instantiatedNextFn}, forInStatement->values.data[0]->location); - } - } - else - { - reportError(UnificationTooComplex{}, forInStatement->values.data[0]->location); - } - } - else - { - // TODO: This will not tell the user that this is because the - // metamethod isn't callable. This is not ideal, and we should - // improve this error message. + return builtinTypes->anyTypePack; +} - // TODO: This will also not handle intersections of functions or - // callable tables (which are supported by the runtime). - reportError(CannotCallNonFunction{*iterMmTy}, forInStatement->values.data[0]->location); - } - } - else - { - reportError(UnificationTooComplex{}, forInStatement->values.data[0]->location); - } - } - else - { - reportError(CannotCallNonFunction{iteratorTy}, forInStatement->values.data[0]->location); - } - } +TypePackId TypeChecker2::reconstructPack(AstArray exprs, TypeArena& arena) +{ + if (exprs.size == 0) + return arena.addTypePack(TypePack{{}, std::nullopt}); - void visit(AstStatAssign* assign) - { - size_t count = std::min(assign->vars.size, assign->values.size); + std::vector head; - for (size_t i = 0; i < count; ++i) - { - AstExpr* lhs = assign->vars.data[i]; - visit(lhs, LValue); - TypeId lhsType = lookupType(lhs); + for (size_t i = 0; i < exprs.size - 1; ++i) + { + head.push_back(lookupType(exprs.data[i])); + } - AstExpr* rhs = assign->values.data[i]; - visit(rhs, RValue); - TypeId rhsType = lookupType(rhs); + TypePackId tail = lookupPack(exprs.data[exprs.size - 1]); + return arena.addTypePack(TypePack{head, tail}); +} - if (get(lhsType)) - continue; +Scope* TypeChecker2::findInnermostScope(Location location) +{ + Scope* bestScope = module->getModuleScope().get(); - if (!isSubtype(rhsType, lhsType, stack.back())) + bool didNarrow; + do + { + didNarrow = false; + for (auto scope : bestScope->children) + { + if (scope->location.encloses(location)) { - reportError(TypeMismatch{lhsType, rhsType}, rhs->location); + bestScope = scope.get(); + didNarrow = true; + break; } } - } + } while (didNarrow && bestScope->children.size() > 0); - void visit(AstStatCompoundAssign* stat) - { - AstExprBinary fake{stat->location, stat->op, stat->var, stat->value}; - TypeId resultTy = visit(&fake, stat); - TypeId varTy = lookupType(stat->var); + return bestScope; +} - reportErrors(tryUnify(stack.back(), stat->location, resultTy, varTy)); - } +void TypeChecker2::visit(AstStat* stat) +{ + auto pusher = pushStack(stat); + + if (auto s = stat->as()) + return visit(s); + else if (auto s = stat->as()) + return visit(s); + else if (auto s = stat->as()) + return visit(s); + else if (auto s = stat->as()) + return visit(s); + else if (auto s = stat->as()) + return visit(s); + else if (auto s = stat->as()) + return visit(s); + else if (auto s = stat->as()) + return visit(s); + else if (auto s = stat->as()) + return visit(s); + else if (auto s = stat->as()) + return visit(s); + else if (auto s = stat->as()) + return visit(s); + else if (auto s = stat->as()) + return visit(s); + else if (auto s = stat->as()) + return visit(s); + else if (auto s = stat->as()) + return visit(s); + else if (auto s = stat->as()) + return visit(s); + else if (auto s = stat->as()) + return visit(s); + else if (auto s = stat->as()) + return visit(s); + else if (auto f = stat->as()) + return visit(f); + else if (auto s = stat->as()) + return visit(s); + else if (auto s = stat->as()) + return visit(s); + else if (auto s = stat->as()) + return visit(s); + else if (auto s = stat->as()) + return visit(s); + else + LUAU_ASSERT(!"TypeChecker2 encountered an unknown node type"); +} - void visit(AstStatFunction* stat) - { - visit(stat->name, LValue); - visit(stat->func); - } +void TypeChecker2::visit(AstStatBlock* block) +{ + auto StackPusher = pushStack(block); + + for (AstStat* statement : block->body) + visit(statement); +} - void visit(AstStatLocalFunction* stat) +void TypeChecker2::visit(AstStatIf* ifStatement) +{ { - visit(stat->func); + InConditionalContext flipper{&typeContext}; + visit(ifStatement->condition, ValueContext::RValue); } - void visit(const AstTypeList* typeList) - { - for (AstType* ty : typeList->types) - visit(ty); + visit(ifStatement->thenbody); + if (ifStatement->elsebody) + visit(ifStatement->elsebody); +} - if (typeList->tailType) - visit(typeList->tailType); - } +void TypeChecker2::visit(AstStatWhile* whileStatement) +{ + visit(whileStatement->condition, ValueContext::RValue); + visit(whileStatement->body); +} - void visit(AstStatTypeAlias* stat) - { - visitGenerics(stat->generics, stat->genericPacks); - visit(stat->type); - } +void TypeChecker2::visit(AstStatRepeat* repeatStatement) +{ + visit(repeatStatement->body); + visit(repeatStatement->condition, ValueContext::RValue); +} - void visit(AstTypeList types) - { - for (AstType* type : types.types) - visit(type); - if (types.tailType) - visit(types.tailType); - } +void TypeChecker2::visit(AstStatBreak*) {} - void visit(AstStatDeclareFunction* stat) - { - visitGenerics(stat->generics, stat->genericPacks); - visit(stat->params); - visit(stat->retTypes); - } +void TypeChecker2::visit(AstStatContinue*) {} - void visit(AstStatDeclareGlobal* stat) - { - visit(stat->type); - } +void TypeChecker2::visit(AstStatReturn* ret) +{ + Scope* scope = findInnermostScope(ret->location); + TypePackId expectedRetType = scope->returnType; - void visit(AstStatDeclareClass* stat) - { - for (const AstDeclaredClassProp& prop : stat->props) - visit(prop.ty); - } + TypeArena* arena = &module->internalTypes; + TypePackId actualRetType = reconstructPack(ret->list, *arena); - void visit(AstStatError* stat) - { - for (AstExpr* expr : stat->expressions) - visit(expr, RValue); + testIsSubtype(actualRetType, expectedRetType, ret->location); - for (AstStat* s : stat->statements) - visit(s); - } + for (AstExpr* expr : ret->list) + visit(expr, ValueContext::RValue); +} + +void TypeChecker2::visit(AstStatExpr* expr) +{ + visit(expr->expr, ValueContext::RValue); +} - void visit(AstExpr* expr, ValueContext context) +void TypeChecker2::visit(AstStatLocal* local) +{ + size_t count = std::max(local->values.size, local->vars.size); + for (size_t i = 0; i < count; ++i) { - auto StackPusher = pushStack(expr); + AstExpr* value = i < local->values.size ? local->values.data[i] : nullptr; + const bool isPack = value && (value->is() || value->is()); + + if (value) + visit(value, ValueContext::RValue); - if (auto e = expr->as()) - return visit(e, context); - else if (auto e = expr->as()) - return visit(e); - else if (auto e = expr->as()) - return visit(e); - else if (auto e = expr->as()) - return visit(e); - else if (auto e = expr->as()) - return visit(e); - else if (auto e = expr->as()) - return visit(e); - else if (auto e = expr->as()) - return visit(e); - else if (auto e = expr->as()) - return visit(e); - else if (auto e = expr->as()) - return visit(e); - else if (auto e = expr->as()) - return visit(e, context); - else if (auto e = expr->as()) - return visit(e, context); - else if (auto e = expr->as()) - return visit(e); - else if (auto e = expr->as()) - return visit(e); - else if (auto e = expr->as()) - return visit(e); - else if (auto e = expr->as()) + if (i != local->values.size - 1 || !isPack) { - visit(e); - return; + AstLocal* var = i < local->vars.size ? local->vars.data[i] : nullptr; + + if (var && var->annotation) + { + TypeId annotationType = lookupAnnotation(var->annotation); + TypeId valueType = value ? lookupType(value) : nullptr; + if (valueType) + testIsSubtype(valueType, annotationType, value->location); + + visit(var->annotation); + } } - else if (auto e = expr->as()) - return visit(e); - else if (auto e = expr->as()) - return visit(e); - else if (auto e = expr->as()) - return visit(e); - else if (auto e = expr->as()) - return visit(e); - else - LUAU_ASSERT(!"TypeChecker2 encountered an unknown expression type"); - } + else if (value) + { + TypePackId valuePack = lookupPack(value); + TypePack valueTypes; + if (i < local->vars.size) + valueTypes = extendTypePack(module->internalTypes, builtinTypes, valuePack, local->vars.size - i); - void visit(AstExprGroup* expr, ValueContext context) - { - visit(expr->expr, context); + Location errorLocation; + for (size_t j = i; j < local->vars.size; ++j) + { + if (j - i >= valueTypes.head.size()) + { + errorLocation = local->vars.data[j]->location; + break; + } + + AstLocal* var = local->vars.data[j]; + if (var->annotation) + { + TypeId varType = lookupAnnotation(var->annotation); + testIsSubtype(valueTypes.head[j - i], varType, value->location); + + visit(var->annotation); + } + } + + if (valueTypes.head.size() < local->vars.size - i) + { + reportError( + CountMismatch{ + // We subtract 1 here because the final AST + // expression is not worth one value. It is worth 0 + // or more depending on valueTypes.head + local->values.size - 1 + valueTypes.head.size(), + std::nullopt, + local->vars.size, + local->values.data[local->values.size - 1]->is() ? CountMismatch::FunctionResult : CountMismatch::ExprListResult, + }, + errorLocation + ); + } + } } +} - void visit(AstExprConstantNil* expr) +void TypeChecker2::visit(AstStatFor* forStatement) +{ + if (forStatement->var->annotation) { - NotNull scope = stack.back(); - TypeId actualType = lookupType(expr); - TypeId expectedType = builtinTypes->nilType; - LUAU_ASSERT(isSubtype(actualType, expectedType, scope)); + visit(forStatement->var->annotation); + + TypeId annotatedType = lookupAnnotation(forStatement->var->annotation); + testIsSubtype(builtinTypes->numberType, annotatedType, forStatement->var->location); } - void visit(AstExprConstantBool* expr) + auto checkNumber = [this](AstExpr* expr) { - NotNull scope = stack.back(); - TypeId actualType = lookupType(expr); - TypeId expectedType = builtinTypes->booleanType; - LUAU_ASSERT(isSubtype(actualType, expectedType, scope)); - } + if (!expr) + return; - void visit(AstExprConstantNumber* expr) + visit(expr, ValueContext::RValue); + testIsSubtype(lookupType(expr), builtinTypes->numberType, expr->location); + }; + + checkNumber(forStatement->from); + checkNumber(forStatement->to); + checkNumber(forStatement->step); + + visit(forStatement->body); +} + +void TypeChecker2::visit(AstStatForIn* forInStatement) +{ + for (AstLocal* local : forInStatement->vars) { - NotNull scope = stack.back(); - TypeId actualType = lookupType(expr); - TypeId expectedType = builtinTypes->numberType; - LUAU_ASSERT(isSubtype(actualType, expectedType, scope)); + if (local->annotation) + visit(local->annotation); } - void visit(AstExprConstantString* expr) + for (AstExpr* expr : forInStatement->values) + visit(expr, ValueContext::RValue); + + visit(forInStatement->body); + + // Rule out crazy stuff. Maybe possible if the file is not syntactically valid. + if (!forInStatement->vars.size || !forInStatement->values.size) + return; + + NotNull scope = stack.back(); + TypeArena& arena = module->internalTypes; + + std::vector variableTypes; + for (AstLocal* var : forInStatement->vars) { - NotNull scope = stack.back(); - TypeId actualType = lookupType(expr); - TypeId expectedType = builtinTypes->stringType; - LUAU_ASSERT(isSubtype(actualType, expectedType, scope)); + std::optional ty = scope->lookup(var); + LUAU_ASSERT(ty); + variableTypes.emplace_back(*ty); } - void visit(AstExprLocal* expr) + AstExpr* firstValue = forInStatement->values.data[0]; + + // we need to build up a typepack for the iterators/values portion of the for-in statement. + std::vector valueTypes; + std::optional iteratorTail; + + // since the first value may be the only iterator (e.g. if it is a call), we want to + // look to see if it has a resulting typepack as our iterators. + TypePackId* retPack = module->astTypePacks.find(firstValue); + if (retPack) { - // TODO! + auto [head, tail] = flatten(*retPack); + valueTypes = head; + iteratorTail = tail; } - - void visit(AstExprGlobal* expr) + else { - // TODO! + valueTypes.emplace_back(lookupType(firstValue)); } - void visit(AstExprVarargs* expr) + // if the initial and expected types from the iterator unified during constraint solving, + // we'll have a resolved type to use here, but we'll only use it if either the iterator is + // directly present in the for-in statement or if we have an iterator state constraining us + TypeId* resolvedTy = module->astForInNextTypes.find(firstValue); + if (resolvedTy && (!retPack || valueTypes.size() > 1)) + valueTypes[0] = *resolvedTy; + + for (size_t i = 1; i < forInStatement->values.size - 1; ++i) { - // TODO! + valueTypes.emplace_back(lookupType(forInStatement->values.data[i])); } - ErrorVec visitOverload(AstExprCall* call, NotNull overloadFunctionType, const std::vector& argLocs, - TypePackId expectedArgTypes, TypePackId expectedRetType) + // if we had more than one value, the tail from the first value is no longer appropriate to use. + if (forInStatement->values.size > 1) { - ErrorVec overloadErrors = - tryUnify(stack.back(), call->location, overloadFunctionType->retTypes, expectedRetType, CountMismatch::FunctionResult); - - size_t argIndex = 0; - auto inferredArgIt = begin(overloadFunctionType->argTypes); - auto expectedArgIt = begin(expectedArgTypes); - while (inferredArgIt != end(overloadFunctionType->argTypes) && expectedArgIt != end(expectedArgTypes)) - { - Location argLoc = (argIndex >= argLocs.size()) ? argLocs.back() : argLocs[argIndex]; - ErrorVec argErrors = tryUnify(stack.back(), argLoc, *expectedArgIt, *inferredArgIt); - for (TypeError e : argErrors) - overloadErrors.emplace_back(e); - - ++argIndex; - ++inferredArgIt; - ++expectedArgIt; - } + auto [head, tail] = flatten(lookupPack(forInStatement->values.data[forInStatement->values.size - 1])); + valueTypes.insert(valueTypes.end(), head.begin(), head.end()); + iteratorTail = tail; + } - // piggyback on the unifier for arity checking, but we can't do this for checking the actual arguments since the locations would be bad - ErrorVec argumentErrors = tryUnify(stack.back(), call->location, expectedArgTypes, overloadFunctionType->argTypes); - for (TypeError e : argumentErrors) - if (get(e) != nullptr) - overloadErrors.emplace_back(std::move(e)); + // and now we can put everything together to get the actual typepack of the iterators. + TypePackId iteratorPack = arena.addTypePack(valueTypes, iteratorTail); - return overloadErrors; + // ... and then expand it out to 3 values (if possible) + TypePack iteratorTypes = extendTypePack(arena, builtinTypes, iteratorPack, 3); + if (iteratorTypes.head.empty()) + { + reportError(GenericError{"for..in loops require at least one value to iterate over. Got zero"}, getLocation(forInStatement->values)); + return; } + TypeId iteratorTy = follow(iteratorTypes.head[0]); - void reportOverloadResolutionErrors(AstExprCall* call, std::vector overloads, TypePackId expectedArgTypes, - const std::vector& overloadsThatMatchArgCount, std::vector> overloadsErrors) + auto checkFunction = [this, &arena, &forInStatement, &variableTypes](const FunctionType* iterFtv, std::vector iterTys, bool isMm) { - if (overloads.size() == 1) + if (iterTys.size() < 1 || iterTys.size() > 3) { - reportErrors(std::get<0>(overloadsErrors.front())); + if (isMm) + reportError(GenericError{"__iter metamethod must return (next[, table[, state]])"}, getLocation(forInStatement->values)); + else + reportError(GenericError{"for..in loops must be passed (next[, table[, state]])"}, getLocation(forInStatement->values)); + return; } - std::vector overloadTypes = overloadsThatMatchArgCount; - if (overloadsThatMatchArgCount.size() == 0) - { - reportError(GenericError{"No overload for function accepts " + std::to_string(size(expectedArgTypes)) + " arguments."}, call->location); - // If no overloads match argument count, just list all overloads. - overloadTypes = overloads; - } - else + // It is okay if there aren't enough iterators, but the iteratee must provide enough. + TypePack expectedVariableTypes = extendTypePack(arena, builtinTypes, iterFtv->retTypes, variableTypes.size()); + if (expectedVariableTypes.head.size() < variableTypes.size()) { - // Report errors of the first argument-count-matching, but failing overload - TypeId overload = overloadsThatMatchArgCount[0]; - - // Remove the overload we are reporting errors about from the list of alternatives - overloadTypes.erase(std::remove(overloadTypes.begin(), overloadTypes.end(), overload), overloadTypes.end()); - - const FunctionType* ftv = get(overload); - LUAU_ASSERT(ftv); // overload must be a function type here - - auto error = std::find_if(overloadsErrors.begin(), overloadsErrors.end(), [ftv](const std::pair& e) { - return ftv == std::get<1>(e); - }); - - LUAU_ASSERT(error != overloadsErrors.end()); - reportErrors(std::get<0>(*error)); - - // If only one overload matched, we don't need this error because we provided the previous errors. - if (overloadsThatMatchArgCount.size() == 1) - return; + if (isMm) + reportError(GenericError{"__iter metamethod's next() function does not return enough values"}, getLocation(forInStatement->values)); + else + reportError(GenericError{"next() does not return enough values"}, forInStatement->values.data[0]->location); } - std::string s; - for (size_t i = 0; i < overloadTypes.size(); ++i) - { - TypeId overload = follow(overloadTypes[i]); - - if (i > 0) - s += "; "; + for (size_t i = 0; i < std::min(expectedVariableTypes.head.size(), variableTypes.size()); ++i) + testIsSubtype(variableTypes[i], expectedVariableTypes.head[i], forInStatement->vars.data[i]->location); - if (i > 0 && i == overloadTypes.size() - 1) - s += "and "; + // nextFn is going to be invoked with (arrayTy, startIndexTy) - s += toString(overload); - } + // It will be passed two arguments on every iteration save the + // first. - if (overloadsThatMatchArgCount.size() == 0) - reportError(ExtraInformation{"Available overloads: " + s}, call->func->location); - else - reportError(ExtraInformation{"Other overloads are also not viable: " + s}, call->func->location); - } + // It may be invoked with 0 or 1 argument on the first iteration. + // This depends on the types in iterateePack and therefore + // iteratorTypes. - // Note: this is intentionally separated from `visit(AstExprCall*)` for stack allocation purposes. - void visitCall(AstExprCall* call) - { - TypeArena* arena = &testArena; - Instantiation instantiation{TxnLog::empty(), arena, TypeLevel{}, stack.back()}; + // If the iteratee is an error type, then we can't really say anything else about iteration over it. + // After all, it _could've_ been a table. + if (get(follow(flattenPack(iterFtv->argTypes)))) + return; - TypePackId expectedRetType = lookupPack(call); - TypeId functionType = lookupType(call->func); - TypeId testFunctionType = functionType; - TypePack args; - std::vector argLocs; - argLocs.reserve(call->args.size + 1); + // If iteratorTypes is too short to be a valid call to nextFn, we have to report a count mismatch error. + // If 2 is too short to be a valid call to nextFn, we have to report a count mismatch error. + // If 2 is too long to be a valid call to nextFn, we have to report a count mismatch error. + auto [minCount, maxCount] = getParameterExtents(TxnLog::empty(), iterFtv->argTypes, /*includeHiddenVariadics*/ true); - if (get(functionType) || get(functionType) || get(functionType)) - return; - else if (std::optional callMm = findMetatableEntry(builtinTypes, module->errors, functionType, "__call", call->func->location)) + TypePack flattenedArgTypes = extendTypePack(arena, builtinTypes, iterFtv->argTypes, 2); + size_t firstIterationArgCount = iterTys.empty() ? 0 : iterTys.size() - 1; + size_t actualArgCount = expectedVariableTypes.head.size(); + if (firstIterationArgCount < minCount) { - if (get(follow(*callMm))) - { - if (std::optional instantiatedCallMm = instantiation.substitute(*callMm)) - { - args.head.push_back(functionType); - argLocs.push_back(call->func->location); - testFunctionType = follow(*instantiatedCallMm); - } - else - { - reportError(UnificationTooComplex{}, call->func->location); - return; - } - } + if (isMm) + reportError(GenericError{"__iter metamethod must return (next[, table[, state]])"}, getLocation(forInStatement->values)); else - { - // TODO: This doesn't flag the __call metamethod as the problem - // very clearly. - reportError(CannotCallNonFunction{*callMm}, call->func->location); - return; - } + reportError(CountMismatch{2, std::nullopt, firstIterationArgCount, CountMismatch::Arg}, forInStatement->values.data[0]->location); } - else if (get(functionType)) + + else if (actualArgCount < minCount) { - if (std::optional instantiatedFunctionType = instantiation.substitute(functionType)) - { - testFunctionType = *instantiatedFunctionType; - } + if (isMm) + reportError(GenericError{"__iter metamethod must return (next[, table[, state]])"}, getLocation(forInStatement->values)); else - { - reportError(UnificationTooComplex{}, call->func->location); - return; - } - } - else if (auto itv = get(functionType)) - { - // We do nothing here because we'll flatten the intersection later, but we don't want to report it as a non-function. + reportError(CountMismatch{2, std::nullopt, firstIterationArgCount, CountMismatch::Arg}, forInStatement->values.data[0]->location); } - else if (auto utv = get(functionType)) - { - // Sometimes it's okay to call a union of functions, but only if all of the functions are the same. - std::optional fst; - for (TypeId ty : utv) - { - if (!fst) - fst = follow(ty); - else if (fst != follow(ty)) - { - reportError(CannotCallNonFunction{functionType}, call->func->location); - return; - } - } - if (!fst) - ice.ice("UnionType had no elements, so fst is nullopt?"); - if (std::optional instantiatedFunctionType = instantiation.substitute(*fst)) - { - testFunctionType = *instantiatedFunctionType; - } - else - { - reportError(UnificationTooComplex{}, call->func->location); - return; - } - } - else + if (iterTys.size() >= 2 && flattenedArgTypes.head.size() > 0) { - reportError(CannotCallNonFunction{functionType}, call->func->location); - return; + size_t valueIndex = forInStatement->values.size > 1 ? 1 : 0; + testIsSubtype(iterTys[1], flattenedArgTypes.head[0], forInStatement->values.data[valueIndex]->location); } - if (call->self) + if (iterTys.size() == 3 && flattenedArgTypes.head.size() > 1) { - AstExprIndexName* indexExpr = call->func->as(); - if (!indexExpr) - ice.ice("method call expression has no 'self'"); - - args.head.push_back(lookupType(indexExpr->expr)); - argLocs.push_back(indexExpr->expr->location); + size_t valueIndex = forInStatement->values.size > 2 ? 2 : 0; + testIsSubtype(iterTys[2], flattenedArgTypes.head[1], forInStatement->values.data[valueIndex]->location); } + }; - for (size_t i = 0; i < call->args.size; ++i) + std::shared_ptr iteratorNorm = normalizer.normalize(iteratorTy); + + if (!iteratorNorm) + reportError(NormalizationTooComplex{}, firstValue->location); + + /* + * If the first iterator argument is a function + * * There must be 1 to 3 iterator arguments. Name them (nextTy, + * arrayTy, startIndexTy) + * * The return type of nextTy() must correspond to the variables' + * types and counts. HOWEVER the first iterator will never be nil. + * * The first return value of nextTy must be compatible with + * startIndexTy. + * * The first argument to nextTy() must be compatible with arrayTy if + * present. nil if not. + * * The second argument to nextTy() must be compatible with + * startIndexTy if it is present. Else, it must be compatible with + * nil. + * * nextTy() must be callable with only 2 arguments. + */ + if (const FunctionType* nextFn = get(iteratorTy)) + { + checkFunction(nextFn, iteratorTypes.head, false); + } + else if (const TableType* ttv = get(iteratorTy)) + { + if ((forInStatement->vars.size == 1 || forInStatement->vars.size == 2) && ttv->indexer) { - AstExpr* arg = call->args.data[i]; - argLocs.push_back(arg->location); - TypeId* argTy = module->astTypes.find(arg); - if (argTy) - args.head.push_back(*argTy); - else if (i == call->args.size - 1) - { - TypePackId* argTail = module->astTypePacks.find(arg); - if (argTail) - args.tail = *argTail; - else - args.tail = builtinTypes->anyTypePack; - } - else - args.head.push_back(builtinTypes->anyType); + testIsSubtype(variableTypes[0], ttv->indexer->indexType, forInStatement->vars.data[0]->location); + if (variableTypes.size() == 2) + testIsSubtype(variableTypes[1], ttv->indexer->indexResultType, forInStatement->vars.data[1]->location); } + else + reportError(GenericError{"Cannot iterate over a table without indexer"}, forInStatement->values.data[0]->location); + } + else if (get(iteratorTy) || get(iteratorTy) || get(iteratorTy)) + { + // nothing + } + else if (isOptional(iteratorTy) && !(iteratorNorm && iteratorNorm->shouldSuppressErrors())) + { + reportError(OptionalValueAccess{iteratorTy}, forInStatement->values.data[0]->location); + } + else if (std::optional iterMmTy = findMetatableEntry(builtinTypes, module->errors, iteratorTy, "__iter", forInStatement->values.data[0]->location)) + { + Instantiation instantiation{TxnLog::empty(), &arena, builtinTypes, TypeLevel{}, scope}; - TypePackId expectedArgTypes = arena->addTypePack(args); + if (std::optional instantiatedIterMmTy = instantiate(builtinTypes, NotNull{&arena}, limits, scope, *iterMmTy)) + { + if (const FunctionType* iterMmFtv = get(*instantiatedIterMmTy)) + { + TypePackId argPack = arena.addTypePack({iteratorTy}); + testIsSubtype(argPack, iterMmFtv->argTypes, forInStatement->values.data[0]->location); - std::vector overloads = flattenIntersection(testFunctionType); - std::vector> overloadsErrors; - overloadsErrors.reserve(overloads.size()); + TypePack mmIteratorTypes = extendTypePack(arena, builtinTypes, iterMmFtv->retTypes, 3); - std::vector overloadsThatMatchArgCount; + if (mmIteratorTypes.head.size() == 0) + { + reportError(GenericError{"__iter must return at least one value"}, forInStatement->values.data[0]->location); + return; + } - for (TypeId overload : overloads) - { - overload = follow(overload); + TypeId nextFn = follow(mmIteratorTypes.head[0]); - const FunctionType* overloadFn = get(overload); - if (!overloadFn) - { - reportError(CannotCallNonFunction{overload}, call->func->location); - return; - } - else - { - // We may have to instantiate the overload in order for it to typecheck. - if (std::optional instantiatedFunctionType = instantiation.substitute(overload)) + if (std::optional instantiatedNextFn = instantiation.substitute(nextFn)) { - overloadFn = get(*instantiatedFunctionType); + std::vector instantiatedIteratorTypes = mmIteratorTypes.head; + instantiatedIteratorTypes[0] = *instantiatedNextFn; + + if (const FunctionType* nextFtv = get(*instantiatedNextFn)) + { + checkFunction(nextFtv, instantiatedIteratorTypes, true); + } + else if (!isErrorSuppressing(forInStatement->values.data[0]->location, *instantiatedNextFn)) + { + reportError(CannotCallNonFunction{*instantiatedNextFn}, forInStatement->values.data[0]->location); + } } else { - overloadsErrors.emplace_back(std::vector{TypeError{call->func->location, UnificationTooComplex{}}}, overloadFn); - return; + reportError(UnificationTooComplex{}, forInStatement->values.data[0]->location); } } - - ErrorVec overloadErrors = visitOverload(call, NotNull{overloadFn}, argLocs, expectedArgTypes, expectedRetType); - if (overloadErrors.empty()) - return; - - bool argMismatch = false; - for (auto error : overloadErrors) + else if (!isErrorSuppressing(forInStatement->values.data[0]->location, *iterMmTy)) { - CountMismatch* cm = get(error); - if (!cm) - continue; + // TODO: This will not tell the user that this is because the + // metamethod isn't callable. This is not ideal, and we should + // improve this error message. - if (cm->context == CountMismatch::Arg) - { - argMismatch = true; - break; - } + // TODO: This will also not handle intersections of functions or + // callable tables (which are supported by the runtime). + reportError(CannotCallNonFunction{*iterMmTy}, forInStatement->values.data[0]->location); } - - if (!argMismatch) - overloadsThatMatchArgCount.push_back(overload); - - overloadsErrors.emplace_back(std::move(overloadErrors), overloadFn); } - - reportOverloadResolutionErrors(call, overloads, expectedArgTypes, overloadsThatMatchArgCount, overloadsErrors); + else + { + reportError(UnificationTooComplex{}, forInStatement->values.data[0]->location); + } } - - void visit(AstExprCall* call) + else if (iteratorNorm && iteratorNorm->hasTables()) { - visit(call->func, RValue); - - for (AstExpr* arg : call->args) - visit(arg, RValue); - - visitCall(call); + // Ok. All tables can be iterated. } - - void visitExprName(AstExpr* expr, Location location, const std::string& propName, ValueContext context) + else if (!iteratorNorm || !iteratorNorm->shouldSuppressErrors()) { - visit(expr, RValue); - - TypeId leftType = lookupType(expr); - const NormalizedType* norm = normalizer.normalize(leftType); - if (!norm) - reportError(NormalizationTooComplex{}, location); - - checkIndexTypeFromType(leftType, *norm, propName, location, context); + reportError(CannotCallNonFunction{iteratorTy}, forInStatement->values.data[0]->location); } +} - void visit(AstExprIndexName* indexName, ValueContext context) +std::optional TypeChecker2::getBindingType(AstExpr* expr) +{ + if (auto localExpr = expr->as()) { - visitExprName(indexName->expr, indexName->location, indexName->index.value, context); + Scope* s = stack.back(); + return s->lookup(localExpr->local); } - - void visit(AstExprIndexExpr* indexExpr, ValueContext context) + else if (auto globalExpr = expr->as()) { - if (auto str = indexExpr->index->as()) - { - const std::string stringValue(str->value.data, str->value.size); - visitExprName(indexExpr->expr, indexExpr->location, stringValue, context); - return; - } - - // TODO! - visit(indexExpr->expr, LValue); - visit(indexExpr->index, RValue); - - NotNull scope = stack.back(); - - TypeId exprType = lookupType(indexExpr->expr); - TypeId indexType = lookupType(indexExpr->index); - - if (auto tt = get(exprType)) - { - if (tt->indexer) - reportErrors(tryUnify(scope, indexExpr->index->location, indexType, tt->indexer->indexType)); - else - reportError(CannotExtendTable{exprType, CannotExtendTable::Indexer, "indexer??"}, indexExpr->location); - } + Scope* s = stack.back(); + return s->lookup(globalExpr->name); } + else + return std::nullopt; +} - void visit(AstExprFunction* fn) - { - auto StackPusher = pushStack(fn); +void TypeChecker2::reportErrorsFromAssigningToNever(AstExpr* lhs, TypeId rhsType) +{ - visitGenerics(fn->generics, fn->genericPacks); + if (auto indexName = lhs->as()) + { + TypeId indexedType = lookupType(indexName->expr); - TypeId inferredFnTy = lookupType(fn); - const FunctionType* inferredFtv = get(inferredFnTy); - LUAU_ASSERT(inferredFtv); + // if it's already never, I don't think we have anything to do here. + if (get(indexedType)) + return; - // There is no way to write an annotation for the self argument, so we - // cannot do anything to check it. - auto argIt = begin(inferredFtv->argTypes); - if (fn->self) - ++argIt; + std::string prop = indexName->index.value; - for (const auto& arg : fn->args) + std::shared_ptr norm = normalizer.normalize(indexedType); + if (!norm) { - if (argIt == end(inferredFtv->argTypes)) - break; - - if (arg->annotation) - { - TypeId inferredArgTy = *argIt; - TypeId annotatedArgTy = lookupAnnotation(arg->annotation); + reportError(NormalizationTooComplex{}, lhs->location); + return; + } - if (!isSubtype(inferredArgTy, annotatedArgTy, stack.back())) - { - reportError(TypeMismatch{inferredArgTy, annotatedArgTy}, arg->location); - } - } + // if the type is error suppressing, we don't actually have any work left to do. + if (norm->shouldSuppressErrors()) + return; - ++argIt; - } + const auto propTypes = lookupProp(norm.get(), prop, ValueContext::LValue, lhs->location, builtinTypes->stringType, module->errors); - visit(fn->body); + reportError(CannotAssignToNever{rhsType, propTypes.typesOfProp, CannotAssignToNever::Reason::PropertyNarrowed}, lhs->location); } +} - void visit(AstExprTable* expr) - { - // TODO! - for (const AstExprTable::Item& item : expr->items) - { - if (item.key) - visit(item.key, LValue); - visit(item.value, RValue); - } - } +void TypeChecker2::visit(AstStatAssign* assign) +{ + size_t count = std::min(assign->vars.size, assign->values.size); - void visit(AstExprUnary* expr) + for (size_t i = 0; i < count; ++i) { - visit(expr->expr, RValue); + AstExpr* lhs = assign->vars.data[i]; + visit(lhs, ValueContext::LValue); + TypeId lhsType = lookupType(lhs); - NotNull scope = stack.back(); - TypeId operandType = lookupType(expr->expr); - TypeId resultType = lookupType(expr); - - if (get(operandType) || get(operandType) || get(operandType)) - return; + AstExpr* rhs = assign->values.data[i]; + visit(rhs, ValueContext::RValue); + TypeId rhsType = lookupType(rhs); - if (auto it = kUnaryOpMetamethods.find(expr->op); it != kUnaryOpMetamethods.end()) + if (get(lhsType)) { - std::optional mm = findMetatableEntry(builtinTypes, module->errors, operandType, it->second, expr->location); - if (mm) - { - if (const FunctionType* ftv = get(follow(*mm))) - { - if (std::optional ret = first(ftv->retTypes)) - { - if (expr->op == AstExprUnary::Op::Len) - { - reportErrors(tryUnify(scope, expr->location, follow(*ret), builtinTypes->numberType)); - } - } - else - { - reportError(GenericError{format("Metamethod '%s' must return a value", it->second)}, expr->location); - } - - std::optional firstArg = first(ftv->argTypes); - if (!firstArg) - { - reportError(GenericError{"__unm metamethod must accept one argument"}, expr->location); - return; - } - - TypePackId expectedArgs = testArena.addTypePack({operandType}); - TypePackId expectedRet = testArena.addTypePack({resultType}); - - TypeId expectedFunction = testArena.addType(FunctionType{expectedArgs, expectedRet}); - - ErrorVec errors = tryUnify(scope, expr->location, *mm, expectedFunction); - if (!errors.empty()) - { - reportError(TypeMismatch{*firstArg, operandType}, expr->location); - return; - } - } - - return; - } + reportErrorsFromAssigningToNever(lhs, rhsType); + continue; } - if (expr->op == AstExprUnary::Op::Len) - { - DenseHashSet seen{nullptr}; - int recursionCount = 0; + bool ok = testIsSubtype(rhsType, lhsType, rhs->location); - if (!hasLength(operandType, seen, &recursionCount)) - { - reportError(NotATable{operandType}, expr->location); - } - } - else if (expr->op == AstExprUnary::Op::Minus) - { - reportErrors(tryUnify(scope, expr->location, operandType, builtinTypes->numberType)); - } - else if (expr->op == AstExprUnary::Op::Not) + // If rhsType bindingType = getBindingType(lhs); + if (bindingType) + testIsSubtype(rhsType, *bindingType, rhs->location); } } +} - TypeId visit(AstExprBinary* expr, AstNode* overrideKey = nullptr) - { - visit(expr->left, LValue); - visit(expr->right, LValue); +void TypeChecker2::visit(AstStatCompoundAssign* stat) +{ + AstExprBinary fake{stat->location, stat->op, stat->var, stat->value}; + visit(&fake, stat); - NotNull scope = stack.back(); + TypeId* resultTy = module->astCompoundAssignResultTypes.find(stat); + LUAU_ASSERT(resultTy); + TypeId varTy = lookupType(stat->var); - bool isEquality = expr->op == AstExprBinary::Op::CompareEq || expr->op == AstExprBinary::Op::CompareNe; - bool isComparison = expr->op >= AstExprBinary::Op::CompareEq && expr->op <= AstExprBinary::Op::CompareGe; - bool isLogical = expr->op == AstExprBinary::Op::And || expr->op == AstExprBinary::Op::Or; + testIsSubtype(*resultTy, varTy, stat->location); +} - TypeId leftType = lookupType(expr->left); - TypeId rightType = lookupType(expr->right); +void TypeChecker2::visit(AstStatFunction* stat) +{ + visit(stat->name, ValueContext::LValue); + visit(stat->func); +} - if (expr->op == AstExprBinary::Op::Or) - { - leftType = stripNil(builtinTypes, testArena, leftType); - } +void TypeChecker2::visit(AstStatLocalFunction* stat) +{ + visit(stat->func); +} - bool isStringOperation = isString(leftType) && isString(rightType); +void TypeChecker2::visit(const AstTypeList* typeList) +{ + for (AstType* ty : typeList->types) + visit(ty); - if (get(leftType) || get(leftType)) - return leftType; - else if (get(rightType) || get(rightType)) - return rightType; + if (typeList->tailType) + visit(typeList->tailType); +} - if ((get(leftType) || get(leftType)) && !isEquality && !isLogical) - { - auto name = getIdentifierOfBaseVar(expr->left); - reportError(CannotInferBinaryOperation{expr->op, name, - isComparison ? CannotInferBinaryOperation::OpKind::Comparison : CannotInferBinaryOperation::OpKind::Operation}, - expr->location); - return leftType; +void TypeChecker2::visit(AstStatTypeAlias* stat) +{ + visitGenerics(stat->generics, stat->genericPacks); + visit(stat->type); +} + +void TypeChecker2::visit(AstStatTypeFunction* stat) +{ + // TODO: add type checking for user-defined type functions + if (!FFlag::LuauUserDefinedTypeFunctions2) + reportError(TypeError{stat->location, GenericError{"This syntax is not supported"}}); +} + +void TypeChecker2::visit(AstTypeList types) +{ + for (AstType* type : types.types) + visit(type); + if (types.tailType) + visit(types.tailType); +} + +void TypeChecker2::visit(AstStatDeclareFunction* stat) +{ + visitGenerics(stat->generics, stat->genericPacks); + visit(stat->params); + visit(stat->retTypes); +} + +void TypeChecker2::visit(AstStatDeclareGlobal* stat) +{ + visit(stat->type); +} + +void TypeChecker2::visit(AstStatDeclareClass* stat) +{ + for (const AstDeclaredClassProp& prop : stat->props) + visit(prop.ty); +} + +void TypeChecker2::visit(AstStatError* stat) +{ + for (AstExpr* expr : stat->expressions) + visit(expr, ValueContext::RValue); + + for (AstStat* s : stat->statements) + visit(s); +} + +void TypeChecker2::visit(AstExpr* expr, ValueContext context) +{ + auto StackPusher = pushStack(expr); + + if (auto e = expr->as()) + return visit(e, context); + else if (auto e = expr->as()) + return visit(e); + else if (auto e = expr->as()) + return visit(e); + else if (auto e = expr->as()) + return visit(e); + else if (auto e = expr->as()) + return visit(e); + else if (auto e = expr->as()) + return visit(e); + else if (auto e = expr->as()) + return visit(e); + else if (auto e = expr->as()) + return visit(e); + else if (auto e = expr->as()) + return visit(e); + else if (auto e = expr->as()) + return visit(e, context); + else if (auto e = expr->as()) + return visit(e, context); + else if (auto e = expr->as()) + return visit(e); + else if (auto e = expr->as()) + return visit(e); + else if (auto e = expr->as()) + return visit(e); + else if (auto e = expr->as()) + { + visit(e); + return; + } + else if (auto e = expr->as()) + return visit(e); + else if (auto e = expr->as()) + return visit(e); + else if (auto e = expr->as()) + return visit(e); + else if (auto e = expr->as()) + return visit(e); + else + LUAU_ASSERT(!"TypeChecker2 encountered an unknown expression type"); +} + +void TypeChecker2::visit(AstExprGroup* expr, ValueContext context) +{ + visit(expr->expr, context); +} + +void TypeChecker2::visit(AstExprConstantNil* expr) +{ +#if defined(LUAU_ENABLE_ASSERT) + TypeId actualType = lookupType(expr); + TypeId expectedType = builtinTypes->nilType; + NotNull scope{findInnermostScope(expr->location)}; + + SubtypingResult r = subtyping->isSubtype(actualType, expectedType, scope); + LUAU_ASSERT(r.isSubtype || isErrorSuppressing(expr->location, actualType)); +#endif +} + +void TypeChecker2::visit(AstExprConstantBool* expr) +{ + // booleans use specialized inference logic for singleton types, which can lead to real type errors here. + + const TypeId bestType = expr->value ? builtinTypes->trueType : builtinTypes->falseType; + const TypeId inferredType = lookupType(expr); + NotNull scope{findInnermostScope(expr->location)}; + + const SubtypingResult r = subtyping->isSubtype(bestType, inferredType, scope); + if (!r.isSubtype && !isErrorSuppressing(expr->location, inferredType)) + reportError(TypeMismatch{inferredType, bestType}, expr->location); +} + +void TypeChecker2::visit(AstExprConstantNumber* expr) +{ +#if defined(LUAU_ENABLE_ASSERT) + const TypeId bestType = builtinTypes->numberType; + const TypeId inferredType = lookupType(expr); + NotNull scope{findInnermostScope(expr->location)}; + + const SubtypingResult r = subtyping->isSubtype(bestType, inferredType, scope); + LUAU_ASSERT(r.isSubtype || isErrorSuppressing(expr->location, inferredType)); +#endif +} + +void TypeChecker2::visit(AstExprConstantString* expr) +{ + // strings use specialized inference logic for singleton types, which can lead to real type errors here. + + const TypeId bestType = module->internalTypes.addType(SingletonType{StringSingleton{std::string{expr->value.data, expr->value.size}}}); + const TypeId inferredType = lookupType(expr); + NotNull scope{findInnermostScope(expr->location)}; + + const SubtypingResult r = subtyping->isSubtype(bestType, inferredType, scope); + if (!r.isSubtype && !isErrorSuppressing(expr->location, inferredType)) + reportError(TypeMismatch{inferredType, bestType}, expr->location); +} + +void TypeChecker2::visit(AstExprLocal* expr) +{ + // TODO! +} + +void TypeChecker2::visit(AstExprGlobal* expr) +{ + NotNull scope = stack.back(); + if (!scope->lookup(expr->name)) + reportError(UnknownSymbol{expr->name.value, UnknownSymbol::Binding}, expr->location); +} + +void TypeChecker2::visit(AstExprVarargs* expr) +{ + // TODO! +} + +void TypeChecker2::visitCall(AstExprCall* call) +{ + TypePack args; + std::vector argExprs; + NotNull scope{findInnermostScope(call->location)}; + argExprs.reserve(call->args.size + 1); + + TypeId* originalCallTy = module->astOriginalCallTypes.find(call->func); + TypeId* selectedOverloadTy = module->astOverloadResolvedTypes.find(call); + if (!originalCallTy) + return; + + TypeId fnTy = follow(*originalCallTy); + + + if (get(fnTy) || get(fnTy) || get(fnTy)) + return; + else if (isOptional(fnTy)) + { + switch (shouldSuppressErrors(NotNull{&normalizer}, fnTy)) + { + case ErrorSuppression::Suppress: + break; + case ErrorSuppression::NormalizationFailed: + reportError(NormalizationTooComplex{}, call->func->location); + [[fallthrough]]; + case ErrorSuppression::DoNotSuppress: + reportError(OptionalValueAccess{fnTy}, call->func->location); + } + return; + } + + if (selectedOverloadTy) + { + SubtypingResult result = subtyping->isSubtype(*originalCallTy, *selectedOverloadTy, scope); + if (result.isSubtype) + fnTy = follow(*selectedOverloadTy); + + if (result.normalizationTooComplex) + { + reportError(NormalizationTooComplex{}, call->func->location); + return; } + } + + if (call->self) + { + AstExprIndexName* indexExpr = call->func->as(); + if (!indexExpr) + ice->ice("method call expression has no 'self'"); + + args.head.push_back(lookupType(indexExpr->expr)); + argExprs.push_back(indexExpr->expr); + } - if (auto it = kBinaryOpMetamethods.find(expr->op); it != kBinaryOpMetamethods.end()) + for (size_t i = 0; i < call->args.size; ++i) + { + AstExpr* arg = call->args.data[i]; + argExprs.push_back(arg); + TypeId* argTy = module->astTypes.find(arg); + if (argTy) + args.head.push_back(*argTy); + else if (i == call->args.size - 1) { - std::optional leftMt = getMetatable(leftType, builtinTypes); - std::optional rightMt = getMetatable(rightType, builtinTypes); - bool matches = leftMt == rightMt; - if (isEquality && !matches) + if (auto argTail = module->astTypePacks.find(arg)) { - auto testUnion = [&matches, builtinTypes = this->builtinTypes](const UnionType* utv, std::optional otherMt) { - for (TypeId option : utv) + auto [head, tail] = flatten(*argTail); + args.head.insert(args.head.end(), head.begin(), head.end()); + args.tail = tail; + } + else + args.tail = builtinTypes->anyTypePack; + } + else + args.head.push_back(builtinTypes->anyType); + } + + TypePackId argsTp = module->internalTypes.addTypePack(args); + if (auto ftv = get(follow(*originalCallTy))) + { + if (ftv->dcrMagicTypeCheck) + { + ftv->dcrMagicTypeCheck(MagicFunctionTypeCheckContext{NotNull{this}, builtinTypes, call, argsTp, scope}); + return; + } + } + + + OverloadResolver resolver{ + builtinTypes, + NotNull{&module->internalTypes}, + NotNull{&normalizer}, + typeFunctionRuntime, + NotNull{stack.back()}, + ice, + limits, + call->location, + }; + resolver.resolve(fnTy, &args, call->func, &argExprs); + + auto norm = normalizer.normalize(fnTy); + if (!norm) + reportError(NormalizationTooComplex{}, call->func->location); + auto isInhabited = normalizer.isInhabited(norm.get()); + if (isInhabited == NormalizationResult::HitLimits) + reportError(NormalizationTooComplex{}, call->func->location); + + if (norm && norm->shouldSuppressErrors()) + return; // error suppressing function type! + else if (!resolver.ok.empty()) + return; // We found a call that works, so this is ok. + else if (!norm || isInhabited == NormalizationResult::False) + return; // Ok. Calling an uninhabited type is no-op. + else if (!resolver.nonviableOverloads.empty()) + { + if (resolver.nonviableOverloads.size() == 1 && !isErrorSuppressing(call->func->location, resolver.nonviableOverloads.front().first)) + reportErrors(resolver.nonviableOverloads.front().second); + else + { + std::string s = "None of the overloads for function that accept "; + s += std::to_string(args.head.size()); + s += " arguments are compatible."; + reportError(GenericError{std::move(s)}, call->location); + } + } + else if (!resolver.arityMismatches.empty()) + { + if (resolver.arityMismatches.size() == 1) + reportErrors(resolver.arityMismatches.front().second); + else + { + std::string s = "No overload for function accepts "; + s += std::to_string(args.head.size()); + s += " arguments."; + reportError(GenericError{std::move(s)}, call->location); + } + } + else if (!resolver.nonFunctions.empty()) + reportError(CannotCallNonFunction{fnTy}, call->func->location); + else + LUAU_ASSERT(!"Generating the best possible error from this function call resolution was inexhaustive?"); + + if (resolver.nonviableOverloads.size() <= 1 && resolver.arityMismatches.size() <= 1) + return; + + std::string s = "Available overloads: "; + + std::vector overloads; + if (resolver.nonviableOverloads.empty()) + { + for (const auto& [ty, p] : resolver.resolution) + { + if (p.first == OverloadResolver::TypeIsNotAFunction) + continue; + + overloads.push_back(ty); + } + } + else + { + for (const auto& [ty, _] : resolver.nonviableOverloads) + overloads.push_back(ty); + } + + if (overloads.size() <= 1) + return; + + for (size_t i = 0; i < overloads.size(); ++i) + { + if (i > 0) + s += (i == overloads.size() - 1) ? "; and " : "; "; + + s += toString(overloads[i]); + } + + reportError(ExtraInformation{std::move(s)}, call->func->location); +} + +void TypeChecker2::visit(AstExprCall* call) +{ + visit(call->func, ValueContext::RValue); + + for (AstExpr* arg : call->args) + visit(arg, ValueContext::RValue); + + visitCall(call); +} + +std::optional TypeChecker2::tryStripUnionFromNil(TypeId ty) +{ + if (const UnionType* utv = get(ty)) + { + if (!std::any_of(begin(utv), end(utv), isNil)) + return ty; + + std::vector result; + + for (TypeId option : utv) + { + if (!isNil(option)) + result.push_back(option); + } + + if (result.empty()) + return std::nullopt; + + return result.size() == 1 ? result[0] : module->internalTypes.addType(UnionType{std::move(result)}); + } + + return std::nullopt; +} + +TypeId TypeChecker2::stripFromNilAndReport(TypeId ty, const Location& location) +{ + ty = follow(ty); + + if (auto utv = get(ty)) + { + if (!std::any_of(begin(utv), end(utv), isNil)) + return ty; + } + + if (std::optional strippedUnion = tryStripUnionFromNil(ty)) + { + switch (shouldSuppressErrors(NotNull{&normalizer}, ty)) + { + case ErrorSuppression::Suppress: + break; + case ErrorSuppression::NormalizationFailed: + reportError(NormalizationTooComplex{}, location); + [[fallthrough]]; + case ErrorSuppression::DoNotSuppress: + reportError(OptionalValueAccess{ty}, location); + } + + return follow(*strippedUnion); + } + + return ty; +} + +void TypeChecker2::visitExprName(AstExpr* expr, Location location, const std::string& propName, ValueContext context, TypeId astIndexExprTy) +{ + visit(expr, ValueContext::RValue); + TypeId leftType = stripFromNilAndReport(lookupType(expr), location); + checkIndexTypeFromType(leftType, propName, context, location, astIndexExprTy); +} + +void TypeChecker2::visit(AstExprIndexName* indexName, ValueContext context) +{ + // If we're indexing like _.foo - foo could either be a prop or a string. + visitExprName(indexName->expr, indexName->location, indexName->index.value, context, builtinTypes->stringType); +} + +void TypeChecker2::indexExprMetatableHelper(AstExprIndexExpr* indexExpr, const MetatableType* metaTable, TypeId exprType, TypeId indexType) +{ + if (auto tt = get(follow(metaTable->table)); tt && tt->indexer) + testIsSubtype(indexType, tt->indexer->indexType, indexExpr->index->location); + else if (auto mt = get(follow(metaTable->table))) + indexExprMetatableHelper(indexExpr, mt, exprType, indexType); + else if (auto tmt = get(follow(metaTable->metatable)); tmt && tmt->indexer) + testIsSubtype(indexType, tmt->indexer->indexType, indexExpr->index->location); + else if (auto mtmt = get(follow(metaTable->metatable))) + indexExprMetatableHelper(indexExpr, mtmt, exprType, indexType); + else + { + if (!(DFInt::LuauTypeSolverRelease >= 647)) + { + LUAU_ASSERT(tt || get(follow(metaTable->table))); + } + // CLI-122161: We're not handling unions correctly (probably). + reportError(CannotExtendTable{exprType, CannotExtendTable::Indexer, "indexer??"}, indexExpr->location); + } +} + +void TypeChecker2::visit(AstExprIndexExpr* indexExpr, ValueContext context) +{ + if (auto str = indexExpr->index->as()) + { + TypeId astIndexExprType = lookupType(indexExpr->index); + const std::string stringValue(str->value.data, str->value.size); + visitExprName(indexExpr->expr, indexExpr->location, stringValue, context, astIndexExprType); + return; + } + + visit(indexExpr->expr, ValueContext::RValue); + visit(indexExpr->index, ValueContext::RValue); + + TypeId exprType = follow(lookupType(indexExpr->expr)); + TypeId indexType = follow(lookupType(indexExpr->index)); + + if (auto tt = get(exprType)) + { + if (tt->indexer) + testIsSubtype(indexType, tt->indexer->indexType, indexExpr->index->location); + else + reportError(CannotExtendTable{exprType, CannotExtendTable::Indexer, "indexer??"}, indexExpr->location); + } + else if (auto mt = get(exprType)) + { + return indexExprMetatableHelper(indexExpr, mt, exprType, indexType); + } + else if (auto cls = get(exprType)) + { + if (cls->indexer) + testIsSubtype(indexType, cls->indexer->indexType, indexExpr->index->location); + else + reportError(DynamicPropertyLookupOnClassesUnsafe{exprType}, indexExpr->location); + } + else if (get(exprType) && isOptional(exprType)) + { + switch (shouldSuppressErrors(NotNull{&normalizer}, exprType)) + { + case ErrorSuppression::Suppress: + break; + case ErrorSuppression::NormalizationFailed: + reportError(NormalizationTooComplex{}, indexExpr->location); + [[fallthrough]]; + case ErrorSuppression::DoNotSuppress: + reportError(OptionalValueAccess{exprType}, indexExpr->location); + } + } + else if (auto ut = get(exprType)) + { + // if all of the types are a table type, the union must be a table, and so we shouldn't error. + if (!std::all_of(begin(ut), end(ut), getTableType)) + reportError(NotATable{exprType}, indexExpr->location); + } + else if (auto it = get(exprType)) + { + // if any of the types are a table type, the intersection must be a table, and so we shouldn't error. + if (!std::any_of(begin(it), end(it), getTableType)) + reportError(NotATable{exprType}, indexExpr->location); + } + else if (get(exprType) || isErrorSuppressing(indexExpr->location, exprType)) + { + // Nothing + } + else + reportError(NotATable{exprType}, indexExpr->location); +} + +void TypeChecker2::visit(AstExprFunction* fn) +{ + auto StackPusher = pushStack(fn); + + visitGenerics(fn->generics, fn->genericPacks); + + TypeId inferredFnTy = lookupType(fn); + functionDeclStack.push_back(inferredFnTy); + + std::shared_ptr normalizedFnTy = normalizer.normalize(inferredFnTy); + if (!normalizedFnTy) + { + reportError(CodeTooComplex{}, fn->location); + } + else if (get(normalizedFnTy->errors)) + { + // Nothing + } + else if (!normalizedFnTy->hasFunctions()) + { + ice->ice("Internal error: Lambda has non-function type " + toString(inferredFnTy), fn->location); + } + else + { + if (1 != normalizedFnTy->functions.parts.size()) + ice->ice("Unexpected: Lambda has unexpected type " + toString(inferredFnTy), fn->location); + + const FunctionType* inferredFtv = get(normalizedFnTy->functions.parts.front()); + LUAU_ASSERT(inferredFtv); + + // There is no way to write an annotation for the self argument, so we + // cannot do anything to check it. + auto argIt = begin(inferredFtv->argTypes); + if (fn->self) + ++argIt; + + for (const auto& arg : fn->args) + { + if (argIt == end(inferredFtv->argTypes)) + break; + + TypeId inferredArgTy = *argIt; + + if (arg->annotation) + { + // we need to typecheck any argument annotations themselves. + visit(arg->annotation); + + TypeId annotatedArgTy = lookupAnnotation(arg->annotation); + + testIsSubtype(inferredArgTy, annotatedArgTy, arg->location); + } + + // Some Luau constructs can result in an argument type being + // reduced to never by inference. In this case, we want to + // report an error at the function, instead of reporting an + // error at every callsite. + if (is(follow(inferredArgTy))) + { + // If the annotation simplified to never, we don't want to + // even look at contributors. + bool explicitlyNever = false; + if (arg->annotation) + { + TypeId annotatedArgTy = lookupAnnotation(arg->annotation); + explicitlyNever = is(annotatedArgTy); + } + + // Not following here is deliberate: the contribution map is + // keyed by type pointer, but that type pointer has, at some + // point, been transmuted to a bound type pointing to never. + if (const auto contributors = module->upperBoundContributors.find(inferredArgTy); contributors && !explicitlyNever) + { + // It's unfortunate that we can't link error messages + // together. For now, this will work. + reportError( + GenericError{format( + "Parameter '%s' has been reduced to never. This function is not callable with any possible value.", arg->name.value + )}, + arg->location + ); + for (const auto& [site, component] : *contributors) + reportError( + ExtraInformation{ + format("Parameter '%s' is required to be a subtype of '%s' here.", arg->name.value, toString(component).c_str()) + }, + site + ); + } + } + + ++argIt; + } + + // we need to typecheck the vararg annotation, if it exists. + if (fn->vararg && fn->varargAnnotation) + visit(fn->varargAnnotation); + + bool reachesImplicitReturn = getFallthrough(fn->body) != nullptr; + if (reachesImplicitReturn && !allowsNoReturnValues(follow(inferredFtv->retTypes))) + reportError(FunctionExitsWithoutReturning{inferredFtv->retTypes}, getEndLocation(fn)); + } + + visit(fn->body); + + // we need to typecheck the return annotation itself, if it exists. + if (fn->returnAnnotation) + visit(*fn->returnAnnotation); + + + // If the function type has a function annotation, we need to see if we can suggest an annotation + if (normalizedFnTy) + { + const FunctionType* inferredFtv = get(normalizedFnTy->functions.parts.front()); + LUAU_ASSERT(inferredFtv); + + TypeFunctionReductionGuesser guesser{NotNull{&module->internalTypes}, builtinTypes, NotNull{&normalizer}}; + for (TypeId retTy : inferredFtv->retTypes) + { + if (get(follow(retTy))) + { + TypeFunctionReductionGuessResult result = guesser.guessTypeFunctionReductionForFunctionExpr(*fn, inferredFtv, retTy); + if (result.shouldRecommendAnnotation && !get(result.guessedReturnType)) + reportError( + ExplicitFunctionAnnotationRecommended{std::move(result.guessedFunctionAnnotations), result.guessedReturnType}, fn->location + ); + } + } + } + + functionDeclStack.pop_back(); +} + +void TypeChecker2::visit(AstExprTable* expr) +{ + // TODO! + for (const AstExprTable::Item& item : expr->items) + { + if (item.key) + visit(item.key, ValueContext::LValue); + visit(item.value, ValueContext::RValue); + } +} + +void TypeChecker2::visit(AstExprUnary* expr) +{ + visit(expr->expr, ValueContext::RValue); + + TypeId operandType = lookupType(expr->expr); + TypeId resultType = lookupType(expr); + + if (isErrorSuppressing(expr->expr->location, operandType)) + return; + + if (auto it = kUnaryOpMetamethods.find(expr->op); it != kUnaryOpMetamethods.end()) + { + std::optional mm = findMetatableEntry(builtinTypes, module->errors, operandType, it->second, expr->location); + if (mm) + { + if (const FunctionType* ftv = get(follow(*mm))) + { + if (std::optional ret = first(ftv->retTypes)) + { + if (expr->op == AstExprUnary::Op::Len) { - if (getMetatable(follow(option), builtinTypes) == otherMt) - { - matches = true; - break; - } + testIsSubtype(follow(*ret), builtinTypes->numberType, expr->location); } - }; + } + else + { + reportError(GenericError{format("Metamethod '%s' must return a value", it->second)}, expr->location); + } + + std::optional firstArg = first(ftv->argTypes); + if (!firstArg) + { + reportError(GenericError{"__unm metamethod must accept one argument"}, expr->location); + return; + } + + TypePackId expectedArgs = module->internalTypes.addTypePack({operandType}); + TypePackId expectedRet = module->internalTypes.addTypePack({resultType}); + + TypeId expectedFunction = module->internalTypes.addType(FunctionType{expectedArgs, expectedRet}); + + bool success = testIsSubtype(*mm, expectedFunction, expr->location); + if (!success) + return; + } + + return; + } + } + + if (expr->op == AstExprUnary::Op::Len) + { + DenseHashSet seen{nullptr}; + int recursionCount = 0; + std::shared_ptr nty = normalizer.normalize(operandType); + + if (nty && nty->shouldSuppressErrors()) + return; + + switch (normalizer.isInhabited(nty.get())) + { + case NormalizationResult::True: + break; + case NormalizationResult::False: + return; + case NormalizationResult::HitLimits: + reportError(NormalizationTooComplex{}, expr->location); + return; + } + + if (!hasLength(operandType, seen, &recursionCount)) + { + if (isOptional(operandType)) + reportError(OptionalValueAccess{operandType}, expr->location); + else + reportError(NotATable{operandType}, expr->location); + } + } + else if (expr->op == AstExprUnary::Op::Minus) + { + testIsSubtype(operandType, builtinTypes->numberType, expr->location); + } + else if (expr->op == AstExprUnary::Op::Not) + { + } + else + { + LUAU_ASSERT(!"Unhandled unary operator"); + } +} + +TypeId TypeChecker2::visit(AstExprBinary* expr, AstNode* overrideKey) +{ + visit(expr->left, ValueContext::RValue); + visit(expr->right, ValueContext::RValue); + + NotNull scope = stack.back(); + + bool isEquality = expr->op == AstExprBinary::Op::CompareEq || expr->op == AstExprBinary::Op::CompareNe; + bool isComparison = expr->op >= AstExprBinary::Op::CompareEq && expr->op <= AstExprBinary::Op::CompareGe; + bool isLogical = expr->op == AstExprBinary::Op::And || expr->op == AstExprBinary::Op::Or; + + TypeId leftType = follow(lookupType(expr->left)); + TypeId rightType = follow(lookupType(expr->right)); + TypeId expectedResult = follow(lookupType(expr)); + + if (get(expectedResult)) + { + checkForInternalTypeFunction(expectedResult, expr->location); + return expectedResult; + } + + if (expr->op == AstExprBinary::Op::Or) + { + leftType = stripNil(builtinTypes, module->internalTypes, leftType); + } + + std::shared_ptr normLeft = normalizer.normalize(leftType); + std::shared_ptr normRight = normalizer.normalize(rightType); + + bool isStringOperation = + (normLeft ? normLeft->isSubtypeOfString() : isString(leftType)) && (normRight ? normRight->isSubtypeOfString() : isString(rightType)); + leftType = follow(leftType); + if (get(leftType) || get(leftType) || get(leftType)) + return leftType; + else if (get(rightType) || get(rightType) || get(rightType)) + return rightType; + else if ((normLeft && normLeft->shouldSuppressErrors()) || (normRight && normRight->shouldSuppressErrors())) + return builtinTypes->anyType; // we can't say anything better if it's error suppressing but not any or error alone. + + if ((get(leftType) || get(leftType) || get(leftType)) && !isEquality && !isLogical) + { + auto name = getIdentifierOfBaseVar(expr->left); + reportError( + CannotInferBinaryOperation{ + expr->op, name, isComparison ? CannotInferBinaryOperation::OpKind::Comparison : CannotInferBinaryOperation::OpKind::Operation + }, + expr->location + ); + return leftType; + } + + NormalizationResult typesHaveIntersection = normalizer.isIntersectionInhabited(leftType, rightType); + if (auto it = kBinaryOpMetamethods.find(expr->op); it != kBinaryOpMetamethods.end()) + { + std::optional leftMt = getMetatable(leftType, builtinTypes); + std::optional rightMt = getMetatable(rightType, builtinTypes); + bool matches = leftMt == rightMt; + - if (const UnionType* utv = get(leftType); utv && rightMt) + if (isEquality && !matches) + { + auto testUnion = [&matches, builtinTypes = this->builtinTypes](const UnionType* utv, std::optional otherMt) + { + for (TypeId option : utv) { - testUnion(utv, rightMt); + if (getMetatable(follow(option), builtinTypes) == otherMt) + { + matches = true; + break; + } } + }; - if (const UnionType* utv = get(rightType); utv && leftMt && !matches) - { - testUnion(utv, leftMt); - } + if (const UnionType* utv = get(leftType); utv && rightMt) + { + testUnion(utv, rightMt); } - if (!matches && isComparison) + if (const UnionType* utv = get(rightType); utv && leftMt && !matches) { - reportError(GenericError{format("Types %s and %s cannot be compared with %s because they do not have the same metatable", - toString(leftType).c_str(), toString(rightType).c_str(), toString(expr->op).c_str())}, - expr->location); - - return builtinTypes->errorRecoveryType(); + testUnion(utv, leftMt); } + } + + // If we're working with things that are not tables, the metatable comparisons above are a little excessive + // It's ok for one type to have a meta table and the other to not. In that case, we should fall back on + // checking if the intersection of the types is inhabited. If `typesHaveIntersection` failed due to limits, + // TODO: Maybe add more checks here (e.g. for functions, classes, etc) + if (!(get(leftType) || get(rightType))) + if (!leftMt.has_value() || !rightMt.has_value()) + matches = matches || typesHaveIntersection != NormalizationResult::False; + + if (!matches && isComparison) + { + reportError( + GenericError{format( + "Types %s and %s cannot be compared with %s because they do not have the same metatable", + toString(leftType).c_str(), + toString(rightType).c_str(), + toString(expr->op).c_str() + )}, + expr->location + ); + + return builtinTypes->errorRecoveryType(); + } + + std::optional mm; + if (std::optional leftMm = findMetatableEntry(builtinTypes, module->errors, leftType, it->second, expr->left->location)) + mm = leftMm; + else if (std::optional rightMm = findMetatableEntry(builtinTypes, module->errors, rightType, it->second, expr->right->location)) + { + mm = rightMm; + std::swap(leftType, rightType); + } + + if (mm) + { + AstNode* key = expr; + if (overrideKey != nullptr) + key = overrideKey; - std::optional mm; - if (std::optional leftMm = findMetatableEntry(builtinTypes, module->errors, leftType, it->second, expr->left->location)) - mm = leftMm; - else if (std::optional rightMm = findMetatableEntry(builtinTypes, module->errors, rightType, it->second, expr->right->location)) + TypeId* selectedOverloadTy = module->astOverloadResolvedTypes.find(key); + if (!selectedOverloadTy) { - mm = rightMm; - std::swap(leftType, rightType); + // reportError(CodeTooComplex{}, expr->location); + // was handled by a type function + return expectedResult; } - if (mm) + else if (const FunctionType* ftv = get(follow(*selectedOverloadTy))) { - AstNode* key = expr; - if (overrideKey != nullptr) - key = overrideKey; + TypePackId expectedArgs; + // For >= and > we invoke __lt and __le respectively with + // swapped argument ordering. + if (expr->op == AstExprBinary::Op::CompareGe || expr->op == AstExprBinary::Op::CompareGt) + { + expectedArgs = module->internalTypes.addTypePack({rightType, leftType}); + } + else + { + expectedArgs = module->internalTypes.addTypePack({leftType, rightType}); + } + + TypePackId expectedRets; + if (expr->op == AstExprBinary::CompareEq || expr->op == AstExprBinary::CompareNe || expr->op == AstExprBinary::CompareGe || + expr->op == AstExprBinary::CompareGt || expr->op == AstExprBinary::Op::CompareLe || expr->op == AstExprBinary::Op::CompareLt) + { + expectedRets = module->internalTypes.addTypePack({builtinTypes->booleanType}); + } + else + { + expectedRets = module->internalTypes.addTypePack({module->internalTypes.freshType(scope, TypeLevel{})}); + } - TypeId instantiatedMm = module->astOverloadResolvedTypes[key]; - if (!instantiatedMm) - reportError(CodeTooComplex{}, expr->location); + TypeId expectedTy = module->internalTypes.addType(FunctionType(expectedArgs, expectedRets)); - else if (const FunctionType* ftv = get(follow(instantiatedMm))) + testIsSubtype(follow(*mm), expectedTy, expr->location); + + std::optional ret = first(ftv->retTypes); + if (ret) { - TypePackId expectedArgs; - // For >= and > we invoke __lt and __le respectively with - // swapped argument ordering. - if (expr->op == AstExprBinary::Op::CompareGe || expr->op == AstExprBinary::Op::CompareGt) + if (isComparison) { - expectedArgs = testArena.addTypePack({rightType, leftType}); + if (!isBoolean(follow(*ret))) + { + reportError(GenericError{format("Metamethod '%s' must return a boolean", it->second)}, expr->location); + } + + return builtinTypes->booleanType; } else { - expectedArgs = testArena.addTypePack({leftType, rightType}); + return follow(*ret); } - - TypePackId expectedRets; - if (expr->op == AstExprBinary::CompareEq || expr->op == AstExprBinary::CompareNe || expr->op == AstExprBinary::CompareGe || - expr->op == AstExprBinary::CompareGt || expr->op == AstExprBinary::Op::CompareLe || expr->op == AstExprBinary::Op::CompareLt) + } + else + { + if (isComparison) { - expectedRets = testArena.addTypePack({builtinTypes->booleanType}); + reportError(GenericError{format("Metamethod '%s' must return a boolean", it->second)}, expr->location); } else { - expectedRets = testArena.addTypePack({testArena.freshType(scope, TypeLevel{})}); + reportError(GenericError{format("Metamethod '%s' must return a value", it->second)}, expr->location); } - TypeId expectedTy = testArena.addType(FunctionType(expectedArgs, expectedRets)); - - reportErrors(tryUnify(scope, expr->location, follow(*mm), expectedTy)); - - std::optional ret = first(ftv->retTypes); - if (ret) - { - if (isComparison) - { - if (!isBoolean(follow(*ret))) - { - reportError(GenericError{format("Metamethod '%s' must return a boolean", it->second)}, expr->location); - } + return builtinTypes->errorRecoveryType(); + } + } + else + { + reportError(CannotCallNonFunction{*mm}, expr->location); + } - return builtinTypes->booleanType; - } - else - { - return follow(*ret); - } - } - else - { - if (isComparison) - { - reportError(GenericError{format("Metamethod '%s' must return a boolean", it->second)}, expr->location); - } - else - { - reportError(GenericError{format("Metamethod '%s' must return a value", it->second)}, expr->location); - } + return builtinTypes->errorRecoveryType(); + } + // If this is a string comparison, or a concatenation of strings, we + // want to fall through to primitive behavior. + else if (!isEquality && !(isStringOperation && (expr->op == AstExprBinary::Op::Concat || isComparison))) + { + if ((leftMt && !isString(leftType)) || (rightMt && !isString(rightType))) + { + if (isComparison) + { + reportError( + GenericError{format( + "Types '%s' and '%s' cannot be compared with %s because neither type's metatable has a '%s' metamethod", + toString(leftType).c_str(), + toString(rightType).c_str(), + toString(expr->op).c_str(), + it->second + )}, + expr->location + ); + } + else + { + reportError( + GenericError{format( + "Operator %s is not applicable for '%s' and '%s' because neither type's metatable has a '%s' metamethod", + toString(expr->op).c_str(), + toString(leftType).c_str(), + toString(rightType).c_str(), + it->second + )}, + expr->location + ); + } - return builtinTypes->errorRecoveryType(); - } + return builtinTypes->errorRecoveryType(); + } + else if (!leftMt && !rightMt && (get(leftType) || get(rightType))) + { + if (isComparison) + { + reportError( + GenericError{format( + "Types '%s' and '%s' cannot be compared with %s because neither type has a metatable", + toString(leftType).c_str(), + toString(rightType).c_str(), + toString(expr->op).c_str() + )}, + expr->location + ); } else { - reportError(CannotCallNonFunction{*mm}, expr->location); + reportError( + GenericError{format( + "Operator %s is not applicable for '%s' and '%s' because neither type has a metatable", + toString(expr->op).c_str(), + toString(leftType).c_str(), + toString(rightType).c_str() + )}, + expr->location + ); } return builtinTypes->errorRecoveryType(); } - // If this is a string comparison, or a concatenation of strings, we - // want to fall through to primitive behavior. - else if (!isEquality && !(isStringOperation && (expr->op == AstExprBinary::Op::Concat || isComparison))) + } + } + + switch (expr->op) + { + case AstExprBinary::Op::Add: + case AstExprBinary::Op::Sub: + case AstExprBinary::Op::Mul: + case AstExprBinary::Op::Div: + case AstExprBinary::Op::FloorDiv: + case AstExprBinary::Op::Pow: + case AstExprBinary::Op::Mod: + testIsSubtype(leftType, builtinTypes->numberType, expr->left->location); + testIsSubtype(rightType, builtinTypes->numberType, expr->right->location); + + return builtinTypes->numberType; + case AstExprBinary::Op::Concat: + testIsSubtype(leftType, builtinTypes->stringType, expr->left->location); + testIsSubtype(rightType, builtinTypes->stringType, expr->right->location); + + return builtinTypes->stringType; + case AstExprBinary::Op::CompareGe: + case AstExprBinary::Op::CompareGt: + case AstExprBinary::Op::CompareLe: + case AstExprBinary::Op::CompareLt: + { + if (normLeft && normLeft->shouldSuppressErrors()) + return builtinTypes->booleanType; + + // if we're comparing against an uninhabited type, it's unobservable that the comparison did not run + if (normLeft && normalizer.isInhabited(normLeft.get()) == NormalizationResult::False) + return builtinTypes->booleanType; + + if (normLeft && normLeft->isExactlyNumber()) + { + testIsSubtype(rightType, builtinTypes->numberType, expr->right->location); + return builtinTypes->booleanType; + } + + if (normLeft && normLeft->isSubtypeOfString()) + { + testIsSubtype(rightType, builtinTypes->stringType, expr->right->location); + return builtinTypes->booleanType; + } + + reportError( + GenericError{format( + "Types '%s' and '%s' cannot be compared with relational operator %s", + toString(leftType).c_str(), + toString(rightType).c_str(), + toString(expr->op).c_str() + )}, + expr->location + ); + return builtinTypes->errorRecoveryType(); + } + + case AstExprBinary::Op::And: + case AstExprBinary::Op::Or: + case AstExprBinary::Op::CompareEq: + case AstExprBinary::Op::CompareNe: + // Ugly case: we don't care about this possibility, because a + // compound assignment will never exist with one of these operators. + return builtinTypes->anyType; + default: + // Unhandled AstExprBinary::Op possibility. + LUAU_ASSERT(false); + return builtinTypes->errorRecoveryType(); + } +} + +void TypeChecker2::visit(AstExprTypeAssertion* expr) +{ + visit(expr->expr, ValueContext::RValue); + visit(expr->annotation); + + TypeId annotationType = lookupAnnotation(expr->annotation); + TypeId computedType = lookupType(expr->expr); + + switch (shouldSuppressErrors(NotNull{&normalizer}, computedType).orElse(shouldSuppressErrors(NotNull{&normalizer}, annotationType))) + { + case ErrorSuppression::Suppress: + return; + case ErrorSuppression::NormalizationFailed: + reportError(NormalizationTooComplex{}, expr->location); + return; + case ErrorSuppression::DoNotSuppress: + break; + } + + switch (normalizer.isInhabited(computedType)) + { + case NormalizationResult::True: + break; + case NormalizationResult::False: + return; + case NormalizationResult::HitLimits: + reportError(NormalizationTooComplex{}, expr->location); + return; + } + + switch (normalizer.isIntersectionInhabited(computedType, annotationType)) + { + case NormalizationResult::True: + return; + case NormalizationResult::False: + reportError(TypesAreUnrelated{computedType, annotationType}, expr->location); + break; + case NormalizationResult::HitLimits: + reportError(NormalizationTooComplex{}, expr->location); + break; + } +} + +void TypeChecker2::visit(AstExprIfElse* expr) +{ + // TODO! + visit(expr->condition, ValueContext::RValue); + visit(expr->trueExpr, ValueContext::RValue); + visit(expr->falseExpr, ValueContext::RValue); +} + +void TypeChecker2::visit(AstExprInterpString* interpString) +{ + for (AstExpr* expr : interpString->expressions) + visit(expr, ValueContext::RValue); +} + +void TypeChecker2::visit(AstExprError* expr) +{ + // TODO! + for (AstExpr* e : expr->expressions) + visit(e, ValueContext::RValue); +} + +TypeId TypeChecker2::flattenPack(TypePackId pack) +{ + pack = follow(pack); + + if (auto fst = first(pack, /*ignoreHiddenVariadics*/ false)) + return *fst; + else if (auto ftp = get(pack)) + { + TypeId result = module->internalTypes.addType(FreeType{ftp->scope}); + TypePackId freeTail = module->internalTypes.addTypePack(FreeTypePack{ftp->scope}); + + TypePack* resultPack = emplaceTypePack(asMutable(pack)); + resultPack->head.assign(1, result); + resultPack->tail = freeTail; + + return result; + } + else if (get(pack)) + return builtinTypes->errorRecoveryType(); + else if (finite(pack) && size(pack) == 0) + return builtinTypes->nilType; // `(f())` where `f()` returns no values is coerced into `nil` + else + ice->ice("flattenPack got a weird pack!"); +} + +void TypeChecker2::visitGenerics(AstArray generics, AstArray genericPacks) +{ + DenseHashSet seen{AstName{}}; + + for (const auto& g : generics) + { + if (seen.contains(g.name)) + reportError(DuplicateGenericParameter{g.name.value}, g.location); + else + seen.insert(g.name); + + if (g.defaultValue) + visit(g.defaultValue); + } + + for (const auto& g : genericPacks) + { + if (seen.contains(g.name)) + reportError(DuplicateGenericParameter{g.name.value}, g.location); + else + seen.insert(g.name); + + if (g.defaultValue) + visit(g.defaultValue); + } +} + +void TypeChecker2::visit(AstType* ty) +{ + TypeId* resolvedTy = module->astResolvedTypes.find(ty); + if (resolvedTy) + checkForTypeFunctionInhabitance(follow(*resolvedTy), ty->location); + + if (auto t = ty->as()) + return visit(t); + else if (auto t = ty->as()) + return visit(t); + else if (auto t = ty->as()) + return visit(t); + else if (auto t = ty->as()) + return visit(t); + else if (auto t = ty->as()) + return visit(t); + else if (auto t = ty->as()) + return visit(t); +} + +void TypeChecker2::visit(AstTypeReference* ty) +{ + // No further validation is necessary in this case. The main logic for + // _luau_print is contained in lookupAnnotation. + if (FFlag::DebugLuauMagicTypes && ty->name == "_luau_print") + return; + + for (const AstTypeOrPack& param : ty->parameters) + { + if (param.type) + visit(param.type); + else + visit(param.typePack); + } + + Scope* scope = findInnermostScope(ty->location); + LUAU_ASSERT(scope); + + std::optional alias = (ty->prefix) ? scope->lookupImportedType(ty->prefix->value, ty->name.value) : scope->lookupType(ty->name.value); + + if (alias.has_value()) + { + size_t typesRequired = alias->typeParams.size(); + size_t packsRequired = alias->typePackParams.size(); + + bool hasDefaultTypes = std::any_of( + alias->typeParams.begin(), + alias->typeParams.end(), + [](auto&& el) + { + return el.defaultValue.has_value(); + } + ); + + bool hasDefaultPacks = std::any_of( + alias->typePackParams.begin(), + alias->typePackParams.end(), + [](auto&& el) + { + return el.defaultValue.has_value(); + } + ); + + if (!ty->hasParameterList) + { + if ((!alias->typeParams.empty() && !hasDefaultTypes) || (!alias->typePackParams.empty() && !hasDefaultPacks)) + { + reportError(GenericError{"Type parameter list is required"}, ty->location); + } + } + + size_t typesProvided = 0; + size_t extraTypes = 0; + size_t packsProvided = 0; + + for (const AstTypeOrPack& p : ty->parameters) + { + if (p.type) { - if ((leftMt && !isString(leftType)) || (rightMt && !isString(rightType))) + if (packsProvided != 0) { - if (isComparison) - { - reportError(GenericError{format( - "Types '%s' and '%s' cannot be compared with %s because neither type's metatable has a '%s' metamethod", - toString(leftType).c_str(), toString(rightType).c_str(), toString(expr->op).c_str(), it->second)}, - expr->location); - } - else - { - reportError(GenericError{format( - "Operator %s is not applicable for '%s' and '%s' because neither type's metatable has a '%s' metamethod", - toString(expr->op).c_str(), toString(leftType).c_str(), toString(rightType).c_str(), it->second)}, - expr->location); - } + reportError(GenericError{"Type parameters must come before type pack parameters"}, ty->location); + continue; + } - return builtinTypes->errorRecoveryType(); + if (typesProvided < typesRequired) + { + typesProvided += 1; } - else if (!leftMt && !rightMt && (get(leftType) || get(rightType))) + else { - if (isComparison) - { - reportError(GenericError{format("Types '%s' and '%s' cannot be compared with %s because neither type has a metatable", - toString(leftType).c_str(), toString(rightType).c_str(), toString(expr->op).c_str())}, - expr->location); - } - else - { - reportError(GenericError{format("Operator %s is not applicable for '%s' and '%s' because neither type has a metatable", - toString(expr->op).c_str(), toString(leftType).c_str(), toString(rightType).c_str())}, - expr->location); - } + extraTypes += 1; + } + } + else if (p.typePack) + { + std::optional tp = lookupPackAnnotation(p.typePack); + if (!tp.has_value()) + continue; - return builtinTypes->errorRecoveryType(); + if (typesProvided < typesRequired && size(*tp) == 1 && finite(*tp) && first(*tp)) + { + typesProvided += 1; + } + else + { + packsProvided += 1; } } } - switch (expr->op) - { - case AstExprBinary::Op::Add: - case AstExprBinary::Op::Sub: - case AstExprBinary::Op::Mul: - case AstExprBinary::Op::Div: - case AstExprBinary::Op::Pow: - case AstExprBinary::Op::Mod: - reportErrors(tryUnify(scope, expr->left->location, leftType, builtinTypes->numberType)); - reportErrors(tryUnify(scope, expr->right->location, rightType, builtinTypes->numberType)); - - return builtinTypes->numberType; - case AstExprBinary::Op::Concat: - reportErrors(tryUnify(scope, expr->left->location, leftType, builtinTypes->stringType)); - reportErrors(tryUnify(scope, expr->right->location, rightType, builtinTypes->stringType)); - - return builtinTypes->stringType; - case AstExprBinary::Op::CompareGe: - case AstExprBinary::Op::CompareGt: - case AstExprBinary::Op::CompareLe: - case AstExprBinary::Op::CompareLt: - if (isNumber(leftType)) + if (extraTypes != 0 && packsProvided == 0) + { + // Extra types are only collected into a pack if a pack is expected + if (packsRequired != 0) + packsProvided += 1; + else + typesProvided += extraTypes; + } + + for (size_t i = typesProvided; i < typesRequired; ++i) + { + if (alias->typeParams[i].defaultValue) { - reportErrors(tryUnify(scope, expr->right->location, rightType, builtinTypes->numberType)); - return builtinTypes->numberType; + typesProvided += 1; } - else if (isString(leftType)) + } + + for (size_t i = packsProvided; i < packsRequired; ++i) + { + if (alias->typePackParams[i].defaultValue) { - reportErrors(tryUnify(scope, expr->right->location, rightType, builtinTypes->stringType)); - return builtinTypes->stringType; + packsProvided += 1; } - else + } + + if (extraTypes == 0 && packsProvided + 1 == packsRequired) + { + packsProvided += 1; + } + + if (typesProvided != typesRequired || packsProvided != packsRequired) + { + reportError( + IncorrectGenericParameterCount{ + /* name */ ty->name.value, + /* typeFun */ *alias, + /* actualParameters */ typesProvided, + /* actualPackParameters */ packsProvided, + }, + ty->location + ); + } + } + else + { + if (scope->lookupPack(ty->name.value)) + { + reportError( + SwappedGenericTypeParameter{ + ty->name.value, + SwappedGenericTypeParameter::Kind::Type, + }, + ty->location + ); + } + else + { + std::string symbol = ""; + if (ty->prefix) { - reportError(GenericError{format("Types '%s' and '%s' cannot be compared with relational operator %s", toString(leftType).c_str(), - toString(rightType).c_str(), toString(expr->op).c_str())}, - expr->location); - return builtinTypes->errorRecoveryType(); + symbol += (*(ty->prefix)).value; + symbol += "."; } - case AstExprBinary::Op::And: - case AstExprBinary::Op::Or: - case AstExprBinary::Op::CompareEq: - case AstExprBinary::Op::CompareNe: - // Ugly case: we don't care about this possibility, because a - // compound assignment will never exist with one of these operators. - return builtinTypes->anyType; - default: - // Unhandled AstExprBinary::Op possibility. - LUAU_ASSERT(false); - return builtinTypes->errorRecoveryType(); + symbol += ty->name.value; + + reportError(UnknownSymbol{symbol, UnknownSymbol::Context::Type}, ty->location); } } +} + +void TypeChecker2::visit(AstTypeTable* table) +{ + // TODO! + + for (const AstTableProp& prop : table->props) + visit(prop.type); - void visit(AstExprTypeAssertion* expr) + if (table->indexer) { - visit(expr->expr, RValue); - visit(expr->annotation); + visit(table->indexer->indexType); + visit(table->indexer->resultType); + } +} + +void TypeChecker2::visit(AstTypeFunction* ty) +{ + visitGenerics(ty->generics, ty->genericPacks); + visit(ty->argTypes); + visit(ty->returnTypes); +} - TypeId annotationType = lookupAnnotation(expr->annotation); - TypeId computedType = lookupType(expr->expr); +void TypeChecker2::visit(AstTypeTypeof* ty) +{ + visit(ty->expr, ValueContext::RValue); +} - // Note: As an optimization, we try 'number <: number | string' first, as that is the more likely case. - if (isSubtype(annotationType, computedType, stack.back())) - return; +void TypeChecker2::visit(AstTypeUnion* ty) +{ + // TODO! + for (AstType* type : ty->types) + visit(type); +} + +void TypeChecker2::visit(AstTypeIntersection* ty) +{ + // TODO! + for (AstType* type : ty->types) + visit(type); +} + +void TypeChecker2::visit(AstTypePack* pack) +{ + if (auto p = pack->as()) + return visit(p); + else if (auto p = pack->as()) + return visit(p); + else if (auto p = pack->as()) + return visit(p); +} + +void TypeChecker2::visit(AstTypePackExplicit* tp) +{ + // TODO! + for (AstType* type : tp->typeList.types) + visit(type); + + if (tp->typeList.tailType) + visit(tp->typeList.tailType); +} - if (isSubtype(computedType, annotationType, stack.back())) - return; +void TypeChecker2::visit(AstTypePackVariadic* tp) +{ + // TODO! + visit(tp->variadicType); +} - reportError(TypesAreUnrelated{computedType, annotationType}, expr->location); - } +void TypeChecker2::visit(AstTypePackGeneric* tp) +{ + Scope* scope = findInnermostScope(tp->location); + LUAU_ASSERT(scope); - void visit(AstExprIfElse* expr) + std::optional alias = scope->lookupPack(tp->genericName.value); + if (!alias.has_value()) { - // TODO! - visit(expr->condition, RValue); - visit(expr->trueExpr, RValue); - visit(expr->falseExpr, RValue); + if (scope->lookupType(tp->genericName.value)) + { + reportError( + SwappedGenericTypeParameter{ + tp->genericName.value, + SwappedGenericTypeParameter::Kind::Pack, + }, + tp->location + ); + } + else + { + reportError(UnknownSymbol{tp->genericName.value, UnknownSymbol::Context::Type}, tp->location); + } } +} - void visit(AstExprInterpString* interpString) - { - for (AstExpr* expr : interpString->expressions) - visit(expr, RValue); - } +template +Reasonings TypeChecker2::explainReasonings_(TID subTy, TID superTy, Location location, const SubtypingResult& r) +{ + if (r.reasoning.empty()) + return {}; - void visit(AstExprError* expr) + std::vector reasons; + bool suppressed = true; + for (const SubtypingReasoning& reasoning : r.reasoning) { - // TODO! - for (AstExpr* e : expr->expressions) - visit(e, RValue); - } + if (reasoning.subPath.empty() && reasoning.superPath.empty()) + continue; - /** Extract a TypeId for the first type of the provided pack. - * - * Note that this may require modifying some types. I hope this doesn't cause problems! - */ - TypeId flattenPack(TypePackId pack) - { - pack = follow(pack); + std::optional optSubLeaf = traverse(subTy, reasoning.subPath, builtinTypes); + std::optional optSuperLeaf = traverse(superTy, reasoning.superPath, builtinTypes); - if (auto fst = first(pack, /*ignoreHiddenVariadics*/ false)) - return *fst; - else if (auto ftp = get(pack)) - { - TypeId result = testArena.addType(FreeType{ftp->scope}); - TypePackId freeTail = testArena.addTypePack(FreeTypePack{ftp->scope}); + if (!optSubLeaf || !optSuperLeaf) + ice->ice("Subtyping test returned a reasoning with an invalid path", location); - TypePack& resultPack = asMutable(pack)->ty.emplace(); - resultPack.head.assign(1, result); - resultPack.tail = freeTail; + const TypeOrPack& subLeaf = *optSubLeaf; + const TypeOrPack& superLeaf = *optSuperLeaf; - return result; - } - else if (get(pack)) - return builtinTypes->errorRecoveryType(); - else if (finite(pack) && size(pack) == 0) - return builtinTypes->nilType; // `(f())` where `f()` returns no values is coerced into `nil` - else - ice.ice("flattenPack got a weird pack!"); - } + auto subLeafTy = get(subLeaf); + auto superLeafTy = get(superLeaf); - void visitGenerics(AstArray generics, AstArray genericPacks) - { - DenseHashSet seen{AstName{}}; + auto subLeafTp = get(subLeaf); + auto superLeafTp = get(superLeaf); - for (const auto& g : generics) - { - if (seen.contains(g.name)) - reportError(DuplicateGenericParameter{g.name.value}, g.location); - else - seen.insert(g.name); + if (!subLeafTy && !superLeafTy && !subLeafTp && !superLeafTp) + ice->ice("Subtyping test returned a reasoning where one path ends at a type and the other ends at a pack.", location); - if (g.defaultValue) - visit(g.defaultValue); - } + std::string relation = "a subtype of"; + if (reasoning.variance == SubtypingVariance::Invariant) + relation = "exactly"; + else if (reasoning.variance == SubtypingVariance::Contravariant) + relation = "a supertype of"; + + std::string reason; + if (reasoning.subPath == reasoning.superPath) + reason = "at " + toString(reasoning.subPath) + ", " + toString(subLeaf) + " is not " + relation + " " + toString(superLeaf); + else + reason = "type " + toString(subTy) + toString(reasoning.subPath, /* prefixDot */ true) + " (" + toString(subLeaf) + ") is not " + + relation + " " + toString(superTy) + toString(reasoning.superPath, /* prefixDot */ true) + " (" + toString(superLeaf) + ")"; + + reasons.push_back(reason); - for (const auto& g : genericPacks) + // if we haven't already proved this isn't suppressing, we have to keep checking. + if (suppressed) { - if (seen.contains(g.name)) - reportError(DuplicateGenericParameter{g.name.value}, g.location); + if (subLeafTy && superLeafTy) + suppressed &= isErrorSuppressing(location, *subLeafTy) || isErrorSuppressing(location, *superLeafTy); else - seen.insert(g.name); - - if (g.defaultValue) - visit(g.defaultValue); + suppressed &= isErrorSuppressing(location, *subLeafTp) || isErrorSuppressing(location, *superLeafTp); } } - void visit(AstType* ty) + return {std::move(reasons), suppressed}; +} + +Reasonings TypeChecker2::explainReasonings(TypeId subTy, TypeId superTy, Location location, const SubtypingResult& r) +{ + return explainReasonings_(subTy, superTy, location, r); +} + +Reasonings TypeChecker2::explainReasonings(TypePackId subTp, TypePackId superTp, Location location, const SubtypingResult& r) +{ + return explainReasonings_(subTp, superTp, location, r); +} + +void TypeChecker2::explainError(TypeId subTy, TypeId superTy, Location location, const SubtypingResult& result) +{ + switch (shouldSuppressErrors(NotNull{&normalizer}, subTy).orElse(shouldSuppressErrors(NotNull{&normalizer}, superTy))) { - if (auto t = ty->as()) - return visit(t); - else if (auto t = ty->as()) - return visit(t); - else if (auto t = ty->as()) - return visit(t); - else if (auto t = ty->as()) - return visit(t); - else if (auto t = ty->as()) - return visit(t); - else if (auto t = ty->as()) - return visit(t); + case ErrorSuppression::Suppress: + return; + case ErrorSuppression::NormalizationFailed: + reportError(NormalizationTooComplex{}, location); + break; + case ErrorSuppression::DoNotSuppress: + break; } - void visit(AstTypeReference* ty) + Reasonings reasonings = explainReasonings(subTy, superTy, location, result); + + if (!reasonings.suppressed) + reportError(TypeMismatch{superTy, subTy, reasonings.toString()}, location); +} + +void TypeChecker2::explainError(TypePackId subTy, TypePackId superTy, Location location, const SubtypingResult& result) +{ + switch (shouldSuppressErrors(NotNull{&normalizer}, subTy).orElse(shouldSuppressErrors(NotNull{&normalizer}, superTy))) { - // No further validation is necessary in this case. The main logic for - // _luau_print is contained in lookupAnnotation. - if (FFlag::DebugLuauMagicTypes && ty->name == "_luau_print" && ty->parameters.size > 0) - return; + case ErrorSuppression::Suppress: + return; + case ErrorSuppression::NormalizationFailed: + reportError(NormalizationTooComplex{}, location); + break; + case ErrorSuppression::DoNotSuppress: + break; + } - for (const AstTypeOrPack& param : ty->parameters) - { - if (param.type) - visit(param.type); - else - visit(param.typePack); - } + Reasonings reasonings = explainReasonings(subTy, superTy, location, result); - Scope* scope = findInnermostScope(ty->location); - LUAU_ASSERT(scope); + if (!reasonings.suppressed) + reportError(TypePackMismatch{superTy, subTy, reasonings.toString()}, location); +} - std::optional alias = - (ty->prefix) ? scope->lookupImportedType(ty->prefix->value, ty->name.value) : scope->lookupType(ty->name.value); +bool TypeChecker2::testIsSubtype(TypeId subTy, TypeId superTy, Location location) +{ + NotNull scope{findInnermostScope(location)}; + SubtypingResult r = subtyping->isSubtype(subTy, superTy, scope); - if (alias.has_value()) - { - size_t typesRequired = alias->typeParams.size(); - size_t packsRequired = alias->typePackParams.size(); + if (r.normalizationTooComplex) + reportError(NormalizationTooComplex{}, location); - bool hasDefaultTypes = std::any_of(alias->typeParams.begin(), alias->typeParams.end(), [](auto&& el) { - return el.defaultValue.has_value(); - }); + if (!r.isSubtype) + explainError(subTy, superTy, location, r); - bool hasDefaultPacks = std::any_of(alias->typePackParams.begin(), alias->typePackParams.end(), [](auto&& el) { - return el.defaultValue.has_value(); - }); + return r.isSubtype; +} - if (!ty->hasParameterList) - { - if ((!alias->typeParams.empty() && !hasDefaultTypes) || (!alias->typePackParams.empty() && !hasDefaultPacks)) - { - reportError(GenericError{"Type parameter list is required"}, ty->location); - } - } +bool TypeChecker2::testIsSubtype(TypePackId subTy, TypePackId superTy, Location location) +{ + NotNull scope{findInnermostScope(location)}; + SubtypingResult r = subtyping->isSubtype(subTy, superTy, scope); - size_t typesProvided = 0; - size_t extraTypes = 0; - size_t packsProvided = 0; + if (r.normalizationTooComplex) + reportError(NormalizationTooComplex{}, location); - for (const AstTypeOrPack& p : ty->parameters) - { - if (p.type) - { - if (packsProvided != 0) - { - reportError(GenericError{"Type parameters must come before type pack parameters"}, ty->location); - } + if (!r.isSubtype) + explainError(subTy, superTy, location, r); - if (typesProvided < typesRequired) - { - typesProvided += 1; - } - else - { - extraTypes += 1; - } - } - else if (p.typePack) - { - TypePackId tp = lookupPackAnnotation(p.typePack); + return r.isSubtype; +} - if (typesProvided < typesRequired && size(tp) == 1 && finite(tp) && first(tp)) - { - typesProvided += 1; - } - else - { - packsProvided += 1; - } - } - } +void TypeChecker2::reportError(TypeErrorData data, const Location& location) +{ + if (auto utk = get_if(&data)) + diagnoseMissingTableKey(utk, data); - if (extraTypes != 0 && packsProvided == 0) - { - packsProvided += 1; - } + module->errors.emplace_back(location, module->name, std::move(data)); - for (size_t i = typesProvided; i < typesRequired; ++i) - { - if (alias->typeParams[i].defaultValue) - { - typesProvided += 1; - } - } + if (logger) + logger->captureTypeCheckError(module->errors.back()); +} - for (size_t i = packsProvided; i < packsRequired; ++i) - { - if (alias->typePackParams[i].defaultValue) - { - packsProvided += 1; - } - } +void TypeChecker2::reportError(TypeError e) +{ + reportError(std::move(e.data), e.location); +} - if (extraTypes == 0 && packsProvided + 1 == packsRequired) - { - packsProvided += 1; - } +void TypeChecker2::reportErrors(ErrorVec errors) +{ + for (TypeError e : errors) + reportError(std::move(e)); +} - if (typesProvided != typesRequired || packsProvided != packsRequired) - { - reportError(IncorrectGenericParameterCount{ - /* name */ ty->name.value, - /* typeFun */ *alias, - /* actualParameters */ typesProvided, - /* actualPackParameters */ packsProvided, - }, - ty->location); - } - } - else - { - if (scope->lookupPack(ty->name.value)) - { - reportError( - SwappedGenericTypeParameter{ - ty->name.value, - SwappedGenericTypeParameter::Kind::Type, - }, - ty->location); - } - else - { - reportError(UnknownSymbol{ty->name.value, UnknownSymbol::Context::Type}, ty->location); - } - } - } +/* A helper for checkIndexTypeFromType. + * + * Returns a pair: + * * A boolean indicating that at least one of the constituent types + * contains the prop, and + * * A vector of types that do not contain the prop. + */ +PropertyTypes TypeChecker2::lookupProp( + const NormalizedType* norm, + const std::string& prop, + ValueContext context, + const Location& location, + TypeId astIndexExprType, + std::vector& errors +) +{ + std::vector typesOfProp; + std::vector typesMissingTheProp; - void visit(AstTypeTable* table) + // this is `false` if we ever hit the resource limits during any of our uses of `fetch`. + bool normValid = true; + + auto fetch = [&](TypeId ty) { - // TODO! + NormalizationResult result = normalizer.isInhabited(ty); + if (result == NormalizationResult::HitLimits) + normValid = false; + if (result != NormalizationResult::True) + return; - for (const AstTableProp& prop : table->props) - visit(prop.type); + DenseHashSet seen{nullptr}; + PropertyType res = hasIndexTypeFromType(ty, prop, context, location, seen, astIndexExprType, errors); - if (table->indexer) + if (res.present == NormalizationResult::HitLimits) { - visit(table->indexer->indexType); - visit(table->indexer->resultType); + normValid = false; + return; } - } - void visit(AstTypeFunction* ty) - { - visitGenerics(ty->generics, ty->genericPacks); - visit(ty->argTypes); - visit(ty->returnTypes); - } + if (res.present == NormalizationResult::True && res.result) + typesOfProp.emplace_back(*res.result); - void visit(AstTypeTypeof* ty) - { - visit(ty->expr, RValue); - } + if (res.present == NormalizationResult::False) + typesMissingTheProp.push_back(ty); + }; - void visit(AstTypeUnion* ty) - { - // TODO! - for (AstType* type : ty->types) - visit(type); - } + if (normValid) + fetch(norm->tops); + if (normValid) + fetch(norm->booleans); - void visit(AstTypeIntersection* ty) + if (normValid) { - // TODO! - for (AstType* type : ty->types) - visit(type); - } + for (const auto& [ty, _negations] : norm->classes.classes) + { + fetch(ty); - void visit(AstTypePack* pack) - { - if (auto p = pack->as()) - return visit(p); - else if (auto p = pack->as()) - return visit(p); - else if (auto p = pack->as()) - return visit(p); + if (!normValid) + break; + } } - void visit(AstTypePackExplicit* tp) + if (normValid) + fetch(norm->errors); + if (normValid) + fetch(norm->nils); + if (normValid) + fetch(norm->numbers); + if (normValid && !norm->strings.isNever()) + fetch(builtinTypes->stringType); + if (normValid) + fetch(norm->threads); + if (normValid) + fetch(norm->buffers); + + if (normValid) { - // TODO! - for (AstType* type : tp->typeList.types) - visit(type); + for (TypeId ty : norm->tables) + { + fetch(ty); - if (tp->typeList.tailType) - visit(tp->typeList.tailType); + if (!normValid) + break; + } } - void visit(AstTypePackVariadic* tp) + if (normValid && norm->functions.isTop) + fetch(builtinTypes->functionType); + else if (normValid && !norm->functions.isNever()) { - // TODO! - visit(tp->variadicType); + if (norm->functions.parts.size() == 1) + fetch(norm->functions.parts.front()); + else + { + std::vector parts; + parts.insert(parts.end(), norm->functions.parts.begin(), norm->functions.parts.end()); + fetch(module->internalTypes.addType(IntersectionType{std::move(parts)})); + } } - void visit(AstTypePackGeneric* tp) + if (normValid) { - Scope* scope = findInnermostScope(tp->location); - LUAU_ASSERT(scope); - - std::optional alias = scope->lookupPack(tp->genericName.value); - if (!alias.has_value()) + for (const auto& [tyvar, intersect] : norm->tyvars) { - if (scope->lookupType(tp->genericName.value)) + if (get(intersect->tops)) { - reportError( - SwappedGenericTypeParameter{ - tp->genericName.value, - SwappedGenericTypeParameter::Kind::Pack, - }, - tp->location); + TypeId ty = normalizer.typeFromNormal(*intersect); + fetch(module->internalTypes.addType(IntersectionType{{tyvar, ty}})); } else - { - reportError(UnknownSymbol{tp->genericName.value, UnknownSymbol::Context::Type}, tp->location); - } + fetch(follow(tyvar)); + + if (!normValid) + break; } } - void reduceTypes() - { - if (FFlag::DebugLuauDontReduceTypes) - return; + return {typesOfProp, typesMissingTheProp}; +} - for (auto [_, scope] : module->scopes) - { - for (auto& [_, b] : scope->bindings) - { - if (auto reduced = module->reduction->reduce(b.typeId)) - b.typeId = *reduced; - } - if (auto reduced = module->reduction->reduce(scope->returnType)) - scope->returnType = *reduced; +void TypeChecker2::checkIndexTypeFromType( + TypeId tableTy, + const std::string& prop, + ValueContext context, + const Location& location, + TypeId astIndexExprType +) +{ + std::shared_ptr norm = normalizer.normalize(tableTy); + if (!norm) + { + reportError(NormalizationTooComplex{}, location); + return; + } - if (scope->varargPack) - { - if (auto reduced = module->reduction->reduce(*scope->varargPack)) - scope->varargPack = *reduced; - } + // if the type is error suppressing, we don't actually have any work left to do. + if (norm->shouldSuppressErrors()) + return; - auto reduceMap = [this](auto& map) { - for (auto& [_, tf] : map) - { - if (auto reduced = module->reduction->reduce(tf)) - tf = *reduced; - } - }; + std::vector dummy; + const auto propTypes = lookupProp(norm.get(), prop, context, location, astIndexExprType, module->errors); - reduceMap(scope->exportedTypeBindings); - reduceMap(scope->privateTypeBindings); - reduceMap(scope->privateTypePackBindings); - for (auto& [_, space] : scope->importedTypeBindings) - reduceMap(space); + if (propTypes.foundMissingProp()) + { + if (propTypes.foundOneProp()) + reportError(MissingUnionProperty{tableTy, propTypes.missingProp, prop}, location); + // For class LValues, we don't want to report an extension error, + // because classes come into being with full knowledge of their + // shape. We instead want to report the unknown property error of + // the `else` branch. + else if (context == ValueContext::LValue && !get(tableTy)) + { + const auto lvPropTypes = lookupProp(norm.get(), prop, ValueContext::RValue, location, astIndexExprType, dummy); + if (lvPropTypes.foundOneProp() && lvPropTypes.noneMissingProp()) + reportError(PropertyAccessViolation{tableTy, prop, PropertyAccessViolation::CannotWrite}, location); + else if (get(tableTy) || get(tableTy)) + reportError(NotATable{tableTy}, location); + else + reportError(CannotExtendTable{tableTy, CannotExtendTable::Property, prop}, location); } + else if (context == ValueContext::RValue && !get(tableTy)) + { + const auto rvPropTypes = lookupProp(norm.get(), prop, ValueContext::LValue, location, astIndexExprType, dummy); + if (rvPropTypes.foundOneProp() && rvPropTypes.noneMissingProp()) + reportError(PropertyAccessViolation{tableTy, prop, PropertyAccessViolation::CannotRead}, location); + else + reportError(UnknownProperty{tableTy, prop}, location); + } + else + reportError(UnknownProperty{tableTy, prop}, location); + } +} - auto reduceOrError = [this](auto& map) { - for (auto [ast, t] : map) - { - if (!t) - continue; // Reminder: this implies that the recursion limit was exceeded. - else if (auto reduced = module->reduction->reduce(t)) - map[ast] = *reduced; - else - reportError(NormalizationTooComplex{}, ast->location); - } - }; - - module->astOriginalResolvedTypes = module->astResolvedTypes; +PropertyType TypeChecker2::hasIndexTypeFromType( + TypeId ty, + const std::string& prop, + ValueContext context, + const Location& location, + DenseHashSet& seen, + TypeId astIndexExprType, + std::vector& errors +) +{ + // If we have already encountered this type, we must assume that some + // other codepath will do the right thing and signal false if the + // property is not present. + if (seen.contains(ty)) + return {NormalizationResult::True, {}}; + seen.insert(ty); - // Both [`Module::returnType`] and [`Module::exportedTypeBindings`] are empty here, and - // is populated by [`Module::clonePublicInterface`] in the future, so by that point these - // two aforementioned fields will only contain types that are irreducible. - reduceOrError(module->astTypes); - reduceOrError(module->astTypePacks); - reduceOrError(module->astExpectedTypes); - reduceOrError(module->astOriginalCallTypes); - reduceOrError(module->astOverloadResolvedTypes); - reduceOrError(module->astResolvedTypes); - reduceOrError(module->astResolvedTypePacks); - } + if (get(ty) || get(ty) || get(ty)) + return {NormalizationResult::True, {ty}}; - template - bool isSubtype(TID subTy, TID superTy, NotNull scope) + if (isString(ty)) { - TypeArena arena; - Unifier u{NotNull{&normalizer}, Mode::Strict, scope, Location{}, Covariant}; - u.useScopes = true; - - u.tryUnify(subTy, superTy); - const bool ok = u.errors.empty() && u.log.empty(); - return ok; + std::optional mtIndex = Luau::findMetatableEntry(builtinTypes, errors, builtinTypes->stringType, "__index", location); + LUAU_ASSERT(mtIndex); + ty = *mtIndex; } - template - ErrorVec tryUnify(NotNull scope, const Location& location, TID subTy, TID superTy, CountMismatch::Context context = CountMismatch::Arg) + if (auto tt = getTableType(ty)) { - Unifier u{NotNull{&normalizer}, Mode::Strict, scope, location, Covariant}; - u.ctx = context; - u.useScopes = true; - u.tryUnify(subTy, superTy); + if (auto resTy = findTablePropertyRespectingMeta(builtinTypes, errors, ty, prop, context, location)) + return {NormalizationResult::True, resTy}; - return std::move(u.errors); - } + if (tt->indexer) + { + TypeId indexType = follow(tt->indexer->indexType); + TypeId givenType = module->internalTypes.addType(SingletonType{StringSingleton{prop}}); + if (isSubtype(givenType, indexType, NotNull{module->getModuleScope().get()}, builtinTypes, *ice)) + return {NormalizationResult::True, {tt->indexer->indexResultType}}; + } - void reportError(TypeErrorData data, const Location& location) - { - module->errors.emplace_back(location, sourceModule->name, std::move(data)); - if (logger) - logger->captureTypeCheckError(module->errors.back()); + // if we are in a conditional context, we treat the property as present and `unknown` because + // we may be _refining_ `tableTy` to include that property. we will want to revisit this a bit + // in the future once luau has support for exact tables since this only applies when inexact. + return {inConditional(typeContext) ? NormalizationResult::True : NormalizationResult::False, {builtinTypes->unknownType}}; } - - void reportError(TypeError e) + else if (const ClassType* cls = get(ty)) { - reportError(std::move(e.data), e.location); + // If the property doesn't exist on the class, we consult the indexer + // We need to check if the type of the index expression foo (x[foo]) + // is compatible with the indexer's indexType + // Construct the intersection and test inhabitedness! + if (auto property = lookupClassProp(cls, prop)) + return {NormalizationResult::True, context == ValueContext::LValue ? property->writeTy : property->readTy}; + if (cls->indexer) + { + TypeId inhabitatedTestType = module->internalTypes.addType(IntersectionType{{cls->indexer->indexType, astIndexExprType}}); + return {normalizer.isInhabited(inhabitatedTestType), {cls->indexer->indexResultType}}; + } + return {NormalizationResult::False, {}}; } - - void reportErrors(ErrorVec errors) + else if (const UnionType* utv = get(ty)) { - for (TypeError e : errors) - reportError(std::move(e)); - } + std::vector parts; + parts.reserve(utv->options.size()); - void checkIndexTypeFromType(TypeId tableTy, const NormalizedType& norm, const std::string& prop, const Location& location, ValueContext context) - { - bool foundOneProp = false; - std::vector typesMissingTheProp; + for (TypeId part : utv) + { + PropertyType result = hasIndexTypeFromType(part, prop, context, location, seen, astIndexExprType, errors); - auto fetch = [&](TypeId ty) { - if (!normalizer.isInhabited(ty)) - return; + if (result.present != NormalizationResult::True) + return {result.present, {}}; + if (result.result) + parts.emplace_back(*result.result); + } - bool found = hasIndexTypeFromType(ty, prop, location); - foundOneProp |= found; - if (!found) - typesMissingTheProp.push_back(ty); - }; + if (parts.size() == 0) + return {NormalizationResult::False, {}}; - fetch(norm.tops); - fetch(norm.booleans); + if (parts.size() == 1) + return {NormalizationResult::True, {parts[0]}}; - if (FFlag::LuauNegatedClassTypes) - { - for (const auto& [ty, _negations] : norm.classes.classes) - { - fetch(ty); - } - } + TypeId propTy; + if (context == ValueContext::LValue) + propTy = module->internalTypes.addType(IntersectionType{parts}); else - { - for (TypeId ty : norm.DEPRECATED_classes) - fetch(ty); - } - fetch(norm.errors); - fetch(norm.nils); - fetch(norm.numbers); - if (!norm.strings.isNever()) - fetch(builtinTypes->stringType); - fetch(norm.threads); - for (TypeId ty : norm.tables) - fetch(ty); - if (norm.functions.isTop) - fetch(builtinTypes->functionType); - else if (!norm.functions.isNever()) - { - if (norm.functions.parts->size() == 1) - fetch(norm.functions.parts->front()); - else - { - std::vector parts; - parts.insert(parts.end(), norm.functions.parts->begin(), norm.functions.parts->end()); - fetch(testArena.addType(IntersectionType{std::move(parts)})); - } - } - for (const auto& [tyvar, intersect] : norm.tyvars) - { - if (get(intersect->tops)) - { - TypeId ty = normalizer.typeFromNormal(*intersect); - fetch(testArena.addType(IntersectionType{{tyvar, ty}})); - } - else - fetch(tyvar); - } + propTy = module->internalTypes.addType(UnionType{parts}); - if (!typesMissingTheProp.empty()) + return {NormalizationResult::True, propTy}; + } + else if (const IntersectionType* itv = get(ty)) + { + for (TypeId part : itv) { - if (foundOneProp) - reportError(MissingUnionProperty{tableTy, typesMissingTheProp, prop}, location); - else if (context == LValue) - reportError(CannotExtendTable{tableTy, CannotExtendTable::Property, prop}, location); - else - reportError(UnknownProperty{tableTy, prop}, location); + PropertyType result = hasIndexTypeFromType(part, prop, context, location, seen, astIndexExprType, errors); + if (result.present != NormalizationResult::False) + return result; } + + return {NormalizationResult::False, {}}; } + else if (const PrimitiveType* pt = get(ty)) + return {(inConditional(typeContext) && pt->type == PrimitiveType::Table) ? NormalizationResult::True : NormalizationResult::False, {ty}}; + else + return {NormalizationResult::False, {}}; +} - bool hasIndexTypeFromType(TypeId ty, const std::string& prop, const Location& location) - { - if (get(ty) || get(ty) || get(ty)) - return true; - if (isString(ty)) + +void TypeChecker2::diagnoseMissingTableKey(UnknownProperty* utk, TypeErrorData& data) const +{ + std::string_view sv(utk->key); + std::set candidates; + + auto accumulate = [&](const TableType::Props& props) + { + for (const auto& [name, ty] : props) { - std::optional mtIndex = Luau::findMetatableEntry(builtinTypes, module->errors, builtinTypes->stringType, "__index", location); - LUAU_ASSERT(mtIndex); - ty = *mtIndex; + if (sv != name && equalsLower(sv, name)) + candidates.insert(name); } + }; - if (auto tt = getTableType(ty)) + if (auto ttv = getTableType(utk->table)) + accumulate(ttv->props); + else if (auto ctv = get(follow(utk->table))) + { + while (ctv) { - if (findTablePropertyRespectingMeta(builtinTypes, module->errors, ty, prop, location)) - return true; + accumulate(ctv->props); - else if (tt->indexer && isPrim(tt->indexer->indexType, PrimitiveType::String)) - return true; + if (!ctv->parent) + break; - else - return false; - } - else if (const ClassType* cls = get(ty)) - return bool(lookupClassProp(cls, prop)); - else if (const UnionType* utv = get(ty)) - ice.ice("getIndexTypeFromTypeHelper cannot take a UnionType"); - else if (const IntersectionType* itv = get(ty)) - return std::any_of(begin(itv), end(itv), [&](TypeId part) { - return hasIndexTypeFromType(part, prop, location); - }); - else - return false; + ctv = get(*ctv->parent); + LUAU_ASSERT(ctv); + } } -}; -void check(NotNull builtinTypes, DcrLogger* logger, const SourceModule& sourceModule, Module* module) + if (!candidates.empty()) + data = TypeErrorData(UnknownPropButFoundLikeProp{utk->table, utk->key, candidates}); +} + +bool TypeChecker2::isErrorSuppressing(Location loc, TypeId ty) { - TypeChecker2 typeChecker{builtinTypes, logger, &sourceModule, module}; - typeChecker.reduceTypes(); - typeChecker.visit(sourceModule.root); + switch (shouldSuppressErrors(NotNull{&normalizer}, ty)) + { + case ErrorSuppression::DoNotSuppress: + return false; + case ErrorSuppression::Suppress: + return true; + case ErrorSuppression::NormalizationFailed: + reportError(NormalizationTooComplex{}, loc); + return false; + }; - unfreeze(module->interfaceTypes); - copyErrors(module->errors, module->interfaceTypes); - freeze(module->interfaceTypes); + LUAU_ASSERT(false); + return false; // UNREACHABLE +} + +bool TypeChecker2::isErrorSuppressing(Location loc1, TypeId ty1, Location loc2, TypeId ty2) +{ + return isErrorSuppressing(loc1, ty1) || isErrorSuppressing(loc2, ty2); +} + +bool TypeChecker2::isErrorSuppressing(Location loc, TypePackId tp) +{ + switch (shouldSuppressErrors(NotNull{&normalizer}, tp)) + { + case ErrorSuppression::DoNotSuppress: + return false; + case ErrorSuppression::Suppress: + return true; + case ErrorSuppression::NormalizationFailed: + reportError(NormalizationTooComplex{}, loc); + return false; + }; + + LUAU_ASSERT(false); + return false; // UNREACHABLE } +bool TypeChecker2::isErrorSuppressing(Location loc1, TypePackId tp1, Location loc2, TypePackId tp2) +{ + return isErrorSuppressing(loc1, tp1) || isErrorSuppressing(loc2, tp2); +} + + } // namespace Luau diff --git a/Analysis/src/TypeFunction.cpp b/Analysis/src/TypeFunction.cpp new file mode 100644 index 000000000..d5eac1f20 --- /dev/null +++ b/Analysis/src/TypeFunction.cpp @@ -0,0 +1,2680 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Luau/TypeFunction.h" + +#include "Luau/BytecodeBuilder.h" +#include "Luau/Common.h" +#include "Luau/Compiler.h" +#include "Luau/ConstraintSolver.h" +#include "Luau/DenseHash.h" +#include "Luau/Instantiation.h" +#include "Luau/Normalize.h" +#include "Luau/NotNull.h" +#include "Luau/OverloadResolution.h" +#include "Luau/Set.h" +#include "Luau/Simplify.h" +#include "Luau/Subtyping.h" +#include "Luau/TimeTrace.h" +#include "Luau/ToString.h" +#include "Luau/TxnLog.h" +#include "Luau/Type.h" +#include "Luau/TypeFunctionReductionGuesser.h" +#include "Luau/TypeFunctionRuntime.h" +#include "Luau/TypeFunctionRuntimeBuilder.h" +#include "Luau/TypeFwd.h" +#include "Luau/TypeUtils.h" +#include "Luau/Unifier2.h" +#include "Luau/VecDeque.h" +#include "Luau/VisitType.h" + +#include "lua.h" +#include "lualib.h" + +#include +#include +#include + +// used to control emitting CodeTooComplex warnings on type function reduction +LUAU_DYNAMIC_FASTINTVARIABLE(LuauTypeFamilyGraphReductionMaximumSteps, 1'000'000); + +// used to control the limits of type function application over union type arguments +// e.g. `mul` blows up into `mul | mul | mul | mul` +LUAU_DYNAMIC_FASTINTVARIABLE(LuauTypeFamilyApplicationCartesianProductLimit, 5'000); + +// used to control falling back to a more conservative reduction based on guessing +// when this value is set to a negative value, guessing will be totally disabled. +LUAU_DYNAMIC_FASTINTVARIABLE(LuauTypeFamilyUseGuesserDepth, -1); + +LUAU_FASTFLAGVARIABLE(DebugLuauLogTypeFamilies, false) +LUAU_FASTFLAGVARIABLE(LuauUserDefinedTypeFunctions2, false) +LUAU_FASTFLAG(LuauUserDefinedTypeFunctionNoEvaluation) +LUAU_FASTFLAG(LuauUserTypeFunFixRegister) +LUAU_FASTFLAG(LuauRemoveNotAnyHack) + +LUAU_DYNAMIC_FASTINT(LuauTypeSolverRelease) + +namespace Luau +{ + +using TypeOrTypePackIdSet = DenseHashSet; + +struct InstanceCollector : TypeOnceVisitor +{ + VecDeque tys; + VecDeque tps; + TypeOrTypePackIdSet shouldGuess{nullptr}; + std::vector cyclicInstance; + + bool visit(TypeId ty, const TypeFunctionInstanceType&) override + { + // TypeOnceVisitor performs a depth-first traversal in the absence of + // cycles. This means that by pushing to the front of the queue, we will + // try to reduce deeper instances first if we start with the first thing + // in the queue. Consider Add, number>, number>: + // we want to reduce the innermost Add instantiation + // first. + + if (DFInt::LuauTypeFamilyUseGuesserDepth >= 0 && typeFunctionDepth > DFInt::LuauTypeFamilyUseGuesserDepth) + shouldGuess.insert(ty); + + tys.push_front(ty); + + return true; + } + + void cycle(TypeId ty) override + { + /// Detected cyclic type pack + TypeId t = follow(ty); + if (get(t)) + cyclicInstance.push_back(t); + } + + bool visit(TypeId ty, const ClassType&) override + { + return false; + } + + bool visit(TypePackId tp, const TypeFunctionInstanceTypePack&) override + { + // TypeOnceVisitor performs a depth-first traversal in the absence of + // cycles. This means that by pushing to the front of the queue, we will + // try to reduce deeper instances first if we start with the first thing + // in the queue. Consider Add, number>, number>: + // we want to reduce the innermost Add instantiation + // first. + + if (DFInt::LuauTypeFamilyUseGuesserDepth >= 0 && typeFunctionDepth > DFInt::LuauTypeFamilyUseGuesserDepth) + shouldGuess.insert(tp); + + tps.push_front(tp); + + return true; + } +}; + +struct TypeFunctionReducer +{ + TypeFunctionContext ctx; + + VecDeque queuedTys; + VecDeque queuedTps; + TypeOrTypePackIdSet shouldGuess; + std::vector cyclicTypeFunctions; + TypeOrTypePackIdSet irreducible{nullptr}; + FunctionGraphReductionResult result; + bool force = false; + + // Local to the constraint being reduced. + Location location; + + TypeFunctionReducer( + VecDeque queuedTys, + VecDeque queuedTps, + TypeOrTypePackIdSet shouldGuess, + std::vector cyclicTypes, + Location location, + TypeFunctionContext ctx, + bool force = false + ) + : ctx(ctx) + , queuedTys(std::move(queuedTys)) + , queuedTps(std::move(queuedTps)) + , shouldGuess(std::move(shouldGuess)) + , cyclicTypeFunctions(std::move(cyclicTypes)) + , force(force) + , location(location) + { + } + + enum class SkipTestResult + { + CyclicTypeFunction, + Irreducible, + Defer, + Okay, + }; + + SkipTestResult testForSkippability(TypeId ty) + { + ty = follow(ty); + + if (is(ty)) + { + for (auto t : cyclicTypeFunctions) + { + if (ty == t) + return SkipTestResult::CyclicTypeFunction; + } + + if (!irreducible.contains(ty)) + return SkipTestResult::Defer; + + return SkipTestResult::Irreducible; + } + else if (is(ty)) + { + return SkipTestResult::Irreducible; + } + + return SkipTestResult::Okay; + } + + SkipTestResult testForSkippability(TypePackId ty) const + { + ty = follow(ty); + + if (is(ty)) + { + if (!irreducible.contains(ty)) + return SkipTestResult::Defer; + else + return SkipTestResult::Irreducible; + } + else if (is(ty)) + { + return SkipTestResult::Irreducible; + } + + return SkipTestResult::Okay; + } + + template + void replace(T subject, T replacement) + { + if (subject->owningArena != ctx.arena.get()) + { + result.errors.emplace_back(location, InternalError{"Attempting to modify a type function instance from another arena"}); + return; + } + + if (FFlag::DebugLuauLogTypeFamilies) + printf("%s -> %s\n", toString(subject, {true}).c_str(), toString(replacement, {true}).c_str()); + + asMutable(subject)->ty.template emplace>(replacement); + + if constexpr (std::is_same_v) + result.reducedTypes.insert(subject); + else if constexpr (std::is_same_v) + result.reducedPacks.insert(subject); + } + + template + void handleTypeFunctionReduction(T subject, TypeFunctionReductionResult reduction) + { + if (reduction.result) + replace(subject, *reduction.result); + else + { + irreducible.insert(subject); + + if (reduction.error.has_value()) + result.errors.emplace_back(location, UserDefinedTypeFunctionError{*reduction.error}); + + if (reduction.uninhabited || force) + { + if (FFlag::DebugLuauLogTypeFamilies) + printf("%s is uninhabited\n", toString(subject, {true}).c_str()); + + if constexpr (std::is_same_v) + result.errors.emplace_back(location, UninhabitedTypeFunction{subject}); + else if constexpr (std::is_same_v) + result.errors.emplace_back(location, UninhabitedTypePackFunction{subject}); + } + else if (!reduction.uninhabited && !force) + { + if (FFlag::DebugLuauLogTypeFamilies) + printf( + "%s is irreducible; blocked on %zu types, %zu packs\n", + toString(subject, {true}).c_str(), + reduction.blockedTypes.size(), + reduction.blockedPacks.size() + ); + + for (TypeId b : reduction.blockedTypes) + result.blockedTypes.insert(b); + + for (TypePackId b : reduction.blockedPacks) + result.blockedPacks.insert(b); + } + } + } + + bool done() const + { + return queuedTys.empty() && queuedTps.empty(); + } + + template + bool testParameters(T subject, const I* tfit) + { + for (TypeId p : tfit->typeArguments) + { + SkipTestResult skip = testForSkippability(p); + + if (skip == SkipTestResult::Irreducible) + { + if (FFlag::DebugLuauLogTypeFamilies) + printf("%s is irreducible due to a dependency on %s\n", toString(subject, {true}).c_str(), toString(p, {true}).c_str()); + + irreducible.insert(subject); + return false; + } + else if (skip == SkipTestResult::Defer) + { + if (FFlag::DebugLuauLogTypeFamilies) + printf("Deferring %s until %s is solved\n", toString(subject, {true}).c_str(), toString(p, {true}).c_str()); + + if constexpr (std::is_same_v) + queuedTys.push_back(subject); + else if constexpr (std::is_same_v) + queuedTps.push_back(subject); + + return false; + } + } + + for (TypePackId p : tfit->packArguments) + { + SkipTestResult skip = testForSkippability(p); + + if (skip == SkipTestResult::Irreducible) + { + if (FFlag::DebugLuauLogTypeFamilies) + printf("%s is irreducible due to a dependency on %s\n", toString(subject, {true}).c_str(), toString(p, {true}).c_str()); + + irreducible.insert(subject); + return false; + } + else if (skip == SkipTestResult::Defer) + { + if (FFlag::DebugLuauLogTypeFamilies) + printf("Deferring %s until %s is solved\n", toString(subject, {true}).c_str(), toString(p, {true}).c_str()); + + if constexpr (std::is_same_v) + queuedTys.push_back(subject); + else if constexpr (std::is_same_v) + queuedTps.push_back(subject); + + return false; + } + } + + return true; + } + + template + inline bool tryGuessing(TID subject) + { + if (shouldGuess.contains(subject)) + { + if (FFlag::DebugLuauLogTypeFamilies) + printf("Flagged %s for reduction with guesser.\n", toString(subject, {true}).c_str()); + + TypeFunctionReductionGuesser guesser{ctx.arena, ctx.builtins, ctx.normalizer}; + auto guessed = guesser.guess(subject); + + if (guessed) + { + if (FFlag::DebugLuauLogTypeFamilies) + printf("Selected %s as the guessed result type.\n", toString(*guessed, {true}).c_str()); + + replace(subject, *guessed); + return true; + } + + if (FFlag::DebugLuauLogTypeFamilies) + printf("Failed to produce a guess for the result of %s.\n", toString(subject, {true}).c_str()); + } + + return false; + } + + + void stepType() + { + TypeId subject = follow(queuedTys.front()); + queuedTys.pop_front(); + + if (irreducible.contains(subject)) + return; + + if (FFlag::DebugLuauLogTypeFamilies) + printf("Trying to reduce %s\n", toString(subject, {true}).c_str()); + + if (const TypeFunctionInstanceType* tfit = get(subject)) + { + SkipTestResult testCyclic = testForSkippability(subject); + + if (!testParameters(subject, tfit) && testCyclic != SkipTestResult::CyclicTypeFunction) + { + if (FFlag::DebugLuauLogTypeFamilies) + printf("Irreducible due to irreducible/pending and a non-cyclic function\n"); + + return; + } + + if (tryGuessing(subject)) + return; + + ctx.userFuncName = tfit->userFuncName; + + TypeFunctionReductionResult result = tfit->function->reducer(subject, tfit->typeArguments, tfit->packArguments, NotNull{&ctx}); + handleTypeFunctionReduction(subject, result); + } + } + + void stepPack() + { + TypePackId subject = follow(queuedTps.front()); + queuedTps.pop_front(); + + if (irreducible.contains(subject)) + return; + + if (FFlag::DebugLuauLogTypeFamilies) + printf("Trying to reduce %s\n", toString(subject, {true}).c_str()); + + if (const TypeFunctionInstanceTypePack* tfit = get(subject)) + { + if (!testParameters(subject, tfit)) + return; + + if (tryGuessing(subject)) + return; + + TypeFunctionReductionResult result = + tfit->function->reducer(subject, tfit->typeArguments, tfit->packArguments, NotNull{&ctx}); + handleTypeFunctionReduction(subject, result); + } + } + + void step() + { + if (!queuedTys.empty()) + stepType(); + else if (!queuedTps.empty()) + stepPack(); + } +}; + +struct LuauTempThreadPopper +{ + explicit LuauTempThreadPopper(lua_State* L) + : L(L) + { + } + ~LuauTempThreadPopper() + { + lua_pop(L, 1); + } + + lua_State* L = nullptr; +}; + +static FunctionGraphReductionResult reduceFunctionsInternal( + VecDeque queuedTys, + VecDeque queuedTps, + TypeOrTypePackIdSet shouldGuess, + std::vector cyclics, + Location location, + TypeFunctionContext ctx, + bool force +) +{ + TypeFunctionReducer reducer{std::move(queuedTys), std::move(queuedTps), std::move(shouldGuess), std::move(cyclics), location, ctx, force}; + int iterationCount = 0; + + while (!reducer.done()) + { + reducer.step(); + + ++iterationCount; + if (iterationCount > DFInt::LuauTypeFamilyGraphReductionMaximumSteps) + { + reducer.result.errors.emplace_back(location, CodeTooComplex{}); + break; + } + } + + return std::move(reducer.result); +} + +FunctionGraphReductionResult reduceTypeFunctions(TypeId entrypoint, Location location, TypeFunctionContext ctx, bool force) +{ + InstanceCollector collector; + + try + { + collector.traverse(entrypoint); + } + catch (RecursionLimitException&) + { + return FunctionGraphReductionResult{}; + } + + if (collector.tys.empty() && collector.tps.empty()) + return {}; + + return reduceFunctionsInternal( + std::move(collector.tys), + std::move(collector.tps), + std::move(collector.shouldGuess), + std::move(collector.cyclicInstance), + location, + ctx, + force + ); +} + +FunctionGraphReductionResult reduceTypeFunctions(TypePackId entrypoint, Location location, TypeFunctionContext ctx, bool force) +{ + InstanceCollector collector; + + try + { + collector.traverse(entrypoint); + } + catch (RecursionLimitException&) + { + return FunctionGraphReductionResult{}; + } + + if (collector.tys.empty() && collector.tps.empty()) + return {}; + + return reduceFunctionsInternal( + std::move(collector.tys), + std::move(collector.tps), + std::move(collector.shouldGuess), + std::move(collector.cyclicInstance), + location, + ctx, + force + ); +} + +bool isPending(TypeId ty, ConstraintSolver* solver) +{ + return is(ty) || (solver && solver->hasUnresolvedConstraints(ty)); +} + +template +static std::optional> tryDistributeTypeFunctionApp( + F f, + TypeId instance, + const std::vector& typeParams, + const std::vector& packParams, + NotNull ctx, + Args&&... args +) +{ + // op (a | b) (c | d) ~ (op a (c | d)) | (op b (c | d)) ~ (op a c) | (op a d) | (op b c) | (op b d) + bool uninhabited = false; + std::vector blockedTypes; + std::vector results; + size_t cartesianProductSize = 1; + + const UnionType* firstUnion = nullptr; + size_t unionIndex = 0; + + std::vector arguments = typeParams; + for (size_t i = 0; i < arguments.size(); ++i) + { + const UnionType* ut = get(follow(arguments[i])); + if (!ut) + continue; + + // We want to find the first union type in the set of arguments to distribute that one and only that one union. + // The function `f` we have is recursive, so `arguments[unionIndex]` will be updated in-place for each option in + // the union we've found in this context, so that index will no longer be a union type. Any other arguments at + // index + 1 or after will instead be distributed, if those are a union, which will be subjected to the same rules. + if (!firstUnion && ut) + { + firstUnion = ut; + unionIndex = i; + } + + cartesianProductSize *= std::distance(begin(ut), end(ut)); + + // TODO: We'd like to report that the type function application is too complex here. + if (size_t(DFInt::LuauTypeFamilyApplicationCartesianProductLimit) <= cartesianProductSize) + return {{std::nullopt, true, {}, {}}}; + } + + if (!firstUnion) + { + // If we couldn't find any union type argument, we're not distributing. + return std::nullopt; + } + + for (TypeId option : firstUnion) + { + arguments[unionIndex] = option; + + TypeFunctionReductionResult result = f(instance, arguments, packParams, ctx, args...); + blockedTypes.insert(blockedTypes.end(), result.blockedTypes.begin(), result.blockedTypes.end()); + uninhabited |= result.uninhabited; + + if (result.uninhabited || !result.result) + break; + else + results.push_back(*result.result); + } + + if (uninhabited || !blockedTypes.empty()) + return {{std::nullopt, uninhabited, blockedTypes, {}}}; + + if (!results.empty()) + { + if (results.size() == 1) + return {{results[0], false, {}, {}}}; + + TypeId resultTy = ctx->arena->addType(TypeFunctionInstanceType{ + NotNull{&builtinTypeFunctions().unionFunc}, + std::move(results), + {}, + }); + + return {{resultTy, false, {}, {}}}; + } + + return std::nullopt; +} + +TypeFunctionReductionResult userDefinedTypeFunction( + TypeId instance, + const std::vector& typeParams, + const std::vector& packParams, + NotNull ctx +) +{ + if (!ctx->userFuncName) + { + ctx->ice->ice("all user-defined type functions must have an associated function definition"); + return {std::nullopt, true, {}, {}}; + } + + if (FFlag::LuauUserDefinedTypeFunctionNoEvaluation) + { + // If type functions cannot be evaluated because of errors in the code, we do not generate any additional ones + if (!ctx->typeFunctionRuntime->allowEvaluation) + return {ctx->builtins->errorRecoveryType(), false, {}, {}}; + } + + for (auto typeParam : typeParams) + { + TypeId ty = follow(typeParam); + + // block if we need to + if (isPending(ty, ctx->solver)) + return {std::nullopt, false, {ty}, {}}; + } + + AstName name = *ctx->userFuncName; + + lua_State* global = ctx->typeFunctionRuntime->state.get(); + + if (global == nullptr) + return {std::nullopt, true, {}, {}, format("'%s' type function: cannot be evaluated in this context", name.value)}; + + // Separate sandboxed thread for individual execution and private globals + lua_State* L = lua_newthread(global); + LuauTempThreadPopper popper(global); + + lua_getglobal(global, name.value); + lua_xmove(global, L, 1); + + // Push serialized arguments onto the stack + + // Since there aren't any new class types being created in type functions, there isn't a deserialization function + // class types. Instead, we can keep this map and return the mapping as the "deserialized value" + std::unique_ptr runtimeBuilder = std::make_unique(ctx); + for (auto typeParam : typeParams) + { + TypeId ty = follow(typeParam); + // This is checked at the top of the function, and should still be true. + LUAU_ASSERT(!isPending(ty, ctx->solver)); + + TypeFunctionTypeId serializedTy = serialize(ty, runtimeBuilder.get()); + // Check if there were any errors while serializing + if (runtimeBuilder->errors.size() != 0) + return {std::nullopt, true, {}, {}, runtimeBuilder->errors.front()}; + + allocTypeUserData(L, serializedTy->type); + } + + // Set up an interrupt handler for type functions to respect type checking limits and LSP cancellation requests. + lua_callbacks(L)->interrupt = [](lua_State* L, int gc) + { + auto ctx = static_cast(lua_getthreaddata(lua_mainthread(L))); + if (ctx->limits->finishTime && TimeTrace::getClock() > *ctx->limits->finishTime) + throw TimeLimitError(ctx->ice->moduleName); + + if (ctx->limits->cancellationToken && ctx->limits->cancellationToken->requested()) + throw UserCancelError(ctx->ice->moduleName); + }; + + if (auto error = checkResultForError(L, name.value, lua_pcall(L, int(typeParams.size()), 1, 0))) + return {std::nullopt, true, {}, {}, error}; + + // If the return value is not a type userdata, return with error message + if (!isTypeUserData(L, 1)) + return {std::nullopt, true, {}, {}, format("'%s' type function: returned a non-type value", name.value)}; + + TypeFunctionTypeId retTypeFunctionTypeId = getTypeUserData(L, 1); + + // No errors should be present here since we should've returned already if any were raised during serialization. + LUAU_ASSERT(runtimeBuilder->errors.size() == 0); + + TypeId retTypeId = deserialize(retTypeFunctionTypeId, runtimeBuilder.get()); + + // At least 1 error occured while deserializing + if (runtimeBuilder->errors.size() > 0) + return {std::nullopt, true, {}, {}, runtimeBuilder->errors.front()}; + + return {retTypeId, false, {}, {}}; +} + +TypeFunctionReductionResult notTypeFunction( + TypeId instance, + const std::vector& typeParams, + const std::vector& packParams, + NotNull ctx +) +{ + if (typeParams.size() != 1 || !packParams.empty()) + { + ctx->ice->ice("not type function: encountered a type function instance without the required argument structure"); + LUAU_ASSERT(false); + } + + TypeId ty = follow(typeParams.at(0)); + + if (ty == instance) + return {ctx->builtins->neverType, false, {}, {}}; + + if (isPending(ty, ctx->solver)) + return {std::nullopt, false, {ty}, {}}; + + if (auto result = tryDistributeTypeFunctionApp(notTypeFunction, instance, typeParams, packParams, ctx)) + return *result; + + // `not` operates on anything and returns a `boolean` always. + return {ctx->builtins->booleanType, false, {}, {}}; +} + +TypeFunctionReductionResult lenTypeFunction( + TypeId instance, + const std::vector& typeParams, + const std::vector& packParams, + NotNull ctx +) +{ + if (typeParams.size() != 1 || !packParams.empty()) + { + ctx->ice->ice("len type function: encountered a type function instance without the required argument structure"); + LUAU_ASSERT(false); + } + + TypeId operandTy = follow(typeParams.at(0)); + + if (operandTy == instance) + return {ctx->builtins->neverType, false, {}, {}}; + + // check to see if the operand type is resolved enough, and wait to reduce if not + // the use of `typeFromNormal` later necessitates blocking on local types. + if (isPending(operandTy, ctx->solver)) + return {std::nullopt, false, {operandTy}, {}}; + + // if the type is free but has only one remaining reference, we can generalize it to its upper bound here. + if (ctx->solver) + { + std::optional maybeGeneralized = ctx->solver->generalizeFreeType(ctx->scope, operandTy, /* avoidSealingTables */ true); + if (!maybeGeneralized) + return {std::nullopt, false, {operandTy}, {}}; + operandTy = *maybeGeneralized; + } + + std::shared_ptr normTy = ctx->normalizer->normalize(operandTy); + NormalizationResult inhabited = ctx->normalizer->isInhabited(normTy.get()); + + // if the type failed to normalize, we can't reduce, but know nothing about inhabitance. + if (!normTy || inhabited == NormalizationResult::HitLimits) + return {std::nullopt, false, {}, {}}; + + // if the operand type is error suppressing, we can immediately reduce to `number`. + if (normTy->shouldSuppressErrors()) + return {ctx->builtins->numberType, false, {}, {}}; + + // # always returns a number, even if its operand is never. + // if we're checking the length of a string, that works! + if (inhabited == NormalizationResult::False || normTy->isSubtypeOfString()) + return {ctx->builtins->numberType, false, {}, {}}; + + // we use the normalized operand here in case there was an intersection or union. + TypeId normalizedOperand = + DFInt::LuauTypeSolverRelease >= 646 ? follow(ctx->normalizer->typeFromNormal(*normTy)) : ctx->normalizer->typeFromNormal(*normTy); + if (normTy->hasTopTable() || get(normalizedOperand)) + return {ctx->builtins->numberType, false, {}, {}}; + + if (auto result = tryDistributeTypeFunctionApp(lenTypeFunction, instance, typeParams, packParams, ctx)) + return *result; + + // findMetatableEntry demands the ability to emit errors, so we must give it + // the necessary state to do that, even if we intend to just eat the errors. + ErrorVec dummy; + + std::optional mmType = findMetatableEntry(ctx->builtins, dummy, operandTy, "__len", Location{}); + if (!mmType) + return {std::nullopt, true, {}, {}}; + + mmType = follow(*mmType); + if (isPending(*mmType, ctx->solver)) + return {std::nullopt, false, {*mmType}, {}}; + + const FunctionType* mmFtv = get(*mmType); + if (!mmFtv) + return {std::nullopt, true, {}, {}}; + + std::optional instantiatedMmType = instantiate(ctx->builtins, ctx->arena, ctx->limits, ctx->scope, *mmType); + if (!instantiatedMmType) + return {std::nullopt, true, {}, {}}; + + const FunctionType* instantiatedMmFtv = get(*instantiatedMmType); + if (!instantiatedMmFtv) + return {ctx->builtins->errorRecoveryType(), false, {}, {}}; + + TypePackId inferredArgPack = ctx->arena->addTypePack({operandTy}); + Unifier2 u2{ctx->arena, ctx->builtins, ctx->scope, ctx->ice}; + if (!u2.unify(inferredArgPack, instantiatedMmFtv->argTypes)) + return {std::nullopt, true, {}, {}}; // occurs check failed + + Subtyping subtyping{ctx->builtins, ctx->arena, ctx->normalizer, ctx->typeFunctionRuntime, ctx->ice}; + if (!subtyping.isSubtype(inferredArgPack, instantiatedMmFtv->argTypes, ctx->scope).isSubtype) // TODO: is this the right variance? + return {std::nullopt, true, {}, {}}; + + // `len` must return a `number`. + return {ctx->builtins->numberType, false, {}, {}}; +} + +TypeFunctionReductionResult unmTypeFunction( + TypeId instance, + const std::vector& typeParams, + const std::vector& packParams, + NotNull ctx +) +{ + if (typeParams.size() != 1 || !packParams.empty()) + { + ctx->ice->ice("unm type function: encountered a type function instance without the required argument structure"); + LUAU_ASSERT(false); + } + + TypeId operandTy = follow(typeParams.at(0)); + + if (operandTy == instance) + return {ctx->builtins->neverType, false, {}, {}}; + + // check to see if the operand type is resolved enough, and wait to reduce if not + if (isPending(operandTy, ctx->solver)) + return {std::nullopt, false, {operandTy}, {}}; + + // if the type is free but has only one remaining reference, we can generalize it to its upper bound here. + if (ctx->solver) + { + std::optional maybeGeneralized = ctx->solver->generalizeFreeType(ctx->scope, operandTy); + if (!maybeGeneralized) + return {std::nullopt, false, {operandTy}, {}}; + operandTy = *maybeGeneralized; + } + + std::shared_ptr normTy = ctx->normalizer->normalize(operandTy); + + // if the operand failed to normalize, we can't reduce, but know nothing about inhabitance. + if (!normTy) + return {std::nullopt, false, {}, {}}; + + // if the operand is error suppressing, we can just go ahead and reduce. + if (normTy->shouldSuppressErrors()) + return {operandTy, false, {}, {}}; + + // if we have a `never`, we can never observe that the operation didn't work. + if (is(operandTy)) + return {ctx->builtins->neverType, false, {}, {}}; + + // If the type is exactly `number`, we can reduce now. + if (normTy->isExactlyNumber()) + return {ctx->builtins->numberType, false, {}, {}}; + + if (auto result = tryDistributeTypeFunctionApp(unmTypeFunction, instance, typeParams, packParams, ctx)) + return *result; + + // findMetatableEntry demands the ability to emit errors, so we must give it + // the necessary state to do that, even if we intend to just eat the errors. + ErrorVec dummy; + + std::optional mmType = findMetatableEntry(ctx->builtins, dummy, operandTy, "__unm", Location{}); + if (!mmType) + return {std::nullopt, true, {}, {}}; + + mmType = follow(*mmType); + if (isPending(*mmType, ctx->solver)) + return {std::nullopt, false, {*mmType}, {}}; + + const FunctionType* mmFtv = get(*mmType); + if (!mmFtv) + return {std::nullopt, true, {}, {}}; + + std::optional instantiatedMmType = instantiate(ctx->builtins, ctx->arena, ctx->limits, ctx->scope, *mmType); + if (!instantiatedMmType) + return {std::nullopt, true, {}, {}}; + + const FunctionType* instantiatedMmFtv = get(*instantiatedMmType); + if (!instantiatedMmFtv) + return {ctx->builtins->errorRecoveryType(), false, {}, {}}; + + TypePackId inferredArgPack = ctx->arena->addTypePack({operandTy}); + Unifier2 u2{ctx->arena, ctx->builtins, ctx->scope, ctx->ice}; + if (!u2.unify(inferredArgPack, instantiatedMmFtv->argTypes)) + return {std::nullopt, true, {}, {}}; // occurs check failed + + Subtyping subtyping{ctx->builtins, ctx->arena, ctx->normalizer, ctx->typeFunctionRuntime, ctx->ice}; + if (!subtyping.isSubtype(inferredArgPack, instantiatedMmFtv->argTypes, ctx->scope).isSubtype) // TODO: is this the right variance? + return {std::nullopt, true, {}, {}}; + + if (std::optional ret = first(instantiatedMmFtv->retTypes)) + return {*ret, false, {}, {}}; + else + return {std::nullopt, true, {}, {}}; +} + +void dummyStateClose(lua_State*) {} + +TypeFunctionRuntime::TypeFunctionRuntime(NotNull ice, NotNull limits) + : ice(ice) + , limits(limits) + , state(nullptr, dummyStateClose) +{ +} + +TypeFunctionRuntime::~TypeFunctionRuntime() {} + +std::optional TypeFunctionRuntime::registerFunction(AstStatTypeFunction* function) +{ + if (FFlag::LuauUserDefinedTypeFunctionNoEvaluation) + { + // If evaluation is disabled, we do not generate additional error messages + if (!allowEvaluation) + return std::nullopt; + } + + prepareState(); + + AstName name = function->name; + + // Construct ParseResult containing the type function + Allocator allocator; + AstNameTable names(allocator); + + AstExpr* exprFunction = function->body; + AstArray exprReturns{&exprFunction, 1}; + AstStatReturn stmtReturn{Location{}, exprReturns}; + AstStat* stmtArray[] = {&stmtReturn}; + AstArray stmts{stmtArray, 1}; + AstStatBlock exec{Location{}, stmts}; + ParseResult parseResult{&exec, 1}; + + BytecodeBuilder builder; + try + { + compileOrThrow(builder, parseResult, names); + } + catch (CompileError& e) + { + return format("'%s' type function failed to compile with error message: %s", name.value, e.what()); + } + + std::string bytecode = builder.getBytecode(); + + lua_State* global = state.get(); + + // Separate sandboxed thread for individual execution and private globals + lua_State* L = lua_newthread(global); + LuauTempThreadPopper popper(global); + + // Create individual environment for the type function + luaL_sandboxthread(L); + + // Do not allow global writes to that environment + lua_pushvalue(L, LUA_GLOBALSINDEX); + lua_setreadonly(L, -1, true); + lua_pop(L, 1); + + // Load bytecode into Luau state + if (auto error = checkResultForError(L, name.value, luau_load(L, name.value, bytecode.data(), bytecode.size(), 0))) + return error; + + // Execute the global function which should return our user-defined type function + if (auto error = checkResultForError(L, name.value, lua_resume(L, nullptr, 0))) + return error; + + if (!lua_isfunction(L, -1)) + { + lua_pop(L, 1); + return format("Could not find '%s' type function in the global scope", name.value); + } + + // Store resulting function in the global environment + lua_xmove(L, global, 1); + lua_setglobal(global, name.value); + + return std::nullopt; +} + +void TypeFunctionRuntime::prepareState() +{ + if (state) + return; + + state = StateRef(lua_newstate(typeFunctionAlloc, nullptr), lua_close); + lua_State* L = state.get(); + + lua_setthreaddata(L, this); + + setTypeFunctionEnvironment(L); + + registerTypeUserData(L); + + if (FFlag::LuauUserTypeFunFixRegister) + registerTypesLibrary(L); + + luaL_sandbox(L); + luaL_sandboxthread(L); +} + +TypeFunctionContext::TypeFunctionContext(NotNull cs, NotNull scope, NotNull constraint) + : arena(cs->arena) + , builtins(cs->builtinTypes) + , scope(scope) + , normalizer(cs->normalizer) + , typeFunctionRuntime(cs->typeFunctionRuntime) + , ice(NotNull{&cs->iceReporter}) + , limits(NotNull{&cs->limits}) + , solver(cs.get()) + , constraint(constraint.get()) +{ +} + +NotNull TypeFunctionContext::pushConstraint(ConstraintV&& c) const +{ + LUAU_ASSERT(solver); + NotNull newConstraint = solver->pushConstraint(scope, constraint ? constraint->location : Location{}, std::move(c)); + + // Every constraint that is blocked on the current constraint must also be + // blocked on this new one. + if (constraint) + solver->inheritBlocks(NotNull{constraint}, newConstraint); + + return newConstraint; +} + +TypeFunctionReductionResult numericBinopTypeFunction( + TypeId instance, + const std::vector& typeParams, + const std::vector& packParams, + NotNull ctx, + const std::string metamethod +) +{ + if (typeParams.size() != 2 || !packParams.empty()) + { + ctx->ice->ice("encountered a type function instance without the required argument structure"); + LUAU_ASSERT(false); + } + + TypeId lhsTy = follow(typeParams.at(0)); + TypeId rhsTy = follow(typeParams.at(1)); + + // isPending of `lhsTy` or `rhsTy` would return true, even if it cycles. We want a different answer for that. + if (lhsTy == instance || rhsTy == instance) + return {ctx->builtins->neverType, false, {}, {}}; + + // if we have a `never`, we can never observe that the math operator is unreachable. + if (is(lhsTy) || is(rhsTy)) + return {ctx->builtins->neverType, false, {}, {}}; + + const Location location = ctx->constraint ? ctx->constraint->location : Location{}; + + // check to see if both operand types are resolved enough, and wait to reduce if not + if (isPending(lhsTy, ctx->solver)) + return {std::nullopt, false, {lhsTy}, {}}; + else if (isPending(rhsTy, ctx->solver)) + return {std::nullopt, false, {rhsTy}, {}}; + + // if either type is free but has only one remaining reference, we can generalize it to its upper bound here. + if (ctx->solver) + { + std::optional lhsMaybeGeneralized = ctx->solver->generalizeFreeType(ctx->scope, lhsTy); + std::optional rhsMaybeGeneralized = ctx->solver->generalizeFreeType(ctx->scope, rhsTy); + + if (!lhsMaybeGeneralized) + return {std::nullopt, false, {lhsTy}, {}}; + else if (!rhsMaybeGeneralized) + return {std::nullopt, false, {rhsTy}, {}}; + + lhsTy = *lhsMaybeGeneralized; + rhsTy = *rhsMaybeGeneralized; + } + + // TODO: Normalization needs to remove cyclic type functions from a `NormalizedType`. + std::shared_ptr normLhsTy = ctx->normalizer->normalize(lhsTy); + std::shared_ptr normRhsTy = ctx->normalizer->normalize(rhsTy); + + // if either failed to normalize, we can't reduce, but know nothing about inhabitance. + if (!normLhsTy || !normRhsTy) + return {std::nullopt, false, {}, {}}; + + // if one of the types is error suppressing, we can reduce to `any` since we should suppress errors in the result of the usage. + if (normLhsTy->shouldSuppressErrors() || normRhsTy->shouldSuppressErrors()) + return {ctx->builtins->anyType, false, {}, {}}; + + // if we're adding two `number` types, the result is `number`. + if (normLhsTy->isExactlyNumber() && normRhsTy->isExactlyNumber()) + return {ctx->builtins->numberType, false, {}, {}}; + + if (auto result = tryDistributeTypeFunctionApp(numericBinopTypeFunction, instance, typeParams, packParams, ctx, metamethod)) + return *result; + + // findMetatableEntry demands the ability to emit errors, so we must give it + // the necessary state to do that, even if we intend to just eat the errors. + ErrorVec dummy; + + std::optional mmType = findMetatableEntry(ctx->builtins, dummy, lhsTy, metamethod, location); + bool reversed = false; + if (!mmType) + { + mmType = findMetatableEntry(ctx->builtins, dummy, rhsTy, metamethod, location); + reversed = true; + } + + if (!mmType) + return {std::nullopt, true, {}, {}}; + + mmType = follow(*mmType); + if (isPending(*mmType, ctx->solver)) + return {std::nullopt, false, {*mmType}, {}}; + + TypePackId argPack = ctx->arena->addTypePack({lhsTy, rhsTy}); + SolveResult solveResult; + + if (!reversed) + solveResult = solveFunctionCall( + ctx->arena, ctx->builtins, ctx->normalizer, ctx->typeFunctionRuntime, ctx->ice, ctx->limits, ctx->scope, location, *mmType, argPack + ); + else + { + TypePack* p = getMutable(argPack); + std::swap(p->head.front(), p->head.back()); + solveResult = solveFunctionCall( + ctx->arena, ctx->builtins, ctx->normalizer, ctx->typeFunctionRuntime, ctx->ice, ctx->limits, ctx->scope, location, *mmType, argPack + ); + } + + if (!solveResult.typePackId.has_value()) + return {std::nullopt, true, {}, {}}; + + TypePack extracted = extendTypePack(*ctx->arena, ctx->builtins, *solveResult.typePackId, 1); + if (extracted.head.empty()) + return {std::nullopt, true, {}, {}}; + + return {extracted.head.front(), false, {}, {}}; +} + +TypeFunctionReductionResult addTypeFunction( + TypeId instance, + const std::vector& typeParams, + const std::vector& packParams, + NotNull ctx +) +{ + if (typeParams.size() != 2 || !packParams.empty()) + { + ctx->ice->ice("add type function: encountered a type function instance without the required argument structure"); + LUAU_ASSERT(false); + } + + return numericBinopTypeFunction(instance, typeParams, packParams, ctx, "__add"); +} + +TypeFunctionReductionResult subTypeFunction( + TypeId instance, + const std::vector& typeParams, + const std::vector& packParams, + NotNull ctx +) +{ + if (typeParams.size() != 2 || !packParams.empty()) + { + ctx->ice->ice("sub type function: encountered a type function instance without the required argument structure"); + LUAU_ASSERT(false); + } + + return numericBinopTypeFunction(instance, typeParams, packParams, ctx, "__sub"); +} + +TypeFunctionReductionResult mulTypeFunction( + TypeId instance, + const std::vector& typeParams, + const std::vector& packParams, + NotNull ctx +) +{ + if (typeParams.size() != 2 || !packParams.empty()) + { + ctx->ice->ice("mul type function: encountered a type function instance without the required argument structure"); + LUAU_ASSERT(false); + } + + return numericBinopTypeFunction(instance, typeParams, packParams, ctx, "__mul"); +} + +TypeFunctionReductionResult divTypeFunction( + TypeId instance, + const std::vector& typeParams, + const std::vector& packParams, + NotNull ctx +) +{ + if (typeParams.size() != 2 || !packParams.empty()) + { + ctx->ice->ice("div type function: encountered a type function instance without the required argument structure"); + LUAU_ASSERT(false); + } + + return numericBinopTypeFunction(instance, typeParams, packParams, ctx, "__div"); +} + +TypeFunctionReductionResult idivTypeFunction( + TypeId instance, + const std::vector& typeParams, + const std::vector& packParams, + NotNull ctx +) +{ + if (typeParams.size() != 2 || !packParams.empty()) + { + ctx->ice->ice("integer div type function: encountered a type function instance without the required argument structure"); + LUAU_ASSERT(false); + } + + return numericBinopTypeFunction(instance, typeParams, packParams, ctx, "__idiv"); +} + +TypeFunctionReductionResult powTypeFunction( + TypeId instance, + const std::vector& typeParams, + const std::vector& packParams, + NotNull ctx +) +{ + if (typeParams.size() != 2 || !packParams.empty()) + { + ctx->ice->ice("pow type function: encountered a type function instance without the required argument structure"); + LUAU_ASSERT(false); + } + + return numericBinopTypeFunction(instance, typeParams, packParams, ctx, "__pow"); +} + +TypeFunctionReductionResult modTypeFunction( + TypeId instance, + const std::vector& typeParams, + const std::vector& packParams, + NotNull ctx +) +{ + if (typeParams.size() != 2 || !packParams.empty()) + { + ctx->ice->ice("modulo type function: encountered a type function instance without the required argument structure"); + LUAU_ASSERT(false); + } + + return numericBinopTypeFunction(instance, typeParams, packParams, ctx, "__mod"); +} + +TypeFunctionReductionResult concatTypeFunction( + TypeId instance, + const std::vector& typeParams, + const std::vector& packParams, + NotNull ctx +) +{ + if (typeParams.size() != 2 || !packParams.empty()) + { + ctx->ice->ice("concat type function: encountered a type function instance without the required argument structure"); + LUAU_ASSERT(false); + } + + TypeId lhsTy = follow(typeParams.at(0)); + TypeId rhsTy = follow(typeParams.at(1)); + + // isPending of `lhsTy` or `rhsTy` would return true, even if it cycles. We want a different answer for that. + if (lhsTy == instance || rhsTy == instance) + return {ctx->builtins->neverType, false, {}, {}}; + + // check to see if both operand types are resolved enough, and wait to reduce if not + if (isPending(lhsTy, ctx->solver)) + return {std::nullopt, false, {lhsTy}, {}}; + else if (isPending(rhsTy, ctx->solver)) + return {std::nullopt, false, {rhsTy}, {}}; + + // if either type is free but has only one remaining reference, we can generalize it to its upper bound here. + if (ctx->solver) + { + std::optional lhsMaybeGeneralized = ctx->solver->generalizeFreeType(ctx->scope, lhsTy); + std::optional rhsMaybeGeneralized = ctx->solver->generalizeFreeType(ctx->scope, rhsTy); + + if (!lhsMaybeGeneralized) + return {std::nullopt, false, {lhsTy}, {}}; + else if (!rhsMaybeGeneralized) + return {std::nullopt, false, {rhsTy}, {}}; + + lhsTy = *lhsMaybeGeneralized; + rhsTy = *rhsMaybeGeneralized; + } + + std::shared_ptr normLhsTy = ctx->normalizer->normalize(lhsTy); + std::shared_ptr normRhsTy = ctx->normalizer->normalize(rhsTy); + + // if either failed to normalize, we can't reduce, but know nothing about inhabitance. + if (!normLhsTy || !normRhsTy) + return {std::nullopt, false, {}, {}}; + + // if one of the types is error suppressing, we can reduce to `any` since we should suppress errors in the result of the usage. + if (normLhsTy->shouldSuppressErrors() || normRhsTy->shouldSuppressErrors()) + return {ctx->builtins->anyType, false, {}, {}}; + + // if we have a `never`, we can never observe that the numeric operator didn't work. + if (is(lhsTy) || is(rhsTy)) + return {ctx->builtins->neverType, false, {}, {}}; + + // if we're concatenating two elements that are either strings or numbers, the result is `string`. + if ((normLhsTy->isSubtypeOfString() || normLhsTy->isExactlyNumber()) && (normRhsTy->isSubtypeOfString() || normRhsTy->isExactlyNumber())) + return {ctx->builtins->stringType, false, {}, {}}; + + if (auto result = tryDistributeTypeFunctionApp(concatTypeFunction, instance, typeParams, packParams, ctx)) + return *result; + + // findMetatableEntry demands the ability to emit errors, so we must give it + // the necessary state to do that, even if we intend to just eat the errors. + ErrorVec dummy; + + std::optional mmType = findMetatableEntry(ctx->builtins, dummy, lhsTy, "__concat", Location{}); + bool reversed = false; + if (!mmType) + { + mmType = findMetatableEntry(ctx->builtins, dummy, rhsTy, "__concat", Location{}); + reversed = true; + } + + if (!mmType) + return {std::nullopt, true, {}, {}}; + + mmType = follow(*mmType); + if (isPending(*mmType, ctx->solver)) + return {std::nullopt, false, {*mmType}, {}}; + + const FunctionType* mmFtv = get(*mmType); + if (!mmFtv) + return {std::nullopt, true, {}, {}}; + + std::optional instantiatedMmType = instantiate(ctx->builtins, ctx->arena, ctx->limits, ctx->scope, *mmType); + if (!instantiatedMmType) + return {std::nullopt, true, {}, {}}; + + const FunctionType* instantiatedMmFtv = get(*instantiatedMmType); + if (!instantiatedMmFtv) + return {ctx->builtins->errorRecoveryType(), false, {}, {}}; + + std::vector inferredArgs; + if (!reversed) + inferredArgs = {lhsTy, rhsTy}; + else + inferredArgs = {rhsTy, lhsTy}; + + TypePackId inferredArgPack = ctx->arena->addTypePack(std::move(inferredArgs)); + Unifier2 u2{ctx->arena, ctx->builtins, ctx->scope, ctx->ice}; + if (!u2.unify(inferredArgPack, instantiatedMmFtv->argTypes)) + return {std::nullopt, true, {}, {}}; // occurs check failed + + Subtyping subtyping{ctx->builtins, ctx->arena, ctx->normalizer, ctx->typeFunctionRuntime, ctx->ice}; + if (!subtyping.isSubtype(inferredArgPack, instantiatedMmFtv->argTypes, ctx->scope).isSubtype) // TODO: is this the right variance? + return {std::nullopt, true, {}, {}}; + + return {ctx->builtins->stringType, false, {}, {}}; +} + +TypeFunctionReductionResult andTypeFunction( + TypeId instance, + const std::vector& typeParams, + const std::vector& packParams, + NotNull ctx +) +{ + if (typeParams.size() != 2 || !packParams.empty()) + { + ctx->ice->ice("and type function: encountered a type function instance without the required argument structure"); + LUAU_ASSERT(false); + } + + TypeId lhsTy = follow(typeParams.at(0)); + TypeId rhsTy = follow(typeParams.at(1)); + + // t1 = and ~> lhs + if (follow(rhsTy) == instance && lhsTy != rhsTy) + return {lhsTy, false, {}, {}}; + // t1 = and ~> rhs + if (follow(lhsTy) == instance && lhsTy != rhsTy) + return {rhsTy, false, {}, {}}; + + // check to see if both operand types are resolved enough, and wait to reduce if not + if (isPending(lhsTy, ctx->solver)) + return {std::nullopt, false, {lhsTy}, {}}; + else if (isPending(rhsTy, ctx->solver)) + return {std::nullopt, false, {rhsTy}, {}}; + + // if either type is free but has only one remaining reference, we can generalize it to its upper bound here. + if (ctx->solver) + { + std::optional lhsMaybeGeneralized = ctx->solver->generalizeFreeType(ctx->scope, lhsTy); + std::optional rhsMaybeGeneralized = ctx->solver->generalizeFreeType(ctx->scope, rhsTy); + + if (!lhsMaybeGeneralized) + return {std::nullopt, false, {lhsTy}, {}}; + else if (!rhsMaybeGeneralized) + return {std::nullopt, false, {rhsTy}, {}}; + + lhsTy = *lhsMaybeGeneralized; + rhsTy = *rhsMaybeGeneralized; + } + + // And evalutes to a boolean if the LHS is falsey, and the RHS type if LHS is truthy. + SimplifyResult filteredLhs = simplifyIntersection(ctx->builtins, ctx->arena, lhsTy, ctx->builtins->falsyType); + SimplifyResult overallResult = simplifyUnion(ctx->builtins, ctx->arena, rhsTy, filteredLhs.result); + std::vector blockedTypes{}; + for (auto ty : filteredLhs.blockedTypes) + blockedTypes.push_back(ty); + for (auto ty : overallResult.blockedTypes) + blockedTypes.push_back(ty); + return {overallResult.result, false, std::move(blockedTypes), {}}; +} + +TypeFunctionReductionResult orTypeFunction( + TypeId instance, + const std::vector& typeParams, + const std::vector& packParams, + NotNull ctx +) +{ + if (typeParams.size() != 2 || !packParams.empty()) + { + ctx->ice->ice("or type function: encountered a type function instance without the required argument structure"); + LUAU_ASSERT(false); + } + + TypeId lhsTy = follow(typeParams.at(0)); + TypeId rhsTy = follow(typeParams.at(1)); + + // t1 = or ~> lhs + if (follow(rhsTy) == instance && lhsTy != rhsTy) + return {lhsTy, false, {}, {}}; + // t1 = or ~> rhs + if (follow(lhsTy) == instance && lhsTy != rhsTy) + return {rhsTy, false, {}, {}}; + + // check to see if both operand types are resolved enough, and wait to reduce if not + if (isPending(lhsTy, ctx->solver)) + return {std::nullopt, false, {lhsTy}, {}}; + else if (isPending(rhsTy, ctx->solver)) + return {std::nullopt, false, {rhsTy}, {}}; + + // if either type is free but has only one remaining reference, we can generalize it to its upper bound here. + if (ctx->solver) + { + std::optional lhsMaybeGeneralized = ctx->solver->generalizeFreeType(ctx->scope, lhsTy); + std::optional rhsMaybeGeneralized = ctx->solver->generalizeFreeType(ctx->scope, rhsTy); + + if (!lhsMaybeGeneralized) + return {std::nullopt, false, {lhsTy}, {}}; + else if (!rhsMaybeGeneralized) + return {std::nullopt, false, {rhsTy}, {}}; + + lhsTy = *lhsMaybeGeneralized; + rhsTy = *rhsMaybeGeneralized; + } + + // Or evalutes to the LHS type if the LHS is truthy, and the RHS type if LHS is falsy. + SimplifyResult filteredLhs = simplifyIntersection(ctx->builtins, ctx->arena, lhsTy, ctx->builtins->truthyType); + SimplifyResult overallResult = simplifyUnion(ctx->builtins, ctx->arena, rhsTy, filteredLhs.result); + std::vector blockedTypes{}; + for (auto ty : filteredLhs.blockedTypes) + blockedTypes.push_back(ty); + for (auto ty : overallResult.blockedTypes) + blockedTypes.push_back(ty); + return {overallResult.result, false, std::move(blockedTypes), {}}; +} + +static TypeFunctionReductionResult comparisonTypeFunction( + TypeId instance, + const std::vector& typeParams, + const std::vector& packParams, + NotNull ctx, + const std::string metamethod +) +{ + + if (typeParams.size() != 2 || !packParams.empty()) + { + ctx->ice->ice("encountered a type function instance without the required argument structure"); + LUAU_ASSERT(false); + } + + TypeId lhsTy = follow(typeParams.at(0)); + TypeId rhsTy = follow(typeParams.at(1)); + + if (lhsTy == instance || rhsTy == instance) + return {ctx->builtins->neverType, false, {}, {}}; + + if (isPending(lhsTy, ctx->solver)) + return {std::nullopt, false, {lhsTy}, {}}; + else if (isPending(rhsTy, ctx->solver)) + return {std::nullopt, false, {rhsTy}, {}}; + + // Algebra Reduction Rules for comparison type functions + // Note that comparing to never tells you nothing about the other operand + // lt< 'a , never> -> continue + // lt< never, 'a> -> continue + // lt< 'a, t> -> 'a is t - we'll solve the constraint, return and solve lt -> bool + // lt< t, 'a> -> same as above + bool canSubmitConstraint = ctx->solver && ctx->constraint; + bool lhsFree = get(lhsTy) != nullptr; + bool rhsFree = get(rhsTy) != nullptr; + if (canSubmitConstraint) + { + // Implement injective type functions for comparison type functions + // lt implies t is number + // lt implies t is number + if (lhsFree && isNumber(rhsTy)) + emplaceType(asMutable(lhsTy), ctx->builtins->numberType); + else if (rhsFree && isNumber(lhsTy)) + emplaceType(asMutable(rhsTy), ctx->builtins->numberType); + else if (lhsFree && ctx->normalizer->isInhabited(rhsTy) != NormalizationResult::False) + { + auto c1 = ctx->pushConstraint(EqualityConstraint{lhsTy, rhsTy}); + const_cast(ctx->constraint)->dependencies.emplace_back(c1); + } + else if (rhsFree && ctx->normalizer->isInhabited(lhsTy) != NormalizationResult::False) + { + auto c1 = ctx->pushConstraint(EqualityConstraint{rhsTy, lhsTy}); + const_cast(ctx->constraint)->dependencies.emplace_back(c1); + } + } + + // The above might have caused the operand types to be rebound, we need to follow them again + lhsTy = follow(lhsTy); + rhsTy = follow(rhsTy); + + // if either type is free but has only one remaining reference, we can generalize it to its upper bound here. + if (ctx->solver) + { + std::optional lhsMaybeGeneralized = ctx->solver->generalizeFreeType(ctx->scope, lhsTy); + std::optional rhsMaybeGeneralized = ctx->solver->generalizeFreeType(ctx->scope, rhsTy); + + if (!lhsMaybeGeneralized) + return {std::nullopt, false, {lhsTy}, {}}; + else if (!rhsMaybeGeneralized) + return {std::nullopt, false, {rhsTy}, {}}; + + lhsTy = *lhsMaybeGeneralized; + rhsTy = *rhsMaybeGeneralized; + } + + // check to see if both operand types are resolved enough, and wait to reduce if not + + std::shared_ptr normLhsTy = ctx->normalizer->normalize(lhsTy); + std::shared_ptr normRhsTy = ctx->normalizer->normalize(rhsTy); + NormalizationResult lhsInhabited = ctx->normalizer->isInhabited(normLhsTy.get()); + NormalizationResult rhsInhabited = ctx->normalizer->isInhabited(normRhsTy.get()); + + // if either failed to normalize, we can't reduce, but know nothing about inhabitance. + if (!normLhsTy || !normRhsTy || lhsInhabited == NormalizationResult::HitLimits || rhsInhabited == NormalizationResult::HitLimits) + return {std::nullopt, false, {}, {}}; + + // if one of the types is error suppressing, we can just go ahead and reduce. + if (normLhsTy->shouldSuppressErrors() || normRhsTy->shouldSuppressErrors()) + return {ctx->builtins->booleanType, false, {}, {}}; + + // if we have an uninhabited type (e.g. `never`), we can never observe that the comparison didn't work. + if (lhsInhabited == NormalizationResult::False || rhsInhabited == NormalizationResult::False) + return {ctx->builtins->booleanType, false, {}, {}}; + + // If both types are some strict subset of `string`, we can reduce now. + if (normLhsTy->isSubtypeOfString() && normRhsTy->isSubtypeOfString()) + return {ctx->builtins->booleanType, false, {}, {}}; + + // If both types are exactly `number`, we can reduce now. + if (normLhsTy->isExactlyNumber() && normRhsTy->isExactlyNumber()) + return {ctx->builtins->booleanType, false, {}, {}}; + + if (auto result = tryDistributeTypeFunctionApp(comparisonTypeFunction, instance, typeParams, packParams, ctx, metamethod)) + return *result; + + // findMetatableEntry demands the ability to emit errors, so we must give it + // the necessary state to do that, even if we intend to just eat the errors. + ErrorVec dummy; + + std::optional mmType = findMetatableEntry(ctx->builtins, dummy, lhsTy, metamethod, Location{}); + if (!mmType) + mmType = findMetatableEntry(ctx->builtins, dummy, rhsTy, metamethod, Location{}); + + if (!mmType) + return {std::nullopt, true, {}, {}}; + + mmType = follow(*mmType); + if (isPending(*mmType, ctx->solver)) + return {std::nullopt, false, {*mmType}, {}}; + + const FunctionType* mmFtv = get(*mmType); + if (!mmFtv) + return {std::nullopt, true, {}, {}}; + + std::optional instantiatedMmType = instantiate(ctx->builtins, ctx->arena, ctx->limits, ctx->scope, *mmType); + if (!instantiatedMmType) + return {std::nullopt, true, {}, {}}; + + const FunctionType* instantiatedMmFtv = get(*instantiatedMmType); + if (!instantiatedMmFtv) + return {ctx->builtins->errorRecoveryType(), false, {}, {}}; + + TypePackId inferredArgPack = ctx->arena->addTypePack({lhsTy, rhsTy}); + Unifier2 u2{ctx->arena, ctx->builtins, ctx->scope, ctx->ice}; + if (!u2.unify(inferredArgPack, instantiatedMmFtv->argTypes)) + return {std::nullopt, true, {}, {}}; // occurs check failed + + Subtyping subtyping{ctx->builtins, ctx->arena, ctx->normalizer, ctx->typeFunctionRuntime, ctx->ice}; + if (!subtyping.isSubtype(inferredArgPack, instantiatedMmFtv->argTypes, ctx->scope).isSubtype) // TODO: is this the right variance? + return {std::nullopt, true, {}, {}}; + + return {ctx->builtins->booleanType, false, {}, {}}; +} + +TypeFunctionReductionResult ltTypeFunction( + TypeId instance, + const std::vector& typeParams, + const std::vector& packParams, + NotNull ctx +) +{ + if (typeParams.size() != 2 || !packParams.empty()) + { + ctx->ice->ice("lt type function: encountered a type function instance without the required argument structure"); + LUAU_ASSERT(false); + } + + return comparisonTypeFunction(instance, typeParams, packParams, ctx, "__lt"); +} + +TypeFunctionReductionResult leTypeFunction( + TypeId instance, + const std::vector& typeParams, + const std::vector& packParams, + NotNull ctx +) +{ + if (typeParams.size() != 2 || !packParams.empty()) + { + ctx->ice->ice("le type function: encountered a type function instance without the required argument structure"); + LUAU_ASSERT(false); + } + + return comparisonTypeFunction(instance, typeParams, packParams, ctx, "__le"); +} + +TypeFunctionReductionResult eqTypeFunction( + TypeId instance, + const std::vector& typeParams, + const std::vector& packParams, + NotNull ctx +) +{ + if (typeParams.size() != 2 || !packParams.empty()) + { + ctx->ice->ice("eq type function: encountered a type function instance without the required argument structure"); + LUAU_ASSERT(false); + } + + TypeId lhsTy = follow(typeParams.at(0)); + TypeId rhsTy = follow(typeParams.at(1)); + + // check to see if both operand types are resolved enough, and wait to reduce if not + if (isPending(lhsTy, ctx->solver)) + return {std::nullopt, false, {lhsTy}, {}}; + else if (isPending(rhsTy, ctx->solver)) + return {std::nullopt, false, {rhsTy}, {}}; + + // if either type is free but has only one remaining reference, we can generalize it to its upper bound here. + if (ctx->solver) + { + std::optional lhsMaybeGeneralized = ctx->solver->generalizeFreeType(ctx->scope, lhsTy); + std::optional rhsMaybeGeneralized = ctx->solver->generalizeFreeType(ctx->scope, rhsTy); + + if (!lhsMaybeGeneralized) + return {std::nullopt, false, {lhsTy}, {}}; + else if (!rhsMaybeGeneralized) + return {std::nullopt, false, {rhsTy}, {}}; + + lhsTy = *lhsMaybeGeneralized; + rhsTy = *rhsMaybeGeneralized; + } + + std::shared_ptr normLhsTy = ctx->normalizer->normalize(lhsTy); + std::shared_ptr normRhsTy = ctx->normalizer->normalize(rhsTy); + NormalizationResult lhsInhabited = ctx->normalizer->isInhabited(normLhsTy.get()); + NormalizationResult rhsInhabited = ctx->normalizer->isInhabited(normRhsTy.get()); + + // if either failed to normalize, we can't reduce, but know nothing about inhabitance. + if (!normLhsTy || !normRhsTy || lhsInhabited == NormalizationResult::HitLimits || rhsInhabited == NormalizationResult::HitLimits) + return {std::nullopt, false, {}, {}}; + + // if one of the types is error suppressing, we can just go ahead and reduce. + if (normLhsTy->shouldSuppressErrors() || normRhsTy->shouldSuppressErrors()) + return {ctx->builtins->booleanType, false, {}, {}}; + + // if we have a `never`, we can never observe that the comparison didn't work. + if (lhsInhabited == NormalizationResult::False || rhsInhabited == NormalizationResult::False) + return {ctx->builtins->booleanType, false, {}, {}}; + + // findMetatableEntry demands the ability to emit errors, so we must give it + // the necessary state to do that, even if we intend to just eat the errors. + ErrorVec dummy; + + std::optional mmType = findMetatableEntry(ctx->builtins, dummy, lhsTy, "__eq", Location{}); + if (!mmType) + mmType = findMetatableEntry(ctx->builtins, dummy, rhsTy, "__eq", Location{}); + + // if neither type has a metatable entry for `__eq`, then we'll check for inhabitance of the intersection! + NormalizationResult intersectInhabited = ctx->normalizer->isIntersectionInhabited(lhsTy, rhsTy); + if (!mmType) + { + if (intersectInhabited == NormalizationResult::True) + return {ctx->builtins->booleanType, false, {}, {}}; // if it's inhabited, everything is okay! + + // we might be in a case where we still want to accept the comparison... + if (intersectInhabited == NormalizationResult::False) + { + // if they're both subtypes of `string` but have no common intersection, the comparison is allowed but always `false`. + if (normLhsTy->isSubtypeOfString() && normRhsTy->isSubtypeOfString()) + return {ctx->builtins->falseType, false, {}, {}}; + + // if they're both subtypes of `boolean` but have no common intersection, the comparison is allowed but always `false`. + if (normLhsTy->isSubtypeOfBooleans() && normRhsTy->isSubtypeOfBooleans()) + return {ctx->builtins->falseType, false, {}, {}}; + } + + return {std::nullopt, true, {}, {}}; // if it's not, then this type function is irreducible! + } + + mmType = follow(*mmType); + if (isPending(*mmType, ctx->solver)) + return {std::nullopt, false, {*mmType}, {}}; + + const FunctionType* mmFtv = get(*mmType); + if (!mmFtv) + return {std::nullopt, true, {}, {}}; + + std::optional instantiatedMmType = instantiate(ctx->builtins, ctx->arena, ctx->limits, ctx->scope, *mmType); + if (!instantiatedMmType) + return {std::nullopt, true, {}, {}}; + + const FunctionType* instantiatedMmFtv = get(*instantiatedMmType); + if (!instantiatedMmFtv) + return {ctx->builtins->errorRecoveryType(), false, {}, {}}; + + TypePackId inferredArgPack = ctx->arena->addTypePack({lhsTy, rhsTy}); + Unifier2 u2{ctx->arena, ctx->builtins, ctx->scope, ctx->ice}; + if (!u2.unify(inferredArgPack, instantiatedMmFtv->argTypes)) + return {std::nullopt, true, {}, {}}; // occurs check failed + + Subtyping subtyping{ctx->builtins, ctx->arena, ctx->normalizer, ctx->typeFunctionRuntime, ctx->ice}; + if (!subtyping.isSubtype(inferredArgPack, instantiatedMmFtv->argTypes, ctx->scope).isSubtype) // TODO: is this the right variance? + return {std::nullopt, true, {}, {}}; + + return {ctx->builtins->booleanType, false, {}, {}}; +} + +// Collect types that prevent us from reducing a particular refinement. +struct FindRefinementBlockers : TypeOnceVisitor +{ + DenseHashSet found{nullptr}; + bool visit(TypeId ty, const BlockedType&) override + { + found.insert(ty); + return false; + } + + bool visit(TypeId ty, const PendingExpansionType&) override + { + found.insert(ty); + return false; + } + + bool visit(TypeId ty, const ClassType&) override + { + return false; + } +}; + +TypeFunctionReductionResult refineTypeFunction( + TypeId instance, + const std::vector& typeParams, + const std::vector& packParams, + NotNull ctx +) +{ + if (typeParams.size() < 2 || !packParams.empty()) + { + ctx->ice->ice("refine type function: encountered a type function instance without the required argument structure"); + LUAU_ASSERT(false); + } + + TypeId targetTy = follow(typeParams.at(0)); + std::vector discriminantTypes; + for (size_t i = 1; i < typeParams.size(); i++) + discriminantTypes.push_back(follow(typeParams.at(i))); + + // check to see if both operand types are resolved enough, and wait to reduce if not + if (isPending(targetTy, ctx->solver)) + return {std::nullopt, false, {targetTy}, {}}; + else + { + for (auto t : discriminantTypes) + { + if (isPending(t, ctx->solver)) + return {std::nullopt, false, {t}, {}}; + } + } + // Refine a target type and a discriminant one at a time. + // Returns result : TypeId, toBlockOn : vector + auto stepRefine = [&ctx](TypeId target, TypeId discriminant) -> std::pair> + { + std::vector toBlock; + if (ctx->solver) + { + std::optional targetMaybeGeneralized = ctx->solver->generalizeFreeType(ctx->scope, target); + std::optional discriminantMaybeGeneralized = ctx->solver->generalizeFreeType(ctx->scope, discriminant); + + if (!targetMaybeGeneralized) + return std::pair>{nullptr, {target}}; + else if (!discriminantMaybeGeneralized) + return std::pair>{nullptr, {discriminant}}; + + target = *targetMaybeGeneralized; + discriminant = *discriminantMaybeGeneralized; + } + + // we need a more complex check for blocking on the discriminant in particular + FindRefinementBlockers frb; + frb.traverse(discriminant); + + if (!frb.found.empty()) + return {nullptr, {frb.found.begin(), frb.found.end()}}; + + /* HACK: Refinements sometimes produce a type T & ~any under the assumption + * that ~any is the same as any. This is so so weird, but refinements needs + * some way to say "I may refine this, but I'm not sure." + * + * It does this by refining on a blocked type and deferring the decision + * until it is unblocked. + * + * Refinements also get negated, so we wind up with types like T & ~*blocked* + * + * We need to treat T & ~any as T in this case. + */ + if (auto nt = get(discriminant)) + { + if (FFlag::LuauRemoveNotAnyHack) + { + if (get(follow(nt->ty))) + return {target, {}}; + } + else + { + if (get(follow(nt->ty))) + return {target, {}}; + } + } + + // If the target type is a table, then simplification already implements the logic to deal with refinements properly since the + // type of the discriminant is guaranteed to only ever be an (arbitrarily-nested) table of a single property type. + if (get(target)) + { + SimplifyResult result = simplifyIntersection(ctx->builtins, ctx->arena, target, discriminant); + if (!result.blockedTypes.empty()) + return {nullptr, {result.blockedTypes.begin(), result.blockedTypes.end()}}; + + return {result.result, {}}; + } + + // In the general case, we'll still use normalization though. + TypeId intersection = ctx->arena->addType(IntersectionType{{target, discriminant}}); + std::shared_ptr normIntersection = ctx->normalizer->normalize(intersection); + std::shared_ptr normType = ctx->normalizer->normalize(target); + + // if the intersection failed to normalize, we can't reduce, but know nothing about inhabitance. + if (!normIntersection || !normType) + return {nullptr, {}}; + + TypeId resultTy = ctx->normalizer->typeFromNormal(*normIntersection); + // include the error type if the target type is error-suppressing and the intersection we computed is not + if (normType->shouldSuppressErrors() && !normIntersection->shouldSuppressErrors()) + resultTy = ctx->arena->addType(UnionType{{resultTy, ctx->builtins->errorType}}); + + return {resultTy, {}}; + }; + + // refine target with each discriminant type in sequence (reverse of insertion order) + // If we cannot proceed, block. If all discriminant types refine successfully, return + // the result + TypeId target = targetTy; + while (!discriminantTypes.empty()) + { + TypeId discriminant = discriminantTypes.back(); + auto [refined, blocked] = stepRefine(target, discriminant); + + if (blocked.empty() && refined == nullptr) + return {std::nullopt, false, {}, {}}; + + if (!blocked.empty()) + return {std::nullopt, false, blocked, {}}; + + target = refined; + discriminantTypes.pop_back(); + } + return {target, false, {}, {}}; +} + +TypeFunctionReductionResult singletonTypeFunction( + TypeId instance, + const std::vector& typeParams, + const std::vector& packParams, + NotNull ctx +) +{ + if (typeParams.size() != 1 || !packParams.empty()) + { + ctx->ice->ice("singleton type function: encountered a type function instance without the required argument structure"); + LUAU_ASSERT(false); + } + + TypeId type = follow(typeParams.at(0)); + + // check to see if both operand types are resolved enough, and wait to reduce if not + if (isPending(type, ctx->solver)) + return {std::nullopt, false, {type}, {}}; + + // if the type is free but has only one remaining reference, we can generalize it to its upper bound here. + if (ctx->solver) + { + std::optional maybeGeneralized = ctx->solver->generalizeFreeType(ctx->scope, type); + if (!maybeGeneralized) + return {std::nullopt, false, {type}, {}}; + type = *maybeGeneralized; + } + + TypeId followed = type; + // we want to follow through a negation here as well. + if (auto negation = get(followed)) + followed = follow(negation->ty); + + // if we have a singleton type or `nil`, which is its own singleton type... + if (get(followed) || isNil(followed)) + return {type, false, {}, {}}; + + // otherwise, we'll return the top type, `unknown`. + return {ctx->builtins->unknownType, false, {}, {}}; +} + +TypeFunctionReductionResult unionTypeFunction( + TypeId instance, + const std::vector& typeParams, + const std::vector& packParams, + NotNull ctx +) +{ + if (!packParams.empty()) + { + ctx->ice->ice("union type function: encountered a type function instance without the required argument structure"); + LUAU_ASSERT(false); + } + + // if we only have one parameter, there's nothing to do. + if (typeParams.size() == 1) + return {follow(typeParams[0]), false, {}, {}}; + + // we need to follow all of the type parameters. + std::vector types; + types.reserve(typeParams.size()); + for (auto ty : typeParams) + types.emplace_back(follow(ty)); + + // unfortunately, we need this short-circuit: if all but one type is `never`, we will return that one type. + // this also will early return if _everything_ is `never`, since we already have to check that. + std::optional lastType = std::nullopt; + for (auto ty : types) + { + // if we have a previous type and it's not `never` and the current type isn't `never`... + if (lastType && !get(lastType) && !get(ty)) + { + // we know we are not taking the short-circuited path. + lastType = std::nullopt; + break; + } + + if (get(ty)) + continue; + lastType = ty; + } + + // if we still have a `lastType` at the end, we're taking the short-circuit and reducing early. + if (lastType) + return {lastType, false, {}, {}}; + + // check to see if the operand types are resolved enough, and wait to reduce if not + for (auto ty : types) + if (isPending(ty, ctx->solver)) + return {std::nullopt, false, {ty}, {}}; + + // fold over the types with `simplifyUnion` + TypeId resultTy = ctx->builtins->neverType; + for (auto ty : types) + { + SimplifyResult result = simplifyUnion(ctx->builtins, ctx->arena, resultTy, ty); + if (!result.blockedTypes.empty()) + return {std::nullopt, false, {result.blockedTypes.begin(), result.blockedTypes.end()}, {}}; + + resultTy = result.result; + } + + return {resultTy, false, {}, {}}; +} + + +TypeFunctionReductionResult intersectTypeFunction( + TypeId instance, + const std::vector& typeParams, + const std::vector& packParams, + NotNull ctx +) +{ + if (!packParams.empty()) + { + ctx->ice->ice("intersect type function: encountered a type function instance without the required argument structure"); + LUAU_ASSERT(false); + } + + // if we only have one parameter, there's nothing to do. + if (typeParams.size() == 1) + return {follow(typeParams[0]), false, {}, {}}; + + // we need to follow all of the type parameters. + std::vector types; + types.reserve(typeParams.size()); + for (auto ty : typeParams) + types.emplace_back(follow(ty)); + + if (FFlag::LuauRemoveNotAnyHack) + { + // if we only have two parameters and one is `*no-refine*`, we're all done. + if (types.size() == 2 && get(types[1])) + return {types[0], false, {}, {}}; + else if (types.size() == 2 && get(types[0])) + return {types[1], false, {}, {}}; + } + + // check to see if the operand types are resolved enough, and wait to reduce if not + // if any of them are `never`, the intersection will always be `never`, so we can reduce directly. + for (auto ty : types) + { + if (isPending(ty, ctx->solver)) + return {std::nullopt, false, {ty}, {}}; + else if (get(ty)) + return {ctx->builtins->neverType, false, {}, {}}; + } + + // fold over the types with `simplifyIntersection` + TypeId resultTy = ctx->builtins->unknownType; + for (auto ty : types) + { + // skip any `*no-refine*` types. + if (FFlag::LuauRemoveNotAnyHack && get(ty)) + continue; + + SimplifyResult result = simplifyIntersection(ctx->builtins, ctx->arena, resultTy, ty); + if (!result.blockedTypes.empty()) + return {std::nullopt, false, {result.blockedTypes.begin(), result.blockedTypes.end()}, {}}; + + resultTy = result.result; + } + + // if the intersection simplifies to `never`, this gives us bad autocomplete. + // we'll just produce the intersection plainly instead, but this might be revisitable + // if we ever give `never` some kind of "explanation" trail. + if (get(resultTy)) + { + TypeId intersection = ctx->arena->addType(IntersectionType{typeParams}); + return {intersection, false, {}, {}}; + } + + return {resultTy, false, {}, {}}; +} + +// computes the keys of `ty` into `result` +// `isRaw` parameter indicates whether or not we should follow __index metamethods +// returns `false` if `result` should be ignored because the answer is "all strings" +bool computeKeysOf(TypeId ty, Set& result, DenseHashSet& seen, bool isRaw, NotNull ctx) +{ + // if the type is the top table type, the answer is just "all strings" + if (get(ty)) + return false; + + // if we've already seen this type, we can do nothing + if (seen.contains(ty)) + return true; + seen.insert(ty); + + // if we have a particular table type, we can insert the keys + if (auto tableTy = get(ty)) + { + if (tableTy->indexer) + { + // if we have a string indexer, the answer is, again, "all strings" + if (isString(tableTy->indexer->indexType)) + return false; + } + + for (auto [key, _] : tableTy->props) + result.insert(key); + return true; + } + + // otherwise, we have a metatable to deal with + if (auto metatableTy = get(ty)) + { + bool res = true; + + if (!isRaw) + { + // findMetatableEntry demands the ability to emit errors, so we must give it + // the necessary state to do that, even if we intend to just eat the errors. + ErrorVec dummy; + + std::optional mmType = findMetatableEntry(ctx->builtins, dummy, ty, "__index", Location{}); + if (mmType) + res = res && computeKeysOf(*mmType, result, seen, isRaw, ctx); + } + + res = res && computeKeysOf(metatableTy->table, result, seen, isRaw, ctx); + + return res; + } + + if (auto classTy = get(ty)) + { + for (auto [key, _] : classTy->props) + result.insert(key); + + bool res = true; + if (classTy->metatable && !isRaw) + { + // findMetatableEntry demands the ability to emit errors, so we must give it + // the necessary state to do that, even if we intend to just eat the errors. + ErrorVec dummy; + + std::optional mmType = findMetatableEntry(ctx->builtins, dummy, ty, "__index", Location{}); + if (mmType) + res = res && computeKeysOf(*mmType, result, seen, isRaw, ctx); + } + + if (classTy->parent) + res = res && computeKeysOf(follow(*classTy->parent), result, seen, isRaw, ctx); + + return res; + } + + // this should not be reachable since the type should be a valid tables or classes part from normalization. + LUAU_ASSERT(false); + return false; +} + +TypeFunctionReductionResult keyofFunctionImpl( + const std::vector& typeParams, + const std::vector& packParams, + NotNull ctx, + bool isRaw +) +{ + if (typeParams.size() != 1 || !packParams.empty()) + { + ctx->ice->ice("keyof type function: encountered a type function instance without the required argument structure"); + LUAU_ASSERT(false); + } + + TypeId operandTy = follow(typeParams.at(0)); + + std::shared_ptr normTy = ctx->normalizer->normalize(operandTy); + + // if the operand failed to normalize, we can't reduce, but know nothing about inhabitance. + if (!normTy) + return {std::nullopt, false, {}, {}}; + + // if we don't have either just tables or just classes, we've got nothing to get keys of (at least until a future version perhaps adds classes + // as well) + if (normTy->hasTables() == normTy->hasClasses()) + return {std::nullopt, true, {}, {}}; + + // this is sort of atrocious, but we're trying to reject any type that has not normalized to a table or a union of tables. + if (normTy->hasTops() || normTy->hasBooleans() || normTy->hasErrors() || normTy->hasNils() || normTy->hasNumbers() || normTy->hasStrings() || + normTy->hasThreads() || normTy->hasBuffers() || normTy->hasFunctions() || normTy->hasTyvars()) + return {std::nullopt, true, {}, {}}; + + // we're going to collect the keys in here + Set keys{{}}; + + // computing the keys for classes + if (normTy->hasClasses()) + { + LUAU_ASSERT(!normTy->hasTables()); + + // seen set for key computation for classes + DenseHashSet seen{{}}; + + auto classesIter = normTy->classes.ordering.begin(); + auto classesIterEnd = normTy->classes.ordering.end(); + LUAU_ASSERT(classesIter != classesIterEnd); // should be guaranteed by the `hasClasses` check earlier + + // collect all the properties from the first class type + if (!computeKeysOf(*classesIter, keys, seen, isRaw, ctx)) + return {ctx->builtins->stringType, false, {}, {}}; // if it failed, we have a top type! + + // we need to look at each class to remove any keys that are not common amongst them all + while (++classesIter != classesIterEnd) + { + seen.clear(); // we'll reuse the same seen set + + Set localKeys{{}}; + + // we can skip to the next class if this one is a top type + if (!computeKeysOf(*classesIter, localKeys, seen, isRaw, ctx)) + continue; + + for (auto& key : keys) + { + // remove any keys that are not present in each class + if (!localKeys.contains(key)) + keys.erase(key); + } + } + } + + // computing the keys for tables + if (normTy->hasTables()) + { + LUAU_ASSERT(!normTy->hasClasses()); + + // seen set for key computation for tables + DenseHashSet seen{{}}; + + auto tablesIter = normTy->tables.begin(); + LUAU_ASSERT(tablesIter != normTy->tables.end()); // should be guaranteed by the `hasTables` check earlier + + // collect all the properties from the first table type + if (!computeKeysOf(*tablesIter, keys, seen, isRaw, ctx)) + return {ctx->builtins->stringType, false, {}, {}}; // if it failed, we have the top table type! + + // we need to look at each tables to remove any keys that are not common amongst them all + while (++tablesIter != normTy->tables.end()) + { + seen.clear(); // we'll reuse the same seen set + + Set localKeys{{}}; + + // we can skip to the next table if this one is the top table type + if (!computeKeysOf(*tablesIter, localKeys, seen, isRaw, ctx)) + continue; + + for (auto& key : keys) + { + // remove any keys that are not present in each table + if (!localKeys.contains(key)) + keys.erase(key); + } + } + } + + // if the set of keys is empty, `keyof` is `never` + if (keys.empty()) + return {ctx->builtins->neverType, false, {}, {}}; + + // everything is validated, we need only construct our big union of singletons now! + std::vector singletons; + singletons.reserve(keys.size()); + + for (std::string key : keys) + singletons.push_back(ctx->arena->addType(SingletonType{StringSingleton{key}})); + + // If there's only one entry, we don't need a UnionType. + // We can take straight take it from the first entry + // because it was added into the type arena already. + if (singletons.size() == 1) + return {singletons.front(), false, {}, {}}; + + return {ctx->arena->addType(UnionType{singletons}), false, {}, {}}; +} + +TypeFunctionReductionResult keyofTypeFunction( + TypeId instance, + const std::vector& typeParams, + const std::vector& packParams, + NotNull ctx +) +{ + if (typeParams.size() != 1 || !packParams.empty()) + { + ctx->ice->ice("keyof type function: encountered a type function instance without the required argument structure"); + LUAU_ASSERT(false); + } + + return keyofFunctionImpl(typeParams, packParams, ctx, /* isRaw */ false); +} + +TypeFunctionReductionResult rawkeyofTypeFunction( + TypeId instance, + const std::vector& typeParams, + const std::vector& packParams, + NotNull ctx +) +{ + if (typeParams.size() != 1 || !packParams.empty()) + { + ctx->ice->ice("rawkeyof type function: encountered a type function instance without the required argument structure"); + LUAU_ASSERT(false); + } + + return keyofFunctionImpl(typeParams, packParams, ctx, /* isRaw */ true); +} + +/* Searches through table's or class's props/indexer to find the property of `ty` + If found, appends that property to `result` and returns true + Else, returns false */ +bool searchPropsAndIndexer( + TypeId ty, + TableType::Props tblProps, + std::optional tblIndexer, + DenseHashSet& result, + NotNull ctx +) +{ + ty = follow(ty); + + // index into tbl's properties + if (auto stringSingleton = get(get(ty))) + { + if (tblProps.find(stringSingleton->value) != tblProps.end()) + { + TypeId propTy = follow(tblProps.at(stringSingleton->value).type()); + + // property is a union type -> we need to extend our reduction type + if (auto propUnionTy = get(propTy)) + { + for (TypeId option : propUnionTy->options) + result.insert(option); + } + else // property is a singular type or intersection type -> we can simply append + result.insert(propTy); + + return true; + } + } + + // index into tbl's indexer + if (tblIndexer) + { + if (isSubtype(ty, tblIndexer->indexType, ctx->scope, ctx->builtins, *ctx->ice)) + { + TypeId idxResultTy = follow(tblIndexer->indexResultType); + + // indexResultType is a union type -> we need to extend our reduction type + if (auto idxResUnionTy = get(idxResultTy)) + { + for (TypeId option : idxResUnionTy->options) + result.insert(option); + } + else // indexResultType is a singular type or intersection type -> we can simply append + result.insert(idxResultTy); + + return true; + } + } + + return false; +} + +/* Handles recursion / metamethods of tables/classes + `isRaw` parameter indicates whether or not we should follow __index metamethods + returns false if property of `ty` could not be found */ +bool tblIndexInto(TypeId indexer, TypeId indexee, DenseHashSet& result, NotNull ctx, bool isRaw) +{ + indexer = follow(indexer); + indexee = follow(indexee); + + // we have a table type to try indexing + if (auto tableTy = get(indexee)) + { + return searchPropsAndIndexer(indexer, tableTy->props, tableTy->indexer, result, ctx); + } + + // we have a metatable type to try indexing + if (auto metatableTy = get(indexee)) + { + if (auto tableTy = get(metatableTy->table)) + { + + // try finding all properties within the current scope of the table + if (searchPropsAndIndexer(indexer, tableTy->props, tableTy->indexer, result, ctx)) + return true; + } + + // if the code reached here, it means we weren't able to find all properties -> look into __index metamethod + if (!isRaw) + { + // findMetatableEntry demands the ability to emit errors, so we must give it + // the necessary state to do that, even if we intend to just eat the errors. + ErrorVec dummy; + std::optional mmType = findMetatableEntry(ctx->builtins, dummy, indexee, "__index", Location{}); + if (mmType) + return tblIndexInto(indexer, *mmType, result, ctx, isRaw); + } + } + + return false; +} + +/* Vocabulary note: indexee refers to the type that contains the properties, + indexer refers to the type that is used to access indexee + Example: index => `Person` is the indexee and `"name"` is the indexer */ +TypeFunctionReductionResult indexFunctionImpl( + const std::vector& typeParams, + const std::vector& packParams, + NotNull ctx, + bool isRaw +) +{ + TypeId indexeeTy = follow(typeParams.at(0)); + std::shared_ptr indexeeNormTy = ctx->normalizer->normalize(indexeeTy); + + // if the indexee failed to normalize, we can't reduce, but know nothing about inhabitance. + if (!indexeeNormTy) + return {std::nullopt, false, {}, {}}; + + // if we don't have either just tables or just classes, we've got nothing to index into + if (indexeeNormTy->hasTables() == indexeeNormTy->hasClasses()) + return {std::nullopt, true, {}, {}}; + + // we're trying to reject any type that has not normalized to a table/class or a union of tables/classes. + if (indexeeNormTy->hasTops() || indexeeNormTy->hasBooleans() || indexeeNormTy->hasErrors() || indexeeNormTy->hasNils() || + indexeeNormTy->hasNumbers() || indexeeNormTy->hasStrings() || indexeeNormTy->hasThreads() || indexeeNormTy->hasBuffers() || + indexeeNormTy->hasFunctions() || indexeeNormTy->hasTyvars()) + return {std::nullopt, true, {}, {}}; + + TypeId indexerTy = follow(typeParams.at(1)); + + if (isPending(indexerTy, ctx->solver)) + return {std::nullopt, false, {indexerTy}, {}}; + + std::shared_ptr indexerNormTy = ctx->normalizer->normalize(indexerTy); + + // if the indexer failed to normalize, we can't reduce, but know nothing about inhabitance. + if (!indexerNormTy) + return {std::nullopt, false, {}, {}}; + + // we're trying to reject any type that is not a string singleton or primitive (string, number, boolean, thread, nil, function, table, or buffer) + if (indexerNormTy->hasTops() || indexerNormTy->hasErrors()) + return {std::nullopt, true, {}, {}}; + + // indexer can be a union —> break them down into a vector + const std::vector* typesToFind = nullptr; + const std::vector singleType{indexerTy}; + if (auto unionTy = get(indexerTy)) + typesToFind = &unionTy->options; + else + typesToFind = &singleType; + + DenseHashSet properties{{}}; // vector of types that will be returned + + if (indexeeNormTy->hasClasses()) + { + LUAU_ASSERT(!indexeeNormTy->hasTables()); + + if (isRaw) // rawget should never reduce for classes (to match the behavior of the rawget global function) + return {std::nullopt, true, {}, {}}; + + // at least one class is guaranteed to be in the iterator by .hasClasses() + for (auto classesIter = indexeeNormTy->classes.ordering.begin(); classesIter != indexeeNormTy->classes.ordering.end(); ++classesIter) + { + auto classTy = get(*classesIter); + if (!classTy) + { + LUAU_ASSERT(false); // this should not be possible according to normalization's spec + return {std::nullopt, true, {}, {}}; + } + + for (TypeId ty : *typesToFind) + { + // Search for all instances of indexer in class->props and class->indexer + if (searchPropsAndIndexer(ty, classTy->props, classTy->indexer, properties, ctx)) + continue; // Indexer was found in this class, so we can move on to the next + + auto parent = classTy->parent; + bool foundInParent = false; + while (parent && !foundInParent) + { + auto parentClass = get(follow(*parent)); + foundInParent = searchPropsAndIndexer(ty, parentClass->props, parentClass->indexer, properties, ctx); + parent = parentClass->parent; + } + + // we move on to the next type if any of the parents we went through had the property. + if (foundInParent) + continue; + + // If code reaches here,that means the property not found -> check in the metatable's __index + + // findMetatableEntry demands the ability to emit errors, so we must give it + // the necessary state to do that, even if we intend to just eat the errors. + ErrorVec dummy; + std::optional mmType = findMetatableEntry(ctx->builtins, dummy, *classesIter, "__index", Location{}); + if (!mmType) // if a metatable does not exist, there is no where else to look + return {std::nullopt, true, {}, {}}; + + if (!tblIndexInto(ty, *mmType, properties, ctx, isRaw)) // if indexer is not in the metatable, we fail to reduce + return {std::nullopt, true, {}, {}}; + } + } + } + + if (indexeeNormTy->hasTables()) + { + LUAU_ASSERT(!indexeeNormTy->hasClasses()); + + // at least one table is guaranteed to be in the iterator by .hasTables() + for (auto tablesIter = indexeeNormTy->tables.begin(); tablesIter != indexeeNormTy->tables.end(); ++tablesIter) + { + for (TypeId ty : *typesToFind) + if (!tblIndexInto(ty, *tablesIter, properties, ctx, isRaw)) + return {std::nullopt, true, {}, {}}; + } + } + + // Call `follow()` on each element to resolve all Bound types before returning + std::transform( + properties.begin(), + properties.end(), + properties.begin(), + [](TypeId ty) + { + return follow(ty); + } + ); + + // If the type being reduced to is a single type, no need to union + if (properties.size() == 1) + return {*properties.begin(), false, {}, {}}; + + return {ctx->arena->addType(UnionType{std::vector(properties.begin(), properties.end())}), false, {}, {}}; +} + +TypeFunctionReductionResult indexTypeFunction( + TypeId instance, + const std::vector& typeParams, + const std::vector& packParams, + NotNull ctx +) +{ + if (typeParams.size() != 2 || !packParams.empty()) + { + ctx->ice->ice("index type function: encountered a type function instance without the required argument structure"); + LUAU_ASSERT(false); + } + + return indexFunctionImpl(typeParams, packParams, ctx, /* isRaw */ false); +} + +TypeFunctionReductionResult rawgetTypeFunction( + TypeId instance, + const std::vector& typeParams, + const std::vector& packParams, + NotNull ctx +) +{ + if (typeParams.size() != 2 || !packParams.empty()) + { + ctx->ice->ice("rawget type function: encountered a type function instance without the required argument structure"); + LUAU_ASSERT(false); + } + + return indexFunctionImpl(typeParams, packParams, ctx, /* isRaw */ true); +} + +BuiltinTypeFunctions::BuiltinTypeFunctions() + : userFunc{"user", userDefinedTypeFunction} + , notFunc{"not", notTypeFunction} + , lenFunc{"len", lenTypeFunction} + , unmFunc{"unm", unmTypeFunction} + , addFunc{"add", addTypeFunction} + , subFunc{"sub", subTypeFunction} + , mulFunc{"mul", mulTypeFunction} + , divFunc{"div", divTypeFunction} + , idivFunc{"idiv", idivTypeFunction} + , powFunc{"pow", powTypeFunction} + , modFunc{"mod", modTypeFunction} + , concatFunc{"concat", concatTypeFunction} + , andFunc{"and", andTypeFunction} + , orFunc{"or", orTypeFunction} + , ltFunc{"lt", ltTypeFunction} + , leFunc{"le", leTypeFunction} + , eqFunc{"eq", eqTypeFunction} + , refineFunc{"refine", refineTypeFunction} + , singletonFunc{"singleton", singletonTypeFunction} + , unionFunc{"union", unionTypeFunction} + , intersectFunc{"intersect", intersectTypeFunction} + , keyofFunc{"keyof", keyofTypeFunction} + , rawkeyofFunc{"rawkeyof", rawkeyofTypeFunction} + , indexFunc{"index", indexTypeFunction} + , rawgetFunc{"rawget", rawgetTypeFunction} +{ +} + +void BuiltinTypeFunctions::addToScope(NotNull arena, NotNull scope) const +{ + // make a type function for a one-argument type function + auto mkUnaryTypeFunction = [&](const TypeFunction* tf) + { + TypeId t = arena->addType(GenericType{"T"}); + GenericTypeDefinition genericT{t}; + + return TypeFun{{genericT}, arena->addType(TypeFunctionInstanceType{NotNull{tf}, {t}, {}})}; + }; + + // make a type function for a two-argument type function + auto mkBinaryTypeFunction = [&](const TypeFunction* tf) + { + TypeId t = arena->addType(GenericType{"T"}); + TypeId u = arena->addType(GenericType{"U"}); + GenericTypeDefinition genericT{t}; + GenericTypeDefinition genericU{u, {t}}; + + return TypeFun{{genericT, genericU}, arena->addType(TypeFunctionInstanceType{NotNull{tf}, {t, u}, {}})}; + }; + + scope->exportedTypeBindings[lenFunc.name] = mkUnaryTypeFunction(&lenFunc); + scope->exportedTypeBindings[unmFunc.name] = mkUnaryTypeFunction(&unmFunc); + + scope->exportedTypeBindings[addFunc.name] = mkBinaryTypeFunction(&addFunc); + scope->exportedTypeBindings[subFunc.name] = mkBinaryTypeFunction(&subFunc); + scope->exportedTypeBindings[mulFunc.name] = mkBinaryTypeFunction(&mulFunc); + scope->exportedTypeBindings[divFunc.name] = mkBinaryTypeFunction(&divFunc); + scope->exportedTypeBindings[idivFunc.name] = mkBinaryTypeFunction(&idivFunc); + scope->exportedTypeBindings[powFunc.name] = mkBinaryTypeFunction(&powFunc); + scope->exportedTypeBindings[modFunc.name] = mkBinaryTypeFunction(&modFunc); + scope->exportedTypeBindings[concatFunc.name] = mkBinaryTypeFunction(&concatFunc); + + scope->exportedTypeBindings[ltFunc.name] = mkBinaryTypeFunction(<Func); + scope->exportedTypeBindings[leFunc.name] = mkBinaryTypeFunction(&leFunc); + scope->exportedTypeBindings[eqFunc.name] = mkBinaryTypeFunction(&eqFunc); + + scope->exportedTypeBindings[keyofFunc.name] = mkUnaryTypeFunction(&keyofFunc); + scope->exportedTypeBindings[rawkeyofFunc.name] = mkUnaryTypeFunction(&rawkeyofFunc); + + scope->exportedTypeBindings[indexFunc.name] = mkBinaryTypeFunction(&indexFunc); + scope->exportedTypeBindings[rawgetFunc.name] = mkBinaryTypeFunction(&rawgetFunc); +} + +const BuiltinTypeFunctions& builtinTypeFunctions() +{ + static std::unique_ptr result = std::make_unique(); + + return *result; +} + +} // namespace Luau diff --git a/Analysis/src/TypeFunctionReductionGuesser.cpp b/Analysis/src/TypeFunctionReductionGuesser.cpp new file mode 100644 index 000000000..389a797d7 --- /dev/null +++ b/Analysis/src/TypeFunctionReductionGuesser.cpp @@ -0,0 +1,455 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/TypeFunctionReductionGuesser.h" + +#include "Luau/DenseHash.h" +#include "Luau/Normalize.h" +#include "Luau/ToString.h" +#include "Luau/TypeFunction.h" +#include "Luau/Type.h" +#include "Luau/TypePack.h" +#include "Luau/TypeUtils.h" +#include "Luau/VecDeque.h" +#include "Luau/VisitType.h" + +#include +#include +#include + +namespace Luau +{ +struct InstanceCollector2 : TypeOnceVisitor +{ + VecDeque tys; + VecDeque tps; + DenseHashSet cyclicInstance{nullptr}; + DenseHashSet instanceArguments{nullptr}; + + bool visit(TypeId ty, const TypeFunctionInstanceType& it) override + { + // TypeOnceVisitor performs a depth-first traversal in the absence of + // cycles. This means that by pushing to the front of the queue, we will + // try to reduce deeper instances first if we start with the first thing + // in the queue. Consider Add, number>, number>: + // we want to reduce the innermost Add instantiation + // first. + tys.push_front(ty); + for (auto t : it.typeArguments) + instanceArguments.insert(follow(t)); + return true; + } + + void cycle(TypeId ty) override + { + /// Detected cyclic type pack + TypeId t = follow(ty); + if (get(t)) + cyclicInstance.insert(t); + } + + bool visit(TypeId ty, const ClassType&) override + { + return false; + } + + bool visit(TypePackId tp, const TypeFunctionInstanceTypePack&) override + { + // TypeOnceVisitor performs a depth-first traversal in the absence of + // cycles. This means that by pushing to the front of the queue, we will + // try to reduce deeper instances first if we start with the first thing + // in the queue. Consider Add, number>, number>: + // we want to reduce the innermost Add instantiation + // first. + tps.push_front(tp); + return true; + } +}; + + + +TypeFunctionReductionGuesser::TypeFunctionReductionGuesser(NotNull arena, NotNull builtins, NotNull normalizer) + : arena(arena) + , builtins(builtins) + , normalizer(normalizer) +{ +} + +bool TypeFunctionReductionGuesser::isFunctionGenericsSaturated(const FunctionType& ftv, DenseHashSet& argsUsed) +{ + bool sameSize = ftv.generics.size() == argsUsed.size(); + bool allGenericsAppear = true; + for (auto gt : ftv.generics) + allGenericsAppear = allGenericsAppear || argsUsed.contains(gt); + return sameSize && allGenericsAppear; +} + +void TypeFunctionReductionGuesser::dumpGuesses() +{ + for (auto [tf, t] : functionReducesTo) + printf("Type family %s ~~> %s\n", toString(tf).c_str(), toString(t).c_str()); + for (auto [t, t_] : substitutable) + printf("Substitute %s for %s\n", toString(t).c_str(), toString(t_).c_str()); +} + +std::optional TypeFunctionReductionGuesser::guess(TypeId typ) +{ + std::optional guessedType = guessType(typ); + + if (!guessedType.has_value()) + return {}; + + TypeId guess = follow(*guessedType); + if (get(guess)) + return {}; + + return guess; +} + +std::optional TypeFunctionReductionGuesser::guess(TypePackId tp) +{ + auto [head, tail] = flatten(tp); + + std::vector guessedHead; + guessedHead.reserve(head.size()); + + for (auto typ : head) + { + std::optional guessedType = guessType(typ); + + if (!guessedType.has_value()) + return {}; + + TypeId guess = follow(*guessedType); + if (get(guess)) + return {}; + + guessedHead.push_back(*guessedType); + } + + return arena->addTypePack(TypePack{guessedHead, tail}); +} + +TypeFunctionReductionGuessResult TypeFunctionReductionGuesser::guessTypeFunctionReductionForFunctionExpr( + const AstExprFunction& expr, + const FunctionType* ftv, + TypeId retTy +) +{ + InstanceCollector2 collector; + collector.traverse(retTy); + toInfer = std::move(collector.tys); + cyclicInstances = std::move(collector.cyclicInstance); + + if (isFunctionGenericsSaturated(*ftv, collector.instanceArguments)) + return TypeFunctionReductionGuessResult{{}, nullptr, false}; + infer(); + + std::vector> results; + std::vector args; + for (TypeId t : ftv->argTypes) + args.push_back(t); + + // Submit a guess for arg types + for (size_t i = 0; i < expr.args.size; i++) + { + TypeId argTy; + AstLocal* local = expr.args.data[i]; + if (i >= args.size()) + continue; + + argTy = args[i]; + std::optional guessedType = guessType(argTy); + if (!guessedType.has_value()) + continue; + TypeId guess = follow(*guessedType); + if (get(guess)) + continue; + + results.push_back({local->name.value, guess}); + } + + // Submit a guess for return types + TypeId recommendedAnnotation; + std::optional guessedReturnType = guessType(retTy); + if (!guessedReturnType.has_value()) + recommendedAnnotation = builtins->unknownType; + else + recommendedAnnotation = follow(*guessedReturnType); + if (auto t = get(recommendedAnnotation)) + recommendedAnnotation = builtins->unknownType; + + toInfer.clear(); + cyclicInstances.clear(); + functionReducesTo.clear(); + substitutable.clear(); + + return TypeFunctionReductionGuessResult{results, recommendedAnnotation}; +} + +std::optional TypeFunctionReductionGuesser::guessType(TypeId arg) +{ + TypeId t = follow(arg); + if (substitutable.contains(t)) + { + TypeId subst = follow(substitutable[t]); + if (subst == t || substitutable.contains(subst)) + return subst; + else if (!get(subst)) + return subst; + else + return guessType(subst); + } + if (get(t)) + { + if (functionReducesTo.contains(t)) + return functionReducesTo[t]; + } + return {}; +} + +bool TypeFunctionReductionGuesser::isNumericBinopFunction(const TypeFunctionInstanceType& instance) +{ + return instance.function->name == "add" || instance.function->name == "sub" || instance.function->name == "mul" || + instance.function->name == "div" || instance.function->name == "idiv" || instance.function->name == "pow" || + instance.function->name == "mod"; +} + +bool TypeFunctionReductionGuesser::isComparisonFunction(const TypeFunctionInstanceType& instance) +{ + return instance.function->name == "lt" || instance.function->name == "le" || instance.function->name == "eq"; +} + +bool TypeFunctionReductionGuesser::isOrAndFunction(const TypeFunctionInstanceType& instance) +{ + return instance.function->name == "or" || instance.function->name == "and"; +} + +bool TypeFunctionReductionGuesser::isNotFunction(const TypeFunctionInstanceType& instance) +{ + return instance.function->name == "not"; +} + +bool TypeFunctionReductionGuesser::isLenFunction(const TypeFunctionInstanceType& instance) +{ + return instance.function->name == "len"; +} + +bool TypeFunctionReductionGuesser::isUnaryMinus(const TypeFunctionInstanceType& instance) +{ + return instance.function->name == "unm"; +} + +// Operand is assignable if it looks like a cyclic function instance, or a generic type +bool TypeFunctionReductionGuesser::operandIsAssignable(TypeId ty) +{ + if (get(ty)) + return true; + if (get(ty)) + return true; + if (cyclicInstances.contains(ty)) + return true; + return false; +} + +std::shared_ptr TypeFunctionReductionGuesser::normalize(TypeId ty) +{ + return normalizer->normalize(ty); +} + + +std::optional TypeFunctionReductionGuesser::tryAssignOperandType(TypeId ty) +{ + // Because we collect innermost instances first, if we see a type function instance as an operand, + // We try to check if we guessed a type for it + if (auto tfit = get(ty)) + { + if (functionReducesTo.contains(ty)) + return {functionReducesTo[ty]}; + } + + // If ty is a generic, we need to check if we inferred a substitution + if (auto gt = get(ty)) + { + if (substitutable.contains(ty)) + return {substitutable[ty]}; + } + + // If we cannot substitute a type for this value, we return an empty optional + return {}; +} + +void TypeFunctionReductionGuesser::step() +{ + TypeId t = toInfer.front(); + toInfer.pop_front(); + t = follow(t); + if (auto tf = get(t)) + inferTypeFunctionSubstitutions(t, tf); +} + +void TypeFunctionReductionGuesser::infer() +{ + while (!done()) + step(); +} + +bool TypeFunctionReductionGuesser::done() +{ + return toInfer.empty(); +} + +void TypeFunctionReductionGuesser::inferTypeFunctionSubstitutions(TypeId ty, const TypeFunctionInstanceType* instance) +{ + + TypeFunctionInferenceResult result; + LUAU_ASSERT(instance); + // TODO: Make an inexhaustive version of this warn in the compiler? + if (isNumericBinopFunction(*instance)) + result = inferNumericBinopFunction(instance); + else if (isComparisonFunction(*instance)) + result = inferComparisonFunction(instance); + else if (isOrAndFunction(*instance)) + result = inferOrAndFunction(instance); + else if (isNotFunction(*instance)) + result = inferNotFunction(instance); + else if (isLenFunction(*instance)) + result = inferLenFunction(instance); + else if (isUnaryMinus(*instance)) + result = inferUnaryMinusFunction(instance); + else + result = {{}, builtins->unknownType}; + + TypeId resultInference = follow(result.functionResultInference); + if (!functionReducesTo.contains(resultInference)) + functionReducesTo[ty] = resultInference; + + for (size_t i = 0; i < instance->typeArguments.size(); i++) + { + if (i < result.operandInference.size()) + { + TypeId arg = follow(instance->typeArguments[i]); + TypeId inference = follow(result.operandInference[i]); + if (auto tfit = get(arg)) + { + if (!functionReducesTo.contains(arg)) + functionReducesTo.try_insert(arg, inference); + } + else if (auto gt = get(arg)) + substitutable[arg] = inference; + } + } +} + +TypeFunctionInferenceResult TypeFunctionReductionGuesser::inferNumericBinopFunction(const TypeFunctionInstanceType* instance) +{ + LUAU_ASSERT(instance->typeArguments.size() == 2); + TypeFunctionInferenceResult defaultNumericBinopInference{{builtins->numberType, builtins->numberType}, builtins->numberType}; + return defaultNumericBinopInference; +} + +TypeFunctionInferenceResult TypeFunctionReductionGuesser::inferComparisonFunction(const TypeFunctionInstanceType* instance) +{ + LUAU_ASSERT(instance->typeArguments.size() == 2); + // Comparison functions are lt/le/eq. + // Heuristic: these are type functions from t -> t -> bool + + TypeId lhsTy = follow(instance->typeArguments[0]); + TypeId rhsTy = follow(instance->typeArguments[1]); + + auto comparisonInference = [&](TypeId op) -> TypeFunctionInferenceResult + { + return TypeFunctionInferenceResult{{op, op}, builtins->booleanType}; + }; + + if (std::optional ty = tryAssignOperandType(lhsTy)) + lhsTy = follow(*ty); + if (std::optional ty = tryAssignOperandType(rhsTy)) + rhsTy = follow(*ty); + if (operandIsAssignable(lhsTy) && !operandIsAssignable(rhsTy)) + return comparisonInference(rhsTy); + if (operandIsAssignable(rhsTy) && !operandIsAssignable(lhsTy)) + return comparisonInference(lhsTy); + return comparisonInference(builtins->numberType); +} + +TypeFunctionInferenceResult TypeFunctionReductionGuesser::inferOrAndFunction(const TypeFunctionInstanceType* instance) +{ + + LUAU_ASSERT(instance->typeArguments.size() == 2); + + TypeId lhsTy = follow(instance->typeArguments[0]); + TypeId rhsTy = follow(instance->typeArguments[1]); + + if (std::optional ty = tryAssignOperandType(lhsTy)) + lhsTy = follow(*ty); + if (std::optional ty = tryAssignOperandType(rhsTy)) + rhsTy = follow(*ty); + TypeFunctionInferenceResult defaultAndOrInference{{builtins->unknownType, builtins->unknownType}, builtins->booleanType}; + + std::shared_ptr lty = normalize(lhsTy); + std::shared_ptr rty = normalize(lhsTy); + bool lhsTruthy = lty ? lty->isTruthy() : false; + bool rhsTruthy = rty ? rty->isTruthy() : false; + // If at the end, we still don't have good substitutions, return the default type + if (instance->function->name == "or") + { + if (operandIsAssignable(lhsTy) && operandIsAssignable(rhsTy)) + return defaultAndOrInference; + if (operandIsAssignable(lhsTy)) + return TypeFunctionInferenceResult{{builtins->unknownType, rhsTy}, rhsTy}; + if (operandIsAssignable(rhsTy)) + return TypeFunctionInferenceResult{{lhsTy, builtins->unknownType}, lhsTy}; + if (lhsTruthy) + return {{lhsTy, rhsTy}, lhsTy}; + if (rhsTruthy) + return {{builtins->unknownType, rhsTy}, rhsTy}; + } + + if (instance->function->name == "and") + { + + if (operandIsAssignable(lhsTy) && operandIsAssignable(rhsTy)) + return defaultAndOrInference; + if (operandIsAssignable(lhsTy)) + return TypeFunctionInferenceResult{{}, rhsTy}; + if (operandIsAssignable(rhsTy)) + return TypeFunctionInferenceResult{{}, lhsTy}; + if (lhsTruthy) + return {{lhsTy, rhsTy}, rhsTy}; + else + return {{lhsTy, rhsTy}, lhsTy}; + } + + return defaultAndOrInference; +} + +TypeFunctionInferenceResult TypeFunctionReductionGuesser::inferNotFunction(const TypeFunctionInstanceType* instance) +{ + LUAU_ASSERT(instance->typeArguments.size() == 1); + TypeId opTy = follow(instance->typeArguments[0]); + if (std::optional ty = tryAssignOperandType(opTy)) + opTy = follow(*ty); + return {{opTy}, builtins->booleanType}; +} + +TypeFunctionInferenceResult TypeFunctionReductionGuesser::inferLenFunction(const TypeFunctionInstanceType* instance) +{ + LUAU_ASSERT(instance->typeArguments.size() == 1); + TypeId opTy = follow(instance->typeArguments[0]); + if (std::optional ty = tryAssignOperandType(opTy)) + opTy = follow(*ty); + return {{opTy}, builtins->numberType}; +} + +TypeFunctionInferenceResult TypeFunctionReductionGuesser::inferUnaryMinusFunction(const TypeFunctionInstanceType* instance) +{ + LUAU_ASSERT(instance->typeArguments.size() == 1); + TypeId opTy = follow(instance->typeArguments[0]); + if (std::optional ty = tryAssignOperandType(opTy)) + opTy = follow(*ty); + if (isNumber(opTy)) + return {{builtins->numberType}, builtins->numberType}; + return {{builtins->unknownType}, builtins->numberType}; +} + + +} // namespace Luau diff --git a/Analysis/src/TypeFunctionRuntime.cpp b/Analysis/src/TypeFunctionRuntime.cpp new file mode 100644 index 000000000..84fa0feae --- /dev/null +++ b/Analysis/src/TypeFunctionRuntime.cpp @@ -0,0 +1,2324 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Luau/TypeFunctionRuntime.h" + +#include "Luau/DenseHash.h" +#include "Luau/StringUtils.h" +#include "Luau/TypeFunction.h" + +#include "lua.h" +#include "lualib.h" + +#include +#include +#include + +LUAU_DYNAMIC_FASTINT(LuauTypeFunctionSerdeIterationLimit) +LUAU_DYNAMIC_FASTINT(LuauTypeSolverRelease) +LUAU_FASTFLAGVARIABLE(LuauUserTypeFunFixRegister, false) +LUAU_FASTFLAGVARIABLE(LuauUserTypeFunFixNoReadWrite, false) + +namespace Luau +{ + +constexpr int kTypeUserdataTag = 42; + +void* typeFunctionAlloc(void* ud, void* ptr, size_t osize, size_t nsize) +{ + if (nsize == 0) + { + ::operator delete(ptr); + return nullptr; + } + else if (osize == 0) + { + return ::operator new(nsize); + } + else + { + void* data = ::operator new(nsize); + memcpy(data, ptr, nsize < osize ? nsize : osize); + + ::operator delete(ptr); + + return data; + } +} + +std::optional checkResultForError(lua_State* L, const char* typeFunctionName, int luaResult) +{ + switch (luaResult) + { + case LUA_OK: + return std::nullopt; + case LUA_YIELD: + case LUA_BREAK: + return format("'%s' type function errored: unexpected yield or break", typeFunctionName); + default: + if (!lua_gettop(L)) + return format("'%s' type function errored unexpectedly", typeFunctionName); + + if (lua_isstring(L, -1)) + return format("'%s' type function errored at runtime: %s", typeFunctionName, lua_tostring(L, -1)); + + return format("'%s' type function errored at runtime: raised an error of type %s", typeFunctionName, lua_typename(L, -1)); + } +} + +static TypeFunctionRuntime* getTypeFunctionRuntime(lua_State* L) +{ + return static_cast(lua_getthreaddata(lua_mainthread(L))); +} + +TypeFunctionType* allocateTypeFunctionType(lua_State* L, TypeFunctionTypeVariant type) +{ + auto ctx = getTypeFunctionRuntime(L); + return ctx->typeArena.allocate(std::move(type)); +} + +TypeFunctionTypePackVar* allocateTypeFunctionTypePack(lua_State* L, TypeFunctionTypePackVariant type) +{ + auto ctx = getTypeFunctionRuntime(L); + return ctx->typePackArena.allocate(std::move(type)); +} + +// Pushes a new type userdata onto the stack +void allocTypeUserData(lua_State* L, TypeFunctionTypeVariant type) +{ + // allocate a new type userdata + TypeFunctionTypeId* ptr = static_cast(lua_newuserdatatagged(L, sizeof(TypeFunctionTypeId), kTypeUserdataTag)); + *ptr = allocateTypeFunctionType(L, std::move(type)); + + // set the new userdata's metatable to type metatable + luaL_getmetatable(L, "type"); + lua_setmetatable(L, -2); +} + +void deallocTypeUserData(lua_State* L, void* data) +{ + // only non-owning pointers into an arena is stored +} + +bool isTypeUserData(lua_State* L, int idx) +{ + if (!lua_isuserdata(L, idx)) + return false; + + return lua_touserdatatagged(L, idx, kTypeUserdataTag) != nullptr; +} + +TypeFunctionTypeId getTypeUserData(lua_State* L, int idx) +{ + if (auto typ = static_cast(lua_touserdatatagged(L, idx, kTypeUserdataTag))) + return *typ; + + luaL_typeerrorL(L, idx, "type"); +} + +std::optional optionalTypeUserData(lua_State* L, int idx) +{ + if (lua_isnoneornil(L, idx)) + return std::nullopt; + else + return getTypeUserData(L, idx); +} + +// returns a string tag of TypeFunctionTypeId +static std::string getTag(lua_State* L, TypeFunctionTypeId ty) +{ + if (auto n = get(ty); n && n->type == TypeFunctionPrimitiveType::Type::NilType) + return "nil"; + else if (auto b = get(ty); b && b->type == TypeFunctionPrimitiveType::Type::Boolean) + return "boolean"; + else if (auto n = get(ty); n && n->type == TypeFunctionPrimitiveType::Type::Number) + return "number"; + else if (auto s = get(ty); s && s->type == TypeFunctionPrimitiveType::Type::String) + return "string"; + else if (get(ty)) + return "unknown"; + else if (get(ty)) + return "never"; + else if (get(ty)) + return "any"; + else if (auto s = get(ty)) + return "singleton"; + else if (get(ty)) + return "negation"; + else if (get(ty)) + return "union"; + else if (get(ty)) + return "intersection"; + else if (get(ty)) + return "table"; + else if (get(ty)) + return "function"; + else if (get(ty)) + return "class"; + + LUAU_UNREACHABLE(); + luaL_error(L, "VM encountered unexpected type variant when determining tag"); +} + +// Luau: `type.unknown` +// Returns the type instance representing unknown +static int createUnknown(lua_State* L) +{ + allocTypeUserData(L, TypeFunctionUnknownType{}); + + return 1; +} + +// Luau: `type.never` +// Returns the type instance representing never +static int createNever(lua_State* L) +{ + allocTypeUserData(L, TypeFunctionNeverType{}); + + return 1; +} + +// Luau: `type.any` +// Returns the type instance representing any +static int createAny(lua_State* L) +{ + allocTypeUserData(L, TypeFunctionAnyType{}); + + return 1; +} + +// Luau: `type.boolean` +// Returns the type instance representing boolean +static int createBoolean(lua_State* L) +{ + allocTypeUserData(L, TypeFunctionPrimitiveType{TypeFunctionPrimitiveType::Boolean}); + + return 1; +} + +// Luau: `type.number` +// Returns the type instance representing number +static int createNumber(lua_State* L) +{ + allocTypeUserData(L, TypeFunctionPrimitiveType{TypeFunctionPrimitiveType::Number}); + + return 1; +} + +// Luau: `type.string` +// Returns the type instance representing string +static int createString(lua_State* L) +{ + allocTypeUserData(L, TypeFunctionPrimitiveType{TypeFunctionPrimitiveType::String}); + + return 1; +} + +// Luau: `type.singleton(value: string | boolean | nil) -> type` +// Returns the type instance representing string or boolean singleton or nil +static int createSingleton(lua_State* L) +{ + if (lua_isboolean(L, 1)) // Create boolean singleton + { + bool value = luaL_checkboolean(L, 1); + allocTypeUserData(L, TypeFunctionSingletonType{TypeFunctionBooleanSingleton{value}}); + + return 1; + } + + // n.b. we cannot use lua_isstring here because lua committed the cardinal sin of calling a number a string + if (lua_type(L, 1) == LUA_TSTRING) // Create string singleton + { + const char* value = luaL_checkstring(L, 1); + allocTypeUserData(L, TypeFunctionSingletonType{TypeFunctionStringSingleton{value}}); + + return 1; + } + + if (lua_isnil(L, 1)) + { + allocTypeUserData(L, TypeFunctionPrimitiveType(TypeFunctionPrimitiveType::NilType)); + + return 1; + } + + luaL_error(L, "types.singleton: can't create singleton from `%s` type", lua_typename(L, 1)); +} + +// Luau: `self:value() -> type` +// Returns the value of a singleton +static int getSingletonValue(lua_State* L) +{ + int argumentCount = lua_gettop(L); + if (argumentCount != 1) + luaL_error(L, "type.value: expected 1 argument, but got %d", argumentCount); + + TypeFunctionTypeId self = getTypeUserData(L, 1); + if (auto tfpt = get(self)) + { + if (tfpt->type != TypeFunctionPrimitiveType::NilType) + luaL_error(L, "type.value: expected self to be a singleton, but got %s instead", getTag(L, self).c_str()); + + lua_pushnil(L); + return 1; + } + + auto tfst = get(self); + if (!tfst) + luaL_error(L, "type.value: expected self to be a singleton, but got %s instead", getTag(L, self).c_str()); + + if (auto tfbst = get(tfst)) + { + lua_pushboolean(L, tfbst->value); + return 1; + } + + if (auto tfsst = get(tfst)) + { + lua_pushlstring(L, tfsst->value.c_str(), tfsst->value.length()); + return 1; + } + + luaL_error(L, "type.value: can't call `value` method on `%s` type", getTag(L, self).c_str()); +} + +// Luau: `types.unionof(...: type) -> type` +// Returns the type instance representing union +static int createUnion(lua_State* L) +{ + // get the number of arguments for union + int argSize = lua_gettop(L); + if (argSize < 2) + luaL_error(L, "types.unionof: expected at least 2 types to union, but got %d", argSize); + + std::vector components; + components.reserve(argSize); + + for (int i = 1; i <= argSize; i++) + components.push_back(getTypeUserData(L, i)); + + allocTypeUserData(L, TypeFunctionUnionType{components}); + + return 1; +} + +// Luau: `types.intersectionof(...: type) -> type` +// Returns the type instance representing intersection +static int createIntersection(lua_State* L) +{ + // get the number of arguments for intersection + int argSize = lua_gettop(L); + if (argSize < 2) + luaL_error(L, "types.intersectionof: expected at least 2 types to intersection, but got %d", argSize); + + std::vector components; + components.reserve(argSize); + + for (int i = 1; i <= argSize; i++) + components.push_back(getTypeUserData(L, i)); + + allocTypeUserData(L, TypeFunctionIntersectionType{components}); + + return 1; +} + +// Luau: `self:components() -> {type}` +// Returns the components of union or intersection +static int getComponents(lua_State* L) +{ + int argumentCount = lua_gettop(L); + if (argumentCount != 1) + luaL_error(L, "type.components: expected 1 argument, but got %d", argumentCount); + + TypeFunctionTypeId self = getTypeUserData(L, 1); + auto tfut = get(self); + if (tfut) + { + int argSize = int(tfut->components.size()); + + lua_createtable(L, argSize, 0); + for (int i = 0; i < argSize; i++) + { + TypeFunctionTypeId component = tfut->components[i]; + allocTypeUserData(L, component->type); + lua_rawseti(L, -2, i + 1); // Luau is 1-indexed while C++ is 0-indexed + } + + return 1; + } + + auto tfit = get(self); + if (tfit) + { + int argSize = int(tfit->components.size()); + + lua_createtable(L, argSize, 0); + for (int i = 0; i < argSize; i++) + { + TypeFunctionTypeId component = tfit->components[i]; + allocTypeUserData(L, component->type); + lua_rawseti(L, -2, i + 1); // Luau is 1-indexed while C++ is 0-indexed + } + + return 1; + } + + luaL_error(L, "type.components: cannot call components of `%s` type", getTag(L, self).c_str()); +} + +// Luau: `types.negationof(arg: type) -> type` +// Returns the type instance representing negation +static int createNegation(lua_State* L) +{ + int argumentCount = lua_gettop(L); + if (argumentCount != 1) + luaL_error(L, "types.negationof: expected 1 argument, but got %d", argumentCount); + + TypeFunctionTypeId arg = getTypeUserData(L, 1); + if (get(arg) || get(arg)) + luaL_error(L, "types.negationof: cannot perform negation on `%s` type", getTag(L, arg).c_str()); + + allocTypeUserData(L, TypeFunctionNegationType{arg}); + + return 1; +} + +// Luau: `self:inner() -> type` +// Returns the type instance being negated +static int getNegatedValue(lua_State* L) +{ + int argumentCount = lua_gettop(L); + if (argumentCount != 1) + luaL_error(L, "type.inner: expected 1 argument, but got %d", argumentCount); + + TypeFunctionTypeId self = getTypeUserData(L, 1); + if (auto tfnt = get(self); !tfnt) + allocTypeUserData(L, tfnt->type->type); + else + luaL_error(L, "type.inner: cannot call inner method on non-negation type: `%s` type", getTag(L, self).c_str()); + + return 1; +} + +// Luau: `types.newtable(props: {[type]: type | { read: type, write: type }}?, indexer: {index: type, readresult: type, writeresult: type}?, +// metatable: type?) -> type` Returns the type instance representing table +static int createTable(lua_State* L) +{ + int argumentCount = lua_gettop(L); + if (argumentCount > 3) + luaL_error(L, "types.newtable: expected 0-3 arguments, but got %d", argumentCount); + + // Parse prop + TypeFunctionTableType::Props props{}; + if (lua_istable(L, 1)) + { + lua_pushnil(L); + while (lua_next(L, 1) != 0) + { + TypeFunctionTypeId key = getTypeUserData(L, -2); + + auto tfst = get(key); + if (!tfst) + luaL_error(L, "types.newtable: expected to be given a singleton type, but got %s instead", getTag(L, key).c_str()); + + auto tfsst = get(tfst); + if (!tfsst) + luaL_error(L, "types.newtable: expected to be given a string singleton type, but got %s instead", getTag(L, key).c_str()); + + if (lua_istable(L, -1)) + { + lua_getfield(L, -1, "read"); + std::optional readTy; + if (!lua_isnil(L, -1)) + readTy = getTypeUserData(L, -1); + lua_pop(L, 1); + + lua_getfield(L, -1, "write"); + std::optional writeTy; + if (!lua_isnil(L, -1)) + writeTy = getTypeUserData(L, -1); + lua_pop(L, 1); + + props[tfsst->value] = TypeFunctionProperty{readTy, writeTy}; + } + else + { + TypeFunctionTypeId value = getTypeUserData(L, -1); + props[tfsst->value] = TypeFunctionProperty::rw(value); + } + + lua_pop(L, 1); + } + } + else if (!lua_isnoneornil(L, 1)) + luaL_typeerrorL(L, 1, "table"); + + // Parse indexer + std::optional indexer; + if (lua_istable(L, 2)) + { + // Parse keyType and valueType + lua_getfield(L, 2, "index"); + TypeFunctionTypeId keyType = getTypeUserData(L, -1); + lua_pop(L, 1); + + lua_getfield(L, 2, "readresult"); + TypeFunctionTypeId valueType = getTypeUserData(L, -1); + lua_pop(L, 1); + + indexer = TypeFunctionTableIndexer(keyType, valueType); + } + else if (!lua_isnoneornil(L, 2)) + luaL_typeerrorL(L, 2, "table"); + + // Parse metatable + std::optional metatable = optionalTypeUserData(L, 3); + if (metatable && !get(*metatable)) + luaL_error(L, "types.newtable: expected to be given a table type as a metatable, but got %s instead", getTag(L, *metatable).c_str()); + + allocTypeUserData(L, TypeFunctionTableType{props, indexer, metatable}); + return 1; +} + +// Luau: `self:setproperty(key: type, value: type?)` +// Sets the properties of a table +static int setTableProp(lua_State* L) +{ + int argumentCount = lua_gettop(L); + if (argumentCount < 2 || argumentCount > 3) + luaL_error(L, "type.setproperty: expected 2-3 arguments, but got %d", argumentCount); + + TypeFunctionTypeId self = getTypeUserData(L, 1); + auto tftt = getMutable(self); + if (!tftt) + luaL_error(L, "type.setproperty: expected self to be a table, but got %s instead", getTag(L, self).c_str()); + + TypeFunctionTypeId key = getTypeUserData(L, 2); + auto tfst = get(key); + if (!tfst) + luaL_error(L, "type.setproperty: expected to be given a singleton type, but got %s instead", getTag(L, key).c_str()); + + auto tfsst = get(tfst); + if (!tfsst) + luaL_error(L, "type.setproperty: expected to be given a string singleton type, but got %s instead", getTag(L, key).c_str()); + + if (argumentCount == 2 || lua_isnil(L, 3)) + { + tftt->props.erase(tfsst->value); + return 0; + } + + TypeFunctionTypeId value = getTypeUserData(L, 3); + tftt->props[tfsst->value] = TypeFunctionProperty::rw(value, value); + + return 0; +} + +// Luau: `self:setreadproperty(key: type, value: type?)` +// Sets the properties of a table +static int setReadTableProp(lua_State* L) +{ + int argumentCount = lua_gettop(L); + if (argumentCount < 2 || argumentCount > 3) + luaL_error(L, "type.setreadproperty: expected 2-3 arguments, but got %d", argumentCount); + + TypeFunctionTypeId self = getTypeUserData(L, 1); + auto tftt = getMutable(self); + if (!tftt) + luaL_error(L, "type.setreadproperty: expected self to be a table, but got %s instead", getTag(L, self).c_str()); + + TypeFunctionTypeId key = getTypeUserData(L, 2); + auto tfst = get(key); + if (!tfst) + luaL_error(L, "type.setreadproperty: expected to be given a singleton type, but got %s instead", getTag(L, key).c_str()); + + auto tfsst = get(tfst); + if (!tfsst) + luaL_error(L, "type.setreadproperty: expected to be given a string singleton type, but got %s instead", getTag(L, key).c_str()); + + auto iter = tftt->props.find(tfsst->value); + + if (argumentCount == 2 || lua_isnil(L, 3)) + { + // if it's read-only, remove it altogether + if (iter != tftt->props.end() && iter->second.isReadOnly()) + tftt->props.erase(tfsst->value); + // but if it's not, just null out the read type. + else if (iter != tftt->props.end()) + iter->second.readTy = std::nullopt; + + return 0; + } + + TypeFunctionTypeId value = getTypeUserData(L, 3); + if (iter == tftt->props.end()) + tftt->props[tfsst->value] = TypeFunctionProperty::readonly(value); + else + iter->second.readTy = value; + + return 0; +} + +// Luau: `self:setwriteproperty(key: type, value: type?)` +// Sets the properties of a table +static int setWriteTableProp(lua_State* L) +{ + int argumentCount = lua_gettop(L); + if (argumentCount < 2 || argumentCount > 3) + luaL_error(L, "type.setwriteproperty: expected 2-3 arguments, but got %d", argumentCount); + + TypeFunctionTypeId self = getTypeUserData(L, 1); + auto tftt = getMutable(self); + if (!tftt) + luaL_error(L, "type.setwriteproperty: expected self to be a table, but got %s instead", getTag(L, self).c_str()); + + TypeFunctionTypeId key = getTypeUserData(L, 2); + auto tfst = get(key); + if (!tfst) + luaL_error(L, "type.setwriteproperty: expected to be given a singleton type, but got %s instead", getTag(L, key).c_str()); + + auto tfsst = get(tfst); + if (!tfsst) + luaL_error(L, "type.setwriteproperty: expected to be given a string singleton type, but got %s instead", getTag(L, key).c_str()); + + auto iter = tftt->props.find(tfsst->value); + + if (argumentCount == 2 || lua_isnil(L, 3)) + { + // if it's write-only, remove it altogether + if (iter != tftt->props.end() && iter->second.isWriteOnly()) + tftt->props.erase(tfsst->value); + // but if it's not, just null out the write type. + else if (iter != tftt->props.end()) + iter->second.writeTy = std::nullopt; + + return 0; + } + + TypeFunctionTypeId value = getTypeUserData(L, 3); + if (iter == tftt->props.end()) + tftt->props[tfsst->value] = TypeFunctionProperty::writeonly(value); + else + iter->second.writeTy = value; + + return 0; +} + +// Luau: `self:readproperty(key: type) -> type` +// Returns the property of a table associated with the key +static int readTableProp(lua_State* L) +{ + int argumentCount = lua_gettop(L); + if (argumentCount != 2) + luaL_error(L, "type.readproperty: expected 2 arguments, but got %d", argumentCount); + + TypeFunctionTypeId self = getTypeUserData(L, 1); + auto tftt = get(self); + if (!tftt) + luaL_error(L, "type.readproperty: expected self to be either a table, but got %s instead", getTag(L, self).c_str()); + + TypeFunctionTypeId key = getTypeUserData(L, 2); + auto tfst = get(key); + if (!tfst) + luaL_error(L, "type.readproperty: expected to be given a singleton type, but got %s instead", getTag(L, key).c_str()); + + auto tfsst = get(tfst); + if (!tfsst) + luaL_error(L, "type.readproperty: expected to be given a string singleton type, but got %s instead", getTag(L, key).c_str()); + + // Check if key is a valid prop + if (tftt->props.find(tfsst->value) == tftt->props.end()) + { + lua_pushnil(L); + return 1; + } + + auto prop = tftt->props.at(tfsst->value); + if (prop.readTy) + allocTypeUserData(L, (*prop.readTy)->type); + else if (FFlag::LuauUserTypeFunFixNoReadWrite) + lua_pushnil(L); + else + luaL_error(L, "type.readproperty: property %s is write-only, and therefore does not have a read type.", tfsst->value.c_str()); + + return 1; +} +// +// Luau: `self:writeproperty(key: type) -> type` +// Returns the property of a table associated with the key +static int writeTableProp(lua_State* L) +{ + int argumentCount = lua_gettop(L); + if (argumentCount != 2) + luaL_error(L, "type.writeproperty: expected 2 arguments, but got %d", argumentCount); + + TypeFunctionTypeId self = getTypeUserData(L, 1); + auto tftt = get(self); + if (!tftt) + luaL_error(L, "type.writeproperty: expected self to be either a table, but got %s instead", getTag(L, self).c_str()); + + TypeFunctionTypeId key = getTypeUserData(L, 2); + auto tfst = get(key); + if (!tfst) + luaL_error(L, "type.writeproperty: expected to be given a singleton type, but got %s instead", getTag(L, key).c_str()); + + auto tfsst = get(tfst); + if (!tfsst) + luaL_error(L, "type.writeproperty: expected to be given a string singleton type, but got %s instead", getTag(L, key).c_str()); + + // Check if key is a valid prop + if (tftt->props.find(tfsst->value) == tftt->props.end()) + { + lua_pushnil(L); + return 1; + } + + auto prop = tftt->props.at(tfsst->value); + if (prop.writeTy) + allocTypeUserData(L, (*prop.writeTy)->type); + else if (FFlag::LuauUserTypeFunFixNoReadWrite) + lua_pushnil(L); + else + luaL_error(L, "type.writeproperty: property %s is read-only, and therefore does not have a write type.", tfsst->value.c_str()); + + return 1; +} + +// Luau: `self:setindexer(key: type, value: type)` +// Sets the indexer of the table, if the key type is `never`, the indexer is removed +static int setTableIndexer(lua_State* L) +{ + int argumentCount = lua_gettop(L); + if (argumentCount != 3) + luaL_error(L, "type.setindexer: expected 3 arguments, but got %d", argumentCount); + + TypeFunctionTypeId self = getTypeUserData(L, 1); + auto tftt = getMutable(self); + if (!tftt) + luaL_error(L, "type.setindexer: expected self to be either a table, but got %s instead", getTag(L, self).c_str()); + + TypeFunctionTypeId key = getTypeUserData(L, 2); + TypeFunctionTypeId value = getTypeUserData(L, 3); + + if (DFInt::LuauTypeSolverRelease >= 646) + { + if (auto tfnt = get(key)) + { + tftt->indexer = std::nullopt; + return 0; + } + } + + tftt->indexer = TypeFunctionTableIndexer{key, value}; + return 0; +} + +// Luau: `self:setreadindexer(key: type, value: type)` +// Sets the read indexer of the table +static int setTableReadIndexer(lua_State* L) +{ + luaL_error(L, "type.setreadindexer: luau does not yet support separate read/write types for indexers."); +} + +// Luau: `self:setwriteindexer(key: type, value: type)` +// Sets the write indexer of the table +static int setTableWriteIndexer(lua_State* L) +{ + luaL_error(L, "type.setwriteindexer: luau does not yet support separate read/write types for indexers."); +} + +// Luau: `self:setmetatable(arg: type)` +// Sets the metatable of the table +static int setTableMetatable(lua_State* L) +{ + int argumentCount = lua_gettop(L); + if (argumentCount != 2) + luaL_error(L, "type.setmetatable: expected 2 arguments, but got %d", argumentCount); + + TypeFunctionTypeId self = getTypeUserData(L, 1); + + auto tftt = getMutable(self); + if (!tftt) + luaL_error(L, "type.setmetatable: expected self to be a table, but got %s instead", getTag(L, self).c_str()); + + TypeFunctionTypeId arg = getTypeUserData(L, 2); + if (!get(arg)) + luaL_error(L, "type.setmetatable: expected the argument to be a table, but got %s instead", getTag(L, self).c_str()); + + tftt->metatable = arg; + + return 0; +} + +// Luau: `types.newfunction(parameters: {head: {type}?, tail: type?}, returns: {head: {type}?, tail: type?}) -> type` +// Returns the type instance representing a function +static int createFunction(lua_State* L) +{ + int argumentCount = lua_gettop(L); + if (argumentCount > 2) + luaL_error(L, "types.newfunction: expected 0-2 arguments, but got %d", argumentCount); + + TypeFunctionTypePackId argTypes = allocateTypeFunctionTypePack(L, TypeFunctionTypePack{}); + if (lua_istable(L, 1)) + { + std::vector head{}; + lua_getfield(L, 1, "head"); + if (lua_istable(L, -1)) + { + int argSize = lua_objlen(L, -1); + for (int i = 1; i <= argSize; i++) + { + lua_pushinteger(L, i); + lua_gettable(L, -2); + + if (lua_isnil(L, -1)) + { + lua_pop(L, 1); + break; + } + + TypeFunctionTypeId ty = getTypeUserData(L, -1); + head.push_back(ty); + + lua_pop(L, 1); // Remove `ty` from stack + } + } + lua_pop(L, 1); // Pop the "head" field + + std::optional tail; + lua_getfield(L, 1, "tail"); + if (auto type = optionalTypeUserData(L, -1)) + tail = allocateTypeFunctionTypePack(L, TypeFunctionVariadicTypePack{*type}); + lua_pop(L, 1); // Pop the "tail" field + + if (head.size() == 0 && tail.has_value()) + argTypes = *tail; + else + argTypes = allocateTypeFunctionTypePack(L, TypeFunctionTypePack{head, tail}); + } + else if (!lua_isnoneornil(L, 1)) + luaL_typeerrorL(L, 1, "table"); + + TypeFunctionTypePackId retTypes = allocateTypeFunctionTypePack(L, TypeFunctionTypePack{}); + if (lua_istable(L, 2)) + { + std::vector head{}; + lua_getfield(L, 2, "head"); + if (lua_istable(L, -1)) + { + int argSize = lua_objlen(L, -1); + for (int i = 1; i <= argSize; i++) + { + lua_pushinteger(L, i); + lua_gettable(L, -2); + + if (lua_isnil(L, -1)) + { + lua_pop(L, 1); + break; + } + + TypeFunctionTypeId ty = getTypeUserData(L, -1); + head.push_back(ty); + + lua_pop(L, 1); // Remove `ty` from stack + } + } + lua_pop(L, 1); // Pop the "head" field + + std::optional tail; + lua_getfield(L, 2, "tail"); + if (auto type = optionalTypeUserData(L, -1)) + tail = allocateTypeFunctionTypePack(L, TypeFunctionVariadicTypePack{*type}); + lua_pop(L, 1); // Pop the "tail" field + + if (head.size() == 0 && tail.has_value()) + retTypes = *tail; + else + retTypes = allocateTypeFunctionTypePack(L, TypeFunctionTypePack{head, tail}); + } + else if (!lua_isnoneornil(L, 2)) + luaL_typeerrorL(L, 2, "table"); + + allocTypeUserData(L, TypeFunctionFunctionType{argTypes, retTypes}); + + return 1; +} + +// Luau: `self:setparameters(head: {type}?, tail: type?)` +// Sets the parameters of the function +static int setFunctionParameters(lua_State* L) +{ + int argumentCount = lua_gettop(L); + if (argumentCount > 3 || argumentCount < 1) + luaL_error(L, "type.setparameters: expected 1-3, but got %d", argumentCount); + + TypeFunctionTypeId self = getTypeUserData(L, 1); + auto tfft = getMutable(self); + if (!tfft) + luaL_error(L, "type.setparameters: expected self to be a function, but got %s instead", getTag(L, self).c_str()); + + std::vector head{}; + if (lua_istable(L, 2)) + { + int argSize = lua_objlen(L, 2); + for (int i = 1; i <= argSize; i++) + { + lua_pushinteger(L, i); + lua_gettable(L, 2); + + if (lua_isnil(L, -1)) + { + lua_pop(L, 1); + break; + } + + TypeFunctionTypeId ty = getTypeUserData(L, -1); + head.push_back(ty); + + lua_pop(L, 1); // Remove `ty` from stack + } + } + else if (!lua_isnoneornil(L, 2)) + luaL_typeerrorL(L, 2, "table"); + + std::optional tail; + if (auto type = optionalTypeUserData(L, 3)) + tail = allocateTypeFunctionTypePack(L, TypeFunctionVariadicTypePack{*type}); + + if (head.size() == 0 && tail.has_value()) // Make argTypes a variadic type pack + tfft->argTypes = *tail; + else // Make argTypes a type pack + tfft->argTypes = allocateTypeFunctionTypePack(L, TypeFunctionTypePack{head, tail}); + + return 0; +} + +// Luau: `self:parameters() -> {head: {type}?, tail: type?}` +// Returns the parameters of the function +static int getFunctionParameters(lua_State* L) +{ + int argumentCount = lua_gettop(L); + if (argumentCount != 1) + luaL_error(L, "type.parameters: expected 1 arguments, but got %d", argumentCount); + + TypeFunctionTypeId self = getTypeUserData(L, 1); + auto tfft = get(self); + if (!tfft) + luaL_error(L, "type.parameters: expected self to be a function, but got %s instead", getTag(L, self).c_str()); + + if (auto tftp = get(tfft->argTypes)) + { + int size = 0; + if (tftp->head.size() > 0) + size++; + if (tftp->tail.has_value()) + size++; + + lua_createtable(L, 0, size); + + int argSize = (int)tftp->head.size(); + if (argSize > 0) + { + lua_createtable(L, argSize, 0); + for (int i = 0; i < argSize; i++) + { + allocTypeUserData(L, tftp->head[i]->type); + lua_rawseti(L, -2, i + 1); // Luau is 1-indexed while C++ is 0-indexed + } + lua_setfield(L, -2, "head"); + } + + if (tftp->tail.has_value()) + { + auto tfvp = get(*tftp->tail); + if (!tfvp) + LUAU_ASSERT(!"We should only be supporting variadic packs as TypeFunctionTypePack.tail at the moment"); + + allocTypeUserData(L, tfvp->type->type); + lua_setfield(L, -2, "tail"); + } + + return 1; + } + + if (auto tfvp = get(tfft->argTypes)) + { + lua_createtable(L, 0, 1); + + allocTypeUserData(L, tfvp->type->type); + lua_setfield(L, -2, "tail"); + + return 1; + } + + lua_createtable(L, 0, 0); + return 1; +} + +// Luau: `self:setreturns(head: {type}?, tail: type?)` +// Sets the returns of the function +static int setFunctionReturns(lua_State* L) +{ + int argumentCount = lua_gettop(L); + if (argumentCount < 2 || argumentCount > 3) + luaL_error(L, "type.setreturns: expected 1-3 arguments, but got %d", argumentCount); + + TypeFunctionTypeId self = getTypeUserData(L, 1); + auto tfft = getMutable(self); + if (!tfft) + luaL_error(L, "type.setreturns: expected self to be a function, but got %s instead", getTag(L, self).c_str()); + + std::vector head{}; + if (lua_istable(L, 2)) + { + int argSize = lua_objlen(L, 2); + for (int i = 1; i <= argSize; i++) + { + lua_pushinteger(L, i); + lua_gettable(L, 2); + + if (lua_isnil(L, -1)) + { + lua_pop(L, 1); + break; + } + + TypeFunctionTypeId ty = getTypeUserData(L, -1); + head.push_back(ty); + + lua_pop(L, 1); // Remove `ty` from stack + } + } + else if (!lua_isnoneornil(L, 2)) + luaL_typeerrorL(L, 2, "table"); + + std::optional tail; + if (auto type = optionalTypeUserData(L, 3)) + tail = allocateTypeFunctionTypePack(L, TypeFunctionVariadicTypePack{*type}); + + if (head.size() == 0 && tail.has_value()) // Make retTypes a variadic type pack + tfft->retTypes = *tail; + else // Make retTypes a type pack + tfft->retTypes = allocateTypeFunctionTypePack(L, TypeFunctionTypePack{head, tail}); + + return 0; +} + +// Luau: `self:returns() -> {head: {type}?, tail: type?}` +// Returns the returns of the function +static int getFunctionReturns(lua_State* L) +{ + int argumentCount = lua_gettop(L); + if (argumentCount != 1) + luaL_error(L, "type.returns: expected 1 arguments, but got %d", argumentCount); + + TypeFunctionTypeId self = getTypeUserData(L, 1); + auto tfft = get(self); + if (!tfft) + luaL_error(L, "type.returns: expected self to be a function, but got %s instead", getTag(L, self).c_str()); + + if (auto tftp = get(tfft->retTypes)) + { + int size = 0; + if (tftp->head.size() > 0) + size++; + if (tftp->tail.has_value()) + size++; + + lua_createtable(L, 0, size); + + int argSize = (int)tftp->head.size(); + if (argSize > 0) + { + lua_createtable(L, argSize, 0); + for (int i = 0; i < argSize; i++) + { + allocTypeUserData(L, tftp->head[i]->type); + lua_rawseti(L, -2, i + 1); // Luau is 1-indexed while C++ is 0-indexed + } + lua_setfield(L, -2, "head"); + } + + if (tftp->tail.has_value()) + { + auto tfvp = get(*tftp->tail); + if (!tfvp) + LUAU_ASSERT(!"We should only be supporting variadic packs as TypeFunctionTypePack.tail at the moment"); + + allocTypeUserData(L, tfvp->type->type); + lua_setfield(L, -2, "tail"); + } + + return 1; + } + + if (auto tfvp = get(tfft->retTypes)) + { + lua_createtable(L, 0, 1); + + allocTypeUserData(L, tfvp->type->type); + lua_setfield(L, -2, "tail"); + + return 1; + } + + lua_createtable(L, 0, 0); + return 1; +} + +// Luau: `self:parent() -> type` +// Returns the parent of a class type +static int getClassParent(lua_State* L) +{ + int argumentCount = lua_gettop(L); + if (argumentCount != 1) + luaL_error(L, "type.parent: expected 1 arguments, but got %d", argumentCount); + + TypeFunctionTypeId self = getTypeUserData(L, 1); + auto tfct = get(self); + if (!tfct) + luaL_error(L, "type.parent: expected self to be a class, but got %s instead", getTag(L, self).c_str()); + + // If the parent does not exist, we should return nil + if (!tfct->parent) + lua_pushnil(L); + else + allocTypeUserData(L, (*tfct->parent)->type); + + return 1; +} + +// Luau: `self:properties() -> {[type]: { read: type?, write: type? }}` +// Returns the properties of a table or class type +static int getProps(lua_State* L) +{ + int argumentCount = lua_gettop(L); + if (argumentCount != 1) + luaL_error(L, "type.properties: expected 1 arguments, but got %d", argumentCount); + + TypeFunctionTypeId self = getTypeUserData(L, 1); + if (auto tftt = get(self)) + { + lua_createtable(L, int(tftt->props.size()), 0); + for (auto& [name, prop] : tftt->props) + { + allocTypeUserData(L, TypeFunctionSingletonType{TypeFunctionStringSingleton{name}}); + + int size = 0; + if (prop.readTy) + size++; + if (prop.writeTy) + size++; + + lua_createtable(L, 0, size); + if (prop.readTy) + { + allocTypeUserData(L, (*prop.readTy)->type); + lua_setfield(L, -2, "read"); + } + + if (prop.writeTy) + { + allocTypeUserData(L, (*prop.writeTy)->type); + lua_setfield(L, -2, "write"); + } + + lua_settable(L, -3); + } + + return 1; + } + + if (auto tfct = get(self)) + { + lua_createtable(L, int(tfct->props.size()), 0); + for (auto& [name, prop] : tfct->props) + { + allocTypeUserData(L, TypeFunctionSingletonType{TypeFunctionStringSingleton{name}}); + + int size = 0; + if (prop.readTy) + size++; + if (prop.writeTy) + size++; + + lua_createtable(L, 0, size); + if (prop.readTy) + { + allocTypeUserData(L, (*prop.readTy)->type); + lua_setfield(L, -2, "read"); + } + + if (prop.writeTy) + { + allocTypeUserData(L, (*prop.writeTy)->type); + lua_setfield(L, -2, "write"); + } + + lua_settable(L, -3); + } + + return 1; + } + + luaL_error(L, "type.properties: expected self to be either a table or class, but got %s instead", getTag(L, self).c_str()); +} + +// Luau: `self:indexer() -> {index: type, readresult: type, writeresult: type}?` +// Returns the indexer of a table or class type +static int getIndexer(lua_State* L) +{ + int argumentCount = lua_gettop(L); + if (argumentCount != 1) + luaL_error(L, "type.indexer: expected 1 arguments, but got %d", argumentCount); + + TypeFunctionTypeId self = getTypeUserData(L, 1); + if (auto tftt = get(self)) + { + // if the indexer does not exist, we should return nil + if (!tftt->indexer.has_value()) + lua_pushnil(L); + else + { + lua_createtable(L, 0, 3); + allocTypeUserData(L, tftt->indexer->keyType->type); + lua_setfield(L, -2, "index"); + allocTypeUserData(L, tftt->indexer->valueType->type); + lua_setfield(L, -2, "readresult"); + allocTypeUserData(L, tftt->indexer->valueType->type); + lua_setfield(L, -2, "writeresult"); + } + + return 1; + } + + if (auto tfct = get(self)) + { + // if the indexer does not exist, we should return nil + if (!tfct->indexer.has_value()) + lua_pushnil(L); + else + { + lua_createtable(L, 0, 3); + allocTypeUserData(L, tfct->indexer->keyType->type); + lua_setfield(L, -2, "index"); + allocTypeUserData(L, tfct->indexer->valueType->type); + lua_setfield(L, -2, "readresult"); + allocTypeUserData(L, tfct->indexer->valueType->type); + lua_setfield(L, -2, "writeresult"); + } + + return 1; + } + + luaL_error(L, "type.indexer: self to be either a table or class, but got %s instead", getTag(L, self).c_str()); +} + +// Luau: `self:readindexer() -> {index: type, result: type}?` +// Returns the read indexer of a table or class type +static int getReadIndexer(lua_State* L) +{ + int argumentCount = lua_gettop(L); + if (argumentCount != 1) + luaL_error(L, "type.readindexer: expected 1 arguments, but got %d", argumentCount); + + TypeFunctionTypeId self = getTypeUserData(L, 1); + if (auto tftt = get(self)) + { + // if the indexer does not exist, we should return nil + if (!tftt->indexer.has_value()) + lua_pushnil(L); + else + { + lua_createtable(L, 0, 2); + allocTypeUserData(L, tftt->indexer->keyType->type); + lua_setfield(L, -2, "index"); + allocTypeUserData(L, tftt->indexer->valueType->type); + lua_setfield(L, -2, "result"); + } + + return 1; + } + + if (auto tfct = get(self)) + { + // if the indexer does not exist, we should return nil + if (!tfct->indexer.has_value()) + lua_pushnil(L); + else + { + lua_createtable(L, 0, 2); + allocTypeUserData(L, tfct->indexer->keyType->type); + lua_setfield(L, -2, "index"); + allocTypeUserData(L, tfct->indexer->valueType->type); + lua_setfield(L, -2, "result"); + } + + return 1; + } + + luaL_error(L, "type.readindexer: expected self to be either a table or class, but got %s instead", getTag(L, self).c_str()); +} + +// Luau: `self:writeindexer() -> {index: type, result: type}?` +// Returns the write indexer of a table or class type +static int getWriteIndexer(lua_State* L) +{ + int argumentCount = lua_gettop(L); + if (argumentCount != 1) + luaL_error(L, "type.writeindexer: expected 1 arguments, but got %d", argumentCount); + + TypeFunctionTypeId self = getTypeUserData(L, 1); + if (auto tftt = get(self)) + { + // if the indexer does not exist, we should return nil + if (!tftt->indexer.has_value()) + lua_pushnil(L); + else + { + lua_createtable(L, 0, 2); + allocTypeUserData(L, tftt->indexer->keyType->type); + lua_setfield(L, -2, "index"); + allocTypeUserData(L, tftt->indexer->valueType->type); + lua_setfield(L, -2, "result"); + } + + return 1; + } + + if (auto tfct = get(self)) + { + // if the indexer does not exist, we should return nil + if (!tfct->indexer.has_value()) + lua_pushnil(L); + else + { + lua_createtable(L, 0, 2); + allocTypeUserData(L, tfct->indexer->keyType->type); + lua_setfield(L, -2, "index"); + allocTypeUserData(L, tfct->indexer->valueType->type); + lua_setfield(L, -2, "result"); + } + + return 1; + } + + luaL_error(L, "type.writeindexer: expected self to be either a table or class, but got %s instead", getTag(L, self).c_str()); +} + +// Luau: `self:metatable() -> type?` +// Returns the metatable of a table or class type +static int getMetatable(lua_State* L) +{ + int argumentCount = lua_gettop(L); + if (argumentCount != 1) + luaL_error(L, "type.metatable: expected 1 arguments, but got %d", argumentCount); + + TypeFunctionTypeId self = getTypeUserData(L, 1); + if (auto tfmt = get(self)) + { + // if the metatable does not exist, we should return nil + if (!tfmt->metatable.has_value()) + lua_pushnil(L); + else + allocTypeUserData(L, (*tfmt->metatable)->type); + + return 1; + } + + if (auto tfct = get(self)) + { + // if the metatable does not exist, we should return nil + if (!tfct->metatable.has_value()) + lua_pushnil(L); + else + allocTypeUserData(L, (*tfct->metatable)->type); + + return 1; + } + + luaL_error(L, "type.metatable: expected self to be a table or class, but got %s instead", getTag(L, self).c_str()); +} + +// Luau: `self:is(arg: string) -> boolean` +// Returns true if given argument is a tag of self +static int checkTag(lua_State* L) +{ + int argumentCount = lua_gettop(L); + if (argumentCount != 2) + luaL_error(L, "type.is: expected 2 arguments, but got %d", argumentCount); + + TypeFunctionTypeId self = getTypeUserData(L, 1); + std::string arg = luaL_checkstring(L, 2); + + lua_pushboolean(L, getTag(L, self) == arg); + return 1; +} + +TypeFunctionTypeId deepClone(NotNull runtime, TypeFunctionTypeId ty); // Forward declaration + +// Luau: `types.copy(arg: string) -> type` +// Returns a deep copy of the argument +static int deepCopy(lua_State* L) +{ + int argumentCount = lua_gettop(L); + if (argumentCount != 1) + luaL_error(L, "types.copy: expected 1 arguments, but got %d", argumentCount); + + TypeFunctionTypeId arg = getTypeUserData(L, 1); + + TypeFunctionTypeId copy = deepClone(NotNull{getTypeFunctionRuntime(L)}, arg); + allocTypeUserData(L, copy->type); + return 1; +} + +// Luau: `self == arg -> boolean` +// Used to set the __eq metamethod +static int isEqualToType(lua_State* L) +{ + int argumentCount = lua_gettop(L); + if (argumentCount != 2) + luaL_error(L, "expected 2 arguments, but got %d", argumentCount); + + TypeFunctionTypeId self = getTypeUserData(L, 1); + TypeFunctionTypeId arg = getTypeUserData(L, 2); + + lua_pushboolean(L, *self == *arg); + return 1; +} + +void registerTypesLibrary(lua_State* L) +{ + LUAU_ASSERT(FFlag::LuauUserTypeFunFixRegister); + + luaL_Reg fields[] = { + {"unknown", createUnknown}, + {"never", createNever}, + {"any", createAny}, + {"boolean", createBoolean}, + {"number", createNumber}, + {"string", createString}, + {nullptr, nullptr} + }; + + luaL_Reg methods[] = { + {"singleton", createSingleton}, + {"negationof", createNegation}, + {"unionof", createUnion}, + {"intersectionof", createIntersection}, + {"newtable", createTable}, + {"newfunction", createFunction}, + {"copy", deepCopy}, + + {nullptr, nullptr} + }; + + luaL_register(L, "types", methods); + + // Set fields for type userdata + for (luaL_Reg* l = fields; l->name; l++) + { + l->func(L); + lua_setfield(L, -2, l->name); + } + + lua_pop(L, 1); +} + +static int typeUserdataIndex(lua_State* L) +{ + TypeFunctionTypeId self = getTypeUserData(L, 1); + const char* field = luaL_checkstring(L, 2); + + if (strcmp(field, "tag") == 0) + { + lua_pushstring(L, getTag(L, self).c_str()); + return 1; + } + + lua_pushvalue(L, lua_upvalueindex(1)); + lua_getfield(L, -1, field); + return 1; +} + +void registerTypeUserData(lua_State* L) +{ + if (FFlag::LuauUserTypeFunFixRegister) + { + luaL_Reg typeUserdataMethods[] = { + {"is", checkTag}, + + // Negation type methods + {"inner", getNegatedValue}, + + // Singleton type methods + {"value", getSingletonValue}, + + // Table type methods + {"setproperty", setTableProp}, + {"setreadproperty", setReadTableProp}, + {"setwriteproperty", setWriteTableProp}, + {"readproperty", readTableProp}, + {"writeproperty", writeTableProp}, + {"properties", getProps}, + {"setindexer", setTableIndexer}, + {"setreadindexer", setTableReadIndexer}, + {"setwriteindexer", setTableWriteIndexer}, + {"indexer", getIndexer}, + {"readindexer", getReadIndexer}, + {"writeindexer", getWriteIndexer}, + {"setmetatable", setTableMetatable}, + {"metatable", getMetatable}, + + // Function type methods + {"setparameters", setFunctionParameters}, + {"parameters", getFunctionParameters}, + {"setreturns", setFunctionReturns}, + {"returns", getFunctionReturns}, + + // Union and Intersection type methods + {"components", getComponents}, + + // Class type methods + {"parent", getClassParent}, + + {nullptr, nullptr} + }; + + // Create and register metatable for type userdata + luaL_newmetatable(L, "type"); + + // Protect metatable from being changed + lua_pushstring(L, "The metatable is locked"); + lua_setfield(L, -2, "__metatable"); + + lua_pushcfunction(L, isEqualToType, "__eq"); + lua_setfield(L, -2, "__eq"); + + // Indexing will be a dynamic function because some type fields are dynamic + lua_newtable(L); + luaL_register(L, nullptr, typeUserdataMethods); + lua_setreadonly(L, -1, true); + lua_pushcclosure(L, typeUserdataIndex, "__index", 1); + lua_setfield(L, -2, "__index"); + + lua_setreadonly(L, -1, true); + lua_pop(L, 1); + } + else + { + // List of fields for type userdata + luaL_Reg typeUserdataFields[] = { + {"unknown", createUnknown}, + {"never", createNever}, + {"any", createAny}, + {"boolean", createBoolean}, + {"number", createNumber}, + {"string", createString}, + {nullptr, nullptr} + }; + + // List of methods for type userdata + luaL_Reg typeUserdataMethods[] = { + {"singleton", createSingleton}, + {"negationof", createNegation}, + {"unionof", createUnion}, + {"intersectionof", createIntersection}, + {"newtable", createTable}, + {"newfunction", createFunction}, + {"copy", deepCopy}, + + // Common methods + {"is", checkTag}, + + // Negation type methods + {"inner", getNegatedValue}, + + // Singleton type methods + {"value", getSingletonValue}, + + // Table type methods + {"setproperty", setTableProp}, + {"setreadproperty", setReadTableProp}, + {"setwriteproperty", setWriteTableProp}, + {"readproperty", readTableProp}, + {"writeproperty", writeTableProp}, + {"properties", getProps}, + {"setindexer", setTableIndexer}, + {"setreadindexer", setTableReadIndexer}, + {"setwriteindexer", setTableWriteIndexer}, + {"indexer", getIndexer}, + {"readindexer", getReadIndexer}, + {"writeindexer", getWriteIndexer}, + {"setmetatable", setTableMetatable}, + {"metatable", getMetatable}, + + // Function type methods + {"setparameters", setFunctionParameters}, + {"parameters", getFunctionParameters}, + {"setreturns", setFunctionReturns}, + {"returns", getFunctionReturns}, + + // Union and Intersection type methods + {"components", getComponents}, + + // Class type methods + {"parent", getClassParent}, + {"indexer", getIndexer}, + {nullptr, nullptr} + }; + + // Create and register metatable for type userdata + luaL_newmetatable(L, "type"); + + // Protect metatable from being fetched. + lua_pushstring(L, "The metatable is locked"); + lua_setfield(L, -2, "__metatable"); + + // Set type userdata metatable's __eq to type_equals() + lua_pushcfunction(L, isEqualToType, "__eq"); + lua_setfield(L, -2, "__eq"); + + // Set type userdata metatable's __index to itself + lua_pushvalue(L, -1); // Push a copy of type userdata metatable + lua_setfield(L, -2, "__index"); + + luaL_register(L, nullptr, typeUserdataMethods); + + // Set fields for type userdata + for (luaL_Reg* l = typeUserdataFields; l->name; l++) + { + l->func(L); + lua_setfield(L, -2, l->name); + } + + // Set types library as a global name "types" + lua_setglobal(L, "types"); + } + + // Sets up a destructor for the type userdata. + lua_setuserdatadtor(L, kTypeUserdataTag, deallocTypeUserData); +} + +// Used to redirect all the removed global functions to say "this function is unsupported" +int unsupportedFunction(lua_State* L) +{ + luaL_errorL(L, "this function is not supported in type functions"); + return 0; +} + +// Add libraries / globals for type function environment +void setTypeFunctionEnvironment(lua_State* L) +{ + // Register math library + luaopen_math(L); + lua_pop(L, 1); + + // Register table library + luaopen_table(L); + lua_pop(L, 1); + + // Register string library + luaopen_string(L); + lua_pop(L, 1); + + // Register bit32 library + luaopen_bit32(L); + lua_pop(L, 1); + + // Register utf8 library + luaopen_utf8(L); + lua_pop(L, 1); + + // Register buffer library + luaopen_buffer(L); + lua_pop(L, 1); + + // Register base library + luaopen_base(L); + lua_pop(L, 1); + + // Remove certain global functions from the base library + static const std::string unavailableGlobals[] = {"gcinfo", "getfenv", "newproxy", "setfenv", "pcall", "xpcall"}; + for (auto& name : unavailableGlobals) + { + lua_pushcfunction(L, unsupportedFunction, "Removing global function from type function environment"); + lua_setglobal(L, name.c_str()); + } +} + +/* + * Below are helper methods for __eq + * Same as one from Type.cpp + */ +using SeenSet = std::set>; +bool areEqual(SeenSet& seen, const TypeFunctionType& lhs, const TypeFunctionType& rhs); +bool areEqual(SeenSet& seen, const TypeFunctionTypePackVar& lhs, const TypeFunctionTypePackVar& rhs); + +bool seenSetContains(SeenSet& seen, const void* lhs, const void* rhs) +{ + if (lhs == rhs) + return true; + + auto p = std::make_pair(lhs, rhs); + if (seen.find(p) != seen.end()) + return true; + + seen.insert(p); + return false; +} + +bool areEqual(SeenSet& seen, const TypeFunctionSingletonType& lhs, const TypeFunctionSingletonType& rhs) +{ + if (seenSetContains(seen, &lhs, &rhs)) + return true; + + { + const TypeFunctionBooleanSingleton* lp = get(&lhs); + const TypeFunctionBooleanSingleton* rp = get(&lhs); + if (lp && rp) + return lp->value == rp->value; + } + + { + const TypeFunctionStringSingleton* lp = get(&lhs); + const TypeFunctionStringSingleton* rp = get(&lhs); + if (lp && rp) + return lp->value == rp->value; + } + + return false; +} + +bool areEqual(SeenSet& seen, const TypeFunctionUnionType& lhs, const TypeFunctionUnionType& rhs) +{ + if (seenSetContains(seen, &lhs, &rhs)) + return true; + + if (lhs.components.size() != rhs.components.size()) + return false; + + auto l = lhs.components.begin(); + auto r = rhs.components.begin(); + + while (l != lhs.components.end()) + { + if (!areEqual(seen, **l, **r)) + return false; + ++l; + ++r; + } + + return true; +} + +bool areEqual(SeenSet& seen, const TypeFunctionIntersectionType& lhs, const TypeFunctionIntersectionType& rhs) +{ + if (seenSetContains(seen, &lhs, &rhs)) + return true; + + if (lhs.components.size() != rhs.components.size()) + return false; + + auto l = lhs.components.begin(); + auto r = rhs.components.begin(); + + while (l != lhs.components.end()) + { + if (!areEqual(seen, **l, **r)) + return false; + ++l; + ++r; + } + + return true; +} + +bool areEqual(SeenSet& seen, const TypeFunctionNegationType& lhs, const TypeFunctionNegationType& rhs) +{ + if (seenSetContains(seen, &lhs, &rhs)) + return true; + + return areEqual(seen, *lhs.type, *rhs.type); +} + +bool areEqual(SeenSet& seen, const TypeFunctionTableType& lhs, const TypeFunctionTableType& rhs) +{ + if (seenSetContains(seen, &lhs, &rhs)) + return true; + + if (lhs.props.size() != rhs.props.size()) + return false; + + if (bool(lhs.indexer) != bool(rhs.indexer)) + return false; + + if (lhs.indexer && rhs.indexer) + { + if (!areEqual(seen, *lhs.indexer->keyType, *rhs.indexer->keyType)) + return false; + + if (!areEqual(seen, *lhs.indexer->valueType, *rhs.indexer->valueType)) + return false; + } + + auto l = lhs.props.begin(); + auto r = rhs.props.begin(); + + while (l != lhs.props.end()) + { + if ((l->second.readTy && !r->second.readTy) || (!l->second.readTy && r->second.readTy)) + return false; + + if (l->second.readTy && r->second.readTy && !areEqual(seen, **(l->second.readTy), **(r->second.readTy))) + return false; + + if ((l->second.writeTy && !r->second.writeTy) || (!l->second.writeTy && r->second.writeTy)) + return false; + + if (l->second.writeTy && r->second.writeTy && !areEqual(seen, **(l->second.writeTy), **(r->second.writeTy))) + return false; + + ++l; + ++r; + } + + return true; +} + +bool areEqual(SeenSet& seen, const TypeFunctionFunctionType& lhs, const TypeFunctionFunctionType& rhs) +{ + if (seenSetContains(seen, &lhs, &rhs)) + return true; + + if (bool(lhs.argTypes) != bool(rhs.argTypes)) + return false; + + if (lhs.argTypes && rhs.argTypes) + { + if (!areEqual(seen, *lhs.argTypes, *rhs.argTypes)) + return false; + } + + if (bool(lhs.retTypes) != bool(rhs.retTypes)) + return false; + + if (lhs.retTypes && rhs.retTypes) + { + if (!areEqual(seen, *lhs.retTypes, *rhs.retTypes)) + return false; + } + + return true; +} + +bool areEqual(SeenSet& seen, const TypeFunctionClassType& lhs, const TypeFunctionClassType& rhs) +{ + if (seenSetContains(seen, &lhs, &rhs)) + return true; + + return lhs.name == rhs.name; +} + +bool areEqual(SeenSet& seen, const TypeFunctionType& lhs, const TypeFunctionType& rhs) +{ + + if (lhs.type.index() != rhs.type.index()) + return false; + + { + const TypeFunctionPrimitiveType* lp = get(&lhs); + const TypeFunctionPrimitiveType* rp = get(&rhs); + if (lp && rp) + return lp->type == rp->type; + } + + if (get(&lhs) && get(&rhs)) + return true; + + if (get(&lhs) && get(&rhs)) + return true; + + if (get(&lhs) && get(&rhs)) + return true; + + { + const TypeFunctionSingletonType* lf = get(&lhs); + const TypeFunctionSingletonType* rf = get(&rhs); + if (lf && rf) + return areEqual(seen, *lf, *rf); + } + + { + const TypeFunctionUnionType* lf = get(&lhs); + const TypeFunctionUnionType* rf = get(&rhs); + if (lf && rf) + return areEqual(seen, *lf, *rf); + } + + { + const TypeFunctionIntersectionType* lf = get(&lhs); + const TypeFunctionIntersectionType* rf = get(&rhs); + if (lf && rf) + return areEqual(seen, *lf, *rf); + } + + { + const TypeFunctionNegationType* lf = get(&lhs); + const TypeFunctionNegationType* rf = get(&rhs); + if (lf && rf) + return areEqual(seen, *lf, *rf); + } + + { + const TypeFunctionTableType* lt = get(&lhs); + const TypeFunctionTableType* rt = get(&rhs); + if (lt && rt) + return areEqual(seen, *lt, *rt); + } + + { + const TypeFunctionFunctionType* lf = get(&lhs); + const TypeFunctionFunctionType* rf = get(&rhs); + if (lf && rf) + return areEqual(seen, *lf, *rf); + } + + { + const TypeFunctionClassType* lf = get(&lhs); + const TypeFunctionClassType* rf = get(&rhs); + if (lf && rf) + return areEqual(seen, *lf, *rf); + } + + return false; +} + +bool areEqual(SeenSet& seen, const TypeFunctionTypePack& lhs, const TypeFunctionTypePack& rhs) +{ + if (lhs.head.size() != rhs.head.size()) + return false; + + auto l = lhs.head.begin(); + auto r = rhs.head.begin(); + + while (l != lhs.head.end()) + { + if (!areEqual(seen, **l, **r)) + return false; + ++l; + ++r; + } + + return true; +} + +bool areEqual(SeenSet& seen, const TypeFunctionVariadicTypePack& lhs, const TypeFunctionVariadicTypePack& rhs) +{ + if (seenSetContains(seen, &lhs, &rhs)) + return true; + + return areEqual(seen, *lhs.type, *rhs.type); +} + +bool areEqual(SeenSet& seen, const TypeFunctionTypePackVar& lhs, const TypeFunctionTypePackVar& rhs) +{ + { + const TypeFunctionTypePack* lb = get(&lhs); + const TypeFunctionTypePack* rb = get(&rhs); + if (lb && rb) + return areEqual(seen, *lb, *rb); + } + + { + const TypeFunctionVariadicTypePack* lv = get(&lhs); + const TypeFunctionVariadicTypePack* rv = get(&rhs); + if (lv && rv) + return areEqual(seen, *lv, *rv); + } + + return false; +} + +bool TypeFunctionType::operator==(const TypeFunctionType& rhs) const +{ + SeenSet seen; + return areEqual(seen, *this, rhs); +} + +bool TypeFunctionTypePackVar::operator==(const TypeFunctionTypePackVar& rhs) const +{ + SeenSet seen; + return areEqual(seen, *this, rhs); +} + + +TypeFunctionProperty TypeFunctionProperty::readonly(TypeFunctionTypeId ty) +{ + TypeFunctionProperty p; + p.readTy = ty; + return p; +} + +TypeFunctionProperty TypeFunctionProperty::writeonly(TypeFunctionTypeId ty) +{ + TypeFunctionProperty p; + p.writeTy = ty; + return p; +} + +TypeFunctionProperty TypeFunctionProperty::rw(TypeFunctionTypeId ty) +{ + return TypeFunctionProperty::rw(ty, ty); +} + +TypeFunctionProperty TypeFunctionProperty::rw(TypeFunctionTypeId read, TypeFunctionTypeId write) +{ + TypeFunctionProperty p; + p.readTy = read; + p.writeTy = write; + return p; +} + +bool TypeFunctionProperty::isReadOnly() const +{ + return readTy && !writeTy; +} + +bool TypeFunctionProperty::isWriteOnly() const +{ + return writeTy && !readTy; +} + +/* + * Below is a helper class for type.copy() + * Forked version of Clone.cpp + */ +using TypeFunctionKind = Variant; + +template +const T* get(const TypeFunctionKind& kind) +{ + return get_if(&kind); +} + +class TypeFunctionCloner +{ + using SeenTypes = DenseHashMap; + using SeenTypePacks = DenseHashMap; + + NotNull typeFunctionRuntime; + + // A queue of TypeFunctionTypeIds that have been cloned, but whose interior types hasn't + // been updated to point to itself. Once all of its interior types + // has been updated, it gets removed from the queue. + + // queue.back() should always return two of same type in their respective sides + // For example `auto [first, second] = queue.back()`: if first is TypeFunctionPrimitiveType, + // second must be TypeFunctionPrimitiveType; `second` is trying to copy `first` + std::vector> queue; + + SeenTypes types{{}}; // Mapping of TypeFunctionTypeIds that have been shallow cloned to TypeFunctionTypeIds + SeenTypePacks packs{{}}; // Mapping of TypeFunctionTypePackIds that have been shallow cloned to TypeFunctionTypePackIds + + int steps = 0; + +public: + explicit TypeFunctionCloner(TypeFunctionRuntime* typeFunctionRuntime) + : typeFunctionRuntime(typeFunctionRuntime) + { + } + + TypeFunctionTypeId clone(TypeFunctionTypeId ty) + { + shallowClone(ty); + run(); + + if (hasExceededIterationLimit()) + return nullptr; + + return find(ty).value_or(nullptr); + } + + TypeFunctionTypePackId clone(TypeFunctionTypePackId tp) + { + shallowClone(tp); + run(); + + if (hasExceededIterationLimit()) + return nullptr; + + return find(tp).value_or(nullptr); + } + +private: + bool hasExceededIterationLimit() const + { + return steps + queue.size() >= (size_t)DFInt::LuauTypeFunctionSerdeIterationLimit; + } + + void run() + { + while (!queue.empty()) + { + ++steps; + + if (hasExceededIterationLimit()) + break; + + auto [ty, tfti] = queue.back(); + queue.pop_back(); + + cloneChildren(ty, tfti); + } + } + + std::optional find(TypeFunctionTypeId ty) const + { + if (auto result = types.find(ty)) + return *result; + + return std::nullopt; + } + + std::optional find(TypeFunctionTypePackId tp) const + { + if (auto result = packs.find(tp)) + return *result; + + return std::nullopt; + } + + std::optional find(TypeFunctionKind kind) const + { + if (auto ty = get(kind)) + return find(*ty); + else if (auto tp = get(kind)) + return find(*tp); + else + { + LUAU_ASSERT(!"Unknown kind?"); + return std::nullopt; + } + } + + TypeFunctionTypeId shallowClone(TypeFunctionTypeId ty) + { + if (auto it = find(ty)) + return *it; + + // Create a shallow serialization + TypeFunctionTypeId target = {}; + if (auto p = get(ty)) + { + switch (p->type) + { + case TypeFunctionPrimitiveType::Type::NilType: + target = typeFunctionRuntime->typeArena.allocate(TypeFunctionPrimitiveType(TypeFunctionPrimitiveType::NilType)); + break; + case TypeFunctionPrimitiveType::Type::Boolean: + target = typeFunctionRuntime->typeArena.allocate(TypeFunctionPrimitiveType(TypeFunctionPrimitiveType::Boolean)); + break; + case TypeFunctionPrimitiveType::Number: + target = typeFunctionRuntime->typeArena.allocate(TypeFunctionPrimitiveType(TypeFunctionPrimitiveType::Number)); + break; + case TypeFunctionPrimitiveType::String: + target = typeFunctionRuntime->typeArena.allocate(TypeFunctionPrimitiveType(TypeFunctionPrimitiveType::String)); + break; + default: + break; + } + } + else if (auto u = get(ty)) + target = typeFunctionRuntime->typeArena.allocate(TypeFunctionUnknownType{}); + else if (auto a = get(ty)) + target = typeFunctionRuntime->typeArena.allocate(TypeFunctionNeverType{}); + else if (auto a = get(ty)) + target = typeFunctionRuntime->typeArena.allocate(TypeFunctionAnyType{}); + else if (auto s = get(ty)) + { + if (auto bs = get(s)) + target = typeFunctionRuntime->typeArena.allocate(TypeFunctionSingletonType{TypeFunctionBooleanSingleton{bs->value}}); + else if (auto ss = get(s)) + target = typeFunctionRuntime->typeArena.allocate(TypeFunctionSingletonType{TypeFunctionStringSingleton{ss->value}}); + } + else if (auto u = get(ty)) + target = typeFunctionRuntime->typeArena.allocate(TypeFunctionUnionType{{}}); + else if (auto i = get(ty)) + target = typeFunctionRuntime->typeArena.allocate(TypeFunctionIntersectionType{{}}); + else if (auto n = get(ty)) + target = typeFunctionRuntime->typeArena.allocate(TypeFunctionNegationType{{}}); + else if (auto t = get(ty)) + target = typeFunctionRuntime->typeArena.allocate(TypeFunctionTableType{{}, std::nullopt, std::nullopt}); + else if (auto f = get(ty)) + { + TypeFunctionTypePackId emptyTypePack = typeFunctionRuntime->typePackArena.allocate(TypeFunctionTypePack{}); + target = typeFunctionRuntime->typeArena.allocate(TypeFunctionFunctionType{emptyTypePack, emptyTypePack}); + } + else if (auto c = get(ty)) + target = ty; // Don't copy a class since they are immutable + + types[ty] = target; + queue.emplace_back(ty, target); + return target; + } + + TypeFunctionTypePackId shallowClone(TypeFunctionTypePackId tp) + { + if (auto it = find(tp)) + return *it; + + // Create a shallow serialization + TypeFunctionTypePackId target = {}; + if (auto tPack = get(tp)) + target = typeFunctionRuntime->typePackArena.allocate(TypeFunctionTypePack{{}}); + else if (auto vPack = get(tp)) + target = typeFunctionRuntime->typePackArena.allocate(TypeFunctionVariadicTypePack{}); + + packs[tp] = target; + queue.emplace_back(tp, target); + return target; + } + + void cloneChildren(TypeFunctionTypeId ty, TypeFunctionTypeId tfti) + { + if (auto [p1, p2] = std::tuple{getMutable(ty), getMutable(tfti)}; p1 && p2) + cloneChildren(p1, p2); + else if (auto [u1, u2] = std::tuple{getMutable(ty), getMutable(tfti)}; u1 && u2) + cloneChildren(u1, u2); + else if (auto [n1, n2] = std::tuple{getMutable(ty), getMutable(tfti)}; n1 && n2) + cloneChildren(n1, n2); + else if (auto [a1, a2] = std::tuple{getMutable(ty), getMutable(tfti)}; a1 && a2) + cloneChildren(a1, a2); + else if (auto [s1, s2] = std::tuple{getMutable(ty), getMutable(tfti)}; s1 && s2) + cloneChildren(s1, s2); + else if (auto [u1, u2] = std::tuple{getMutable(ty), getMutable(tfti)}; u1 && u2) + cloneChildren(u1, u2); + else if (auto [i1, i2] = std::tuple{getMutable(ty), getMutable(tfti)}; i1 && i2) + cloneChildren(i1, i2); + else if (auto [n1, n2] = std::tuple{getMutable(ty), getMutable(tfti)}; n1 && n2) + cloneChildren(n1, n2); + else if (auto [t1, t2] = std::tuple{getMutable(ty), getMutable(tfti)}; t1 && t2) + cloneChildren(t1, t2); + else if (auto [f1, f2] = std::tuple{getMutable(ty), getMutable(tfti)}; f1 && f2) + cloneChildren(f1, f2); + else if (auto [c1, c2] = std::tuple{getMutable(ty), getMutable(tfti)}; c1 && c2) + cloneChildren(c1, c2); + else + LUAU_ASSERT(!"Unknown pair?"); // First and argument should always represent the same types + } + + void cloneChildren(TypeFunctionTypePackId tp, TypeFunctionTypePackId tftp) + { + if (auto [tPack1, tPack2] = std::tuple{getMutable(tp), getMutable(tftp)}; tPack1 && tPack2) + cloneChildren(tPack1, tPack2); + else if (auto [vPack1, vPack2] = std::tuple{getMutable(tp), getMutable(tftp)}; + vPack1 && vPack2) + cloneChildren(vPack1, vPack2); + else + LUAU_ASSERT(!"Unknown pair?"); // First and argument should always represent the same types + } + + void cloneChildren(TypeFunctionKind kind, TypeFunctionKind tfkind) + { + if (auto [ty, tfty] = std::tuple{get(kind), get(tfkind)}; ty && tfty) + cloneChildren(*ty, *tfty); + else if (auto [tp, tftp] = std::tuple{get(kind), get(tfkind)}; tp && tftp) + cloneChildren(*tp, *tftp); + else + LUAU_ASSERT(!"Unknown pair?"); // First and argument should always represent the same types + } + + void cloneChildren(TypeFunctionPrimitiveType* p1, TypeFunctionPrimitiveType* p2) + { + // noop. + } + + void cloneChildren(TypeFunctionUnknownType* u1, TypeFunctionUnknownType* u2) + { + // noop. + } + + void cloneChildren(TypeFunctionNeverType* n1, TypeFunctionNeverType* n2) + { + // noop. + } + + void cloneChildren(TypeFunctionAnyType* a1, TypeFunctionAnyType* a2) + { + // noop. + } + + void cloneChildren(TypeFunctionSingletonType* s1, TypeFunctionSingletonType* s2) + { + // noop. + } + + void cloneChildren(TypeFunctionUnionType* u1, TypeFunctionUnionType* u2) + { + for (TypeFunctionTypeId& ty : u1->components) + u2->components.push_back(shallowClone(ty)); + } + + void cloneChildren(TypeFunctionIntersectionType* i1, TypeFunctionIntersectionType* i2) + { + for (TypeFunctionTypeId& ty : i1->components) + i2->components.push_back(shallowClone(ty)); + } + + void cloneChildren(TypeFunctionNegationType* n1, TypeFunctionNegationType* n2) + { + n2->type = shallowClone(n1->type); + } + + void cloneChildren(TypeFunctionTableType* t1, TypeFunctionTableType* t2) + { + for (auto& [k, p] : t1->props) + { + std::optional readTy; + if (p.readTy) + readTy = shallowClone(*p.readTy); + + std::optional writeTy; + if (p.writeTy) + writeTy = shallowClone(*p.writeTy); + + t2->props[k] = TypeFunctionProperty{readTy, writeTy}; + } + + if (t1->indexer.has_value()) + t2->indexer = TypeFunctionTableIndexer(shallowClone(t1->indexer->keyType), shallowClone(t1->indexer->valueType)); + + if (t1->metatable.has_value()) + t2->metatable = shallowClone(*t1->metatable); + } + + void cloneChildren(TypeFunctionFunctionType* f1, TypeFunctionFunctionType* f2) + { + f2->argTypes = shallowClone(f1->argTypes); + f2->retTypes = shallowClone(f1->retTypes); + } + + void cloneChildren(TypeFunctionClassType* c1, TypeFunctionClassType* c2) + { + // noop. + } + + void cloneChildren(TypeFunctionTypePack* t1, TypeFunctionTypePack* t2) + { + for (TypeFunctionTypeId& ty : t1->head) + t2->head.push_back(shallowClone(ty)); + } + + void cloneChildren(TypeFunctionVariadicTypePack* v1, TypeFunctionVariadicTypePack* v2) + { + v2->type = shallowClone(v1->type); + } +}; + +TypeFunctionTypeId deepClone(NotNull runtime, TypeFunctionTypeId ty) +{ + return TypeFunctionCloner(runtime).clone(ty); +} + +} // namespace Luau diff --git a/Analysis/src/TypeFunctionRuntimeBuilder.cpp b/Analysis/src/TypeFunctionRuntimeBuilder.cpp new file mode 100644 index 000000000..e14c37739 --- /dev/null +++ b/Analysis/src/TypeFunctionRuntimeBuilder.cpp @@ -0,0 +1,788 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Luau/TypeFunctionRuntimeBuilder.h" + +#include "Luau/Ast.h" +#include "Luau/BuiltinDefinitions.h" +#include "Luau/Common.h" +#include "Luau/DenseHash.h" +#include "Luau/StringUtils.h" +#include "Luau/Type.h" +#include "Luau/TypeArena.h" +#include "Luau/TypeFwd.h" +#include "Luau/TypeFunctionRuntime.h" +#include "Luau/TypePack.h" +#include "Luau/ToString.h" + +#include + +// used to control the recursion limit of any operations done by user-defined type functions +// currently, controls serialization, deserialization, and `type.copy` +LUAU_DYNAMIC_FASTINTVARIABLE(LuauTypeFunctionSerdeIterationLimit, 100'000); + +namespace Luau +{ + +// Forked version of Clone.cpp +class TypeFunctionSerializer +{ + using SeenTypes = DenseHashMap; + using SeenTypePacks = DenseHashMap; + + TypeFunctionRuntimeBuilderState* state = nullptr; + NotNull typeFunctionRuntime; + + // A queue of TypeFunctionTypeIds that have been serialized, but whose interior types hasn't + // been updated to point to itself. Once all of its interior types + // has been updated, it gets removed from the queue. + + // queue.back() should always return two of same type in their respective sides + // For example `auto [first, second] = queue.back()`: if first is PrimitiveType, + // second must be TypeFunctionPrimitiveType; else there should be an error + std::vector> queue; + + SeenTypes types; // Mapping of TypeIds that have been shallow serialized to TypeFunctionTypeIds + SeenTypePacks packs; // Mapping of TypePackIds that have been shallow serialized to TypeFunctionTypePackIds + + int steps = 0; + +public: + explicit TypeFunctionSerializer(TypeFunctionRuntimeBuilderState* state) + : state(state) + , typeFunctionRuntime(state->ctx->typeFunctionRuntime) + , queue({}) + , types({}) + , packs({}) + { + } + + TypeFunctionTypeId serialize(TypeId ty) + { + shallowSerialize(ty); + run(); + + if (hasExceededIterationLimit() || state->errors.size() != 0) + return nullptr; + + return find(ty).value_or(nullptr); + } + + TypeFunctionTypePackId serialize(TypePackId tp) + { + shallowSerialize(tp); + run(); + + if (hasExceededIterationLimit() || state->errors.size() != 0) + return nullptr; + + return find(tp).value_or(nullptr); + } + +private: + bool hasExceededIterationLimit() const + { + if (DFInt::LuauTypeFunctionSerdeIterationLimit == 0) + return false; + + return steps + queue.size() >= size_t(DFInt::LuauTypeFunctionSerdeIterationLimit); + } + + void run() + { + while (!queue.empty()) + { + ++steps; + + if (hasExceededIterationLimit() || state->errors.size() != 0) + break; + + auto [ty, tfti] = queue.back(); + queue.pop_back(); + + serializeChildren(ty, tfti); + } + } + + std::optional find(TypeId ty) const + { + if (auto result = types.find(ty)) + return *result; + + return std::nullopt; + } + + std::optional find(TypePackId tp) const + { + if (auto result = packs.find(tp)) + return *result; + + return std::nullopt; + } + + std::optional find(Kind kind) const + { + if (auto ty = get(kind)) + return find(*ty); + else if (auto tp = get(kind)) + return find(*tp); + else + { + LUAU_ASSERT(!"Unknown kind found at TypeFunctionRuntimeSerializer"); + return std::nullopt; + } + } + + TypeFunctionTypeId shallowSerialize(TypeId ty) + { + ty = follow(ty); + + if (auto it = find(ty)) + return *it; + + // Create a shallow serialization + TypeFunctionTypeId target = {}; + if (auto p = get(ty)) + { + switch (p->type) + { + case PrimitiveType::Type::NilType: + target = typeFunctionRuntime->typeArena.allocate(TypeFunctionPrimitiveType(TypeFunctionPrimitiveType::NilType)); + break; + case PrimitiveType::Type::Boolean: + target = typeFunctionRuntime->typeArena.allocate(TypeFunctionPrimitiveType(TypeFunctionPrimitiveType::Boolean)); + break; + case PrimitiveType::Number: + target = typeFunctionRuntime->typeArena.allocate(TypeFunctionPrimitiveType(TypeFunctionPrimitiveType::Number)); + break; + case PrimitiveType::String: + target = typeFunctionRuntime->typeArena.allocate(TypeFunctionPrimitiveType(TypeFunctionPrimitiveType::String)); + break; + case PrimitiveType::Thread: + case PrimitiveType::Function: + case PrimitiveType::Table: + case PrimitiveType::Buffer: + default: + { + std::string error = format("Argument of primitive type %s is not currently serializable by type functions", toString(ty).c_str()); + state->errors.push_back(error); + } + } + } + else if (auto u = get(ty)) + target = typeFunctionRuntime->typeArena.allocate(TypeFunctionUnknownType{}); + else if (auto a = get(ty)) + target = typeFunctionRuntime->typeArena.allocate(TypeFunctionNeverType{}); + else if (auto a = get(ty)) + target = typeFunctionRuntime->typeArena.allocate(TypeFunctionAnyType{}); + else if (auto s = get(ty)) + { + if (auto bs = get(s)) + target = typeFunctionRuntime->typeArena.allocate(TypeFunctionSingletonType{TypeFunctionBooleanSingleton{bs->value}}); + else if (auto ss = get(s)) + target = typeFunctionRuntime->typeArena.allocate(TypeFunctionSingletonType{TypeFunctionStringSingleton{ss->value}}); + else + { + std::string error = format("Argument of singleton type %s is not currently serializable by type functions", toString(ty).c_str()); + state->errors.push_back(error); + } + } + else if (auto u = get(ty)) + target = typeFunctionRuntime->typeArena.allocate(TypeFunctionUnionType{{}}); + else if (auto i = get(ty)) + target = typeFunctionRuntime->typeArena.allocate(TypeFunctionIntersectionType{{}}); + else if (auto n = get(ty)) + target = typeFunctionRuntime->typeArena.allocate(TypeFunctionNegationType{{}}); + else if (auto t = get(ty)) + target = typeFunctionRuntime->typeArena.allocate(TypeFunctionTableType{{}, std::nullopt, std::nullopt}); + else if (auto m = get(ty)) + target = typeFunctionRuntime->typeArena.allocate(TypeFunctionTableType{{}, std::nullopt, std::nullopt}); + else if (auto f = get(ty)) + { + TypeFunctionTypePackId emptyTypePack = typeFunctionRuntime->typePackArena.allocate(TypeFunctionTypePack{}); + target = typeFunctionRuntime->typeArena.allocate(TypeFunctionFunctionType{emptyTypePack, emptyTypePack}); + } + else if (auto c = get(ty)) + { + state->classesSerialized[c->name] = ty; + target = typeFunctionRuntime->typeArena.allocate(TypeFunctionClassType{{}, std::nullopt, std::nullopt, std::nullopt, c->name}); + } + else + { + std::string error = format("Argument of type %s is not currently serializable by type functions", toString(ty).c_str()); + state->errors.push_back(error); + } + + types[ty] = target; + queue.emplace_back(ty, target); + return target; + } + + TypeFunctionTypePackId shallowSerialize(TypePackId tp) + { + tp = follow(tp); + + if (auto it = find(tp)) + return *it; + + // Create a shallow serialization + TypeFunctionTypePackId target = {}; + if (auto tPack = get(tp)) + target = typeFunctionRuntime->typePackArena.allocate(TypeFunctionTypePack{{}}); + else if (auto vPack = get(tp)) + target = typeFunctionRuntime->typePackArena.allocate(TypeFunctionVariadicTypePack{}); + else + { + std::string error = format("Argument of type pack %s is not currently serializable by type functions", toString(tp).c_str()); + state->errors.push_back(error); + } + + packs[tp] = target; + queue.emplace_back(tp, target); + return target; + } + + void serializeChildren(TypeId ty, TypeFunctionTypeId tfti) + { + if (auto [p1, p2] = std::tuple{getMutable(ty), getMutable(tfti)}; p1 && p2) + serializeChildren(p1, p2); + else if (auto [u1, u2] = std::tuple{getMutable(ty), getMutable(tfti)}; u1 && u2) + serializeChildren(u1, u2); + else if (auto [n1, n2] = std::tuple{getMutable(ty), getMutable(tfti)}; n1 && n2) + serializeChildren(n1, n2); + else if (auto [a1, a2] = std::tuple{getMutable(ty), getMutable(tfti)}; a1 && a2) + serializeChildren(a1, a2); + else if (auto [s1, s2] = std::tuple{getMutable(ty), getMutable(tfti)}; s1 && s2) + serializeChildren(s1, s2); + else if (auto [u1, u2] = std::tuple{getMutable(ty), getMutable(tfti)}; u1 && u2) + serializeChildren(u1, u2); + else if (auto [i1, i2] = std::tuple{getMutable(ty), getMutable(tfti)}; i1 && i2) + serializeChildren(i1, i2); + else if (auto [n1, n2] = std::tuple{getMutable(ty), getMutable(tfti)}; n1 && n2) + serializeChildren(n1, n2); + else if (auto [t1, t2] = std::tuple{getMutable(ty), getMutable(tfti)}; t1 && t2) + serializeChildren(t1, t2); + else if (auto [m1, m2] = std::tuple{getMutable(ty), getMutable(tfti)}; m1 && m2) + serializeChildren(m1, m2); + else if (auto [f1, f2] = std::tuple{getMutable(ty), getMutable(tfti)}; f1 && f2) + serializeChildren(f1, f2); + else if (auto [c1, c2] = std::tuple{getMutable(ty), getMutable(tfti)}; c1 && c2) + serializeChildren(c1, c2); + else + { // Either this or ty and tfti do not represent the same type + std::string error = format("Argument of type %s is not currently serializable by type functions", toString(ty).c_str()); + state->errors.push_back(error); + } + } + + void serializeChildren(TypePackId tp, TypeFunctionTypePackId tftp) + { + if (auto [tPack1, tPack2] = std::tuple{getMutable(tp), getMutable(tftp)}; tPack1 && tPack2) + serializeChildren(tPack1, tPack2); + else if (auto [vPack1, vPack2] = std::tuple{getMutable(tp), getMutable(tftp)}; + vPack1 && vPack2) + serializeChildren(vPack1, vPack2); + else + { // Either this or ty and tfti do not represent the same type + std::string error = format("Argument of type pack %s is not currently serializable by type functions", toString(tp).c_str()); + state->errors.push_back(error); + } + } + + void serializeChildren(Kind kind, TypeFunctionKind tfkind) + { + if (auto [ty, tfty] = std::tuple{get(kind), get(tfkind)}; ty && tfty) + serializeChildren(*ty, *tfty); + else if (auto [tp, tftp] = std::tuple{get(kind), get(tfkind)}; tp && tftp) + serializeChildren(*tp, *tftp); + else + state->ctx->ice->ice("Serializing user defined type function arguments: kind and tfkind do not represent the same type"); + } + + void serializeChildren(PrimitiveType* p1, TypeFunctionPrimitiveType* p2) + { + // noop. + } + + void serializeChildren(UnknownType* u1, TypeFunctionUnknownType* u2) + { + // noop. + } + + void serializeChildren(NeverType* n1, TypeFunctionNeverType* n2) + { + // noop. + } + + void serializeChildren(AnyType* a1, TypeFunctionAnyType* a2) + { + // noop. + } + + void serializeChildren(SingletonType* s1, TypeFunctionSingletonType* s2) + { + // noop. + } + + void serializeChildren(UnionType* u1, TypeFunctionUnionType* u2) + { + for (TypeId& ty : u1->options) + u2->components.push_back(shallowSerialize(ty)); + } + + void serializeChildren(IntersectionType* i1, TypeFunctionIntersectionType* i2) + { + for (TypeId& ty : i1->parts) + i2->components.push_back(shallowSerialize(ty)); + } + + void serializeChildren(NegationType* n1, TypeFunctionNegationType* n2) + { + n2->type = shallowSerialize(n1->ty); + } + + void serializeChildren(TableType* t1, TypeFunctionTableType* t2) + { + for (const auto& [k, p] : t1->props) + { + std::optional readTy = std::nullopt; + if (p.readTy) + readTy = shallowSerialize(*p.readTy); + + std::optional writeTy = std::nullopt; + if (p.writeTy) + writeTy = shallowSerialize(*p.writeTy); + + t2->props[k] = TypeFunctionProperty{readTy, writeTy}; + } + + if (t1->indexer) + t2->indexer = TypeFunctionTableIndexer(shallowSerialize(t1->indexer->indexType), shallowSerialize(t1->indexer->indexResultType)); + } + + void serializeChildren(MetatableType* m1, TypeFunctionTableType* m2) + { + auto tmpTable = get(shallowSerialize(m1->table)); + if (!tmpTable) + state->ctx->ice->ice("Serializing user defined type function arguments: metatable's table is not a TableType"); + + m2->props = tmpTable->props; + m2->indexer = tmpTable->indexer; + + m2->metatable = shallowSerialize(m1->metatable); + } + + void serializeChildren(FunctionType* f1, TypeFunctionFunctionType* f2) + { + f2->argTypes = shallowSerialize(f1->argTypes); + f2->retTypes = shallowSerialize(f1->retTypes); + } + + void serializeChildren(ClassType* c1, TypeFunctionClassType* c2) + { + for (const auto& [k, p] : c1->props) + { + std::optional readTy = std::nullopt; + if (p.readTy) + readTy = shallowSerialize(*p.readTy); + + std::optional writeTy = std::nullopt; + if (p.writeTy) + writeTy = shallowSerialize(*p.writeTy); + + c2->props[k] = TypeFunctionProperty{readTy, writeTy}; + } + + if (c1->indexer) + c2->indexer = TypeFunctionTableIndexer(shallowSerialize(c1->indexer->indexType), shallowSerialize(c1->indexer->indexResultType)); + + if (c1->metatable) + c2->metatable = shallowSerialize(*c1->metatable); + + if (c1->parent) + c2->parent = shallowSerialize(*c1->parent); + } + + void serializeChildren(TypePack* t1, TypeFunctionTypePack* t2) + { + for (TypeId& ty : t1->head) + t2->head.push_back(shallowSerialize(ty)); + + if (t1->tail.has_value()) + t2->tail = shallowSerialize(*t1->tail); + } + + void serializeChildren(VariadicTypePack* v1, TypeFunctionVariadicTypePack* v2) + { + v2->type = shallowSerialize(v1->ty); + } +}; + +// Complete inverse of TypeFunctionSerializer +class TypeFunctionDeserializer +{ + using SeenTypes = DenseHashMap; + using SeenTypePacks = DenseHashMap; + + TypeFunctionRuntimeBuilderState* state = nullptr; + NotNull typeFunctionRuntime; + + // A queue of TypeIds that have been deserialized, but whose interior types hasn't + // been updated to point to itself. Once all of its interior types + // has been updated, it gets removed from the queue. + + // queue.back() should always return two of same type in their respective sides + // For example `auto [first, second] = queue.back()`: if first is TypeFunctionPrimitiveType, + // second must be PrimitiveType; else there should be an error + std::vector> queue; + + SeenTypes types; // Mapping of TypeFunctionTypeIds that have been shallow deserialized to TypeIds + SeenTypePacks packs; // Mapping of TypeFunctionTypePackIds that have been shallow deserialized to TypePackIds + + int steps = 0; + +public: + explicit TypeFunctionDeserializer(TypeFunctionRuntimeBuilderState* state) + : state(state) + , typeFunctionRuntime(state->ctx->typeFunctionRuntime) + , queue({}) + , types({}) + , packs({}){}; + + TypeId deserialize(TypeFunctionTypeId ty) + { + shallowDeserialize(ty); + run(); + + if (hasExceededIterationLimit() || state->errors.size() != 0) + { + TypeId error = state->ctx->builtins->errorRecoveryType(); + types[ty] = error; + return error; + } + + return find(ty).value_or(state->ctx->builtins->errorRecoveryType()); + } + + TypePackId deserialize(TypeFunctionTypePackId tp) + { + shallowDeserialize(tp); + run(); + + if (hasExceededIterationLimit() || state->errors.size() != 0) + { + TypePackId error = state->ctx->builtins->errorRecoveryTypePack(); + packs[tp] = error; + return error; + } + + return find(tp).value_or(state->ctx->builtins->errorRecoveryTypePack()); + } + +private: + bool hasExceededIterationLimit() const + { + if (DFInt::LuauTypeFunctionSerdeIterationLimit == 0) + return false; + + return steps + queue.size() >= size_t(DFInt::LuauTypeFunctionSerdeIterationLimit); + } + + void run() + { + while (!queue.empty()) + { + ++steps; + + if (hasExceededIterationLimit() || state->errors.size() != 0) + break; + + auto [tfti, ty] = queue.back(); + queue.pop_back(); + + deserializeChildren(tfti, ty); + } + } + + std::optional find(TypeFunctionTypeId ty) const + { + if (auto result = types.find(ty)) + return *result; + + return std::nullopt; + } + + std::optional find(TypeFunctionTypePackId tp) const + { + if (auto result = packs.find(tp)) + return *result; + + return std::nullopt; + } + + std::optional find(TypeFunctionKind kind) const + { + if (auto ty = get(kind)) + return find(*ty); + else if (auto tp = get(kind)) + return find(*tp); + else + { + LUAU_ASSERT(!"Unknown kind found at TypeFunctionDeserializer"); + return std::nullopt; + } + } + + TypeId shallowDeserialize(TypeFunctionTypeId ty) + { + if (auto it = find(ty)) + return *it; + + // Create a shallow deserialization + TypeId target = {}; + if (auto p = get(ty)) + { + switch (p->type) + { + case TypeFunctionPrimitiveType::Type::NilType: + target = state->ctx->builtins->nilType; + break; + case TypeFunctionPrimitiveType::Type::Boolean: + target = state->ctx->builtins->booleanType; + break; + case TypeFunctionPrimitiveType::Type::Number: + target = state->ctx->builtins->numberType; + break; + case TypeFunctionPrimitiveType::Type::String: + target = state->ctx->builtins->stringType; + break; + default: + state->ctx->ice->ice("Deserializing user defined type function arguments: mysterious type is being deserialized"); + } + } + else if (auto u = get(ty)) + target = state->ctx->builtins->unknownType; + else if (auto n = get(ty)) + target = state->ctx->builtins->neverType; + else if (auto a = get(ty)) + target = state->ctx->builtins->anyType; + else if (auto s = get(ty)) + { + if (auto bs = get(s)) + target = state->ctx->arena->addType(SingletonType{BooleanSingleton{bs->value}}); + else if (auto ss = get(s)) + target = state->ctx->arena->addType(SingletonType{StringSingleton{ss->value}}); + else + state->ctx->ice->ice("Deserializing user defined type function arguments: mysterious type is being deserialized"); + } + else if (auto u = get(ty)) + target = state->ctx->arena->addTV(Type(UnionType{{}})); + else if (auto i = get(ty)) + target = state->ctx->arena->addTV(Type(IntersectionType{{}})); + else if (auto n = get(ty)) + target = state->ctx->arena->addType(NegationType{state->ctx->builtins->unknownType}); + else if (auto t = get(ty); t && !t->metatable.has_value()) + target = state->ctx->arena->addType(TableType{TableType::Props{}, std::nullopt, TypeLevel{}, TableState::Sealed}); + else if (auto m = get(ty); m && m->metatable.has_value()) + { + TypeId emptyTable = state->ctx->arena->addType(TableType{TableType::Props{}, std::nullopt, TypeLevel{}, TableState::Sealed}); + target = state->ctx->arena->addType(MetatableType{emptyTable, emptyTable}); + } + else if (auto f = get(ty)) + { + TypePackId emptyTypePack = state->ctx->arena->addTypePack(TypePack{}); + target = state->ctx->arena->addType(FunctionType{emptyTypePack, emptyTypePack, {}, false}); + } + else if (auto c = get(ty)) + { + if (auto result = state->classesSerialized.find(c->name)) + target = *result; + else + state->ctx->ice->ice("Deserializing user defined type function arguments: mysterious class type is being deserialized"); + } + else + state->ctx->ice->ice("Deserializing user defined type function arguments: mysterious type is being deserialized"); + + types[ty] = target; + queue.emplace_back(ty, target); + return target; + } + + TypePackId shallowDeserialize(TypeFunctionTypePackId tp) + { + if (auto it = find(tp)) + return *it; + + // Create a shallow deserialization + TypePackId target = {}; + if (auto tPack = get(tp)) + target = state->ctx->arena->addTypePack(TypePack{}); + else if (auto vPack = get(tp)) + target = state->ctx->arena->addTypePack(VariadicTypePack{}); + else + state->ctx->ice->ice("Deserializing user defined type function arguments: mysterious type is being deserialized"); + + packs[tp] = target; + queue.emplace_back(tp, target); + return target; + } + + void deserializeChildren(TypeFunctionTypeId tfti, TypeId ty) + { + if (auto [p1, p2] = std::tuple{getMutable(ty), getMutable(tfti)}; p1 && p2) + deserializeChildren(p2, p1); + else if (auto [u1, u2] = std::tuple{getMutable(ty), getMutable(tfti)}; u1 && u2) + deserializeChildren(u2, u1); + else if (auto [n1, n2] = std::tuple{getMutable(ty), getMutable(tfti)}; n1 && n2) + deserializeChildren(n2, n1); + else if (auto [a1, a2] = std::tuple{getMutable(ty), getMutable(tfti)}; a1 && a2) + deserializeChildren(a2, a1); + else if (auto [s1, s2] = std::tuple{getMutable(ty), getMutable(tfti)}; s1 && s2) + deserializeChildren(s2, s1); + else if (auto [u1, u2] = std::tuple{getMutable(ty), getMutable(tfti)}; u1 && u2) + deserializeChildren(u2, u1); + else if (auto [i1, i2] = std::tuple{getMutable(ty), getMutable(tfti)}; i1 && i2) + deserializeChildren(i2, i1); + else if (auto [n1, n2] = std::tuple{getMutable(ty), getMutable(tfti)}; n1 && n2) + deserializeChildren(n2, n1); + else if (auto [t1, t2] = std::tuple{getMutable(ty), getMutable(tfti)}; + t1 && t2 && !t2->metatable.has_value()) + deserializeChildren(t2, t1); + else if (auto [m1, m2] = std::tuple{getMutable(ty), getMutable(tfti)}; + m1 && m2 && m2->metatable.has_value()) + deserializeChildren(m2, m1); + else if (auto [f1, f2] = std::tuple{getMutable(ty), getMutable(tfti)}; f1 && f2) + deserializeChildren(f2, f1); + else if (auto [c1, c2] = std::tuple{getMutable(ty), getMutable(tfti)}; c1 && c2) + deserializeChildren(c2, c1); + else + state->ctx->ice->ice("Deserializing user defined type function arguments: mysterious type is being deserialized"); + } + + void deserializeChildren(TypeFunctionTypePackId tftp, TypePackId tp) + { + if (auto [tPack1, tPack2] = std::tuple{getMutable(tp), getMutable(tftp)}; tPack1 && tPack2) + deserializeChildren(tPack2, tPack1); + else if (auto [vPack1, vPack2] = std::tuple{getMutable(tp), getMutable(tftp)}; + vPack1 && vPack2) + deserializeChildren(vPack2, vPack1); + else + state->ctx->ice->ice("Deserializing user defined type function arguments: mysterious type is being deserialized"); + } + + void deserializeChildren(TypeFunctionKind tfkind, Kind kind) + { + if (auto [ty, tfty] = std::tuple{get(kind), get(tfkind)}; ty && tfty) + deserializeChildren(*tfty, *ty); + else if (auto [tp, tftp] = std::tuple{get(kind), get(tfkind)}; tp && tftp) + deserializeChildren(*tftp, *tp); + else + state->ctx->ice->ice("Deserializing user defined type function arguments: tfkind and kind do not represent the same type"); + } + + void deserializeChildren(TypeFunctionPrimitiveType* p2, PrimitiveType* p1) + { + // noop. + } + + void deserializeChildren(TypeFunctionUnknownType* u2, UnknownType* u1) + { + // noop. + } + + void deserializeChildren(TypeFunctionNeverType* n2, NeverType* n1) + { + // noop. + } + + void deserializeChildren(TypeFunctionAnyType* a2, AnyType* a1) + { + // noop. + } + + void deserializeChildren(TypeFunctionSingletonType* s2, SingletonType* s1) + { + // noop. + } + + void deserializeChildren(TypeFunctionUnionType* u2, UnionType* u1) + { + for (TypeFunctionTypeId& ty : u2->components) + u1->options.push_back(shallowDeserialize(ty)); + } + + void deserializeChildren(TypeFunctionIntersectionType* i2, IntersectionType* i1) + { + for (TypeFunctionTypeId& ty : i2->components) + i1->parts.push_back(shallowDeserialize(ty)); + } + + void deserializeChildren(TypeFunctionNegationType* n2, NegationType* n1) + { + n1->ty = shallowDeserialize(n2->type); + } + + void deserializeChildren(TypeFunctionTableType* t2, TableType* t1) + { + for (const auto& [k, p] : t2->props) + { + if (p.readTy && p.writeTy) + t1->props[k] = Property::rw(shallowDeserialize(*p.readTy), shallowDeserialize(*p.writeTy)); + else if (p.readTy) + t1->props[k] = Property::readonly(shallowDeserialize(*p.readTy)); + else if (p.writeTy) + t1->props[k] = Property::writeonly(shallowDeserialize(*p.writeTy)); + } + + if (t2->indexer.has_value()) + t1->indexer = TableIndexer(shallowDeserialize(t2->indexer->keyType), shallowDeserialize(t2->indexer->valueType)); + } + + void deserializeChildren(TypeFunctionTableType* m2, MetatableType* m1) + { + TypeFunctionTypeId temp = typeFunctionRuntime->typeArena.allocate(TypeFunctionTableType{m2->props, m2->indexer}); + m1->table = shallowDeserialize(temp); + + if (m2->metatable.has_value()) + m1->metatable = shallowDeserialize(*m2->metatable); + } + + void deserializeChildren(TypeFunctionFunctionType* f2, FunctionType* f1) + { + if (f2->argTypes) + f1->argTypes = shallowDeserialize(f2->argTypes); + + if (f2->retTypes) + f1->retTypes = shallowDeserialize(f2->retTypes); + } + + void deserializeChildren(TypeFunctionClassType* c2, ClassType* c1) + { + // noop. + } + + void deserializeChildren(TypeFunctionTypePack* t2, TypePack* t1) + { + for (TypeFunctionTypeId& ty : t2->head) + t1->head.push_back(shallowDeserialize(ty)); + + if (t2->tail.has_value()) + t1->tail = shallowDeserialize(*t2->tail); + } + + void deserializeChildren(TypeFunctionVariadicTypePack* v2, VariadicTypePack* v1) + { + v1->ty = shallowDeserialize(v2->type); + } +}; + +TypeFunctionTypeId serialize(TypeId ty, TypeFunctionRuntimeBuilderState* state) +{ + return TypeFunctionSerializer(state).serialize(ty); +} + +TypeId deserialize(TypeFunctionTypeId ty, TypeFunctionRuntimeBuilderState* state) +{ + return TypeFunctionDeserializer(state).deserialize(ty); +} + +} // namespace Luau diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index 4774b0a12..cb1176351 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -2,12 +2,11 @@ #include "Luau/TypeInfer.h" #include "Luau/ApplyTypeFunction.h" -#include "Luau/Clone.h" +#include "Luau/Cancellation.h" #include "Luau/Common.h" #include "Luau/Instantiation.h" #include "Luau/ModuleResolver.h" #include "Luau/Normalize.h" -#include "Luau/Parser.h" #include "Luau/Quantify.h" #include "Luau/RecursionCounter.h" #include "Luau/Scope.h" @@ -18,7 +17,6 @@ #include "Luau/ToString.h" #include "Luau/Type.h" #include "Luau/TypePack.h" -#include "Luau/TypeReduction.h" #include "Luau/TypeUtils.h" #include "Luau/VisitType.h" @@ -26,7 +24,6 @@ #include LUAU_FASTFLAGVARIABLE(DebugLuauMagicTypes, false) -LUAU_FASTFLAGVARIABLE(LuauDontExtendUnsealedRValueTables, false) LUAU_FASTINTVARIABLE(LuauTypeInferRecursionLimit, 165) LUAU_FASTINTVARIABLE(LuauTypeInferIterationLimit, 20000) LUAU_FASTINTVARIABLE(LuauTypeInferTypePackLoopLimit, 5000) @@ -34,16 +31,8 @@ LUAU_FASTINTVARIABLE(LuauCheckRecursionLimit, 300) LUAU_FASTINTVARIABLE(LuauVisitRecursionLimit, 500) LUAU_FASTFLAG(LuauKnowsTheDataModel3) LUAU_FASTFLAGVARIABLE(DebugLuauFreezeDuringUnification, false) -LUAU_FASTFLAGVARIABLE(LuauReturnAnyInsteadOfICE, false) // Eventually removed as false. -LUAU_FASTFLAGVARIABLE(DebugLuauSharedSelf, false) -LUAU_FASTFLAGVARIABLE(LuauTryhardAnd, false) LUAU_FASTFLAG(LuauInstantiateInSubtyping) -LUAU_FASTFLAGVARIABLE(LuauIntersectionTestForEquality, false) -LUAU_FASTFLAG(LuauNegatedClassTypes) -LUAU_FASTFLAGVARIABLE(LuauAllowIndexClassParameters, false) -LUAU_FASTFLAG(LuauUninhabitedSubAnything2) -LUAU_FASTFLAG(SupportTypeAliasGoToDeclaration) -LUAU_FASTFLAGVARIABLE(LuauTypecheckTypeguards, false) +LUAU_FASTFLAGVARIABLE(LuauAcceptIndexingTableUnionsIntersections, false) namespace Luau { @@ -204,7 +193,8 @@ static bool isMetamethod(const Name& name) { return name == "__index" || name == "__newindex" || name == "__call" || name == "__concat" || name == "__unm" || name == "__add" || name == "__sub" || name == "__mul" || name == "__div" || name == "__mod" || name == "__pow" || name == "__tostring" || - name == "__metatable" || name == "__eq" || name == "__lt" || name == "__le" || name == "__mode" || name == "__iter" || name == "__len"; + name == "__metatable" || name == "__eq" || name == "__lt" || name == "__le" || name == "__mode" || name == "__iter" || name == "__len" || + name == "__idiv"; } size_t HashBoolNamePair::operator()(const std::pair& pair) const @@ -212,17 +202,20 @@ size_t HashBoolNamePair::operator()(const std::pair& pair) const return std::hash()(pair.first) ^ std::hash()(pair.second); } -TypeChecker::TypeChecker(ModuleResolver* resolver, NotNull builtinTypes, InternalErrorReporter* iceHandler) - : resolver(resolver) +TypeChecker::TypeChecker(const ScopePtr& globalScope, ModuleResolver* resolver, NotNull builtinTypes, InternalErrorReporter* iceHandler) + : globalScope(globalScope) + , resolver(resolver) , builtinTypes(builtinTypes) , iceHandler(iceHandler) , unifierState(iceHandler) , normalizer(nullptr, builtinTypes, NotNull{&unifierState}) + , reusableInstantiation(TxnLog::empty(), nullptr, builtinTypes, {}, nullptr) , nilType(builtinTypes->nilType) , numberType(builtinTypes->numberType) , stringType(builtinTypes->stringType) , booleanType(builtinTypes->booleanType) , threadType(builtinTypes->threadType) + , bufferType(builtinTypes->bufferType) , anyType(builtinTypes->anyType) , unknownType(builtinTypes->unknownType) , neverType(builtinTypes->neverType) @@ -231,16 +224,6 @@ TypeChecker::TypeChecker(ModuleResolver* resolver, NotNull builtin , uninhabitableTypePack(builtinTypes->uninhabitableTypePack) , duplicateTypeAliases{{false, {}}} { - globalScope = std::make_shared(globalTypes.addTypePack(TypePackVar{FreeTypePack{TypeLevel{}}})); - - globalScope->addBuiltinTypeBinding("any", TypeFun{{}, anyType}); - globalScope->addBuiltinTypeBinding("nil", TypeFun{{}, nilType}); - globalScope->addBuiltinTypeBinding("number", TypeFun{{}, numberType}); - globalScope->addBuiltinTypeBinding("string", TypeFun{{}, stringType}); - globalScope->addBuiltinTypeBinding("boolean", TypeFun{{}, booleanType}); - globalScope->addBuiltinTypeBinding("thread", TypeFun{{}, threadType}); - globalScope->addBuiltinTypeBinding("unknown", TypeFun{{}, unknownType}); - globalScope->addBuiltinTypeBinding("never", TypeFun{{}, neverType}); } ModulePtr TypeChecker::check(const SourceModule& module, Mode mode, std::optional environmentScope) @@ -260,9 +243,13 @@ ModulePtr TypeChecker::checkWithoutRecursionCheck(const SourceModule& module, Mo { LUAU_TIMETRACE_SCOPE("TypeChecker::check", "TypeChecker"); LUAU_TIMETRACE_ARGUMENT("module", module.name.c_str()); + LUAU_TIMETRACE_ARGUMENT("name", module.humanReadableName.c_str()); currentModule.reset(new Module); - currentModule->reduction = std::make_unique(NotNull{¤tModule->internalTypes}, builtinTypes, NotNull{iceHandler}); + currentModule->name = module.name; + currentModule->humanReadableName = module.humanReadableName; + currentModule->internalTypes.owningModule = currentModule.get(); + currentModule->interfaceTypes.owningModule = currentModule.get(); currentModule->type = module.type; currentModule->allocator = module.allocator; currentModule->names = module.names; @@ -286,10 +273,8 @@ ModulePtr TypeChecker::checkWithoutRecursionCheck(const SourceModule& module, Mo currentModule->scopes.push_back(std::make_pair(module.root->location, moduleScope)); currentModule->mode = mode; - currentModuleName = module.name; - if (prepareModuleScope) - prepareModuleScope(module.name, currentModule->getModuleScope()); + prepareModuleScope(currentModule->name, currentModule->getModuleScope()); try { @@ -299,12 +284,9 @@ ModulePtr TypeChecker::checkWithoutRecursionCheck(const SourceModule& module, Mo { currentModule->timeout = true; } - - if (FFlag::DebugLuauSharedSelf) + catch (const UserCancelError&) { - for (auto& [ty, scope] : deferredQuantification) - Luau::quantify(ty, scope->level); - deferredQuantification.clear(); + currentModule->cancelled = true; } if (get(follow(moduleScope->returnType))) @@ -339,42 +321,54 @@ ModulePtr TypeChecker::checkWithoutRecursionCheck(const SourceModule& module, Mo return std::move(currentModule); } -void TypeChecker::check(const ScopePtr& scope, const AstStat& program) +ControlFlow TypeChecker::check(const ScopePtr& scope, const AstStat& program) { + if (finishTime && TimeTrace::getClock() > *finishTime) + throwTimeLimitError(); + if (cancellationToken && cancellationToken->requested()) + throwUserCancelError(); + if (auto block = program.as()) - check(scope, *block); + return check(scope, *block); else if (auto if_ = program.as()) - check(scope, *if_); + return check(scope, *if_); else if (auto while_ = program.as()) - check(scope, *while_); + return check(scope, *while_); else if (auto repeat = program.as()) - check(scope, *repeat); + return check(scope, *repeat); else if (program.is()) - { - } // Nothing to do + return ControlFlow::Breaks; else if (program.is()) - { - } // Nothing to do + return ControlFlow::Continues; else if (auto return_ = program.as()) - check(scope, *return_); + return check(scope, *return_); else if (auto expr = program.as()) + { checkExprPack(scope, *expr->expr); + + if (auto call = expr->expr->as(); call && doesCallError(call)) + return ControlFlow::Throws; + + return ControlFlow::None; + } else if (auto local = program.as()) - check(scope, *local); + return check(scope, *local); else if (auto for_ = program.as()) - check(scope, *for_); + return check(scope, *for_); else if (auto forIn = program.as()) - check(scope, *forIn); + return check(scope, *forIn); else if (auto assign = program.as()) - check(scope, *assign); + return check(scope, *assign); else if (auto assign = program.as()) - check(scope, *assign); + return check(scope, *assign); else if (program.is()) ice("Should not be calling two-argument check() on a function statement", program.location); else if (program.is()) ice("Should not be calling two-argument check() on a function statement", program.location); else if (auto typealias = program.as()) - check(scope, *typealias); + return check(scope, *typealias); + else if (auto typefunction = program.as()) + return check(scope, *typefunction); else if (auto global = program.as()) { TypeId globalType = resolveType(scope, *global->type); @@ -382,11 +376,13 @@ void TypeChecker::check(const ScopePtr& scope, const AstStat& program) currentModule->declaredGlobals[globalName] = globalType; currentModule->getModuleScope()->bindings[global->name] = Binding{globalType, global->location}; + + return ControlFlow::None; } else if (auto global = program.as()) - check(scope, *global); + return check(scope, *global); else if (auto global = program.as()) - check(scope, *global); + return check(scope, *global); else if (auto errorStatement = program.as()) { const size_t oldSize = currentModule->errors.size(); @@ -400,37 +396,40 @@ void TypeChecker::check(const ScopePtr& scope, const AstStat& program) // HACK: We want to run typechecking on the contents of the AstStatError, but // we don't think the type errors will be useful most of the time. currentModule->errors.resize(oldSize); + + return ControlFlow::None; } else ice("Unknown AstStat"); - - if (finishTime && TimeTrace::getClock() > *finishTime) - throw TimeLimitError(iceHandler->moduleName); } // This particular overload is for do...end. If you need to not increase the scope level, use checkBlock directly. -void TypeChecker::check(const ScopePtr& scope, const AstStatBlock& block) +ControlFlow TypeChecker::check(const ScopePtr& scope, const AstStatBlock& block) { ScopePtr child = childScope(scope, block.location); - checkBlock(child, block); + + ControlFlow flow = checkBlock(child, block); + scope->inheritRefinements(child); + + return flow; } -void TypeChecker::checkBlock(const ScopePtr& scope, const AstStatBlock& block) +ControlFlow TypeChecker::checkBlock(const ScopePtr& scope, const AstStatBlock& block) { RecursionCounter _rc(&checkRecursionCount); if (FInt::LuauCheckRecursionLimit > 0 && checkRecursionCount >= FInt::LuauCheckRecursionLimit) { reportErrorCodeTooComplex(block.location); - return; + return ControlFlow::None; } try { - checkBlockWithoutRecursionCheck(scope, block); + return checkBlockWithoutRecursionCheck(scope, block); } catch (const RecursionLimitException&) { reportErrorCodeTooComplex(block.location); - return; + return ControlFlow::None; } } @@ -483,7 +482,7 @@ struct InplaceDemoter : TypeOnceVisitor } }; -void TypeChecker::checkBlockWithoutRecursionCheck(const ScopePtr& scope, const AstStatBlock& block) +ControlFlow TypeChecker::checkBlockWithoutRecursionCheck(const ScopePtr& scope, const AstStatBlock& block) { int subLevel = 0; @@ -508,7 +507,8 @@ void TypeChecker::checkBlockWithoutRecursionCheck(const ScopePtr& scope, const A std::unordered_map> functionDecls; - auto checkBody = [&](AstStat* stat) { + auto checkBody = [&](AstStat* stat) + { if (auto fun = stat->as()) { LUAU_ASSERT(functionDecls.count(stat)); @@ -523,6 +523,7 @@ void TypeChecker::checkBlockWithoutRecursionCheck(const ScopePtr& scope, const A } }; + std::optional firstFlow; while (protoIter != sorted.end()) { // protoIter walks forward @@ -565,43 +566,21 @@ void TypeChecker::checkBlockWithoutRecursionCheck(const ScopePtr& scope, const A // We do check the current element, so advance checkIter beyond it. ++checkIter; - check(scope, **protoIter); + ControlFlow flow = check(scope, **protoIter); + if (flow != ControlFlow::None && !firstFlow) + firstFlow = flow; } else if (auto fun = (*protoIter)->as()) { - std::optional selfType; + std::optional selfType; // TODO clip std::optional expectedType; - if (FFlag::DebugLuauSharedSelf) + if (!fun->func->self) { if (auto name = fun->name->as()) { - TypeId baseTy = checkExpr(scope, *name->expr).type; - tablify(baseTy); - - if (!fun->func->self) - expectedType = getIndexTypeFromType(scope, baseTy, name->index.value, name->indexLocation, /* addErrors= */ false); - else if (auto ttv = getMutableTableType(baseTy)) - { - if (!baseTy->persistent && ttv->state != TableState::Sealed && !ttv->selfTy) - { - ttv->selfTy = anyIfNonstrict(freshType(ttv->level)); - deferredQuantification.push_back({baseTy, scope}); - } - - selfType = ttv->selfTy; - } - } - } - else - { - if (!fun->func->self) - { - if (auto name = fun->name->as()) - { - TypeId exprTy = checkExpr(scope, *name->expr).type; - expectedType = getIndexTypeFromType(scope, exprTy, name->index.value, name->indexLocation, /* addErrors= */ false); - } + TypeId exprTy = checkExpr(scope, *name->expr).type; + expectedType = getIndexTypeFromType(scope, exprTy, name->index.value, name->indexLocation, /* addErrors= */ false); } } @@ -626,7 +605,11 @@ void TypeChecker::checkBlockWithoutRecursionCheck(const ScopePtr& scope, const A scope->bindings[fun->name] = {funTy, fun->name->location}; } else - check(scope, **protoIter); + { + ControlFlow flow = check(scope, **protoIter); + if (flow != ControlFlow::None && !firstFlow) + firstFlow = flow; + } ++protoIter; } @@ -638,6 +621,8 @@ void TypeChecker::checkBlockWithoutRecursionCheck(const ScopePtr& scope, const A } checkBlockTypeAliases(scope, sorted); + + return firstFlow.value_or(ControlFlow::None); } LUAU_NOINLINE void TypeChecker::checkBlockTypeAliases(const ScopePtr& scope, std::vector& sorted) @@ -646,7 +631,7 @@ LUAU_NOINLINE void TypeChecker::checkBlockTypeAliases(const ScopePtr& scope, std { if (const auto& typealias = stat->as()) { - if (typealias->name == kParseNameError) + if (typealias->name == kParseNameError || typealias->name == "typeof") continue; auto& bindings = typealias->exported ? scope->exportedTypeBindings : scope->privateTypeBindings; @@ -656,11 +641,10 @@ LUAU_NOINLINE void TypeChecker::checkBlockTypeAliases(const ScopePtr& scope, std if (duplicateTypeAliases.contains({typealias->exported, name})) continue; - TypeId type = bindings[name].type; - if (get(follow(type))) + TypeId type = follow(bindings[name].type); + if (get(type)) { - Type* mty = asMutable(follow(type)); - mty->reassign(*errorRecoveryType(anyType)); + asMutable(type)->ty.emplace(errorRecoveryType(anyType)); reportError(TypeError{typealias->location, OccursCheckFailed{}}); } @@ -712,20 +696,32 @@ static std::optional tryGetTypeGuardPredicate(const AstExprBinary& ex return predicate; } -void TypeChecker::check(const ScopePtr& scope, const AstStatIf& statement) +ControlFlow TypeChecker::check(const ScopePtr& scope, const AstStatIf& statement) { WithPredicate result = checkExpr(scope, *statement.condition); - ScopePtr ifScope = childScope(scope, statement.thenbody->location); - resolve(result.predicates, ifScope, true); - check(ifScope, *statement.thenbody); + ScopePtr thenScope = childScope(scope, statement.thenbody->location); + resolve(result.predicates, thenScope, true); + ScopePtr elseScope = childScope(scope, statement.elsebody ? statement.elsebody->location : statement.location); + resolve(result.predicates, elseScope, false); + + ControlFlow thencf = check(thenScope, *statement.thenbody); + ControlFlow elsecf = ControlFlow::None; if (statement.elsebody) - { - ScopePtr elseScope = childScope(scope, statement.elsebody->location); - resolve(result.predicates, elseScope, false); - check(elseScope, *statement.elsebody); - } + elsecf = check(elseScope, *statement.elsebody); + + if (thencf != ControlFlow::None && elsecf == ControlFlow::None) + scope->inheritRefinements(elseScope); + else if (thencf == ControlFlow::None && elsecf != ControlFlow::None) + scope->inheritRefinements(thenScope); + + if (thencf == elsecf) + return thencf; + else if (matches(thencf, ControlFlow::Returns | ControlFlow::Throws) && matches(elsecf, ControlFlow::Returns | ControlFlow::Throws)) + return ControlFlow::Returns; + else + return ControlFlow::None; } template @@ -745,22 +741,26 @@ ErrorVec TypeChecker::canUnify(TypePackId subTy, TypePackId superTy, const Scope return canUnify_(subTy, superTy, scope, location); } -void TypeChecker::check(const ScopePtr& scope, const AstStatWhile& statement) +ControlFlow TypeChecker::check(const ScopePtr& scope, const AstStatWhile& statement) { WithPredicate result = checkExpr(scope, *statement.condition); ScopePtr whileScope = childScope(scope, statement.body->location); resolve(result.predicates, whileScope, true); check(whileScope, *statement.body); + + return ControlFlow::None; } -void TypeChecker::check(const ScopePtr& scope, const AstStatRepeat& statement) +ControlFlow TypeChecker::check(const ScopePtr& scope, const AstStatRepeat& statement) { ScopePtr repScope = childScope(scope, statement.location); checkBlock(repScope, *statement.body); checkExpr(repScope, *statement.condition); + + return ControlFlow::None; } struct Demoter : Substitution @@ -782,7 +782,7 @@ struct Demoter : Substitution bool ignoreChildren(TypeId ty) override { - if (FFlag::LuauClassTypeVarsInSubstitution && get(ty)) + if (get(ty)) return true; return false; @@ -817,7 +817,7 @@ struct Demoter : Substitution } }; -void TypeChecker::check(const ScopePtr& scope, const AstStatReturn& return_) +ControlFlow TypeChecker::check(const ScopePtr& scope, const AstStatReturn& return_) { std::vector> expectedTypes; expectedTypes.reserve(return_.list.size); @@ -853,10 +853,12 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatReturn& return_) if (!errors.empty()) currentModule->getModuleScope()->returnType = addTypePack({anyType}); - return; + return ControlFlow::Returns; } unify(retPack, scope->returnType, scope, return_.location, CountMismatch::Context::Return); + + return ControlFlow::Returns; } template @@ -888,7 +890,7 @@ ErrorVec TypeChecker::tryUnify(TypePackId subTy, TypePackId superTy, const Scope return tryUnify_(subTy, superTy, scope, location); } -void TypeChecker::check(const ScopePtr& scope, const AstStatAssign& assign) +ControlFlow TypeChecker::check(const ScopePtr& scope, const AstStatAssign& assign) { std::vector> expectedTypes; expectedTypes.reserve(assign.vars.size); @@ -958,7 +960,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatAssign& assign) right = errorRecoveryType(scope); else if (auto vtp = get(tailPack)) right = vtp->ty; - else if (get(tailPack)) + else if (get(tailPack)) { *asMutable(tailPack) = TypePack{{left}}; growingPack = getMutable(tailPack); @@ -988,9 +990,11 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatAssign& assign) } } } + + return ControlFlow::None; } -void TypeChecker::check(const ScopePtr& scope, const AstStatCompoundAssign& assign) +ControlFlow TypeChecker::check(const ScopePtr& scope, const AstStatCompoundAssign& assign) { AstExprBinary expr(assign.location, assign.op, assign.var, assign.value); @@ -1000,9 +1004,11 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatCompoundAssign& assi TypeId result = checkBinaryOperation(scope, expr, left, right); unify(result, left, scope, assign.location); + + return ControlFlow::None; } -void TypeChecker::check(const ScopePtr& scope, const AstStatLocal& local) +ControlFlow TypeChecker::check(const ScopePtr& scope, const AstStatLocal& local) { // Important subtlety: A local variable is not in scope while its initializer is being evaluated. // For instance, you cannot do this: @@ -1114,15 +1120,24 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatLocal& local) { AstExpr* require = *maybeRequire; - if (auto moduleInfo = resolver->resolveModuleInfo(currentModuleName, *require)) + if (auto moduleInfo = resolver->resolveModuleInfo(currentModule->name, *require)) { const Name name{local.vars.data[i]->name.value}; if (ModulePtr module = resolver->getModule(moduleInfo->name)) { scope->importedTypeBindings[name] = module->exportedTypeBindings; - if (FFlag::SupportTypeAliasGoToDeclaration) - scope->importedModules[name] = moduleInfo->name; + scope->importedModules[name] = moduleInfo->name; + + // Imported types of requires that transitively refer to current module have to be replaced with 'any' + for (const auto& [location, path] : requireCycles) + { + if (!path.empty() && path.front() == moduleInfo->name) + { + for (auto& [name, tf] : scope->importedTypeBindings[name]) + tf = TypeFun{{}, {}, anyType}; + } + } } // In non-strict mode we force the module type on the variable, in strict mode it is already unified @@ -1140,9 +1155,11 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatLocal& local) for (const auto& [local, binding] : varBindings) scope->bindings[local] = binding; + + return ControlFlow::None; } -void TypeChecker::check(const ScopePtr& scope, const AstStatFor& expr) +ControlFlow TypeChecker::check(const ScopePtr& scope, const AstStatFor& expr) { ScopePtr loopScope = childScope(scope, expr.location); @@ -1165,9 +1182,11 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatFor& expr) unify(checkExpr(loopScope, *expr.step).type, loopVarType, scope, expr.step->location); check(loopScope, *expr.body); + + return ControlFlow::None; } -void TypeChecker::check(const ScopePtr& scope, const AstStatForIn& forin) +ControlFlow TypeChecker::check(const ScopePtr& scope, const AstStatForIn& forin) { ScopePtr loopScope = childScope(scope, forin.location); @@ -1212,7 +1231,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatForIn& forin) callRetPack = checkExprPack(scope, *exprCall).type; callRetPack = follow(callRetPack); - if (get(callRetPack)) + if (get(callRetPack)) { iterTy = freshType(scope); unify(callRetPack, addTypePack({{iterTy}, freshTypePack(scope)}), scope, forin.location); @@ -1263,19 +1282,10 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatForIn& forin) for (size_t i = 2; i < varTypes.size(); ++i) unify(nilType, varTypes[i], scope, forin.location); } - else if (isNonstrictMode()) - { - for (TypeId var : varTypes) - unify(anyType, var, scope, forin.location); - } else { - TypeId varTy = errorRecoveryType(loopScope); - for (TypeId var : varTypes) - unify(varTy, var, scope, forin.location); - - reportError(firstValue->location, GenericError{"Cannot iterate over a table without indexer"}); + unify(unknownType, var, scope, forin.location); } return check(loopScope, *forin.body); @@ -1356,9 +1366,11 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatForIn& forin) unify(retPack, varPack, scope, forin.location); check(loopScope, *forin.body); + + return ControlFlow::None; } -void TypeChecker::check(const ScopePtr& scope, TypeId ty, const ScopePtr& funScope, const AstStatFunction& function) +ControlFlow TypeChecker::check(const ScopePtr& scope, TypeId ty, const ScopePtr& funScope, const AstStatFunction& function) { if (auto exprName = function.name->as()) { @@ -1383,8 +1395,6 @@ void TypeChecker::check(const ScopePtr& scope, TypeId ty, const ScopePtr& funSco globalBindings[name] = oldBinding; else globalBindings[name] = {quantify(funScope, ty, exprName->location), exprName->location}; - - return; } else if (auto name = function.name->as()) { @@ -1393,7 +1403,6 @@ void TypeChecker::check(const ScopePtr& scope, TypeId ty, const ScopePtr& funSco checkFunctionBody(funScope, ty, *function.func); scope->bindings[name->local] = {anyIfNonstrict(quantify(funScope, ty, name->local->location)), name->local->location}; - return; } else if (auto name = function.name->as()) { @@ -1440,9 +1449,11 @@ void TypeChecker::check(const ScopePtr& scope, TypeId ty, const ScopePtr& funSco checkFunctionBody(funScope, ty, *function.func); } + + return ControlFlow::None; } -void TypeChecker::check(const ScopePtr& scope, TypeId ty, const ScopePtr& funScope, const AstStatLocalFunction& function) +ControlFlow TypeChecker::check(const ScopePtr& scope, TypeId ty, const ScopePtr& funScope, const AstStatLocalFunction& function) { Name name = function.name->name.value; @@ -1451,15 +1462,23 @@ void TypeChecker::check(const ScopePtr& scope, TypeId ty, const ScopePtr& funSco checkFunctionBody(funScope, ty, *function.func); scope->bindings[function.name] = {quantify(funScope, ty, function.name->location), function.name->location}; + + return ControlFlow::None; } -void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias) +ControlFlow TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias) { Name name = typealias.name.value; // If the alias is missing a name, we can't do anything with it. Ignore it. if (name == kParseNameError) - return; + return ControlFlow::None; + + if (name == "typeof") + { + reportError(typealias.location, GenericError{"Type aliases cannot be named typeof"}); + return ControlFlow::None; + } std::optional binding; if (auto it = scope->exportedTypeBindings.find(name); it != scope->exportedTypeBindings.end()) @@ -1472,7 +1491,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias // If the first pass failed (this should mean a duplicate definition), the second pass isn't going to be // interesting. if (duplicateTypeAliases.find({typealias.exported, name})) - return; + return ControlFlow::None; // By now this alias must have been `prototype()`d first. if (!binding) @@ -1502,14 +1521,26 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias // Additionally, we can't modify types that come from other modules if (ttv->name || follow(ty)->owningArena != ¤tModule->internalTypes) { - bool sameTys = std::equal(ttv->instantiatedTypeParams.begin(), ttv->instantiatedTypeParams.end(), binding->typeParams.begin(), - binding->typeParams.end(), [](auto&& itp, auto&& tp) { + bool sameTys = std::equal( + ttv->instantiatedTypeParams.begin(), + ttv->instantiatedTypeParams.end(), + binding->typeParams.begin(), + binding->typeParams.end(), + [](auto&& itp, auto&& tp) + { return itp == tp.ty; - }); - bool sameTps = std::equal(ttv->instantiatedTypePackParams.begin(), ttv->instantiatedTypePackParams.end(), binding->typePackParams.begin(), - binding->typePackParams.end(), [](auto&& itpp, auto&& tpp) { + } + ); + bool sameTps = std::equal( + ttv->instantiatedTypePackParams.begin(), + ttv->instantiatedTypePackParams.end(), + binding->typePackParams.begin(), + binding->typePackParams.end(), + [](auto&& itpp, auto&& tpp) + { return itpp == tpp.tp; - }); + } + ); // Copy can be skipped if this is an identical alias if (!ttv->name || ttv->name != name || !sameTys || !sameTps) @@ -1551,8 +1582,29 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias TypeId& bindingType = bindingsMap[name].type; - if (unify(ty, bindingType, aliasScope, typealias.location)) - bindingType = ty; + unify(ty, bindingType, aliasScope, typealias.location); + + // It is possible for this unification to succeed but for + // `bindingType` still to be free For example, in + // `type T = T|T`, we generate a fresh free type `X`, and then + // unify `X` with `X|X`, which succeeds without binding `X` to + // anything, since `X <: X|X` + if (bindingType->ty.get_if()) + { + ty = errorRecoveryType(aliasScope); + unify(ty, bindingType, aliasScope, typealias.location); + reportError(TypeError{typealias.location, OccursCheckFailed{}}); + } + + bindingType = ty; + return ControlFlow::None; +} + +ControlFlow TypeChecker::check(const ScopePtr& scope, const AstStatTypeFunction& typefunction) +{ + reportError(TypeError{typefunction.location, GenericError{"This syntax is not supported"}}); + + return ControlFlow::None; } void TypeChecker::prototype(const ScopePtr& scope, const AstStatTypeAlias& typealias, int subLevel) @@ -1560,7 +1612,9 @@ void TypeChecker::prototype(const ScopePtr& scope, const AstStatTypeAlias& typea Name name = typealias.name.value; // If the alias is missing a name, we can't do anything with it. Ignore it. - if (name == kParseNameError) + // Also, typeof is not a valid type alias name. We will report an error for + // this in check() + if (name == kParseNameError || name == "typeof") return; std::optional binding; @@ -1601,15 +1655,14 @@ void TypeChecker::prototype(const ScopePtr& scope, const AstStatTypeAlias& typea bindingsMap[name] = {std::move(generics), std::move(genericPacks), ty}; scope->typeAliasLocations[name] = typealias.location; - if (FFlag::SupportTypeAliasGoToDeclaration) - scope->typeAliasNameLocations[name] = typealias.nameLocation; + scope->typeAliasNameLocations[name] = typealias.nameLocation; } } } void TypeChecker::prototype(const ScopePtr& scope, const AstStatDeclareClass& declaredClass) { - std::optional superTy = FFlag::LuauNegatedClassTypes ? std::make_optional(builtinTypes->classType) : std::nullopt; + std::optional superTy = std::make_optional(builtinTypes->classType); if (declaredClass.superName) { Name superName = Name(declaredClass.superName->value); @@ -1628,8 +1681,10 @@ void TypeChecker::prototype(const ScopePtr& scope, const AstStatDeclareClass& de if (!get(follow(*superTy))) { - reportError(declaredClass.location, - GenericError{format("Cannot use non-class type '%s' as a superclass of class '%s'", superName.c_str(), declaredClass.name.value)}); + reportError( + declaredClass.location, + GenericError{format("Cannot use non-class type '%s' as a superclass of class '%s'", superName.c_str(), declaredClass.name.value)} + ); incorrectClassDefinitions.insert(&declaredClass); return; } @@ -1637,7 +1692,7 @@ void TypeChecker::prototype(const ScopePtr& scope, const AstStatDeclareClass& de Name className(declaredClass.name.value); - TypeId classTy = addType(ClassType(className, {}, superTy, std::nullopt, {}, {}, currentModuleName)); + TypeId classTy = addType(ClassType(className, {}, superTy, std::nullopt, {}, {}, currentModule->name, declaredClass.location)); ClassType* ctv = getMutable(classTy); TypeId metaTy = addType(TableType{TableState::Sealed, scope->level}); @@ -1645,13 +1700,13 @@ void TypeChecker::prototype(const ScopePtr& scope, const AstStatDeclareClass& de scope->exportedTypeBindings[className] = TypeFun{{}, classTy}; } -void TypeChecker::check(const ScopePtr& scope, const AstStatDeclareClass& declaredClass) +ControlFlow TypeChecker::check(const ScopePtr& scope, const AstStatDeclareClass& declaredClass) { Name className(declaredClass.name.value); // Don't bother checking if the class definition was incorrect if (incorrectClassDefinitions.find(&declaredClass)) - return; + return ControlFlow::None; std::optional binding; if (auto it = scope->exportedTypeBindings.find(className); it != scope->exportedTypeBindings.end()) @@ -1667,6 +1722,9 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatDeclareClass& declar if (!ctv->metatable) ice("No metatable for declared class"); + if (const auto& indexer = declaredClass.indexer) + ctv->indexer = TableIndexer(resolveType(scope, *indexer->indexType), resolveType(scope, *indexer->resultType)); + TableType* metatable = getMutable(*ctv->metatable); for (const AstDeclaredClassProp& prop : declaredClass.props) { @@ -1685,16 +1743,26 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatDeclareClass& declar ftv->argNames.insert(ftv->argNames.begin(), FunctionArgument{"self", {}}); ftv->argTypes = addTypePack(TypePack{{classTy}, ftv->argTypes}); ftv->hasSelf = true; + + FunctionDefinition defn; + + defn.definitionModuleName = currentModule->name; + defn.definitionLocation = prop.location; + // No data is preserved for varargLocation + defn.originalNameLocation = prop.nameLocation; + + ftv->definition = defn; } } if (assignTo.count(propName) == 0) { - assignTo[propName] = {propTy}; + assignTo[propName] = {propTy, /*deprecated*/ false, /*deprecatedSuggestion*/ "", prop.location}; } else { - TypeId currentTy = assignTo[propName].type; + Luau::Property& prop = assignTo[propName]; + TypeId currentTy = prop.type(); // We special-case this logic to keep the intersection flat; otherwise we // would create a ton of nested intersection types. @@ -1704,13 +1772,15 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatDeclareClass& declar options.push_back(propTy); TypeId newItv = addType(IntersectionType{std::move(options)}); - assignTo[propName] = {newItv}; + prop.readTy = newItv; + prop.writeTy = newItv; } else if (get(currentTy)) { TypeId intersection = addType(IntersectionType{{currentTy, propTy}}); - assignTo[propName] = {intersection}; + prop.readTy = intersection; + prop.writeTy = intersection; } else { @@ -1718,9 +1788,11 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatDeclareClass& declar } } } + + return ControlFlow::None; } -void TypeChecker::check(const ScopePtr& scope, const AstStatDeclareFunction& global) +ControlFlow TypeChecker::check(const ScopePtr& scope, const AstStatDeclareFunction& global) { ScopePtr funScope = childFunctionScope(scope, global.location); @@ -1728,19 +1800,39 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatDeclareFunction& glo std::vector genericTys; genericTys.reserve(generics.size()); - std::transform(generics.begin(), generics.end(), std::back_inserter(genericTys), [](auto&& el) { - return el.ty; - }); + std::transform( + generics.begin(), + generics.end(), + std::back_inserter(genericTys), + [](auto&& el) + { + return el.ty; + } + ); std::vector genericTps; genericTps.reserve(genericPacks.size()); - std::transform(genericPacks.begin(), genericPacks.end(), std::back_inserter(genericTps), [](auto&& el) { - return el.tp; - }); + std::transform( + genericPacks.begin(), + genericPacks.end(), + std::back_inserter(genericTps), + [](auto&& el) + { + return el.tp; + } + ); TypePackId argPack = resolveTypePack(funScope, global.params); TypePackId retPack = resolveTypePack(funScope, global.retTypes); - TypeId fnType = addType(FunctionType{funScope->level, std::move(genericTys), std::move(genericTps), argPack, retPack}); + + FunctionDefinition defn; + + defn.definitionModuleName = currentModule->name; + defn.definitionLocation = global.location; + defn.varargLocation = global.vararg ? std::make_optional(global.varargLocation) : std::nullopt; + defn.originalNameLocation = global.nameLocation; + + TypeId fnType = addType(FunctionType{funScope->level, std::move(genericTys), std::move(genericTps), argPack, retPack, defn}); FunctionType* ftv = getMutable(fnType); ftv->argNames.reserve(global.paramNames.size); @@ -1751,6 +1843,8 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatDeclareFunction& glo currentModule->declaredGlobals[fnName] = fnType; currentModule->getModuleScope()->bindings[global.name] = Binding{fnType, global.location}; + + return ControlFlow::None; } WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExpr& expr, std::optional expectedType, bool forceSingleton) @@ -1874,7 +1968,7 @@ WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExp return WithPredicate{errorRecoveryType(scope)}; else if (auto vtp = get(varargPack)) return WithPredicate{vtp->ty}; - else if (get(varargPack)) + else if (get(varargPack)) { // TODO: Better error? reportError(expr.location, GenericError{"Trying to get a type from a variadic type parameter"}); @@ -1893,7 +1987,7 @@ WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExp { return {pack->head.empty() ? nilType : pack->head[0], std::move(result.predicates)}; } - else if (const FreeTypePack* ftp = get(retPack)) + else if (const FreeTypePack* ftp = get(retPack)) { TypeId head = freshType(scope->level); TypePackId pack = addTypePack(TypePackVar{TypePack{{head}, freshTypePack(scope->level)}}); @@ -1904,13 +1998,8 @@ WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExp return {errorRecoveryType(scope), std::move(result.predicates)}; else if (auto vtp = get(retPack)) return {vtp->ty, std::move(result.predicates)}; - else if (get(retPack)) - { - if (FFlag::LuauReturnAnyInsteadOfICE) - return {anyType, std::move(result.predicates)}; - else - ice("Unexpected abstract type pack!", expr.location); - } + else if (get(retPack)) + return {anyType, std::move(result.predicates)}; else ice("Unknown TypePack type!", expr.location); } @@ -1953,7 +2042,12 @@ std::optional TypeChecker::findMetatableEntry(TypeId type, std::string e } std::optional TypeChecker::getIndexTypeFromType( - const ScopePtr& scope, TypeId type, const Name& name, const Location& location, bool addErrors) + const ScopePtr& scope, + TypeId type, + const Name& name, + const Location& location, + bool addErrors +) { size_t errorCount = currentModule->errors.size(); @@ -1966,7 +2060,12 @@ std::optional TypeChecker::getIndexTypeFromType( } std::optional TypeChecker::getIndexTypeFromTypeImpl( - const ScopePtr& scope, TypeId type, const Name& name, const Location& location, bool addErrors) + const ScopePtr& scope, + TypeId type, + const Name& name, + const Location& location, + bool addErrors +) { type = follow(type); @@ -1985,7 +2084,7 @@ std::optional TypeChecker::getIndexTypeFromTypeImpl( if (TableType* tableType = getMutableTableType(type)) { if (auto it = tableType->props.find(name); it != tableType->props.end()) - return it->second.type; + return it->second.type(); else if (auto indexer = tableType->indexer) { // TODO: Property lookup should work with string singletons or unions thereof as the indexer key type. @@ -2013,7 +2112,21 @@ std::optional TypeChecker::getIndexTypeFromTypeImpl( { const Property* prop = lookupClassProp(cls, name); if (prop) - return prop->type; + return prop->type(); + + if (auto indexer = cls->indexer) + { + // TODO: Property lookup should work with string singletons or unions thereof as the indexer key type. + ErrorVec errors = tryUnify(stringType, indexer->indexType, scope, location); + + if (errors.empty()) + return indexer->indexResultType; + + if (addErrors) + reportError(location, UnknownProperty{type, name}); + + return std::nullopt; + } } else if (const UnionType* utv = get(type)) { @@ -2151,7 +2264,11 @@ WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExp } TypeId TypeChecker::checkExprTable( - const ScopePtr& scope, const AstExprTable& expr, const std::vector>& fieldTypes, std::optional expectedType) + const ScopePtr& scope, + const AstExprTable& expr, + const std::vector>& fieldTypes, + std::optional expectedType +) { TableType::Props props; std::optional indexer; @@ -2203,9 +2320,9 @@ TypeId TypeChecker::checkExprTable( if (it != expectedTable->props.end()) { Property expectedProp = it->second; - ErrorVec errors = tryUnify(exprType, expectedProp.type, scope, k->location); + ErrorVec errors = tryUnify(exprType, expectedProp.type(), scope, k->location); if (errors.empty()) - exprType = expectedProp.type; + exprType = expectedProp.type(); } else if (expectedTable->indexer && maybeString(expectedTable->indexer->indexType)) { @@ -2241,7 +2358,7 @@ TypeId TypeChecker::checkExprTable( TableState state = TableState::Unsealed; TableType table = TableType{std::move(props), indexer, scope->level, state}; - table.definitionModuleName = currentModuleName; + table.definitionModuleName = currentModule->name; table.definitionLocation = expr.location; return addType(table); } @@ -2299,7 +2416,7 @@ WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExp if (expectedTable) { if (auto prop = expectedTable->props.find(key->value.data); prop != expectedTable->props.end()) - expectedResultType = prop->second.type; + expectedResultType = prop->second.type(); else if (expectedIndexType && maybeString(*expectedIndexType)) expectedResultType = expectedIndexResultType; } @@ -2311,7 +2428,7 @@ WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExp if (const TableType* ttv = get(follow(expectedOption))) { if (auto prop = ttv->props.find(key->value.data); prop != ttv->props.end()) - expectedResultTypes.push_back(prop->second.type); + expectedResultTypes.push_back(prop->second.type()); else if (ttv->indexer && maybeString(ttv->indexer->indexType)) expectedResultTypes.push_back(ttv->indexer->indexResultType); } @@ -2384,8 +2501,10 @@ WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExp return WithPredicate{retType}; } - reportError(expr.location, - GenericError{format("Unary operator '%s' not supported by type '%s'", toString(expr.op).c_str(), toString(operandType).c_str())}); + reportError( + expr.location, + GenericError{format("Unary operator '%s' not supported by type '%s'", toString(expr.op).c_str(), toString(operandType).c_str())} + ); return WithPredicate{errorRecoveryType(scope)}; } @@ -2452,14 +2571,14 @@ std::string opToMetaTableEntry(const AstExprBinary::Op& op) return "__mul"; case AstExprBinary::Div: return "__div"; + case AstExprBinary::FloorDiv: + return "__idiv"; case AstExprBinary::Mod: return "__mod"; case AstExprBinary::Pow: return "__pow"; case AstExprBinary::Concat: return "__concat"; - case AstExprBinary::DivInt: - return "__idiv"; case AstExprBinary::MaxOf: return "__max"; case AstExprBinary::MinOf: @@ -2546,28 +2665,47 @@ static std::optional areEqComparable(NotNull arena, NotNull(t); }; if (isExempt(a) || isExempt(b)) return true; + NormalizationResult nr; + TypeId c = arena->addType(IntersectionType{{a, b}}); - const NormalizedType* n = normalizer->normalize(c); + std::shared_ptr n = normalizer->normalize(c); if (!n) return std::nullopt; - if (FFlag::LuauUninhabitedSubAnything2) - return normalizer->isInhabited(n); - else - return isInhabited_DEPRECATED(*n); + nr = normalizer->isInhabited(n.get()); + + switch (nr) + { + case NormalizationResult::HitLimits: + return std::nullopt; + case NormalizationResult::False: + return false; + case NormalizationResult::True: + return true; + } + + // n.b. msvc can never figure this stuff out. + LUAU_UNREACHABLE(); } TypeId TypeChecker::checkRelationalOperation( - const ScopePtr& scope, const AstExprBinary& expr, TypeId lhsType, TypeId rhsType, const PredicateVec& predicates) + const ScopePtr& scope, + const AstExprBinary& expr, + TypeId lhsType, + TypeId rhsType, + const PredicateVec& predicates +) { - auto stripNil = [this](TypeId ty, bool isOrOp = false) { + auto stripNil = [this](TypeId ty, bool isOrOp = false) + { ty = follow(ty); if (!isNonstrictMode() && !isOrOp) return ty; @@ -2620,7 +2758,7 @@ TypeId TypeChecker::checkRelationalOperation( if (lhsIsAny || rhsIsAny) return booleanType; - // Fallthrough here is intentional + [[fallthrough]]; } case AstExprBinary::CompareLt: case AstExprBinary::CompareGt: @@ -2631,7 +2769,7 @@ TypeId TypeChecker::checkRelationalOperation( if (get(lhsType) || get(rhsType)) return booleanType; - if (FFlag::LuauIntersectionTestForEquality && isEquality) + if (isEquality) { // Unless either type is free or any, an equality comparison is only // valid when the intersection of the two operands is non-empty. @@ -2648,7 +2786,8 @@ TypeId TypeChecker::checkRelationalOperation( if (!*eqTestResult) { reportError( - expr.location, GenericError{format("Type %s cannot be compared with %s", toString(lhsType).c_str(), toString(rhsType).c_str())}); + expr.location, GenericError{format("Type %s cannot be compared with %s", toString(lhsType).c_str(), toString(rhsType).c_str())} + ); return errorRecoveryType(booleanType); } } @@ -2672,19 +2811,29 @@ TypeId TypeChecker::checkRelationalOperation( { reportErrors(state.errors); - if (!isEquality && state.errors.empty() && (get(leftType) || isBoolean(leftType))) + // The original version of this check also produced this error when we had a union type. + // However, the old solver does not readily have the ability to discern if the union is comparable. + // This is the case when the lhs is e.g. a union of singletons and the rhs is the combined type. + // The new solver has much more powerful logic for resolving relational operators, but for now, + // we need to be conservative in the old solver to deliver a reasonable developer experience. + if (!isEquality && state.errors.empty() && isBoolean(leftType)) { - reportError(expr.location, GenericError{format("Type '%s' cannot be compared with relational operator %s", toString(leftType).c_str(), - toString(expr.op).c_str())}); - } + reportError( + expr.location, + GenericError{ + format("Type '%s' cannot be compared with relational operator %s", toString(leftType).c_str(), toString(expr.op).c_str()) + } + ); + } return booleanType; } std::string metamethodName = opToMetaTableEntry(expr.op); - std::optional leftMetatable = isString(lhsType) ? std::nullopt : getMetatable(follow(lhsType), builtinTypes); - std::optional rightMetatable = isString(rhsType) ? std::nullopt : getMetatable(follow(rhsType), builtinTypes); + std::optional stringNoMT = std::nullopt; // works around gcc false positive "maybe uninitialized" warnings + std::optional leftMetatable = isString(lhsType) ? stringNoMT : getMetatable(follow(lhsType), builtinTypes); + std::optional rightMetatable = isString(rhsType) ? stringNoMT : getMetatable(follow(rhsType), builtinTypes); if (leftMetatable != rightMetatable) { @@ -2722,8 +2871,14 @@ TypeId TypeChecker::checkRelationalOperation( if (!matches) { reportError( - expr.location, GenericError{format("Types %s and %s cannot be compared with %s because they do not have the same metatable", - toString(lhsType).c_str(), toString(rhsType).c_str(), toString(expr.op).c_str())}); + expr.location, + GenericError{format( + "Types %s and %s cannot be compared with %s because they do not have the same metatable", + toString(lhsType).c_str(), + toString(rhsType).c_str(), + toString(expr.op).c_str() + )} + ); return errorRecoveryType(booleanType); } } @@ -2754,7 +2909,8 @@ TypeId TypeChecker::checkRelationalOperation( TypeId actualFunctionType = addType(FunctionType(scope->level, addTypePack({lhsType, rhsType}), addTypePack({booleanType}))); state.tryUnify( - instantiate(scope, actualFunctionType, expr.location), instantiate(scope, *metamethod, expr.location), /*isFunctionCall*/ true); + instantiate(scope, actualFunctionType, expr.location), instantiate(scope, *metamethod, expr.location), /*isFunctionCall*/ true + ); state.log.commit(); @@ -2764,7 +2920,8 @@ TypeId TypeChecker::checkRelationalOperation( else if (needsMetamethod) { reportError( - expr.location, GenericError{format("Table %s does not offer metamethod %s", toString(lhsType).c_str(), metamethodName.c_str())}); + expr.location, GenericError{format("Table %s does not offer metamethod %s", toString(lhsType).c_str(), metamethodName.c_str())} + ); return errorRecoveryType(booleanType); } } @@ -2778,8 +2935,12 @@ TypeId TypeChecker::checkRelationalOperation( if (needsMetamethod) { - reportError(expr.location, GenericError{format("Type %s cannot be compared with %s because it has no metatable", - toString(lhsType).c_str(), toString(expr.op).c_str())}); + reportError( + expr.location, + GenericError{ + format("Type %s cannot be compared with %s because it has no metatable", toString(lhsType).c_str(), toString(expr.op).c_str()) + } + ); return errorRecoveryType(booleanType); } @@ -2791,7 +2952,7 @@ TypeId TypeChecker::checkRelationalOperation( { return lhsType; } - else if (FFlag::LuauTryhardAnd) + else { // If lhs is free, we can't tell which 'falsy' components it has, if any if (get(lhsType)) @@ -2802,6 +2963,13 @@ TypeId TypeChecker::checkRelationalOperation( if (notNever) { LUAU_ASSERT(oty); + + // Perform a limited form of type reduction for booleans + if (isPrim(*oty, PrimitiveType::Boolean) && get(get(follow(rhsType)))) + return booleanType; + if (isPrim(rhsType, PrimitiveType::Boolean) && get(get(follow(*oty)))) + return booleanType; + return unionOfTypes(*oty, rhsType, scope, expr.location, false); } else @@ -2809,22 +2977,25 @@ TypeId TypeChecker::checkRelationalOperation( return rhsType; } } - else - { - return unionOfTypes(rhsType, booleanType, scope, expr.location, false); - } case AstExprBinary::Or: if (lhsIsAny) { return lhsType; } - else if (FFlag::LuauTryhardAnd) + else { auto [oty, notNever] = pickTypesFromSense(lhsType, true, neverType); // Filter out truthy types if (notNever) { LUAU_ASSERT(oty); + + // Perform a limited form of type reduction for booleans + if (isPrim(*oty, PrimitiveType::Boolean) && get(get(follow(rhsType)))) + return booleanType; + if (isPrim(rhsType, PrimitiveType::Boolean) && get(get(follow(*oty)))) + return booleanType; + return unionOfTypes(*oty, rhsType, scope, expr.location); } else @@ -2832,10 +3003,6 @@ TypeId TypeChecker::checkRelationalOperation( return rhsType; } } - else - { - return unionOfTypes(lhsType, rhsType, scope, expr.location); - } default: LUAU_ASSERT(0); ice(format("checkRelationalOperation called with incorrect binary expression '%s'", toString(expr.op).c_str()), expr.location); @@ -2843,7 +3010,12 @@ TypeId TypeChecker::checkRelationalOperation( } TypeId TypeChecker::checkBinaryOperation( - const ScopePtr& scope, const AstExprBinary& expr, TypeId lhsType, TypeId rhsType, const PredicateVec& predicates) + const ScopePtr& scope, + const AstExprBinary& expr, + TypeId lhsType, + TypeId rhsType, + const PredicateVec& predicates +) { switch (expr.op) { @@ -2894,7 +3066,8 @@ TypeId TypeChecker::checkBinaryOperation( if (typeCouldHaveMetatable(lhsType) || typeCouldHaveMetatable(rhsType)) { - auto checkMetatableCall = [this, &scope, &expr](TypeId fnt, TypeId lhst, TypeId rhst) -> TypeId { + auto checkMetatableCall = [this, &scope, &expr](TypeId fnt, TypeId lhst, TypeId rhst) -> TypeId + { TypeId actualFunctionType = instantiate(scope, fnt, expr.location); TypePackId arguments = addTypePack({lhst, rhst}); TypePackId retTypePack = freshTypePack(scope); @@ -2941,8 +3114,15 @@ TypeId TypeChecker::checkBinaryOperation( return checkMetatableCall(*fnt, rhsType, lhsType); } - reportError(expr.location, GenericError{format("Binary operator '%s' not supported by types '%s' and '%s'", toString(expr.op).c_str(), - toString(lhsType).c_str(), toString(rhsType).c_str())}); + reportError( + expr.location, + GenericError{format( + "Binary operator '%s' not supported by types '%s' and '%s'", + toString(expr.op).c_str(), + toString(lhsType).c_str(), + toString(rhsType).c_str() + )} + ); return errorRecoveryType(scope); } @@ -2957,9 +3137,9 @@ TypeId TypeChecker::checkBinaryOperation( case AstExprBinary::Sub: case AstExprBinary::Mul: case AstExprBinary::Div: + case AstExprBinary::FloorDiv: case AstExprBinary::Mod: case AstExprBinary::Pow: - case AstExprBinary::DivInt: case AstExprBinary::MaxOf: case AstExprBinary::MinOf: case AstExprBinary::BinAnd: @@ -3005,22 +3185,13 @@ WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExp } else if (expr.op == AstExprBinary::CompareEq || expr.op == AstExprBinary::CompareNe) { - if (!FFlag::LuauTypecheckTypeguards) - { - if (auto predicate = tryGetTypeGuardPredicate(expr)) - return {booleanType, {std::move(*predicate)}}; - } - // For these, passing expectedType is worse than simply forcing them, because their implementation // may inadvertently check if expectedTypes exist first and use it, instead of forceSingleton first. WithPredicate lhs = checkExpr(scope, *expr.left, std::nullopt, /*forceSingleton=*/true); WithPredicate rhs = checkExpr(scope, *expr.right, std::nullopt, /*forceSingleton=*/true); - if (FFlag::LuauTypecheckTypeguards) - { - if (auto predicate = tryGetTypeGuardPredicate(expr)) - return {booleanType, {std::move(*predicate)}}; - } + if (auto predicate = tryGetTypeGuardPredicate(expr)) + return {booleanType, {std::move(*predicate)}}; PredicateVec predicates; @@ -3188,22 +3359,13 @@ TypeId TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprIndex const auto& it = lhsTable->props.find(name); if (it != lhsTable->props.end()) { - return it->second.type; + return it->second.type(); } - else if (!FFlag::LuauDontExtendUnsealedRValueTables && (lhsTable->state == TableState::Unsealed || lhsTable->state == TableState::Free)) + else if ((ctx == ValueContext::LValue && lhsTable->state == TableState::Unsealed) || lhsTable->state == TableState::Free) { TypeId theType = freshType(scope); Property& property = lhsTable->props[name]; - property.type = theType; - property.location = expr.indexLocation; - return theType; - } - else if (FFlag::LuauDontExtendUnsealedRValueTables && - ((ctx == ValueContext::LValue && lhsTable->state == TableState::Unsealed) || lhsTable->state == TableState::Free)) - { - TypeId theType = freshType(scope); - Property& property = lhsTable->props[name]; - property.type = theType; + property.setType(theType); property.location = expr.indexLocation; return theType; } @@ -3236,14 +3398,24 @@ TypeId TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprIndex } else if (const ClassType* lhsClass = get(lhs)) { - const Property* prop = lookupClassProp(lhsClass, name); - if (!prop) + if (const Property* prop = lookupClassProp(lhsClass, name)) { - reportError(TypeError{expr.location, UnknownProperty{lhs, name}}); - return errorRecoveryType(scope); + return prop->type(); + } + + if (auto indexer = lhsClass->indexer) + { + Unifier state = mkUnifier(scope, expr.location); + state.tryUnify(stringType, indexer->indexType); + if (state.errors.empty()) + { + state.log.commit(); + return indexer->indexResultType; + } } - return prop->type; + reportError(TypeError{expr.location, UnknownProperty{lhs, name}}); + return errorRecoveryType(scope); } else if (get(lhs)) { @@ -3285,17 +3457,46 @@ TypeId TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprIndex { if (const ClassType* exprClass = get(exprType)) { - const Property* prop = lookupClassProp(exprClass, value->value.data); - if (!prop) + if (const Property* prop = lookupClassProp(exprClass, value->value.data)) { - reportError(TypeError{expr.location, UnknownProperty{exprType, value->value.data}}); + return prop->type(); + } + + if (auto indexer = exprClass->indexer) + { + unify(stringType, indexer->indexType, scope, expr.index->location); + return indexer->indexResultType; + } + + reportError(TypeError{expr.location, UnknownProperty{exprType, value->value.data}}); + return errorRecoveryType(scope); + } + else if (get(exprType)) + { + Name name = std::string(value->value.data, value->value.size); + + if (std::optional ty = getIndexTypeFromType(scope, exprType, name, expr.location, /* addErrors= */ false)) + return *ty; + + // If intersection has a table part, report that it cannot be extended just as a sealed table + if (isTableIntersection(exprType)) + { + reportError(TypeError{expr.location, CannotExtendTable{exprType, CannotExtendTable::Property, name}}); return errorRecoveryType(scope); } - return prop->type; } } - else if (FFlag::LuauAllowIndexClassParameters) + else { + if (const ClassType* exprClass = get(exprType)) + { + if (auto indexer = exprClass->indexer) + { + unify(indexType, indexer->indexType, scope, expr.index->location); + return indexer->indexResultType; + } + } + if (const ClassType* exprClass = get(exprType)) { if (isNonstrictMode()) @@ -3305,75 +3506,206 @@ TypeId TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprIndex } } - TableType* exprTable = getMutableTableType(exprType); - - if (!exprTable) + if (FFlag::LuauAcceptIndexingTableUnionsIntersections) { - reportError(TypeError{expr.expr->location, NotATable{exprType}}); - return errorRecoveryType(scope); - } + // We're going to have a whole vector. + std::vector tableTypes{}; + bool isUnion = true; - if (value) - { - const auto& it = exprTable->props.find(value->value.data); - if (it != exprTable->props.end()) + // We'd like for normalization eventually to deal with this sort of thing, but as a tactical affordance, we will + // attempt to deal with _one_ level of unions or intersections. + if (auto exprUnion = get(exprType)) { - return it->second.type; + tableTypes.reserve(exprUnion->options.size()); + + for (auto option : exprUnion) + { + TableType* optionTable = getMutableTableType(option); + + if (!optionTable) + { + // TODO: we could do better here and report `option` is not a table as reasoning for the error + reportError(TypeError{expr.expr->location, NotATable{exprType}}); + return errorRecoveryType(scope); + } + + tableTypes.push_back(optionTable); + } } - else if (!FFlag::LuauDontExtendUnsealedRValueTables && (exprTable->state == TableState::Unsealed || exprTable->state == TableState::Free)) + else if (auto exprIntersection = get(exprType)) { - TypeId resultType = freshType(scope); - Property& property = exprTable->props[value->value.data]; - property.type = resultType; - property.location = expr.index->location; - return resultType; + tableTypes.reserve(exprIntersection->parts.size()); + isUnion = false; + + for (auto part : exprIntersection) + { + TableType* partTable = getMutableTableType(part); + + if (!partTable) + { + // TODO: we could do better here and report `part` is not a table as reasoning for the error + reportError(TypeError{expr.expr->location, NotATable{exprType}}); + return errorRecoveryType(scope); + } + + tableTypes.push_back(partTable); + } } - else if (FFlag::LuauDontExtendUnsealedRValueTables && - ((ctx == ValueContext::LValue && exprTable->state == TableState::Unsealed) || exprTable->state == TableState::Free)) + else if (auto exprTable = getMutableTableType(exprType)) { - TypeId resultType = freshType(scope); - Property& property = exprTable->props[value->value.data]; - property.type = resultType; - property.location = expr.index->location; - return resultType; + tableTypes.push_back(exprTable); + } + else + { + reportError(TypeError{expr.expr->location, NotATable{exprType}}); + return errorRecoveryType(scope); } - } - if (exprTable->indexer) - { - const TableIndexer& indexer = *exprTable->indexer; - unify(indexType, indexer.indexType, scope, expr.index->location); - return indexer.indexResultType; - } - else if (!FFlag::LuauDontExtendUnsealedRValueTables && (exprTable->state == TableState::Unsealed || exprTable->state == TableState::Free)) - { - TypeId resultType = freshType(exprTable->level); - exprTable->indexer = TableIndexer{anyIfNonstrict(indexType), anyIfNonstrict(resultType)}; - return resultType; - } - else if (FFlag::LuauDontExtendUnsealedRValueTables && - ((ctx == ValueContext::LValue && exprTable->state == TableState::Unsealed) || exprTable->state == TableState::Free)) - { - TypeId indexerType = freshType(exprTable->level); - unify(indexType, indexerType, scope, expr.location); - TypeId indexResultType = freshType(exprTable->level); + if (value) + { + DenseHashSet propTypes{{}}; + + for (auto table : tableTypes) + { + const auto& it = table->props.find(value->value.data); + if (it != table->props.end()) + { + propTypes.insert(it->second.type()); + } + else if ((ctx == ValueContext::LValue && table->state == TableState::Unsealed) || table->state == TableState::Free) + { + TypeId resultType = freshType(scope); + Property& property = table->props[value->value.data]; + property.setType(resultType); + property.location = expr.index->location; + propTypes.insert(resultType); + } + } + + if (propTypes.size() == 1) + return *propTypes.begin(); + + if (!propTypes.empty()) + { + if (isUnion) + { + std::vector options = reduceUnion({propTypes.begin(), propTypes.end()}); + + if (options.empty()) + return neverType; + + if (options.size() == 1) + return options[0]; - exprTable->indexer = TableIndexer{anyIfNonstrict(indexerType), anyIfNonstrict(indexResultType)}; - return indexResultType; + return addType(UnionType{options}); + } + + return addType(IntersectionType{{propTypes.begin(), propTypes.end()}}); + } + } + + DenseHashSet resultTypes{{}}; + + for (auto table : tableTypes) + { + if (table->indexer) + { + const TableIndexer& indexer = *table->indexer; + unify(indexType, indexer.indexType, scope, expr.index->location); + resultTypes.insert(indexer.indexResultType); + } + else if ((ctx == ValueContext::LValue && table->state == TableState::Unsealed) || table->state == TableState::Free) + { + TypeId indexerType = freshType(table->level); + unify(indexType, indexerType, scope, expr.location); + TypeId indexResultType = freshType(table->level); + + table->indexer = TableIndexer{anyIfNonstrict(indexerType), anyIfNonstrict(indexResultType)}; + resultTypes.insert(indexResultType); + } + else + { + /* + * If we use [] indexing to fetch a property from a sealed table that + * has no indexer, we have no idea if it will work so we just return any + * and hope for the best. + */ + + // if this is a union, it's going to be equivalent to `any` no matter what at this point, so we'll just call it done. + if (isUnion) + return anyType; + + resultTypes.insert(anyType); + } + } + + if (resultTypes.size() == 1) + return *resultTypes.begin(); + + if (isUnion) + { + std::vector options = reduceUnion({resultTypes.begin(), resultTypes.end()}); + + if (options.empty()) + return neverType; + + if (options.size() == 1) + return options[0]; + + return addType(UnionType{options}); + } + + return addType(IntersectionType{{resultTypes.begin(), resultTypes.end()}}); } else { - /* - * If we use [] indexing to fetch a property from a sealed table that - * has no indexer, we have no idea if it will work so we just return any - * and hope for the best. - */ - if (FFlag::LuauDontExtendUnsealedRValueTables) - return anyType; + TableType* exprTable = getMutableTableType(exprType); + if (!exprTable) + { + reportError(TypeError{expr.expr->location, NotATable{exprType}}); + return errorRecoveryType(scope); + } + + if (value) + { + const auto& it = exprTable->props.find(value->value.data); + if (it != exprTable->props.end()) + { + return it->second.type(); + } + else if ((ctx == ValueContext::LValue && exprTable->state == TableState::Unsealed) || exprTable->state == TableState::Free) + { + TypeId resultType = freshType(scope); + Property& property = exprTable->props[value->value.data]; + property.setType(resultType); + property.location = expr.index->location; + return resultType; + } + } + + if (exprTable->indexer) + { + const TableIndexer& indexer = *exprTable->indexer; + unify(indexType, indexer.indexType, scope, expr.index->location); + return indexer.indexResultType; + } + else if ((ctx == ValueContext::LValue && exprTable->state == TableState::Unsealed) || exprTable->state == TableState::Free) + { + TypeId indexerType = freshType(exprTable->level); + unify(indexType, indexerType, scope, expr.location); + TypeId indexResultType = freshType(exprTable->level); + + exprTable->indexer = TableIndexer{anyIfNonstrict(indexerType), anyIfNonstrict(indexResultType)}; + return indexResultType; + } else { - TypeId resultType = freshType(scope); - return resultType; + /* + * If we use [] indexing to fetch a property from a sealed table that + * has no indexer, we have no idea if it will work so we just return any + * and hope for the best. + */ + return anyType; } } } @@ -3382,25 +3714,26 @@ TypeId TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprIndex // Primarily about detecting duplicates. TypeId TypeChecker::checkFunctionName(const ScopePtr& scope, AstExpr& funName, TypeLevel level) { - auto freshTy = [&]() { + auto freshTy = [&]() + { return freshType(level); }; if (auto globalName = funName.as()) { - const ScopePtr& globalScope = currentModule->getModuleScope(); + const ScopePtr& moduleScope = currentModule->getModuleScope(); Symbol name = globalName->name; - if (globalScope->bindings.count(name)) + if (moduleScope->bindings.count(name)) { if (isNonstrictMode()) - return globalScope->bindings[name].typeId; + return moduleScope->bindings[name].typeId; return errorRecoveryType(scope); } else { TypeId ty = freshTy(); - globalScope->bindings[name] = {ty, funName.location}; + moduleScope->bindings[name] = {ty, funName.location}; return ty; } } @@ -3430,13 +3763,12 @@ TypeId TypeChecker::checkFunctionName(const ScopePtr& scope, AstExpr& funName, T /* hgy29@gideros: this leads to exception when reparsing the same file for update XXX(fixed) */ if (ttv->props.count(name)) - return ttv->props[name].type; + return ttv->props[name].type(); Property& property = ttv->props[name]; - - property.type = freshTy(); + property.setType(freshTy()); property.location = indexName->indexLocation; - return property.type; + return property.type(); } else if (funName.is()) return errorRecoveryType(scope); @@ -3457,8 +3789,14 @@ TypeId TypeChecker::checkFunctionName(const ScopePtr& scope, AstExpr& funName, T // `(X) -> Y...`, but after typechecking the body, we cam unify `Y...` with `X` // to get type `(X) -> X`, then we quantify the free types to get the final // generic type `(a) -> a`. -std::pair TypeChecker::checkFunctionSignature(const ScopePtr& scope, int subLevel, const AstExprFunction& expr, - std::optional originalName, std::optional selfType, std::optional expectedType) +std::pair TypeChecker::checkFunctionSignature( + const ScopePtr& scope, + int subLevel, + const AstExprFunction& expr, + std::optional originalName, + std::optional selfType, + std::optional expectedType +) { ScopePtr funScope = childFunctionScope(scope, expr.location, subLevel); @@ -3551,25 +3889,11 @@ std::pair TypeChecker::checkFunctionSignature(const ScopePtr& funScope->returnType = retPack; - if (FFlag::DebugLuauSharedSelf) - { - if (expr.self) - { - // TODO: generic self types: CLI-39906 - TypeId selfTy = anyIfNonstrict(selfType ? *selfType : freshType(funScope)); - funScope->bindings[expr.self] = {selfTy, expr.self->location}; - argTypes.push_back(selfTy); - } - } - else + if (expr.self) { - if (expr.self) - { - // TODO: generic self types: CLI-39906 - TypeId selfType = anyIfNonstrict(freshType(funScope)); - funScope->bindings[expr.self] = {selfType, expr.self->location}; - argTypes.push_back(selfType); - } + TypeId selfType = anyIfNonstrict(freshType(funScope)); + funScope->bindings[expr.self] = {selfType, expr.self->location}; + argTypes.push_back(selfType); } // Prepare expected argument type iterators if we have an expected function type @@ -3622,7 +3946,7 @@ std::pair TypeChecker::checkFunctionSignature(const ScopePtr& TypePackId argPack = addTypePack(TypePackVar(TypePack{argTypes, funScope->varargPack})); FunctionDefinition defn; - defn.definitionModuleName = currentModuleName; + defn.definitionModuleName = currentModule->name; defn.definitionLocation = expr.location; defn.varargLocation = expr.vararg ? std::make_optional(expr.varargLocation) : std::nullopt; defn.originalNameLocation = originalName.value_or(Location(expr.location.begin, 0)); @@ -3758,8 +4082,14 @@ WithPredicate TypeChecker::checkExprPackHelper(const ScopePtr& scope } } -void TypeChecker::checkArgumentList(const ScopePtr& scope, const AstExpr& funName, Unifier& state, TypePackId argPack, TypePackId paramPack, - const std::vector& argLocations) +void TypeChecker::checkArgumentList( + const ScopePtr& scope, + const AstExpr& funName, + Unifier& state, + TypePackId argPack, + TypePackId paramPack, + const std::vector& argLocations +) { /* Important terminology refresher: * A function requires parameters. @@ -3771,7 +4101,8 @@ void TypeChecker::checkArgumentList(const ScopePtr& scope, const AstExpr& funNam size_t paramIndex = 0; - auto reportCountMismatchError = [&state, &argLocations, paramPack, argPack, &funName]() { + auto reportCountMismatchError = [&state, &argLocations, paramPack, argPack, &funName]() + { // For this case, we want the error span to cover every errant extra parameter Location location = state.location; if (!argLocations.empty()) @@ -3783,8 +4114,10 @@ void TypeChecker::checkArgumentList(const ScopePtr& scope, const AstExpr& funNam namePath = *path; auto [minParams, optMaxParams] = getParameterExtents(&state.log, paramPack); - state.reportError(TypeError{location, - CountMismatch{minParams, optMaxParams, std::distance(begin(argPack), end(argPack)), CountMismatch::Context::Arg, false, namePath}}); + state.reportError(TypeError{ + location, + CountMismatch{minParams, optMaxParams, std::distance(begin(argPack), end(argPack)), CountMismatch::Context::Arg, false, namePath} + }); }; while (true) @@ -3801,7 +4134,7 @@ void TypeChecker::checkArgumentList(const ScopePtr& scope, const AstExpr& funNam if (argTail) { - if (state.log.getMutable(state.log.follow(*argTail))) + if (state.log.getMutable(state.log.follow(*argTail))) { if (paramTail) state.tryUnify(*paramTail, *argTail); @@ -3816,7 +4149,7 @@ void TypeChecker::checkArgumentList(const ScopePtr& scope, const AstExpr& funNam else if (paramTail) { // argTail is definitely empty - if (state.log.getMutable(state.log.follow(*paramTail))) + if (state.log.getMutable(state.log.follow(*paramTail))) state.log.replace(*paramTail, TypePackVar(TypePack{{}})); } @@ -3891,7 +4224,8 @@ void TypeChecker::checkArgumentList(const ScopePtr& scope, const AstExpr& funNam namePath = *path; state.reportError(TypeError{ - funName.location, CountMismatch{minParams, optMaxParams, paramIndex, CountMismatch::Context::Arg, isVariadic, namePath}}); + funName.location, CountMismatch{minParams, optMaxParams, paramIndex, CountMismatch::Context::Arg, isVariadic, namePath} + }); return; } ++paramIter; @@ -3931,7 +4265,9 @@ void TypeChecker::checkArgumentList(const ScopePtr& scope, const AstExpr& funNam if (argIndex < argLocations.size()) location = argLocations[argIndex]; - unify(*argIter, vtp->ty, scope, location); + state.location = location; + state.tryUnify(*argIter, vtp->ty); + ++argIter; ++argIndex; } @@ -4033,7 +4369,8 @@ WithPredicate TypeChecker::checkExprPackHelper(const ScopePtr& scope // We break this function up into a lambda here to limit our stack footprint. // The vectors used by this function aren't allocated until the lambda is actually called. - auto the_rest = [&]() -> WithPredicate { + auto the_rest = [&]() -> WithPredicate + { // checkExpr will log the pre-instantiated type of the function. // That's not nearly as interesting as the instantiated type, which will include details about how // generic functions are being instantiated for this particular callsite. @@ -4076,7 +4413,8 @@ WithPredicate TypeChecker::checkExprPackHelper(const ScopePtr& scope fn = follow(fn); if (auto ret = checkCallOverload( - scope, expr, fn, retPack, argPack, args, &argLocations, argListResult, overloadsThatMatchArgCount, overloadsThatDont, errors)) + scope, expr, fn, retPack, argPack, args, &argLocations, argListResult, overloadsThatMatchArgCount, overloadsThatDont, errors + )) return *ret; } @@ -4103,7 +4441,8 @@ std::vector> TypeChecker::getExpectedTypesForCall(const st { std::vector> expectedTypes; - auto assignOption = [this, &expectedTypes](size_t index, TypeId ty) { + auto assignOption = [this, &expectedTypes](size_t index, TypeId ty) + { if (index == expectedTypes.size()) { expectedTypes.push_back(ty); @@ -4162,9 +4501,19 @@ std::vector> TypeChecker::getExpectedTypesForCall(const st * If this was an optional, callers would have to pay the stack cost for the result. This is problematic * for functions that need to support recursion up to 600 levels deep. */ -std::unique_ptr> TypeChecker::checkCallOverload(const ScopePtr& scope, const AstExprCall& expr, TypeId fn, - TypePackId retPack, TypePackId argPack, TypePack* args, const std::vector* argLocations, const WithPredicate& argListResult, - std::vector& overloadsThatMatchArgCount, std::vector& overloadsThatDont, std::vector& errors) +std::unique_ptr> TypeChecker::checkCallOverload( + const ScopePtr& scope, + const AstExprCall& expr, + TypeId fn, + TypePackId retPack, + TypePackId argPack, + TypePack* args, + const std::vector* argLocations, + const WithPredicate& argListResult, + std::vector& overloadsThatMatchArgCount, + std::vector& overloadsThatDont, + std::vector& errors +) { LUAU_ASSERT(argLocations); @@ -4278,7 +4627,12 @@ std::unique_ptr> TypeChecker::checkCallOverload(const else overloadsThatDont.push_back(fn); - errors.emplace_back(std::move(state.errors), args->head, ftv); + errors.push_back(OverloadErrorEntry{ + std::move(state.log), + std::move(state.errors), + args->head, + ftv, + }); } else { @@ -4293,12 +4647,17 @@ std::unique_ptr> TypeChecker::checkCallOverload(const return nullptr; } -bool TypeChecker::handleSelfCallMismatch(const ScopePtr& scope, const AstExprCall& expr, TypePack* args, const std::vector& argLocations, - const std::vector& errors) +bool TypeChecker::handleSelfCallMismatch( + const ScopePtr& scope, + const AstExprCall& expr, + TypePack* args, + const std::vector& argLocations, + const std::vector& errors +) { // No overloads succeeded: Scan for one that would have worked had the user // used a.b() rather than a:b() or vice versa. - for (const auto& [_, argVec, ftv] : errors) + for (const auto& e : errors) { // Did you write foo:bar() when you should have written foo.bar()? if (expr.self) @@ -4309,7 +4668,7 @@ bool TypeChecker::handleSelfCallMismatch(const ScopePtr& scope, const AstExprCal TypePackId editedArgPack = addTypePack(TypePack{editedParamList}); Unifier editedState = mkUnifier(scope, expr.location); - checkArgumentList(scope, *expr.func, editedState, editedArgPack, ftv->argTypes, editedArgLocations); + checkArgumentList(scope, *expr.func, editedState, editedArgPack, e.fnTy->argTypes, editedArgLocations); if (editedState.errors.empty()) { @@ -4324,7 +4683,7 @@ bool TypeChecker::handleSelfCallMismatch(const ScopePtr& scope, const AstExprCal return true; } } - else if (ftv->hasSelf) + else if (e.fnTy->hasSelf) { // Did you write foo.bar() when you should have written foo:bar()? if (AstExprIndexName* indexName = expr.func->as()) @@ -4340,7 +4699,7 @@ bool TypeChecker::handleSelfCallMismatch(const ScopePtr& scope, const AstExprCal Unifier editedState = mkUnifier(scope, expr.location); - checkArgumentList(scope, *expr.func, editedState, editedArgPack, ftv->argTypes, editedArgLocations); + checkArgumentList(scope, *expr.func, editedState, editedArgPack, e.fnTy->argTypes, editedArgLocations); if (editedState.errors.empty()) { @@ -4361,13 +4720,22 @@ bool TypeChecker::handleSelfCallMismatch(const ScopePtr& scope, const AstExprCal return false; } -void TypeChecker::reportOverloadResolutionError(const ScopePtr& scope, const AstExprCall& expr, TypePackId retPack, TypePackId argPack, - const std::vector& argLocations, const std::vector& overloads, const std::vector& overloadsThatMatchArgCount, - const std::vector& errors) +void TypeChecker::reportOverloadResolutionError( + const ScopePtr& scope, + const AstExprCall& expr, + TypePackId retPack, + TypePackId argPack, + const std::vector& argLocations, + const std::vector& overloads, + const std::vector& overloadsThatMatchArgCount, + std::vector& errors +) { if (overloads.size() == 1) { - reportErrors(std::get<0>(errors.front())); + errors.front().log.commit(); + + reportErrors(errors.front().errors); return; } @@ -4388,12 +4756,20 @@ void TypeChecker::reportOverloadResolutionError(const ScopePtr& scope, const Ast const FunctionType* ftv = get(overload); - auto error = std::find_if(errors.begin(), errors.end(), [ftv](const OverloadErrorEntry& e) { - return ftv == std::get<2>(e); - }); + auto error = std::find_if( + errors.begin(), + errors.end(), + [ftv](const OverloadErrorEntry& e) + { + return ftv == e.fnTy; + } + ); LUAU_ASSERT(error != errors.end()); - reportErrors(std::get<0>(*error)); + + error->log.commit(); + + reportErrors(error->errors); // If only one overload matched, we don't need this error because we provided the previous errors. if (overloadsThatMatchArgCount.size() == 1) @@ -4434,14 +4810,21 @@ void TypeChecker::reportOverloadResolutionError(const ScopePtr& scope, const Ast return; } -WithPredicate TypeChecker::checkExprList(const ScopePtr& scope, const Location& location, const AstArray& exprs, - bool substituteFreeForNil, const std::vector& instantiateGenerics, const std::vector>& expectedTypes) +WithPredicate TypeChecker::checkExprList( + const ScopePtr& scope, + const Location& location, + const AstArray& exprs, + bool substituteFreeForNil, + const std::vector& instantiateGenerics, + const std::vector>& expectedTypes +) { bool uninhabitable = false; TypePackId pack = addTypePack(TypePack{}); PredicateVec predicates; // At the moment we will be pushing all predicate sets into this. Do we need some way to split them up? - auto insert = [&predicates](PredicateVec& vec) { + auto insert = [&predicates](PredicateVec& vec) + { for (Predicate& c : vec) predicates.push_back(std::move(c)); }; @@ -4565,11 +4948,9 @@ TypeId TypeChecker::checkRequire(const ScopePtr& scope, const ModuleInfo& module } // Types of requires that transitively refer to current module have to be replaced with 'any' - std::string humanReadableName = resolver->getHumanReadableModuleName(moduleInfo.name); - for (const auto& [location, path] : requireCycles) { - if (!path.empty() && path.front() == humanReadableName) + if (!path.empty() && path.front() == moduleInfo.name) return anyType; } @@ -4580,14 +4961,14 @@ TypeId TypeChecker::checkRequire(const ScopePtr& scope, const ModuleInfo& module // either the file does not exist or there's a cycle. If there's a cycle // we will already have reported the error. if (!resolver->moduleExists(moduleInfo.name) && !moduleInfo.optional) - reportError(TypeError{location, UnknownRequire{humanReadableName}}); + reportError(TypeError{location, UnknownRequire{resolver->getHumanReadableModuleName(moduleInfo.name)}}); return errorRecoveryType(scope); } if (module->type != SourceCode::Module) { - reportError(location, IllegalRequire{humanReadableName, "Module is not a ModuleScript. It cannot be required."}); + reportError(location, IllegalRequire{module->humanReadableName, "Module is not a ModuleScript. It cannot be required."}); return errorRecoveryType(scope); } @@ -4599,7 +4980,7 @@ TypeId TypeChecker::checkRequire(const ScopePtr& scope, const ModuleInfo& module std::optional moduleType = first(modulePack); if (!moduleType) { - reportError(location, IllegalRequire{humanReadableName, "Module does not return exactly 1 value. It cannot be required."}); + reportError(location, IllegalRequire{module->humanReadableName, "Module does not return exactly 1 value. It cannot be required."}); return errorRecoveryType(scope); } @@ -4682,17 +5063,17 @@ void TypeChecker::unifyWithInstantiationIfNeeded(TypeId subTy, TypeId superTy, c { // First try unifying with the original uninstantiated type // but if that fails, try the instantiated one. - Unifier child = state.makeChildUnifier(); - child.tryUnify(subTy, superTy, /*isFunctionCall*/ false); - if (!child.errors.empty()) + std::unique_ptr child = state.makeChildUnifier(); + child->tryUnify(subTy, superTy, /*isFunctionCall*/ false); + if (!child->errors.empty()) { - TypeId instantiated = instantiate(scope, subTy, state.location, &child.log); + TypeId instantiated = instantiate(scope, subTy, state.location, &child->log); if (subTy == instantiated) { // Instantiating the argument made no difference, so just report any child errors - state.log.concat(std::move(child.log)); + state.log.concat(std::move(child->log)); - state.errors.insert(state.errors.end(), child.errors.begin(), child.errors.end()); + state.errors.insert(state.errors.end(), child->errors.begin(), child->errors.end()); } else { @@ -4701,7 +5082,7 @@ void TypeChecker::unifyWithInstantiationIfNeeded(TypeId subTy, TypeId superTy, c } else { - state.log.concat(std::move(child.log)); + state.log.concat(std::move(child->log)); } } } @@ -4710,20 +5091,10 @@ TypeId TypeChecker::quantify(const ScopePtr& scope, TypeId ty, Location location { ty = follow(ty); - if (FFlag::DebugLuauSharedSelf) - { - if (auto ftv = get(ty)) - Luau::quantify(ty, scope->level); - else if (auto ttv = getTableType(ty); ttv && ttv->selfTy) - Luau::quantify(ty, scope->level); - } - else - { - const FunctionType* ftv = get(ty); + const FunctionType* ftv = get(ty); - if (ftv) - Luau::quantify(ty, scope->level); - } + if (ftv) + Luau::quantify(ty, scope->level); return ty; } @@ -4733,15 +5104,18 @@ TypeId TypeChecker::instantiate(const ScopePtr& scope, TypeId ty, Location locat ty = follow(ty); const FunctionType* ftv = get(ty); - if (ftv && ftv->hasNoGenerics) + if (ftv && ftv->hasNoFreeOrGenericTypes) return ty; - Instantiation instantiation{log, ¤tModule->internalTypes, scope->level, /*scope*/ nullptr}; + std::optional instantiated; + + reusableInstantiation.resetState(log, ¤tModule->internalTypes, builtinTypes, scope->level, /*scope*/ nullptr); if (instantiationChildLimit) - instantiation.childLimit = *instantiationChildLimit; + reusableInstantiation.childLimit = *instantiationChildLimit; + + instantiated = reusableInstantiation.substitute(ty); - std::optional instantiated = instantiation.substitute(ty); if (instantiated.has_value()) return *instantiated; else @@ -4814,7 +5188,7 @@ void TypeChecker::reportError(const TypeError& error) if (currentModule->mode == Mode::NoCheck) return; currentModule->errors.push_back(error); - currentModule->errors.back().moduleName = currentModuleName; + currentModule->errors.back().moduleName = currentModule->name; } void TypeChecker::reportError(const Location& location, TypeErrorData errorData) @@ -4828,24 +5202,40 @@ void TypeChecker::reportErrors(const ErrorVec& errors) reportError(err); } -void TypeChecker::ice(const std::string& message, const Location& location) +LUAU_NOINLINE void TypeChecker::ice(const std::string& message, const Location& location) { iceHandler->ice(message, location); } -void TypeChecker::ice(const std::string& message) +LUAU_NOINLINE void TypeChecker::ice(const std::string& message) { iceHandler->ice(message); } +LUAU_NOINLINE void TypeChecker::throwTimeLimitError() +{ + throw TimeLimitError(iceHandler->moduleName); +} + +LUAU_NOINLINE void TypeChecker::throwUserCancelError() +{ + throw UserCancelError(iceHandler->moduleName); +} + void TypeChecker::prepareErrorsForDisplay(ErrorVec& errVec) { // Remove errors with names that were generated by recovery from a parse error - errVec.erase(std::remove_if(errVec.begin(), errVec.end(), - [](auto& err) { - return containsParseErrorName(err); - }), - errVec.end()); + errVec.erase( + std::remove_if( + errVec.begin(), + errVec.end(), + [](auto& err) + { + return containsParseErrorName(err); + } + ), + errVec.end() + ); for (auto& err : errVec) { @@ -4859,7 +5249,8 @@ void TypeChecker::diagnoseMissingTableKey(UnknownProperty* utk, TypeErrorData& d std::string_view sv(utk->key); std::set candidates; - auto accumulate = [&](const TableType::Props& props) { + auto accumulate = [&](const TableType::Props& props) + { for (const auto& [name, ty] : props) { if (sv != name && equalsLower(sv, name)) @@ -4913,30 +5304,35 @@ ScopePtr TypeChecker::childScope(const ScopePtr& parent, const Location& locatio void TypeChecker::merge(RefinementMap& l, const RefinementMap& r) { - Luau::merge(l, r, [this](TypeId a, TypeId b) { - // TODO: normalize(UnionType{{a, b}}) - std::unordered_set set; + Luau::merge( + l, + r, + [this](TypeId a, TypeId b) + { + // TODO: normalize(UnionType{{a, b}}) + std::unordered_set set; - if (auto utv = get(follow(a))) - set.insert(begin(utv), end(utv)); - else - set.insert(a); + if (auto utv = get(follow(a))) + set.insert(begin(utv), end(utv)); + else + set.insert(a); - if (auto utv = get(follow(b))) - set.insert(begin(utv), end(utv)); - else - set.insert(b); + if (auto utv = get(follow(b))) + set.insert(begin(utv), end(utv)); + else + set.insert(b); - std::vector options(set.begin(), set.end()); - if (set.size() == 1) - return options[0]; - return addType(UnionType{std::move(options)}); - }); + std::vector options(set.begin(), set.end()); + if (set.size() == 1) + return options[0]; + return addType(UnionType{std::move(options)}); + } + ); } Unifier TypeChecker::mkUnifier(const ScopePtr& scope, const Location& location) { - return Unifier{NotNull{&normalizer}, currentModule->mode, NotNull{scope.get()}, location, Variance::Covariant}; + return Unifier{NotNull{&normalizer}, NotNull{scope.get()}, location, Variance::Covariant}; } TypeId TypeChecker::freshType(const ScopePtr& scope) @@ -4982,7 +5378,8 @@ TypePackId TypeChecker::errorRecoveryTypePack(TypePackId guess) TypeIdPredicate TypeChecker::mkTruthyPredicate(bool sense, TypeId emptySetTy) { - return [this, sense, emptySetTy](TypeId ty) -> std::optional { + return [this, sense, emptySetTy](TypeId ty) -> std::optional + { // any/error/free gets a special pass unconditionally because they can't be decided. if (get(ty) || get(ty) || get(ty)) return ty; @@ -5124,12 +5521,22 @@ TypeId TypeChecker::resolveTypeWorker(const ScopePtr& scope, const AstType& anno return tf->type; bool parameterCountErrorReported = false; - bool hasDefaultTypes = std::any_of(tf->typeParams.begin(), tf->typeParams.end(), [](auto&& el) { - return el.defaultValue.has_value(); - }); - bool hasDefaultPacks = std::any_of(tf->typePackParams.begin(), tf->typePackParams.end(), [](auto&& el) { - return el.defaultValue.has_value(); - }); + bool hasDefaultTypes = std::any_of( + tf->typeParams.begin(), + tf->typeParams.end(), + [](auto&& el) + { + return el.defaultValue.has_value(); + } + ); + bool hasDefaultPacks = std::any_of( + tf->typePackParams.begin(), + tf->typePackParams.end(), + [](auto&& el) + { + return el.defaultValue.has_value(); + } + ); if (!lit->hasParameterList) { @@ -5252,7 +5659,8 @@ TypeId TypeChecker::resolveTypeWorker(const ScopePtr& scope, const AstType& anno { if (!parameterCountErrorReported) reportError( - TypeError{annotation.location, IncorrectGenericParameterCount{lit->name.value, *tf, typeParams.size(), typePackParams.size()}}); + TypeError{annotation.location, IncorrectGenericParameterCount{lit->name.value, *tf, typeParams.size(), typePackParams.size()}} + ); // Pad the types out with error recovery types while (typeParams.size() < tf->typeParams.size()) @@ -5261,13 +5669,26 @@ TypeId TypeChecker::resolveTypeWorker(const ScopePtr& scope, const AstType& anno typePackParams.push_back(errorRecoveryTypePack(scope)); } - bool sameTys = std::equal(typeParams.begin(), typeParams.end(), tf->typeParams.begin(), tf->typeParams.end(), [](auto&& itp, auto&& tp) { - return itp == tp.ty; - }); + bool sameTys = std::equal( + typeParams.begin(), + typeParams.end(), + tf->typeParams.begin(), + tf->typeParams.end(), + [](auto&& itp, auto&& tp) + { + return itp == tp.ty; + } + ); bool sameTps = std::equal( - typePackParams.begin(), typePackParams.end(), tf->typePackParams.begin(), tf->typePackParams.end(), [](auto&& itpp, auto&& tpp) { + typePackParams.begin(), + typePackParams.end(), + tf->typePackParams.begin(), + tf->typePackParams.end(), + [](auto&& itpp, auto&& tpp) + { return itpp == tpp.tp; - }); + } + ); // If the generic parameters and the type arguments are the same, we are about to // perform an identity substitution, which we can just short-circuit. @@ -5282,13 +5703,31 @@ TypeId TypeChecker::resolveTypeWorker(const ScopePtr& scope, const AstType& anno std::optional tableIndexer; for (const auto& prop : table->props) - props[prop.name.value] = {resolveType(scope, *prop.type)}; + { + if (prop.access == AstTableAccess::Read) + reportError(prop.accessLocation.value_or(Location{}), GenericError{"read keyword is illegal here"}); + else if (prop.access == AstTableAccess::Write) + reportError(prop.accessLocation.value_or(Location{}), GenericError{"write keyword is illegal here"}); + else if (prop.access == AstTableAccess::ReadWrite) + props[prop.name.value] = {resolveType(scope, *prop.type), /* deprecated: */ false, {}, std::nullopt, {}, std::nullopt, prop.location}; + else + ice("Unexpected property access " + std::to_string(int(prop.access))); + } if (const auto& indexer = table->indexer) - tableIndexer = TableIndexer(resolveType(scope, *indexer->indexType), resolveType(scope, *indexer->resultType)); + { + if (indexer->access == AstTableAccess::Read) + reportError(indexer->accessLocation.value_or(Location{}), GenericError{"read keyword is illegal here"}); + else if (indexer->access == AstTableAccess::Write) + reportError(indexer->accessLocation.value_or(Location{}), GenericError{"write keyword is illegal here"}); + else if (indexer->access == AstTableAccess::ReadWrite) + tableIndexer = TableIndexer(resolveType(scope, *indexer->indexType), resolveType(scope, *indexer->resultType)); + else + ice("Unexpected property access " + std::to_string(int(indexer->access))); + } TableType ttv{props, tableIndexer, scope->level, TableState::Sealed}; - ttv.definitionModuleName = currentModuleName; + ttv.definitionModuleName = currentModule->name; ttv.definitionLocation = annotation.location; return addType(std::move(ttv)); } @@ -5304,15 +5743,27 @@ TypeId TypeChecker::resolveTypeWorker(const ScopePtr& scope, const AstType& anno std::vector genericTys; genericTys.reserve(generics.size()); - std::transform(generics.begin(), generics.end(), std::back_inserter(genericTys), [](auto&& el) { - return el.ty; - }); + std::transform( + generics.begin(), + generics.end(), + std::back_inserter(genericTys), + [](auto&& el) + { + return el.ty; + } + ); std::vector genericTps; genericTps.reserve(genericPacks.size()); - std::transform(genericPacks.begin(), genericPacks.end(), std::back_inserter(genericTps), [](auto&& el) { - return el.tp; - }); + std::transform( + genericPacks.begin(), + genericPacks.end(), + std::back_inserter(genericTps), + [](auto&& el) + { + return el.tp; + } + ); TypeId fnType = addType(FunctionType{funcScope->level, std::move(genericTys), std::move(genericTps), argTypes, retTypes}); @@ -5433,8 +5884,13 @@ TypePackId TypeChecker::resolveTypePack(const ScopePtr& scope, const AstTypePack return result; } -TypeId TypeChecker::instantiateTypeFun(const ScopePtr& scope, const TypeFun& tf, const std::vector& typeParams, - const std::vector& typePackParams, const Location& location) +TypeId TypeChecker::instantiateTypeFun( + const ScopePtr& scope, + const TypeFun& tf, + const std::vector& typeParams, + const std::vector& typePackParams, + const Location& location +) { if (tf.typeParams.empty() && tf.typePackParams.empty()) return tf.type; @@ -5462,7 +5918,8 @@ TypeId TypeChecker::instantiateTypeFun(const ScopePtr& scope, const TypeFun& tf, TypeId instantiated = *maybeInstantiated; TypeId target = follow(instantiated); - bool needsClone = follow(tf.type) == target; + const TableType* tfTable = getTableType(tf.type); + bool needsClone = follow(tf.type) == target || (tfTable != nullptr && tfTable == getTableType(target)); bool shouldMutate = getTableType(tf.type); TableType* ttv = getMutableTableType(target); @@ -5490,15 +5947,21 @@ TypeId TypeChecker::instantiateTypeFun(const ScopePtr& scope, const TypeFun& tf, { ttv->instantiatedTypeParams = typeParams; ttv->instantiatedTypePackParams = typePackParams; - ttv->definitionModuleName = currentModuleName; + ttv->definitionModuleName = currentModule->name; ttv->definitionLocation = location; } return instantiated; } -GenericTypeDefinitions TypeChecker::createGenericTypes(const ScopePtr& scope, std::optional levelOpt, const AstNode& node, - const AstArray& genericNames, const AstArray& genericPackNames, bool useCache) +GenericTypeDefinitions TypeChecker::createGenericTypes( + const ScopePtr& scope, + std::optional levelOpt, + const AstNode& node, + const AstArray& genericNames, + const AstArray& genericPackNames, + bool useCache +) { LUAU_ASSERT(scope->parent); @@ -5533,7 +5996,7 @@ GenericTypeDefinitions TypeChecker::createGenericTypes(const ScopePtr& scope, st } else { - g = addType(Unifiable::Generic{level, n}); + g = addType(GenericType{level, n}); } generics.push_back({g, defaultValue}); @@ -5561,7 +6024,7 @@ GenericTypeDefinitions TypeChecker::createGenericTypes(const ScopePtr& scope, st TypePackId& cached = scope->parent->typeAliasTypePackParameters[n]; if (!cached) - cached = addTypePack(TypePackVar{Unifiable::Generic{level, n}}); + cached = addTypePack(TypePackVar{GenericTypePack{level, n}}); genericPacks.push_back({cached, defaultValue}); scope->privateTypePackBindings[n] = cached; @@ -5626,7 +6089,8 @@ void TypeChecker::refineLValue(const LValue& lvalue, RefinementMap& refis, const } } - auto intoType = [this](const std::unordered_set& s) -> std::optional { + auto intoType = [this](const std::unordered_set& s) -> std::optional + { if (s.empty()) return std::nullopt; @@ -5813,7 +6277,8 @@ void TypeChecker::resolve(const OrPredicate& orP, RefinementMap& refis, const Sc void TypeChecker::resolve(const IsAPredicate& isaP, RefinementMap& refis, const ScopePtr& scope, bool sense) { - auto predicate = [&](TypeId option) -> std::optional { + auto predicate = [&](TypeId option) -> std::optional + { // This by itself is not truly enough to determine that A is stronger than B or vice versa. bool optionIsSubtype = canUnify(option, isaP.ty, scope, isaP.location).empty(); bool targetIsSubtype = canUnify(isaP.ty, option, scope, isaP.location).empty(); @@ -5876,8 +6341,10 @@ void TypeChecker::resolve(const TypeGuardPredicate& typeguardP, RefinementMap& r return; } - auto refine = [this, &lvalue = typeguardP.lvalue, &refis, &scope, sense](bool(f)(TypeId), std::optional mapsTo = std::nullopt) { - TypeIdPredicate predicate = [f, mapsTo, sense](TypeId ty) -> std::optional { + auto refine = [this, &lvalue = typeguardP.lvalue, &refis, &scope, sense](bool(f)(TypeId), std::optional mapsTo = std::nullopt) + { + TypeIdPredicate predicate = [f, mapsTo, sense](TypeId ty) -> std::optional + { if (sense && get(ty)) return mapsTo.value_or(ty); @@ -5904,24 +6371,35 @@ void TypeChecker::resolve(const TypeGuardPredicate& typeguardP, RefinementMap& r return refine(isBoolean, booleanType); else if (typeguardP.kind == "thread") return refine(isThread, threadType); + else if (typeguardP.kind == "buffer") + return refine(isBuffer, bufferType); else if (typeguardP.kind == "table") { - return refine([](TypeId ty) -> bool { - return isTableIntersection(ty) || get(ty) || get(ty); - }); + return refine( + [](TypeId ty) -> bool + { + return isTableIntersection(ty) || get(ty) || get(ty); + } + ); } else if (typeguardP.kind == "function") { - return refine([](TypeId ty) -> bool { - return isOverloadedFunction(ty) || get(ty); - }); + return refine( + [](TypeId ty) -> bool + { + return isOverloadedFunction(ty) || get(ty); + } + ); } else if (typeguardP.kind == "userdata") { // For now, we don't really care about being accurate with userdata if the typeguard was using typeof. - return refine([](TypeId ty) -> bool { - return get(ty); - }); + return refine( + [](TypeId ty) -> bool + { + return get(ty); + } + ); } if (!typeguardP.isTypeof) @@ -5934,17 +6412,13 @@ void TypeChecker::resolve(const TypeGuardPredicate& typeguardP, RefinementMap& r TypeId type = follow(typeFun->type); // You cannot refine to the top class type. - if (FFlag::LuauNegatedClassTypes) + if (type == builtinTypes->classType) { - if (type == builtinTypes->classType) - { - return addRefinement(refis, typeguardP.lvalue, errorRecoveryType(scope)); - } + return addRefinement(refis, typeguardP.lvalue, errorRecoveryType(scope)); } // We're only interested in the root class of any classes. - if (auto ctv = get(type); - !ctv || (FFlag::LuauNegatedClassTypes ? (ctv->parent != builtinTypes->classType) : (ctv->parent != std::nullopt))) + if (auto ctv = get(type); !ctv || (ctv->parent != builtinTypes->classType && !hasTag(type, kTypeofRootTag))) return addRefinement(refis, typeguardP.lvalue, errorRecoveryType(scope)); // This probably hints at breaking out type filtering functions from the predicate solver so that typeof is not tightly coupled with IsA. @@ -5955,7 +6429,8 @@ void TypeChecker::resolve(const TypeGuardPredicate& typeguardP, RefinementMap& r void TypeChecker::resolve(const EqPredicate& eqP, RefinementMap& refis, const ScopePtr& scope, bool sense) { // This refinement will require success typing to do everything correctly. For now, we can get most of the way there. - auto options = [](TypeId ty) -> std::vector { + auto options = [](TypeId ty) -> std::vector + { if (auto utv = get(follow(ty))) return std::vector(begin(utv), end(utv)); return {ty}; @@ -5966,7 +6441,8 @@ void TypeChecker::resolve(const EqPredicate& eqP, RefinementMap& refis, const Sc if (sense && std::any_of(rhs.begin(), rhs.end(), isUndecidable)) return; // Optimization: the other side has unknown types, so there's probably an overlap. Refining is no-op here. - auto predicate = [&](TypeId option) -> std::optional { + auto predicate = [&](TypeId option) -> std::optional + { if (!sense && isNil(eqP.type)) return (isUndecidable(option) || !isNil(option)) ? std::optional(option) : std::nullopt; diff --git a/Analysis/src/TypeOrPack.cpp b/Analysis/src/TypeOrPack.cpp new file mode 100644 index 000000000..86652141d --- /dev/null +++ b/Analysis/src/TypeOrPack.cpp @@ -0,0 +1,29 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Luau/TypeOrPack.h" +#include "Luau/Common.h" + +namespace Luau +{ + +const void* ptr(TypeOrPack tyOrTp) +{ + if (auto ty = get(tyOrTp)) + return static_cast(*ty); + else if (auto tp = get(tyOrTp)) + return static_cast(*tp); + else + LUAU_UNREACHABLE(); +} + +TypeOrPack follow(TypeOrPack tyOrTp) +{ + if (auto ty = get(tyOrTp)) + return follow(*ty); + else if (auto tp = get(tyOrTp)) + return follow(*tp); + else + LUAU_UNREACHABLE(); +} + +} // namespace Luau diff --git a/Analysis/src/TypePack.cpp b/Analysis/src/TypePack.cpp index ccea604ff..7e11d462e 100644 --- a/Analysis/src/TypePack.cpp +++ b/Analysis/src/TypePack.cpp @@ -6,9 +6,74 @@ #include +LUAU_FASTFLAG(LuauSolverV2); + namespace Luau { +FreeTypePack::FreeTypePack(TypeLevel level) + : index(Unifiable::freshIndex()) + , level(level) + , scope(nullptr) +{ +} + +FreeTypePack::FreeTypePack(Scope* scope) + : index(Unifiable::freshIndex()) + , level{} + , scope(scope) +{ +} + +FreeTypePack::FreeTypePack(Scope* scope, TypeLevel level) + : index(Unifiable::freshIndex()) + , level(level) + , scope(scope) +{ +} + +GenericTypePack::GenericTypePack() + : index(Unifiable::freshIndex()) + , name("g" + std::to_string(index)) +{ +} + +GenericTypePack::GenericTypePack(TypeLevel level) + : index(Unifiable::freshIndex()) + , level(level) + , name("g" + std::to_string(index)) +{ +} + +GenericTypePack::GenericTypePack(const Name& name) + : index(Unifiable::freshIndex()) + , name(name) + , explicitName(true) +{ +} + +GenericTypePack::GenericTypePack(Scope* scope) + : index(Unifiable::freshIndex()) + , scope(scope) +{ +} + +GenericTypePack::GenericTypePack(TypeLevel level, const Name& name) + : index(Unifiable::freshIndex()) + , level(level) + , name(name) + , explicitName(true) +{ +} + +GenericTypePack::GenericTypePack(Scope* scope, const Name& name) + : index(Unifiable::freshIndex()) + , scope(scope) + , name(name) + , explicitName(true) +{ +} + BlockedTypePack::BlockedTypePack() : index(++nextIndex) { @@ -160,8 +225,8 @@ bool areEqual(SeenSet& seen, const TypePackVar& lhs, const TypePackVar& rhs) TypePackId rhsTail = *rhsIter.tail(); { - const Unifiable::Free* lf = get_if(&lhsTail->ty); - const Unifiable::Free* rf = get_if(&rhsTail->ty); + const FreeTypePack* lf = get_if(&lhsTail->ty); + const FreeTypePack* rf = get_if(&rhsTail->ty); if (lf && rf) return lf->index == rf->index; } @@ -174,8 +239,8 @@ bool areEqual(SeenSet& seen, const TypePackVar& lhs, const TypePackVar& rhs) } { - const Unifiable::Generic* lg = get_if(&lhsTail->ty); - const Unifiable::Generic* rg = get_if(&rhsTail->ty); + const GenericTypePack* lg = get_if(&lhsTail->ty); + const GenericTypePack* rg = get_if(&rhsTail->ty); if (lg && rg) return lg->index == rg->index; } @@ -192,16 +257,26 @@ bool areEqual(SeenSet& seen, const TypePackVar& lhs, const TypePackVar& rhs) TypePackId follow(TypePackId tp) { - return follow(tp, [](TypePackId t) { - return t; - }); + return follow( + tp, + nullptr, + [](const void*, TypePackId t) + { + return t; + } + ); } -TypePackId follow(TypePackId tp, std::function mapper) +TypePackId follow(TypePackId tp, const void* context, TypePackId (*mapper)(const void*, TypePackId)) { - auto advance = [&mapper](TypePackId ty) -> std::optional { - if (const Unifiable::Bound* btv = get>(mapper(ty))) + auto advance = [context, mapper](TypePackId ty) -> std::optional + { + TypePackId mapped = mapper(context, ty); + + if (const Unifiable::Bound* btv = get>(mapped)) return btv->boundTo; + else if (const TypePack* tp = get(mapped); tp && tp->head.empty()) + return tp->tail; else return std::nullopt; }; @@ -212,6 +287,9 @@ TypePackId follow(TypePackId tp, std::function mapper) else return tp; + if (!advance(cycleTester)) // Short circuit traversal for the rather common case when advance(advance(t)) == null + return cycleTester; + while (true) { auto a1 = advance(tp); @@ -393,4 +471,11 @@ bool containsNever(TypePackId tp) return false; } +template<> +LUAU_NOINLINE Unifiable::Bound* emplaceTypePack(TypePackVar* ty, TypePackId& tyArg) +{ + LUAU_ASSERT(ty != follow(tyArg)); + return &ty->ty.emplace(tyArg); +} + } // namespace Luau diff --git a/Analysis/src/TypePath.cpp b/Analysis/src/TypePath.cpp new file mode 100644 index 000000000..855ac3034 --- /dev/null +++ b/Analysis/src/TypePath.cpp @@ -0,0 +1,722 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Luau/TypePath.h" +#include "Luau/Common.h" +#include "Luau/DenseHash.h" +#include "Luau/Type.h" +#include "Luau/TypeFwd.h" +#include "Luau/TypePack.h" +#include "Luau/TypeOrPack.h" + +#include +#include +#include +#include + +LUAU_FASTFLAG(LuauSolverV2); + +// Maximum number of steps to follow when traversing a path. May not always +// equate to the number of components in a path, depending on the traversal +// logic. +LUAU_DYNAMIC_FASTINTVARIABLE(LuauTypePathMaximumTraverseSteps, 100); + +namespace Luau +{ + +namespace TypePath +{ + +Property::Property(std::string name) + : name(std::move(name)) +{ + LUAU_ASSERT(!FFlag::LuauSolverV2); +} + +Property Property::read(std::string name) +{ + return Property(std::move(name), true); +} + +Property Property::write(std::string name) +{ + return Property(std::move(name), false); +} + +bool Property::operator==(const Property& other) const +{ + return name == other.name && isRead == other.isRead; +} + +bool Index::operator==(const Index& other) const +{ + return index == other.index; +} + +bool Reduction::operator==(const Reduction& other) const +{ + return resultType == other.resultType; +} + +Path Path::append(const Path& suffix) const +{ + std::vector joined(components); + joined.reserve(suffix.components.size()); + joined.insert(joined.end(), suffix.components.begin(), suffix.components.end()); + return Path(std::move(joined)); +} + +Path Path::push(Component component) const +{ + std::vector joined(components); + joined.push_back(component); + return Path(std::move(joined)); +} + +Path Path::push_front(Component component) const +{ + std::vector joined{}; + joined.reserve(components.size() + 1); + joined.push_back(std::move(component)); + joined.insert(joined.end(), components.begin(), components.end()); + return Path(std::move(joined)); +} + +Path Path::pop() const +{ + if (empty()) + return kEmpty; + + std::vector popped(components); + popped.pop_back(); + return Path(std::move(popped)); +} + +std::optional Path::last() const +{ + if (empty()) + return std::nullopt; + + return components.back(); +} + +bool Path::empty() const +{ + return components.empty(); +} + +bool Path::operator==(const Path& other) const +{ + return components == other.components; +} + +size_t PathHash::operator()(const Property& prop) const +{ + return std::hash()(prop.name) ^ static_cast(prop.isRead); +} + +size_t PathHash::operator()(const Index& idx) const +{ + return idx.index; +} + +size_t PathHash::operator()(const TypeField& field) const +{ + return static_cast(field); +} + +size_t PathHash::operator()(const PackField& field) const +{ + return static_cast(field); +} + +size_t PathHash::operator()(const Reduction& reduction) const +{ + return std::hash()(reduction.resultType); +} + +size_t PathHash::operator()(const Component& component) const +{ + return visit(*this, component); +} + +size_t PathHash::operator()(const Path& path) const +{ + size_t hash = 0; + + for (const Component& component : path.components) + hash ^= (*this)(component); + + return hash; +} + +Path PathBuilder::build() +{ + return Path(std::move(components)); +} + +PathBuilder& PathBuilder::readProp(std::string name) +{ + LUAU_ASSERT(FFlag::LuauSolverV2); + components.push_back(Property{std::move(name), true}); + return *this; +} + +PathBuilder& PathBuilder::writeProp(std::string name) +{ + LUAU_ASSERT(FFlag::LuauSolverV2); + components.push_back(Property{std::move(name), false}); + return *this; +} + +PathBuilder& PathBuilder::prop(std::string name) +{ + LUAU_ASSERT(!FFlag::LuauSolverV2); + components.push_back(Property{std::move(name)}); + return *this; +} + +PathBuilder& PathBuilder::index(size_t i) +{ + components.push_back(Index{i}); + return *this; +} + +PathBuilder& PathBuilder::mt() +{ + components.push_back(TypeField::Metatable); + return *this; +} + +PathBuilder& PathBuilder::lb() +{ + components.push_back(TypeField::LowerBound); + return *this; +} + +PathBuilder& PathBuilder::ub() +{ + components.push_back(TypeField::UpperBound); + return *this; +} + +PathBuilder& PathBuilder::indexKey() +{ + components.push_back(TypeField::IndexLookup); + return *this; +} + +PathBuilder& PathBuilder::indexValue() +{ + components.push_back(TypeField::IndexResult); + return *this; +} + +PathBuilder& PathBuilder::negated() +{ + components.push_back(TypeField::Negated); + return *this; +} + +PathBuilder& PathBuilder::variadic() +{ + components.push_back(TypeField::Variadic); + return *this; +} + +PathBuilder& PathBuilder::args() +{ + components.push_back(PackField::Arguments); + return *this; +} + +PathBuilder& PathBuilder::rets() +{ + components.push_back(PackField::Returns); + return *this; +} + +PathBuilder& PathBuilder::tail() +{ + components.push_back(PackField::Tail); + return *this; +} + +} // namespace TypePath + +namespace +{ + +struct TraversalState +{ + TraversalState(TypeId root, NotNull builtinTypes) + : current(root) + , builtinTypes(builtinTypes) + { + } + TraversalState(TypePackId root, NotNull builtinTypes) + : current(root) + , builtinTypes(builtinTypes) + { + } + + TypeOrPack current; + NotNull builtinTypes; + int steps = 0; + + void updateCurrent(TypeId ty) + { + LUAU_ASSERT(ty); + current = follow(ty); + } + + void updateCurrent(TypePackId tp) + { + LUAU_ASSERT(tp); + current = follow(tp); + } + + bool tooLong() + { + return ++steps > DFInt::LuauTypePathMaximumTraverseSteps; + } + + bool checkInvariants() + { + return tooLong(); + } + + bool traverse(const TypePath::Property& property) + { + auto currentType = get(current); + if (!currentType) + return false; + + if (checkInvariants()) + return false; + + const Property* prop = nullptr; + + if (auto t = get(*currentType)) + { + auto it = t->props.find(property.name); + if (it != t->props.end()) + { + prop = &it->second; + } + } + else if (auto c = get(*currentType)) + { + prop = lookupClassProp(c, property.name); + } + // For a metatable type, the table takes priority; check that before + // falling through to the metatable entry below. + else if (auto m = get(*currentType)) + { + TypeOrPack pinned = current; + updateCurrent(m->table); + + if (traverse(property)) + return true; + + // Restore the old current type if we didn't traverse the metatable + // successfully; we'll use the next branch to address this. + current = pinned; + } + + if (!prop) + { + if (auto m = getMetatable(*currentType, builtinTypes)) + { + // Weird: rather than use findMetatableEntry, which requires a lot + // of stuff that we don't have and don't want to pull in, we use the + // path traversal logic to grab __index and then re-enter the lookup + // logic there. + updateCurrent(*m); + + if (!traverse(TypePath::Property::read("__index"))) + return false; + + return traverse(property); + } + } + + if (prop) + { + std::optional maybeType; + if (FFlag::LuauSolverV2) + maybeType = property.isRead ? prop->readTy : prop->writeTy; + else + maybeType = prop->type(); + + if (maybeType) + { + updateCurrent(*maybeType); + return true; + } + } + + return false; + } + + bool traverse(const TypePath::Index& index) + { + if (checkInvariants()) + return false; + + if (auto currentType = get(current)) + { + if (auto u = get(*currentType)) + { + auto it = begin(u); + std::advance(it, index.index); + if (it != end(u)) + { + updateCurrent(*it); + return true; + } + } + else if (auto i = get(*currentType)) + { + auto it = begin(i); + std::advance(it, index.index); + if (it != end(i)) + { + updateCurrent(*it); + return true; + } + } + } + else + { + auto currentPack = get(current); + LUAU_ASSERT(currentPack); + if (get(*currentPack)) + { + auto it = begin(*currentPack); + + for (size_t i = 0; i < index.index && it != end(*currentPack); ++i) + ++it; + + if (it != end(*currentPack)) + { + updateCurrent(*it); + return true; + } + } + } + + return false; + } + + bool traverse(TypePath::TypeField field) + { + if (checkInvariants()) + return false; + + switch (field) + { + case TypePath::TypeField::Table: + if (auto mt = get(current)) + { + updateCurrent(mt->table); + return true; + } + + return false; + case TypePath::TypeField::Metatable: + if (auto currentType = get(current)) + { + if (std::optional mt = getMetatable(*currentType, builtinTypes)) + { + updateCurrent(*mt); + return true; + } + } + + return false; + case TypePath::TypeField::LowerBound: + case TypePath::TypeField::UpperBound: + if (auto ft = get(current)) + { + updateCurrent(field == TypePath::TypeField::LowerBound ? ft->lowerBound : ft->upperBound); + return true; + } + + return false; + case TypePath::TypeField::IndexLookup: + case TypePath::TypeField::IndexResult: + { + const TableIndexer* indexer = nullptr; + + if (auto tt = get(current); tt && tt->indexer) + indexer = &(*tt->indexer); + else if (auto mt = get(current)) + { + if (auto mtTab = get(follow(mt->table)); mtTab && mtTab->indexer) + indexer = &(*mtTab->indexer); + else if (auto mtMt = get(follow(mt->metatable)); mtMt && mtMt->indexer) + indexer = &(*mtMt->indexer); + } + // Note: we don't appear to walk the class hierarchy for indexers + else if (auto ct = get(current); ct && ct->indexer) + indexer = &(*ct->indexer); + + if (indexer) + { + updateCurrent(field == TypePath::TypeField::IndexLookup ? indexer->indexType : indexer->indexResultType); + return true; + } + + return false; + } + case TypePath::TypeField::Negated: + if (auto nt = get(current)) + { + updateCurrent(nt->ty); + return true; + } + + return false; + case TypePath::TypeField::Variadic: + if (auto vtp = get(current)) + { + updateCurrent(vtp->ty); + return true; + } + + return false; + } + + return false; + } + + bool traverse(TypePath::Reduction reduction) + { + if (checkInvariants()) + return false; + updateCurrent(reduction.resultType); + return true; + } + + bool traverse(TypePath::PackField field) + { + if (checkInvariants()) + return false; + + switch (field) + { + case TypePath::PackField::Arguments: + case TypePath::PackField::Returns: + if (auto ft = get(current)) + { + updateCurrent(field == TypePath::PackField::Arguments ? ft->argTypes : ft->retTypes); + return true; + } + + return false; + case TypePath::PackField::Tail: + if (auto currentPack = get(current)) + { + auto it = begin(*currentPack); + while (it != end(*currentPack)) + ++it; + + if (auto tail = it.tail()) + { + updateCurrent(*tail); + return true; + } + } + + return false; + } + + return false; + } +}; + +} // namespace + +std::string toString(const TypePath::Path& path, bool prefixDot) +{ + std::stringstream result; + bool first = true; + + auto strComponent = [&](auto&& c) + { + using T = std::decay_t; + if constexpr (std::is_same_v) + { + result << '['; + if (FFlag::LuauSolverV2) + { + if (c.isRead) + result << "read "; + else + result << "write "; + } + + result << '"' << c.name << '"' << ']'; + } + else if constexpr (std::is_same_v) + { + result << '[' << std::to_string(c.index) << ']'; + } + else if constexpr (std::is_same_v) + { + if (!first || prefixDot) + result << '.'; + + switch (c) + { + case TypePath::TypeField::Table: + result << "table"; + break; + case TypePath::TypeField::Metatable: + result << "metatable"; + break; + case TypePath::TypeField::LowerBound: + result << "lowerBound"; + break; + case TypePath::TypeField::UpperBound: + result << "upperBound"; + break; + case TypePath::TypeField::IndexLookup: + result << "indexer"; + break; + case TypePath::TypeField::IndexResult: + result << "indexResult"; + break; + case TypePath::TypeField::Negated: + result << "negated"; + break; + case TypePath::TypeField::Variadic: + result << "variadic"; + break; + } + + result << "()"; + } + else if constexpr (std::is_same_v) + { + if (!first || prefixDot) + result << '.'; + + switch (c) + { + case TypePath::PackField::Arguments: + result << "arguments"; + break; + case TypePath::PackField::Returns: + result << "returns"; + break; + case TypePath::PackField::Tail: + result << "tail"; + break; + } + result << "()"; + } + else if constexpr (std::is_same_v) + { + // We need to rework the TypePath system to make subtyping failures easier to understand + // https://roblox.atlassian.net/browse/CLI-104422 + result << "~~>"; + } + else + { + static_assert(always_false_v, "Unhandled Component variant"); + } + + first = false; + }; + + for (const TypePath::Component& component : path.components) + Luau::visit(strComponent, component); + + return result.str(); +} + +static bool traverse(TraversalState& state, const Path& path) +{ + auto step = [&state](auto&& c) + { + return state.traverse(c); + }; + + for (const TypePath::Component& component : path.components) + { + bool stepSuccess = visit(step, component); + if (!stepSuccess) + return false; + } + + return true; +} + +std::optional traverse(TypeId root, const Path& path, NotNull builtinTypes) +{ + TraversalState state(follow(root), builtinTypes); + if (traverse(state, path)) + return state.current; + else + return std::nullopt; +} + +std::optional traverse(TypePackId root, const Path& path, NotNull builtinTypes) +{ + TraversalState state(follow(root), builtinTypes); + if (traverse(state, path)) + return state.current; + else + return std::nullopt; +} + +std::optional traverseForType(TypeId root, const Path& path, NotNull builtinTypes) +{ + TraversalState state(follow(root), builtinTypes); + if (traverse(state, path)) + { + auto ty = get(state.current); + return ty ? std::make_optional(*ty) : std::nullopt; + } + else + return std::nullopt; +} + +std::optional traverseForType(TypePackId root, const Path& path, NotNull builtinTypes) +{ + TraversalState state(follow(root), builtinTypes); + if (traverse(state, path)) + { + auto ty = get(state.current); + return ty ? std::make_optional(*ty) : std::nullopt; + } + else + return std::nullopt; +} + +std::optional traverseForPack(TypeId root, const Path& path, NotNull builtinTypes) +{ + TraversalState state(follow(root), builtinTypes); + if (traverse(state, path)) + { + auto ty = get(state.current); + return ty ? std::make_optional(*ty) : std::nullopt; + } + else + return std::nullopt; +} + +std::optional traverseForPack(TypePackId root, const Path& path, NotNull builtinTypes) +{ + TraversalState state(follow(root), builtinTypes); + if (traverse(state, path)) + { + auto ty = get(state.current); + return ty ? std::make_optional(*ty) : std::nullopt; + } + else + return std::nullopt; +} + +} // namespace Luau diff --git a/Analysis/src/TypeReduction.cpp b/Analysis/src/TypeReduction.cpp deleted file mode 100644 index abafa9fbc..000000000 --- a/Analysis/src/TypeReduction.cpp +++ /dev/null @@ -1,1162 +0,0 @@ -// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details -#include "Luau/TypeReduction.h" - -#include "Luau/Common.h" -#include "Luau/Error.h" -#include "Luau/RecursionCounter.h" -#include "Luau/VisitType.h" - -#include -#include - -LUAU_FASTINTVARIABLE(LuauTypeReductionCartesianProductLimit, 100'000) -LUAU_FASTINTVARIABLE(LuauTypeReductionRecursionLimit, 400) -LUAU_FASTFLAGVARIABLE(DebugLuauDontReduceTypes, false) - -namespace Luau -{ - -namespace detail -{ -bool TypeReductionMemoization::isIrreducible(TypeId ty) -{ - ty = follow(ty); - - // Only does shallow check, the TypeReducer itself already does deep traversal. - if (auto edge = types.find(ty); edge && edge->irreducible) - return true; - else if (get(ty) || get(ty) || get(ty)) - return false; - else if (auto tt = get(ty); tt && (tt->state == TableState::Free || tt->state == TableState::Unsealed)) - return false; - else - return true; -} - -bool TypeReductionMemoization::isIrreducible(TypePackId tp) -{ - tp = follow(tp); - - // Only does shallow check, the TypeReducer itself already does deep traversal. - if (auto edge = typePacks.find(tp); edge && edge->irreducible) - return true; - else if (get(tp) || get(tp)) - return false; - else if (auto vtp = get(tp)) - return isIrreducible(vtp->ty); - else - return true; -} - -TypeId TypeReductionMemoization::memoize(TypeId ty, TypeId reducedTy) -{ - ty = follow(ty); - reducedTy = follow(reducedTy); - - // The irreducibility of this [`reducedTy`] depends on whether its contents are themselves irreducible. - // We don't need to recurse much further than that, because we already record the irreducibility from - // the bottom up. - bool irreducible = isIrreducible(reducedTy); - if (auto it = get(reducedTy)) - { - for (TypeId part : it) - irreducible &= isIrreducible(part); - } - else if (auto ut = get(reducedTy)) - { - for (TypeId option : ut) - irreducible &= isIrreducible(option); - } - else if (auto tt = get(reducedTy)) - { - for (auto& [k, p] : tt->props) - irreducible &= isIrreducible(p.type); - - if (tt->indexer) - { - irreducible &= isIrreducible(tt->indexer->indexType); - irreducible &= isIrreducible(tt->indexer->indexResultType); - } - - for (auto ta : tt->instantiatedTypeParams) - irreducible &= isIrreducible(ta); - - for (auto tpa : tt->instantiatedTypePackParams) - irreducible &= isIrreducible(tpa); - } - else if (auto mt = get(reducedTy)) - { - irreducible &= isIrreducible(mt->table); - irreducible &= isIrreducible(mt->metatable); - } - else if (auto ft = get(reducedTy)) - { - irreducible &= isIrreducible(ft->argTypes); - irreducible &= isIrreducible(ft->retTypes); - } - else if (auto nt = get(reducedTy)) - irreducible &= isIrreducible(nt->ty); - - types[ty] = {reducedTy, irreducible}; - types[reducedTy] = {reducedTy, irreducible}; - return reducedTy; -} - -TypePackId TypeReductionMemoization::memoize(TypePackId tp, TypePackId reducedTp) -{ - tp = follow(tp); - reducedTp = follow(reducedTp); - - bool irreducible = isIrreducible(reducedTp); - TypePackIterator it = begin(tp); - while (it != end(tp)) - { - irreducible &= isIrreducible(*it); - ++it; - } - - if (it.tail()) - irreducible &= isIrreducible(*it.tail()); - - typePacks[tp] = {reducedTp, irreducible}; - typePacks[reducedTp] = {reducedTp, irreducible}; - return reducedTp; -} - -std::optional> TypeReductionMemoization::memoizedof(TypeId ty) const -{ - auto fetchContext = [this](TypeId ty) -> std::optional> { - if (auto edge = types.find(ty)) - return *edge; - else - return std::nullopt; - }; - - TypeId currentTy = ty; - std::optional> lastEdge; - while (auto edge = fetchContext(currentTy)) - { - lastEdge = edge; - if (edge->irreducible) - return edge; - else if (edge->type == currentTy) - return edge; - else - currentTy = edge->type; - } - - return lastEdge; -} - -std::optional> TypeReductionMemoization::memoizedof(TypePackId tp) const -{ - auto fetchContext = [this](TypePackId tp) -> std::optional> { - if (auto edge = typePacks.find(tp)) - return *edge; - else - return std::nullopt; - }; - - TypePackId currentTp = tp; - std::optional> lastEdge; - while (auto edge = fetchContext(currentTp)) - { - lastEdge = edge; - if (edge->irreducible) - return edge; - else if (edge->type == currentTp) - return edge; - else - currentTp = edge->type; - } - - return lastEdge; -} -} // namespace detail - -namespace -{ - -template -std::pair get2(const Thing& one, const Thing& two) -{ - const A* a = get(one); - const B* b = get(two); - return a && b ? std::make_pair(a, b) : std::make_pair(nullptr, nullptr); -} - -struct TypeReducer -{ - NotNull arena; - NotNull builtinTypes; - NotNull handle; - NotNull memoization; - DenseHashSet* cyclics; - - int depth = 0; - - TypeId reduce(TypeId ty); - TypePackId reduce(TypePackId tp); - - std::optional intersectionType(TypeId left, TypeId right); - std::optional unionType(TypeId left, TypeId right); - TypeId tableType(TypeId ty); - TypeId functionType(TypeId ty); - TypeId negationType(TypeId ty); - - using BinaryFold = std::optional (TypeReducer::*)(TypeId, TypeId); - using UnaryFold = TypeId (TypeReducer::*)(TypeId); - - template - LUAU_NOINLINE std::pair copy(TypeId ty, const T* t) - { - ty = follow(ty); - - if (auto edge = memoization->memoizedof(ty)) - return {edge->type, getMutable(edge->type)}; - - // We specifically do not want to use [`detail::TypeReductionMemoization::memoize`] because that will - // potentially consider these copiedTy to be reducible, but we need this to resolve cyclic references - // without attempting to recursively reduce it, causing copies of copies of copies of... - TypeId copiedTy = arena->addType(*t); - memoization->types[ty] = {copiedTy, true}; - memoization->types[copiedTy] = {copiedTy, true}; - return {copiedTy, getMutable(copiedTy)}; - } - - template - void foldl_impl(Iter it, Iter endIt, BinaryFold f, std::vector* result, bool* didReduce) - { - RecursionLimiter rl{&depth, FInt::LuauTypeReductionRecursionLimit}; - - while (it != endIt) - { - TypeId right = reduce(*it); - *didReduce |= right != follow(*it); - - // We're hitting a case where the `currentTy` returned a type that's the same as `T`. - // e.g. `(string?) & ~(false | nil)` became `(string?) & (~false & ~nil)` but the current iterator we're consuming doesn't know this. - // We will need to recurse and traverse that first. - if (auto t = get(right)) - { - foldl_impl(begin(t), end(t), f, result, didReduce); - ++it; - continue; - } - - bool replaced = false; - auto resultIt = result->begin(); - while (resultIt != result->end()) - { - TypeId left = *resultIt; - if (left == right) - { - replaced = true; - ++resultIt; - continue; - } - - std::optional reduced = (this->*f)(left, right); - if (reduced) - { - *resultIt = *reduced; - ++resultIt; - replaced = true; - } - else - { - ++resultIt; - continue; - } - } - - if (!replaced) - result->push_back(right); - - *didReduce |= replaced; - ++it; - } - } - - template - TypeId flatten(std::vector&& types) - { - if (types.size() == 1) - return types[0]; - else - return arena->addType(T{std::move(types)}); - } - - template - TypeId foldl(Iter it, Iter endIt, std::optional ty, BinaryFold f) - { - std::vector result; - bool didReduce = false; - foldl_impl(it, endIt, f, &result, &didReduce); - - // If we've done any reduction, then we'll need to reduce it again, e.g. - // `"a" | "b" | string` is reduced into `string | string`, which is then reduced into `string`. - if (!didReduce) - return ty ? *ty : flatten(std::move(result)); - else - return reduce(flatten(std::move(result))); - } - - template - TypeId apply(BinaryFold f, TypeId left, TypeId right) - { - std::vector types{left, right}; - return foldl(begin(types), end(types), std::nullopt, f); - } - - template - TypeId distribute(TypeIterator it, TypeIterator endIt, BinaryFold f, TypeId ty) - { - std::vector result; - while (it != endIt) - { - result.push_back(apply(f, *it, ty)); - ++it; - } - return flatten(std::move(result)); - } -}; - -TypeId TypeReducer::reduce(TypeId ty) -{ - ty = follow(ty); - - if (auto edge = memoization->memoizedof(ty)) - { - if (edge->irreducible) - return edge->type; - else - ty = edge->type; - } - else if (cyclics->contains(ty)) - return ty; - - RecursionLimiter rl{&depth, FInt::LuauTypeReductionRecursionLimit}; - - TypeId result = nullptr; - if (auto i = get(ty)) - result = foldl(begin(i), end(i), ty, &TypeReducer::intersectionType); - else if (auto u = get(ty)) - result = foldl(begin(u), end(u), ty, &TypeReducer::unionType); - else if (get(ty) || get(ty)) - result = tableType(ty); - else if (get(ty)) - result = functionType(ty); - else if (get(ty)) - result = negationType(ty); - else - result = ty; - - return memoization->memoize(ty, result); -} - -TypePackId TypeReducer::reduce(TypePackId tp) -{ - tp = follow(tp); - - if (auto edge = memoization->memoizedof(tp)) - { - if (edge->irreducible) - return edge->type; - else - tp = edge->type; - } - else if (cyclics->contains(tp)) - return tp; - - RecursionLimiter rl{&depth, FInt::LuauTypeReductionRecursionLimit}; - - bool didReduce = false; - TypePackIterator it = begin(tp); - - std::vector head; - while (it != end(tp)) - { - TypeId reducedTy = reduce(*it); - head.push_back(reducedTy); - didReduce |= follow(*it) != follow(reducedTy); - ++it; - } - - std::optional tail = it.tail(); - if (tail) - { - if (auto vtp = get(follow(*it.tail()))) - { - TypeId reducedTy = reduce(vtp->ty); - if (follow(vtp->ty) != follow(reducedTy)) - { - tail = arena->addTypePack(VariadicTypePack{reducedTy, vtp->hidden}); - didReduce = true; - } - } - } - - if (!didReduce) - return memoization->memoize(tp, tp); - else if (head.empty() && tail) - return memoization->memoize(tp, *tail); - else - return memoization->memoize(tp, arena->addTypePack(TypePack{std::move(head), tail})); -} - -std::optional TypeReducer::intersectionType(TypeId left, TypeId right) -{ - LUAU_ASSERT(!get(left)); - LUAU_ASSERT(!get(right)); - - if (get(left)) - return left; // never & T ~ never - else if (get(right)) - return right; // T & never ~ never - else if (get(left)) - return right; // unknown & T ~ T - else if (get(right)) - return left; // T & unknown ~ T - else if (get(left)) - return right; // any & T ~ T - else if (get(right)) - return left; // T & any ~ T - else if (get(left)) - return std::nullopt; // 'a & T ~ 'a & T - else if (get(right)) - return std::nullopt; // T & 'a ~ T & 'a - else if (get(left)) - return std::nullopt; // G & T ~ G & T - else if (get(right)) - return std::nullopt; // T & G ~ T & G - else if (get(left)) - return std::nullopt; // error & T ~ error & T - else if (get(right)) - return std::nullopt; // T & error ~ T & error - else if (get(left)) - return std::nullopt; // *blocked* & T ~ *blocked* & T - else if (get(right)) - return std::nullopt; // T & *blocked* ~ T & *blocked* - else if (get(left)) - return std::nullopt; // *pending* & T ~ *pending* & T - else if (get(right)) - return std::nullopt; // T & *pending* ~ T & *pending* - else if (auto ut = get(left)) - return reduce(distribute(begin(ut), end(ut), &TypeReducer::intersectionType, right)); // (A | B) & T ~ (A & T) | (B & T) - else if (get(right)) - return intersectionType(right, left); // T & (A | B) ~ (A | B) & T - else if (auto [p1, p2] = get2(left, right); p1 && p2) - { - if (p1->type == p2->type) - return left; // P1 & P2 ~ P1 iff P1 == P2 - else - return builtinTypes->neverType; // P1 & P2 ~ never iff P1 != P2 - } - else if (auto [p, s] = get2(left, right); p && s) - { - if (p->type == PrimitiveType::String && get(s)) - return right; // string & "A" ~ "A" - else if (p->type == PrimitiveType::Boolean && get(s)) - return right; // boolean & true ~ true - else - return builtinTypes->neverType; // string & true ~ never - } - else if (auto [s, p] = get2(left, right); s && p) - return intersectionType(right, left); // S & P ~ P & S - else if (auto [p, f] = get2(left, right); p && f) - { - if (p->type == PrimitiveType::Function) - return right; // function & () -> () ~ () -> () - else - return builtinTypes->neverType; // string & () -> () ~ never - } - else if (auto [f, p] = get2(left, right); f && p) - return intersectionType(right, left); // () -> () & P ~ P & () -> () - else if (auto [p, t] = get2(left, right); p && t) - { - if (p->type == PrimitiveType::Table) - return right; // table & {} ~ {} - else - return builtinTypes->neverType; // string & {} ~ never - } - else if (auto [p, t] = get2(left, right); p && t) - { - if (p->type == PrimitiveType::Table) - return right; // table & {} ~ {} - else - return builtinTypes->neverType; // string & {} ~ never - } - else if (auto [t, p] = get2(left, right); t && p) - return intersectionType(right, left); // {} & P ~ P & {} - else if (auto [t, p] = get2(left, right); t && p) - return intersectionType(right, left); // M & P ~ P & M - else if (auto [s1, s2] = get2(left, right); s1 && s2) - { - if (*s1 == *s2) - return left; // "a" & "a" ~ "a" - else - return builtinTypes->neverType; // "a" & "b" ~ never - } - else if (auto [c1, c2] = get2(left, right); c1 && c2) - { - if (isSubclass(c1, c2)) - return left; // Derived & Base ~ Derived - else if (isSubclass(c2, c1)) - return right; // Base & Derived ~ Derived - else - return builtinTypes->neverType; // Base & Unrelated ~ never - } - else if (auto [f1, f2] = get2(left, right); f1 && f2) - return std::nullopt; // TODO - else if (auto [t1, t2] = get2(left, right); t1 && t2) - { - if (t1->state == TableState::Free || t2->state == TableState::Free) - return std::nullopt; // '{ x: T } & { x: U } ~ '{ x: T } & { x: U } - else if (t1->state == TableState::Generic || t2->state == TableState::Generic) - return std::nullopt; // '{ x: T } & { x: U } ~ '{ x: T } & { x: U } - - if (cyclics->contains(left)) - return std::nullopt; // (t1 where t1 = { p: t1 }) & {} ~ t1 & {} - else if (cyclics->contains(right)) - return std::nullopt; // {} & (t1 where t1 = { p: t1 }) ~ {} & t1 - - TypeId resultTy = arena->addType(TableType{}); - TableType* table = getMutable(resultTy); - table->state = t1->state == TableState::Sealed || t2->state == TableState::Sealed ? TableState::Sealed : TableState::Unsealed; - - for (const auto& [name, prop] : t1->props) - { - // TODO: when t1 has properties, we should also intersect that with the indexer in t2 if it exists, - // even if we have the corresponding property in the other one. - if (auto other = t2->props.find(name); other != t2->props.end()) - { - TypeId propTy = apply(&TypeReducer::intersectionType, prop.type, other->second.type); - if (get(propTy)) - return builtinTypes->neverType; // { p : string } & { p : number } ~ { p : string & number } ~ { p : never } ~ never - else - table->props[name] = {propTy}; // { p : string } & { p : ~"a" } ~ { p : string & ~"a" } - } - else - table->props[name] = prop; // { p : string } & {} ~ { p : string } - } - - for (const auto& [name, prop] : t2->props) - { - // TODO: And vice versa, t2 properties against t1 indexer if it exists, - // even if we have the corresponding property in the other one. - if (!t1->props.count(name)) - table->props[name] = {reduce(prop.type)}; // {} & { p : string & string } ~ { p : string } - } - - if (t1->indexer && t2->indexer) - { - TypeId keyTy = apply(&TypeReducer::intersectionType, t1->indexer->indexType, t2->indexer->indexType); - if (get(keyTy)) - return std::nullopt; // { [string]: _ } & { [number]: _ } ~ { [string]: _ } & { [number]: _ } - - TypeId valueTy = apply(&TypeReducer::intersectionType, t1->indexer->indexResultType, t2->indexer->indexResultType); - table->indexer = TableIndexer{keyTy, valueTy}; // { [string]: number } & { [string]: string } ~ { [string]: never } - } - else if (t1->indexer) - { - TypeId keyTy = reduce(t1->indexer->indexType); - TypeId valueTy = reduce(t1->indexer->indexResultType); - table->indexer = TableIndexer{keyTy, valueTy}; // { [number]: boolean } & { p : string } ~ { p : string, [number]: boolean } - } - else if (t2->indexer) - { - TypeId keyTy = reduce(t2->indexer->indexType); - TypeId valueTy = reduce(t2->indexer->indexResultType); - table->indexer = TableIndexer{keyTy, valueTy}; // { p : string } & { [number]: boolean } ~ { p : string, [number]: boolean } - } - - return resultTy; - } - else if (auto [mt, tt] = get2(left, right); mt && tt) - return std::nullopt; // TODO - else if (auto [tt, mt] = get2(left, right); tt && mt) - return intersectionType(right, left); // T & M ~ M & T - else if (auto [m1, m2] = get2(left, right); m1 && m2) - return std::nullopt; // TODO - else if (auto [nl, nr] = get2(left, right); nl && nr) - { - // These should've been reduced already. - TypeId nlTy = follow(nl->ty); - TypeId nrTy = follow(nr->ty); - LUAU_ASSERT(!get(nlTy) && !get(nrTy)); - LUAU_ASSERT(!get(nlTy) && !get(nrTy)); - LUAU_ASSERT(!get(nlTy) && !get(nrTy)); - LUAU_ASSERT(!get(nlTy) && !get(nrTy)); - LUAU_ASSERT(!get(nlTy) && !get(nrTy)); - - if (auto [npl, npr] = get2(nlTy, nrTy); npl && npr) - { - if (npl->type == npr->type) - return left; // ~P1 & ~P2 ~ ~P1 iff P1 == P2 - else - return std::nullopt; // ~P1 & ~P2 ~ ~P1 & ~P2 iff P1 != P2 - } - else if (auto [nsl, nsr] = get2(nlTy, nrTy); nsl && nsr) - { - if (*nsl == *nsr) - return left; // ~"A" & ~"A" ~ ~"A" - else - return std::nullopt; // ~"A" & ~"B" ~ ~"A" & ~"B" - } - else if (auto [ns, np] = get2(nlTy, nrTy); ns && np) - { - if (get(ns) && np->type == PrimitiveType::String) - return right; // ~"A" & ~string ~ ~string - else if (get(ns) && np->type == PrimitiveType::Boolean) - return right; // ~false & ~boolean ~ ~boolean - else - return std::nullopt; // ~"A" | ~P ~ ~"A" & ~P - } - else if (auto [np, ns] = get2(nlTy, nrTy); np && ns) - return intersectionType(right, left); // ~P & ~S ~ ~S & ~P - else - return std::nullopt; // ~T & ~U ~ ~T & ~U - } - else if (auto nl = get(left)) - { - // These should've been reduced already. - TypeId nlTy = follow(nl->ty); - LUAU_ASSERT(!get(nlTy)); - LUAU_ASSERT(!get(nlTy)); - LUAU_ASSERT(!get(nlTy)); - LUAU_ASSERT(!get(nlTy)); - LUAU_ASSERT(!get(nlTy)); - - if (auto [np, p] = get2(nlTy, right); np && p) - { - if (np->type == p->type) - return builtinTypes->neverType; // ~P1 & P2 ~ never iff P1 == P2 - else - return right; // ~P1 & P2 ~ P2 iff P1 != P2 - } - else if (auto [ns, s] = get2(nlTy, right); ns && s) - { - if (*ns == *s) - return builtinTypes->neverType; // ~"A" & "A" ~ never - else - return right; // ~"A" & "B" ~ "B" - } - else if (auto [ns, p] = get2(nlTy, right); ns && p) - { - if (get(ns) && p->type == PrimitiveType::String) - return std::nullopt; // ~"A" & string ~ ~"A" & string - else if (get(ns) && p->type == PrimitiveType::Boolean) - { - // Because booleans contain a fixed amount of values (2), we can do something cooler with this one. - const BooleanSingleton* b = get(ns); - return arena->addType(SingletonType{BooleanSingleton{!b->value}}); // ~false & boolean ~ true - } - else - return right; // ~"A" & number ~ number - } - else if (auto [np, s] = get2(nlTy, right); np && s) - { - if (np->type == PrimitiveType::String && get(s)) - return builtinTypes->neverType; // ~string & "A" ~ never - else if (np->type == PrimitiveType::Boolean && get(s)) - return builtinTypes->neverType; // ~boolean & true ~ never - else - return right; // ~P & "A" ~ "A" - } - else if (auto [np, f] = get2(nlTy, right); np && f) - { - if (np->type == PrimitiveType::Function) - return builtinTypes->neverType; // ~function & () -> () ~ never - else - return right; // ~string & () -> () ~ () -> () - } - else if (auto [nc, c] = get2(nlTy, right); nc && c) - { - if (isSubclass(c, nc)) - return builtinTypes->neverType; // ~Base & Derived ~ never - else if (isSubclass(nc, c)) - return std::nullopt; // ~Derived & Base ~ ~Derived & Base - else - return right; // ~Base & Unrelated ~ Unrelated - } - else if (auto [np, t] = get2(nlTy, right); np && t) - { - if (np->type == PrimitiveType::Table) - return builtinTypes->neverType; // ~table & {} ~ never - else - return right; // ~string & {} ~ {} - } - else if (auto [np, t] = get2(nlTy, right); np && t) - { - if (np->type == PrimitiveType::Table) - return builtinTypes->neverType; // ~table & {} ~ never - else - return right; // ~string & {} ~ {} - } - else - return right; // ~T & U ~ U - } - else if (get(right)) - return intersectionType(right, left); // T & ~U ~ ~U & T - else - return builtinTypes->neverType; // for all T and U except the ones handled above, T & U ~ never -} - -std::optional TypeReducer::unionType(TypeId left, TypeId right) -{ - LUAU_ASSERT(!get(left)); - LUAU_ASSERT(!get(right)); - - if (get(left)) - return right; // never | T ~ T - else if (get(right)) - return left; // T | never ~ T - else if (get(left)) - return left; // unknown | T ~ unknown - else if (get(right)) - return right; // T | unknown ~ unknown - else if (get(left)) - return left; // any | T ~ any - else if (get(right)) - return right; // T | any ~ any - else if (get(left)) - return std::nullopt; // error | T ~ error | T - else if (get(right)) - return std::nullopt; // T | error ~ T | error - else if (auto [p1, p2] = get2(left, right); p1 && p2) - { - if (p1->type == p2->type) - return left; // P1 | P2 ~ P1 iff P1 == P2 - else - return std::nullopt; // P1 | P2 ~ P1 | P2 iff P1 != P2 - } - else if (auto [p, s] = get2(left, right); p && s) - { - if (p->type == PrimitiveType::String && get(s)) - return left; // string | "A" ~ string - else if (p->type == PrimitiveType::Boolean && get(s)) - return left; // boolean | true ~ boolean - else - return std::nullopt; // string | true ~ string | true - } - else if (auto [s, p] = get2(left, right); s && p) - return unionType(right, left); // S | P ~ P | S - else if (auto [p, f] = get2(left, right); p && f) - { - if (p->type == PrimitiveType::Function) - return left; // function | () -> () ~ function - else - return std::nullopt; // P | () -> () ~ P | () -> () - } - else if (auto [f, p] = get2(left, right); f && p) - return unionType(right, left); // () -> () | P ~ P | () -> () - else if (auto [p, t] = get2(left, right); p && t) - { - if (p->type == PrimitiveType::Table) - return left; // table | {} ~ table - else - return std::nullopt; // P | {} ~ P | {} - } - else if (auto [p, t] = get2(left, right); p && t) - { - if (p->type == PrimitiveType::Table) - return left; // table | {} ~ table - else - return std::nullopt; // P | {} ~ P | {} - } - else if (auto [t, p] = get2(left, right); t && p) - return unionType(right, left); // {} | P ~ P | {} - else if (auto [t, p] = get2(left, right); t && p) - return unionType(right, left); // M | P ~ P | M - else if (auto [s1, s2] = get2(left, right); s1 && s2) - { - if (*s1 == *s2) - return left; // "a" | "a" ~ "a" - else - return std::nullopt; // "a" | "b" ~ "a" | "b" - } - else if (auto [c1, c2] = get2(left, right); c1 && c2) - { - if (isSubclass(c1, c2)) - return right; // Derived | Base ~ Base - else if (isSubclass(c2, c1)) - return left; // Base | Derived ~ Base - else - return std::nullopt; // Base | Unrelated ~ Base | Unrelated - } - else if (auto [nt, it] = get2(left, right); nt && it) - return reduce(distribute(begin(it), end(it), &TypeReducer::unionType, left)); // ~T | (A & B) ~ (~T | A) & (~T | B) - else if (auto [it, nt] = get2(left, right); it && nt) - return unionType(right, left); // (A & B) | ~T ~ ~T | (A & B) - else if (auto [nl, nr] = get2(left, right); nl && nr) - { - // These should've been reduced already. - TypeId nlTy = follow(nl->ty); - TypeId nrTy = follow(nr->ty); - LUAU_ASSERT(!get(nlTy) && !get(nrTy)); - LUAU_ASSERT(!get(nlTy) && !get(nrTy)); - LUAU_ASSERT(!get(nlTy) && !get(nrTy)); - LUAU_ASSERT(!get(nlTy) && !get(nrTy)); - LUAU_ASSERT(!get(nlTy) && !get(nrTy)); - - if (auto [npl, npr] = get2(nlTy, nrTy); npl && npr) - { - if (npl->type == npr->type) - return left; // ~P1 | ~P2 ~ ~P1 iff P1 == P2 - else - return builtinTypes->unknownType; // ~P1 | ~P2 ~ ~P1 iff P1 != P2 - } - else if (auto [nsl, nsr] = get2(nlTy, nrTy); nsl && nsr) - { - if (*nsl == *nsr) - return left; // ~"A" | ~"A" ~ ~"A" - else - return builtinTypes->unknownType; // ~"A" | ~"B" ~ unknown - } - else if (auto [ns, np] = get2(nlTy, nrTy); ns && np) - { - if (get(ns) && np->type == PrimitiveType::String) - return left; // ~"A" | ~string ~ ~"A" - else if (get(ns) && np->type == PrimitiveType::Boolean) - return left; // ~false | ~boolean ~ ~false - else - return builtinTypes->unknownType; // ~"A" | ~P ~ unknown - } - else if (auto [np, ns] = get2(nlTy, nrTy); np && ns) - return unionType(right, left); // ~P | ~S ~ ~S | ~P - else - return std::nullopt; // TODO! - } - else if (auto nl = get(left)) - { - // These should've been reduced already. - TypeId nlTy = follow(nl->ty); - LUAU_ASSERT(!get(nlTy)); - LUAU_ASSERT(!get(nlTy)); - LUAU_ASSERT(!get(nlTy)); - LUAU_ASSERT(!get(nlTy)); - LUAU_ASSERT(!get(nlTy)); - - if (auto [np, p] = get2(nlTy, right); np && p) - { - if (np->type == p->type) - return builtinTypes->unknownType; // ~P1 | P2 ~ unknown iff P1 == P2 - else - return left; // ~P1 | P2 ~ ~P1 iff P1 != P2 - } - else if (auto [ns, s] = get2(nlTy, right); ns && s) - { - if (*ns == *s) - return builtinTypes->unknownType; // ~"A" | "A" ~ unknown - else - return left; // ~"A" | "B" ~ ~"A" - } - else if (auto [ns, p] = get2(nlTy, right); ns && p) - { - if (get(ns) && p->type == PrimitiveType::String) - return builtinTypes->unknownType; // ~"A" | string ~ unknown - else if (get(ns) && p->type == PrimitiveType::Boolean) - return builtinTypes->unknownType; // ~false | boolean ~ unknown - else - return left; // ~"A" | T ~ ~"A" - } - else if (auto [np, s] = get2(nlTy, right); np && s) - { - if (np->type == PrimitiveType::String && get(s)) - return std::nullopt; // ~string | "A" ~ ~string | "A" - else if (np->type == PrimitiveType::Boolean && get(s)) - { - const BooleanSingleton* b = get(s); - return negationType(arena->addType(SingletonType{BooleanSingleton{!b->value}})); // ~boolean | false ~ ~true - } - else - return left; // ~P | "A" ~ ~P - } - else if (auto [nc, c] = get2(nlTy, right); nc && c) - { - if (isSubclass(c, nc)) - return std::nullopt; // ~Base | Derived ~ ~Base | Derived - else if (isSubclass(nc, c)) - return builtinTypes->unknownType; // ~Derived | Base ~ unknown - else - return left; // ~Base | Unrelated ~ ~Base - } - else if (auto [np, t] = get2(nlTy, right); np && t) - { - if (np->type == PrimitiveType::Table) - return std::nullopt; // ~table | {} ~ ~table | {} - else - return right; // ~P | {} ~ ~P | {} - } - else if (auto [np, t] = get2(nlTy, right); np && t) - { - if (np->type == PrimitiveType::Table) - return std::nullopt; // ~table | {} ~ ~table | {} - else - return right; // ~P | M ~ ~P | M - } - else - return std::nullopt; // TODO - } - else if (get(right)) - return unionType(right, left); // T | ~U ~ ~U | T - else - return std::nullopt; // for all T and U except the ones handled above, T | U ~ T | U -} - -TypeId TypeReducer::tableType(TypeId ty) -{ - if (auto mt = get(ty)) - { - auto [copiedTy, copied] = copy(ty, mt); - copied->table = reduce(mt->table); - copied->metatable = reduce(mt->metatable); - return copiedTy; - } - else if (auto tt = get(ty)) - { - // Because of `typeof()`, we need to preserve pointer identity of free/unsealed tables so that - // all mutations that occurs on this will be applied without leaking the implementation details. - // As a result, we'll just use the type instead of cloning it if it's free/unsealed. - // - // We could choose to do in-place reductions here, but to be on the safer side, I propose that we do not. - if (tt->state == TableState::Free || tt->state == TableState::Unsealed) - return ty; - - auto [copiedTy, copied] = copy(ty, tt); - - for (auto& [name, prop] : copied->props) - { - TypeId propTy = reduce(prop.type); - if (get(propTy)) - return builtinTypes->neverType; - else - prop.type = propTy; - } - - if (copied->indexer) - { - TypeId keyTy = reduce(copied->indexer->indexType); - TypeId valueTy = reduce(copied->indexer->indexResultType); - copied->indexer = TableIndexer{keyTy, valueTy}; - } - - for (TypeId& ty : copied->instantiatedTypeParams) - ty = reduce(ty); - - for (TypePackId& tp : copied->instantiatedTypePackParams) - tp = reduce(tp); - - return copiedTy; - } - else - handle->ice("TypeReducer::tableType expects a TableType or MetatableType"); -} - -TypeId TypeReducer::functionType(TypeId ty) -{ - const FunctionType* f = get(ty); - if (!f) - handle->ice("TypeReducer::functionType expects a FunctionType"); - - // TODO: once we have bounded quantification, we need to be able to reduce the generic bounds. - auto [copiedTy, copied] = copy(ty, f); - copied->argTypes = reduce(f->argTypes); - copied->retTypes = reduce(f->retTypes); - return copiedTy; -} - -TypeId TypeReducer::negationType(TypeId ty) -{ - const NegationType* n = get(ty); - if (!n) - return arena->addType(NegationType{ty}); - - TypeId negatedTy = follow(n->ty); - - if (auto nn = get(negatedTy)) - return nn->ty; // ~~T ~ T - else if (get(negatedTy)) - return builtinTypes->unknownType; // ~never ~ unknown - else if (get(negatedTy)) - return builtinTypes->neverType; // ~unknown ~ never - else if (get(negatedTy)) - return builtinTypes->anyType; // ~any ~ any - else if (auto ni = get(negatedTy)) - { - std::vector options; - for (TypeId part : ni) - options.push_back(negationType(arena->addType(NegationType{part}))); - return reduce(flatten(std::move(options))); // ~(T & U) ~ (~T | ~U) - } - else if (auto nu = get(negatedTy)) - { - std::vector parts; - for (TypeId option : nu) - parts.push_back(negationType(arena->addType(NegationType{option}))); - return reduce(flatten(std::move(parts))); // ~(T | U) ~ (~T & ~U) - } - else - return ty; // for all T except the ones handled above, ~T ~ ~T -} - -struct MarkCycles : TypeVisitor -{ - DenseHashSet cyclics{nullptr}; - - void cycle(TypeId ty) override - { - cyclics.insert(follow(ty)); - } - - void cycle(TypePackId tp) override - { - cyclics.insert(follow(tp)); - } - - bool visit(TypeId ty) override - { - return !cyclics.find(follow(ty)); - } - - bool visit(TypePackId tp) override - { - return !cyclics.find(follow(tp)); - } -}; -} // namespace - -TypeReduction::TypeReduction( - NotNull arena, NotNull builtinTypes, NotNull handle, const TypeReductionOptions& opts) - : arena(arena) - , builtinTypes(builtinTypes) - , handle(handle) - , options(opts) -{ -} - -std::optional TypeReduction::reduce(TypeId ty) -{ - ty = follow(ty); - - if (FFlag::DebugLuauDontReduceTypes) - return ty; - else if (!options.allowTypeReductionsFromOtherArenas && ty->owningArena != arena) - return ty; - else if (auto edge = memoization.memoizedof(ty)) - { - if (edge->irreducible) - return edge->type; - else - ty = edge->type; - } - else if (hasExceededCartesianProductLimit(ty)) - return std::nullopt; - - try - { - MarkCycles finder; - finder.traverse(ty); - - TypeReducer reducer{arena, builtinTypes, handle, NotNull{&memoization}, &finder.cyclics}; - return reducer.reduce(ty); - } - catch (const RecursionLimitException&) - { - return std::nullopt; - } -} - -std::optional TypeReduction::reduce(TypePackId tp) -{ - tp = follow(tp); - - if (FFlag::DebugLuauDontReduceTypes) - return tp; - else if (!options.allowTypeReductionsFromOtherArenas && tp->owningArena != arena) - return tp; - else if (auto edge = memoization.memoizedof(tp)) - { - if (edge->irreducible) - return edge->type; - else - tp = edge->type; - } - else if (hasExceededCartesianProductLimit(tp)) - return std::nullopt; - - try - { - MarkCycles finder; - finder.traverse(tp); - - TypeReducer reducer{arena, builtinTypes, handle, NotNull{&memoization}, &finder.cyclics}; - return reducer.reduce(tp); - } - catch (const RecursionLimitException&) - { - return std::nullopt; - } -} - -std::optional TypeReduction::reduce(const TypeFun& fun) -{ - if (FFlag::DebugLuauDontReduceTypes) - return fun; - - // TODO: once we have bounded quantification, we need to be able to reduce the generic bounds. - if (auto reducedTy = reduce(fun.type)) - return TypeFun{fun.typeParams, fun.typePackParams, *reducedTy}; - - return std::nullopt; -} - -size_t TypeReduction::cartesianProductSize(TypeId ty) const -{ - ty = follow(ty); - - auto it = get(follow(ty)); - if (!it) - return 1; - - return std::accumulate(begin(it), end(it), size_t(1), [](size_t acc, TypeId ty) { - if (auto ut = get(ty)) - return acc * std::distance(begin(ut), end(ut)); - else if (get(ty)) - return acc * 0; - else - return acc * 1; - }); -} - -bool TypeReduction::hasExceededCartesianProductLimit(TypeId ty) const -{ - return cartesianProductSize(ty) >= size_t(FInt::LuauTypeReductionCartesianProductLimit); -} - -bool TypeReduction::hasExceededCartesianProductLimit(TypePackId tp) const -{ - TypePackIterator it = begin(tp); - - while (it != end(tp)) - { - if (hasExceededCartesianProductLimit(*it)) - return true; - - ++it; - } - - if (auto tail = it.tail()) - { - if (auto vtp = get(follow(*tail))) - { - if (hasExceededCartesianProductLimit(vtp->ty)) - return true; - } - } - - return false; -} - -} // namespace Luau diff --git a/Analysis/src/TypeUtils.cpp b/Analysis/src/TypeUtils.cpp index e5029e587..1ed1b9e0a 100644 --- a/Analysis/src/TypeUtils.cpp +++ b/Analysis/src/TypeUtils.cpp @@ -1,6 +1,7 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/TypeUtils.h" +#include "Luau/Common.h" #include "Luau/Normalize.h" #include "Luau/Scope.h" #include "Luau/ToString.h" @@ -8,11 +9,96 @@ #include +LUAU_FASTFLAG(LuauSolverV2); + namespace Luau { +bool inConditional(const TypeContext& context) +{ + return context == TypeContext::Condition; +} + +bool occursCheck(TypeId needle, TypeId haystack) +{ + LUAU_ASSERT(get(needle) || get(needle)); + haystack = follow(haystack); + + auto checkHaystack = [needle](TypeId haystack) + { + return occursCheck(needle, haystack); + }; + + if (needle == haystack) + return true; + else if (auto ut = get(haystack)) + return std::any_of(begin(ut), end(ut), checkHaystack); + else if (auto it = get(haystack)) + return std::any_of(begin(it), end(it), checkHaystack); + + return false; +} + +// FIXME: Property is quite large. +// +// Returning it on the stack like this isn't great. We'd like to just return a +// const Property*, but we mint a property of type any if the subject type is +// any. +std::optional findTableProperty(NotNull builtinTypes, ErrorVec& errors, TypeId ty, const std::string& name, Location location) +{ + if (get(ty)) + return Property::rw(ty); + + if (const TableType* tableType = getTableType(ty)) + { + const auto& it = tableType->props.find(name); + if (it != tableType->props.end()) + return it->second; + } + + std::optional mtIndex = findMetatableEntry(builtinTypes, errors, ty, "__index", location); + int count = 0; + while (mtIndex) + { + TypeId index = follow(*mtIndex); + + if (count >= 100) + return std::nullopt; + + ++count; + + if (const auto& itt = getTableType(index)) + { + const auto& fit = itt->props.find(name); + if (fit != itt->props.end()) + return fit->second.type(); + } + else if (const auto& itf = get(index)) + { + std::optional r = first(follow(itf->retTypes)); + if (!r) + return builtinTypes->nilType; + else + return *r; + } + else if (get(index)) + return builtinTypes->anyType; + else + errors.push_back(TypeError{location, GenericError{"__index should either be a function or table. Got " + toString(index)}}); + + mtIndex = findMetatableEntry(builtinTypes, errors, *mtIndex, "__index", location); + } + + return std::nullopt; +} + std::optional findMetatableEntry( - NotNull builtinTypes, ErrorVec& errors, TypeId type, const std::string& entry, Location location) + NotNull builtinTypes, + ErrorVec& errors, + TypeId type, + const std::string& entry, + Location location +) { type = follow(type); @@ -34,13 +120,30 @@ std::optional findMetatableEntry( auto it = mtt->props.find(entry); if (it != mtt->props.end()) - return it->second.type; + return it->second.type(); else return std::nullopt; } std::optional findTablePropertyRespectingMeta( - NotNull builtinTypes, ErrorVec& errors, TypeId ty, const std::string& name, Location location) + NotNull builtinTypes, + ErrorVec& errors, + TypeId ty, + const std::string& name, + Location location +) +{ + return findTablePropertyRespectingMeta(builtinTypes, errors, ty, name, ValueContext::RValue, location); +} + +std::optional findTablePropertyRespectingMeta( + NotNull builtinTypes, + ErrorVec& errors, + TypeId ty, + const std::string& name, + ValueContext context, + Location location +) { if (get(ty)) return ty; @@ -49,7 +152,20 @@ std::optional findTablePropertyRespectingMeta( { const auto& it = tableType->props.find(name); if (it != tableType->props.end()) - return it->second.type; + { + if (FFlag::LuauSolverV2) + { + switch (context) + { + case ValueContext::RValue: + return it->second.readTy; + case ValueContext::LValue: + return it->second.writeTy; + } + } + else + return it->second.type(); + } } std::optional mtIndex = findMetatableEntry(builtinTypes, errors, ty, "__index", location); @@ -67,7 +183,7 @@ std::optional findTablePropertyRespectingMeta( { const auto& fit = itt->props.find(name); if (fit != itt->props.end()) - return fit->second.type; + return fit->second.type(); } else if (const auto& itf = get(index)) { @@ -118,7 +234,12 @@ std::pair> getParameterExtents(const TxnLog* log, } TypePack extendTypePack( - TypeArena& arena, NotNull builtinTypes, TypePackId pack, size_t length, std::vector> overrides) + TypeArena& arena, + NotNull builtinTypes, + TypePackId pack, + size_t length, + std::vector> overrides +) { TypePack result; @@ -180,6 +301,8 @@ TypePack extendTypePack( TypePack newPack; newPack.tail = arena.freshTypePack(ftp->scope); + if (FFlag::LuauSolverV2) + result.tail = newPack.tail; size_t overridesIndex = 0; while (result.head.size() < length) { @@ -190,7 +313,13 @@ TypePack extendTypePack( } else { - t = arena.freshType(ftp->scope); + if (FFlag::LuauSolverV2) + { + FreeType ft{ftp->scope, builtinTypes->neverType, builtinTypes->unknownType}; + t = arena.addType(ft); + } + else + t = arena.freshType(ftp->scope); } newPack.head.push_back(t); @@ -295,4 +424,123 @@ TypeId stripNil(NotNull builtinTypes, TypeArena& arena, TypeId ty) return follow(ty); } +ErrorSuppression shouldSuppressErrors(NotNull normalizer, TypeId ty) +{ + LUAU_ASSERT(FFlag::LuauSolverV2); + std::shared_ptr normType = normalizer->normalize(ty); + + if (!normType) + return ErrorSuppression::NormalizationFailed; + + return (normType->shouldSuppressErrors()) ? ErrorSuppression::Suppress : ErrorSuppression::DoNotSuppress; +} + +ErrorSuppression shouldSuppressErrors(NotNull normalizer, TypePackId tp) +{ + auto [tys, tail] = flatten(tp); + + // check the head, one type at a time + for (TypeId ty : tys) + { + auto result = shouldSuppressErrors(normalizer, ty); + if (result != ErrorSuppression::DoNotSuppress) + return result; + } + + // check the tail if we have one and it's finite + if (tail && tp != tail && finite(*tail)) + return shouldSuppressErrors(normalizer, *tail); + + return ErrorSuppression::DoNotSuppress; +} + +// This is a useful helper because it is often the case that we are looking at specifically a pair of types that might suppress. +ErrorSuppression shouldSuppressErrors(NotNull normalizer, TypeId ty1, TypeId ty2) +{ + auto result = shouldSuppressErrors(normalizer, ty1); + + // if ty1 is do not suppress, ty2 determines our overall behavior + if (result == ErrorSuppression::DoNotSuppress) + return shouldSuppressErrors(normalizer, ty2); + + // otherwise, ty1 is either suppress or normalization failure which are both the appropriate overarching result + return result; +} + +ErrorSuppression shouldSuppressErrors(NotNull normalizer, TypePackId tp1, TypePackId tp2) +{ + auto result = shouldSuppressErrors(normalizer, tp1); + + // if tp1 is do not suppress, tp2 determines our overall behavior + if (result == ErrorSuppression::DoNotSuppress) + return shouldSuppressErrors(normalizer, tp2); + + // otherwise, tp1 is either suppress or normalization failure which are both the appropriate overarching result + return result; +} + +bool isLiteral(const AstExpr* expr) +{ + return ( + expr->is() || expr->is() || expr->is() || expr->is() || + expr->is() || expr->is() + ); +} +/** + * Visitor which, given an expression and a mapping from expression to TypeId, + * determines if there are any literal expressions that contain blocked types. + * This is used for bi-directional inference: we want to "apply" a type from + * a function argument or a type annotation to a literal. + */ +class BlockedTypeInLiteralVisitor : public AstVisitor +{ +public: + explicit BlockedTypeInLiteralVisitor(NotNull> astTypes, NotNull> toBlock) + : astTypes_{astTypes} + , toBlock_{toBlock} + { + } + bool visit(AstNode*) override + { + return false; + } + + bool visit(AstExpr* e) override + { + auto ty = astTypes_->find(e); + if (ty && (get(follow(*ty)) != nullptr)) + { + toBlock_->push_back(*ty); + } + return isLiteral(e) || e->is(); + } + +private: + NotNull> astTypes_; + NotNull> toBlock_; +}; + +std::vector findBlockedTypesIn(AstExprTable* expr, NotNull> astTypes) +{ + std::vector toBlock; + BlockedTypeInLiteralVisitor v{astTypes, NotNull{&toBlock}}; + expr->visit(&v); + return toBlock; +} + +std::vector findBlockedArgTypesIn(AstExprCall* expr, NotNull> astTypes) +{ + std::vector toBlock; + BlockedTypeInLiteralVisitor v{astTypes, NotNull{&toBlock}}; + for (auto arg: expr->args) + { + if (isLiteral(arg) || arg->is()) + { + arg->visit(&v); + } + } + return toBlock; +} + + } // namespace Luau diff --git a/Analysis/src/TypedAllocator.cpp b/Analysis/src/TypedAllocator.cpp index 4dc26219c..a2f49afbd 100644 --- a/Analysis/src/TypedAllocator.cpp +++ b/Analysis/src/TypedAllocator.cpp @@ -10,7 +10,7 @@ #ifndef NOMINMAX #define NOMINMAX #endif -#include +#include const size_t kPageSize = 4096; #else diff --git a/Analysis/src/Unifiable.cpp b/Analysis/src/Unifiable.cpp index 9db8f7f00..2ceb97aae 100644 --- a/Analysis/src/Unifiable.cpp +++ b/Analysis/src/Unifiable.cpp @@ -8,71 +8,11 @@ namespace Unifiable static int nextIndex = 0; -Free::Free(TypeLevel level) - : index(++nextIndex) - , level(level) -{ -} - -Free::Free(Scope* scope) - : index(++nextIndex) - , scope(scope) -{ -} - -Free::Free(Scope* scope, TypeLevel level) - : index(++nextIndex) - , level(level) - , scope(scope) -{ -} - -int Free::DEPRECATED_nextIndex = 0; - -Generic::Generic() - : index(++nextIndex) - , name("g" + std::to_string(index)) -{ -} - -Generic::Generic(TypeLevel level) - : index(++nextIndex) - , level(level) - , name("g" + std::to_string(index)) +int freshIndex() { + return ++nextIndex; } -Generic::Generic(const Name& name) - : index(++nextIndex) - , name(name) - , explicitName(true) -{ -} - -Generic::Generic(Scope* scope) - : index(++nextIndex) - , scope(scope) -{ -} - -Generic::Generic(TypeLevel level, const Name& name) - : index(++nextIndex) - , level(level) - , name(name) - , explicitName(true) -{ -} - -Generic::Generic(Scope* scope, const Name& name) - : index(++nextIndex) - , scope(scope) - , name(name) - , explicitName(true) -{ -} - -int Generic::DEPRECATED_nextIndex = 0; - Error::Error() : index(++nextIndex) { diff --git a/Analysis/src/Unifier.cpp b/Analysis/src/Unifier.cpp index aba642714..b1e16c25a 100644 --- a/Analysis/src/Unifier.cpp +++ b/Analysis/src/Unifier.cpp @@ -18,15 +18,10 @@ LUAU_FASTINT(LuauTypeInferTypePackLoopLimit) LUAU_FASTFLAG(LuauErrorRecoveryType) LUAU_FASTFLAGVARIABLE(LuauInstantiateInSubtyping, false) -LUAU_FASTFLAGVARIABLE(LuauScalarShapeUnifyToMtOwner2, false) -LUAU_FASTFLAGVARIABLE(LuauUninhabitedSubAnything2, false) -LUAU_FASTFLAGVARIABLE(LuauMaintainScopesInUnifier, false) -LUAU_FASTFLAGVARIABLE(LuauTableUnifyInstantiationFix, false) -LUAU_FASTFLAG(LuauClassTypeVarsInSubstitution) -LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) -LUAU_FASTFLAG(LuauNegatedFunctionTypes) -LUAU_FASTFLAG(LuauNegatedClassTypes) -LUAU_FASTFLAG(LuauNegatedTableTypes) +LUAU_FASTFLAGVARIABLE(LuauTransitiveSubtyping, false) +LUAU_FASTFLAG(LuauSolverV2) +LUAU_FASTFLAGVARIABLE(LuauFixIndexerSubtypingOrdering, false) +LUAU_FASTFLAGVARIABLE(LuauUnifierRecursionOnRestart, false) namespace Luau { @@ -52,7 +47,7 @@ struct PromoteTypeLevels final : TypeOnceVisitor template void promote(TID ty, T* t) { - if (FFlag::DebugLuauDeferredConstraintResolution && !t) + if (useScopes && !t) return; LUAU_ASSERT(t); @@ -108,7 +103,7 @@ struct PromoteTypeLevels final : TypeOnceVisitor // Surprise, it's actually a BoundTypePack that hasn't been committed yet. // Calling getMutable on this will trigger an assertion. - if (FFlag::LuauScalarShapeUnifyToMtOwner2 && !log.is(ty)) + if (!log.is(ty)) return true; promote(ty, log.getMutable(ty)); @@ -126,7 +121,7 @@ struct PromoteTypeLevels final : TypeOnceVisitor // Surprise, it's actually a BoundTypePack that hasn't been committed yet. // Calling getMutable on this will trigger an assertion. - if (FFlag::LuauScalarShapeUnifyToMtOwner2 && !log.is(ty)) + if (!log.is(ty)) return true; promote(ty, log.getMutable(ty)); @@ -191,6 +186,18 @@ struct SkipCacheForType final : TypeOnceVisitor return false; } + bool visit(TypeId, const BlockedType&) override + { + result = true; + return false; + } + + bool visit(TypeId, const PendingExpansionType&) override + { + result = true; + return false; + } + bool visit(TypeId ty, const TableType&) override { // Types from other modules don't contain mutable elements and are ok to cache @@ -258,6 +265,12 @@ struct SkipCacheForType final : TypeOnceVisitor return false; } + bool visit(TypePackId tp, const BlockedTypePack&) override + { + result = true; + return false; + } + const DenseHashMap& skipCacheForType; const TypeArena* typeArena = nullptr; bool result = false; @@ -296,7 +309,7 @@ TypePackId Widen::clean(TypePackId) bool Widen::ignoreChildren(TypeId ty) { - if (FFlag::LuauClassTypeVarsInSubstitution && get(ty)) + if (get(ty)) return true; return !log->is(ty); @@ -314,7 +327,8 @@ TypePackId Widen::operator()(TypePackId tp) std::optional hasUnificationTooComplex(const ErrorVec& errors) { - auto isUnificationTooComplex = [](const TypeError& te) { + auto isUnificationTooComplex = [](const TypeError& te) + { return nullptr != get(te); }; @@ -325,6 +339,20 @@ std::optional hasUnificationTooComplex(const ErrorVec& errors) return *it; } +std::optional hasCountMismatch(const ErrorVec& errors) +{ + auto isCountMismatch = [](const TypeError& te) + { + return nullptr != get(te); + }; + + auto it = std::find_if(errors.begin(), errors.end(), isCountMismatch); + if (it == errors.end()) + return std::nullopt; + else + return *it; +} + // Used for tagged union matching heuristic, returns first singleton type field static std::optional> getTableMatchTag(TypeId type) { @@ -332,7 +360,7 @@ static std::optional> getTableMatchT { for (auto&& [name, prop] : ttv->props) { - if (auto sing = get(follow(prop.type))) + if (auto sing = get(follow(prop.type()))) return {{name, sing}}; } } @@ -340,7 +368,6 @@ static std::optional> getTableMatchT return std::nullopt; } -// TODO: Inline and clip with FFlag::DebugLuauDeferredConstraintResolution template static bool subsumes(bool useScopes, TY_A* left, TY_B* right) { @@ -364,11 +391,10 @@ TypeMismatch::Context Unifier::mismatchContext() } } -Unifier::Unifier(NotNull normalizer, Mode mode, NotNull scope, const Location& location, Variance variance, TxnLog* parentLog) +Unifier::Unifier(NotNull normalizer, NotNull scope, const Location& location, Variance variance, TxnLog* parentLog) : types(normalizer->arena) , builtinTypes(normalizer->builtinTypes) , normalizer(normalizer) - , mode(mode) , scope(scope) , log(parentLog) , location(location) @@ -376,16 +402,31 @@ Unifier::Unifier(NotNull normalizer, Mode mode, NotNull scope , sharedState(*normalizer->sharedState) { LUAU_ASSERT(sharedState.iceHandler); + + // Unifier is not usable when this flag is enabled! Please consider using Subtyping instead. + LUAU_ASSERT(!FFlag::LuauSolverV2); } -void Unifier::tryUnify(TypeId subTy, TypeId superTy, bool isFunctionCall, bool isIntersection) +void Unifier::tryUnify(TypeId subTy, TypeId superTy, bool isFunctionCall, bool isIntersection, const LiteralProperties* literalProperties) { sharedState.counters.iterationCount = 0; - tryUnify_(subTy, superTy, isFunctionCall, isIntersection); + tryUnify_(subTy, superTy, isFunctionCall, isIntersection, literalProperties); +} + +static bool isBlocked(const TxnLog& log, TypeId ty) +{ + ty = log.follow(ty); + return get(ty) || get(ty); +} + +static bool isBlocked(const TxnLog& log, TypePackId tp) +{ + tp = log.follow(tp); + return get(tp); } -void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool isIntersection) +void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool isIntersection, const LiteralProperties* literalProperties) { RecursionLimiter _ra(&sharedState.counters.recursionCount, sharedState.counters.recursionLimit); @@ -403,21 +444,37 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool if (superTy == subTy) return; + if (isBlocked(log, subTy) && isBlocked(log, superTy)) + { + blockedTypes.push_back(subTy); + blockedTypes.push_back(superTy); + } + else if (isBlocked(log, subTy)) + blockedTypes.push_back(subTy); + else if (isBlocked(log, superTy)) + blockedTypes.push_back(superTy); + + if (log.get(superTy)) + ice("Unexpected TypeFunctionInstanceType superTy"); + + if (log.get(subTy)) + ice("Unexpected TypeFunctionInstanceType subTy"); + auto superFree = log.getMutable(superTy); auto subFree = log.getMutable(subTy); - if (superFree && subFree && subsumes(useScopes, superFree, subFree)) + if (superFree && subFree && subsumes(useNewSolver, superFree, subFree)) { - if (!occursCheck(subTy, superTy)) + if (!occursCheck(subTy, superTy, /* reversed = */ false)) log.replace(subTy, BoundType(superTy)); return; } else if (superFree && subFree) { - if (!occursCheck(superTy, subTy)) + if (!occursCheck(superTy, subTy, /* reversed = */ true)) { - if (subsumes(useScopes, superFree, subFree)) + if (subsumes(useNewSolver, superFree, subFree)) { log.changeLevel(subTy, superFree->level); } @@ -431,16 +488,16 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool { // Unification can't change the level of a generic. auto subGeneric = log.getMutable(subTy); - if (subGeneric && !subsumes(useScopes, subGeneric, superFree)) + if (subGeneric && !subsumes(useNewSolver, subGeneric, superFree)) { // TODO: a more informative error message? CLI-39912 reportError(location, GenericError{"Generic subtype escaping scope"}); return; } - if (!occursCheck(superTy, subTy)) + if (!occursCheck(superTy, subTy, /* reversed = */ true)) { - promoteTypeLevels(log, types, superFree->level, superFree->scope, useScopes, subTy); + promoteTypeLevels(log, types, superFree->level, superFree->scope, useNewSolver, subTy); Widen widen{types, builtinTypes}; log.replace(superTy, BoundType(widen(subTy))); @@ -457,36 +514,85 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool // Unification can't change the level of a generic. auto superGeneric = log.getMutable(superTy); - if (superGeneric && !subsumes(useScopes, superGeneric, subFree)) + if (superGeneric && !subsumes(useNewSolver, superGeneric, subFree)) { // TODO: a more informative error message? CLI-39912 reportError(location, GenericError{"Generic supertype escaping scope"}); return; } - if (!occursCheck(subTy, superTy)) + if (!occursCheck(subTy, superTy, /* reversed = */ false)) { - promoteTypeLevels(log, types, subFree->level, subFree->scope, useScopes, superTy); + promoteTypeLevels(log, types, subFree->level, subFree->scope, useNewSolver, superTy); log.replace(subTy, BoundType(superTy)); } return; } + if (hideousFixMeGenericsAreActuallyFree) + { + auto superGeneric = log.getMutable(superTy); + auto subGeneric = log.getMutable(subTy); + + if (superGeneric && subGeneric && subsumes(useNewSolver, superGeneric, subGeneric)) + { + if (!occursCheck(subTy, superTy, /* reversed = */ false)) + log.replace(subTy, BoundType(superTy)); + + return; + } + else if (superGeneric && subGeneric) + { + if (!occursCheck(superTy, subTy, /* reversed = */ true)) + log.replace(superTy, BoundType(subTy)); + + return; + } + else if (superGeneric) + { + if (!occursCheck(superTy, subTy, /* reversed = */ true)) + { + Widen widen{types, builtinTypes}; + log.replace(superTy, BoundType(widen(subTy))); + } + + return; + } + else if (subGeneric) + { + // Normally, if the subtype is free, it should not be bound to any, unknown, or error types. + // But for bug compatibility, we'll only apply this rule to unknown. Doing this will silence cascading type errors. + if (log.get(superTy)) + return; + + if (!occursCheck(subTy, superTy, /* reversed = */ false)) + log.replace(subTy, BoundType(superTy)); + + return; + } + } + if (log.get(superTy)) return tryUnifyWithAny(subTy, builtinTypes->anyType); - if (log.get(superTy)) - return tryUnifyWithAny(subTy, builtinTypes->errorType); + if (log.get(subTy)) + { + if (normalize) + { + // TODO: there are probably cheaper ways to check if any <: T. + std::shared_ptr superNorm = normalizer->normalize(superTy); - if (log.get(superTy)) - return tryUnifyWithAny(subTy, builtinTypes->unknownType); + if (!superNorm) + return reportError(location, NormalizationTooComplex{}); - if (log.get(subTy)) + if (!log.get(superNorm->tops)) + failure = true; + } + else + failure = true; return tryUnifyWithAny(superTy, builtinTypes->anyType); - - if (log.get(subTy)) - return tryUnifyWithAny(superTy, builtinTypes->errorType); + } if (log.get(subTy)) return tryUnifyWithAny(superTy, builtinTypes->neverType); @@ -519,12 +625,7 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool size_t errorCount = errors.size(); - if (log.getMutable(subTy) && log.getMutable(superTy)) - { - blockedTypes.push_back(subTy); - blockedTypes.push_back(superTy); - } - else if (const UnionType* subUnion = log.getMutable(subTy)) + if (const UnionType* subUnion = log.getMutable(subTy)) { tryUnifyUnionWithType(subTy, subUnion, superTy); } @@ -540,19 +641,47 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool { tryUnifyIntersectionWithType(subTy, uv, superTy, cacheEnabled, isFunctionCall); } + else if (log.get(subTy)) + { + tryUnifyWithAny(superTy, builtinTypes->unknownType); + failure = true; + } + else if (log.get(subTy) && log.get(superTy)) + { + // error <: error + } + else if (log.get(superTy)) + { + tryUnifyWithAny(subTy, builtinTypes->errorType); + failure = true; + } + else if (log.get(subTy)) + { + tryUnifyWithAny(superTy, builtinTypes->errorType); + failure = true; + } + else if (log.get(superTy)) + { + // At this point, all the supertypes of `error` have been handled, + // and if `error unknownType); + } + else if (log.get(superTy)) + { + tryUnifyWithAny(subTy, builtinTypes->unknownType); + } else if (log.getMutable(superTy) && log.getMutable(subTy)) tryUnifyPrimitives(subTy, superTy); else if ((log.getMutable(superTy) || log.getMutable(superTy)) && log.getMutable(subTy)) tryUnifySingletons(subTy, superTy); - else if (auto ptv = get(superTy); - FFlag::LuauNegatedFunctionTypes && ptv && ptv->type == PrimitiveType::Function && get(subTy)) + else if (auto ptv = get(superTy); ptv && ptv->type == PrimitiveType::Function && get(subTy)) { // Ok. Do nothing. forall functions F, F <: function } - else if (FFlag::LuauNegatedTableTypes && isPrim(superTy, PrimitiveType::Table) && (get(subTy) || get(subTy))) + else if (isPrim(superTy, PrimitiveType::Table) && (get(subTy) || get(subTy))) { // Ok, do nothing: forall tables T, T <: table } @@ -567,7 +696,7 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool else if (log.getMutable(superTy) && log.getMutable(subTy)) { - tryUnifyTables(subTy, superTy, isIntersection); + tryUnifyTables(subTy, superTy, isIntersection, literalProperties); } else if (log.get(superTy) && (log.get(subTy) || log.get(subTy))) { @@ -595,10 +724,10 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool else if (log.get(superTy) || log.get(subTy)) tryUnifyNegations(subTy, superTy); - else if (FFlag::LuauUninhabitedSubAnything2 && checkInhabited && !normalizer->isInhabited(subTy)) + // If the normalizer hits resource limits, we can't show it's uninhabited, so, we should error. + else if (checkInhabited && normalizer->isInhabited(subTy) == NormalizationResult::False) { } - else reportError(location, TypeMismatch{superTy, subTy, mismatchContext()}); @@ -612,6 +741,7 @@ void Unifier::tryUnifyUnionWithType(TypeId subTy, const UnionType* subUnion, Typ { // A | B <: T if and only if A <: T and B <: T bool failed = false; + bool errorsSuppressed = true; std::optional unificationTooComplex; std::optional firstFailedOption; @@ -619,65 +749,29 @@ void Unifier::tryUnifyUnionWithType(TypeId subTy, const UnionType* subUnion, Typ for (TypeId type : subUnion->options) { - Unifier innerState = makeChildUnifier(); - innerState.tryUnify_(type, superTy); + std::unique_ptr innerState = makeChildUnifier(); + innerState->tryUnify_(type, superTy); - if (FFlag::DebugLuauDeferredConstraintResolution) - logs.push_back(std::move(innerState.log)); + if (useNewSolver) + logs.push_back(std::move(innerState->log)); - if (auto e = hasUnificationTooComplex(innerState.errors)) + if (auto e = hasUnificationTooComplex(innerState->errors)) unificationTooComplex = e; - else if (!innerState.errors.empty()) + else if (innerState->failure) { + // If errors were suppressed, we store the log up, so we can commit it if no other option succeeds. + if (innerState->errors.empty()) + logs.push_back(std::move(innerState->log)); // 'nil' option is skipped from extended report because we present the type in a special way - 'T?' - if (!firstFailedOption && !isNil(type)) - firstFailedOption = {innerState.errors.front()}; + else if (!firstFailedOption && !isNil(type)) + firstFailedOption = {innerState->errors.front()}; failed = true; + errorsSuppressed &= innerState->errors.empty(); } } - if (FFlag::DebugLuauDeferredConstraintResolution) - log.concatAsUnion(combineLogsIntoUnion(std::move(logs)), NotNull{types}); - else - { - // even if A | B <: T fails, we want to bind some options of T with A | B iff A | B was a subtype of that option. - auto tryBind = [this, subTy](TypeId superOption) { - superOption = log.follow(superOption); - - // just skip if the superOption is not free-ish. - auto ttv = log.getMutable(superOption); - if (!log.is(superOption) && (!ttv || ttv->state != TableState::Free)) - return; - - // If superOption is already present in subTy, do nothing. Nothing new has been learned, but the subtype - // test is successful. - if (auto subUnion = get(subTy)) - { - if (end(subUnion) != std::find(begin(subUnion), end(subUnion), superOption)) - return; - } - - // Since we have already checked if S <: T, checking it again will not queue up the type for replacement. - // So we'll have to do it ourselves. We assume they unified cleanly if they are still in the seen set. - if (log.haveSeen(subTy, superOption)) - { - // TODO: would it be nice for TxnLog::replace to do this? - if (log.is(superOption)) - log.bindTable(superOption, subTy); - else - log.replace(superOption, *subTy); - } - }; - - if (auto superUnion = log.getMutable(superTy)) - { - for (TypeId ty : superUnion) - tryBind(ty); - } - else - tryBind(superTy); - } + log.concatAsUnion(combineLogsIntoUnion(std::move(logs)), NotNull{types}); if (unificationTooComplex) reportError(*unificationTooComplex); @@ -685,8 +779,9 @@ void Unifier::tryUnifyUnionWithType(TypeId subTy, const UnionType* subUnion, Typ { if (firstFailedOption) reportError(location, TypeMismatch{superTy, subTy, "Not all union options are compatible.", *firstFailedOption, mismatchContext()}); - else + else if (!errorsSuppressed) reportError(location, TypeMismatch{superTy, subTy, mismatchContext()}); + failure = true; } } @@ -694,6 +789,7 @@ void Unifier::tryUnifyTypeWithUnion(TypeId subTy, TypeId superTy, const UnionTyp { // T <: A | B if T <: A or T <: B bool found = false; + bool errorsSuppressed = false; std::optional unificationTooComplex; size_t failedOptionCount = 0; @@ -730,6 +826,21 @@ void Unifier::tryUnifyTypeWithUnion(TypeId subTy, TypeId superTy, const UnionTyp } } + if (!foundHeuristic) + { + for (size_t i = 0; i < uv->options.size(); ++i) + { + TypeId type = uv->options[i]; + + if (subTy == type) + { + foundHeuristic = true; + startIndex = i; + break; + } + } + } + if (!foundHeuristic && cacheEnabled) { auto& cache = sharedState.cachedUnify; @@ -751,22 +862,26 @@ void Unifier::tryUnifyTypeWithUnion(TypeId subTy, TypeId superTy, const UnionTyp for (size_t i = 0; i < uv->options.size(); ++i) { TypeId type = uv->options[(i + startIndex) % uv->options.size()]; - Unifier innerState = makeChildUnifier(); - innerState.normalize = false; - innerState.tryUnify_(subTy, type, isFunctionCall); + std::unique_ptr innerState = makeChildUnifier(); + innerState->normalize = false; + innerState->tryUnify_(subTy, type, isFunctionCall); - if (innerState.errors.empty()) + if (!innerState->failure) { found = true; - if (FFlag::DebugLuauDeferredConstraintResolution) - logs.push_back(std::move(innerState.log)); + if (useNewSolver) + logs.push_back(std::move(innerState->log)); else { - log.concat(std::move(innerState.log)); + log.concat(std::move(innerState->log)); break; } } - else if (auto e = hasUnificationTooComplex(innerState.errors)) + else if (innerState->errors.empty()) + { + errorsSuppressed = true; + } + else if (auto e = hasUnificationTooComplex(innerState->errors)) { unificationTooComplex = e; } @@ -775,11 +890,11 @@ void Unifier::tryUnifyTypeWithUnion(TypeId subTy, TypeId superTy, const UnionTyp failedOptionCount++; if (!failedOption) - failedOption = {innerState.errors.front()}; + failedOption = {innerState->errors.front()}; } } - if (FFlag::DebugLuauDeferredConstraintResolution) + if (useNewSolver) log.concatAsUnion(combineLogsIntoUnion(std::move(logs)), NotNull{types}); if (unificationTooComplex) @@ -791,10 +906,35 @@ void Unifier::tryUnifyTypeWithUnion(TypeId subTy, TypeId superTy, const UnionTyp // It is possible that T <: A | B even though T normalize(subTy); - const NormalizedType* superNorm = normalizer->normalize(superTy); + std::unique_ptr innerState = makeChildUnifier(); + + std::shared_ptr subNorm = normalizer->normalize(subTy); + std::shared_ptr superNorm = normalizer->normalize(superTy); if (!subNorm || !superNorm) - reportError(location, UnificationTooComplex{}); + return reportError(location, NormalizationTooComplex{}); + else if ((failedOptionCount == 1 || foundHeuristic) && failedOption) + innerState->tryUnifyNormalizedTypes( + subTy, superTy, *subNorm, *superNorm, "None of the union options are compatible. For example:", *failedOption + ); + else + innerState->tryUnifyNormalizedTypes(subTy, superTy, *subNorm, *superNorm, "none of the union options are compatible"); + + if (!innerState->failure) + log.concat(std::move(innerState->log)); + else if (errorsSuppressed || innerState->errors.empty()) + failure = true; + else + reportError(std::move(innerState->errors.front())); + } + else if (!found && normalize) + { + // It is possible that T <: A | B even though T subNorm = normalizer->normalize(subTy); + std::shared_ptr superNorm = normalizer->normalize(superTy); + if (!subNorm || !superNorm) + reportError(location, NormalizationTooComplex{}); else if ((failedOptionCount == 1 || foundHeuristic) && failedOption) tryUnifyNormalizedTypes(subTy, superTy, *subNorm, *superNorm, "None of the union options are compatible. For example:", *failedOption); else @@ -802,9 +942,12 @@ void Unifier::tryUnifyTypeWithUnion(TypeId subTy, TypeId superTy, const UnionTyp } else if (!found) { - if ((failedOptionCount == 1 || foundHeuristic) && failedOption) + if (errorsSuppressed) + failure = true; + else if ((failedOptionCount == 1 || foundHeuristic) && failedOption) reportError( - location, TypeMismatch{superTy, subTy, "None of the union options are compatible. For example:", *failedOption, mismatchContext()}); + location, TypeMismatch{superTy, subTy, "None of the union options are compatible. For example:", *failedOption, mismatchContext()} + ); else reportError(location, TypeMismatch{superTy, subTy, "none of the union options are compatible", mismatchContext()}); } @@ -820,24 +963,25 @@ void Unifier::tryUnifyTypeWithIntersection(TypeId subTy, TypeId superTy, const I // T <: A & B if and only if T <: A and T <: B for (TypeId type : uv->parts) { - Unifier innerState = makeChildUnifier(); - innerState.tryUnify_(subTy, type, /*isFunctionCall*/ false, /*isIntersection*/ true); + std::unique_ptr innerState = makeChildUnifier(); + innerState->tryUnify_(subTy, type, /*isFunctionCall*/ false, /*isIntersection*/ true); - if (auto e = hasUnificationTooComplex(innerState.errors)) + if (auto e = hasUnificationTooComplex(innerState->errors)) unificationTooComplex = e; - else if (!innerState.errors.empty()) + else if (!innerState->errors.empty()) { if (!firstFailedOption) - firstFailedOption = {innerState.errors.front()}; + firstFailedOption = {innerState->errors.front()}; } - if (FFlag::DebugLuauDeferredConstraintResolution) - logs.push_back(std::move(innerState.log)); + if (useNewSolver) + logs.push_back(std::move(innerState->log)); else - log.concat(std::move(innerState.log)); + log.concat(std::move(innerState->log)); + failure |= innerState->failure; } - if (FFlag::DebugLuauDeferredConstraintResolution) + if (useNewSolver) log.concat(combineLogsIntoIntersection(std::move(logs))); if (unificationTooComplex) @@ -866,6 +1010,7 @@ void Unifier::tryUnifyIntersectionWithType(TypeId subTy, const IntersectionType* { // A & B <: T if A <: T or B <: T bool found = false; + bool errorsSuppressed = false; std::optional unificationTooComplex; size_t startIndex = 0; @@ -886,7 +1031,7 @@ void Unifier::tryUnifyIntersectionWithType(TypeId subTy, const IntersectionType* } } - if (FFlag::DebugLuauDeferredConstraintResolution && normalize) + if (useNewSolver && normalize) { // Sometimes a negation type is inside one of the types, e.g. { p: number } & { p: ~number }. NegationTypeFinder finder; @@ -897,12 +1042,12 @@ void Unifier::tryUnifyIntersectionWithType(TypeId subTy, const IntersectionType* // It is possible that A & B <: T even though A normalize(subTy); - const NormalizedType* superNorm = normalizer->normalize(superTy); + std::shared_ptr subNorm = normalizer->normalize(subTy); + std::shared_ptr superNorm = normalizer->normalize(superTy); if (subNorm && superNorm) tryUnifyNormalizedTypes(subTy, superTy, *subNorm, *superNorm, "none of the intersection parts are compatible"); else - reportError(location, UnificationTooComplex{}); + reportError(location, NormalizationTooComplex{}); return; } @@ -913,29 +1058,36 @@ void Unifier::tryUnifyIntersectionWithType(TypeId subTy, const IntersectionType* for (size_t i = 0; i < uv->parts.size(); ++i) { TypeId type = uv->parts[(i + startIndex) % uv->parts.size()]; - Unifier innerState = makeChildUnifier(); - innerState.normalize = false; - innerState.tryUnify_(type, superTy, isFunctionCall); + std::unique_ptr innerState = makeChildUnifier(); + innerState->normalize = false; + innerState->tryUnify_(type, superTy, isFunctionCall); - if (innerState.errors.empty()) + // TODO: This sets errorSuppressed to true if any of the parts is error-suppressing, + // in paricular any & T is error-suppressing. Really, errorSuppressed should be true if + // all of the parts are error-suppressing, but that fails to typecheck lua-apps. + if (innerState->errors.empty()) { found = true; - if (FFlag::DebugLuauDeferredConstraintResolution) - logs.push_back(std::move(innerState.log)); + errorsSuppressed = innerState->failure; + if (useNewSolver || innerState->failure) + logs.push_back(std::move(innerState->log)); else { - log.concat(std::move(innerState.log)); + errorsSuppressed = false; + log.concat(std::move(innerState->log)); break; } } - else if (auto e = hasUnificationTooComplex(innerState.errors)) + else if (auto e = hasUnificationTooComplex(innerState->errors)) { unificationTooComplex = e; } } - if (FFlag::DebugLuauDeferredConstraintResolution) + if (useNewSolver) log.concat(combineLogsIntoIntersection(std::move(logs))); + else if (errorsSuppressed) + log.concat(std::move(logs.front())); if (unificationTooComplex) reportError(*unificationTooComplex); @@ -944,30 +1096,51 @@ void Unifier::tryUnifyIntersectionWithType(TypeId subTy, const IntersectionType* // It is possible that A & B <: T even though A normalize(subTy); - const NormalizedType* superNorm = normalizer->normalize(superTy); + + std::shared_ptr subNorm = normalizer->normalize(subTy); + std::shared_ptr superNorm = normalizer->normalize(superTy); if (subNorm && superNorm) tryUnifyNormalizedTypes(subTy, superTy, *subNorm, *superNorm, "none of the intersection parts are compatible"); else - reportError(location, UnificationTooComplex{}); + reportError(location, NormalizationTooComplex{}); } else if (!found) { reportError(location, TypeMismatch{superTy, subTy, "none of the intersection parts are compatible", mismatchContext()}); } + else if (errorsSuppressed) + failure = true; } void Unifier::tryUnifyNormalizedTypes( - TypeId subTy, TypeId superTy, const NormalizedType& subNorm, const NormalizedType& superNorm, std::string reason, std::optional error) -{ - if (get(superNorm.tops) || get(superNorm.tops) || get(subNorm.tops)) + TypeId subTy, + TypeId superTy, + const NormalizedType& subNorm, + const NormalizedType& superNorm, + std::string reason, + std::optional error +) +{ + if (get(superNorm.tops)) return; - else if (get(subNorm.tops)) - return reportError(location, TypeMismatch{superTy, subTy, reason, error, mismatchContext()}); + else if (get(subNorm.tops)) + { + failure = true; + return; + } if (get(subNorm.errors)) if (!get(superNorm.errors)) - return reportError(location, TypeMismatch{superTy, subTy, reason, error, mismatchContext()}); + { + failure = true; + return; + } + + if (get(superNorm.tops)) + return; + + if (get(subNorm.tops)) + return reportError(location, TypeMismatch{superTy, subTy, reason, error, mismatchContext()}); if (get(subNorm.booleans)) { @@ -995,81 +1168,59 @@ void Unifier::tryUnifyNormalizedTypes( if (!get(superNorm.errors)) return reportError(location, TypeMismatch{superTy, subTy, reason, error, mismatchContext()}); - if (FFlag::LuauNegatedClassTypes) + for (const auto& [subClass, _] : subNorm.classes.classes) { - for (const auto& [subClass, _] : subNorm.classes.classes) + bool found = false; + const ClassType* subCtv = get(subClass); + LUAU_ASSERT(subCtv); + + for (const auto& [superClass, superNegations] : superNorm.classes.classes) { - bool found = false; - const ClassType* subCtv = get(subClass); - LUAU_ASSERT(subCtv); + const ClassType* superCtv = get(superClass); + LUAU_ASSERT(superCtv); - for (const auto& [superClass, superNegations] : superNorm.classes.classes) + if (isSubclass(subCtv, superCtv)) { - const ClassType* superCtv = get(superClass); - LUAU_ASSERT(superCtv); + found = true; - if (isSubclass(subCtv, superCtv)) + for (TypeId negation : superNegations) { - found = true; + const ClassType* negationCtv = get(negation); + LUAU_ASSERT(negationCtv); - for (TypeId negation : superNegations) + if (isSubclass(subCtv, negationCtv)) { - const ClassType* negationCtv = get(negation); - LUAU_ASSERT(negationCtv); - - if (isSubclass(subCtv, negationCtv)) - { - found = false; - break; - } - } - - if (found) - break; - } - } - - if (FFlag::DebugLuauDeferredConstraintResolution) - { - for (TypeId superTable : superNorm.tables) - { - Unifier innerState = makeChildUnifier(); - innerState.tryUnify(subClass, superTable); - - if (innerState.errors.empty()) - { - found = true; - log.concat(std::move(innerState.log)); + found = false; break; } - else if (auto e = hasUnificationTooComplex(innerState.errors)) - return reportError(*e); } - } - if (!found) - { - return reportError(location, TypeMismatch{superTy, subTy, reason, error, mismatchContext()}); + if (found) + break; } } - } - else - { - for (TypeId subClass : subNorm.DEPRECATED_classes) + + if (useNewSolver) { - bool found = false; - const ClassType* subCtv = get(subClass); - for (TypeId superClass : superNorm.DEPRECATED_classes) + for (TypeId superTable : superNorm.tables) { - const ClassType* superCtv = get(superClass); - if (isSubclass(subCtv, superCtv)) + std::unique_ptr innerState = makeChildUnifier(); + innerState->tryUnify(subClass, superTable); + + if (innerState->errors.empty()) { found = true; + log.concat(std::move(innerState->log)); break; } + else if (auto e = hasUnificationTooComplex(innerState->errors)) + return reportError(*e); } - if (!found) - return reportError(location, TypeMismatch{superTy, subTy, reason, error, mismatchContext()}); + } + + if (!found) + { + return reportError(location, TypeMismatch{superTy, subTy, reason, error, mismatchContext()}); } } @@ -1078,26 +1229,23 @@ void Unifier::tryUnifyNormalizedTypes( bool found = false; for (TypeId superTable : superNorm.tables) { - if (FFlag::LuauNegatedTableTypes && isPrim(superTable, PrimitiveType::Table)) + if (isPrim(superTable, PrimitiveType::Table)) { found = true; break; } - Unifier innerState = makeChildUnifier(); - if (get(superTable)) - innerState.tryUnifyWithMetatable(subTable, superTable, /* reversed */ false); - else if (get(subTable)) - innerState.tryUnifyWithMetatable(superTable, subTable, /* reversed */ true); - else - innerState.tryUnifyTables(subTable, superTable); - if (innerState.errors.empty()) + std::unique_ptr innerState = makeChildUnifier(); + + innerState->tryUnify(subTable, superTable); + + if (innerState->errors.empty()) { found = true; - log.concat(std::move(innerState.log)); + log.concat(std::move(innerState->log)); break; } - else if (auto e = hasUnificationTooComplex(innerState.errors)) + else if (auto e = hasUnificationTooComplex(innerState->errors)) return reportError(*e); } if (!found) @@ -1108,17 +1256,17 @@ void Unifier::tryUnifyNormalizedTypes( { if (superNorm.functions.isNever()) return reportError(location, TypeMismatch{superTy, subTy, reason, error, mismatchContext()}); - for (TypeId superFun : *superNorm.functions.parts) + for (TypeId superFun : superNorm.functions.parts) { - Unifier innerState = makeChildUnifier(); + std::unique_ptr innerState = makeChildUnifier(); const FunctionType* superFtv = get(superFun); if (!superFtv) return reportError(location, TypeMismatch{superTy, subTy, reason, error, mismatchContext()}); - TypePackId tgt = innerState.tryApplyOverloadedFunction(subTy, subNorm.functions, superFtv->argTypes); - innerState.tryUnify_(tgt, superFtv->retTypes); - if (innerState.errors.empty()) - log.concat(std::move(innerState.log)); - else if (auto e = hasUnificationTooComplex(innerState.errors)) + TypePackId tgt = innerState->tryApplyOverloadedFunction(subTy, subNorm.functions, superFtv->argTypes); + innerState->tryUnify_(tgt, superFtv->retTypes); + if (innerState->errors.empty()) + log.concat(std::move(innerState->log)); + else if (auto e = hasUnificationTooComplex(innerState->errors)) return reportError(*e); else return reportError(location, TypeMismatch{superTy, subTy, reason, error, mismatchContext()}); @@ -1147,7 +1295,7 @@ TypePackId Unifier::tryApplyOverloadedFunction(TypeId function, const Normalized std::optional result; const FunctionType* firstFun = nullptr; - for (TypeId overload : *overloads.parts) + for (TypeId overload : overloads.parts) { if (const FunctionType* ftv = get(overload)) { @@ -1156,17 +1304,17 @@ TypePackId Unifier::tryApplyOverloadedFunction(TypeId function, const Normalized { if (!firstFun) firstFun = ftv; - Unifier innerState = makeChildUnifier(); - innerState.tryUnify_(args, ftv->argTypes); - if (innerState.errors.empty()) + std::unique_ptr innerState = makeChildUnifier(); + innerState->tryUnify_(args, ftv->argTypes); + if (innerState->errors.empty()) { - log.concat(std::move(innerState.log)); + log.concat(std::move(innerState->log)); if (result) { - innerState.log.clear(); - innerState.tryUnify_(*result, ftv->retTypes); - if (innerState.errors.empty()) - log.concat(std::move(innerState.log)); + innerState->log.clear(); + innerState->tryUnify_(*result, ftv->retTypes); + if (innerState->errors.empty()) + log.concat(std::move(innerState->log)); // Annoyingly, since we don't support intersection of generic type packs, // the intersection may fail. We rather arbitrarily use the first matching overload // in that case. @@ -1176,7 +1324,7 @@ TypePackId Unifier::tryApplyOverloadedFunction(TypeId function, const Normalized else result = ftv->retTypes; } - else if (auto e = hasUnificationTooComplex(innerState.errors)) + else if (auto e = hasUnificationTooComplex(innerState->errors)) { reportError(*e); return builtinTypes->errorRecoveryTypePack(args); @@ -1214,7 +1362,8 @@ bool Unifier::canCacheResult(TypeId subTy, TypeId superTy) if (subTyInfo && *subTyInfo) return false; - auto skipCacheFor = [this](TypeId ty) { + auto skipCacheFor = [this](TypeId ty) + { SkipCacheForType visitor{sharedState.skipCacheForType, types}; visitor.traverse(ty); @@ -1283,6 +1432,15 @@ struct WeirdIter return pack != nullptr && index < pack->head.size(); } + std::optional tail() const + { + if (!pack) + return packId; + + LUAU_ASSERT(index == pack->head.size()); + return pack->tail; + } + bool advance() { if (!pack) @@ -1306,7 +1464,7 @@ struct WeirdIter bool canGrow() const { - return nullptr != log.getMutable(packId); + return nullptr != log.getMutable(packId); } void grow(TypePackId newTail) @@ -1314,10 +1472,10 @@ struct WeirdIter LUAU_ASSERT(canGrow()); LUAU_ASSERT(log.getMutable(newTail)); - auto freePack = log.getMutable(packId); + auto freePack = log.getMutable(packId); level = freePack->level; - if (FFlag::LuauMaintainScopesInUnifier && freePack->scope != nullptr) + if (freePack->scope != nullptr) scope = freePack->scope; log.replace(packId, BoundTypePack(newTail)); packId = newTail; @@ -1344,20 +1502,26 @@ struct WeirdIter } }; +void Unifier::enableNewSolver() +{ + useNewSolver = true; + log.useScopes = true; +} + ErrorVec Unifier::canUnify(TypeId subTy, TypeId superTy) { - Unifier s = makeChildUnifier(); - s.tryUnify_(subTy, superTy); + std::unique_ptr s = makeChildUnifier(); + s->tryUnify_(subTy, superTy); - return s.errors; + return s->errors; } ErrorVec Unifier::canUnify(TypePackId subTy, TypePackId superTy, bool isFunctionCall) { - Unifier s = makeChildUnifier(); - s.tryUnify_(subTy, superTy, isFunctionCall); + std::unique_ptr s = makeChildUnifier(); + s->tryUnify_(subTy, superTy, isFunctionCall); - return s.errors; + return s->errors; } void Unifier::tryUnify(TypePackId subTp, TypePackId superTp, bool isFunctionCall) @@ -1408,17 +1572,46 @@ void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCal if (log.haveSeen(superTp, subTp)) return; - if (log.getMutable(superTp)) + if (isBlocked(log, subTp) && isBlocked(log, superTp)) + { + blockedTypePacks.push_back(subTp); + blockedTypePacks.push_back(superTp); + } + else if (isBlocked(log, subTp)) + blockedTypePacks.push_back(subTp); + else if (isBlocked(log, superTp)) + blockedTypePacks.push_back(superTp); + + if (auto superFree = log.getMutable(superTp)) + { + if (!occursCheck(superTp, subTp, /* reversed = */ true)) + { + Widen widen{types, builtinTypes}; + if (useNewSolver) + promoteTypeLevels(log, types, superFree->level, superFree->scope, /*useScopes*/ true, subTp); + log.replace(superTp, Unifiable::Bound(widen(subTp))); + } + } + else if (auto subFree = log.getMutable(subTp)) { - if (!occursCheck(superTp, subTp)) + if (!occursCheck(subTp, superTp, /* reversed = */ false)) + { + if (useNewSolver) + promoteTypeLevels(log, types, subFree->level, subFree->scope, /*useScopes*/ true, superTp); + log.replace(subTp, Unifiable::Bound(superTp)); + } + } + else if (hideousFixMeGenericsAreActuallyFree && log.getMutable(superTp)) + { + if (!occursCheck(superTp, subTp, /* reversed = */ true)) { Widen widen{types, builtinTypes}; log.replace(superTp, Unifiable::Bound(widen(subTp))); } } - else if (log.getMutable(subTp)) + else if (hideousFixMeGenericsAreActuallyFree && log.getMutable(subTp)) { - if (!occursCheck(subTp, superTp)) + if (!occursCheck(subTp, superTp, /* reversed = */ false)) { log.replace(subTp, Unifiable::Bound(superTp)); } @@ -1447,14 +1640,15 @@ void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCal auto superIter = WeirdIter(superTp, log); auto subIter = WeirdIter(subTp, log); - if (FFlag::LuauMaintainScopesInUnifier) - { - superIter.scope = scope.get(); - subIter.scope = scope.get(); - } + superIter.scope = scope.get(); + subIter.scope = scope.get(); - auto mkFreshType = [this](Scope* scope, TypeLevel level) { - return types->freshType(scope, level); + auto mkFreshType = [this](Scope* scope, TypeLevel level) + { + if (FFlag::LuauSolverV2) + return freshType(NotNull{types}, builtinTypes, scope); + else + return types->freshType(scope, level); }; const TypePackId emptyTp = types->addTypePack(TypePack{{}, std::nullopt}); @@ -1493,28 +1687,74 @@ void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCal // If both are at the end, we're done if (!superIter.good() && !subIter.good()) { - const bool lFreeTail = superTpv->tail && log.getMutable(log.follow(*superTpv->tail)) != nullptr; - const bool rFreeTail = subTpv->tail && log.getMutable(log.follow(*subTpv->tail)) != nullptr; - if (lFreeTail && rFreeTail) - { - tryUnify_(*subTpv->tail, *superTpv->tail); - } - else if (lFreeTail) - { - tryUnify_(emptyTp, *superTpv->tail); - } - else if (rFreeTail) + if (useNewSolver) { - tryUnify_(emptyTp, *subTpv->tail); + if (subIter.tail() && superIter.tail()) + tryUnify_(*subIter.tail(), *superIter.tail()); + else if (subIter.tail()) + { + const TypePackId subTail = log.follow(*subIter.tail()); + + if (log.get(subTail)) + tryUnify_(subTail, emptyTp); + else if (log.get(subTail)) + reportError(location, TypePackMismatch{subTail, emptyTp}); + else if (log.get(subTail) || log.get(subTail)) + { + // Nothing. This is ok. + } + else + { + ice("Unexpected subtype tail pack " + toString(subTail), location); + } + } + else if (superIter.tail()) + { + const TypePackId superTail = log.follow(*superIter.tail()); + + if (log.get(superTail)) + tryUnify_(emptyTp, superTail); + else if (log.get(superTail)) + reportError(location, TypePackMismatch{emptyTp, superTail}); + else if (log.get(superTail) || log.get(superTail)) + { + // Nothing. This is ok. + } + else + { + ice("Unexpected supertype tail pack " + toString(superTail), location); + } + } + else + { + // Nothing. This is ok. + } } - else if (subTpv->tail && superTpv->tail) + else { - if (log.getMutable(superIter.packId)) - tryUnifyVariadics(subIter.packId, superIter.packId, false, int(subIter.index)); - else if (log.getMutable(subIter.packId)) - tryUnifyVariadics(superIter.packId, subIter.packId, true, int(superIter.index)); - else + const bool lFreeTail = superTpv->tail && log.getMutable(log.follow(*superTpv->tail)) != nullptr; + const bool rFreeTail = subTpv->tail && log.getMutable(log.follow(*subTpv->tail)) != nullptr; + if (lFreeTail && rFreeTail) + { tryUnify_(*subTpv->tail, *superTpv->tail); + } + else if (lFreeTail) + { + tryUnify_(emptyTp, *superTpv->tail); + } + else if (rFreeTail) + { + tryUnify_(emptyTp, *subTpv->tail); + } + else if (subTpv->tail && superTpv->tail) + { + if (log.getMutable(superIter.packId)) + tryUnifyVariadics(subIter.packId, superIter.packId, false, int(subIter.index)); + else if (log.getMutable(subIter.packId)) + tryUnifyVariadics(superIter.packId, subIter.packId, true, int(superIter.index)); + else + tryUnify_(*subTpv->tail, *superTpv->tail); + } } break; @@ -1644,9 +1884,9 @@ void Unifier::tryUnifyFunctions(TypeId subTy, TypeId superTy, bool isFunctionCal // generic methods in tables to be marked read-only. if (FFlag::LuauInstantiateInSubtyping && shouldInstantiate) { - Instantiation instantiation{&log, types, scope->level, scope}; + std::unique_ptr instantiation = std::make_unique(&log, types, builtinTypes, scope->level, scope); - std::optional instantiated = instantiation.substitute(subTy); + std::optional instantiated = instantiation->substitute(subTy); if (instantiated.has_value()) { subFunction = log.getMutable(*instantiated); @@ -1690,38 +1930,54 @@ void Unifier::tryUnifyFunctions(TypeId subTy, TypeId superTy, bool isFunctionCal if (!isFunctionCall) { - Unifier innerState = makeChildUnifier(); + std::unique_ptr innerState = makeChildUnifier(); - innerState.ctx = CountMismatch::Arg; - innerState.tryUnify_(superFunction->argTypes, subFunction->argTypes, isFunctionCall); + innerState->ctx = CountMismatch::Arg; + innerState->tryUnify_(superFunction->argTypes, subFunction->argTypes, isFunctionCall); - bool reported = !innerState.errors.empty(); + bool reported = !innerState->errors.empty(); - if (auto e = hasUnificationTooComplex(innerState.errors)) + if (auto e = hasUnificationTooComplex(innerState->errors)) reportError(*e); - else if (!innerState.errors.empty() && innerState.firstPackErrorPos) - reportError(location, TypeMismatch{superTy, subTy, format("Argument #%d type is not compatible.", *innerState.firstPackErrorPos), - innerState.errors.front(), mismatchContext()}); - else if (!innerState.errors.empty()) - reportError(location, TypeMismatch{superTy, subTy, "", innerState.errors.front(), mismatchContext()}); + else if (!innerState->errors.empty() && innerState->firstPackErrorPos) + reportError( + location, + TypeMismatch{ + superTy, + subTy, + format("Argument #%d type is not compatible.", *innerState->firstPackErrorPos), + innerState->errors.front(), + mismatchContext() + } + ); + else if (!innerState->errors.empty()) + reportError(location, TypeMismatch{superTy, subTy, "", innerState->errors.front(), mismatchContext()}); - innerState.ctx = CountMismatch::FunctionResult; - innerState.tryUnify_(subFunction->retTypes, superFunction->retTypes); + innerState->ctx = CountMismatch::FunctionResult; + innerState->tryUnify_(subFunction->retTypes, superFunction->retTypes); if (!reported) { - if (auto e = hasUnificationTooComplex(innerState.errors)) + if (auto e = hasUnificationTooComplex(innerState->errors)) reportError(*e); - else if (!innerState.errors.empty() && size(superFunction->retTypes) == 1 && finite(superFunction->retTypes)) - reportError(location, TypeMismatch{superTy, subTy, "Return type is not compatible.", innerState.errors.front(), mismatchContext()}); - else if (!innerState.errors.empty() && innerState.firstPackErrorPos) - reportError(location, TypeMismatch{superTy, subTy, format("Return #%d type is not compatible.", *innerState.firstPackErrorPos), - innerState.errors.front(), mismatchContext()}); - else if (!innerState.errors.empty()) - reportError(location, TypeMismatch{superTy, subTy, "", innerState.errors.front(), mismatchContext()}); + else if (!innerState->errors.empty() && size(superFunction->retTypes) == 1 && finite(superFunction->retTypes)) + reportError(location, TypeMismatch{superTy, subTy, "Return type is not compatible.", innerState->errors.front(), mismatchContext()}); + else if (!innerState->errors.empty() && innerState->firstPackErrorPos) + reportError( + location, + TypeMismatch{ + superTy, + subTy, + format("Return #%d type is not compatible.", *innerState->firstPackErrorPos), + innerState->errors.front(), + mismatchContext() + } + ); + else if (!innerState->errors.empty()) + reportError(location, TypeMismatch{superTy, subTy, "", innerState->errors.front(), mismatchContext()}); } - log.concat(std::move(innerState.log)); + log.concat(std::move(innerState->log)); } else { @@ -1771,7 +2027,7 @@ struct Resetter } // namespace -void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) +void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection, const LiteralProperties* literalProperties) { if (isPrim(log.follow(subTy), PrimitiveType::Table)) subTy = builtinTypes->emptyTableType; @@ -1782,7 +2038,6 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) TypeId activeSubTy = subTy; TableType* superTable = log.getMutable(superTy); TableType* subTable = log.getMutable(subTy); - TableType* instantiatedSubTable = subTable; // TODO: remove with FFlagLuauTableUnifyInstantiationFix if (!superTable || !subTable) ice("passed non-table types to unifyTables"); @@ -1794,21 +2049,13 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) { if (variance == Covariant && subTable->state == TableState::Generic && superTable->state != TableState::Generic) { - Instantiation instantiation{&log, types, subTable->level, scope}; + Instantiation instantiation{&log, types, builtinTypes, subTable->level, scope}; std::optional instantiated = instantiation.substitute(subTy); if (instantiated.has_value()) { - if (FFlag::LuauTableUnifyInstantiationFix) - { - activeSubTy = *instantiated; - subTable = log.getMutable(activeSubTy); - } - else - { - subTable = log.getMutable(*instantiated); - instantiatedSubTable = subTable; - } + activeSubTy = *instantiated; + subTable = log.getMutable(activeSubTy); if (!subTable) ice("instantiation made a table type into a non-table type in tryUnifyTables"); @@ -1827,7 +2074,7 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) { auto subIter = subTable->props.find(propName); - if (subIter == subTable->props.end() && subTable->state == TableState::Unsealed && !isOptional(superProp.type)) + if (subIter == subTable->props.end() && subTable->state == TableState::Unsealed && !isOptional(superProp.type())) missingProperties.push_back(propName); } @@ -1865,32 +2112,36 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) { // TODO: read-only properties don't need invariance Resetter resetter{&variance}; - variance = Invariant; + if (!literalProperties || !literalProperties->contains(name)) + variance = Invariant; - Unifier innerState = makeChildUnifier(); - innerState.tryUnify_(r->second.type, prop.type); + std::unique_ptr innerState = makeChildUnifier(); + innerState->tryUnify_(r->second.type(), prop.type()); - checkChildUnifierTypeMismatch(innerState.errors, name, superTy, subTy); + checkChildUnifierTypeMismatch(innerState->errors, name, superTy, subTy); - if (innerState.errors.empty()) - log.concat(std::move(innerState.log)); + if (innerState->errors.empty()) + log.concat(std::move(innerState->log)); + failure |= innerState->failure; } else if (subTable->indexer && maybeString(subTable->indexer->indexType)) { // TODO: read-only indexers don't need invariance // TODO: really we should only allow this if prop.type is optional. Resetter resetter{&variance}; - variance = Invariant; + if (!literalProperties || !literalProperties->contains(name)) + variance = Invariant; - Unifier innerState = makeChildUnifier(); - innerState.tryUnify_(subTable->indexer->indexResultType, prop.type); + std::unique_ptr innerState = makeChildUnifier(); + innerState->tryUnify_(subTable->indexer->indexResultType, prop.type()); - checkChildUnifierTypeMismatch(innerState.errors, name, superTy, subTy); + checkChildUnifierTypeMismatch(innerState->errors, name, superTy, subTy); - if (innerState.errors.empty()) - log.concat(std::move(innerState.log)); + if (innerState->errors.empty()) + log.concat(std::move(innerState->log)); + failure |= innerState->failure; } - else if (subTable->state == TableState::Unsealed && isOptional(prop.type)) + else if (subTable->state == TableState::Unsealed && isOptional(prop.type())) // This is sound because unsealed table types are precise, so `{ p : T } <: { p : T, q : U? }` // since if `t : { p : T }` then we are guaranteed that `t.q` is `nil`. // TODO: if the supertype is written to, the subtype may no longer be precise (alias analysis?) @@ -1910,26 +2161,37 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) // Recursive unification can change the txn log, and invalidate the old // table. If we detect that this has happened, we start over, with the updated // txn log. - TypeId superTyNew = FFlag::LuauScalarShapeUnifyToMtOwner2 ? log.follow(superTy) : superTy; - TypeId subTyNew = FFlag::LuauScalarShapeUnifyToMtOwner2 ? log.follow(activeSubTy) : activeSubTy; + TypeId superTyNew = log.follow(superTy); + TypeId subTyNew = log.follow(activeSubTy); - if (FFlag::LuauScalarShapeUnifyToMtOwner2) + // If one of the types stopped being a table altogether, we need to restart from the top + if ((superTy != superTyNew || activeSubTy != subTyNew) && errors.empty()) { - // If one of the types stopped being a table altogether, we need to restart from the top - if ((superTy != superTyNew || activeSubTy != subTyNew) && errors.empty()) + if (FFlag::LuauUnifierRecursionOnRestart) + { + RecursionLimiter _ra(&sharedState.counters.recursionCount, sharedState.counters.recursionLimit); + tryUnify(subTy, superTy, false, isIntersection); + return; + } + else + { return tryUnify(subTy, superTy, false, isIntersection); + } } // Otherwise, restart only the table unification TableType* newSuperTable = log.getMutable(superTyNew); TableType* newSubTable = log.getMutable(subTyNew); - if (superTable != newSuperTable || (subTable != newSubTable && (FFlag::LuauTableUnifyInstantiationFix || subTable != instantiatedSubTable))) + if (superTable != newSuperTable || subTable != newSubTable) { if (errors.empty()) - return tryUnifyTables(subTy, superTy, isIntersection); - else - return; + { + RecursionLimiter _ra(&sharedState.counters.recursionCount, sharedState.counters.recursionLimit); + tryUnifyTables(subTy, superTy, isIntersection); + } + + return; } } @@ -1945,15 +2207,23 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) // TODO: read-only indexers don't need invariance // TODO: really we should only allow this if prop.type is optional. Resetter resetter{&variance}; - variance = Invariant; + if (!literalProperties || !literalProperties->contains(name)) + variance = Invariant; - Unifier innerState = makeChildUnifier(); - innerState.tryUnify_(superTable->indexer->indexResultType, prop.type); + std::unique_ptr innerState = makeChildUnifier(); + if (useNewSolver || FFlag::LuauFixIndexerSubtypingOrdering) + innerState->tryUnify_(prop.type(), superTable->indexer->indexResultType); + else + { + // Incredibly, the old solver depends on this bug somehow. + innerState->tryUnify_(superTable->indexer->indexResultType, prop.type()); + } - checkChildUnifierTypeMismatch(innerState.errors, name, superTy, subTy); + checkChildUnifierTypeMismatch(innerState->errors, name, superTy, subTy); - if (innerState.errors.empty()) - log.concat(std::move(innerState.log)); + if (innerState->errors.empty()) + log.concat(std::move(innerState->log)); + failure |= innerState->failure; } else if (superTable->state == TableState::Unsealed) { @@ -1961,7 +2231,7 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) // TODO: file a JIRA // TODO: hopefully readonly/writeonly properties will fix this. Property clone = prop; - clone.type = deeplyOptional(clone.type); + clone.setType(deeplyOptional(clone.type())); PendingType* pendingSuper = log.queue(superTy); TableType* pendingSuperTtv = getMutable(pendingSuper); @@ -1981,14 +2251,22 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) else extraProperties.push_back(name); - TypeId superTyNew = FFlag::LuauScalarShapeUnifyToMtOwner2 ? log.follow(superTy) : superTy; - TypeId subTyNew = FFlag::LuauScalarShapeUnifyToMtOwner2 ? log.follow(activeSubTy) : activeSubTy; + TypeId superTyNew = log.follow(superTy); + TypeId subTyNew = log.follow(activeSubTy); - if (FFlag::LuauScalarShapeUnifyToMtOwner2) + // If one of the types stopped being a table altogether, we need to restart from the top + if ((superTy != superTyNew || activeSubTy != subTyNew) && errors.empty()) { - // If one of the types stopped being a table altogether, we need to restart from the top - if ((superTy != superTyNew || activeSubTy != subTyNew) && errors.empty()) + if (FFlag::LuauUnifierRecursionOnRestart) + { + RecursionLimiter _ra(&sharedState.counters.recursionCount, sharedState.counters.recursionLimit); + tryUnify(subTy, superTy, false, isIntersection); + return; + } + else + { return tryUnify(subTy, superTy, false, isIntersection); + } } // Recursive unification can change the txn log, and invalidate the old @@ -1997,12 +2275,15 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) TableType* newSuperTable = log.getMutable(superTyNew); TableType* newSubTable = log.getMutable(subTyNew); - if (superTable != newSuperTable || (subTable != newSubTable && (FFlag::LuauTableUnifyInstantiationFix || subTable != instantiatedSubTable))) + if (superTable != newSuperTable || subTable != newSubTable) { if (errors.empty()) - return tryUnifyTables(subTy, superTy, isIntersection); - else - return; + { + RecursionLimiter _ra(&sharedState.counters.recursionCount, sharedState.counters.recursionLimit); + tryUnifyTables(subTy, superTy, isIntersection); + } + + return; } } @@ -2013,21 +2294,22 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) Resetter resetter{&variance}; variance = Invariant; - Unifier innerState = makeChildUnifier(); + std::unique_ptr innerState = makeChildUnifier(); - innerState.tryUnify_(subTable->indexer->indexType, superTable->indexer->indexType); + innerState->tryUnify_(subTable->indexer->indexType, superTable->indexer->indexType); - bool reported = !innerState.errors.empty(); + bool reported = !innerState->errors.empty(); - checkChildUnifierTypeMismatch(innerState.errors, "[indexer key]", superTy, subTy); + checkChildUnifierTypeMismatch(innerState->errors, "[indexer key]", superTy, subTy); - innerState.tryUnify_(subTable->indexer->indexResultType, superTable->indexer->indexResultType); + innerState->tryUnify_(subTable->indexer->indexResultType, superTable->indexer->indexResultType); if (!reported) - checkChildUnifierTypeMismatch(innerState.errors, "[indexer value]", superTy, subTy); + checkChildUnifierTypeMismatch(innerState->errors, "[indexer value]", superTy, subTy); - if (innerState.errors.empty()) - log.concat(std::move(innerState.log)); + if (innerState->errors.empty()) + log.concat(std::move(innerState->log)); + failure |= innerState->failure; } else if (superTable->indexer) { @@ -2050,19 +2332,11 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) } // Changing the indexer can invalidate the table pointers. - if (FFlag::LuauScalarShapeUnifyToMtOwner2) - { - superTable = log.getMutable(log.follow(superTy)); - subTable = log.getMutable(log.follow(activeSubTy)); + superTable = log.getMutable(log.follow(superTy)); + subTable = log.getMutable(log.follow(activeSubTy)); - if (!superTable || !subTable) - return; - } - else - { - superTable = log.getMutable(superTy); - subTable = log.getMutable(activeSubTy); - } + if (!superTable || !subTable) + return; if (!missingProperties.empty()) { @@ -2102,7 +2376,8 @@ void Unifier::tryUnifyScalarShape(TypeId subTy, TypeId superTy, bool reversed) TypeId osubTy = subTy; TypeId osuperTy = superTy; - if (FFlag::LuauUninhabitedSubAnything2 && checkInhabited && !normalizer->isInhabited(subTy)) + // If the normalizer hits resource limits, we can't show it's uninhabited, so, we should continue. + if (checkInhabited && normalizer->isInhabited(subTy) == NormalizationResult::False) return; if (reversed) @@ -2113,7 +2388,8 @@ void Unifier::tryUnifyScalarShape(TypeId subTy, TypeId superTy, bool reversed) if (!superTable || superTable->state != TableState::Free) return reportError(location, TypeMismatch{osuperTy, osubTy, mismatchContext()}); - auto fail = [&](std::optional e) { + auto fail = [&](std::optional e) + { std::string reason = "The former's metatable does not satisfy the requirements."; if (e) reportError(location, TypeMismatch{osuperTy, osubTy, reason, *e, mismatchContext()}); @@ -2131,38 +2407,32 @@ void Unifier::tryUnifyScalarShape(TypeId subTy, TypeId superTy, bool reversed) if (auto it = mttv->props.find("__index"); it != mttv->props.end()) { - TypeId ty = it->second.type; - Unifier child = makeChildUnifier(); - child.tryUnify_(ty, superTy); + TypeId ty = it->second.type(); + std::unique_ptr child = makeChildUnifier(); + child->tryUnify_(ty, superTy); - if (FFlag::LuauScalarShapeUnifyToMtOwner2) - { - // To perform subtype <: free table unification, we have tried to unify (subtype's metatable) <: free table - // There is a chance that it was unified with the origial subtype, but then, (subtype's metatable) <: subtype could've failed - // Here we check if we have a new supertype instead of the original free table and try original subtype <: new supertype check - TypeId newSuperTy = child.log.follow(superTy); + // To perform subtype <: free table unification, we have tried to unify (subtype's metatable) <: free table + // There is a chance that it was unified with the origial subtype, but then, (subtype's metatable) <: subtype could've failed + // Here we check if we have a new supertype instead of the original free table and try original subtype <: new supertype check + TypeId newSuperTy = child->log.follow(superTy); - if (superTy != newSuperTy && canUnify(subTy, newSuperTy).empty()) - { - log.replace(superTy, BoundType{subTy}); - return; - } + if (superTy != newSuperTy && canUnify(subTy, newSuperTy).empty()) + { + log.replace(superTy, BoundType{subTy}); + return; } - if (auto e = hasUnificationTooComplex(child.errors)) + if (auto e = hasUnificationTooComplex(child->errors)) reportError(*e); - else if (!child.errors.empty()) - fail(child.errors.front()); + else if (!child->errors.empty()) + fail(child->errors.front()); - log.concat(std::move(child.log)); + log.concat(std::move(child->log)); - if (FFlag::LuauScalarShapeUnifyToMtOwner2) - { - // To perform subtype <: free table unification, we have tried to unify (subtype's metatable) <: free table - // We return success because subtype <: free table which means that correct unification is to replace free table with the subtype - if (child.errors.empty()) - log.replace(superTy, BoundType{subTy}); - } + // To perform subtype <: free table unification, we have tried to unify (subtype's metatable) <: free table + // We return success because subtype <: free table which means that correct unification is to replace free table with the subtype + if (child->errors.empty()) + log.replace(superTy, BoundType{subTy}); return; } @@ -2189,7 +2459,7 @@ TypeId Unifier::deeplyOptional(TypeId ty, std::unordered_map see result = types->addType(*ttv); TableType* resultTtv = getMutable(result); for (auto& [name, prop] : resultTtv->props) - prop.type = deeplyOptional(prop.type, seen); + prop.setType(deeplyOptional(prop.type(), seen)); return types->addType(UnionType{{builtinTypes->nilType, result}}); } else @@ -2206,17 +2476,19 @@ void Unifier::tryUnifyWithMetatable(TypeId subTy, TypeId superTy, bool reversed) if (const MetatableType* subMetatable = log.getMutable(subTy)) { - Unifier innerState = makeChildUnifier(); - innerState.tryUnify_(subMetatable->table, superMetatable->table); - innerState.tryUnify_(subMetatable->metatable, superMetatable->metatable); + std::unique_ptr innerState = makeChildUnifier(); + innerState->tryUnify_(subMetatable->table, superMetatable->table); + innerState->tryUnify_(subMetatable->metatable, superMetatable->metatable); - if (auto e = hasUnificationTooComplex(innerState.errors)) + if (auto e = hasUnificationTooComplex(innerState->errors)) reportError(*e); - else if (!innerState.errors.empty()) + else if (!innerState->errors.empty()) reportError( - location, TypeMismatch{reversed ? subTy : superTy, reversed ? superTy : subTy, "", innerState.errors.front(), mismatchContext()}); + location, TypeMismatch{reversed ? subTy : superTy, reversed ? superTy : subTy, "", innerState->errors.front(), mismatchContext()} + ); - log.concat(std::move(innerState.log)); + log.concat(std::move(innerState->log)); + failure |= innerState->failure; } else if (TableType* subTable = log.getMutable(subTy)) { @@ -2224,16 +2496,16 @@ void Unifier::tryUnifyWithMetatable(TypeId subTy, TypeId superTy, bool reversed) { case TableState::Free: { - if (FFlag::DebugLuauDeferredConstraintResolution) + if (useNewSolver) { - Unifier innerState = makeChildUnifier(); + std::unique_ptr innerState = makeChildUnifier(); bool missingProperty = false; for (const auto& [propName, prop] : subTable->props) { if (std::optional mtPropTy = findTablePropertyRespectingMeta(superTy, propName)) { - innerState.tryUnify(prop.type, *mtPropTy); + innerState->tryUnify(prop.type(), *mtPropTy); } else { @@ -2248,15 +2520,18 @@ void Unifier::tryUnifyWithMetatable(TypeId subTy, TypeId superTy, bool reversed) // TODO: Unify indexers. } - if (auto e = hasUnificationTooComplex(innerState.errors)) + if (auto e = hasUnificationTooComplex(innerState->errors)) reportError(*e); - else if (!innerState.errors.empty()) - reportError(TypeError{location, - TypeMismatch{reversed ? subTy : superTy, reversed ? superTy : subTy, "", innerState.errors.front(), mismatchContext()}}); + else if (!innerState->errors.empty()) + reportError(TypeError{ + location, + TypeMismatch{reversed ? subTy : superTy, reversed ? superTy : subTy, "", innerState->errors.front(), mismatchContext()} + }); else if (!missingProperty) { - log.concat(std::move(innerState.log)); + log.concat(std::move(innerState->log)); log.bindTable(subTy, superTy); + failure |= innerState->failure; } } else @@ -2289,7 +2564,8 @@ void Unifier::tryUnifyWithClass(TypeId subTy, TypeId superTy, bool reversed) if (reversed) std::swap(superTy, subTy); - auto fail = [&]() { + auto fail = [&]() + { if (!reversed) reportError(location, TypeMismatch{superTy, subTy, mismatchContext()}); else @@ -2342,14 +2618,15 @@ void Unifier::tryUnifyWithClass(TypeId subTy, TypeId superTy, bool reversed) } else { - Unifier innerState = makeChildUnifier(); - innerState.tryUnify_(classProp->type, prop.type); + std::unique_ptr innerState = makeChildUnifier(); + innerState->tryUnify_(classProp->type(), prop.type()); - checkChildUnifierTypeMismatch(innerState.errors, propName, reversed ? subTy : superTy, reversed ? superTy : subTy); + checkChildUnifierTypeMismatch(innerState->errors, propName, reversed ? subTy : superTy, reversed ? superTy : subTy); - if (innerState.errors.empty()) + if (innerState->errors.empty()) { - log.concat(std::move(innerState.log)); + log.concat(std::move(innerState->log)); + failure |= innerState->failure; } else { @@ -2379,15 +2656,15 @@ void Unifier::tryUnifyNegations(TypeId subTy, TypeId superTy) if (!log.get(subTy) && !log.get(superTy)) ice("tryUnifyNegations superTy or subTy must be a negation type"); - const NormalizedType* subNorm = normalizer->normalize(subTy); - const NormalizedType* superNorm = normalizer->normalize(superTy); + std::shared_ptr subNorm = normalizer->normalize(subTy); + std::shared_ptr superNorm = normalizer->normalize(superTy); if (!subNorm || !superNorm) - return reportError(location, UnificationTooComplex{}); + return reportError(location, NormalizationTooComplex{}); // T state = makeChildUnifier(); + state->tryUnifyNormalizedTypes(subTy, superTy, *subNorm, *superNorm, ""); + if (state->errors.empty()) reportError(location, TypeMismatch{superTy, subTy, mismatchContext()}); } @@ -2401,9 +2678,9 @@ static void queueTypePack(std::vector& queue, DenseHashSet& break; seenTypePacks.insert(a); - if (state.log.getMutable(a)) + if (state.log.getMutable(a)) { - state.log.replace(a, Unifiable::Bound{anyTypePack}); + state.log.replace(a, BoundTypePack{anyTypePack}); } else if (auto tp = state.log.getMutable(a)) { @@ -2419,13 +2696,14 @@ static void queueTypePack(std::vector& queue, DenseHashSet& void Unifier::tryUnifyVariadics(TypePackId subTp, TypePackId superTp, bool reversed, int subOffset) { const VariadicTypePack* superVariadic = log.getMutable(superTp); + const TypeId variadicTy = follow(superVariadic->ty); if (!superVariadic) ice("passed non-variadic pack to tryUnifyVariadics"); if (const VariadicTypePack* subVariadic = log.get(subTp)) { - tryUnify_(reversed ? superVariadic->ty : subVariadic->ty, reversed ? subVariadic->ty : superVariadic->ty); + tryUnify_(reversed ? variadicTy : subVariadic->ty, reversed ? subVariadic->ty : variadicTy); } else if (log.get(subTp)) { @@ -2436,24 +2714,32 @@ void Unifier::tryUnifyVariadics(TypePackId subTp, TypePackId superTp, bool rever while (subIter != subEnd) { - tryUnify_(reversed ? superVariadic->ty : *subIter, reversed ? *subIter : superVariadic->ty); + tryUnify_(reversed ? variadicTy : *subIter, reversed ? *subIter : variadicTy); ++subIter; } if (std::optional maybeTail = subIter.tail()) { TypePackId tail = follow(*maybeTail); - if (get(tail)) + + if (isBlocked(log, tail)) + { + blockedTypePacks.push_back(tail); + } + else if (get(tail)) { log.replace(tail, BoundTypePack(superTp)); } else if (const VariadicTypePack* vtp = get(tail)) { - tryUnify_(vtp->ty, superVariadic->ty); + tryUnify_(vtp->ty, variadicTy); } - else if (get(tail)) + else if (get(tail)) { - reportError(location, GenericError{"Cannot unify variadic and generic packs"}); + if (!hideousFixMeGenericsAreActuallyFree) + reportError(location, GenericError{"Cannot unify variadic and generic packs"}); + else + log.replace(tail, BoundTypePack{superTp}); } else if (get(tail)) { @@ -2465,14 +2751,25 @@ void Unifier::tryUnifyVariadics(TypePackId subTp, TypePackId superTp, bool rever } } } + else if (get(variadicTy) && log.get(subTp)) + { + // Nothing to do. This is ok. + } else { reportError(location, GenericError{"Failed to unify variadic packs"}); } } -static void tryUnifyWithAny(std::vector& queue, Unifier& state, DenseHashSet& seen, DenseHashSet& seenTypePacks, - const TypeArena* typeArena, TypeId anyType, TypePackId anyTypePack) +static void tryUnifyWithAny( + std::vector& queue, + Unifier& state, + DenseHashSet& seen, + DenseHashSet& seenTypePacks, + const TypeArena* typeArena, + TypeId anyType, + TypePackId anyTypePack +) { while (!queue.empty()) { @@ -2501,7 +2798,7 @@ static void tryUnifyWithAny(std::vector& queue, Unifier& state, DenseHas else if (auto table = state.log.getMutable(ty)) { for (const auto& [_name, prop] : table->props) - queue.push_back(prop.type); + queue.push_back(prop.type()); if (table->indexer) { @@ -2569,8 +2866,8 @@ std::optional Unifier::findTablePropertyRespectingMeta(TypeId lhsType, N TxnLog Unifier::combineLogsIntoIntersection(std::vector logs) { - LUAU_ASSERT(FFlag::DebugLuauDeferredConstraintResolution); - TxnLog result; + LUAU_ASSERT(useNewSolver); + TxnLog result(useNewSolver); for (TxnLog& log : logs) result.concatAsIntersections(std::move(log), NotNull{types}); return result; @@ -2578,18 +2875,48 @@ TxnLog Unifier::combineLogsIntoIntersection(std::vector logs) TxnLog Unifier::combineLogsIntoUnion(std::vector logs) { - LUAU_ASSERT(FFlag::DebugLuauDeferredConstraintResolution); - TxnLog result; + TxnLog result(useNewSolver); for (TxnLog& log : logs) result.concatAsUnion(std::move(log), NotNull{types}); return result; } -bool Unifier::occursCheck(TypeId needle, TypeId haystack) +bool Unifier::occursCheck(TypeId needle, TypeId haystack, bool reversed) { sharedState.tempSeenTy.clear(); - return occursCheck(sharedState.tempSeenTy, needle, haystack); + bool occurs = occursCheck(sharedState.tempSeenTy, needle, haystack); + + if (occurs) + { + std::unique_ptr innerState = makeChildUnifier(); + if (const UnionType* ut = get(haystack)) + { + if (reversed) + innerState->tryUnifyUnionWithType(haystack, ut, needle); + else + innerState->tryUnifyTypeWithUnion(needle, haystack, ut, /* cacheEnabled = */ false, /* isFunction = */ false); + } + else if (const IntersectionType* it = get(haystack)) + { + if (reversed) + innerState->tryUnifyIntersectionWithType(haystack, it, needle, /* cacheEnabled = */ false, /* isFunction = */ false); + else + innerState->tryUnifyTypeWithIntersection(needle, haystack, it); + } + else + { + innerState->failure = true; + } + + if (innerState->failure) + { + reportError(location, OccursCheckFailed{}); + log.replace(needle, BoundType{builtinTypes->errorRecoveryType()}); + } + } + + return occurs; } bool Unifier::occursCheck(DenseHashSet& seen, TypeId needle, TypeId haystack) @@ -2598,7 +2925,8 @@ bool Unifier::occursCheck(DenseHashSet& seen, TypeId needle, TypeId hays bool occurrence = false; - auto check = [&](TypeId tv) { + auto check = [&](TypeId tv) + { if (occursCheck(seen, needle, tv)) occurrence = true; }; @@ -2611,21 +2939,16 @@ bool Unifier::occursCheck(DenseHashSet& seen, TypeId needle, TypeId hays seen.insert(haystack); - if (log.getMutable(needle)) + if (log.getMutable(needle)) return false; - if (!log.getMutable(needle)) + if (!log.getMutable(needle) && !(hideousFixMeGenericsAreActuallyFree && log.is(needle))) ice("Expected needle to be free"); if (needle == haystack) - { - reportError(location, OccursCheckFailed{}); - log.replace(needle, *builtinTypes->errorRecoveryType()); - return true; - } - if (log.getMutable(haystack)) + if (log.getMutable(haystack) || (hideousFixMeGenericsAreActuallyFree && log.is(haystack))) return false; else if (auto a = log.getMutable(haystack)) { @@ -2641,11 +2964,19 @@ bool Unifier::occursCheck(DenseHashSet& seen, TypeId needle, TypeId hays return occurrence; } -bool Unifier::occursCheck(TypePackId needle, TypePackId haystack) +bool Unifier::occursCheck(TypePackId needle, TypePackId haystack, bool reversed) { sharedState.tempSeenTp.clear(); - return occursCheck(sharedState.tempSeenTp, needle, haystack); + bool occurs = occursCheck(sharedState.tempSeenTp, needle, haystack); + + if (occurs) + { + reportError(location, OccursCheckFailed{}); + log.replace(needle, BoundTypePack{builtinTypes->errorRecoveryTypePack()}); + } + + return occurs; } bool Unifier::occursCheck(DenseHashSet& seen, TypePackId needle, TypePackId haystack) @@ -2658,10 +2989,10 @@ bool Unifier::occursCheck(DenseHashSet& seen, TypePackId needle, Typ seen.insert(haystack); - if (log.getMutable(needle)) + if (log.getMutable(needle)) return false; - if (!log.getMutable(needle)) + if (!log.getMutable(needle) && !(hideousFixMeGenericsAreActuallyFree && log.is(needle))) ice("Expected needle pack to be free"); RecursionLimiter _ra(&sharedState.counters.recursionCount, sharedState.counters.recursionLimit); @@ -2669,12 +3000,7 @@ bool Unifier::occursCheck(DenseHashSet& seen, TypePackId needle, Typ while (!log.getMutable(haystack)) { if (needle == haystack) - { - reportError(location, OccursCheckFailed{}); - log.replace(needle, *builtinTypes->errorRecoveryTypePack()); - return true; - } if (auto a = get(haystack); a && a->tail) { @@ -2688,12 +3014,15 @@ bool Unifier::occursCheck(DenseHashSet& seen, TypePackId needle, Typ return false; } -Unifier Unifier::makeChildUnifier() +std::unique_ptr Unifier::makeChildUnifier() { - Unifier u = Unifier{normalizer, mode, scope, location, variance, &log}; - u.normalize = normalize; - u.checkInhabited = checkInhabited; - u.useScopes = useScopes; + std::unique_ptr u = std::make_unique(normalizer, scope, location, variance, &log); + u->normalize = normalize; + u->checkInhabited = checkInhabited; + + if (useNewSolver) + u->enableNewSolver(); + return u; } @@ -2704,6 +3033,7 @@ Unifier Unifier::makeChildUnifier() void Unifier::reportError(Location location, TypeErrorData data) { errors.emplace_back(std::move(location), std::move(data)); + failure = true; } // A utility function that appends the given error to the unifier's error log. @@ -2714,12 +3044,7 @@ void Unifier::reportError(Location location, TypeErrorData data) void Unifier::reportError(TypeError err) { errors.push_back(std::move(err)); -} - - -bool Unifier::isNonstrictMode() const -{ - return (mode == Mode::Nonstrict) || (mode == Mode::NoCheck); + failure = true; } void Unifier::checkChildUnifierTypeMismatch(const ErrorVec& innerErrors, TypeId wantedType, TypeId givenType) @@ -2735,8 +3060,10 @@ void Unifier::checkChildUnifierTypeMismatch(const ErrorVec& innerErrors, const s if (auto e = hasUnificationTooComplex(innerErrors)) reportError(*e); else if (!innerErrors.empty()) - reportError(TypeError{location, - TypeMismatch{wantedType, givenType, format("Property '%s' is not compatible.", prop.c_str()), innerErrors.front(), mismatchContext()}}); + reportError(TypeError{ + location, + TypeMismatch{wantedType, givenType, format("Property '%s' is not compatible.", prop.c_str()), innerErrors.front(), mismatchContext()} + }); } void Unifier::ice(const std::string& message, const Location& location) diff --git a/Analysis/src/Unifier2.cpp b/Analysis/src/Unifier2.cpp new file mode 100644 index 000000000..5ea11ad06 --- /dev/null +++ b/Analysis/src/Unifier2.cpp @@ -0,0 +1,928 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Luau/Unifier2.h" + +#include "Luau/Instantiation.h" +#include "Luau/Scope.h" +#include "Luau/Simplify.h" +#include "Luau/Type.h" +#include "Luau/TypeArena.h" +#include "Luau/TypeCheckLimits.h" +#include "Luau/TypeFunction.h" +#include "Luau/TypeFwd.h" +#include "Luau/TypePack.h" +#include "Luau/TypeUtils.h" +#include "Luau/VisitType.h" + +#include +#include + +LUAU_FASTINT(LuauTypeInferRecursionLimit) + +namespace Luau +{ + +static bool areCompatible(TypeId left, TypeId right) +{ + auto p = get2(follow(left), follow(right)); + if (!p) + return true; + + const TableType* leftTable = p.first; + LUAU_ASSERT(leftTable); + const TableType* rightTable = p.second; + LUAU_ASSERT(rightTable); + + const auto missingPropIsCompatible = [](const Property& leftProp, const TableType* rightTable) + { + // Two tables may be compatible even if their shapes aren't exactly the + // same if the extra property is optional, free (and therefore + // potentially optional), or if the right table has an indexer. Or if + // the right table is free (and therefore potentially has an indexer or + // a compatible property) + + LUAU_ASSERT(leftProp.isReadOnly() || leftProp.isShared()); + + const TypeId leftType = follow(leftProp.isReadOnly() ? *leftProp.readTy : leftProp.type()); + + if (isOptional(leftType) || get(leftType) || rightTable->state == TableState::Free || rightTable->indexer.has_value()) + return true; + + return false; + }; + + for (const auto& [name, leftProp] : leftTable->props) + { + auto it = rightTable->props.find(name); + if (it == rightTable->props.end()) + { + if (!missingPropIsCompatible(leftProp, rightTable)) + return false; + } + } + + for (const auto& [name, rightProp] : rightTable->props) + { + auto it = leftTable->props.find(name); + if (it == leftTable->props.end()) + { + if (!missingPropIsCompatible(rightProp, leftTable)) + return false; + } + } + + return true; +} + +// returns `true` if `ty` is irressolvable and should be added to `incompleteSubtypes`. +static bool isIrresolvable(TypeId ty) +{ + return get(ty) || get(ty); +} + +// returns `true` if `tp` is irressolvable and should be added to `incompleteSubtypes`. +static bool isIrresolvable(TypePackId tp) +{ + return get(tp) || get(tp); +} + +Unifier2::Unifier2(NotNull arena, NotNull builtinTypes, NotNull scope, NotNull ice) + : arena(arena) + , builtinTypes(builtinTypes) + , scope(scope) + , ice(ice) + , limits(TypeCheckLimits{}) // TODO: typecheck limits in unifier2 + , recursionLimit(FInt::LuauTypeInferRecursionLimit) + , uninhabitedTypeFunctions(nullptr) +{ +} + +Unifier2::Unifier2( + NotNull arena, + NotNull builtinTypes, + NotNull scope, + NotNull ice, + DenseHashSet* uninhabitedTypeFunctions +) + : arena(arena) + , builtinTypes(builtinTypes) + , scope(scope) + , ice(ice) + , limits(TypeCheckLimits{}) // TODO: typecheck limits in unifier2 + , recursionLimit(FInt::LuauTypeInferRecursionLimit) + , uninhabitedTypeFunctions(uninhabitedTypeFunctions) +{ +} + +bool Unifier2::unify(TypeId subTy, TypeId superTy) +{ + subTy = follow(subTy); + superTy = follow(superTy); + + if (auto subGen = genericSubstitutions.find(subTy)) + return unify(*subGen, superTy); + + if (auto superGen = genericSubstitutions.find(superTy)) + return unify(subTy, *superGen); + + if (seenTypePairings.contains({subTy, superTy})) + return true; + seenTypePairings.insert({subTy, superTy}); + + if (subTy == superTy) + return true; + + // We have potentially done some unifications while dispatching either `SubtypeConstraint` or `PackSubtypeConstraint`, + // so rather than implementing backtracking or traversing the entire type graph multiple times, we could push + // additional constraints as we discover blocked types along with their proper bounds. + // + // But we exclude these two subtyping patterns, they are tautological: + // - never <: *blocked* + // - *blocked* <: unknown + if ((isIrresolvable(subTy) || isIrresolvable(superTy)) && !get(subTy) && !get(superTy)) + { + if (uninhabitedTypeFunctions && (uninhabitedTypeFunctions->contains(subTy) || uninhabitedTypeFunctions->contains(superTy))) + return true; + + incompleteSubtypes.push_back(SubtypeConstraint{subTy, superTy}); + return true; + } + + FreeType* subFree = getMutable(subTy); + FreeType* superFree = getMutable(superTy); + + if (superFree) + { + superFree->lowerBound = mkUnion(superFree->lowerBound, subTy); + } + + if (subFree) + { + return unifyFreeWithType(subTy, superTy); + } + + if (subFree || superFree) + return true; + + auto subFn = get(subTy); + auto superFn = get(superTy); + if (subFn && superFn) + return unify(subTy, superFn); + + auto subUnion = get(subTy); + auto superUnion = get(superTy); + if (subUnion) + return unify(subUnion, superTy); + else if (superUnion) + return unify(subTy, superUnion); + + auto subIntersection = get(subTy); + auto superIntersection = get(superTy); + if (subIntersection) + return unify(subIntersection, superTy); + else if (superIntersection) + return unify(subTy, superIntersection); + + auto subNever = get(subTy); + auto superNever = get(superTy); + if (subNever && superNever) + return true; + else if (subNever && superFn) + { + // If `never` is the subtype, then we can propagate that inward. + bool argResult = unify(superFn->argTypes, builtinTypes->neverTypePack); + bool retResult = unify(builtinTypes->neverTypePack, superFn->retTypes); + return argResult && retResult; + } + else if (subFn && superNever) + { + // If `never` is the supertype, then we can propagate that inward. + bool argResult = unify(builtinTypes->neverTypePack, subFn->argTypes); + bool retResult = unify(subFn->retTypes, builtinTypes->neverTypePack); + return argResult && retResult; + } + + auto subAny = get(subTy); + auto superAny = get(superTy); + + auto subTable = getMutable(subTy); + auto superTable = get(superTy); + + if (subAny && superAny) + return true; + else if (subAny && superFn) + return unify(subAny, superFn); + else if (subFn && superAny) + return unify(subFn, superAny); + else if (subAny && superTable) + return unify(subAny, superTable); + else if (subTable && superAny) + return unify(subTable, superAny); + + if (subTable && superTable) + { + // `boundTo` works like a bound type, and therefore we'd replace it + // with the `boundTo` and try unification again. + // + // However, these pointers should have been chased already by follow(). + LUAU_ASSERT(!subTable->boundTo); + LUAU_ASSERT(!superTable->boundTo); + + return unify(subTable, superTable); + } + + auto subMetatable = get(subTy); + auto superMetatable = get(superTy); + if (subMetatable && superMetatable) + return unify(subMetatable, superMetatable); + else if (subMetatable) // if we only have one metatable, unify with the inner table + return unify(subMetatable->table, superTy); + else if (superMetatable) // if we only have one metatable, unify with the inner table + return unify(subTy, superMetatable->table); + + auto [subNegation, superNegation] = get2(subTy, superTy); + if (subNegation && superNegation) + return unify(subNegation->ty, superNegation->ty); + + // The unification failed, but we're not doing type checking. + return true; +} + +// If superTy is a function and subTy already has a +// potentially-compatible function in its upper bound, we assume that +// the function is not overloaded and attempt to combine superTy into +// subTy's existing function bound. +bool Unifier2::unifyFreeWithType(TypeId subTy, TypeId superTy) +{ + FreeType* subFree = getMutable(subTy); + LUAU_ASSERT(subFree); + + auto doDefault = [&]() + { + subFree->upperBound = mkIntersection(subFree->upperBound, superTy); + expandedFreeTypes[subTy].push_back(superTy); + return true; + }; + + TypeId upperBound = follow(subFree->upperBound); + + if (get(upperBound)) + return unify(subFree->upperBound, superTy); + + const FunctionType* superFunction = get(superTy); + if (!superFunction) + return doDefault(); + + const auto [superArgHead, superArgTail] = flatten(superFunction->argTypes); + if (superArgTail) + return doDefault(); + + const IntersectionType* upperBoundIntersection = get(subFree->upperBound); + if (!upperBoundIntersection) + return doDefault(); + + bool ok = true; + bool foundOne = false; + + for (TypeId part : upperBoundIntersection->parts) + { + const FunctionType* ft = get(follow(part)); + if (!ft) + continue; + + const auto [subArgHead, subArgTail] = flatten(ft->argTypes); + + if (!subArgTail && subArgHead.size() == superArgHead.size()) + { + foundOne = true; + ok &= unify(part, superTy); + } + } + + if (foundOne) + return ok; + else + return doDefault(); +} + +bool Unifier2::unify(TypeId subTy, const FunctionType* superFn) +{ + const FunctionType* subFn = get(subTy); + + bool shouldInstantiate = + (superFn->generics.empty() && !subFn->generics.empty()) || (superFn->genericPacks.empty() && !subFn->genericPacks.empty()); + + if (shouldInstantiate) + { + for (auto generic : subFn->generics) + genericSubstitutions[generic] = freshType(arena, builtinTypes, scope); + + for (auto genericPack : subFn->genericPacks) + genericPackSubstitutions[genericPack] = arena->freshTypePack(scope); + } + + bool argResult = unify(superFn->argTypes, subFn->argTypes); + bool retResult = unify(subFn->retTypes, superFn->retTypes); + return argResult && retResult; +} + +bool Unifier2::unify(const UnionType* subUnion, TypeId superTy) +{ + bool result = true; + + // if the occurs check fails for any option, it fails overall + for (auto subOption : subUnion->options) + { + if (areCompatible(subOption, superTy)) + result &= unify(subOption, superTy); + } + + return result; +} + +bool Unifier2::unify(TypeId subTy, const UnionType* superUnion) +{ + bool result = true; + + // if the occurs check fails for any option, it fails overall + for (auto superOption : superUnion->options) + { + if (areCompatible(subTy, superOption)) + result &= unify(subTy, superOption); + } + + return result; +} + +bool Unifier2::unify(const IntersectionType* subIntersection, TypeId superTy) +{ + bool result = true; + + // if the occurs check fails for any part, it fails overall + for (auto subPart : subIntersection->parts) + result &= unify(subPart, superTy); + + return result; +} + +bool Unifier2::unify(TypeId subTy, const IntersectionType* superIntersection) +{ + bool result = true; + + // if the occurs check fails for any part, it fails overall + for (auto superPart : superIntersection->parts) + result &= unify(subTy, superPart); + + return result; +} + +bool Unifier2::unify(TableType* subTable, const TableType* superTable) +{ + bool result = true; + + // It suffices to only check one direction of properties since we'll only ever have work to do during unification + // if the property is present in both table types. + for (const auto& [propName, subProp] : subTable->props) + { + auto superPropOpt = superTable->props.find(propName); + + if (superPropOpt != superTable->props.end()) + { + const Property& superProp = superPropOpt->second; + + if (subProp.isReadOnly() && superProp.isReadOnly()) + result &= unify(*subProp.readTy, *superPropOpt->second.readTy); + else if (subProp.isReadOnly()) + result &= unify(*subProp.readTy, superProp.type()); + else if (superProp.isReadOnly()) + result &= unify(subProp.type(), *superProp.readTy); + else + { + result &= unify(subProp.type(), superProp.type()); + result &= unify(superProp.type(), subProp.type()); + } + } + } + + auto subTypeParamsIter = subTable->instantiatedTypeParams.begin(); + auto superTypeParamsIter = superTable->instantiatedTypeParams.begin(); + + while (subTypeParamsIter != subTable->instantiatedTypeParams.end() && superTypeParamsIter != superTable->instantiatedTypeParams.end()) + { + result &= unify(*subTypeParamsIter, *superTypeParamsIter); + + subTypeParamsIter++; + superTypeParamsIter++; + } + + auto subTypePackParamsIter = subTable->instantiatedTypePackParams.begin(); + auto superTypePackParamsIter = superTable->instantiatedTypePackParams.begin(); + + while (subTypePackParamsIter != subTable->instantiatedTypePackParams.end() && + superTypePackParamsIter != superTable->instantiatedTypePackParams.end()) + { + result &= unify(*subTypePackParamsIter, *superTypePackParamsIter); + + subTypePackParamsIter++; + superTypePackParamsIter++; + } + + if (subTable->selfTy && superTable->selfTy) + result &= unify(*subTable->selfTy, *superTable->selfTy); + + if (subTable->indexer && superTable->indexer) + { + result &= unify(subTable->indexer->indexType, superTable->indexer->indexType); + result &= unify(subTable->indexer->indexResultType, superTable->indexer->indexResultType); + } + + if (!subTable->indexer && subTable->state == TableState::Unsealed && superTable->indexer) + { + /* + * Unsealed tables are always created from literal table expressions. We + * can't be completely certain whether such a table has an indexer just + * by the content of the expression itself, so we need to be a bit more + * flexible here. + * + * If we are trying to reconcile an unsealed table with a table that has + * an indexer, we therefore conclude that the unsealed table has the + * same indexer. + */ + + TypeId indexType = superTable->indexer->indexType; + if (TypeId* subst = genericSubstitutions.find(indexType)) + indexType = *subst; + + TypeId indexResultType = superTable->indexer->indexResultType; + if (TypeId* subst = genericSubstitutions.find(indexResultType)) + indexResultType = *subst; + + subTable->indexer = TableIndexer{indexType, indexResultType}; + } + + return result; +} + +bool Unifier2::unify(const MetatableType* subMetatable, const MetatableType* superMetatable) +{ + return unify(subMetatable->metatable, superMetatable->metatable) && unify(subMetatable->table, superMetatable->table); +} + +bool Unifier2::unify(const AnyType* subAny, const FunctionType* superFn) +{ + // If `any` is the subtype, then we can propagate that inward. + bool argResult = unify(superFn->argTypes, builtinTypes->anyTypePack); + bool retResult = unify(builtinTypes->anyTypePack, superFn->retTypes); + return argResult && retResult; +} + +bool Unifier2::unify(const FunctionType* subFn, const AnyType* superAny) +{ + // If `any` is the supertype, then we can propagate that inward. + bool argResult = unify(builtinTypes->anyTypePack, subFn->argTypes); + bool retResult = unify(subFn->retTypes, builtinTypes->anyTypePack); + return argResult && retResult; +} + +bool Unifier2::unify(const AnyType* subAny, const TableType* superTable) +{ + for (const auto& [propName, prop] : superTable->props) + { + if (prop.readTy) + unify(builtinTypes->anyType, *prop.readTy); + + if (prop.writeTy) + unify(*prop.writeTy, builtinTypes->anyType); + } + + if (superTable->indexer) + { + unify(builtinTypes->anyType, superTable->indexer->indexType); + unify(builtinTypes->anyType, superTable->indexer->indexResultType); + } + + return true; +} + +bool Unifier2::unify(const TableType* subTable, const AnyType* superAny) +{ + for (const auto& [propName, prop] : subTable->props) + { + if (prop.readTy) + unify(*prop.readTy, builtinTypes->anyType); + + if (prop.writeTy) + unify(builtinTypes->anyType, *prop.writeTy); + } + + if (subTable->indexer) + { + unify(subTable->indexer->indexType, builtinTypes->anyType); + unify(subTable->indexer->indexResultType, builtinTypes->anyType); + } + + return true; +} + +// FIXME? This should probably return an ErrorVec or an optional +// rather than a boolean to signal an occurs check failure. +bool Unifier2::unify(TypePackId subTp, TypePackId superTp) +{ + subTp = follow(subTp); + superTp = follow(superTp); + + if (auto subGen = genericPackSubstitutions.find(subTp)) + return unify(*subGen, superTp); + + if (auto superGen = genericPackSubstitutions.find(superTp)) + return unify(subTp, *superGen); + + if (seenTypePackPairings.contains({subTp, superTp})) + return true; + seenTypePackPairings.insert({subTp, superTp}); + + if (subTp == superTp) + return true; + + if (isIrresolvable(subTp) || isIrresolvable(superTp)) + { + if (uninhabitedTypeFunctions && (uninhabitedTypeFunctions->contains(subTp) || uninhabitedTypeFunctions->contains(superTp))) + return true; + + incompleteSubtypes.push_back(PackSubtypeConstraint{subTp, superTp}); + return true; + } + + const FreeTypePack* subFree = get(subTp); + const FreeTypePack* superFree = get(superTp); + + if (subFree) + { + DenseHashSet seen{nullptr}; + if (OccursCheckResult::Fail == occursCheck(seen, subTp, superTp)) + { + emplaceTypePack(asMutable(subTp), builtinTypes->errorTypePack); + return false; + } + + emplaceTypePack(asMutable(subTp), superTp); + return true; + } + + if (superFree) + { + DenseHashSet seen{nullptr}; + if (OccursCheckResult::Fail == occursCheck(seen, superTp, subTp)) + { + emplaceTypePack(asMutable(superTp), builtinTypes->errorTypePack); + return false; + } + + emplaceTypePack(asMutable(superTp), subTp); + return true; + } + + size_t maxLength = std::max(flatten(subTp).first.size(), flatten(superTp).first.size()); + + auto [subTypes, subTail] = extendTypePack(*arena, builtinTypes, subTp, maxLength); + auto [superTypes, superTail] = extendTypePack(*arena, builtinTypes, superTp, maxLength); + + // right-pad the subpack with nils if `superPack` is larger since that's what a function call does + if (subTypes.size() < maxLength) + { + for (size_t i = 0; i <= maxLength - subTypes.size(); i++) + subTypes.push_back(builtinTypes->nilType); + } + + if (subTypes.size() < maxLength || superTypes.size() < maxLength) + return true; + + for (size_t i = 0; i < maxLength; ++i) + unify(subTypes[i], superTypes[i]); + + if (subTail && superTail) + { + TypePackId followedSubTail = follow(*subTail); + TypePackId followedSuperTail = follow(*superTail); + + if (get(followedSubTail) || get(followedSuperTail)) + return unify(followedSubTail, followedSuperTail); + } + else if (subTail) + { + TypePackId followedSubTail = follow(*subTail); + if (get(followedSubTail)) + emplaceTypePack(asMutable(followedSubTail), builtinTypes->emptyTypePack); + } + else if (superTail) + { + TypePackId followedSuperTail = follow(*superTail); + if (get(followedSuperTail)) + emplaceTypePack(asMutable(followedSuperTail), builtinTypes->emptyTypePack); + } + + return true; +} + +struct FreeTypeSearcher : TypeVisitor +{ + NotNull scope; + + explicit FreeTypeSearcher(NotNull scope) + : TypeVisitor(/*skipBoundTypes*/ true) + , scope(scope) + { + } + + enum Polarity + { + Positive, + Negative, + Both, + }; + + Polarity polarity = Positive; + + void flip() + { + switch (polarity) + { + case Positive: + polarity = Negative; + break; + case Negative: + polarity = Positive; + break; + case Both: + break; + } + } + + DenseHashSet seenPositive{nullptr}; + DenseHashSet seenNegative{nullptr}; + + bool seenWithPolarity(const void* ty) + { + switch (polarity) + { + case Positive: + { + if (seenPositive.contains(ty)) + return true; + + seenPositive.insert(ty); + return false; + } + case Negative: + { + if (seenNegative.contains(ty)) + return true; + + seenNegative.insert(ty); + return false; + } + case Both: + { + if (seenPositive.contains(ty) && seenNegative.contains(ty)) + return true; + + seenPositive.insert(ty); + seenNegative.insert(ty); + return false; + } + } + + return false; + } + + // The keys in these maps are either TypeIds or TypePackIds. It's safe to + // mix them because we only use these pointers as unique keys. We never + // indirect them. + DenseHashMap negativeTypes{0}; + DenseHashMap positiveTypes{0}; + + bool visit(TypeId ty) override + { + if (seenWithPolarity(ty)) + return false; + + LUAU_ASSERT(ty); + return true; + } + + bool visit(TypeId ty, const FreeType& ft) override + { + if (seenWithPolarity(ty)) + return false; + + if (!subsumes(scope, ft.scope)) + return true; + + switch (polarity) + { + case Positive: + positiveTypes[ty]++; + break; + case Negative: + negativeTypes[ty]++; + break; + case Both: + positiveTypes[ty]++; + negativeTypes[ty]++; + break; + } + + return true; + } + + bool visit(TypeId ty, const TableType& tt) override + { + if (seenWithPolarity(ty)) + return false; + + if ((tt.state == TableState::Free || tt.state == TableState::Unsealed) && subsumes(scope, tt.scope)) + { + switch (polarity) + { + case Positive: + positiveTypes[ty]++; + break; + case Negative: + negativeTypes[ty]++; + break; + case Both: + positiveTypes[ty]++; + negativeTypes[ty]++; + break; + } + } + + for (const auto& [_name, prop] : tt.props) + { + if (prop.isReadOnly()) + traverse(*prop.readTy); + else + { + LUAU_ASSERT(prop.isShared()); + + Polarity p = polarity; + polarity = Both; + traverse(prop.type()); + polarity = p; + } + } + + if (tt.indexer) + { + traverse(tt.indexer->indexType); + traverse(tt.indexer->indexResultType); + } + + return false; + } + + bool visit(TypeId ty, const FunctionType& ft) override + { + if (seenWithPolarity(ty)) + return false; + + flip(); + traverse(ft.argTypes); + flip(); + + traverse(ft.retTypes); + + return false; + } + + bool visit(TypeId, const ClassType&) override + { + return false; + } + + bool visit(TypePackId tp, const FreeTypePack& ftp) override + { + if (seenWithPolarity(tp)) + return false; + + if (!subsumes(scope, ftp.scope)) + return true; + + switch (polarity) + { + case Positive: + positiveTypes[tp]++; + break; + case Negative: + negativeTypes[tp]++; + break; + case Both: + positiveTypes[tp]++; + negativeTypes[tp]++; + break; + } + + return true; + } +}; + +TypeId Unifier2::mkUnion(TypeId left, TypeId right) +{ + left = follow(left); + right = follow(right); + + return simplifyUnion(builtinTypes, arena, left, right).result; +} + +TypeId Unifier2::mkIntersection(TypeId left, TypeId right) +{ + left = follow(left); + right = follow(right); + + return simplifyIntersection(builtinTypes, arena, left, right).result; +} + +OccursCheckResult Unifier2::occursCheck(DenseHashSet& seen, TypeId needle, TypeId haystack) +{ + RecursionLimiter _ra(&recursionCount, recursionLimit); + + OccursCheckResult occurrence = OccursCheckResult::Pass; + + auto check = [&](TypeId ty) + { + if (occursCheck(seen, needle, ty) == OccursCheckResult::Fail) + occurrence = OccursCheckResult::Fail; + }; + + needle = follow(needle); + haystack = follow(haystack); + + if (seen.find(haystack)) + return OccursCheckResult::Pass; + + seen.insert(haystack); + + if (get(needle)) + return OccursCheckResult::Pass; + + if (!get(needle)) + ice->ice("Expected needle to be free"); + + if (needle == haystack) + return OccursCheckResult::Fail; + + if (auto haystackFree = get(haystack)) + { + check(haystackFree->lowerBound); + check(haystackFree->upperBound); + } + else if (auto ut = get(haystack)) + { + for (TypeId ty : ut->options) + check(ty); + } + else if (auto it = get(haystack)) + { + for (TypeId ty : it->parts) + check(ty); + } + + return occurrence; +} + +OccursCheckResult Unifier2::occursCheck(DenseHashSet& seen, TypePackId needle, TypePackId haystack) +{ + needle = follow(needle); + haystack = follow(haystack); + + if (seen.find(haystack)) + return OccursCheckResult::Pass; + + seen.insert(haystack); + + if (getMutable(needle)) + return OccursCheckResult::Pass; + + if (!getMutable(needle)) + ice->ice("Expected needle pack to be free"); + + RecursionLimiter _ra(&recursionCount, recursionLimit); + + while (!getMutable(haystack)) + { + if (needle == haystack) + return OccursCheckResult::Fail; + + if (auto a = get(haystack); a && a->tail) + { + haystack = follow(*a->tail); + continue; + } + + break; + } + + return OccursCheckResult::Pass; +} + +} // namespace Luau diff --git a/Ast/include/Luau/Ast.h b/Ast/include/Luau/Ast.h index 783e37b7f..fdef14c07 100644 --- a/Ast/include/Luau/Ast.h +++ b/Ast/include/Luau/Ast.h @@ -3,11 +3,13 @@ #include "Luau/Location.h" +#include #include #include #include #include +#include namespace Luau { @@ -58,6 +60,8 @@ class AstStat; class AstStatBlock; class AstExpr; class AstTypePack; +class AstAttr; +class AstExprTable; struct AstLocal { @@ -90,10 +94,21 @@ struct AstArray { return data; } + const T* end() const { return data + size; } + + std::reverse_iterator rbegin() const + { + return std::make_reverse_iterator(end()); + } + + std::reverse_iterator rend() const + { + return std::make_reverse_iterator(begin()); + } }; struct AstTypeList @@ -159,6 +174,10 @@ class AstNode { return nullptr; } + virtual AstAttr* asAttr() + { + return nullptr; + } template bool is() const @@ -180,6 +199,29 @@ class AstNode Location location; }; +class AstAttr : public AstNode +{ +public: + LUAU_RTTI(AstAttr) + + enum Type + { + Checked, + Native, + }; + + AstAttr(const Location& location, Type type); + + AstAttr* asAttr() override + { + return this; + } + + void visit(AstVisitor* visitor) override; + + Type type; +}; + class AstExpr : public AstNode { public: @@ -248,6 +290,7 @@ class AstExprConstantBool : public AstExpr enum class ConstantNumberParseResult { Ok, + Imprecise, Malformed, BinOverflow, HexOverflow, @@ -271,11 +314,18 @@ class AstExprConstantString : public AstExpr public: LUAU_RTTI(AstExprConstantString) - AstExprConstantString(const Location& location, const AstArray& value); + enum QuoteStyle + { + Quoted, + Unquoted + }; + + AstExprConstantString(const Location& location, const AstArray& value, QuoteStyle quoteStyle = Quoted); void visit(AstVisitor* visitor) override; AstArray value; + QuoteStyle quoteStyle = Quoted; }; class AstExprLocal : public AstExpr @@ -334,7 +384,13 @@ class AstExprIndexName : public AstExpr LUAU_RTTI(AstExprIndexName) AstExprIndexName( - const Location& location, AstExpr* expr, const AstName& index, const Location& indexLocation, const Position& opPosition, char op); + const Location& location, + AstExpr* expr, + const AstName& index, + const Location& indexLocation, + const Position& opPosition, + char op + ); void visit(AstVisitor* visitor) override; @@ -363,13 +419,28 @@ class AstExprFunction : public AstExpr public: LUAU_RTTI(AstExprFunction) - AstExprFunction(const Location& location, const AstArray& generics, const AstArray& genericPacks, - AstLocal* self, const AstArray& args, bool vararg, const Location& varargLocation, AstStatBlock* body, size_t functionDepth, - const AstName& debugname, const std::optional& returnAnnotation = {}, AstTypePack* varargAnnotation = nullptr, - bool hasEnd = false, const std::optional& argLocation = std::nullopt); + AstExprFunction( + const Location& location, + const AstArray& attributes, + const AstArray& generics, + const AstArray& genericPacks, + AstLocal* self, + const AstArray& args, + bool vararg, + const Location& varargLocation, + AstStatBlock* body, + size_t functionDepth, + const AstName& debugname, + const std::optional& returnAnnotation = {}, + AstTypePack* varargAnnotation = nullptr, + const std::optional& argLocation = std::nullopt + ); void visit(AstVisitor* visitor) override; + bool hasNativeAttribute() const; + + AstArray attributes; AstArray generics; AstArray genericPacks; AstLocal* self; @@ -385,7 +456,6 @@ class AstExprFunction : public AstExpr AstName debugname; - bool hasEnd = false; std::optional argLocation; }; @@ -453,6 +523,7 @@ class AstExprBinary : public AstExpr Sub, Mul, Div, + FloorDiv, Mod, Pow, Concat, @@ -465,14 +536,15 @@ class AstExprBinary : public AstExpr And, Or, //GIDEROS ADDED - DivInt, MaxOf, MinOf, BinAnd, BinOr, BinXor, BinShiftR, - BinShiftL + BinShiftL, + + Op__Count }; AstExprBinary(const Location& location, Op op, AstExpr* left, AstExpr* right); @@ -536,11 +608,23 @@ class AstStatBlock : public AstStat public: LUAU_RTTI(AstStatBlock) - AstStatBlock(const Location& location, const AstArray& body); + AstStatBlock(const Location& location, const AstArray& body, bool hasEnd = true); void visit(AstVisitor* visitor) override; AstArray body; + + /* Indicates whether or not this block has been terminated in a + * syntactically valid way. + * + * This is usually but not always done with the 'end' keyword. AstStatIf + * and AstStatRepeat are the two main exceptions to this. + * + * The 'then' clause of an if statement can properly be closed by the + * keywords 'else' or 'elseif'. A 'repeat' loop's body is closed with the + * 'until' keyword. + */ + bool hasEnd = false; }; class AstStatIf : public AstStat @@ -548,8 +632,14 @@ class AstStatIf : public AstStat public: LUAU_RTTI(AstStatIf) - AstStatIf(const Location& location, AstExpr* condition, AstStatBlock* thenbody, AstStat* elsebody, const std::optional& thenLocation, - const std::optional& elseLocation, bool hasEnd); + AstStatIf( + const Location& location, + AstExpr* condition, + AstStatBlock* thenbody, + AstStat* elsebody, + const std::optional& thenLocation, + const std::optional& elseLocation + ); void visit(AstVisitor* visitor) override; @@ -561,8 +651,6 @@ class AstStatIf : public AstStat // Active for 'elseif' as well std::optional elseLocation; - - bool hasEnd = false; }; class AstStatWhile : public AstStat @@ -570,7 +658,7 @@ class AstStatWhile : public AstStat public: LUAU_RTTI(AstStatWhile) - AstStatWhile(const Location& location, AstExpr* condition, AstStatBlock* body, bool hasDo, const Location& doLocation, bool hasEnd); + AstStatWhile(const Location& location, AstExpr* condition, AstStatBlock* body, bool hasDo, const Location& doLocation); void visit(AstVisitor* visitor) override; @@ -579,8 +667,6 @@ class AstStatWhile : public AstStat bool hasDo = false; Location doLocation; - - bool hasEnd = false; }; class AstStatRepeat : public AstStat @@ -588,14 +674,14 @@ class AstStatRepeat : public AstStat public: LUAU_RTTI(AstStatRepeat) - AstStatRepeat(const Location& location, AstExpr* condition, AstStatBlock* body, bool hasUntil); + AstStatRepeat(const Location& location, AstExpr* condition, AstStatBlock* body, bool DEPRECATED_hasUntil); void visit(AstVisitor* visitor) override; AstExpr* condition; AstStatBlock* body; - bool hasUntil = false; + bool DEPRECATED_hasUntil = false; }; class AstStatBreak : public AstStat @@ -647,8 +733,12 @@ class AstStatLocal : public AstStat public: LUAU_RTTI(AstStatLocal) - AstStatLocal(const Location& location, const AstArray& vars, const AstArray& values, - const std::optional& equalsSignLocation); + AstStatLocal( + const Location& location, + const AstArray& vars, + const AstArray& values, + const std::optional& equalsSignLocation + ); void visit(AstVisitor* visitor) override; @@ -663,8 +753,16 @@ class AstStatFor : public AstStat public: LUAU_RTTI(AstStatFor) - AstStatFor(const Location& location, AstLocal* var, AstExpr* from, AstExpr* to, AstExpr* step, AstStatBlock* body, bool hasDo, - const Location& doLocation, bool hasEnd); + AstStatFor( + const Location& location, + AstLocal* var, + AstExpr* from, + AstExpr* to, + AstExpr* step, + AstStatBlock* body, + bool hasDo, + const Location& doLocation + ); void visit(AstVisitor* visitor) override; @@ -676,8 +774,6 @@ class AstStatFor : public AstStat bool hasDo = false; Location doLocation; - - bool hasEnd = false; }; class AstStatForIn : public AstStat @@ -685,8 +781,16 @@ class AstStatForIn : public AstStat public: LUAU_RTTI(AstStatForIn) - AstStatForIn(const Location& location, const AstArray& vars, const AstArray& values, AstStatBlock* body, bool hasIn, - const Location& inLocation, bool hasDo, const Location& doLocation, bool hasEnd); + AstStatForIn( + const Location& location, + const AstArray& vars, + const AstArray& values, + AstStatBlock* body, + bool hasIn, + const Location& inLocation, + bool hasDo, + const Location& doLocation + ); void visit(AstVisitor* visitor) override; @@ -699,8 +803,6 @@ class AstStatForIn : public AstStat bool hasDo = false; Location doLocation; - - bool hasEnd = false; }; class AstStatAssign : public AstStat @@ -761,8 +863,15 @@ class AstStatTypeAlias : public AstStat public: LUAU_RTTI(AstStatTypeAlias) - AstStatTypeAlias(const Location& location, const AstName& name, const Location& nameLocation, const AstArray& generics, - const AstArray& genericPacks, AstType* type, bool exported); + AstStatTypeAlias( + const Location& location, + const AstName& name, + const Location& nameLocation, + const AstArray& generics, + const AstArray& genericPacks, + AstType* type, + bool exported + ); void visit(AstVisitor* visitor) override; @@ -774,16 +883,31 @@ class AstStatTypeAlias : public AstStat bool exported; }; +class AstStatTypeFunction : public AstStat +{ +public: + LUAU_RTTI(AstStatTypeFunction); + + AstStatTypeFunction(const Location& location, const AstName& name, const Location& nameLocation, AstExprFunction* body); + + void visit(AstVisitor* visitor) override; + + AstName name; + Location nameLocation; + AstExprFunction* body; +}; + class AstStatDeclareGlobal : public AstStat { public: LUAU_RTTI(AstStatDeclareGlobal) - AstStatDeclareGlobal(const Location& location, const AstName& name, AstType* type); + AstStatDeclareGlobal(const Location& location, const AstName& name, const Location& nameLocation, AstType* type); void visit(AstVisitor* visitor) override; AstName name; + Location nameLocation; AstType* type; }; @@ -792,25 +916,74 @@ class AstStatDeclareFunction : public AstStat public: LUAU_RTTI(AstStatDeclareFunction) - AstStatDeclareFunction(const Location& location, const AstName& name, const AstArray& generics, - const AstArray& genericPacks, const AstTypeList& params, const AstArray& paramNames, - const AstTypeList& retTypes); + AstStatDeclareFunction( + const Location& location, + const AstName& name, + const Location& nameLocation, + const AstArray& generics, + const AstArray& genericPacks, + const AstTypeList& params, + const AstArray& paramNames, + bool vararg, + const Location& varargLocation, + const AstTypeList& retTypes + ); + + AstStatDeclareFunction( + const Location& location, + const AstArray& attributes, + const AstName& name, + const Location& nameLocation, + const AstArray& generics, + const AstArray& genericPacks, + const AstTypeList& params, + const AstArray& paramNames, + bool vararg, + const Location& varargLocation, + const AstTypeList& retTypes + ); + void visit(AstVisitor* visitor) override; + bool isCheckedFunction() const; + + AstArray attributes; AstName name; + Location nameLocation; AstArray generics; AstArray genericPacks; AstTypeList params; AstArray paramNames; + bool vararg = false; + Location varargLocation; AstTypeList retTypes; }; struct AstDeclaredClassProp { AstName name; + Location nameLocation; AstType* ty = nullptr; bool isMethod = false; + Location location; +}; + +enum class AstTableAccess +{ + Read = 0b01, + Write = 0b10, + ReadWrite = 0b11, +}; + +struct AstTableIndexer +{ + AstType* indexType; + AstType* resultType; + Location location; + + AstTableAccess access = AstTableAccess::ReadWrite; + std::optional accessLocation; }; class AstStatDeclareClass : public AstStat @@ -818,7 +991,13 @@ class AstStatDeclareClass : public AstStat public: LUAU_RTTI(AstStatDeclareClass) - AstStatDeclareClass(const Location& location, const AstName& name, std::optional superName, const AstArray& props); + AstStatDeclareClass( + const Location& location, + const AstName& name, + std::optional superName, + const AstArray& props, + AstTableIndexer* indexer = nullptr + ); void visit(AstVisitor* visitor) override; @@ -826,6 +1005,7 @@ class AstStatDeclareClass : public AstStat std::optional superName; AstArray props; + AstTableIndexer* indexer; }; class AstType : public AstNode @@ -854,14 +1034,23 @@ class AstTypeReference : public AstType public: LUAU_RTTI(AstTypeReference) - AstTypeReference(const Location& location, std::optional prefix, AstName name, bool hasParameterList = false, - const AstArray& parameters = {}); + AstTypeReference( + const Location& location, + std::optional prefix, + AstName name, + std::optional prefixLocation, + const Location& nameLocation, + bool hasParameterList = false, + const AstArray& parameters = {} + ); void visit(AstVisitor* visitor) override; bool hasParameterList; std::optional prefix; + std::optional prefixLocation; AstName name; + Location nameLocation; AstArray parameters; }; @@ -870,13 +1059,8 @@ struct AstTableProp AstName name; Location location; AstType* type; -}; - -struct AstTableIndexer -{ - AstType* indexType; - AstType* resultType; - Location location; + AstTableAccess access = AstTableAccess::ReadWrite; + std::optional accessLocation; }; class AstTypeTable : public AstType @@ -898,11 +1082,30 @@ class AstTypeFunction : public AstType public: LUAU_RTTI(AstTypeFunction) - AstTypeFunction(const Location& location, const AstArray& generics, const AstArray& genericPacks, - const AstTypeList& argTypes, const AstArray>& argNames, const AstTypeList& returnTypes); + AstTypeFunction( + const Location& location, + const AstArray& generics, + const AstArray& genericPacks, + const AstTypeList& argTypes, + const AstArray>& argNames, + const AstTypeList& returnTypes + ); + + AstTypeFunction( + const Location& location, + const AstArray& attributes, + const AstArray& generics, + const AstArray& genericPacks, + const AstTypeList& argTypes, + const AstArray>& argNames, + const AstTypeList& returnTypes + ); void visit(AstVisitor* visitor) override; + bool isCheckedFunction() const; + + AstArray attributes; AstArray generics; AstArray genericPacks; AstTypeList argTypes; @@ -1066,6 +1269,11 @@ class AstVisitor return true; } + virtual bool visit(class AstAttr* node) + { + return visit(static_cast(node)); + } + virtual bool visit(class AstExpr* node) { return visit(static_cast(node)); @@ -1295,6 +1503,7 @@ class AstVisitor } }; +bool isLValue(const AstExpr*); AstName getIdentifier(AstExpr*); Location getLocation(const AstTypeList& typeList); diff --git a/Ast/include/Luau/Lexer.h b/Ast/include/Luau/Lexer.h index be317cf2f..fc77dacf1 100644 --- a/Ast/include/Luau/Lexer.h +++ b/Ast/include/Luau/Lexer.h @@ -62,6 +62,7 @@ struct Lexeme Dot3, SkinnyArrow, DoubleColon, + FloorDiv, InterpStringBegin, InterpStringMid, @@ -73,6 +74,7 @@ struct Lexeme SubAssign, MulAssign, DivAssign, + FloorDivAssign, ModAssign, PowAssign, ConcatAssign, @@ -93,11 +95,12 @@ struct Lexeme Comment, BlockComment, + Attribute, + BrokenString, BrokenComment, BrokenUnicode, BrokenInterpDoubleBrace, - Error, Reserved_BEGIN, @@ -127,8 +130,15 @@ struct Lexeme Type type; Location location; + + // Field declared here, before the union, to ensure that Lexeme size is 32 bytes. +private: + // length is used to extract a slice from the input buffer. + // This field is only valid for certain lexeme types which don't duplicate portions of input + // but instead store a pointer to a location in the input buffer and the length of lexeme. unsigned int length; +public: union { const char* data; // String, Number, Comment @@ -141,9 +151,13 @@ struct Lexeme Lexeme(const Location& location, Type type, const char* data, size_t size); Lexeme(const Location& location, Type type, const char* name); + unsigned int getLength() const; + std::string toString() const; }; +static_assert(sizeof(Lexeme) <= 32, "Size of `Lexeme` struct should be up to 32 bytes."); + class AstNameTable { public: @@ -212,7 +226,9 @@ class Lexer Position position() const; + // consume() assumes current character is not a newline for performance; when that is not known, consumeAny() should be used instead. void consume(); + void consumeAny(); Lexeme readCommentBody(); diff --git a/Ast/include/Luau/Location.h b/Ast/include/Luau/Location.h index dbe36becb..3fc8921a5 100644 --- a/Ast/include/Luau/Location.h +++ b/Ast/include/Luau/Location.h @@ -8,7 +8,11 @@ struct Position { unsigned int line, column; - Position(unsigned int line, unsigned int column); + Position(unsigned int line, unsigned int column) + : line(line) + , column(column) + { + } bool operator==(const Position& rhs) const; bool operator!=(const Position& rhs) const; @@ -24,10 +28,29 @@ struct Location { Position begin, end; - Location(); - Location(const Position& begin, const Position& end); - Location(const Position& begin, unsigned int length); - Location(const Location& begin, const Location& end); + Location() + : begin(0, 0) + , end(0, 0) + { + } + + Location(const Position& begin, const Position& end) + : begin(begin) + , end(end) + { + } + + Location(const Position& begin, unsigned int length) + : begin(begin) + , end(begin.line, begin.column + length) + { + } + + Location(const Location& begin, const Location& end) + : begin(begin.begin) + , end(end.end) + { + } bool operator==(const Location& rhs) const; bool operator!=(const Location& rhs) const; diff --git a/Ast/include/Luau/ParseOptions.h b/Ast/include/Luau/ParseOptions.h index 89e79528b..804d16fca 100644 --- a/Ast/include/Luau/ParseOptions.h +++ b/Ast/include/Luau/ParseOptions.h @@ -1,6 +1,11 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #pragma once +#include "Luau/Ast.h" +#include "Luau/DenseHash.h" + +#include + namespace Luau { @@ -12,12 +17,17 @@ enum class Mode Definition, // Type definition module, has special parsing rules }; +struct FragmentParseResumeSettings +{ + DenseHashMap localMap{AstName()}; + std::vector localStack; +}; + struct ParseOptions { - bool allowTypeAnnotations = true; - bool supportContinueStatement = true; bool allowDeclarationSyntax = false; bool captureComments = false; + std::optional parseFragment = std::nullopt; }; } // namespace Luau diff --git a/Ast/include/Luau/Parser.h b/Ast/include/Luau/Parser.h index 8b7eb73cf..5411379e3 100644 --- a/Ast/include/Luau/Parser.h +++ b/Ast/include/Luau/Parser.h @@ -55,7 +55,12 @@ class Parser { public: static ParseResult parse( - const char* buffer, std::size_t bufferSize, AstNameTable& names, Allocator& allocator, ParseOptions options = ParseOptions()); + const char* buffer, + std::size_t bufferSize, + AstNameTable& names, + Allocator& allocator, + ParseOptions options = ParseOptions() + ); private: struct Name; @@ -82,8 +87,8 @@ class Parser // if exp then block {elseif exp then block} [else block] end | // for Name `=' exp `,' exp [`,' exp] do block end | // for namelist in explist do block end | - // function funcname funcbody | - // local function Name funcbody | + // [attributes] function funcname funcbody | + // [attributes] local function Name funcbody | // local namelist [`=' explist] // laststat ::= return [explist] | break AstStat* parseStat(); @@ -114,23 +119,40 @@ class Parser AstExpr* parseFunctionName(Location start, bool& hasself, AstName& debugname); // function funcname funcbody - AstStat* parseFunctionStat(); + LUAU_FORCEINLINE AstStat* parseFunctionStat(const AstArray& attributes = {nullptr, 0}); + + std::pair validateAttribute(const char* attributeName, const TempVector& attributes); + + // attribute ::= '@' NAME + void parseAttribute(TempVector& attribute); + + // attributes ::= {attribute} + AstArray parseAttributes(); + + // attributes local function Name funcbody + // attributes function funcname funcbody + // attributes `declare function' Name`(' [parlist] `)' [`:` Type] + // declare Name '{' Name ':' attributes `(' [parlist] `)' [`:` Type] '}' + AstStat* parseAttributeStat(); // local function Name funcbody | // local namelist [`=' explist] - AstStat* parseLocal(); + AstStat* parseLocal(const AstArray& attributes); // return [explist] AstStat* parseReturn(); - // type Name `=' typeannotation + // type Name `=' Type AstStat* parseTypeAlias(const Location& start, bool exported); + // type function Name ... end + AstStat* parseTypeFunction(const Location& start, bool exported); + AstDeclaredClassProp parseDeclaredClassMethod(); - // `declare global' Name: typeannotation | - // `declare function' Name`(' [parlist] `)' [`:` TypeAnnotation] - AstStat* parseDeclaration(const Location& start); + // `declare global' Name: Type | + // `declare function' Name`(' [parlist] `)' [`:` Type] + AstStat* parseDeclaration(const Location& start, const AstArray& attributes); // varlist `=' explist AstStat* parseAssignment(AstExpr* initial); @@ -140,29 +162,34 @@ class Parser std::pair> prepareFunctionArguments(const Location& start, bool hasself, const TempVector& args); - // funcbodyhead ::= `(' [namelist [`,' `...'] | `...'] `)' [`:` TypeAnnotation] + // funcbodyhead ::= `(' [namelist [`,' `...'] | `...'] `)' [`:` Type] // funcbody ::= funcbodyhead block end std::pair parseFunctionBody( - bool hasself, const Lexeme& matchFunction, const AstName& debugname, const Name* localName); + bool hasself, + const Lexeme& matchFunction, + const AstName& debugname, + const Name* localName, + const AstArray& attributes + ); // explist ::= {exp `,'} exp void parseExprList(TempVector& result); - // binding ::= Name [`:` TypeAnnotation] + // binding ::= Name [`:` Type] Binding parseBinding(); // bindinglist ::= (binding | `...') {`,' bindinglist} // Returns the location of the vararg ..., or std::nullopt if the function is not vararg. std::tuple parseBindingList(TempVector& result, bool allowDot3 = false); - AstType* parseOptionalTypeAnnotation(); + AstType* parseOptionalType(); - // TypeList ::= TypeAnnotation [`,' TypeList] - // ReturnType ::= TypeAnnotation | `(' TypeList `)' - // TableProp ::= Name `:' TypeAnnotation - // TableIndexer ::= `[' TypeAnnotation `]' `:' TypeAnnotation + // TypeList ::= Type [`,' TypeList] + // ReturnType ::= Type | `(' TypeList `)' + // TableProp ::= Name `:' Type + // TableIndexer ::= `[' Type `]' `:' Type // PropList ::= (TableProp | TableIndexer) [`,' PropList] - // TypeAnnotation + // Type // ::= Name // | `nil` // | `{' [PropList] `}' @@ -171,24 +198,32 @@ class Parser // Returns the variadic annotation, if it exists. AstTypePack* parseTypeList(TempVector& result, TempVector>& resultNames); - std::optional parseOptionalReturnTypeAnnotation(); - std::pair parseReturnTypeAnnotation(); + std::optional parseOptionalReturnType(); + std::pair parseReturnType(); - AstTableIndexer* parseTableIndexerAnnotation(); + AstTableIndexer* parseTableIndexer(AstTableAccess access, std::optional accessLocation); - AstTypeOrPack parseFunctionTypeAnnotation(bool allowPack); - AstType* parseFunctionTypeAnnotationTail(const Lexeme& begin, AstArray generics, AstArray genericPacks, - AstArray& params, AstArray>& paramNames, AstTypePack* varargAnnotation); + AstTypeOrPack parseFunctionType(bool allowPack, const AstArray& attributes); + AstType* parseFunctionTypeTail( + const Lexeme& begin, + const AstArray& attributes, + AstArray generics, + AstArray genericPacks, + AstArray params, + AstArray> paramNames, + AstTypePack* varargAnnotation + ); - AstType* parseTableTypeAnnotation(); - AstTypeOrPack parseSimpleTypeAnnotation(bool allowPack); + AstType* parseTableType(bool inDeclarationContext = false); + AstTypeOrPack parseSimpleType(bool allowPack, bool inDeclarationContext = false); - AstTypeOrPack parseTypeOrPackAnnotation(); - AstType* parseTypeAnnotation(TempVector& parts, const Location& begin); - AstType* parseTypeAnnotation(); + AstTypeOrPack parseTypeOrPack(); + AstType* parseType(bool inDeclarationContext = false); - AstTypePack* parseTypePackAnnotation(); - AstTypePack* parseVariadicArgumentAnnotation(); + AstTypePack* parseTypePack(); + AstTypePack* parseVariadicArgumentTypePack(); + + AstType* parseTypeSuffix(AstType* type, const Location& begin); static std::optional parseUnaryOp(const Lexeme& l); static std::optional parseBinaryOp(const Lexeme& l); @@ -215,10 +250,10 @@ class Parser // primaryexp -> prefixexp { `.' NAME | `[' exp `]' | `:' NAME funcargs | funcargs } AstExpr* parsePrimaryExpr(bool asStatement); - // asexp -> simpleexp [`::' typeAnnotation] + // asexp -> simpleexp [`::' Type] AstExpr* parseAssertionExpr(); - // simpleexp -> NUMBER | STRING | NIL | true | false | ... | constructor | FUNCTION body | primaryexp + // simpleexp -> NUMBER | STRING | NIL | true | false | ... | constructor | [attributes] FUNCTION body | primaryexp AstExpr* parseSimpleExpr(); // args ::= `(' [explist] `)' | tableconstructor | String @@ -244,7 +279,7 @@ class Parser // `<' namelist `>' std::pair, AstArray> parseGenericTypeList(bool withDefaultValues); - // `<' typeAnnotation[, ...] `>' + // `<' Type[, ...] `>' AstArray parseTypeParams(); std::optional> parseCharArray(); @@ -299,16 +334,20 @@ class Parser void reportNameError(const char* context); - AstStatError* reportStatError(const Location& location, const AstArray& expressions, const AstArray& statements, - const char* format, ...) LUAU_PRINTF_ATTR(5, 6); + AstStatError* reportStatError( + const Location& location, + const AstArray& expressions, + const AstArray& statements, + const char* format, + ... + ) LUAU_PRINTF_ATTR(5, 6); AstExprError* reportExprError(const Location& location, const AstArray& expressions, const char* format, ...) LUAU_PRINTF_ATTR(4, 5); - AstTypeError* reportTypeAnnotationError(const Location& location, const AstArray& types, const char* format, ...) - LUAU_PRINTF_ATTR(4, 5); + AstTypeError* reportTypeError(const Location& location, const AstArray& types, const char* format, ...) LUAU_PRINTF_ATTR(4, 5); // `parseErrorLocation` is associated with the parser error // `astErrorLocation` is associated with the AstTypeError created // It can be useful to have different error locations so that the parse error can include the next lexeme, while the AstTypeError can precisely // define the location (possibly of zero size) where a type annotation is expected. - AstTypeError* reportMissingTypeAnnotationError(const Location& parseErrorLocation, const Location& astErrorLocation, const char* format, ...) + AstTypeError* reportMissingTypeError(const Location& parseErrorLocation, const Location& astErrorLocation, const char* format, ...) LUAU_PRINTF_ATTR(4, 5); AstExpr* reportFunctionArgsError(AstExpr* func, bool self); @@ -384,6 +423,7 @@ class Parser MatchLexeme endMismatchSuspect; std::vector functionStack; + size_t typeFunctionDepth = 0; DenseHashMap localMap; std::vector localStack; @@ -392,6 +432,7 @@ class Parser std::vector matchRecoveryStopOnToken; + std::vector scratchAttr; std::vector scratchStat; std::vector> scratchString; std::vector scratchExpr; @@ -401,8 +442,8 @@ class Parser std::vector scratchBinding; std::vector scratchLocal; std::vector scratchTableTypeProps; - std::vector scratchAnnotation; - std::vector scratchTypeOrPackAnnotation; + std::vector scratchType; + std::vector scratchTypeOrPack; std::vector scratchDeclaredClassProps; std::vector scratchItem; std::vector scratchArgName; diff --git a/Ast/include/Luau/TimeTrace.h b/Ast/include/Luau/TimeTrace.h index be2828272..2259f21ce 100644 --- a/Ast/include/Luau/TimeTrace.h +++ b/Ast/include/Luau/TimeTrace.h @@ -4,8 +4,10 @@ #include "Luau/Common.h" #include +#include #include +#include LUAU_FASTFLAG(DebugLuauTimeTracing) @@ -54,7 +56,7 @@ struct Event struct GlobalContext; struct ThreadContext; -GlobalContext& getGlobalContext(); +std::shared_ptr getGlobalContext(); uint16_t createToken(GlobalContext& context, const char* name, const char* category); uint32_t createThread(GlobalContext& context, ThreadContext* threadContext); @@ -66,7 +68,7 @@ struct ThreadContext ThreadContext() : globalContext(getGlobalContext()) { - threadId = createThread(globalContext, this); + threadId = createThread(*globalContext, this); } ~ThreadContext() @@ -74,16 +76,16 @@ struct ThreadContext if (!events.empty()) flushEvents(); - releaseThread(globalContext, this); + releaseThread(*globalContext, this); } void flushEvents() { - static uint16_t flushToken = createToken(globalContext, "flushEvents", "TimeTrace"); + static uint16_t flushToken = createToken(*globalContext, "flushEvents", "TimeTrace"); events.push_back({EventType::Enter, flushToken, {getClockMicroseconds()}}); - TimeTrace::flushEvents(globalContext, threadId, events, data); + TimeTrace::flushEvents(*globalContext, threadId, events, data); events.clear(); data.clear(); @@ -125,7 +127,7 @@ struct ThreadContext events.push_back({EventType::ArgValue, 0, {pos}}); } - GlobalContext& globalContext; + std::shared_ptr globalContext; uint32_t threadId; std::vector events; std::vector data; @@ -133,6 +135,14 @@ struct ThreadContext static constexpr size_t kEventFlushLimit = 8192; }; +using ThreadContextProvider = ThreadContext& (*)(); + +inline ThreadContextProvider& threadContextProvider() +{ + static ThreadContextProvider handler = nullptr; + return handler; +} + ThreadContext& getThreadContext(); struct Scope diff --git a/Ast/src/Ast.cpp b/Ast/src/Ast.cpp index 88c20e001..15f462852 100644 --- a/Ast/src/Ast.cpp +++ b/Ast/src/Ast.cpp @@ -3,6 +3,13 @@ #include "Luau/Common.h" +LUAU_FASTFLAG(LuauNativeAttribute); + +// The default value here is 643 because the first release in which this was implemented is 644, +// and actively we want new changes to be off by default until they're enabled consciously. +// The flag is placed in AST project here to be common in all Luau libraries +LUAU_DYNAMIC_FASTINTVARIABLE(LuauTypeSolverRelease, 643) + namespace Luau { @@ -15,6 +22,17 @@ static void visitTypeList(AstVisitor* visitor, const AstTypeList& list) list.tailType->visit(visitor); } +AstAttr::AstAttr(const Location& location, Type type) + : AstNode(ClassIndex(), location) + , type(type) +{ +} + +void AstAttr::visit(AstVisitor* visitor) +{ + visitor->visit(this); +} + int gAstRttiIndex = 0; AstExprGroup::AstExprGroup(const Location& location, AstExpr* expr) @@ -62,9 +80,10 @@ void AstExprConstantNumber::visit(AstVisitor* visitor) visitor->visit(this); } -AstExprConstantString::AstExprConstantString(const Location& location, const AstArray& value) +AstExprConstantString::AstExprConstantString(const Location& location, const AstArray& value, QuoteStyle quoteStyle) : AstExpr(ClassIndex(), location) , value(value) + , quoteStyle(quoteStyle) { } @@ -127,7 +146,13 @@ void AstExprCall::visit(AstVisitor* visitor) } AstExprIndexName::AstExprIndexName( - const Location& location, AstExpr* expr, const AstName& index, const Location& indexLocation, const Position& opPosition, char op) + const Location& location, + AstExpr* expr, + const AstName& index, + const Location& indexLocation, + const Position& opPosition, + char op +) : AstExpr(ClassIndex(), location) , expr(expr) , index(index) @@ -159,11 +184,24 @@ void AstExprIndexExpr::visit(AstVisitor* visitor) } } -AstExprFunction::AstExprFunction(const Location& location, const AstArray& generics, const AstArray& genericPacks, - AstLocal* self, const AstArray& args, bool vararg, const Location& varargLocation, AstStatBlock* body, size_t functionDepth, - const AstName& debugname, const std::optional& returnAnnotation, AstTypePack* varargAnnotation, bool hasEnd, - const std::optional& argLocation) +AstExprFunction::AstExprFunction( + const Location& location, + const AstArray& attributes, + const AstArray& generics, + const AstArray& genericPacks, + AstLocal* self, + const AstArray& args, + bool vararg, + const Location& varargLocation, + AstStatBlock* body, + size_t functionDepth, + const AstName& debugname, + const std::optional& returnAnnotation, + AstTypePack* varargAnnotation, + const std::optional& argLocation +) : AstExpr(ClassIndex(), location) + , attributes(attributes) , generics(generics) , genericPacks(genericPacks) , self(self) @@ -175,7 +213,6 @@ AstExprFunction::AstExprFunction(const Location& location, const AstArraytype == AstAttr::Type::Native) + return true; + } + return false; +} + AstExprTable::AstExprTable(const Location& location, const AstArray& items) : AstExpr(ClassIndex(), location) , items(items) @@ -285,6 +334,8 @@ std::string toString(AstExprBinary::Op op) return "*"; case AstExprBinary::Div: return "/"; + case AstExprBinary::FloorDiv: + return "//"; case AstExprBinary::Mod: return "%"; case AstExprBinary::Pow: @@ -308,8 +359,6 @@ std::string toString(AstExprBinary::Op op) case AstExprBinary::Or: return "or"; //GIDEROS - case AstExprBinary::DivInt: - return "//"; case AstExprBinary::MinOf: return "><"; case AstExprBinary::MaxOf: @@ -398,9 +447,10 @@ void AstExprError::visit(AstVisitor* visitor) } } -AstStatBlock::AstStatBlock(const Location& location, const AstArray& body) +AstStatBlock::AstStatBlock(const Location& location, const AstArray& body, bool hasEnd) : AstStat(ClassIndex(), location) , body(body) + , hasEnd(hasEnd) { } @@ -413,15 +463,20 @@ void AstStatBlock::visit(AstVisitor* visitor) } } -AstStatIf::AstStatIf(const Location& location, AstExpr* condition, AstStatBlock* thenbody, AstStat* elsebody, - const std::optional& thenLocation, const std::optional& elseLocation, bool hasEnd) +AstStatIf::AstStatIf( + const Location& location, + AstExpr* condition, + AstStatBlock* thenbody, + AstStat* elsebody, + const std::optional& thenLocation, + const std::optional& elseLocation +) : AstStat(ClassIndex(), location) , condition(condition) , thenbody(thenbody) , elsebody(elsebody) , thenLocation(thenLocation) , elseLocation(elseLocation) - , hasEnd(hasEnd) { } @@ -437,13 +492,12 @@ void AstStatIf::visit(AstVisitor* visitor) } } -AstStatWhile::AstStatWhile(const Location& location, AstExpr* condition, AstStatBlock* body, bool hasDo, const Location& doLocation, bool hasEnd) +AstStatWhile::AstStatWhile(const Location& location, AstExpr* condition, AstStatBlock* body, bool hasDo, const Location& doLocation) : AstStat(ClassIndex(), location) , condition(condition) , body(body) , hasDo(hasDo) , doLocation(doLocation) - , hasEnd(hasEnd) { } @@ -456,11 +510,11 @@ void AstStatWhile::visit(AstVisitor* visitor) } } -AstStatRepeat::AstStatRepeat(const Location& location, AstExpr* condition, AstStatBlock* body, bool hasUntil) +AstStatRepeat::AstStatRepeat(const Location& location, AstExpr* condition, AstStatBlock* body, bool DEPRECATED_hasUntil) : AstStat(ClassIndex(), location) , condition(condition) , body(body) - , hasUntil(hasUntil) + , DEPRECATED_hasUntil(DEPRECATED_hasUntil) { } @@ -521,7 +575,11 @@ void AstStatExpr::visit(AstVisitor* visitor) } AstStatLocal::AstStatLocal( - const Location& location, const AstArray& vars, const AstArray& values, const std::optional& equalsSignLocation) + const Location& location, + const AstArray& vars, + const AstArray& values, + const std::optional& equalsSignLocation +) : AstStat(ClassIndex(), location) , vars(vars) , values(values) @@ -544,8 +602,16 @@ void AstStatLocal::visit(AstVisitor* visitor) } } -AstStatFor::AstStatFor(const Location& location, AstLocal* var, AstExpr* from, AstExpr* to, AstExpr* step, AstStatBlock* body, bool hasDo, - const Location& doLocation, bool hasEnd) +AstStatFor::AstStatFor( + const Location& location, + AstLocal* var, + AstExpr* from, + AstExpr* to, + AstExpr* step, + AstStatBlock* body, + bool hasDo, + const Location& doLocation +) : AstStat(ClassIndex(), location) , var(var) , from(from) @@ -554,7 +620,6 @@ AstStatFor::AstStatFor(const Location& location, AstLocal* var, AstExpr* from, A , body(body) , hasDo(hasDo) , doLocation(doLocation) - , hasEnd(hasEnd) { } @@ -575,8 +640,16 @@ void AstStatFor::visit(AstVisitor* visitor) } } -AstStatForIn::AstStatForIn(const Location& location, const AstArray& vars, const AstArray& values, AstStatBlock* body, - bool hasIn, const Location& inLocation, bool hasDo, const Location& doLocation, bool hasEnd) +AstStatForIn::AstStatForIn( + const Location& location, + const AstArray& vars, + const AstArray& values, + AstStatBlock* body, + bool hasIn, + const Location& inLocation, + bool hasDo, + const Location& doLocation +) : AstStat(ClassIndex(), location) , vars(vars) , values(values) @@ -585,7 +658,6 @@ AstStatForIn::AstStatForIn(const Location& location, const AstArray& , inLocation(inLocation) , hasDo(hasDo) , doLocation(doLocation) - , hasEnd(hasEnd) { } @@ -671,8 +743,15 @@ void AstStatLocalFunction::visit(AstVisitor* visitor) func->visit(visitor); } -AstStatTypeAlias::AstStatTypeAlias(const Location& location, const AstName& name, const Location& nameLocation, - const AstArray& generics, const AstArray& genericPacks, AstType* type, bool exported) +AstStatTypeAlias::AstStatTypeAlias( + const Location& location, + const AstName& name, + const Location& nameLocation, + const AstArray& generics, + const AstArray& genericPacks, + AstType* type, + bool exported +) : AstStat(ClassIndex(), location) , name(name) , nameLocation(nameLocation) @@ -703,9 +782,24 @@ void AstStatTypeAlias::visit(AstVisitor* visitor) } } -AstStatDeclareGlobal::AstStatDeclareGlobal(const Location& location, const AstName& name, AstType* type) +AstStatTypeFunction::AstStatTypeFunction(const Location& location, const AstName& name, const Location& nameLocation, AstExprFunction* body) + : AstStat(ClassIndex(), location) + , name(name) + , nameLocation(nameLocation) + , body(body) +{ +} + +void AstStatTypeFunction::visit(AstVisitor* visitor) +{ + if (visitor->visit(this)) + body->visit(visitor); +} + +AstStatDeclareGlobal::AstStatDeclareGlobal(const Location& location, const AstName& name, const Location& nameLocation, AstType* type) : AstStat(ClassIndex(), location) , name(name) + , nameLocation(nameLocation) , type(type) { } @@ -716,15 +810,55 @@ void AstStatDeclareGlobal::visit(AstVisitor* visitor) type->visit(visitor); } -AstStatDeclareFunction::AstStatDeclareFunction(const Location& location, const AstName& name, const AstArray& generics, - const AstArray& genericPacks, const AstTypeList& params, const AstArray& paramNames, - const AstTypeList& retTypes) +AstStatDeclareFunction::AstStatDeclareFunction( + const Location& location, + const AstName& name, + const Location& nameLocation, + const AstArray& generics, + const AstArray& genericPacks, + const AstTypeList& params, + const AstArray& paramNames, + bool vararg, + const Location& varargLocation, + const AstTypeList& retTypes +) + : AstStat(ClassIndex(), location) + , attributes() + , name(name) + , nameLocation(nameLocation) + , generics(generics) + , genericPacks(genericPacks) + , params(params) + , paramNames(paramNames) + , vararg(vararg) + , varargLocation(varargLocation) + , retTypes(retTypes) +{ +} + +AstStatDeclareFunction::AstStatDeclareFunction( + const Location& location, + const AstArray& attributes, + const AstName& name, + const Location& nameLocation, + const AstArray& generics, + const AstArray& genericPacks, + const AstTypeList& params, + const AstArray& paramNames, + bool vararg, + const Location& varargLocation, + const AstTypeList& retTypes +) : AstStat(ClassIndex(), location) + , attributes(attributes) , name(name) + , nameLocation(nameLocation) , generics(generics) , genericPacks(genericPacks) , params(params) , paramNames(paramNames) + , vararg(vararg) + , varargLocation(varargLocation) , retTypes(retTypes) { } @@ -738,12 +872,29 @@ void AstStatDeclareFunction::visit(AstVisitor* visitor) } } +bool AstStatDeclareFunction::isCheckedFunction() const +{ + for (const AstAttr* attr : attributes) + { + if (attr->type == AstAttr::Type::Checked) + return true; + } + + return false; +} + AstStatDeclareClass::AstStatDeclareClass( - const Location& location, const AstName& name, std::optional superName, const AstArray& props) + const Location& location, + const AstName& name, + std::optional superName, + const AstArray& props, + AstTableIndexer* indexer +) : AstStat(ClassIndex(), location) , name(name) , superName(superName) , props(props) + , indexer(indexer) { } @@ -757,7 +908,11 @@ void AstStatDeclareClass::visit(AstVisitor* visitor) } AstStatError::AstStatError( - const Location& location, const AstArray& expressions, const AstArray& statements, unsigned messageIndex) + const Location& location, + const AstArray& expressions, + const AstArray& statements, + unsigned messageIndex +) : AstStat(ClassIndex(), location) , expressions(expressions) , statements(statements) @@ -778,11 +933,20 @@ void AstStatError::visit(AstVisitor* visitor) } AstTypeReference::AstTypeReference( - const Location& location, std::optional prefix, AstName name, bool hasParameterList, const AstArray& parameters) + const Location& location, + std::optional prefix, + AstName name, + std::optional prefixLocation, + const Location& nameLocation, + bool hasParameterList, + const AstArray& parameters +) : AstType(ClassIndex(), location) , hasParameterList(hasParameterList) , prefix(prefix) + , prefixLocation(prefixLocation) , name(name) + , nameLocation(nameLocation) , parameters(parameters) { } @@ -824,9 +988,36 @@ void AstTypeTable::visit(AstVisitor* visitor) } } -AstTypeFunction::AstTypeFunction(const Location& location, const AstArray& generics, const AstArray& genericPacks, - const AstTypeList& argTypes, const AstArray>& argNames, const AstTypeList& returnTypes) +AstTypeFunction::AstTypeFunction( + const Location& location, + const AstArray& generics, + const AstArray& genericPacks, + const AstTypeList& argTypes, + const AstArray>& argNames, + const AstTypeList& returnTypes +) + : AstType(ClassIndex(), location) + , attributes() + , generics(generics) + , genericPacks(genericPacks) + , argTypes(argTypes) + , argNames(argNames) + , returnTypes(returnTypes) +{ + LUAU_ASSERT(argNames.size == 0 || argNames.size == argTypes.types.size); +} + +AstTypeFunction::AstTypeFunction( + const Location& location, + const AstArray& attributes, + const AstArray& generics, + const AstArray& genericPacks, + const AstTypeList& argTypes, + const AstArray>& argNames, + const AstTypeList& returnTypes +) : AstType(ClassIndex(), location) + , attributes(attributes) , generics(generics) , genericPacks(genericPacks) , argTypes(argTypes) @@ -845,6 +1036,17 @@ void AstTypeFunction::visit(AstVisitor* visitor) } } +bool AstTypeFunction::isCheckedFunction() const +{ + for (const AstAttr* attr : attributes) + { + if (attr->type == AstAttr::Type::Checked) + return true; + } + + return false; +} + AstTypeTypeof::AstTypeTypeof(const Location& location, AstExpr* expr) : AstType(ClassIndex(), location) , expr(expr) @@ -967,6 +1169,14 @@ void AstTypePackGeneric::visit(AstVisitor* visitor) visitor->visit(this); } +bool isLValue(const AstExpr* expr) +{ + return expr->is() + || expr->is() + || expr->is() + || expr->is(); +} + AstName getIdentifier(AstExpr* node) { if (AstExprGlobal* expr = node->as()) diff --git a/Ast/src/Confusables.cpp b/Ast/src/Confusables.cpp index 1c792156b..8f7fb56c4 100644 --- a/Ast/src/Confusables.cpp +++ b/Ast/src/Confusables.cpp @@ -1808,9 +1808,15 @@ static const Confusable kConfusables[] = const char* findConfusable(uint32_t codepoint) { - auto it = std::lower_bound(std::begin(kConfusables), std::end(kConfusables), codepoint, [](const Confusable& lhs, uint32_t rhs) { - return lhs.codepoint < rhs; - }); + auto it = std::lower_bound( + std::begin(kConfusables), + std::end(kConfusables), + codepoint, + [](const Confusable& lhs, uint32_t rhs) + { + return lhs.codepoint < rhs; + } + ); return (it != std::end(kConfusables) && it->codepoint == codepoint) ? it->text : nullptr; } diff --git a/Ast/src/Lexer.cpp b/Ast/src/Lexer.cpp index bdbd51d10..939f37c56 100644 --- a/Ast/src/Lexer.cpp +++ b/Ast/src/Lexer.cpp @@ -1,13 +1,12 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/Lexer.h" +#include "Luau/Common.h" #include "Luau/Confusables.h" #include "Luau/StringUtils.h" #include -LUAU_FASTFLAGVARIABLE(LuauFixInterpStringMid, false) - namespace Luau { @@ -91,8 +90,10 @@ Lexeme::Lexeme(const Location& location, Type type, const char* data, size_t siz , length(unsigned(size)) , data(data) { - LUAU_ASSERT(type == RawString || type == QuotedString || type == InterpStringBegin || type == InterpStringMid || type == InterpStringEnd || - type == InterpStringSimple || type == BrokenInterpDoubleBrace || type == Number || type == Comment || type == BlockComment); + LUAU_ASSERT( + type == RawString || type == QuotedString || type == InterpStringBegin || type == InterpStringMid || type == InterpStringEnd || + type == InterpStringSimple || type == BrokenInterpDoubleBrace || type == Number || type == Comment || type == BlockComment + ); } Lexeme::Lexeme(const Location& location, Type type, const char* name) @@ -101,11 +102,21 @@ Lexeme::Lexeme(const Location& location, Type type, const char* name) , length(0) , name(name) { - LUAU_ASSERT(type == Name || (type >= Reserved_BEGIN && type < Lexeme::Reserved_END)); + LUAU_ASSERT(type == Name || type == Attribute || (type >= Reserved_BEGIN && type < Lexeme::Reserved_END)); +} + +unsigned int Lexeme::getLength() const +{ + LUAU_ASSERT( + type == RawString || type == QuotedString || type == InterpStringBegin || type == InterpStringMid || type == InterpStringEnd || + type == InterpStringSimple || type == BrokenInterpDoubleBrace || type == Number || type == Comment || type == BlockComment + ); + + return length; } -static const char* kReserved[] = {"and", "break", "do", "else", "elseif", "end", "false", "for", "function", "if", "in", "local", "nil", "not", "or", - "repeat", "return", "then", "true", "until", "while"}; +static const char* kReserved[] = {"and", "break", "do", "else", "elseif", "end", "false", "for", "function", "if", "in", + "local", "nil", "not", "or", "repeat", "return", "then", "true", "until", "while"}; std::string Lexeme::toString() const { @@ -138,6 +149,9 @@ std::string Lexeme::toString() const case DoubleColon: return "'::'"; + case FloorDiv: + return "'//'"; + case AddAssign: return "'+='"; @@ -150,6 +164,9 @@ std::string Lexeme::toString() const case DivAssign: return "'/='"; + case FloorDivAssign: + return "'//='"; + case ModAssign: return "'%='"; @@ -159,9 +176,6 @@ std::string Lexeme::toString() const case ConcatAssign: return "'..='"; //GIDEROS ADDED - case DivInt: - return "'//'"; - case MaxOf: return "'<>'"; @@ -205,6 +219,9 @@ std::string Lexeme::toString() const case Comment: return "comment"; + case Attribute: + return name ? format("'%s'", name) : "attribute"; + case BrokenString: return "malformed string"; @@ -292,7 +309,7 @@ std::pair AstNameTable::getOrAddWithType(const char* name nameData[length] = 0; const_cast(entry).value = AstName(nameData); - const_cast(entry).type = Lexeme::Name; + const_cast(entry).type = (name[0] == '@' ? Lexeme::Attribute : Lexeme::Name); return std::make_pair(entry.value, entry.type); } @@ -398,7 +415,7 @@ const Lexeme& Lexer::next(bool skipComments, bool updatePrevLocation) { // consume whitespace before the token while (isSpace(peekch())) - consume(); + consumeAny(); if (updatePrevLocation) prevLocation = lexeme.location; @@ -425,6 +442,8 @@ Lexeme Lexer::lookahead() unsigned int currentLineOffset = lineOffset; Lexeme currentLexeme = lexeme; Location currentPrevLocation = prevLocation; + size_t currentBraceStackSize = braceStack.size(); + BraceType currentBraceType = braceStack.empty() ? BraceType::Normal : braceStack.back(); Lexeme result = next(); @@ -434,6 +453,11 @@ Lexeme Lexer::lookahead() lexeme = currentLexeme; prevLocation = currentPrevLocation; + if (braceStack.size() < currentBraceStackSize) + braceStack.push_back(currentBraceType); + else if (braceStack.size() > currentBraceStackSize) + braceStack.pop_back(); + return result; } @@ -463,7 +487,17 @@ Position Lexer::position() const return Position(line, offset - lineOffset); } +LUAU_FORCEINLINE void Lexer::consume() +{ + // consume() assumes current character is known to not be a newline; use consumeAny if this is not guaranteed + LUAU_ASSERT(!isNewline(buffer[offset])); + + offset++; +} + +LUAU_FORCEINLINE +void Lexer::consumeAny() { if (isNewline(buffer[offset])) { @@ -549,7 +583,7 @@ Lexeme Lexer::readLongString(const Position& start, int sep, Lexeme::Type ok, Le } else { - consume(); + consumeAny(); } } @@ -565,7 +599,7 @@ void Lexer::readBackslashInString() case '\r': consume(); if (peekch() == '\n') - consume(); + consumeAny(); break; case 0: @@ -574,11 +608,11 @@ void Lexer::readBackslashInString() case 'z': consume(); while (isSpace(peekch())) - consume(); + consumeAny(); break; default: - consume(); + consumeAny(); } } @@ -665,9 +699,7 @@ Lexeme Lexer::readInterpolatedStringSection(Position start, Lexeme::Type formatT } consume(); - Lexeme lexemeOutput(Location(start, position()), FFlag::LuauFixInterpStringMid ? formatType : Lexeme::InterpStringBegin, - &buffer[startOffset], offset - startOffset - 1); - return lexemeOutput; + return Lexeme(Location(start, position()), formatType, &buffer[startOffset], offset - startOffset - 1); } default: @@ -708,7 +740,7 @@ Lexeme Lexer::readNumber(const Position& start, unsigned int startOffset) std::pair Lexer::readName() { - LUAU_ASSERT(isAlpha(peekch()) || peekch() == '_'); + LUAU_ASSERT(isAlpha(peekch()) || peekch() == '_' || peekch() == '@'); unsigned int startOffset = offset; @@ -925,20 +957,31 @@ Lexeme Lexer::readNext() return Lexeme(Location(start, 1), '+'); case '/': + { consume(); - if (peekch() == '=') + char ch = peekch(); + + if (ch == '=') { consume(); return Lexeme(Location(start, 2), Lexeme::DivAssign); } - else if (peekch() == '/') + else if (ch == '/') { - consume(); - return Lexeme(Location(start, 2), Lexeme::DivInt); + consume(); + + if (peekch() == '=') + { + consume(); + return Lexeme(Location(start, 3), Lexeme::FloorDivAssign); + } + else + return Lexeme(Location(start, 2), Lexeme::FloorDiv); } else return Lexeme(Location(start, 1), '/'); + } case '*': consume(); @@ -1001,13 +1044,20 @@ Lexeme Lexer::readNext() case ';': case ',': case '#': + case '?': + case '&': + case '|': { char ch = peekch(); consume(); return Lexeme(Location(start, 1), ch); } - + case '@': + { + std::pair attribute = readName(); + return Lexeme(Location(start, position()), Lexeme::Attribute, attribute.first.value); + } default: if (isDigit(peekch())) { diff --git a/Ast/src/Location.cpp b/Ast/src/Location.cpp index d01d8a186..c2c66d9f2 100644 --- a/Ast/src/Location.cpp +++ b/Ast/src/Location.cpp @@ -4,12 +4,6 @@ namespace Luau { -Position::Position(unsigned int line, unsigned int column) - : line(line) - , column(column) -{ -} - bool Position::operator==(const Position& rhs) const { return this->column == rhs.column && this->line == rhs.line; @@ -60,30 +54,6 @@ void Position::shift(const Position& start, const Position& oldEnd, const Positi } } -Location::Location() - : begin(0, 0) - , end(0, 0) -{ -} - -Location::Location(const Position& begin, const Position& end) - : begin(begin) - , end(end) -{ -} - -Location::Location(const Position& begin, unsigned int length) - : begin(begin) - , end(begin.line, begin.column + length) -{ -} - -Location::Location(const Location& begin, const Location& end) - : begin(begin.begin) - , end(end.end) -{ -} - bool Location::operator==(const Location& rhs) const { return this->begin == rhs.begin && this->end == rhs.end; diff --git a/Ast/src/Parser.cpp b/Ast/src/Parser.cpp index d810c5e2d..8fd3fbf52 100644 --- a/Ast/src/Parser.cpp +++ b/Ast/src/Parser.cpp @@ -1,6 +1,7 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/Parser.h" +#include "Luau/Common.h" #include "Luau/TimeTrace.h" #include @@ -11,19 +12,30 @@ #include #include -// Warning: If you are introducing new syntax, ensure that it is behind a separate -// flag so that we don't break production games by reverting syntax changes. -// See docs/SyntaxChanges.md for an explanation. LUAU_FASTINTVARIABLE(LuauRecursionLimit, 1000) +LUAU_FASTINTVARIABLE(LuauTypeLengthLimit, 1000) LUAU_FASTINTVARIABLE(LuauParseErrorLimit, 100) -LUAU_FASTFLAGVARIABLE(LuauParserErrorsOnMissingDefaultTypePackArgument, false) - -#define ERROR_INVALID_INTERP_DOUBLE_BRACE "Double braces are not permitted within interpolated strings. Did you mean '\\{'?" +// Warning: If you are introducing new syntax, ensure that it is behind a separate +// flag so that we don't break production games by reverting syntax changes. +// See docs/SyntaxChanges.md for an explanation. +LUAU_FASTFLAGVARIABLE(LuauSolverV2, false) +LUAU_FASTFLAGVARIABLE(LuauNativeAttribute, false) +LUAU_FASTFLAGVARIABLE(LuauAttributeSyntaxFunExpr, false) +LUAU_FASTFLAGVARIABLE(LuauUserDefinedTypeFunctionsSyntax2, false) +LUAU_FASTFLAGVARIABLE(LuauAllowFragmentParsing, false) namespace Luau { +struct AttributeEntry +{ + const char* name; + AstAttr::Type type; +}; + +AttributeEntry kAttributeEntries[] = {{"@checked", AstAttr::Type::Checked}, {"@native", AstAttr::Type::Native}, {nullptr, AstAttr::Type::Checked}}; + ParseError::ParseError(const Location& location, const std::string& message) : location(location) , message(message) @@ -133,7 +145,7 @@ void TempVector::push_back(const T& item) size_++; } -static bool shouldParseTypePackAnnotation(Lexer& lexer) +static bool shouldParseTypePack(Lexer& lexer) { if (lexer.current().type == Lexeme::Dot3) return true; @@ -179,9 +191,18 @@ Parser::Parser(const char* buffer, size_t bufferSize, AstNameTable& names, Alloc functionStack.reserve(8); functionStack.push_back(top); - nameSelf = names.addStatic("self"); - nameNumber = names.addStatic("number"); - nameError = names.addStatic(kParseNameError); + if (FFlag::LuauAllowFragmentParsing) + { + nameSelf = names.getOrAdd("self"); + nameNumber = names.getOrAdd("number"); + nameError = names.getOrAdd(kParseNameError); + } + else + { + nameSelf = names.addStatic("self"); + nameNumber = names.addStatic("number"); + nameError = names.addStatic(kParseNameError); + } nameNil = names.getOrAdd("nil"); // nil is a reserved keyword matchRecoveryStopOnToken.assign(Lexeme::Type::Reserved_END, 0); @@ -203,6 +224,15 @@ Parser::Parser(const char* buffer, size_t bufferSize, AstNameTable& names, Alloc scratchExpr.reserve(16); scratchLocal.reserve(16); scratchBinding.reserve(16); + + if (FFlag::LuauAllowFragmentParsing) + { + if (options.parseFragment) + { + localMap = options.parseFragment->localMap; + localStack = options.parseFragment->localStack; + } + } } bool Parser::blockFollow(const Lexeme& l) @@ -247,13 +277,13 @@ AstStatBlock* Parser::parseBlockNoScope() while (!blockFollow(lexer.current())) { - unsigned int recursionCounterOld = recursionCounter; + unsigned int oldRecursionCount = recursionCounter; incrementRecursionCounter("block"); AstStat* stat = parseStat(); - recursionCounter = recursionCounterOld; + recursionCounter = oldRecursionCount; if (lexer.current().type == ';') { @@ -282,7 +312,9 @@ AstStatBlock* Parser::parseBlockNoScope() // for binding `=' exp `,' exp [`,' exp] do block end | // for namelist in explist do block end | // function funcname funcbody | +// attributes function funcname funcbody | // local function Name funcbody | +// local attributes function Name funcbody | // local namelist [`=' explist] // laststat ::= return [explist] | break AstStat* Parser::parseStat() @@ -301,13 +333,15 @@ AstStat* Parser::parseStat() case Lexeme::ReservedRepeat: return parseRepeat(); case Lexeme::ReservedFunction: - return parseFunctionStat(); + return parseFunctionStat(AstArray({nullptr, 0})); case Lexeme::ReservedLocal: - return parseLocal(); + return parseLocal(AstArray({nullptr, 0})); case Lexeme::ReservedReturn: return parseReturn(); case Lexeme::ReservedBreak: return parseBreak(); + case Lexeme::Attribute: + return parseAttributeStat(); default:; } @@ -330,24 +364,22 @@ AstStat* Parser::parseStat() // we know this isn't a call or an assignment; therefore it must be a context-sensitive keyword such as `type` or `continue` AstName ident = getIdentifier(expr); - if (options.allowTypeAnnotations) + if (ident == "type") + return parseTypeAlias(expr->location, /* exported= */ false); + + if (ident == "export" && lexer.current().type == Lexeme::Name && AstName(lexer.current().name) == "type") { - if (ident == "type") - return parseTypeAlias(expr->location, /* exported =*/false); - if (ident == "export" && lexer.current().type == Lexeme::Name && AstName(lexer.current().name) == "type") - { - nextLexeme(); - return parseTypeAlias(expr->location, /* exported =*/true); - } + nextLexeme(); + return parseTypeAlias(expr->location, /* exported= */ true); } - if (options.supportContinueStatement && ident == "continue") + if (ident == "continue") return parseContinue(expr->location); - if (options.allowTypeAnnotations && options.allowDeclarationSyntax) + if (options.allowDeclarationSyntax) { if (ident == "declare") - return parseDeclaration(expr->location); + return parseDeclaration(expr->location, AstArray({nullptr, 0})); } // skip unexpected symbol if lexer couldn't advance at all (statements are parsed in a loop) @@ -376,17 +408,16 @@ AstStat* Parser::parseIf() AstStat* elsebody = nullptr; Location end = start; std::optional elseLocation; - bool hasEnd = false; if (lexer.current().type == Lexeme::ReservedElseif) { - unsigned int recursionCounterOld = recursionCounter; + thenbody->hasEnd = true; + unsigned int oldRecursionCount = recursionCounter; incrementRecursionCounter("elseif"); elseLocation = lexer.current().location; elsebody = parseIf(); end = elsebody->location; - hasEnd = elsebody->as()->hasEnd; - recursionCounter = recursionCounterOld; + recursionCounter = oldRecursionCount; } else { @@ -394,6 +425,7 @@ AstStat* Parser::parseIf() if (lexer.current().type == Lexeme::ReservedElse) { + thenbody->hasEnd = true; elseLocation = lexer.current().location; matchThenElse = lexer.current(); nextLexeme(); @@ -404,10 +436,18 @@ AstStat* Parser::parseIf() end = lexer.current().location; - hasEnd = expectMatchEndAndConsume(Lexeme::ReservedEnd, matchThenElse); + bool hasEnd = expectMatchEndAndConsume(Lexeme::ReservedEnd, matchThenElse); + + if (elsebody) + { + if (AstStatBlock* elseBlock = elsebody->as()) + elseBlock->hasEnd = hasEnd; + } + else + thenbody->hasEnd = hasEnd; } - return allocator.alloc(Location(start, end), cond, thenbody, elsebody, thenLocation, elseLocation, hasEnd); + return allocator.alloc(Location(start, end), cond, thenbody, elsebody, thenLocation, elseLocation); } // while exp do block end @@ -431,8 +471,9 @@ AstStat* Parser::parseWhile() Location end = lexer.current().location; bool hasEnd = expectMatchEndAndConsume(Lexeme::ReservedEnd, matchDo); + body->hasEnd = hasEnd; - return allocator.alloc(Location(start, end), cond, body, hasDo, matchDo.location, hasEnd); + return allocator.alloc(Location(start, end), cond, body, hasDo, matchDo.location); } // repeat block until exp @@ -452,6 +493,7 @@ AstStat* Parser::parseRepeat() functionStack.back().loopDepth--; bool hasUntil = expectMatchEndAndConsume(Lexeme::ReservedUntil, matchRepeat); + body->hasEnd = hasUntil; AstExpr* cond = parseExpr(); @@ -468,11 +510,11 @@ AstStat* Parser::parseDo() Lexeme matchDo = lexer.current(); nextLexeme(); // do - AstStat* body = parseBlock(); + AstStatBlock* body = parseBlock(); body->location.begin = start.begin; - expectMatchEndAndConsume(Lexeme::ReservedEnd, matchDo); + body->hasEnd = expectMatchEndAndConsume(Lexeme::ReservedEnd, matchDo); return body; } @@ -548,8 +590,9 @@ AstStat* Parser::parseFor() Location end = lexer.current().location; bool hasEnd = expectMatchEndAndConsume(Lexeme::ReservedEnd, matchDo); + body->hasEnd = hasEnd; - return allocator.alloc(Location(start, end), var, from, to, step, body, hasDo, matchDo.location, hasEnd); + return allocator.alloc(Location(start, end), var, from, to, step, body, hasDo, matchDo.location); } else { @@ -590,9 +633,9 @@ AstStat* Parser::parseFor() Location end = lexer.current().location; bool hasEnd = expectMatchEndAndConsume(Lexeme::ReservedEnd, matchDo); + body->hasEnd = hasEnd; - return allocator.alloc( - Location(start, end), copy(vars), copy(values), body, hasIn, inLocation, hasDo, matchDo.location, hasEnd); + return allocator.alloc(Location(start, end), copy(vars), copy(values), body, hasIn, inLocation, hasDo, matchDo.location); } } @@ -605,7 +648,7 @@ AstExpr* Parser::parseFunctionName(Location start, bool& hasself, AstName& debug // parse funcname into a chain of indexing operators AstExpr* expr = parseNameExpr("function name"); - unsigned int recursionCounterOld = recursionCounter; + unsigned int oldRecursionCount = recursionCounter; while (lexer.current().type == '.') { @@ -623,7 +666,7 @@ AstExpr* Parser::parseFunctionName(Location start, bool& hasself, AstName& debug incrementRecursionCounter("function name"); } - recursionCounter = recursionCounterOld; + recursionCounter = oldRecursionCount; // finish with : if (lexer.current().type == ':') @@ -645,7 +688,7 @@ AstExpr* Parser::parseFunctionName(Location start, bool& hasself, AstName& debug } // function funcname funcbody -AstStat* Parser::parseFunctionStat() +AstStat* Parser::parseFunctionStat(const AstArray& attributes) { Location start = lexer.current().location; @@ -658,16 +701,126 @@ AstStat* Parser::parseFunctionStat() matchRecoveryStopOnToken[Lexeme::ReservedEnd]++; - AstExprFunction* body = parseFunctionBody(hasself, matchFunction, debugname, nullptr).first; + AstExprFunction* body = parseFunctionBody(hasself, matchFunction, debugname, nullptr, attributes).first; matchRecoveryStopOnToken[Lexeme::ReservedEnd]--; return allocator.alloc(Location(start, body->location), expr, body); } + +std::pair Parser::validateAttribute(const char* attributeName, const TempVector& attributes) +{ + AstAttr::Type type; + + // check if the attribute name is valid + + bool found = false; + + for (int i = 0; kAttributeEntries[i].name; ++i) + { + found = !strcmp(attributeName, kAttributeEntries[i].name); + if (found) + { + type = kAttributeEntries[i].type; + + if (!FFlag::LuauNativeAttribute && type == AstAttr::Type::Native) + found = false; + + break; + } + } + + if (!found) + { + if (strlen(attributeName) == 1) + report(lexer.current().location, "Attribute name is missing"); + else + report(lexer.current().location, "Invalid attribute '%s'", attributeName); + } + else + { + // check that attribute is not duplicated + for (const AstAttr* attr : attributes) + { + if (attr->type == type) + { + report(lexer.current().location, "Cannot duplicate attribute '%s'", attributeName); + } + } + } + + return {found, type}; +} + +// attribute ::= '@' NAME +void Parser::parseAttribute(TempVector& attributes) +{ + LUAU_ASSERT(lexer.current().type == Lexeme::Type::Attribute); + + Location loc = lexer.current().location; + + const char* name = lexer.current().name; + const auto [found, type] = validateAttribute(name, attributes); + + nextLexeme(); + + if (found) + attributes.push_back(allocator.alloc(loc, type)); +} + +// attributes ::= {attribute} +AstArray Parser::parseAttributes() +{ + Lexeme::Type type = lexer.current().type; + + LUAU_ASSERT(type == Lexeme::Attribute); + + TempVector attributes(scratchAttr); + + while (lexer.current().type == Lexeme::Attribute) + parseAttribute(attributes); + + return copy(attributes); +} + +// attributes local function Name funcbody +// attributes function funcname funcbody +// attributes `declare function' Name`(' [parlist] `)' [`:` Type] +// declare Name '{' Name ':' attributes `(' [parlist] `)' [`:` Type] '}' +AstStat* Parser::parseAttributeStat() +{ + AstArray attributes = parseAttributes(); + + Lexeme::Type type = lexer.current().type; + + switch (type) + { + case Lexeme::Type::ReservedFunction: + return parseFunctionStat(attributes); + case Lexeme::Type::ReservedLocal: + return parseLocal(attributes); + case Lexeme::Type::Name: + if (options.allowDeclarationSyntax && !strcmp("declare", lexer.current().data)) + { + AstExpr* expr = parsePrimaryExpr(/* asStatement= */ true); + return parseDeclaration(expr->location, attributes); + } + [[fallthrough]]; + default: + return reportStatError( + lexer.current().location, + {}, + {}, + "Expected 'function', 'local function', 'declare function' or a function type declaration after attribute, but got %s instead", + lexer.current().toString().c_str() + ); + } +} + // local function Name funcbody | // local bindinglist [`=' explist] -AstStat* Parser::parseLocal() +AstStat* Parser::parseLocal(const AstArray& attributes) { Location start = lexer.current().location; @@ -687,7 +840,7 @@ AstStat* Parser::parseLocal() matchRecoveryStopOnToken[Lexeme::ReservedEnd]++; - auto [body, var] = parseFunctionBody(false, matchFunction, name.name, &name); + auto [body, var] = parseFunctionBody(false, matchFunction, name.name, &name, attributes); matchRecoveryStopOnToken[Lexeme::ReservedEnd]--; @@ -697,6 +850,17 @@ AstStat* Parser::parseLocal() } else { + if (attributes.size != 0) + { + return reportStatError( + lexer.current().location, + {}, + {}, + "Expected 'function' after local declaration with attribute, but got %s instead", + lexer.current().toString().c_str() + ); + } + matchRecoveryStopOnToken['=']++; TempVector names(scratchBinding); @@ -745,9 +909,18 @@ AstStat* Parser::parseReturn() return allocator.alloc(Location(start, end), copy(list)); } -// type Name [`<' varlist `>'] `=' typeannotation +// type Name [`<' varlist `>'] `=' Type AstStat* Parser::parseTypeAlias(const Location& start, bool exported) { + // parsing a type function + if (FFlag::LuauUserDefinedTypeFunctionsSyntax2) + { + if (lexer.current().type == Lexeme::ReservedFunction) + return parseTypeFunction(start, exported); + } + + // parsing a type alias + // note: `type` token is already parsed for us, so we just need to parse the rest std::optional name = parseNameOpt("type name"); @@ -760,15 +933,45 @@ AstStat* Parser::parseTypeAlias(const Location& start, bool exported) expectAndConsume('=', "type alias"); - AstType* type = parseTypeAnnotation(); + AstType* type = parseType(); return allocator.alloc(Location(start, type->location), name->name, name->location, generics, genericPacks, type, exported); } -AstDeclaredClassProp Parser::parseDeclaredClassMethod() +// type function Name `(' arglist `)' `=' funcbody `end' +AstStat* Parser::parseTypeFunction(const Location& start, bool exported) { + Lexeme matchFn = lexer.current(); nextLexeme(); + + if (exported) + report(start, "Type function cannot be exported"); + + // parse the name of the type function + std::optional fnName = parseNameOpt("type function name"); + if (!fnName) + fnName = Name(nameError, lexer.current().location); + + matchRecoveryStopOnToken[Lexeme::ReservedEnd]++; + + size_t oldTypeFunctionDepth = typeFunctionDepth; + typeFunctionDepth = functionStack.size(); + + AstExprFunction* body = parseFunctionBody(/* hasself */ false, matchFn, fnName->name, nullptr, AstArray({nullptr, 0})).first; + + typeFunctionDepth = oldTypeFunctionDepth; + + matchRecoveryStopOnToken[Lexeme::ReservedEnd]--; + + return allocator.alloc(Location(start, body->location), fnName->name, fnName->location, body); +} + +AstDeclaredClassProp Parser::parseDeclaredClassMethod() +{ Location start = lexer.current().location; + + nextLexeme(); + Name fnName = parseName("function name"); // TODO: generic method declarations CLI-39909 @@ -792,16 +995,17 @@ AstDeclaredClassProp Parser::parseDeclaredClassMethod() expectMatchAndConsume(')', matchParen); - AstTypeList retTypes = parseOptionalReturnTypeAnnotation().value_or(AstTypeList{copy(nullptr, 0), nullptr}); - Location end = lexer.current().location; + AstTypeList retTypes = parseOptionalReturnType().value_or(AstTypeList{copy(nullptr, 0), nullptr}); + Location end = lexer.previousLocation(); - TempVector vars(scratchAnnotation); + TempVector vars(scratchType); TempVector> varNames(scratchOptArgName); if (args.size() == 0 || args[0].name.name != "self" || args[0].annotation != nullptr) { return AstDeclaredClassProp{ - fnName.name, reportTypeAnnotationError(Location(start, end), {}, "'self' must be present as the unannotated first parameter"), true}; + fnName.name, fnName.location, reportTypeError(Location(start, end), {}, "'self' must be present as the unannotated first parameter"), true + }; } // Skip the first index. @@ -812,26 +1016,37 @@ AstDeclaredClassProp Parser::parseDeclaredClassMethod() if (args[i].annotation) vars.push_back(args[i].annotation); else - vars.push_back(reportTypeAnnotationError(Location(start, end), {}, "All declaration parameters aside from 'self' must be annotated")); + vars.push_back(reportTypeError(Location(start, end), {}, "All declaration parameters aside from 'self' must be annotated")); } if (vararg && !varargAnnotation) report(start, "All declaration parameters aside from 'self' must be annotated"); AstType* fnType = allocator.alloc( - Location(start, end), generics, genericPacks, AstTypeList{copy(vars), varargAnnotation}, copy(varNames), retTypes); + Location(start, end), generics, genericPacks, AstTypeList{copy(vars), varargAnnotation}, copy(varNames), retTypes + ); - return AstDeclaredClassProp{fnName.name, fnType, true}; + return AstDeclaredClassProp{fnName.name, fnName.location, fnType, true, Location(start, end)}; } -AstStat* Parser::parseDeclaration(const Location& start) +AstStat* Parser::parseDeclaration(const Location& start, const AstArray& attributes) { // `declare` token is already parsed at this point + + if ((attributes.size != 0) && (lexer.current().type != Lexeme::ReservedFunction)) + return reportStatError( + lexer.current().location, + {}, + {}, + "Expected a function type declaration after attribute, but got %s instead", + lexer.current().toString().c_str() + ); + if (lexer.current().type == Lexeme::ReservedFunction) { nextLexeme(); - Name globalName = parseName("global function name"); + Name globalName = parseName("global function name"); auto [generics, genericPacks] = parseGenericTypeList(/* withDefaultValues= */ false); MatchLexeme matchParen = lexer.current(); @@ -849,10 +1064,10 @@ AstStat* Parser::parseDeclaration(const Location& start) expectMatchAndConsume(')', matchParen); - AstTypeList retTypes = parseOptionalReturnTypeAnnotation().value_or(AstTypeList{copy(nullptr, 0)}); + AstTypeList retTypes = parseOptionalReturnType().value_or(AstTypeList{copy(nullptr, 0)}); Location end = lexer.current().location; - TempVector vars(scratchAnnotation); + TempVector vars(scratchType); TempVector varNames(scratchArgName); for (size_t i = 0; i < args.size(); ++i) @@ -868,7 +1083,18 @@ AstStat* Parser::parseDeclaration(const Location& start) return reportStatError(Location(start, end), {}, {}, "All declaration parameters must be annotated"); return allocator.alloc( - Location(start, end), globalName.name, generics, genericPacks, AstTypeList{copy(vars), varargAnnotation}, copy(varNames), retTypes); + Location(start, end), + attributes, + globalName.name, + globalName.location, + generics, + genericPacks, + AstTypeList{copy(vars), varargAnnotation}, + copy(varNames), + vararg, + varargLocation, + retTypes + ); } else if (AstName(lexer.current().name) == "class") { @@ -884,6 +1110,7 @@ AstStat* Parser::parseDeclaration(const Location& start) } TempVector props(scratchDeclaredClassProps); + AstTableIndexer* indexer = nullptr; while (lexer.current().type != Lexeme::ReservedEnd) { @@ -892,45 +1119,72 @@ AstStat* Parser::parseDeclaration(const Location& start) { props.push_back(parseDeclaredClassMethod()); } - else if (lexer.current().type == '[') + else if (lexer.current().type == '[' && (lexer.lookahead().type == Lexeme::RawString || lexer.lookahead().type == Lexeme::QuotedString)) { const Lexeme begin = lexer.current(); nextLexeme(); // [ + const Location nameBegin = lexer.current().location; std::optional> chars = parseCharArray(); + const Location nameEnd = lexer.previousLocation(); + expectMatchAndConsume(']', begin); expectAndConsume(':', "property type annotation"); - AstType* type = parseTypeAnnotation(); + AstType* type = parseType(); - // TODO: since AstName conains a char*, it can't contain null + // since AstName contains a char*, it can't contain null bool containsNull = chars && (strnlen(chars->data, chars->size) < chars->size); if (chars && !containsNull) - props.push_back(AstDeclaredClassProp{AstName(chars->data), type, false}); + { + props.push_back(AstDeclaredClassProp{ + AstName(chars->data), Location(nameBegin, nameEnd), type, false, Location(begin.location, lexer.previousLocation()) + }); + } + else + { + report(begin.location, "String literal contains malformed escape sequence or \\0"); + } + } + else if (lexer.current().type == '[') + { + if (indexer) + { + // maybe we don't need to parse the entire badIndexer... + // however, we either have { or [ to lint, not the entire table type or the bad indexer. + AstTableIndexer* badIndexer = parseTableIndexer(AstTableAccess::ReadWrite, std::nullopt); + + // we lose all additional indexer expressions from the AST after error recovery here + report(badIndexer->location, "Cannot have more than one class indexer"); + } else - report(begin.location, "String literal contains malformed escape sequence"); + { + indexer = parseTableIndexer(AstTableAccess::ReadWrite, std::nullopt); + } } else { + Location propStart = lexer.current().location; Name propName = parseName("property name"); expectAndConsume(':', "property type annotation"); - AstType* propType = parseTypeAnnotation(); - props.push_back(AstDeclaredClassProp{propName.name, propType, false}); + AstType* propType = parseType(); + props.push_back(AstDeclaredClassProp{propName.name, propName.location, propType, false, Location(propStart, lexer.previousLocation())} + ); } } Location classEnd = lexer.current().location; nextLexeme(); // skip past `end` - return allocator.alloc(Location(classStart, classEnd), className.name, superName, copy(props)); + return allocator.alloc(Location(classStart, classEnd), className.name, superName, copy(props), indexer); } else if (std::optional globalName = parseNameOpt("global variable name")) { expectAndConsume(':', "global variable declaration"); - AstType* type = parseTypeAnnotation(); - return allocator.alloc(Location(start, type->location), globalName->name, type); + AstType* type = parseType(/* in declaration context */ true); + return allocator.alloc(Location(start, type->location), globalName->name, globalName->location, type); } else { @@ -1005,7 +1259,12 @@ std::pair> Parser::prepareFunctionArguments(const // funcbody ::= `(' [parlist] `)' [`:' ReturnType] block end // parlist ::= bindinglist [`,' `...'] | `...' std::pair Parser::parseFunctionBody( - bool hasself, const Lexeme& matchFunction, const AstName& debugname, const Name* localName) + bool hasself, + const Lexeme& matchFunction, + const AstName& debugname, + const Name* localName, + const AstArray& attributes +) { Location start = matchFunction.location; @@ -1030,7 +1289,7 @@ std::pair Parser::parseFunctionBody( expectMatchAndConsume(')', matchParen, true); - std::optional typelist = parseOptionalReturnTypeAnnotation(); + std::optional typelist = parseOptionalReturnType(); AstLocal* funLocal = nullptr; @@ -1055,10 +1314,27 @@ std::pair Parser::parseFunctionBody( Location end = lexer.current().location; bool hasEnd = expectMatchEndAndConsume(Lexeme::ReservedEnd, matchFunction); - - return {allocator.alloc(Location(start, end), generics, genericPacks, self, vars, vararg, varargLocation, body, - functionStack.size(), debugname, typelist, varargAnnotation, hasEnd, argLocation), - funLocal}; + body->hasEnd = hasEnd; + + return { + allocator.alloc( + Location(start, end), + attributes, + generics, + genericPacks, + self, + vars, + vararg, + varargLocation, + body, + functionStack.size(), + debugname, + typelist, + varargAnnotation, + argLocation + ), + funLocal + }; } // explist ::= {exp `,'} exp @@ -1088,7 +1364,7 @@ Parser::Binding Parser::parseBinding() if (!name) name = Name(nameError, lexer.current().location); - AstType* annotation = parseOptionalTypeAnnotation(); + AstType* annotation = parseOptionalType(); return Binding(*name, annotation); } @@ -1107,7 +1383,7 @@ std::tuple Parser::parseBindingList(TempVector Parser::parseBindingList(TempVector& result, TempVector>& resultNames) { while (true) { - if (shouldParseTypePackAnnotation(lexer)) - return parseTypePackAnnotation(); + if (shouldParseTypePack(lexer)) + return parseTypePack(); if (lexer.current().type == Lexeme::Name && lexer.lookahead().type == ':') { @@ -1159,7 +1435,7 @@ AstTypePack* Parser::parseTypeList(TempVector& result, TempVector& result, TempVector Parser::parseOptionalReturnTypeAnnotation() +std::optional Parser::parseOptionalReturnType() { - if (options.allowTypeAnnotations && (lexer.current().type == ':' || lexer.current().type == Lexeme::SkinnyArrow)) + if (lexer.current().type == ':' || lexer.current().type == Lexeme::SkinnyArrow) { if (lexer.current().type == Lexeme::SkinnyArrow) report(lexer.current().location, "Function return type annotations are written after ':' instead of '->'"); @@ -1186,7 +1462,7 @@ std::optional Parser::parseOptionalReturnTypeAnnotation() unsigned int oldRecursionCount = recursionCounter; - auto [_location, result] = parseReturnTypeAnnotation(); + auto [_location, result] = parseReturnType(); // At this point, if we find a , character, it indicates that there are multiple return types // in this type annotation, but the list wasn't wrapped in parentheses. @@ -1205,27 +1481,27 @@ std::optional Parser::parseOptionalReturnTypeAnnotation() return std::nullopt; } -// ReturnType ::= TypeAnnotation | `(' TypeList `)' -std::pair Parser::parseReturnTypeAnnotation() +// ReturnType ::= Type | `(' TypeList `)' +std::pair Parser::parseReturnType() { incrementRecursionCounter("type annotation"); - TempVector result(scratchAnnotation); - TempVector> resultNames(scratchOptArgName); - AstTypePack* varargAnnotation = nullptr; - Lexeme begin = lexer.current(); if (lexer.current().type != '(') { - if (shouldParseTypePackAnnotation(lexer)) - varargAnnotation = parseTypePackAnnotation(); - else - result.push_back(parseTypeAnnotation()); + if (shouldParseTypePack(lexer)) + { + AstTypePack* typePack = parseTypePack(); - Location resultLocation = result.size() == 0 ? varargAnnotation->location : result[0]->location; + return {typePack->location, AstTypeList{{}, typePack}}; + } + else + { + AstType* type = parseType(); - return {resultLocation, AstTypeList{copy(result), varargAnnotation}}; + return {type->location, AstTypeList{copy(&type, 1), nullptr}}; + } } nextLexeme(); @@ -1234,6 +1510,10 @@ std::pair Parser::parseReturnTypeAnnotation() matchRecoveryStopOnToken[Lexeme::SkinnyArrow]++; + TempVector result(scratchType); + TempVector> resultNames(scratchOptArgName); + AstTypePack* varargAnnotation = nullptr; + // possibly () -> ReturnType if (lexer.current().type != ')') varargAnnotation = parseTypeList(result, resultNames); @@ -1249,9 +1529,9 @@ std::pair Parser::parseReturnTypeAnnotation() // If it turns out that it's just '(A)', it's possible that there are unions/intersections to follow, so fold over it. if (result.size() == 1) { - AstType* returnType = parseTypeAnnotation(result, innerBegin); + AstType* returnType = parseTypeSuffix(result[0], innerBegin); - // If parseTypeAnnotation parses nothing, then returnType->location.end only points at the last non-type-pack + // If parseType parses nothing, then returnType->location.end only points at the last non-type-pack // type to successfully parse. We need the span of the whole annotation. Position endPos = result.size() == 1 ? location.end : returnType->location.end; @@ -1261,45 +1541,38 @@ std::pair Parser::parseReturnTypeAnnotation() return {location, AstTypeList{copy(result), varargAnnotation}}; } - AstArray generics{nullptr, 0}; - AstArray genericPacks{nullptr, 0}; - AstArray types = copy(result); - AstArray> names = copy(resultNames); + AstType* tail = parseFunctionTypeTail(begin, {nullptr, 0}, {}, {}, copy(result), copy(resultNames), varargAnnotation); - TempVector fallbackReturnTypes(scratchAnnotation); - fallbackReturnTypes.push_back(parseFunctionTypeAnnotationTail(begin, generics, genericPacks, types, names, varargAnnotation)); - - return {Location{location, fallbackReturnTypes[0]->location}, AstTypeList{copy(fallbackReturnTypes), varargAnnotation}}; + return {Location{location, tail->location}, AstTypeList{copy(&tail, 1), varargAnnotation}}; } -// TableIndexer ::= `[' TypeAnnotation `]' `:' TypeAnnotation -AstTableIndexer* Parser::parseTableIndexerAnnotation() +// TableIndexer ::= `[' Type `]' `:' Type +AstTableIndexer* Parser::parseTableIndexer(AstTableAccess access, std::optional accessLocation) { const Lexeme begin = lexer.current(); nextLexeme(); // [ - AstType* index = parseTypeAnnotation(); + AstType* index = parseType(); expectMatchAndConsume(']', begin); expectAndConsume(':', "table field"); - AstType* result = parseTypeAnnotation(); + AstType* result = parseType(); - return allocator.alloc(AstTableIndexer{index, result, Location(begin.location, result->location)}); + return allocator.alloc(AstTableIndexer{index, result, Location(begin.location, result->location), access, accessLocation}); } -// TableProp ::= Name `:' TypeAnnotation +// TableProp ::= Name `:' Type // TablePropOrIndexer ::= TableProp | TableIndexer // PropList ::= TablePropOrIndexer {fieldsep TablePropOrIndexer} [fieldsep] -// TableTypeAnnotation ::= `{' PropList `}' -AstType* Parser::parseTableTypeAnnotation() +// TableType ::= `{' PropList `}' +AstType* Parser::parseTableType(bool inDeclarationContext) { incrementRecursionCounter("type annotation"); TempVector props(scratchTableTypeProps); AstTableIndexer* indexer = nullptr; - bool unsealed = false; Location start = lexer.current().location; @@ -1308,6 +1581,25 @@ AstType* Parser::parseTableTypeAnnotation() while (lexer.current().type != '}') { + AstTableAccess access = AstTableAccess::ReadWrite; + std::optional accessLocation; + + if (lexer.current().type == Lexeme::Name && lexer.lookahead().type != ':') + { + if (AstName(lexer.current().name) == "read") + { + accessLocation = lexer.current().location; + access = AstTableAccess::Read; + lexer.next(); + } + else if (AstName(lexer.current().name) == "write") + { + accessLocation = lexer.current().location; + access = AstTableAccess::Write; + lexer.next(); + } + } + if (lexer.current().type == '[' && (lexer.lookahead().type == Lexeme::RawString || lexer.lookahead().type == Lexeme::QuotedString)) { const Lexeme begin = lexer.current(); @@ -1317,15 +1609,15 @@ AstType* Parser::parseTableTypeAnnotation() expectMatchAndConsume(']', begin); expectAndConsume(':', "table field"); - AstType* type = parseTypeAnnotation(); + AstType* type = parseType(); - // TODO: since AstName conains a char*, it can't contain null + // since AstName contains a char*, it can't contain null bool containsNull = chars && (strnlen(chars->data, chars->size) < chars->size); if (chars && !containsNull) - props.push_back({AstName(chars->data), begin.location, type}); + props.push_back(AstTableProp{AstName(chars->data), begin.location, type, access, accessLocation}); else - report(begin.location, "String literal contains malformed escape sequence"); + report(begin.location, "String literal contains malformed escape sequence or \\0"); } else if (lexer.current().type == '[') { @@ -1333,23 +1625,23 @@ AstType* Parser::parseTableTypeAnnotation() { // maybe we don't need to parse the entire badIndexer... // however, we either have { or [ to lint, not the entire table type or the bad indexer. - AstTableIndexer* badIndexer = parseTableIndexerAnnotation(); + AstTableIndexer* badIndexer = parseTableIndexer(access, accessLocation); // we lose all additional indexer expressions from the AST after error recovery here report(badIndexer->location, "Cannot have more than one table indexer"); } else { - indexer = parseTableIndexerAnnotation(); + indexer = parseTableIndexer(access, accessLocation); } } else if (props.empty() && !indexer && !(lexer.current().type == Lexeme::Name && lexer.lookahead().type == ':')) { - AstType* type = parseTypeAnnotation(); + AstType* type = parseType(); // array-like table type: {T} desugars into {[number]: T} - AstType* index = allocator.alloc(type->location, std::nullopt, nameNumber); - indexer = allocator.alloc(AstTableIndexer{index, type, type->location}); + AstType* index = allocator.alloc(type->location, std::nullopt, nameNumber, std::nullopt, type->location); + indexer = allocator.alloc(AstTableIndexer{index, type, type->location, access, accessLocation}); break; } @@ -1362,20 +1654,14 @@ AstType* Parser::parseTableTypeAnnotation() expectAndConsume(':', "table field"); - AstType* type = parseTypeAnnotation(); + AstType* type = parseType(inDeclarationContext); - props.push_back({name->name, name->location, type}); + props.push_back(AstTableProp{name->name, name->location, type, access, accessLocation}); } - if (lexer.current().type == ';') - { - nextLexeme(); - } - else if (lexer.current().type == ',' ) + if (lexer.current().type == ',' || lexer.current().type == ';') { nextLexeme(); - if (lexer.current().type == '}') - unsealed=true; } else { @@ -1389,12 +1675,12 @@ AstType* Parser::parseTableTypeAnnotation() if (!expectMatchAndConsume('}', matchBrace)) end = lexer.previousLocation(); - return allocator.alloc(Location(start, end), copy(props), indexer, unsealed); + return allocator.alloc(Location(start, end), copy(props), indexer); } -// ReturnType ::= TypeAnnotation | `(' TypeList `)' -// FunctionTypeAnnotation ::= [`<' varlist `>'] `(' [TypeList] `)' `->` ReturnType -AstTypeOrPack Parser::parseFunctionTypeAnnotation(bool allowPack) +// ReturnType ::= Type | `(' TypeList `)' +// FunctionType ::= [`<' varlist `>'] `(' [TypeList] `)' `->` ReturnType +AstTypeOrPack Parser::parseFunctionType(bool allowPack, const AstArray& attributes) { incrementRecursionCounter("type annotation"); @@ -1410,7 +1696,7 @@ AstTypeOrPack Parser::parseFunctionTypeAnnotation(bool allowPack) matchRecoveryStopOnToken[Lexeme::SkinnyArrow]++; - TempVector params(scratchAnnotation); + TempVector params(scratchType); TempVector> names(scratchOptArgName); AstTypePack* varargAnnotation = nullptr; @@ -1442,12 +1728,18 @@ AstTypeOrPack Parser::parseFunctionTypeAnnotation(bool allowPack) AstArray> paramNames = copy(names); - return {parseFunctionTypeAnnotationTail(begin, generics, genericPacks, paramTypes, paramNames, varargAnnotation), {}}; + return {parseFunctionTypeTail(begin, attributes, generics, genericPacks, paramTypes, paramNames, varargAnnotation), {}}; } -AstType* Parser::parseFunctionTypeAnnotationTail(const Lexeme& begin, AstArray generics, AstArray genericPacks, - AstArray& params, AstArray>& paramNames, AstTypePack* varargAnnotation) - +AstType* Parser::parseFunctionTypeTail( + const Lexeme& begin, + const AstArray& attributes, + AstArray generics, + AstArray genericPacks, + AstArray params, + AstArray> paramNames, + AstTypePack* varargAnnotation +) { incrementRecursionCounter("type annotation"); @@ -1461,33 +1753,39 @@ AstType* Parser::parseFunctionTypeAnnotationTail(const Lexeme& begin, AstArray' after '()' when parsing function type; did you mean 'nil'?"); - return allocator.alloc(begin.location, std::nullopt, nameNil); + return allocator.alloc(begin.location, std::nullopt, nameNil, std::nullopt, begin.location); } else { expectAndConsume(Lexeme::SkinnyArrow, "function type"); } - auto [endLocation, returnTypeList] = parseReturnTypeAnnotation(); + auto [endLocation, returnTypeList] = parseReturnType(); AstTypeList paramTypes = AstTypeList{params, varargAnnotation}; - return allocator.alloc(Location(begin.location, endLocation), generics, genericPacks, paramTypes, paramNames, returnTypeList); + return allocator.alloc( + Location(begin.location, endLocation), attributes, generics, genericPacks, paramTypes, paramNames, returnTypeList + ); } -// typeannotation ::= +// Type ::= // nil | // Name[`.' Name] [`<' namelist `>'] | // `{' [PropList] `}' | // `(' [TypeList] `)' `->` ReturnType -// `typeof` typeannotation -AstType* Parser::parseTypeAnnotation(TempVector& parts, const Location& begin) +// `typeof` Type +AstType* Parser::parseTypeSuffix(AstType* type, const Location& begin) { - LUAU_ASSERT(!parts.empty()); + TempVector parts(scratchType); + + if (type != nullptr) + parts.push_back(type); incrementRecursionCounter("type annotation"); bool isUnion = false; bool isIntersection = false; + bool hasOptional = false; Location location = begin; @@ -1497,20 +1795,34 @@ AstType* Parser::parseTypeAnnotation(TempVector& parts, const Location if (c == '|') { nextLexeme(); - parts.push_back(parseSimpleTypeAnnotation(/* allowPack= */ false).type); + + unsigned int oldRecursionCount = recursionCounter; + parts.push_back(parseSimpleType(/* allowPack= */ false).type); + recursionCounter = oldRecursionCount; + isUnion = true; } else if (c == '?') { + LUAU_ASSERT(parts.size() >= 1); + Location loc = lexer.current().location; nextLexeme(); - parts.push_back(allocator.alloc(loc, std::nullopt, nameNil)); + + if (!hasOptional) + parts.push_back(allocator.alloc(loc, std::nullopt, nameNil, std::nullopt, loc)); + isUnion = true; + hasOptional = true; } else if (c == '&') { nextLexeme(); - parts.push_back(parseSimpleTypeAnnotation(/* allowPack= */ false).type); + + unsigned int oldRecursionCount = recursionCounter; + parts.push_back(parseSimpleType(/* allowPack= */ false).type); + recursionCounter = oldRecursionCount; + isIntersection = true; } else if (c == Lexeme::Dot3) @@ -1520,6 +1832,9 @@ AstType* Parser::parseTypeAnnotation(TempVector& parts, const Location } else break; + + if (parts.size() > unsigned(FInt::LuauTypeLengthLimit) + hasOptional) + ParseError::raise(parts.back()->location, "Exceeded allowed type length; simplify your type annotation to make the code compile"); } if (parts.size() == 1) @@ -1527,8 +1842,11 @@ AstType* Parser::parseTypeAnnotation(TempVector& parts, const Location if (isUnion && isIntersection) { - return reportTypeAnnotationError(Location(begin, parts.back()->location), copy(parts), - "Mixing union and intersection types is not allowed; consider wrapping in parentheses."); + return reportTypeError( + Location(begin, parts.back()->location), + copy(parts), + "Mixing union and intersection types is not allowed; consider wrapping in parentheses." + ); } location.end = parts.back()->location.end; @@ -1543,16 +1861,14 @@ AstType* Parser::parseTypeAnnotation(TempVector& parts, const Location ParseError::raise(begin, "Composite type was not an intersection or union."); } -AstTypeOrPack Parser::parseTypeOrPackAnnotation() +AstTypeOrPack Parser::parseTypeOrPack() { unsigned int oldRecursionCount = recursionCounter; - incrementRecursionCounter("type annotation"); + // recursion counter is incremented in parseSimpleType Location begin = lexer.current().location; - TempVector parts(scratchAnnotation); - - auto [type, typePack] = parseSimpleTypeAnnotation(/* allowPack= */ true); + auto [type, typePack] = parseSimpleType(/* allowPack= */ true); if (typePack) { @@ -1560,40 +1876,59 @@ AstTypeOrPack Parser::parseTypeOrPackAnnotation() return {{}, typePack}; } - parts.push_back(type); - recursionCounter = oldRecursionCount; - return {parseTypeAnnotation(parts, begin), {}}; + return {parseTypeSuffix(type, begin), {}}; } -AstType* Parser::parseTypeAnnotation() +AstType* Parser::parseType(bool inDeclarationContext) { unsigned int oldRecursionCount = recursionCounter; - incrementRecursionCounter("type annotation"); + // recursion counter is incremented in parseSimpleType and/or parseTypeSuffix Location begin = lexer.current().location; - TempVector parts(scratchAnnotation); - parts.push_back(parseSimpleTypeAnnotation(/* allowPack= */ false).type); + AstType* type = nullptr; + Lexeme::Type c = lexer.current().type; + if (c != '|' && c != '&') + { + type = parseSimpleType(/* allowPack= */ false, /* in declaration context */ inDeclarationContext).type; + recursionCounter = oldRecursionCount; + } + + AstType* typeWithSuffix = parseTypeSuffix(type, begin); recursionCounter = oldRecursionCount; - return parseTypeAnnotation(parts, begin); + return typeWithSuffix; } -// typeannotation ::= nil | Name[`.' Name] [ `<' typeannotation [`,' ...] `>' ] | `typeof' `(' expr `)' | `{' [PropList] `}' +// Type ::= nil | Name[`.' Name] [ `<' Type [`,' ...] `>' ] | `typeof' `(' expr `)' | `{' [PropList] `}' // | [`<' varlist `>'] `(' [TypeList] `)' `->` ReturnType -AstTypeOrPack Parser::parseSimpleTypeAnnotation(bool allowPack) +AstTypeOrPack Parser::parseSimpleType(bool allowPack, bool inDeclarationContext) { incrementRecursionCounter("type annotation"); Location start = lexer.current().location; - if (lexer.current().type == Lexeme::ReservedNil) + AstArray attributes{nullptr, 0}; + + if (lexer.current().type == Lexeme::Attribute) + { + if (!inDeclarationContext) + { + return {reportTypeError(start, {}, "attributes are not allowed in declaration context")}; + } + else + { + attributes = Parser::parseAttributes(); + return parseFunctionType(allowPack, attributes); + } + } + else if (lexer.current().type == Lexeme::ReservedNil) { nextLexeme(); - return {allocator.alloc(start, std::nullopt, nameNil), {}}; + return {allocator.alloc(start, std::nullopt, nameNil, std::nullopt, start), {}}; } else if (lexer.current().type == Lexeme::ReservedTrue) { @@ -1613,22 +1948,23 @@ AstTypeOrPack Parser::parseSimpleTypeAnnotation(bool allowPack) return {allocator.alloc(start, svalue)}; } else - return {reportTypeAnnotationError(start, {}, "String literal contains malformed escape sequence")}; + return {reportTypeError(start, {}, "String literal contains malformed escape sequence")}; } else if (lexer.current().type == Lexeme::InterpStringBegin || lexer.current().type == Lexeme::InterpStringSimple) { parseInterpString(); - return {reportTypeAnnotationError(start, {}, "Interpolated string literals cannot be used as types")}; + return {reportTypeError(start, {}, "Interpolated string literals cannot be used as types")}; } else if (lexer.current().type == Lexeme::BrokenString) { nextLexeme(); - return {reportTypeAnnotationError(start, {}, "Malformed string")}; + return {reportTypeError(start, {}, "Malformed string; did you forget to finish it?")}; } else if (lexer.current().type == Lexeme::Name) { std::optional prefix; + std::optional prefixLocation; Name name = parseName("type name"); if (lexer.current().type == '.') @@ -1637,6 +1973,7 @@ AstTypeOrPack Parser::parseSimpleTypeAnnotation(bool allowPack) nextLexeme(); prefix = name.name; + prefixLocation = name.location; name = parseIndexName("field name", pointPosition); } else if (lexer.current().type == Lexeme::Dot3) @@ -1669,24 +2006,31 @@ AstTypeOrPack Parser::parseSimpleTypeAnnotation(bool allowPack) Location end = lexer.previousLocation(); - return {allocator.alloc(Location(start, end), prefix, name.name, hasParameters, parameters), {}}; + return { + allocator.alloc(Location(start, end), prefix, name.name, prefixLocation, name.location, hasParameters, parameters), {} + }; } else if (lexer.current().type == '{') { - return {parseTableTypeAnnotation(), {}}; + return {parseTableType(/* inDeclarationContext */ inDeclarationContext), {}}; } else if (lexer.current().type == '(' || lexer.current().type == '<') { - return parseFunctionTypeAnnotation(allowPack); + return parseFunctionType(allowPack, AstArray({nullptr, 0})); } else if (lexer.current().type == Lexeme::ReservedFunction) { nextLexeme(); - return {reportTypeAnnotationError(start, {}, - "Using 'function' as a type annotation is not supported, consider replacing with a function type annotation e.g. '(...any) -> " - "...any'"), - {}}; + return { + reportTypeError( + start, + {}, + "Using 'function' as a type annotation is not supported, consider replacing with a function type annotation e.g. '(...any) -> " + "...any'" + ), + {} + }; } else { @@ -1695,12 +2039,11 @@ AstTypeOrPack Parser::parseSimpleTypeAnnotation(bool allowPack) // The parse error includes the next lexeme to make it easier to display where the error is (e.g. in an IDE or a CLI error message). // Including the current lexeme also makes the parse error consistent with other parse errors returned by Luau. Location parseErrorLocation(lexer.previousLocation().end, start.end); - return { - reportMissingTypeAnnotationError(parseErrorLocation, astErrorlocation, "Expected type, got %s", lexer.current().toString().c_str()), {}}; + return {reportMissingTypeError(parseErrorLocation, astErrorlocation, "Expected type, got %s", lexer.current().toString().c_str()), {}}; } } -AstTypePack* Parser::parseVariadicArgumentAnnotation() +AstTypePack* Parser::parseVariadicArgumentTypePack() { // Generic: a... if (lexer.current().type == Lexeme::Name && lexer.lookahead().type == Lexeme::Dot3) @@ -1715,19 +2058,19 @@ AstTypePack* Parser::parseVariadicArgumentAnnotation() // Variadic: T else { - AstType* variadicAnnotation = parseTypeAnnotation(); + AstType* variadicAnnotation = parseType(); return allocator.alloc(variadicAnnotation->location, variadicAnnotation); } } -AstTypePack* Parser::parseTypePackAnnotation() +AstTypePack* Parser::parseTypePack() { // Variadic: ...T if (lexer.current().type == Lexeme::Dot3) { Location start = lexer.current().location; nextLexeme(); - AstType* varargTy = parseTypeAnnotation(); + AstType* varargTy = parseType(); return allocator.alloc(Location(start, varargTy->location), varargTy); } // Generic: a... @@ -1741,7 +2084,8 @@ AstTypePack* Parser::parseTypePackAnnotation() return allocator.alloc(Location(name.location, end), name.name); } - // No type pack annotation exists here. + // TODO: shouldParseTypePack can be removed and parseTypePack can be called unconditionally instead + LUAU_ASSERT(!"parseTypePack can't be called if shouldParseTypePack() returned false"); return nullptr; } @@ -1773,6 +2117,8 @@ std::optional Parser::parseBinaryOp(const Lexeme& l) return AstExprBinary::Mul; else if (l.type == '/') return AstExprBinary::Div; + else if (l.type == Lexeme::FloorDiv) + return AstExprBinary::FloorDiv; else if (l.type == '%') return AstExprBinary::Mod; else if (l.type == '^') @@ -1796,8 +2142,6 @@ std::optional Parser::parseBinaryOp(const Lexeme& l) else if (l.type == Lexeme::ReservedOr) return AstExprBinary::Or; //GIDEROS Added - else if (l.type == Lexeme::DivInt) - return AstExprBinary::DivInt; else if (l.type == Lexeme::MaxOf) return AstExprBinary::MaxOf; else if (l.type == Lexeme::MinOf) @@ -1826,6 +2170,8 @@ std::optional Parser::parseCompoundOp(const Lexeme& l) return AstExprBinary::Mul; else if (l.type == Lexeme::DivAssign) return AstExprBinary::Div; + else if (l.type == Lexeme::FloorDivAssign) + return AstExprBinary::FloorDiv; else if (l.type == Lexeme::ModAssign) return AstExprBinary::Mod; else if (l.type == Lexeme::PowAssign) @@ -1849,7 +2195,7 @@ std::optional Parser::checkUnaryConfusables() if (curr.type == '!') { - report(start, "Unexpected '!', did you mean 'not'?"); + report(start, "Unexpected '!'; did you mean 'not'?"); return AstExprUnary::Not; } @@ -1871,20 +2217,19 @@ std::optional Parser::checkBinaryConfusables(const BinaryOpPr if (curr.type == '&' && next.type == '&' && curr.location.end == next.location.begin && binaryPriority[AstExprBinary::And].left > limit) { nextLexeme(); - report(Location(start, next.location), "Unexpected '&&', did you mean 'and'?"); + report(Location(start, next.location), "Unexpected '&&'; did you mean 'and'?"); return AstExprBinary::And; } else if (curr.type == '|' && next.type == '|' && curr.location.end == next.location.begin && binaryPriority[AstExprBinary::Or].left > limit) { nextLexeme(); - report(Location(start, next.location), "Unexpected '||', did you mean 'or'?"); + report(Location(start, next.location), "Unexpected '||'; did you mean 'or'?"); return AstExprBinary::Or; } - else if (curr.type == '!' && next.type == '=' && curr.location.end == next.location.begin && - binaryPriority[AstExprBinary::CompareNe].left > limit) + else if (curr.type == '!' && next.type == '=' && curr.location.end == next.location.begin && binaryPriority[AstExprBinary::CompareNe].left > limit) { nextLexeme(); - report(Location(start, next.location), "Unexpected '!=', did you mean '~='?"); + report(Location(start, next.location), "Unexpected '!='; did you mean '~='?"); return AstExprBinary::CompareNe; } @@ -1896,16 +2241,18 @@ std::optional Parser::checkBinaryConfusables(const BinaryOpPr AstExpr* Parser::parseExpr(unsigned int limit) { static const BinaryOpPriority binaryPriority[] = { - {6, 6}, {6, 6}, {7, 7}, {7, 7}, {7, 7}, // `+' `-' `*' `/' `%' + {6, 6}, {6, 6}, {7, 7}, {7, 7}, {7, 7}, {7, 7}, // `+' `-' `*' `/' `//' `%' {10, 9}, {5, 4}, // power and concat (right associative) {3, 3}, {3, 3}, // equality and inequality {3, 3}, {3, 3}, {3, 3}, {3, 3}, // order {2, 2}, {1, 1}, // logical (and/or) - {7, 7}, {7, 7}, {7, 7}, // GIDEROS (DivInt, MaxOf, MinOf) + {7, 7}, {7, 7}, // GIDEROS (MaxOf, MinOf) {6, 6}, {6, 6}, {6, 6}, {7, 7}, {7, 7} // GIDEROS (&,|,~,>>,<<) }; - unsigned int recursionCounterOld = recursionCounter; + static_assert(sizeof(binaryPriority) / sizeof(binaryPriority[0]) == size_t(AstExprBinary::Op__Count), "binaryPriority needs an entry per op"); + + unsigned int oldRecursionCount = recursionCounter; // this handles recursive calls to parseSubExpr/parseExpr incrementRecursionCounter("expression"); @@ -1963,7 +2310,7 @@ AstExpr* Parser::parseExpr(unsigned int limit) incrementRecursionCounter("expression"); } - recursionCounter = recursionCounterOld; + recursionCounter = oldRecursionCount; return expr; } @@ -1982,6 +2329,12 @@ AstExpr* Parser::parseNameExpr(const char* context) { AstLocal* local = *value; + if (FFlag::LuauUserDefinedTypeFunctionsSyntax2) + { + if (local->functionDepth < typeFunctionDepth) + return reportExprError(lexer.current().location, {}, "Type function cannot reference outer local '%s'", local->name.value); + } + return allocator.alloc(name->location, local, local->functionDepth != functionStack.size() - 1); } @@ -2030,7 +2383,7 @@ AstExpr* Parser::parsePrimaryExpr(bool asStatement) AstExpr* expr = parsePrefixExpr(); - unsigned int recursionCounterOld = recursionCounter; + unsigned int oldRecursionCount = recursionCounter; while (true) { @@ -2096,21 +2449,21 @@ AstExpr* Parser::parsePrimaryExpr(bool asStatement) incrementRecursionCounter("expression"); } - recursionCounter = recursionCounterOld; + recursionCounter = oldRecursionCount; return expr; } -// asexp -> simpleexp [`::' typeannotation] +// asexp -> simpleexp [`::' Type] AstExpr* Parser::parseAssertionExpr() { Location start = lexer.current().location; AstExpr* expr = parseSimpleExpr(); - if (options.allowTypeAnnotations && lexer.current().type == Lexeme::DoubleColon) + if (lexer.current().type == Lexeme::DoubleColon) { nextLexeme(); - AstType* annotation = parseTypeAnnotation(); + AstType* annotation = parseType(); return allocator.alloc(Location(start, annotation->location), expr, annotation); } else @@ -2140,6 +2493,9 @@ static ConstantNumberParseResult parseInteger(double& result, const char* data, return base == 2 ? ConstantNumberParseResult::BinOverflow : ConstantNumberParseResult::HexOverflow; } + if (value >= (1ull << 53) && static_cast(result) != value) + return ConstantNumberParseResult::Imprecise; + return ConstantNumberParseResult::Ok; } @@ -2156,15 +2512,45 @@ static ConstantNumberParseResult parseDouble(double& result, const char* data) char* end = nullptr; double value = strtod(data, &end); + // trailing non-numeric characters + if (*end != 0) + return ConstantNumberParseResult::Malformed; + result = value; - return *end == 0 ? ConstantNumberParseResult::Ok : ConstantNumberParseResult::Malformed; + + // for linting, we detect integer constants that are parsed imprecisely + // since the check is expensive we only perform it when the number is larger than the precise integer range + if (value >= double(1ull << 53) && strspn(data, "0123456789") == strlen(data)) + { + char repr[512]; + snprintf(repr, sizeof(repr), "%.0f", value); + + if (strcmp(repr, data) != 0) + return ConstantNumberParseResult::Imprecise; + } + + return ConstantNumberParseResult::Ok; } -// simpleexp -> NUMBER | STRING | NIL | true | false | ... | constructor | FUNCTION body | primaryexp +// simpleexp -> NUMBER | STRING | NIL | true | false | ... | constructor | [attributes] FUNCTION body | primaryexp AstExpr* Parser::parseSimpleExpr() { Location start = lexer.current().location; + AstArray attributes{nullptr, 0}; + + if (FFlag::LuauAttributeSyntaxFunExpr && lexer.current().type == Lexeme::Attribute) + { + attributes = parseAttributes(); + + if (lexer.current().type != Lexeme::ReservedFunction) + { + return reportExprError( + start, {}, "Expected 'function' declaration after attribute, but got %s instead", lexer.current().toString().c_str() + ); + } + } + if (lexer.current().type == Lexeme::ReservedNil) { nextLexeme(); @@ -2188,14 +2574,13 @@ AstExpr* Parser::parseSimpleExpr() Lexeme matchFunction = lexer.current(); nextLexeme(); - return parseFunctionBody(false, matchFunction, AstName(), nullptr).first; + return parseFunctionBody(false, matchFunction, AstName(), nullptr, attributes).first; } else if (lexer.current().type == Lexeme::Number) { return parseNumber(); } - else if (lexer.current().type == Lexeme::RawString || lexer.current().type == Lexeme::QuotedString || - lexer.current().type == Lexeme::InterpStringSimple) + else if (lexer.current().type == Lexeme::RawString || lexer.current().type == Lexeme::QuotedString || lexer.current().type == Lexeme::InterpStringSimple) { return parseString(); } @@ -2206,12 +2591,12 @@ AstExpr* Parser::parseSimpleExpr() else if (lexer.current().type == Lexeme::BrokenString) { nextLexeme(); - return reportExprError(start, {}, "Malformed string"); + return reportExprError(start, {}, "Malformed string; did you forget to finish it?"); } else if (lexer.current().type == Lexeme::BrokenInterpDoubleBrace) { nextLexeme(); - return reportExprError(start, {}, ERROR_INVALID_INTERP_DOUBLE_BRACE); + return reportExprError(start, {}, "Double braces are not permitted within interpolated strings; did you mean '\\{'?"); } else if (lexer.current().type == Lexeme::Dot3) { @@ -2295,15 +2680,22 @@ LUAU_NOINLINE AstExpr* Parser::reportFunctionArgsError(AstExpr* func, bool self) } else { - return reportExprError(Location(func->location.begin, lexer.current().location.begin), copy({func}), - "Expected '(', '{' or when parsing function call, got %s", lexer.current().toString().c_str()); + return reportExprError( + Location(func->location.begin, lexer.current().location.begin), + copy({func}), + "Expected '(', '{' or when parsing function call, got %s", + lexer.current().toString().c_str() + ); } } LUAU_NOINLINE void Parser::reportAmbiguousCallError() { - report(lexer.current().location, "Ambiguous syntax: this looks like an argument list for a function call, but could also be a start of " - "new statement; use ';' to separate statements"); + report( + lexer.current().location, + "Ambiguous syntax: this looks like an argument list for a function call, but could also be a start of " + "new statement; use ';' to separate statements" + ); } // tableconstructor ::= `{' [fieldlist] `}' @@ -2349,7 +2741,7 @@ AstExpr* Parser::parseTableConstructor() nameString.data = const_cast(name.name.value); nameString.size = strlen(name.name.value); - AstExpr* key = allocator.alloc(name.location, nameString); + AstExpr* key = allocator.alloc(name.location, nameString, AstExprConstantString::Unquoted); AstExpr* value = parseExpr(); if (AstExprFunction* func = value->as()) @@ -2500,26 +2892,15 @@ std::pair, AstArray> Parser::parseG seenDefault = true; nextLexeme(); - Lexeme packBegin = lexer.current(); - - if (shouldParseTypePackAnnotation(lexer)) + if (shouldParseTypePack(lexer)) { - AstTypePack* typePack = parseTypePackAnnotation(); - - namePacks.push_back({name, nameLocation, typePack}); - } - else if (!FFlag::LuauParserErrorsOnMissingDefaultTypePackArgument && lexer.current().type == '(') - { - auto [type, typePack] = parseTypeOrPackAnnotation(); - - if (type) - report(Location(packBegin.location.begin, lexer.previousLocation().end), "Expected type pack after '=', got type"); + AstTypePack* typePack = parseTypePack(); namePacks.push_back({name, nameLocation, typePack}); } - else if (FFlag::LuauParserErrorsOnMissingDefaultTypePackArgument) + else { - auto [type, typePack] = parseTypeOrPackAnnotation(); + auto [type, typePack] = parseTypeOrPack(); if (type) report(type->location, "Expected type pack after '=', got type"); @@ -2542,7 +2923,7 @@ std::pair, AstArray> Parser::parseG seenDefault = true; nextLexeme(); - AstType* defaultType = parseTypeAnnotation(); + AstType* defaultType = parseType(); names.push_back({name, nameLocation, defaultType}); } @@ -2579,7 +2960,7 @@ std::pair, AstArray> Parser::parseG AstArray Parser::parseTypeParams() { - TempVector parameters{scratchTypeOrPackAnnotation}; + TempVector parameters{scratchTypeOrPack}; if (lexer.current().type == '<') { @@ -2588,15 +2969,15 @@ AstArray Parser::parseTypeParams() while (true) { - if (shouldParseTypePackAnnotation(lexer)) + if (shouldParseTypePack(lexer)) { - AstTypePack* typePack = parseTypePackAnnotation(); + AstTypePack* typePack = parseTypePack(); parameters.push_back({{}, typePack}); } else if (lexer.current().type == '(') { - auto [type, typePack] = parseTypeOrPackAnnotation(); + auto [type, typePack] = parseTypeOrPack(); if (typePack) parameters.push_back({{}, typePack}); @@ -2609,7 +2990,7 @@ AstArray Parser::parseTypeParams() } else { - parameters.push_back({parseTypeAnnotation(), {}}); + parameters.push_back({parseType(), {}}); } if (lexer.current().type == ',') @@ -2626,10 +3007,12 @@ AstArray Parser::parseTypeParams() std::optional> Parser::parseCharArray() { - LUAU_ASSERT(lexer.current().type == Lexeme::QuotedString || lexer.current().type == Lexeme::RawString || - lexer.current().type == Lexeme::InterpStringSimple); + LUAU_ASSERT( + lexer.current().type == Lexeme::QuotedString || lexer.current().type == Lexeme::RawString || + lexer.current().type == Lexeme::InterpStringSimple + ); - scratchData.assign(lexer.current().data, lexer.current().length); + scratchData.assign(lexer.current().data, lexer.current().getLength()); if (lexer.current().type == Lexeme::QuotedString || lexer.current().type == Lexeme::InterpStringSimple) { @@ -2669,12 +3052,14 @@ AstExpr* Parser::parseInterpString() do { Lexeme currentLexeme = lexer.current(); - LUAU_ASSERT(currentLexeme.type == Lexeme::InterpStringBegin || currentLexeme.type == Lexeme::InterpStringMid || - currentLexeme.type == Lexeme::InterpStringEnd || currentLexeme.type == Lexeme::InterpStringSimple); + LUAU_ASSERT( + currentLexeme.type == Lexeme::InterpStringBegin || currentLexeme.type == Lexeme::InterpStringMid || + currentLexeme.type == Lexeme::InterpStringEnd || currentLexeme.type == Lexeme::InterpStringSimple + ); endLocation = currentLexeme.location; - scratchData.assign(currentLexeme.data, currentLexeme.length); + scratchData.assign(currentLexeme.data, currentLexeme.getLength()); if (!Lexer::fixupQuotedString(scratchData)) { @@ -2709,7 +3094,7 @@ AstExpr* Parser::parseInterpString() { errorWhileChecking = true; nextLexeme(); - expressions.push_back(reportExprError(endLocation, {}, "Malformed interpolated string, did you forget to add a '`'?")); + expressions.push_back(reportExprError(endLocation, {}, "Malformed interpolated string; did you forget to add a '`'?")); break; } default: @@ -2729,10 +3114,10 @@ AstExpr* Parser::parseInterpString() break; case Lexeme::BrokenInterpDoubleBrace: nextLexeme(); - return reportExprError(endLocation, {}, ERROR_INVALID_INTERP_DOUBLE_BRACE); + return reportExprError(endLocation, {}, "Double braces are not permitted within interpolated strings; did you mean '\\{'?"); case Lexeme::BrokenString: nextLexeme(); - return reportExprError(endLocation, {}, "Malformed interpolated string, did you forget to add a '}'?"); + return reportExprError(endLocation, {}, "Malformed interpolated string; did you forget to add a '}'?"); default: return reportExprError(endLocation, {}, "Malformed interpolated string, got %s", lexer.current().toString().c_str()); } @@ -2747,7 +3132,7 @@ AstExpr* Parser::parseNumber() { Location start = lexer.current().location; - scratchData.assign(lexer.current().data, lexer.current().length); + scratchData.assign(lexer.current().data, lexer.current().getLength()); // Remove all internal _ - they don't hold any meaning and this allows parsing code to just pass the string pointer to strtod et al if (scratchData.find('_') != std::string::npos) @@ -2771,7 +3156,8 @@ AstLocal* Parser::pushLocal(const Binding& binding) AstLocal*& local = localMap[name.name]; local = allocator.alloc( - name.name, name.location, /* shadow= */ local, functionStack.size() - 1, functionStack.back().loopDepth, binding.annotation); + name.name, name.location, /* shadow= */ local, functionStack.size() - 1, functionStack.back().loopDepth, binding.annotation + ); localStack.push_back(local); @@ -2904,11 +3290,25 @@ LUAU_NOINLINE void Parser::expectMatchAndConsumeFail(Lexeme::Type type, const Ma std::string matchString = Lexeme(Location(Position(0, 0), 0), begin.type).toString(); if (lexer.current().location.begin.line == begin.position.line) - report(lexer.current().location, "Expected %s (to close %s at column %d), got %s%s", typeString.c_str(), matchString.c_str(), - begin.position.column + 1, lexer.current().toString().c_str(), extra ? extra : ""); + report( + lexer.current().location, + "Expected %s (to close %s at column %d), got %s%s", + typeString.c_str(), + matchString.c_str(), + begin.position.column + 1, + lexer.current().toString().c_str(), + extra ? extra : "" + ); else - report(lexer.current().location, "Expected %s (to close %s at line %d), got %s%s", typeString.c_str(), matchString.c_str(), - begin.position.line + 1, lexer.current().toString().c_str(), extra ? extra : ""); + report( + lexer.current().location, + "Expected %s (to close %s at line %d), got %s%s", + typeString.c_str(), + matchString.c_str(), + begin.position.line + 1, + lexer.current().toString().c_str(), + extra ? extra : "" + ); } bool Parser::expectMatchEndAndConsume(Lexeme::Type type, const MatchLexeme& begin) @@ -3045,7 +3445,12 @@ LUAU_NOINLINE void Parser::reportNameError(const char* context) } AstStatError* Parser::reportStatError( - const Location& location, const AstArray& expressions, const AstArray& statements, const char* format, ...) + const Location& location, + const AstArray& expressions, + const AstArray& statements, + const char* format, + ... +) { va_list args; va_start(args, format); @@ -3065,7 +3470,7 @@ AstExprError* Parser::reportExprError(const Location& location, const AstArray(location, expressions, unsigned(parseErrors.size() - 1)); } -AstTypeError* Parser::reportTypeAnnotationError(const Location& location, const AstArray& types, const char* format, ...) +AstTypeError* Parser::reportTypeError(const Location& location, const AstArray& types, const char* format, ...) { va_list args; va_start(args, format); @@ -3075,7 +3480,7 @@ AstTypeError* Parser::reportTypeAnnotationError(const Location& location, const return allocator.alloc(location, types, false, unsigned(parseErrors.size() - 1)); } -AstTypeError* Parser::reportMissingTypeAnnotationError(const Location& parseErrorLocation, const Location& astErrorLocation, const char* format, ...) +AstTypeError* Parser::reportMissingTypeError(const Location& parseErrorLocation, const Location& astErrorLocation, const char* format, ...) { va_list args; va_start(args, format); @@ -3102,11 +3507,11 @@ void Parser::nextLexeme() return; // Comments starting with ! are called "hot comments" and contain directives for type checking / linting / compiling - if (lexeme.type == Lexeme::Comment && lexeme.length && lexeme.data[0] == '!') + if (lexeme.type == Lexeme::Comment && lexeme.getLength() && lexeme.data[0] == '!') { const char* text = lexeme.data; - unsigned int end = lexeme.length; + unsigned int end = lexeme.getLength(); while (end > 0 && isSpace(text[end - 1])) --end; diff --git a/Ast/src/StringUtils.cpp b/Ast/src/StringUtils.cpp index 0df87c971..830bf57e8 100644 --- a/Ast/src/StringUtils.cpp +++ b/Ast/src/StringUtils.cpp @@ -7,7 +7,7 @@ #include #include #include -#include +#include namespace Luau { @@ -141,7 +141,8 @@ size_t editDistance(std::string_view a, std::string_view b) size_t maxDistance = a.size() + b.size(); std::vector distances((a.size() + 2) * (b.size() + 2), 0); - auto getPos = [b](size_t x, size_t y) -> size_t { + auto getPos = [b](size_t x, size_t y) -> size_t + { return (x * (b.size() + 2)) + y; }; @@ -168,7 +169,9 @@ size_t editDistance(std::string_view a, std::string_view b) for (size_t y = 1; y <= b.size(); ++y) { - size_t x1 = seenCharToRow[b[y - 1]]; + // The value of b[N] can be negative with unicode characters + unsigned char bSeenCharIndex = static_cast(b[y - 1]); + size_t x1 = seenCharToRow[bSeenCharIndex]; size_t y1 = lastMatchedY; size_t cost = 1; @@ -188,7 +191,9 @@ size_t editDistance(std::string_view a, std::string_view b) distances[getPos(x + 1, y + 1)] = std::min(std::min(insertion, deletion), std::min(substitution, transposition)); } - seenCharToRow[a[x - 1]] = x; + // The value of a[N] can be negative with unicode characters + unsigned char aSeenCharIndex = static_cast(a[x - 1]); + seenCharToRow[aSeenCharIndex] = x; } return distances[getPos(a.size() + 1, b.size() + 1)]; diff --git a/Ast/src/TimeTrace.cpp b/Ast/src/TimeTrace.cpp index e38076830..8bccffce2 100644 --- a/Ast/src/TimeTrace.cpp +++ b/Ast/src/TimeTrace.cpp @@ -3,6 +3,7 @@ #include "Luau/StringUtils.h" +#include #include #include @@ -15,7 +16,7 @@ #ifndef NOMINMAX #define NOMINMAX #endif -#include +#include #endif #ifdef __APPLE__ @@ -40,7 +41,7 @@ static double getClockPeriod() mach_timebase_info_data_t result = {}; mach_timebase_info(&result); return double(result.numer) / double(result.denom) * 1e-9; -#elif defined(__linux__) +#elif defined(__linux__) || defined(__FreeBSD__) return 1e-9; #else return 1.0 / double(CLOCKS_PER_SEC); @@ -55,7 +56,7 @@ static double getClockTimestamp() return double(result.QuadPart); #elif defined(__APPLE__) return double(mach_absolute_time()); -#elif defined(__linux__) +#elif defined(__linux__) || defined(__FreeBSD__) timespec now; clock_gettime(CLOCK_MONOTONIC, &now); return now.tv_sec * 1e9 + now.tv_nsec; @@ -90,17 +91,8 @@ namespace TimeTrace { struct GlobalContext { - GlobalContext() = default; ~GlobalContext() { - // Ideally we would want all ThreadContext destructors to run - // But in VS, not all thread_local object instances are destroyed - for (ThreadContext* context : threads) - { - if (!context->events.empty()) - context->flushEvents(); - } - if (traceFile) fclose(traceFile); } @@ -110,11 +102,15 @@ struct GlobalContext uint32_t nextThreadId = 0; std::vector tokens; FILE* traceFile = nullptr; + +private: + friend std::shared_ptr getGlobalContext(); + GlobalContext() = default; }; -GlobalContext& getGlobalContext() +std::shared_ptr getGlobalContext() { - static GlobalContext context; + static std::shared_ptr context = std::shared_ptr{new GlobalContext}; return context; } @@ -189,8 +185,14 @@ void flushEvents(GlobalContext& context, uint32_t threadId, const std::vector +#include +#include +#include +#include +#include +#include + #ifdef CALLGRIND #include #endif LUAU_FASTFLAG(DebugLuauTimeTracing) +LUAU_FASTFLAG(DebugLuauLogSolverToJsonFile) enum class ReportFormat { @@ -55,8 +64,13 @@ static void reportError(const Luau::Frontend& frontend, ReportFormat format, con if (const Luau::SyntaxError* syntaxError = Luau::get_if(&error.data)) report(format, humanReadableName.c_str(), error.location, "SyntaxError", syntaxError->message.c_str()); else - report(format, humanReadableName.c_str(), error.location, "TypeError", - Luau::toString(error, Luau::TypeErrorToStringOptions{frontend.fileResolver}).c_str()); + report( + format, + humanReadableName.c_str(), + error.location, + "TypeError", + Luau::toString(error, Luau::TypeErrorToStringOptions{frontend.fileResolver}).c_str() + ); } static void reportWarning(ReportFormat format, const char* name, const Luau::LintWarning& warning) @@ -64,28 +78,29 @@ static void reportWarning(ReportFormat format, const char* name, const Luau::Lin report(format, name, warning.location, Luau::LintWarning::getName(warning.code), warning.text.c_str()); } -static bool analyzeFile(Luau::Frontend& frontend, const char* name, ReportFormat format, bool annotate) +static bool reportModuleResult(Luau::Frontend& frontend, const Luau::ModuleName& name, ReportFormat format, bool annotate) { - Luau::CheckResult cr; + std::optional cr = frontend.getCheckResult(name, false); - if (frontend.isDirty(name)) - cr = frontend.check(name); + if (!cr) + { + fprintf(stderr, "Failed to find result for %s\n", name.c_str()); + return false; + } if (!frontend.getSourceModule(name)) { - fprintf(stderr, "Error opening %s\n", name); + fprintf(stderr, "Error opening %s\n", name.c_str()); return false; } - for (auto& error : cr.errors) + for (auto& error : cr->errors) reportError(frontend, format, error); - Luau::LintResult lr = frontend.lint(name); - std::string humanReadableName = frontend.fileResolver->getHumanReadableModuleName(name); - for (auto& error : lr.errors) + for (auto& error : cr->lintResult.errors) reportWarning(format, humanReadableName.c_str(), error); - for (auto& warning : lr.warnings) + for (auto& warning : cr->lintResult.warnings) reportWarning(format, humanReadableName.c_str(), warning); if (annotate) @@ -100,7 +115,7 @@ static bool analyzeFile(Luau::Frontend& frontend, const char* name, ReportFormat printf("%s", annotated.c_str()); } - return cr.errors.empty() && lr.errors.empty(); + return cr->errors.empty() && cr->lintResult.errors.empty(); } static void displayHelp(const char* argv0) @@ -121,6 +136,7 @@ static void displayHelp(const char* argv0) static int assertionHandler(const char* expr, const char* file, int line, const char* function) { printf("%s(%d): ASSERTION FAILED: %s\n", file, line, expr); + fflush(stdout); return 1; } @@ -217,6 +233,77 @@ struct CliConfigResolver : Luau::ConfigResolver } }; +struct TaskScheduler +{ + TaskScheduler(unsigned threadCount) + : threadCount(threadCount) + { + for (unsigned i = 0; i < threadCount; i++) + { + workers.emplace_back( + [this] + { + workerFunction(); + } + ); + } + } + + ~TaskScheduler() + { + for (unsigned i = 0; i < threadCount; i++) + push({}); + + for (std::thread& worker : workers) + worker.join(); + } + + std::function pop() + { + std::unique_lock guard(mtx); + + cv.wait( + guard, + [this] + { + return !tasks.empty(); + } + ); + + std::function task = tasks.front(); + tasks.pop(); + return task; + } + + void push(std::function task) + { + { + std::unique_lock guard(mtx); + tasks.push(std::move(task)); + } + + cv.notify_one(); + } + + static unsigned getThreadCount() + { + return std::max(std::thread::hardware_concurrency(), 1u); + } + +private: + void workerFunction() + { + while (std::function task = pop()) + task(); + } + + unsigned threadCount = 1; + std::mutex mtx; + std::condition_variable cv; + std::vector workers; + std::queue> tasks; +}; + int main(int argc, char** argv) { Luau::assertHandler() = assertionHandler; @@ -232,6 +319,8 @@ int main(int argc, char** argv) ReportFormat format = ReportFormat::Default; Luau::Mode mode = Luau::Mode::Nonstrict; bool annotate = false; + int threadCount = 0; + std::string basePath = ""; for (int i = 1; i < argc; ++i) { @@ -250,6 +339,10 @@ int main(int argc, char** argv) FFlag::DebugLuauTimeTracing.value = true; else if (strncmp(argv[i], "--fflags=", 9) == 0) setLuauFlags(argv[i] + 9); + else if (strncmp(argv[i], "-j", 2) == 0) + threadCount = int(strtol(argv[i] + 2, nullptr, 10)); + else if (strncmp(argv[i], "--logbase=", 10) == 0) + basePath = std::string{argv[i] + 10}; } #if !defined(LUAU_ENABLE_TIME_TRACE) @@ -262,13 +355,33 @@ int main(int argc, char** argv) Luau::FrontendOptions frontendOptions; frontendOptions.retainFullTypeGraphs = annotate; + frontendOptions.runLintChecks = true; CliFileResolver fileResolver; CliConfigResolver configResolver(mode); Luau::Frontend frontend(&fileResolver, &configResolver, frontendOptions); - Luau::registerBuiltinGlobals(frontend.typeChecker); - Luau::freeze(frontend.typeChecker.globalTypes); + if (FFlag::DebugLuauLogSolverToJsonFile) + { + frontend.writeJsonLog = [&basePath](const Luau::ModuleName& moduleName, std::string log) + { + std::string path = moduleName + ".log.json"; + size_t pos = moduleName.find_last_of('/'); + if (pos != std::string::npos) + path = moduleName.substr(pos + 1); + + if (!basePath.empty()) + path = joinPaths(basePath, path); + + std::ofstream os(path); + + os << log << std::endl; + printf("Wrote JSON log to %s\n", path.c_str()); + }; + } + + Luau::registerBuiltinGlobals(frontend, frontend.globals); + Luau::freeze(frontend.globals.globalTypes); #ifdef CALLGRIND CALLGRIND_ZERO_STATS; @@ -276,10 +389,51 @@ int main(int argc, char** argv) std::vector files = getSourceFiles(argc, argv); + for (const std::string& path : files) + frontend.queueModuleCheck(path); + + std::vector checkedModules; + + // If thread count is not set, try to use HW thread count, but with an upper limit + // When we improve scalability of typechecking, upper limit can be adjusted/removed + if (threadCount <= 0) + threadCount = std::min(TaskScheduler::getThreadCount(), 8u); + + try + { + TaskScheduler scheduler(threadCount); + + checkedModules = frontend.checkQueuedModules( + std::nullopt, + [&](std::function f) + { + scheduler.push(std::move(f)); + } + ); + } + catch (const Luau::InternalCompilerError& ice) + { + Luau::Location location = ice.location ? *ice.location : Luau::Location(); + + std::string moduleName = ice.moduleName ? *ice.moduleName : ""; + std::string humanReadableName = frontend.fileResolver->getHumanReadableModuleName(moduleName); + + Luau::TypeError error(location, moduleName, Luau::InternalError{ice.message}); + + report( + format, + humanReadableName.c_str(), + location, + "InternalCompilerError", + Luau::toString(error, Luau::TypeErrorToStringOptions{frontend.fileResolver}).c_str() + ); + return 1; + } + int failed = 0; - for (const std::string& path : files) - failed += !analyzeFile(frontend, path.c_str(), format, annotate); + for (const Luau::ModuleName& name : checkedModules) + failed += !reportModuleResult(frontend, name, format, annotate); if (!configResolver.configErrors.empty()) { diff --git a/CLI/Ast.cpp b/CLI/Ast.cpp index 99c583936..b5a922aaa 100644 --- a/CLI/Ast.cpp +++ b/CLI/Ast.cpp @@ -64,8 +64,6 @@ int main(int argc, char** argv) Luau::ParseOptions options; options.captureComments = true; - options.supportContinueStatement = true; - options.allowTypeAnnotations = true; options.allowDeclarationSyntax = true; Luau::ParseResult parseResult = Luau::Parser::parse(source.data(), source.size(), names, allocator, options); diff --git a/CLI/Bytecode.cpp b/CLI/Bytecode.cpp new file mode 100644 index 000000000..2da9570b2 --- /dev/null +++ b/CLI/Bytecode.cpp @@ -0,0 +1,299 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "lua.h" +#include "lualib.h" + +#include "Luau/CodeGen.h" +#include "Luau/Compiler.h" +#include "Luau/BytecodeBuilder.h" +#include "Luau/Parser.h" +#include "Luau/BytecodeSummary.h" +#include "FileUtils.h" +#include "Flags.h" + +#include + +using Luau::CodeGen::FunctionBytecodeSummary; + +struct GlobalOptions +{ + int optimizationLevel = 1; + int debugLevel = 1; +} globalOptions; + +static Luau::CompileOptions copts() +{ + Luau::CompileOptions result = {}; + result.optimizationLevel = globalOptions.optimizationLevel; + result.debugLevel = globalOptions.debugLevel; + result.typeInfoLevel = 1; + + return result; +} + +static void displayHelp(const char* argv0) +{ + printf("Usage: %s [options] [file list]\n", argv0); + printf("\n"); + printf("Available options:\n"); + printf(" -h, --help: Display this usage message.\n"); + printf(" -O: compile with optimization level n (default 1, n should be between 0 and 2).\n"); + printf(" -g: compile with debug level n (default 1, n should be between 0 and 2).\n"); + printf(" --fflags=: flags to be enabled.\n"); + printf(" --summary-file=: file in which bytecode analysis summary will be recorded (default 'bytecode-summary.json').\n"); + + exit(0); +} + +static bool parseArgs(int argc, char** argv, std::string& summaryFile) +{ + for (int i = 1; i < argc; i++) + { + if (strcmp(argv[i], "-h") == 0 || strcmp(argv[i], "--help") == 0) + { + displayHelp(argv[0]); + } + else if (strncmp(argv[i], "-O", 2) == 0) + { + int level = atoi(argv[i] + 2); + if (level < 0 || level > 2) + { + fprintf(stderr, "Error: Optimization level must be between 0 and 2 inclusive.\n"); + return false; + } + globalOptions.optimizationLevel = level; + } + else if (strncmp(argv[i], "-g", 2) == 0) + { + int level = atoi(argv[i] + 2); + if (level < 0 || level > 2) + { + fprintf(stderr, "Error: Debug level must be between 0 and 2 inclusive.\n"); + return false; + } + globalOptions.debugLevel = level; + } + else if (strncmp(argv[i], "--summary-file=", 15) == 0) + { + summaryFile = argv[i] + 15; + + if (summaryFile.size() == 0) + { + fprintf(stderr, "Error: filename missing for '--summary-file'.\n\n"); + return false; + } + } + else if (strncmp(argv[i], "--fflags=", 9) == 0) + { + setLuauFlags(argv[i] + 9); + } + else if (argv[i][0] == '-') + { + fprintf(stderr, "Error: Unrecognized option '%s'.\n\n", argv[i]); + displayHelp(argv[0]); + } + } + + return true; +} + +static void report(const char* name, const Luau::Location& location, const char* type, const char* message) +{ + fprintf(stderr, "%s(%d,%d): %s: %s\n", name, location.begin.line + 1, location.begin.column + 1, type, message); +} + +static void reportError(const char* name, const Luau::ParseError& error) +{ + report(name, error.getLocation(), "SyntaxError", error.what()); +} + +static void reportError(const char* name, const Luau::CompileError& error) +{ + report(name, error.getLocation(), "CompileError", error.what()); +} + +static bool analyzeFile(const char* name, const unsigned nestingLimit, std::vector& summaries) +{ + std::optional source = readFile(name); + + if (!source) + { + fprintf(stderr, "Error opening %s\n", name); + return false; + } + + try + { + Luau::BytecodeBuilder bcb; + + compileOrThrow(bcb, *source, copts()); + + const std::string& bytecode = bcb.getBytecode(); + + std::unique_ptr globalState(luaL_newstate(), lua_close); + lua_State* L = globalState.get(); + + if (luau_load(L, name, bytecode.data(), bytecode.size(), 0) == 0) + { + summaries = Luau::CodeGen::summarizeBytecode(L, -1, nestingLimit); + return true; + } + else + { + fprintf(stderr, "Error loading bytecode %s\n", name); + return false; + } + } + catch (Luau::ParseErrors& e) + { + for (auto& error : e.getErrors()) + reportError(name, error); + return false; + } + catch (Luau::CompileError& e) + { + reportError(name, e); + return false; + } + + return true; +} + +static std::string escapeFilename(const std::string& filename) +{ + std::string escaped; + escaped.reserve(filename.size()); + + for (const char ch : filename) + { + switch (ch) + { + case '\\': + escaped.push_back('/'); + break; + case '"': + escaped.push_back('\\'); + escaped.push_back(ch); + break; + default: + escaped.push_back(ch); + } + } + + return escaped; +} + +static void serializeFunctionSummary(const FunctionBytecodeSummary& summary, FILE* fp) +{ + const unsigned nestingLimit = summary.getNestingLimit(); + const unsigned opLimit = summary.getOpLimit(); + + fprintf(fp, " {\n"); + fprintf(fp, " \"source\": \"%s\",\n", summary.getSource().c_str()); + fprintf(fp, " \"name\": \"%s\",\n", summary.getName().c_str()); + fprintf(fp, " \"line\": %d,\n", summary.getLine()); + fprintf(fp, " \"nestingLimit\": %u,\n", nestingLimit); + fprintf(fp, " \"counts\": ["); + + for (unsigned nesting = 0; nesting <= nestingLimit; ++nesting) + { + fprintf(fp, "\n ["); + + for (unsigned i = 0; i < opLimit; ++i) + { + fprintf(fp, "%d", summary.getCount(nesting, uint8_t(i))); + if (i < opLimit - 1) + fprintf(fp, ", "); + } + + fprintf(fp, "]"); + if (nesting < nestingLimit) + fprintf(fp, ","); + } + + fprintf(fp, "\n ]"); + fprintf(fp, "\n }"); +} + +static void serializeScriptSummary(const std::string& file, const std::vector& scriptSummary, FILE* fp) +{ + std::string escaped(escapeFilename(file)); + const size_t functionCount = scriptSummary.size(); + + fprintf(fp, " \"%s\": [\n", escaped.c_str()); + + for (size_t i = 0; i < functionCount; ++i) + { + serializeFunctionSummary(scriptSummary[i], fp); + fprintf(fp, i == (functionCount - 1) ? "\n" : ",\n"); + } + + fprintf(fp, " ]"); +} + +static bool serializeSummaries( + const std::vector& files, + const std::vector>& scriptSummaries, + const std::string& summaryFile +) +{ + + FILE* fp = fopen(summaryFile.c_str(), "w"); + const size_t fileCount = files.size(); + + if (!fp) + { + fprintf(stderr, "Unable to open '%s'.\n", summaryFile.c_str()); + return false; + } + + fprintf(fp, "{\n"); + + for (size_t i = 0; i < fileCount; ++i) + { + serializeScriptSummary(files[i], scriptSummaries[i], fp); + fprintf(fp, i < (fileCount - 1) ? ",\n" : "\n"); + } + + fprintf(fp, "}"); + fclose(fp); + + return true; +} + +static int assertionHandler(const char* expr, const char* file, int line, const char* function) +{ + printf("%s(%d): ASSERTION FAILED: %s\n", file, line, expr); + return 1; +} + +int main(int argc, char** argv) +{ + Luau::assertHandler() = assertionHandler; + + setLuauFlagsDefault(); + + std::string summaryFile("bytecode-summary.json"); + unsigned nestingLimit = 0; + + if (!parseArgs(argc, argv, summaryFile)) + return 1; + + const std::vector files = getSourceFiles(argc, argv); + size_t fileCount = files.size(); + + std::vector> scriptSummaries; + scriptSummaries.reserve(fileCount); + + for (size_t i = 0; i < fileCount; ++i) + { + if (!analyzeFile(files[i].c_str(), nestingLimit, scriptSummaries[i])) + return 1; + } + + if (!serializeSummaries(files, scriptSummaries, summaryFile)) + return 1; + + fprintf(stdout, "Bytecode summary written to '%s'\n", summaryFile.c_str()); + + return 0; +} diff --git a/CLI/Compile.cpp b/CLI/Compile.cpp new file mode 100644 index 000000000..6ecb44f0e --- /dev/null +++ b/CLI/Compile.cpp @@ -0,0 +1,701 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "lua.h" +#include "lualib.h" + +#include "Luau/CodeGen.h" +#include "Luau/Compiler.h" +#include "Luau/BytecodeBuilder.h" +#include "Luau/Parser.h" +#include "Luau/TimeTrace.h" + +#include "FileUtils.h" +#include "Flags.h" + +#include + +#ifdef _WIN32 +#include +#include +#endif + +LUAU_FASTFLAG(DebugLuauTimeTracing) + +enum class CompileFormat +{ + Text, + Binary, + Remarks, + Codegen, // Prints annotated native code including IR and assembly + CodegenAsm, // Prints annotated native code assembly + CodegenIr, // Prints annotated native code IR + CodegenVerbose, // Prints annotated native code including IR, assembly and outlined code + CodegenNull, + Null +}; + +enum class RecordStats +{ + None, + Total, + File, + Function +}; + +struct GlobalOptions +{ + int optimizationLevel = 1; + int debugLevel = 1; + int typeInfoLevel = 0; + + const char* vectorLib = nullptr; + const char* vectorCtor = nullptr; + const char* vectorType = nullptr; +} globalOptions; + +static Luau::CompileOptions copts() +{ + Luau::CompileOptions result = {}; + result.optimizationLevel = globalOptions.optimizationLevel; + result.debugLevel = globalOptions.debugLevel; + result.typeInfoLevel = globalOptions.typeInfoLevel; + + result.vectorLib = globalOptions.vectorLib; + result.vectorCtor = globalOptions.vectorCtor; + result.vectorType = globalOptions.vectorType; + + return result; +} + +static std::optional getCompileFormat(const char* name) +{ + if (strcmp(name, "text") == 0) + return CompileFormat::Text; + else if (strcmp(name, "binary") == 0) + return CompileFormat::Binary; + else if (strcmp(name, "text") == 0) + return CompileFormat::Text; + else if (strcmp(name, "remarks") == 0) + return CompileFormat::Remarks; + else if (strcmp(name, "codegen") == 0) + return CompileFormat::Codegen; + else if (strcmp(name, "codegenasm") == 0) + return CompileFormat::CodegenAsm; + else if (strcmp(name, "codegenir") == 0) + return CompileFormat::CodegenIr; + else if (strcmp(name, "codegenverbose") == 0) + return CompileFormat::CodegenVerbose; + else if (strcmp(name, "codegennull") == 0) + return CompileFormat::CodegenNull; + else if (strcmp(name, "null") == 0) + return CompileFormat::Null; + else + return std::nullopt; +} + +static void report(const char* name, const Luau::Location& location, const char* type, const char* message) +{ + fprintf(stderr, "%s(%d,%d): %s: %s\n", name, location.begin.line + 1, location.begin.column + 1, type, message); +} + +static void reportError(const char* name, const Luau::ParseError& error) +{ + report(name, error.getLocation(), "SyntaxError", error.what()); +} + +static void reportError(const char* name, const Luau::CompileError& error) +{ + report(name, error.getLocation(), "CompileError", error.what()); +} + +#ifndef NO_CODEGEN +static std::string getCodegenAssembly( + const char* name, + const std::string& bytecode, + Luau::CodeGen::AssemblyOptions options, + Luau::CodeGen::LoweringStats* stats +) +{ + std::unique_ptr globalState(luaL_newstate(), lua_close); + lua_State* L = globalState.get(); + + if (luau_load(L, name, bytecode.data(), bytecode.size(), 0) == 0) + return Luau::CodeGen::getAssembly(L, -1, options, stats); + + fprintf(stderr, "Error loading bytecode %s\n", name); + return ""; +} +#endif + +static void annotateInstruction(void* context, std::string& text, int fid, int instpos) +{ + Luau::BytecodeBuilder& bcb = *(Luau::BytecodeBuilder*)context; + + bcb.annotateInstruction(text, fid, instpos); +} + +struct CompileStats +{ + size_t lines; + size_t bytecode; + size_t bytecodeInstructionCount; + size_t codegen; + + double readTime; + double miscTime; + double parseTime; + double compileTime; + double codegenTime; + + Luau::CodeGen::LoweringStats lowerStats; + + CompileStats& operator+=(const CompileStats& that) + { + this->lines += that.lines; + this->bytecode += that.bytecode; + this->bytecodeInstructionCount += that.bytecodeInstructionCount; + this->codegen += that.codegen; + this->readTime += that.readTime; + this->miscTime += that.miscTime; + this->parseTime += that.parseTime; + this->compileTime += that.compileTime; + this->codegenTime += that.codegenTime; + this->lowerStats += that.lowerStats; + + return *this; + } + + CompileStats operator+(const CompileStats& other) const + { + CompileStats result(*this); + result += other; + return result; + } +}; + +#define WRITE_NAME(INDENT, NAME) fprintf(fp, INDENT "\"" #NAME "\": ") +#define WRITE_PAIR(INDENT, NAME, FORMAT) fprintf(fp, INDENT "\"" #NAME "\": " FORMAT, stats.NAME) +#define WRITE_PAIR_STRING(INDENT, NAME, FORMAT) fprintf(fp, INDENT "\"" #NAME "\": " FORMAT, stats.NAME.c_str()) + +void serializeFunctionStats(FILE* fp, const Luau::CodeGen::FunctionStats& stats) +{ + fprintf(fp, " {\n"); + WRITE_PAIR_STRING(" ", name, "\"%s\",\n"); + WRITE_PAIR(" ", line, "%d,\n"); + WRITE_PAIR(" ", bcodeCount, "%u,\n"); + WRITE_PAIR(" ", irCount, "%u,\n"); + WRITE_PAIR(" ", asmCount, "%u,\n"); + WRITE_PAIR(" ", asmSize, "%u,\n"); + + WRITE_NAME(" ", bytecodeSummary); + const size_t nestingLimit = stats.bytecodeSummary.size(); + + if (nestingLimit == 0) + fprintf(fp, "[]"); + else + { + fprintf(fp, "[\n"); + for (size_t i = 0; i < nestingLimit; ++i) + { + const std::vector& counts = stats.bytecodeSummary[i]; + fprintf(fp, " ["); + for (size_t j = 0; j < counts.size(); ++j) + { + fprintf(fp, "%u", counts[j]); + if (j < counts.size() - 1) + fprintf(fp, ", "); + } + fprintf(fp, "]"); + if (i < stats.bytecodeSummary.size() - 1) + fprintf(fp, ",\n"); + } + fprintf(fp, "\n ]"); + } + + fprintf(fp, "\n }"); +} + +void serializeBlockLinearizationStats(FILE* fp, const Luau::CodeGen::BlockLinearizationStats& stats) +{ + fprintf(fp, "{\n"); + + WRITE_PAIR(" ", constPropInstructionCount, "%u,\n"); + WRITE_PAIR(" ", timeSeconds, "%f\n"); + + fprintf(fp, " }"); +} + +void serializeLoweringStats(FILE* fp, const Luau::CodeGen::LoweringStats& stats) +{ + fprintf(fp, "{\n"); + + WRITE_PAIR(" ", totalFunctions, "%u,\n"); + WRITE_PAIR(" ", skippedFunctions, "%u,\n"); + WRITE_PAIR(" ", spillsToSlot, "%d,\n"); + WRITE_PAIR(" ", spillsToRestore, "%d,\n"); + WRITE_PAIR(" ", maxSpillSlotsUsed, "%u,\n"); + WRITE_PAIR(" ", blocksPreOpt, "%u,\n"); + WRITE_PAIR(" ", blocksPostOpt, "%u,\n"); + WRITE_PAIR(" ", maxBlockInstructions, "%u,\n"); + WRITE_PAIR(" ", regAllocErrors, "%d,\n"); + WRITE_PAIR(" ", loweringErrors, "%d,\n"); + + WRITE_NAME(" ", blockLinearizationStats); + serializeBlockLinearizationStats(fp, stats.blockLinearizationStats); + fprintf(fp, ",\n"); + + WRITE_NAME(" ", functions); + const size_t functionCount = stats.functions.size(); + + if (functionCount == 0) + fprintf(fp, "[]"); + else + { + fprintf(fp, "[\n"); + for (size_t i = 0; i < functionCount; ++i) + { + serializeFunctionStats(fp, stats.functions[i]); + if (i < functionCount - 1) + fprintf(fp, ",\n"); + } + fprintf(fp, "\n ]"); + } + + fprintf(fp, "\n }"); +} + +void serializeCompileStats(FILE* fp, const CompileStats& stats) +{ + fprintf(fp, "{\n"); + + WRITE_PAIR(" ", lines, "%zu,\n"); + WRITE_PAIR(" ", bytecode, "%zu,\n"); + WRITE_PAIR(" ", bytecodeInstructionCount, "%zu,\n"); + WRITE_PAIR(" ", codegen, "%zu,\n"); + WRITE_PAIR(" ", readTime, "%f,\n"); + WRITE_PAIR(" ", miscTime, "%f,\n"); + WRITE_PAIR(" ", parseTime, "%f,\n"); + WRITE_PAIR(" ", compileTime, "%f,\n"); + WRITE_PAIR(" ", codegenTime, "%f,\n"); + + WRITE_NAME(" ", lowerStats); + serializeLoweringStats(fp, stats.lowerStats); + + fprintf(fp, "\n }"); +} + +#undef WRITE_NAME +#undef WRITE_PAIR +#undef WRITE_PAIR_STRING + +static double recordDeltaTime(double& timer) +{ + double now = Luau::TimeTrace::getClock(); + double delta = now - timer; + timer = now; + return delta; +} + +static bool compileFile(const char* name, CompileFormat format, Luau::CodeGen::AssemblyOptions::Target assemblyTarget, CompileStats& stats) +{ + double currts = Luau::TimeTrace::getClock(); + + std::optional source = readFile(name); + if (!source) + { + fprintf(stderr, "Error opening %s\n", name); + return false; + } + + stats.readTime += recordDeltaTime(currts); + + // NOTE: Normally, you should use Luau::compile or luau_compile (see lua_require as an example) + // This function is much more complicated because it supports many output human-readable formats through internal interfaces + + try + { + Luau::BytecodeBuilder bcb; + + Luau::CodeGen::AssemblyOptions options; + options.target = assemblyTarget; + options.outputBinary = format == CompileFormat::CodegenNull; + + if (!options.outputBinary) + { + options.includeAssembly = format != CompileFormat::CodegenIr; + options.includeIr = format != CompileFormat::CodegenAsm; + options.includeIrTypes = format != CompileFormat::CodegenAsm; + options.includeOutlinedCode = format == CompileFormat::CodegenVerbose; + } + + options.annotator = annotateInstruction; + options.annotatorContext = &bcb; + + if (format == CompileFormat::Text) + { + bcb.setDumpFlags( + Luau::BytecodeBuilder::Dump_Code | Luau::BytecodeBuilder::Dump_Source | Luau::BytecodeBuilder::Dump_Locals | + Luau::BytecodeBuilder::Dump_Remarks | Luau::BytecodeBuilder::Dump_Types + ); + bcb.setDumpSource(*source); + } + else if (format == CompileFormat::Remarks) + { + bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Source | Luau::BytecodeBuilder::Dump_Remarks); + bcb.setDumpSource(*source); + } + else if (format == CompileFormat::Codegen || format == CompileFormat::CodegenAsm || format == CompileFormat::CodegenIr || format == CompileFormat::CodegenVerbose) + { + bcb.setDumpFlags( + Luau::BytecodeBuilder::Dump_Code | Luau::BytecodeBuilder::Dump_Source | Luau::BytecodeBuilder::Dump_Locals | + Luau::BytecodeBuilder::Dump_Remarks + ); + bcb.setDumpSource(*source); + } + + stats.miscTime += recordDeltaTime(currts); + + Luau::Allocator allocator; + Luau::AstNameTable names(allocator); + Luau::ParseResult result = Luau::Parser::parse(source->c_str(), source->size(), names, allocator); + + if (!result.errors.empty()) + throw Luau::ParseErrors(result.errors); + + stats.lines += result.lines; + stats.parseTime += recordDeltaTime(currts); + + Luau::compileOrThrow(bcb, result, names, copts()); + stats.bytecode += bcb.getBytecode().size(); + stats.bytecodeInstructionCount = bcb.getTotalInstructionCount(); + stats.compileTime += recordDeltaTime(currts); + + switch (format) + { + case CompileFormat::Text: + printf("%s", bcb.dumpEverything().c_str()); + break; + case CompileFormat::Remarks: + printf("%s", bcb.dumpSourceRemarks().c_str()); + break; + case CompileFormat::Binary: + fwrite(bcb.getBytecode().data(), 1, bcb.getBytecode().size(), stdout); + break; +#ifndef NO_CODEGEN + case CompileFormat::Codegen: + case CompileFormat::CodegenAsm: + case CompileFormat::CodegenIr: + case CompileFormat::CodegenVerbose: + printf("%s", getCodegenAssembly(name, bcb.getBytecode(), options, &stats.lowerStats).c_str()); + break; + case CompileFormat::CodegenNull: + stats.codegen += getCodegenAssembly(name, bcb.getBytecode(), options, &stats.lowerStats).size(); + stats.codegenTime += recordDeltaTime(currts); + break; +#endif + case CompileFormat::Null: + break; + } + + return true; + } + catch (Luau::ParseErrors& e) + { + for (auto& error : e.getErrors()) + reportError(name, error); + return false; + } + catch (Luau::CompileError& e) + { + reportError(name, e); + return false; + } +} + +static void displayHelp(const char* argv0) +{ + printf("Usage: %s [--mode] [options] [file list]\n", argv0); + printf("\n"); + printf("Available modes:\n"); + printf(" binary, text, remarks, codegen\n"); + printf("\n"); + printf("Available options:\n"); + printf(" -h, --help: Display this usage message.\n"); + printf(" -O: compile with optimization level n (default 1, n should be between 0 and 2).\n"); + printf(" -g: compile with debug level n (default 1, n should be between 0 and 2).\n"); + printf(" --target=: compile code for specific architecture (a64, x64, a64_nf, x64_ms).\n"); + printf(" --timetrace: record compiler time tracing information into trace.json\n"); + printf(" --record-stats=: granularity of compilation stats (total, file, function).\n"); + printf(" --bytecode-summary: Compute bytecode operation distribution.\n"); + printf(" --stats-file=: file in which compilation stats will be recored (default 'stats.json').\n"); + printf(" --vector-lib=: name of the library providing vector type operations.\n"); + printf(" --vector-ctor=: name of the function constructing a vector value.\n"); + printf(" --vector-type=: name of the vector type.\n"); +} + +static int assertionHandler(const char* expr, const char* file, int line, const char* function) +{ + printf("%s(%d): ASSERTION FAILED: %s\n", file, line, expr); + return 1; +} + +std::string escapeFilename(const std::string& filename) +{ + std::string escaped; + escaped.reserve(filename.size()); + + for (const char ch : filename) + { + switch (ch) + { + case '\\': + escaped.push_back('/'); + break; + case '"': + escaped.push_back('\\'); + escaped.push_back(ch); + break; + default: + escaped.push_back(ch); + } + } + + return escaped; +} + +int main(int argc, char** argv) +{ + Luau::assertHandler() = assertionHandler; + + setLuauFlagsDefault(); + + CompileFormat compileFormat = CompileFormat::Text; + Luau::CodeGen::AssemblyOptions::Target assemblyTarget = Luau::CodeGen::AssemblyOptions::Host; + RecordStats recordStats = RecordStats::None; + std::string statsFile("stats.json"); + bool bytecodeSummary = false; + + for (int i = 1; i < argc; i++) + { + if (strcmp(argv[i], "-h") == 0 || strcmp(argv[i], "--help") == 0) + { + displayHelp(argv[0]); + return 0; + } + else if (strncmp(argv[i], "-O", 2) == 0) + { + int level = atoi(argv[i] + 2); + if (level < 0 || level > 2) + { + fprintf(stderr, "Error: Optimization level must be between 0 and 2 inclusive.\n"); + return 1; + } + globalOptions.optimizationLevel = level; + } + else if (strncmp(argv[i], "-g", 2) == 0) + { + int level = atoi(argv[i] + 2); + if (level < 0 || level > 2) + { + fprintf(stderr, "Error: Debug level must be between 0 and 2 inclusive.\n"); + return 1; + } + globalOptions.debugLevel = level; + } + else if (strncmp(argv[i], "-t", 2) == 0) + { + int level = atoi(argv[i] + 2); + if (level < 0 || level > 1) + { + fprintf(stderr, "Error: Type info level must be between 0 and 1 inclusive.\n"); + return 1; + } + globalOptions.typeInfoLevel = level; + } + else if (strncmp(argv[i], "--target=", 9) == 0) + { + const char* value = argv[i] + 9; + + if (strcmp(value, "a64") == 0) + assemblyTarget = Luau::CodeGen::AssemblyOptions::A64; + else if (strcmp(value, "a64_nf") == 0) + assemblyTarget = Luau::CodeGen::AssemblyOptions::A64_NoFeatures; + else if (strcmp(value, "x64") == 0) + assemblyTarget = Luau::CodeGen::AssemblyOptions::X64_SystemV; + else if (strcmp(value, "x64_ms") == 0) + assemblyTarget = Luau::CodeGen::AssemblyOptions::X64_Windows; + else + { + fprintf(stderr, "Error: unknown target\n"); + return 1; + } + } + else if (strcmp(argv[i], "--timetrace") == 0) + { + FFlag::DebugLuauTimeTracing.value = true; + } + else if (strncmp(argv[i], "--record-stats=", 15) == 0) + { + const char* value = argv[i] + 15; + + if (strcmp(value, "total") == 0) + recordStats = RecordStats::Total; + else if (strcmp(value, "file") == 0) + recordStats = RecordStats::File; + else if (strcmp(value, "function") == 0) + recordStats = RecordStats::Function; + else + { + fprintf(stderr, "Error: unknown 'granularity' for '--record-stats'.\n"); + return 1; + } + } + else if (strncmp(argv[i], "--bytecode-summary", 18) == 0) + { + bytecodeSummary = true; + } + else if (strncmp(argv[i], "--stats-file=", 13) == 0) + { + statsFile = argv[i] + 13; + + if (statsFile.size() == 0) + { + fprintf(stderr, "Error: filename missing for '--stats-file'.\n\n"); + return 1; + } + } + else if (strncmp(argv[i], "--fflags=", 9) == 0) + { + setLuauFlags(argv[i] + 9); + } + else if (strncmp(argv[i], "--vector-lib=", 13) == 0) + { + globalOptions.vectorLib = argv[i] + 13; + } + else if (strncmp(argv[i], "--vector-ctor=", 14) == 0) + { + globalOptions.vectorCtor = argv[i] + 14; + } + else if (strncmp(argv[i], "--vector-type=", 14) == 0) + { + globalOptions.vectorType = argv[i] + 14; + } + else if (argv[i][0] == '-' && argv[i][1] == '-' && getCompileFormat(argv[i] + 2)) + { + compileFormat = *getCompileFormat(argv[i] + 2); + } + else if (argv[i][0] == '-') + { + fprintf(stderr, "Error: Unrecognized option '%s'.\n\n", argv[i]); + displayHelp(argv[0]); + return 1; + } + } + + if (bytecodeSummary && (recordStats != RecordStats::Function)) + { + fprintf(stderr, "'Error: Required '--record-stats=function' for '--bytecode-summary'.\n"); + return 1; + } + +#if !defined(LUAU_ENABLE_TIME_TRACE) + if (FFlag::DebugLuauTimeTracing) + { + fprintf(stderr, "To run with --timetrace, Luau has to be built with LUAU_ENABLE_TIME_TRACE enabled\n"); + return 1; + } +#endif + + const std::vector files = getSourceFiles(argc, argv); + +#ifdef _WIN32 + if (compileFormat == CompileFormat::Binary) + _setmode(_fileno(stdout), _O_BINARY); +#endif + + const size_t fileCount = files.size(); + CompileStats stats = {}; + + std::vector fileStats; + if (recordStats == RecordStats::File || recordStats == RecordStats::Function) + fileStats.reserve(fileCount); + + int failed = 0; + unsigned functionStats = (recordStats == RecordStats::Function ? Luau::CodeGen::FunctionStats_Enable : 0) | + (bytecodeSummary ? Luau::CodeGen::FunctionStats_BytecodeSummary : 0); + for (const std::string& path : files) + { + CompileStats fileStat = {}; + fileStat.lowerStats.functionStatsFlags = functionStats; + failed += !compileFile(path.c_str(), compileFormat, assemblyTarget, fileStat); + stats += fileStat; + if (recordStats == RecordStats::File || recordStats == RecordStats::Function) + fileStats.push_back(fileStat); + } + + if (compileFormat == CompileFormat::Null) + { + printf( + "Compiled %d KLOC into %d KB bytecode (read %.2fs, parse %.2fs, compile %.2fs)\n", + int(stats.lines / 1000), + int(stats.bytecode / 1024), + stats.readTime, + stats.parseTime, + stats.compileTime + ); + } + else if (compileFormat == CompileFormat::CodegenNull) + { + printf( + "Compiled %d KLOC into %d KB bytecode => %d KB native code (%.2fx) (read %.2fs, parse %.2fs, compile %.2fs, codegen %.2fs)\n", + int(stats.lines / 1000), + int(stats.bytecode / 1024), + int(stats.codegen / 1024), + stats.bytecode == 0 ? 0.0 : double(stats.codegen) / double(stats.bytecode), + stats.readTime, + stats.parseTime, + stats.compileTime, + stats.codegenTime + ); + + printf( + "Lowering: regalloc failed: %d, lowering failed %d; spills to stack: %d, spills to restore: %d, max spill slot %u\n", + stats.lowerStats.regAllocErrors, + stats.lowerStats.loweringErrors, + stats.lowerStats.spillsToSlot, + stats.lowerStats.spillsToRestore, + stats.lowerStats.maxSpillSlotsUsed + ); + } + + if (recordStats != RecordStats::None) + { + FILE* fp = fopen(statsFile.c_str(), "w"); + + if (!fp) + { + fprintf(stderr, "Unable to open 'stats.json'\n"); + return 1; + } + + if (recordStats == RecordStats::Total) + { + serializeCompileStats(fp, stats); + } + else if (recordStats == RecordStats::File || recordStats == RecordStats::Function) + { + fprintf(fp, "{\n"); + for (size_t i = 0; i < fileCount; ++i) + { + std::string escaped(escapeFilename(files[i])); + fprintf(fp, " \"%s\": ", escaped.c_str()); + serializeCompileStats(fp, fileStats[i]); + fprintf(fp, i == (fileCount - 1) ? "\n" : ",\n"); + } + fprintf(fp, "}"); + } + + fclose(fp); + } + + return failed ? 1 : 0; +} diff --git a/CLI/FileUtils.cpp b/CLI/FileUtils.cpp index 39a14ec71..e9f40a09a 100644 --- a/CLI/FileUtils.cpp +++ b/CLI/FileUtils.cpp @@ -10,7 +10,8 @@ #ifndef NOMINMAX #define NOMINMAX #endif -#include +#include +#include #else #include #include @@ -44,6 +45,142 @@ static std::string toUtf8(const std::wstring& path) } #endif +bool isAbsolutePath(std::string_view path) +{ +#ifdef _WIN32 + // Must either begin with "X:/", "X:\", "/", or "\", where X is a drive letter + return (path.size() >= 3 && isalpha(path[0]) && path[1] == ':' && (path[2] == '/' || path[2] == '\\')) || + (path.size() >= 1 && (path[0] == '/' || path[0] == '\\')); +#else + // Must begin with '/' + return path.size() >= 1 && path[0] == '/'; +#endif +} + +std::optional getCurrentWorkingDirectory() +{ + // 2^17 - derived from the Windows path length limit + constexpr size_t maxPathLength = 131072; + constexpr size_t initialPathLength = 260; + + std::string directory(initialPathLength, '\0'); + char* cstr = nullptr; + + while (!cstr && directory.size() <= maxPathLength) + { +#ifdef _WIN32 + cstr = _getcwd(directory.data(), static_cast(directory.size())); +#else + cstr = getcwd(directory.data(), directory.size()); +#endif + if (cstr) + { + directory.resize(strlen(cstr)); + return directory; + } + else if (errno != ERANGE || directory.size() * 2 > maxPathLength) + { + return std::nullopt; + } + else + { + directory.resize(directory.size() * 2); + } + } + return std::nullopt; +} + +// Returns the normal/canonical form of a path (e.g. "../subfolder/../module.luau" -> "../module.luau") +std::string normalizePath(std::string_view path) +{ + return resolvePath(path, ""); +} + +// Takes a path that is relative to the file at baseFilePath and returns the path explicitly rebased onto baseFilePath. +// For absolute paths, baseFilePath will be ignored, and this function will resolve the path to a canonical path: +// (e.g. "/Users/.././Users/johndoe" -> "/Users/johndoe"). +std::string resolvePath(std::string_view path, std::string_view baseFilePath) +{ + std::vector pathComponents; + std::vector baseFilePathComponents; + + // Dependent on whether the final resolved path is absolute or relative + // - if relative (when path and baseFilePath are both relative), resolvedPathPrefix remains empty + // - if absolute (if either path or baseFilePath are absolute), resolvedPathPrefix is "C:\", "/", etc. + std::string resolvedPathPrefix; + + if (isAbsolutePath(path)) + { + // path is absolute, we use path's prefix and ignore baseFilePath + size_t afterPrefix = path.find_first_of("\\/") + 1; + resolvedPathPrefix = path.substr(0, afterPrefix); + pathComponents = splitPath(path.substr(afterPrefix)); + } + else + { + pathComponents = splitPath(path); + if (isAbsolutePath(baseFilePath)) + { + // path is relative and baseFilePath is absolute, we use baseFilePath's prefix + size_t afterPrefix = baseFilePath.find_first_of("\\/") + 1; + resolvedPathPrefix = baseFilePath.substr(0, afterPrefix); + baseFilePathComponents = splitPath(baseFilePath.substr(afterPrefix)); + } + else + { + // path and baseFilePath are both relative, we do not set a prefix (resolved path will be relative) + baseFilePathComponents = splitPath(baseFilePath); + } + } + + // Remove filename from components + if (!baseFilePathComponents.empty()) + baseFilePathComponents.pop_back(); + + // Resolve the path by applying pathComponents to baseFilePathComponents + int numPrependedParents = 0; + for (std::string_view component : pathComponents) + { + if (component == "..") + { + if (baseFilePathComponents.empty()) + { + if (resolvedPathPrefix.empty()) // only when final resolved path will be relative + numPrependedParents++; // "../" will later be added to the beginning of the resolved path + } + else if (baseFilePathComponents.back() != "..") + { + baseFilePathComponents.pop_back(); // Resolve cases like "folder/subfolder/../../file" to "file" + } + } + else if (component != "." && !component.empty()) + { + baseFilePathComponents.push_back(component); + } + } + + // Join baseFilePathComponents to form the resolved path + std::string resolvedPath = resolvedPathPrefix; + // Only when resolvedPath will be relative + for (int i = 0; i < numPrependedParents; i++) + { + resolvedPath += "../"; + } + for (auto iter = baseFilePathComponents.begin(); iter != baseFilePathComponents.end(); ++iter) + { + if (iter != baseFilePathComponents.begin()) + resolvedPath += "/"; + + resolvedPath += *iter; + } + if (resolvedPath.size() > resolvedPathPrefix.size() && resolvedPath.back() == '/') + { + // Remove trailing '/' if present + resolvedPath.pop_back(); + } + return resolvedPath; +} + std::optional readFile(const std::string& name) { #ifdef _WIN32 @@ -165,11 +302,14 @@ static bool traverseDirectoryRec(const std::string& path, const std::function splitPath(std::string_view path) +{ + std::vector components; + + size_t pos = 0; + size_t nextPos = path.find_first_of("\\/", pos); + + while (nextPos != std::string::npos) + { + components.push_back(path.substr(pos, nextPos - pos)); + pos = nextPos + 1; + nextPos = path.find_first_of("\\/", pos); + } + components.push_back(path.substr(pos)); + + return components; +} + std::string joinPaths(const std::string& lhs, const std::string& rhs) { std::string result = lhs; @@ -267,6 +439,10 @@ std::vector getSourceFiles(int argc, char** argv) for (int i = 1; i < argc; ++i) { + // Early out once we reach --program-args,-a since the remaining args are passed to lua + if (strcmp(argv[i], "--program-args") == 0 || strcmp(argv[i], "-a") == 0) + return files; + // Treat '-' as a special file whose source is read from stdin // All other arguments that start with '-' are skipped if (argv[i][0] == '-' && argv[i][1] != '\0') @@ -274,12 +450,16 @@ std::vector getSourceFiles(int argc, char** argv) if (isDirectory(argv[i])) { - traverseDirectory(argv[i], [&](const std::string& name) { - std::string ext = getExtension(name); - - if (ext == ".lua" || ext == ".luau") - files.push_back(name); - }); + traverseDirectory( + argv[i], + [&](const std::string& name) + { + std::string ext = getExtension(name); + + if (ext == ".lua" || ext == ".luau") + files.push_back(name); + } + ); } else { diff --git a/CLI/FileUtils.h b/CLI/FileUtils.h index 97471cdc0..dce94ace0 100644 --- a/CLI/FileUtils.h +++ b/CLI/FileUtils.h @@ -3,15 +3,24 @@ #include #include +#include #include #include +std::optional getCurrentWorkingDirectory(); + +std::string normalizePath(std::string_view path); +std::string resolvePath(std::string_view relativePath, std::string_view baseFilePath); + std::optional readFile(const std::string& name); std::optional readStdin(); +bool isAbsolutePath(std::string_view path); +bool isFile(const std::string& path); bool isDirectory(const std::string& path); bool traverseDirectory(const std::string& path, const std::function& callback); +std::vector splitPath(std::string_view path); std::string joinPaths(const std::string& lhs, const std::string& rhs); std::optional getParentPath(const std::string& path); diff --git a/CLI/Flags.cpp b/CLI/Flags.cpp index 4e261171a..c0bb485f4 100644 --- a/CLI/Flags.cpp +++ b/CLI/Flags.cpp @@ -2,11 +2,14 @@ #include "Luau/Common.h" #include "Luau/ExperimentalFlags.h" +#include // TODO: remove with LuauTypeSolverRelease #include #include #include +LUAU_DYNAMIC_FASTINT(LuauTypeSolverRelease) + static void setLuauFlag(std::string_view name, bool state) { for (Luau::FValue* flag = Luau::FValue::list; flag; flag = flag->next) @@ -23,6 +26,13 @@ static void setLuauFlag(std::string_view name, bool state) static void setLuauFlags(bool state) { + if (state) + { + // Setting flags to 'true' means enabling all Luau flags including new type solver + // In that case, it is provided with all fixes enabled (as if each fix had its own boolean flag) + DFInt::LuauTypeSolverRelease.value = std::numeric_limits::max(); + } + for (Luau::FValue* flag = Luau::FValue::list; flag; flag = flag->next) if (strncmp(flag->name, "Luau", 4) == 0) flag->value = state; @@ -54,8 +64,9 @@ void setLuauFlags(const char* list) else if (value == "false" || value == "False") setLuauFlag(key, false); else - fprintf(stderr, "Warning: unrecognized value '%.*s' for flag '%.*s'.\n", int(value.length()), value.data(), int(key.length()), - key.data()); + fprintf( + stderr, "Warning: unrecognized value '%.*s' for flag '%.*s'.\n", int(value.length()), value.data(), int(key.length()), key.data() + ); } else { diff --git a/CLI/Profiler.cpp b/CLI/Profiler.cpp index d3ad4e996..3cf0aea2e 100644 --- a/CLI/Profiler.cpp +++ b/CLI/Profiler.cpp @@ -131,8 +131,13 @@ void profilerDump(const char* path) fclose(f); - printf("Profiler dump written to %s (total runtime %.3f seconds, %lld samples, %lld stacks)\n", path, double(total) / 1e6, - static_cast(gProfiler.samples.load()), static_cast(gProfiler.data.size())); + printf( + "Profiler dump written to %s (total runtime %.3f seconds, %lld samples, %lld stacks)\n", + path, + double(total) / 1e6, + static_cast(gProfiler.samples.load()), + static_cast(gProfiler.data.size()) + ); uint64_t totalgc = 0; for (uint64_t p : gProfiler.gc) diff --git a/CLI/Reduce.cpp b/CLI/Reduce.cpp index b7c780128..7f8c459c7 100644 --- a/CLI/Reduce.cpp +++ b/CLI/Reduce.cpp @@ -15,7 +15,7 @@ #define VERBOSE 0 // 1 - print out commandline invocations. 2 - print out stdout -#ifdef _WIN32 +#if defined(_WIN32) && !defined(__MINGW32__) const auto popen = &_popen; const auto pclose = &_pclose; @@ -56,10 +56,9 @@ struct Reducer ParseResult parseResult; AstStatBlock* root; - std::string tempScriptName; + std::string scriptName; - std::string appName; - std::vector appArgs; + std::string command; std::string_view searchText; Reducer() @@ -99,10 +98,10 @@ struct Reducer } while (true); } - FILE* f = fopen(tempScriptName.c_str(), "w"); + FILE* f = fopen(scriptName.c_str(), "w"); if (!f) { - printf("Unable to open temp script to %s\n", tempScriptName.c_str()); + printf("Unable to open temp script to %s\n", scriptName.c_str()); exit(2); } @@ -113,7 +112,7 @@ struct Reducer if (written != source.size()) { printf("??? %zu %zu\n", written, source.size()); - printf("Unable to write to temp script %s\n", tempScriptName.c_str()); + printf("Unable to write to temp script %s\n", scriptName.c_str()); exit(3); } @@ -142,12 +141,18 @@ struct Reducer { writeTempScript(); - std::string command = appName + " " + escape(tempScriptName); - for (const auto& arg : appArgs) - command += " " + escape(arg); + std::string cmd = command; + while (true) + { + auto pos = cmd.find("{}"); + if (std::string::npos == pos) + break; + + cmd = cmd.substr(0, pos) + escape(scriptName) + cmd.substr(pos + 2); + } #if VERBOSE >= 1 - printf("running %s\n", command.c_str()); + printf("running %s\n", cmd.c_str()); #endif TestResult result = TestResult::NoBug; @@ -155,7 +160,7 @@ struct Reducer ++step; printf("Step %4d...\n", step); - FILE* p = popen(command.c_str(), "r"); + FILE* p = popen(cmd.c_str(), "r"); while (!feof(p)) { @@ -179,7 +184,8 @@ struct Reducer { std::vector result; - auto append = [&](AstStatBlock* block) { + auto append = [&](AstStatBlock* block) + { if (block) result.insert(result.end(), block->body.data, block->body.data + block->body.size); }; @@ -245,7 +251,8 @@ struct Reducer std::vector> result; - auto append = [&result](Span a, Span b) { + auto append = [&result](Span a, Span b) + { if (a.first == a.second && b.first == b.second) return; else @@ -424,30 +431,19 @@ struct Reducer } } - void run(const std::string scriptName, const std::string appName, const std::vector& appArgs, std::string_view source, - std::string_view searchText) + void run(const std::string scriptName, const std::string command, std::string_view source, std::string_view searchText) { - tempScriptName = scriptName; - if (tempScriptName.substr(tempScriptName.size() - 4) == ".lua") - { - tempScriptName.erase(tempScriptName.size() - 4); - tempScriptName += "-reduced.lua"; - } - else - { - this->tempScriptName = scriptName + "-reduced"; - } + this->scriptName = scriptName; #if 0 // Handy debugging trick: VS Code will update its view of the file in realtime as it is edited. - std::string wheee = "code " + tempScriptName; + std::string wheee = "code " + scriptName; system(wheee.c_str()); #endif - printf("Temp script: %s\n", tempScriptName.c_str()); + printf("Script: %s\n", scriptName.c_str()); - this->appName = appName; - this->appArgs = appArgs; + this->command = command; this->searchText = searchText; parseResult = Parser::parse(source.data(), source.size(), nameTable, allocator, parseOptions); @@ -470,13 +466,14 @@ struct Reducer writeTempScript(/* minify */ true); - printf("Done! Check %s\n", tempScriptName.c_str()); + printf("Done! Check %s\n", scriptName.c_str()); } }; [[noreturn]] void help(const std::vector& args) { - printf("Syntax: %s script application \"search text\" [arguments]\n", args[0].data()); + printf("Syntax: %s script command \"search text\"\n", args[0].data()); + printf(" Within command, use {} as a stand-in for the script being reduced\n"); exit(1); } @@ -484,7 +481,7 @@ int main(int argc, char** argv) { const std::vector args(argv, argv + argc); - if (args.size() < 4) + if (args.size() != 4) help(args); for (size_t i = 1; i < args.size(); ++i) @@ -496,7 +493,6 @@ int main(int argc, char** argv) const std::string scriptName = argv[1]; const std::string appName = argv[2]; const std::string searchText = argv[3]; - const std::vector appArgs(begin(args) + 4, end(args)); std::optional source = readFile(scriptName); @@ -507,5 +503,5 @@ int main(int argc, char** argv) } Reducer reducer; - reducer.run(scriptName, appName, appArgs, *source, searchText); + reducer.run(scriptName, appName, *source, searchText); } diff --git a/CLI/Repl.cpp b/CLI/Repl.cpp index 00e72f468..018480a61 100644 --- a/CLI/Repl.cpp +++ b/CLI/Repl.cpp @@ -1,6 +1,7 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Repl.h" +#include "Luau/Common.h" #include "lua.h" #include "lualib.h" @@ -8,13 +9,14 @@ #include "Luau/CodeGen.h" #endif #include "Luau/Compiler.h" -#include "Luau/BytecodeBuilder.h" #include "Luau/Parser.h" +#include "Luau/TimeTrace.h" #include "Coverage.h" #include "FileUtils.h" #include "Flags.h" #include "Profiler.h" +#include "Require.h" #include "isocline.h" @@ -28,6 +30,10 @@ #include #endif +#ifdef __linux__ +#include +#endif + #ifdef CALLGRIND #include #endif @@ -37,30 +43,12 @@ LUAU_FASTFLAG(DebugLuauTimeTracing) -enum class CliMode -{ - Unknown, - Repl, - Compile, - RunSourceFiles -}; - -enum class CompileFormat -{ - Text, - Binary, - Remarks, - Codegen, // Prints annotated native code including IR and assembly - CodegenAsm, // Prints annotated native code assembly - CodegenIr, // Prints annotated native code IR - CodegenVerbose, // Prints annotated native code including IR, assembly and outlined code - CodegenNull, - Null -}; constexpr int MaxTraversalLimit = 50; static bool codegen = false; +static int program_argc = 0; +char** program_argv = nullptr; // Ctrl-C handling static void sigintCallback(lua_State* L, int gc) @@ -102,6 +90,7 @@ static Luau::CompileOptions copts() Luau::CompileOptions result = {}; result.optimizationLevel = globalOptions.optimizationLevel; result.debugLevel = globalOptions.debugLevel; + result.typeInfoLevel = 1; result.coverageLevel = coverageActive() ? 2 : 0; return result; @@ -135,27 +124,15 @@ static int finishrequire(lua_State* L) static int lua_require(lua_State* L) { std::string name = luaL_checkstring(L, 1); - std::string chunkname = "=" + name; - luaL_findtable(L, LUA_REGISTRYINDEX, "_MODULES", 1); + RequireResolver::ResolvedRequire resolvedRequire = RequireResolver::resolveRequire(L, std::move(name)); - // return the module from the cache - lua_getfield(L, -1, name.c_str()); - if (!lua_isnil(L, -1)) - { - // L stack: _MODULES result + if (resolvedRequire.status == RequireResolver::ModuleStatus::Cached) return finishrequire(L); - } - - lua_pop(L, 1); - - std::optional source = readFile(name + ".luau"); - if (!source) - { - source = readFile(name + ".lua"); // try .lua if .luau doesn't exist - if (!source) - luaL_argerrorL(L, 1, ("error loading " + name).c_str()); // if neither .luau nor .lua exist, we have an error - } + else if (resolvedRequire.status == RequireResolver::ModuleStatus::Ambiguous) + luaL_errorL(L, "require path could not be resolved to a unique file"); + else if (resolvedRequire.status == RequireResolver::ModuleStatus::NotFound) + luaL_errorL(L, "error requiring module"); // module needs to run in a new thread, isolated from the rest // note: we create ML on main thread so that it doesn't inherit environment of L @@ -167,12 +144,19 @@ static int lua_require(lua_State* L) luaL_sandboxthread(ML); // now we can compile & run module on the new thread +/*HEAD std::string bytecode = Luau::compile(*source, chunkname, copts()); if (luau_load(ML, chunkname.c_str(), bytecode.data(), bytecode.size(), 0) == 0) + */ + std::string bytecode = Luau::compile(resolvedRequire.sourceCode, resolvedRequire.chunkName, copts()); + if (luau_load(ML, resolvedRequire.chunkName.c_str(), bytecode.data(), bytecode.size(), 0) == 0) { #ifndef NO_CODEGEN if (codegen) - Luau::CodeGen::compile(ML, -1); + { + Luau::CodeGen::CompilationOptions nativeOptions; + Luau::CodeGen::compile(ML, -1, nativeOptions); + } #endif if (coverageActive()) coverageTrack(ML, -1); @@ -199,7 +183,7 @@ static int lua_require(lua_State* L) // there's now a return value on top of ML; L stack: _MODULES ML lua_xmove(ML, L, 1); lua_pushvalue(L, -1); - lua_setfield(L, -4, name.c_str()); + lua_setfield(L, -4, resolvedRequire.absolutePath.c_str()); // L stack: _MODULES ML result return finishrequire(L); @@ -280,8 +264,18 @@ void setupState(lua_State* L) luaL_sandbox(L); } +void setupArguments(lua_State* L, int argc, char** argv) +{ + lua_checkstack(L, argc); + + for (int i = 0; i < argc; ++i) + lua_pushstring(L, argv[i]); +} + std::string runCode(lua_State* L, const std::string& source) { + lua_checkstack(L, LUA_MINSTACK); + std::string bytecode = Luau::compile(source, "=stdin", copts()); if (luau_load(L, "=stdin", bytecode.data(), bytecode.size(), 0) != 0) @@ -408,8 +402,13 @@ static void safeGetTable(lua_State* L, int tableIndex) // completePartialMatches finds keys that match the specified 'prefix' // Note: the table/object to be searched must be on the top of the Lua stack -static void completePartialMatches(lua_State* L, bool completeOnlyFunctions, const std::string& editBuffer, std::string_view prefix, - const AddCompletionCallback& addCompletionCallback) +static void completePartialMatches( + lua_State* L, + bool completeOnlyFunctions, + const std::string& editBuffer, + std::string_view prefix, + const AddCompletionCallback& addCompletionCallback +) { for (int i = 0; i < MaxTraversalLimit && lua_istable(L, -1); i++) { @@ -456,6 +455,8 @@ static void completeIndexer(lua_State* L, const std::string& editBuffer, const A std::string_view lookup = editBuffer; bool completeOnlyFunctions = false; + lua_checkstack(L, LUA_MINSTACK); + // Push the global variable table to begin the search lua_pushvalue(L, LUA_GLOBALSINDEX); @@ -501,9 +502,14 @@ static void icGetCompletions(ic_completion_env_t* cenv, const char* editBuffer) { auto* L = reinterpret_cast(ic_completion_arg(cenv)); - getCompletions(L, std::string(editBuffer), [cenv](const std::string& completion, const std::string& display) { - ic_add_completion_ex(cenv, completion.data(), display.data(), nullptr); - }); + getCompletions( + L, + std::string(editBuffer), + [cenv](const std::string& completion, const std::string& display) + { + ic_add_completion_ex(cenv, completion.data(), display.data(), nullptr); + } + ); } static bool isMethodOrFunctionChar(const char* s, long len) @@ -626,16 +632,21 @@ static bool runFile(const char* name, lua_State* GL, bool repl) std::string bytecode = Luau::compile(*source, chunkname, copts()); int status = 0; - if (luau_load(L, chunkname.c_str(), bytecode.data(), bytecode.size(), 0) == 0) + if (luau_load(L, chunkname.c_str(), bytecode.data(), bytecode.size(), 0) == -1) { #ifndef NO_CODEGEN if (codegen) - Luau::CodeGen::compile(L, -1); + { + Luau::CodeGen::CompilationOptions nativeOptions; + Luau::CodeGen::compile(L, -1, nativeOptions); + } #endif + if (coverageActive()) coverageTrack(L, -1); - status = lua_resume(L, NULL, 0); + setupArguments(L, program_argc, program_argv); + status = lua_resume(L, NULL, program_argc); } else { @@ -669,161 +680,11 @@ static bool runFile(const char* name, lua_State* GL, bool repl) return status == 0; } -static void report(const char* name, const Luau::Location& location, const char* type, const char* message) -{ - fprintf(stderr, "%s(%d,%d): %s: %s\n", name, location.begin.line + 1, location.begin.column + 1, type, message); -} - -static void reportError(const char* name, const Luau::ParseError& error) -{ - report(name, error.getLocation(), "SyntaxError", error.what()); -} - -static void reportError(const char* name, const Luau::CompileError& error) -{ - report(name, error.getLocation(), "CompileError", error.what()); -} - -#ifndef NO_CODEGEN -static std::string getCodegenAssembly(const char* name, const std::string& bytecode, Luau::CodeGen::AssemblyOptions options) -{ - std::unique_ptr globalState(luaL_newstate(), lua_close); - lua_State* L = globalState.get(); - - if (luau_load(L, name, bytecode.data(), bytecode.size(), 0) == 0) - return Luau::CodeGen::getAssembly(L, -1, options); - - fprintf(stderr, "Error loading bytecode %s\n", name); - return ""; -} -#endif -static void annotateInstruction(void* context, std::string& text, int fid, int instpos) -{ - Luau::BytecodeBuilder& bcb = *(Luau::BytecodeBuilder*)context; - - bcb.annotateInstruction(text, fid, instpos); -} - -struct CompileStats -{ - size_t lines; - size_t bytecode; - size_t codegen; -}; - -static bool compileFile(const char* name, CompileFormat format, CompileStats& stats) -{ - std::optional source = readFile(name); - if (!source) - { - fprintf(stderr, "Error opening %s\n", name); - return false; - } - - // NOTE: Normally, you should use Luau::compile or luau_compile (see lua_require as an example) - // This function is much more complicated because it supports many output human-readable formats through internal interfaces - - try - { - Luau::BytecodeBuilder bcb; - bcb.setChunkName(name); - -#ifndef NO_CODEGEN - Luau::CodeGen::AssemblyOptions options; - options.outputBinary = format == CompileFormat::CodegenNull; - - if (!options.outputBinary) - { - options.includeAssembly = format != CompileFormat::CodegenIr; - options.includeIr = format != CompileFormat::CodegenAsm; - options.includeOutlinedCode = format == CompileFormat::CodegenVerbose; - } - - options.annotator = annotateInstruction; - options.annotatorContext = &bcb; -#endif - - if (format == CompileFormat::Text) - { - bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Code | Luau::BytecodeBuilder::Dump_Source | Luau::BytecodeBuilder::Dump_Locals | - Luau::BytecodeBuilder::Dump_Remarks); - bcb.setDumpSource(*source); - } - else if (format == CompileFormat::Remarks) - { - bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Source | Luau::BytecodeBuilder::Dump_Remarks); - bcb.setDumpSource(*source); - } - else if (format == CompileFormat::Codegen || format == CompileFormat::CodegenAsm || format == CompileFormat::CodegenIr || - format == CompileFormat::CodegenVerbose) - { - bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Code | Luau::BytecodeBuilder::Dump_Source | Luau::BytecodeBuilder::Dump_Locals | - Luau::BytecodeBuilder::Dump_Remarks); - bcb.setDumpSource(*source); - } - - Luau::Allocator allocator; - Luau::AstNameTable names(allocator); - Luau::ParseResult result = Luau::Parser::parse(source->c_str(), source->size(), names, allocator); - - if (!result.errors.empty()) - throw Luau::ParseErrors(result.errors); - - stats.lines += result.lines; - - Luau::compileOrThrow(bcb, result, names, copts()); - stats.bytecode += bcb.getBytecode().size(); - - switch (format) - { - case CompileFormat::Text: - printf("%s", bcb.dumpEverything().c_str()); - break; - case CompileFormat::Remarks: - printf("%s", bcb.dumpSourceRemarks().c_str()); - break; - case CompileFormat::Binary: - fwrite(bcb.getBytecode().data(), 1, bcb.getBytecode().size(), stdout); - break; -#ifndef NO_CODEGEN - case CompileFormat::Codegen: - case CompileFormat::CodegenAsm: - case CompileFormat::CodegenIr: - case CompileFormat::CodegenVerbose: - printf("%s", getCodegenAssembly(name, bcb.getBytecode(), options).c_str()); - break; - case CompileFormat::CodegenNull: - stats.codegen += getCodegenAssembly(name, bcb.getBytecode(), options).size(); - break; -#endif - case CompileFormat::Null: - break; - } - - return true; - } - catch (Luau::ParseErrors& e) - { - for (auto& error : e.getErrors()) - reportError(name, error); - return false; - } - catch (Luau::CompileError& e) - { - reportError(name, e); - return false; - } -} - static void displayHelp(const char* argv0) { - printf("Usage: %s [--mode] [options] [file list]\n", argv0); + printf("Usage: %s [options] [file list] [-a] [arg list]\n", argv0); printf("\n"); - printf("When mode and file list are omitted, an interactive REPL is started instead.\n"); - printf("\n"); - printf("Available modes:\n"); - printf(" omitted: compile and run input files one by one\n"); - printf(" --compile[=format]: compile input files and output resulting bytecode/assembly (binary, text, remarks, codegen)\n"); + printf("When file list is omitted, an interactive REPL is started instead.\n"); printf("\n"); printf("Available options:\n"); printf(" --coverage: collect code coverage while running the code and output results to coverage.out\n"); @@ -834,6 +695,7 @@ static void displayHelp(const char* argv0) printf(" --profile[=N]: profile the code using N Hz sampling (default 10000) and output results to profile.out\n"); printf(" --timetrace: record compiler time tracing information into trace.json\n"); printf(" --codegen: execute code using native code generation\n"); + printf(" --program-args,-a: declare start of arguments to be passed to the Luau program\n"); } static int assertionHandler(const char* expr, const char* file, int line, const char* function) @@ -848,66 +710,17 @@ int replMain(int argc, char** argv) setLuauFlagsDefault(); - CliMode mode = CliMode::Unknown; - CompileFormat compileFormat{}; +#ifdef _WIN32 + SetConsoleOutputCP(CP_UTF8); +#endif + int profile = 0; bool coverage = false; bool interactive = false; + bool codegenPerf = false; + int program_args = argc; - // Set the mode if the user has explicitly specified one. - int argStart = 1; - if (argc >= 2 && strncmp(argv[1], "--compile", strlen("--compile")) == 0) - { - argStart++; - mode = CliMode::Compile; - if (strcmp(argv[1], "--compile") == 0) - { - compileFormat = CompileFormat::Text; - } - else if (strcmp(argv[1], "--compile=binary") == 0) - { - compileFormat = CompileFormat::Binary; - } - else if (strcmp(argv[1], "--compile=text") == 0) - { - compileFormat = CompileFormat::Text; - } - else if (strcmp(argv[1], "--compile=remarks") == 0) - { - compileFormat = CompileFormat::Remarks; - } - else if (strcmp(argv[1], "--compile=codegen") == 0) - { - compileFormat = CompileFormat::Codegen; - } - else if (strcmp(argv[1], "--compile=codegenasm") == 0) - { - compileFormat = CompileFormat::CodegenAsm; - } - else if (strcmp(argv[1], "--compile=codegenir") == 0) - { - compileFormat = CompileFormat::CodegenIr; - } - else if (strcmp(argv[1], "--compile=codegenverbose") == 0) - { - compileFormat = CompileFormat::CodegenVerbose; - } - else if (strcmp(argv[1], "--compile=codegennull") == 0) - { - compileFormat = CompileFormat::CodegenNull; - } - else if (strcmp(argv[1], "--compile=null") == 0) - { - compileFormat = CompileFormat::Null; - } - else - { - fprintf(stderr, "Error: Unrecognized value for '--compile' specified.\n"); - return 1; - } - } - - for (int i = argStart; i < argc; i++) + for (int i = 1; i < argc; i++) { if (strcmp(argv[i], "-h") == 0 || strcmp(argv[i], "--help") == 0) { @@ -950,6 +763,11 @@ int replMain(int argc, char** argv) { codegen = true; } + else if (strcmp(argv[i], "--codegen-perf") == 0) + { + codegen = true; + codegenPerf = true; + } else if (strcmp(argv[i], "--coverage") == 0) { coverage = true; @@ -962,6 +780,11 @@ int replMain(int argc, char** argv) { setLuauFlags(argv[i] + 9); } + else if (strcmp(argv[i], "--program-args") == 0 || strcmp(argv[i], "-a") == 0) + { + program_args = i + 1; + break; + } else if (argv[i][0] == '-') { fprintf(stderr, "Error: Unrecognized option '%s'.\n\n", argv[i]); @@ -970,6 +793,10 @@ int replMain(int argc, char** argv) } } + program_argc = argc - program_args; + program_argv = &argv[program_args]; + + #if !defined(LUAU_ENABLE_TIME_TRACE) if (FFlag::DebugLuauTimeTracing) { @@ -978,55 +805,39 @@ int replMain(int argc, char** argv) } #endif -#if !LUA_CUSTOM_EXECUTION - if (codegen) + if (codegenPerf) { - fprintf(stderr, "To run with --codegen, Luau has to be built with LUA_CUSTOM_EXECUTION enabled\n"); - return 1; - } -#endif +#if __linux__ + char path[128]; + snprintf(path, sizeof(path), "/tmp/perf-%d.map", getpid()); - const std::vector files = getSourceFiles(argc, argv); - if (mode == CliMode::Unknown) - { - mode = files.empty() ? CliMode::Repl : CliMode::RunSourceFiles; - } -#ifndef NO_CODEGEN - if (mode != CliMode::Compile && codegen && !Luau::CodeGen::isSupported()) - { - fprintf(stderr, "Cannot enable --codegen, native code generation is not supported in current configuration\n"); + // note, there's no need to close the log explicitly as it will be closed when the process exits + FILE* codegenPerfLog = fopen(path, "w"); + + Luau::CodeGen::setPerfLog( + codegenPerfLog, + [](void* context, uintptr_t addr, unsigned size, const char* symbol) + { + fprintf(static_cast(context), "%016lx %08x %s\n", long(addr), size, symbol); + } + ); +#else + fprintf(stderr, "--codegen-perf option is only supported on Linux\n"); return 1; - } -#endif - switch (mode) - { - case CliMode::Compile: - { -#ifdef _WIN32 - if (compileFormat == CompileFormat::Binary) - _setmode(_fileno(stdout), _O_BINARY); #endif + } - CompileStats stats = {}; - int failed = 0; - - for (const std::string& path : files) - failed += !compileFile(path.c_str(), compileFormat, stats); - - if (compileFormat == CompileFormat::Null) - printf("Compiled %d KLOC into %d KB bytecode\n", int(stats.lines / 1000), int(stats.bytecode / 1024)); - else if (compileFormat == CompileFormat::CodegenNull) - printf("Compiled %d KLOC into %d KB bytecode => %d KB native code\n", int(stats.lines / 1000), int(stats.bytecode / 1024), - int(stats.codegen / 1024)); +/* if (codegen && !Luau::CodeGen::isSupported()) + fprintf(stderr, "Warning: Native code generation is not supported in current configuration\n"); +*/ + const std::vector files = getSourceFiles(argc, argv); - return failed ? 1 : 0; - } - case CliMode::Repl: + if (files.empty()) { runRepl(); return 0; } - case CliMode::RunSourceFiles: + else { std::unique_ptr globalState(luaL_newstate(), lua_close); lua_State* L = globalState.get(); @@ -1058,9 +869,4 @@ int replMain(int argc, char** argv) return failed ? 1 : 0; } - case CliMode::Unknown: - default: - LUAU_ASSERT(!"Unhandled cli mode."); - return 1; - } } diff --git a/CLI/Require.cpp b/CLI/Require.cpp new file mode 100644 index 000000000..9a00597af --- /dev/null +++ b/CLI/Require.cpp @@ -0,0 +1,284 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Require.h" + +#include "FileUtils.h" +#include "Luau/Common.h" + +#include +#include +#include + +RequireResolver::RequireResolver(lua_State* L, std::string path) + : pathToResolve(std::move(path)) + , L(L) +{ + lua_Debug ar; + lua_getinfo(L, 1, "s", &ar); + sourceChunkname = ar.source; + + if (!isRequireAllowed(sourceChunkname)) + luaL_errorL(L, "require is not supported in this context"); + + if (isAbsolutePath(pathToResolve)) + luaL_argerrorL(L, 1, "cannot require an absolute path"); + + std::replace(pathToResolve.begin(), pathToResolve.end(), '\\', '/'); + + if (!isPrefixValid()) + luaL_argerrorL(L, 1, "require path must start with a valid prefix: ./, ../, or @"); + + substituteAliasIfPresent(pathToResolve); +} + +[[nodiscard]] RequireResolver::ResolvedRequire RequireResolver::resolveRequire(lua_State* L, std::string path) +{ + RequireResolver resolver(L, std::move(path)); + ModuleStatus status = resolver.findModule(); + if (status != ModuleStatus::FileRead) + return ResolvedRequire{status}; + else + return ResolvedRequire{status, std::move(resolver.chunkname), std::move(resolver.absolutePath), std::move(resolver.sourceCode)}; +} + +RequireResolver::ModuleStatus RequireResolver::findModule() +{ + resolveAndStoreDefaultPaths(); + + // Put _MODULES table on stack for checking and saving to the cache + luaL_findtable(L, LUA_REGISTRYINDEX, "_MODULES", 1); + + return findModuleImpl(); +} + +RequireResolver::ModuleStatus RequireResolver::findModuleImpl() +{ + if (isPathAmbiguous(absolutePath)) + return ModuleStatus::Ambiguous; + + static const std::array possibleSuffixes = {".luau", ".lua", "/init.luau", "/init.lua"}; + + size_t unsuffixedAbsolutePathSize = absolutePath.size(); + + for (const char* possibleSuffix : possibleSuffixes) + { + absolutePath += possibleSuffix; + + // Check cache for module + lua_getfield(L, -1, absolutePath.c_str()); + if (!lua_isnil(L, -1)) + { + return ModuleStatus::Cached; + } + lua_pop(L, 1); + + // Try to read the matching file + std::optional source = readFile(absolutePath); + if (source) + { + chunkname = "=" + chunkname + possibleSuffix; + sourceCode = *source; + return ModuleStatus::FileRead; + } + + absolutePath.resize(unsuffixedAbsolutePathSize); // truncate to remove suffix + } + + return ModuleStatus::NotFound; +} + +bool RequireResolver::isPathAmbiguous(const std::string& path) +{ + bool found = false; + for (const char* suffix : {".luau", ".lua"}) + { + if (isFile(path + suffix)) + { + if (found) + return true; + else + found = true; + } + } + if (isDirectory(path) && found) + return true; + + return false; +} + +bool RequireResolver::isRequireAllowed(std::string_view sourceChunkname) +{ + LUAU_ASSERT(!sourceChunkname.empty()); + return (sourceChunkname[0] == '=' || sourceChunkname[0] == '@'); +} + +bool RequireResolver::isPrefixValid() +{ + return pathToResolve.compare(0, 2, "./") == 0 || pathToResolve.compare(0, 3, "../") == 0 || pathToResolve.compare(0, 1, "@") == 0; +} + +void RequireResolver::resolveAndStoreDefaultPaths() +{ + if (!isAbsolutePath(pathToResolve)) + { + std::string chunknameContext = getRequiringContextRelative(); + std::optional absolutePathContext = getRequiringContextAbsolute(); + + if (!absolutePathContext) + luaL_errorL(L, "error requiring module"); + + // resolvePath automatically sanitizes/normalizes the paths + std::optional chunknameOpt = resolvePath(pathToResolve, chunknameContext); + std::optional absolutePathOpt = resolvePath(pathToResolve, *absolutePathContext); + + if (!chunknameOpt || !absolutePathOpt) + luaL_errorL(L, "error requiring module"); + + chunkname = std::move(*chunknameOpt); + absolutePath = std::move(*absolutePathOpt); + } + else + { + // Here we must explicitly sanitize, as the path is taken as is + std::optional sanitizedPath = normalizePath(pathToResolve); + if (!sanitizedPath) + luaL_errorL(L, "error requiring module"); + + chunkname = *sanitizedPath; + absolutePath = std::move(*sanitizedPath); + } +} + +std::optional RequireResolver::getRequiringContextAbsolute() +{ + std::string requiringFile; + if (isAbsolutePath(sourceChunkname.substr(1))) + { + // We already have an absolute path for the requiring file + requiringFile = sourceChunkname.substr(1); + } + else + { + // Requiring file's stored path is relative to the CWD, must make absolute + std::optional cwd = getCurrentWorkingDirectory(); + if (!cwd) + return std::nullopt; + + if (sourceChunkname.substr(1) == "stdin") + { + // Require statement is being executed from REPL input prompt + // The requiring context is the pseudo-file "stdin" in the CWD + requiringFile = joinPaths(*cwd, "stdin"); + } + else + { + // Require statement is being executed in a file, must resolve relative to CWD + std::optional requiringFileOpt = resolvePath(sourceChunkname.substr(1), joinPaths(*cwd, "stdin")); + if (!requiringFileOpt) + return std::nullopt; + + requiringFile = *requiringFileOpt; + } + } + std::replace(requiringFile.begin(), requiringFile.end(), '\\', '/'); + return requiringFile; +} + +std::string RequireResolver::getRequiringContextRelative() +{ + std::string baseFilePath; + if (sourceChunkname.substr(1) != "stdin") + baseFilePath = sourceChunkname.substr(1); + + return baseFilePath; +} + +void RequireResolver::substituteAliasIfPresent(std::string& path) +{ + if (path.size() < 1 || path[0] != '@') + return; + + // To ignore the '@' alias prefix when processing the alias + const size_t aliasStartPos = 1; + + // If a directory separator was found, the length of the alias is the + // distance between the start of the alias and the separator. Otherwise, + // the whole string after the alias symbol is the alias. + size_t aliasLen = path.find_first_of("\\/"); + if (aliasLen != std::string::npos) + aliasLen -= aliasStartPos; + + const std::string potentialAlias = path.substr(aliasStartPos, aliasLen); + + // Not worth searching when potentialAlias cannot be an alias + if (!Luau::isValidAlias(potentialAlias)) + luaL_errorL(L, "@%s is not a valid alias", potentialAlias.c_str()); + + std::optional alias = getAlias(potentialAlias); + if (alias) + { + path = *alias + path.substr(potentialAlias.size() + 1); + } + else + { + luaL_errorL(L, "@%s is not a valid alias", potentialAlias.c_str()); + } +} + +std::optional RequireResolver::getAlias(std::string alias) +{ + std::transform( + alias.begin(), + alias.end(), + alias.begin(), + [](unsigned char c) + { + return ('A' <= c && c <= 'Z') ? (c + ('a' - 'A')) : c; + } + ); + while (!config.aliases.count(alias) && !isConfigFullyResolved) + { + parseNextConfig(); + } + if (!config.aliases.count(alias) && isConfigFullyResolved) + return std::nullopt; // could not find alias + + return resolvePath(config.aliases[alias], joinPaths(lastSearchedDir, Luau::kConfigName)); +} + +void RequireResolver::parseNextConfig() +{ + if (isConfigFullyResolved) + return; // no config files left to parse + + std::optional directory; + if (lastSearchedDir.empty()) + { + std::optional requiringFile = getRequiringContextAbsolute(); + if (!requiringFile) + luaL_errorL(L, "error requiring module"); + + directory = getParentPath(*requiringFile); + } + else + directory = getParentPath(lastSearchedDir); + + if (directory) + { + lastSearchedDir = *directory; + parseConfigInDirectory(*directory); + } + else + isConfigFullyResolved = true; +} + +void RequireResolver::parseConfigInDirectory(const std::string& directory) +{ + std::string configPath = joinPaths(directory, Luau::kConfigName); + + if (std::optional contents = readFile(configPath)) + { + std::optional error = Luau::parseConfig(*contents, config); + if (error) + luaL_errorL(L, "error parsing %s (%s)", configPath.c_str(), (*error).c_str()); + } +} diff --git a/CLI/Require.h b/CLI/Require.h new file mode 100644 index 000000000..9c86c3cca --- /dev/null +++ b/CLI/Require.h @@ -0,0 +1,64 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "lua.h" +#include "lualib.h" + +#include "Luau/Config.h" + +#include +#include + +class RequireResolver +{ +public: + std::string chunkname; + std::string absolutePath; + std::string sourceCode; + + enum class ModuleStatus + { + Cached, + FileRead, + Ambiguous, + NotFound + }; + + struct ResolvedRequire + { + ModuleStatus status; + std::string chunkName; + std::string absolutePath; + std::string sourceCode; + }; + + [[nodiscard]] ResolvedRequire static resolveRequire(lua_State* L, std::string path); + +private: + std::string pathToResolve; + std::string_view sourceChunkname; + + RequireResolver(lua_State* L, std::string path); + + ModuleStatus findModule(); + lua_State* L; + Luau::Config config; + std::string lastSearchedDir; + bool isConfigFullyResolved = false; + + bool isRequireAllowed(std::string_view sourceChunkname); + bool isPrefixValid(); + + void resolveAndStoreDefaultPaths(); + ModuleStatus findModuleImpl(); + bool isPathAmbiguous(const std::string& path); + + std::optional getRequiringContextAbsolute(); + std::string getRequiringContextRelative(); + + void substituteAliasIfPresent(std::string& path); + std::optional getAlias(std::string alias); + + void parseNextConfig(); + void parseConfigInDirectory(const std::string& path); +}; diff --git a/CMakeLists.txt b/CMakeLists.txt index 6e15e5f88..51fa919ef 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -12,7 +12,6 @@ option(LUAU_BUILD_WEB "Build Web module" OFF) option(LUAU_WERROR "Warnings as errors" OFF) option(LUAU_STATIC_CRT "Link with the static CRT (/MT)" OFF) option(LUAU_EXTERN_C "Use extern C for all APIs" OFF) -option(LUAU_NATIVE "Enable support for native code generation" OFF) cmake_policy(SET CMP0054 NEW) cmake_policy(SET CMP0091 NEW) @@ -24,9 +23,12 @@ endif() project(Luau LANGUAGES CXX C) add_library(Luau.Common INTERFACE) +add_library(Luau.CLI.lib STATIC) add_library(Luau.Ast STATIC) add_library(Luau.Compiler STATIC) +add_library(Luau.Config STATIC) add_library(Luau.Analysis STATIC) +add_library(Luau.EqSat STATIC) add_library(Luau.CodeGen STATIC) add_library(Luau.VM STATIC) add_library(isocline STATIC) @@ -36,12 +38,16 @@ if(LUAU_BUILD_CLI) add_executable(Luau.Analyze.CLI) add_executable(Luau.Ast.CLI) add_executable(Luau.Reduce.CLI) + add_executable(Luau.Compile.CLI) + add_executable(Luau.Bytecode.CLI) # This also adds target `name` on Linux/macOS and `name.exe` on Windows set_target_properties(Luau.Repl.CLI PROPERTIES OUTPUT_NAME luau) set_target_properties(Luau.Analyze.CLI PROPERTIES OUTPUT_NAME luau-analyze) set_target_properties(Luau.Ast.CLI PROPERTIES OUTPUT_NAME luau-ast) set_target_properties(Luau.Reduce.CLI PROPERTIES OUTPUT_NAME luau-reduce) + set_target_properties(Luau.Compile.CLI PROPERTIES OUTPUT_NAME luau-compile) + set_target_properties(Luau.Bytecode.CLI PROPERTIES OUTPUT_NAME luau-bytecode) endif() if(LUAU_BUILD_TESTS) @@ -61,17 +67,29 @@ include(Sources.cmake) target_include_directories(Luau.Common INTERFACE Common/include) +target_compile_features(Luau.CLI.lib PUBLIC cxx_std_17) +target_link_libraries(Luau.CLI.lib PRIVATE Luau.Common) + target_compile_features(Luau.Ast PUBLIC cxx_std_17) target_include_directories(Luau.Ast PUBLIC Ast/include) -target_link_libraries(Luau.Ast PUBLIC Luau.Common) +target_link_libraries(Luau.Ast PUBLIC Luau.Common Luau.CLI.lib) target_compile_features(Luau.Compiler PUBLIC cxx_std_17) target_include_directories(Luau.Compiler PUBLIC Compiler/include) target_link_libraries(Luau.Compiler PUBLIC Luau.Ast) +target_compile_features(Luau.Config PUBLIC cxx_std_17) +target_include_directories(Luau.Config PUBLIC Config/include) +target_link_libraries(Luau.Config PUBLIC Luau.Ast) + target_compile_features(Luau.Analysis PUBLIC cxx_std_17) target_include_directories(Luau.Analysis PUBLIC Analysis/include) -target_link_libraries(Luau.Analysis PUBLIC Luau.Ast) +target_link_libraries(Luau.Analysis PUBLIC Luau.Ast Luau.EqSat Luau.Config) +target_link_libraries(Luau.Analysis PRIVATE Luau.Compiler Luau.VM) + +target_compile_features(Luau.EqSat PUBLIC cxx_std_17) +target_include_directories(Luau.EqSat PUBLIC EqSat/include) +target_link_libraries(Luau.EqSat PUBLIC Luau.Common) target_compile_features(Luau.CodeGen PRIVATE cxx_std_17) target_include_directories(Luau.CodeGen PUBLIC CodeGen/include) @@ -94,6 +112,7 @@ if(MSVC) list(APPEND LUAU_OPTIONS "/we4388") # Also signed/unsigned mismatch else() list(APPEND LUAU_OPTIONS -Wall) # All warnings + list(APPEND LUAU_OPTIONS -Wimplicit-fallthrough) list(APPEND LUAU_OPTIONS -Wsign-compare) # This looks to be included in -Wall for GCC but not clang endif() @@ -105,6 +124,8 @@ if (CMAKE_CXX_COMPILER_ID STREQUAL "GNU") # Some gcc versions treat var in `if (type var = val)` as unused # Some gcc versions treat variables used in constexpr if blocks as unused list(APPEND LUAU_OPTIONS -Wno-unused) + # some gcc versions warn maybe uninitialized on optional members on structs + list(APPEND LUAU_OPTIONS -Wno-maybe-uninitialized) endif() # Enabled in CI; we should be warning free on our main compiler versions but don't guarantee being warning free everywhere @@ -129,6 +150,8 @@ endif() target_compile_options(Luau.Ast PRIVATE ${LUAU_OPTIONS}) target_compile_options(Luau.Analysis PRIVATE ${LUAU_OPTIONS}) +target_compile_options(Luau.EqSat PRIVATE ${LUAU_OPTIONS}) +target_compile_options(Luau.CLI.lib PRIVATE ${LUAU_OPTIONS}) target_compile_options(Luau.CodeGen PRIVATE ${LUAU_OPTIONS}) target_compile_options(Luau.VM PRIVATE ${LUAU_OPTIONS}) target_compile_options(isocline PRIVATE ${LUAU_OPTIONS} ${ISOCLINE_OPTIONS}) @@ -139,10 +162,7 @@ if(LUAU_EXTERN_C) target_compile_definitions(Luau.VM PUBLIC LUA_USE_LONGJMP=1) target_compile_definitions(Luau.VM PUBLIC LUA_API=extern\"C\") target_compile_definitions(Luau.Compiler PUBLIC LUACODE_API=extern\"C\") -endif() - -if(LUAU_NATIVE) - target_compile_definitions(Luau.VM PUBLIC LUA_CUSTOM_EXECUTION=1) + target_compile_definitions(Luau.CodeGen PUBLIC LUACODEGEN_API=extern\"C\") endif() if (CMAKE_CXX_COMPILER_ID STREQUAL "MSVC" AND MSVC_VERSION GREATER_EQUAL 1924) @@ -178,30 +198,38 @@ if(MSVC_IDE) target_sources(Luau.VM PRIVATE tools/natvis/VM.natvis) endif() +# On Windows and Android threads are provided, on Linux/Mac/iOS we use pthreads +add_library(osthreads INTERFACE) +if(CMAKE_SYSTEM_NAME MATCHES "Linux|Darwin|iOS") + target_link_libraries(osthreads INTERFACE "-lpthread") +endif () + if(LUAU_BUILD_CLI) target_compile_options(Luau.Repl.CLI PRIVATE ${LUAU_OPTIONS}) target_compile_options(Luau.Reduce.CLI PRIVATE ${LUAU_OPTIONS}) target_compile_options(Luau.Analyze.CLI PRIVATE ${LUAU_OPTIONS}) target_compile_options(Luau.Ast.CLI PRIVATE ${LUAU_OPTIONS}) + target_compile_options(Luau.Compile.CLI PRIVATE ${LUAU_OPTIONS}) + target_compile_options(Luau.Bytecode.CLI PRIVATE ${LUAU_OPTIONS}) target_include_directories(Luau.Repl.CLI PRIVATE extern extern/isocline/include) - target_link_libraries(Luau.Repl.CLI PRIVATE Luau.Compiler Luau.CodeGen Luau.VM isocline) + target_link_libraries(Luau.Repl.CLI PRIVATE Luau.Compiler Luau.Config Luau.CodeGen Luau.VM Luau.CLI.lib isocline) - if(UNIX) - find_library(LIBPTHREAD pthread) - if (LIBPTHREAD) - target_link_libraries(Luau.Repl.CLI PRIVATE pthread) - endif() - endif() + target_link_libraries(Luau.Repl.CLI PRIVATE osthreads) + target_link_libraries(Luau.Analyze.CLI PRIVATE osthreads) - target_link_libraries(Luau.Analyze.CLI PRIVATE Luau.Analysis) + target_link_libraries(Luau.Analyze.CLI PRIVATE Luau.Analysis Luau.CLI.lib) - target_link_libraries(Luau.Ast.CLI PRIVATE Luau.Ast Luau.Analysis) + target_link_libraries(Luau.Ast.CLI PRIVATE Luau.Ast Luau.Analysis Luau.CLI.lib) target_compile_features(Luau.Reduce.CLI PRIVATE cxx_std_17) target_include_directories(Luau.Reduce.CLI PUBLIC Reduce/include) - target_link_libraries(Luau.Reduce.CLI PRIVATE Luau.Common Luau.Ast Luau.Analysis) + target_link_libraries(Luau.Reduce.CLI PRIVATE Luau.Common Luau.Ast Luau.Analysis Luau.CLI.lib) + + target_link_libraries(Luau.Compile.CLI PRIVATE Luau.Compiler Luau.VM Luau.CodeGen Luau.CLI.lib) + + target_link_libraries(Luau.Bytecode.CLI PRIVATE Luau.Compiler Luau.VM Luau.CodeGen Luau.CLI.lib) endif() if(LUAU_BUILD_TESTS) @@ -211,18 +239,20 @@ if(LUAU_BUILD_TESTS) target_link_libraries(Luau.UnitTest PRIVATE Luau.Analysis Luau.Compiler Luau.CodeGen) target_compile_options(Luau.Conformance PRIVATE ${LUAU_OPTIONS}) + target_compile_definitions(Luau.Conformance PRIVATE DOCTEST_CONFIG_DOUBLE_STRINGIFY) target_include_directories(Luau.Conformance PRIVATE extern) target_link_libraries(Luau.Conformance PRIVATE Luau.Analysis Luau.Compiler Luau.CodeGen Luau.VM) + if(CMAKE_SYSTEM_NAME MATCHES "Android|iOS") + set(LUAU_CONFORMANCE_SOURCE_DIR "Client/Luau/tests/conformance") + else () + file(REAL_PATH "tests/conformance" LUAU_CONFORMANCE_SOURCE_DIR) + endif () + target_compile_definitions(Luau.Conformance PRIVATE LUAU_CONFORMANCE_SOURCE_DIR="${LUAU_CONFORMANCE_SOURCE_DIR}") target_compile_options(Luau.CLI.Test PRIVATE ${LUAU_OPTIONS}) target_include_directories(Luau.CLI.Test PRIVATE extern CLI) - target_link_libraries(Luau.CLI.Test PRIVATE Luau.Compiler Luau.CodeGen Luau.VM isocline) - if(UNIX) - find_library(LIBPTHREAD pthread) - if (LIBPTHREAD) - target_link_libraries(Luau.CLI.Test PRIVATE pthread) - endif() - endif() + target_link_libraries(Luau.CLI.Test PRIVATE Luau.Compiler Luau.Config Luau.CodeGen Luau.VM Luau.CLI.lib isocline) + target_link_libraries(Luau.CLI.Test PRIVATE osthreads) endif() @@ -239,3 +269,21 @@ if(LUAU_BUILD_WEB) # the output is a single .js file with an embedded wasm blob target_link_options(Luau.Web PRIVATE -sSINGLE_FILE=1) endif() + +add_subdirectory(fuzz) + +# validate dependencies for internal libraries +foreach(LIB Luau.Ast Luau.Compiler Luau.Config Luau.Analysis Luau.EqSat Luau.CodeGen Luau.VM) + if(TARGET ${LIB}) + get_target_property(DEPENDS ${LIB} LINK_LIBRARIES) + if(LIB MATCHES "CodeGen|VM" AND DEPENDS MATCHES "Ast|Analysis|Config|Compiler") + message(FATAL_ERROR ${LIB} " is a runtime component but it depends on one of the offline components") + endif() + if(LIB MATCHES "Ast|EqSat|Compiler" AND DEPENDS MATCHES "CodeGen|VM") + message(FATAL_ERROR ${LIB} " is an offline component but it depends on one of the runtime components") + endif() + if(LIB MATCHES "Ast|Compiler" AND DEPENDS MATCHES "Analysis|Config") + message(FATAL_ERROR ${LIB} " is a compiler component but it depends on one of the analysis components") + endif() + endif() +endforeach() diff --git a/CMakePresets.json b/CMakePresets.json new file mode 100644 index 000000000..94bb6f023 --- /dev/null +++ b/CMakePresets.json @@ -0,0 +1,47 @@ +{ + "version": 6, + "configurePresets": [ + { + "name": "fuzz", + "displayName": "Fuzz", + "description": "Configures required fuzzer settings.", + "binaryDir": "build", + "condition": { + "type": "anyOf", + "conditions": [ + { + "type": "equals", + "lhs": "${hostSystemName}", + "rhs": "Darwin" + }, + { + "type": "equals", + "lhs": "${hostSystemName}", + "rhs": "Linux" + } + ] + }, + "cacheVariables": { + "CMAKE_OSX_ARCHITECTURES": "x86_64", + "CMAKE_BUILD_TYPE": "Release", + "CMAKE_CXX_STANDARD": "17", + "CMAKE_CXX_EXTENSIONS": false + }, + "warnings": { + "dev": false + } + } + ], + "buildPresets": [ + { + "name": "fuzz-proto", + "displayName": "Protobuf Fuzzer", + "description": "Builds the protobuf-based fuzzer and transpiler tools.", + "configurePreset": "fuzz", + "targets": [ + "Luau.Fuzz.Proto", + "Luau.Fuzz.ProtoTest" + ] + } + ] +} diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 460e4b43a..265797406 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -2,20 +2,20 @@ Thanks for deciding to contribute to Luau! These guidelines will try to help mak ## Questions -If you have a question regarding the language usage/implementation, please [use GitHub discussions](https://github.com/Roblox/luau/discussions). +If you have a question regarding the language usage/implementation, please [use GitHub discussions](https://github.com/luau-lang/luau/discussions). Some questions just need answers, but it's nice to keep them for future reference in case other people want to know the same thing. Some questions help improve the language, implementation or documentation by inspiring future changes. ## Documentation -This repository hosts the language documentation in addition to implementation, which is accessible on https://luau-lang.org. +A [separate site repository](https://github.com/luau-lang/site) hosts the language documentation, which is accessible on https://luau-lang.org. Changes to this documentation that improve clarity, fix grammatical issues, explain aspects that haven't been explained before and the like are warmly welcomed. Please feel free to [create a pull request](https://help.github.com/articles/about-pull-requests/) to improve our documentation. Note that at this point the documentation is English-only. ## Bugs -If the language implementation doesn't compile on your system, compiles with warnings, doesn't seem to run correctly for your code or if anything else is amiss, please [open a GitHub issue](https://github.com/Roblox/luau/issues/new). +If the language implementation doesn't compile on your system, compiles with warnings, doesn't seem to run correctly for your code or if anything else is amiss, please [open a GitHub issue](https://github.com/luau-lang/luau/issues/new). It helps if you note the Git revision issue happens in, the version of your compiler for compilation issues, and a reproduction case for runtime bugs. Of course, feel free to [create a pull request](https://help.github.com/articles/about-pull-requests/) to fix the bug yourself. @@ -25,7 +25,7 @@ Of course, feel free to [create a pull request](https://help.github.com/articles If you're thinking of adding a new feature to the language, library, analysis tools, etc., please *don't* start by submitting a pull request. Luau team has internal priorities and a roadmap that may or may not align with specific features, so before starting to work on a feature please submit an issue describing the missing feature that you'd like to add. -For features that result in observable change of language syntax or semantics, you'd need to [create an RFC](https://github.com/Roblox/luau/blob/master/rfcs/README.md) to make sure that the feature is needed and well designed. +For features that result in observable change of language syntax or semantics, you'd need to [create an RFC](https://github.com/luau-lang/rfcs/blob/master/README.md) to make sure that the feature is needed and well designed. Finally, please note that Luau tries to carry a minimal feature set. All features must be evaluated not just for the benefits that they provide, but also for the downsides/costs in terms of language simplicity, maintainability, cross-feature interaction etc. As such, feature requests may not be accepted even if a comprehensive RFC is written - don't expect Luau to gain a feature just because another programming language has it. diff --git a/CodeGen/include/Luau/AddressA64.h b/CodeGen/include/Luau/AddressA64.h index 2c852046c..fbac3ec38 100644 --- a/CodeGen/include/Luau/AddressA64.h +++ b/CodeGen/include/Luau/AddressA64.h @@ -3,6 +3,8 @@ #include "Luau/RegisterA64.h" +#include + namespace Luau { namespace CodeGen @@ -12,35 +14,36 @@ namespace A64 enum class AddressKindA64 : uint8_t { - imm, // reg + imm - reg, // reg + reg - - // TODO: - // reg + reg << shift - // reg + sext(reg) << shift - // reg + uext(reg) << shift + reg, // reg + reg + imm, // reg + imm + pre, // reg + imm, reg += imm + post, // reg, reg += imm }; struct AddressA64 { - AddressA64(RegisterA64 base, int off = 0) - : kind(AddressKindA64::imm) + // This is a little misleading since AddressA64 can encode offsets up to 1023*size where size depends on the load/store size + // For example, ldr x0, [reg+imm] is limited to 8 KB offsets assuming imm is divisible by 8, but loading into w0 reduces the range to 4 KB + static constexpr size_t kMaxOffset = 1023; + + constexpr AddressA64(RegisterA64 base, int off = 0, AddressKindA64 kind = AddressKindA64::imm) + : kind(kind) , base(base) , offset(xzr) , data(off) { - LUAU_ASSERT(base.kind == KindA64::x || base == sp); - LUAU_ASSERT(off >= -256 && off < 4096); + CODEGEN_ASSERT(base.kind == KindA64::x || base == sp); + CODEGEN_ASSERT(kind != AddressKindA64::reg); } - AddressA64(RegisterA64 base, RegisterA64 offset) + constexpr AddressA64(RegisterA64 base, RegisterA64 offset) : kind(AddressKindA64::reg) , base(base) , offset(offset) , data(0) { - LUAU_ASSERT(base.kind == KindA64::x); - LUAU_ASSERT(offset.kind == KindA64::x); + CODEGEN_ASSERT(base.kind == KindA64::x); + CODEGEN_ASSERT(offset.kind == KindA64::x); } AddressKindA64 kind; diff --git a/CodeGen/include/Luau/AssemblyBuilderA64.h b/CodeGen/include/Luau/AssemblyBuilderA64.h index 94d8f8114..a4d857a49 100644 --- a/CodeGen/include/Luau/AssemblyBuilderA64.h +++ b/CodeGen/include/Luau/AssemblyBuilderA64.h @@ -16,36 +16,54 @@ namespace CodeGen namespace A64 { +enum FeaturesA64 +{ + Feature_JSCVT = 1 << 0, +}; + class AssemblyBuilderA64 { public: - explicit AssemblyBuilderA64(bool logText); + explicit AssemblyBuilderA64(bool logText, unsigned int features = 0); ~AssemblyBuilderA64(); // Moves void mov(RegisterA64 dst, RegisterA64 src); - void mov(RegisterA64 dst, uint16_t src, int shift = 0); + void mov(RegisterA64 dst, int src); // macro + + // Moves of 32-bit immediates get decomposed into one or more of these + void movz(RegisterA64 dst, uint16_t src, int shift = 0); + void movn(RegisterA64 dst, uint16_t src, int shift = 0); void movk(RegisterA64 dst, uint16_t src, int shift = 0); // Arithmetics void add(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, int shift = 0); - void add(RegisterA64 dst, RegisterA64 src1, int src2); + void add(RegisterA64 dst, RegisterA64 src1, uint16_t src2); void sub(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, int shift = 0); - void sub(RegisterA64 dst, RegisterA64 src1, int src2); + void sub(RegisterA64 dst, RegisterA64 src1, uint16_t src2); void neg(RegisterA64 dst, RegisterA64 src); // Comparisons // Note: some arithmetic instructions also have versions that update flags (ADDS etc) but we aren't using them atm void cmp(RegisterA64 src1, RegisterA64 src2); - void cmp(RegisterA64 src1, int src2); + void cmp(RegisterA64 src1, uint16_t src2); + void csel(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, ConditionA64 cond); + void cset(RegisterA64 dst, ConditionA64 cond); // Bitwise - // Note: shifted-register support and bitfield operations are omitted for simplicity - // TODO: support immediate arguments (they have odd encoding and forbid many values) - void and_(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2); - void orr(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2); - void eor(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2); - void mvn(RegisterA64 dst, RegisterA64 src); + void and_(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, int shift = 0); + void orr(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, int shift = 0); + void eor(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, int shift = 0); + void bic(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, int shift = 0); + void tst(RegisterA64 src1, RegisterA64 src2, int shift = 0); + void mvn_(RegisterA64 dst, RegisterA64 src); + + // Bitwise with immediate + // Note: immediate must have a single contiguous sequence of 1 bits set of length 1..31 + void and_(RegisterA64 dst, RegisterA64 src1, uint32_t src2); + void orr(RegisterA64 dst, RegisterA64 src1, uint32_t src2); + void eor(RegisterA64 dst, RegisterA64 src1, uint32_t src2); + void tst(RegisterA64 src1, uint32_t src2); // Shifts void lsl(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2); @@ -54,6 +72,20 @@ class AssemblyBuilderA64 void ror(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2); void clz(RegisterA64 dst, RegisterA64 src); void rbit(RegisterA64 dst, RegisterA64 src); + void rev(RegisterA64 dst, RegisterA64 src); + + // Shifts with immediates + // Note: immediate value must be in [0, 31] or [0, 63] range based on register type + void lsl(RegisterA64 dst, RegisterA64 src1, uint8_t src2); + void lsr(RegisterA64 dst, RegisterA64 src1, uint8_t src2); + void asr(RegisterA64 dst, RegisterA64 src1, uint8_t src2); + void ror(RegisterA64 dst, RegisterA64 src1, uint8_t src2); + + // Bitfields + void ubfiz(RegisterA64 dst, RegisterA64 src, uint8_t f, uint8_t w); + void ubfx(RegisterA64 dst, RegisterA64 src, uint8_t f, uint8_t w); + void sbfiz(RegisterA64 dst, RegisterA64 src, uint8_t f, uint8_t w); + void sbfx(RegisterA64 dst, RegisterA64 src, uint8_t f, uint8_t w); // Load // Note: paired loads are currently omitted for simplicity @@ -63,27 +95,76 @@ class AssemblyBuilderA64 void ldrsb(RegisterA64 dst, AddressA64 src); void ldrsh(RegisterA64 dst, AddressA64 src); void ldrsw(RegisterA64 dst, AddressA64 src); + void ldp(RegisterA64 dst1, RegisterA64 dst2, AddressA64 src); // Store void str(RegisterA64 src, AddressA64 dst); void strb(RegisterA64 src, AddressA64 dst); void strh(RegisterA64 src, AddressA64 dst); + void stp(RegisterA64 src1, RegisterA64 src2, AddressA64 dst); // Control flow - // Note: tbz/tbnz are currently not supported because they have 15-bit offsets and we don't support branch thunks void b(Label& label); - void b(ConditionA64 cond, Label& label); - void cbz(RegisterA64 src, Label& label); - void cbnz(RegisterA64 src, Label& label); + void bl(Label& label); void br(RegisterA64 src); void blr(RegisterA64 src); void ret(); + // Conditional control flow + void b(ConditionA64 cond, Label& label); + void cbz(RegisterA64 src, Label& label); + void cbnz(RegisterA64 src, Label& label); + void tbz(RegisterA64 src, uint8_t bit, Label& label); + void tbnz(RegisterA64 src, uint8_t bit, Label& label); + // Address of embedded data void adr(RegisterA64 dst, const void* ptr, size_t size); void adr(RegisterA64 dst, uint64_t value); void adr(RegisterA64 dst, double value); + // Address of code (label) + void adr(RegisterA64 dst, Label& label); + + // Floating-point scalar/vector moves + // Note: constant must be compatible with immediate floating point moves (see isFmovSupported) + void fmov(RegisterA64 dst, RegisterA64 src); + void fmov(RegisterA64 dst, double src); + + // Floating-point scalar/vector math + void fabs(RegisterA64 dst, RegisterA64 src); + void fadd(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2); + void fdiv(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2); + void fmul(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2); + void fneg(RegisterA64 dst, RegisterA64 src); + void fsqrt(RegisterA64 dst, RegisterA64 src); + void fsub(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2); + + // Vector component manipulation + void ins_4s(RegisterA64 dst, RegisterA64 src, uint8_t index); + void ins_4s(RegisterA64 dst, uint8_t dstIndex, RegisterA64 src, uint8_t srcIndex); + void dup_4s(RegisterA64 dst, RegisterA64 src, uint8_t index); + + // Floating-point rounding and conversions + void frinta(RegisterA64 dst, RegisterA64 src); + void frintm(RegisterA64 dst, RegisterA64 src); + void frintp(RegisterA64 dst, RegisterA64 src); + void fcvt(RegisterA64 dst, RegisterA64 src); + void fcvtzs(RegisterA64 dst, RegisterA64 src); + void fcvtzu(RegisterA64 dst, RegisterA64 src); + void scvtf(RegisterA64 dst, RegisterA64 src); + void ucvtf(RegisterA64 dst, RegisterA64 src); + + // Floating-point conversion to integer using JS rules (wrap around 2^32) and set Z flag + // note: this is part of ARM8.3 (JSCVT feature); support of this instruction needs to be checked at runtime + void fjcvtzs(RegisterA64 dst, RegisterA64 src); + + // Floating-point comparisons + void fcmp(RegisterA64 src1, RegisterA64 src2); + void fcmpz(RegisterA64 src); + void fcsel(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, ConditionA64 cond); + + void udf(); + // Run final checks bool finalize(); @@ -93,10 +174,19 @@ class AssemblyBuilderA64 // Assigns label position to the current location void setLabel(Label& label); + // Extracts code offset (in bytes) from label + uint32_t getLabelOffset(const Label& label) + { + CODEGEN_ASSERT(label.location != ~0u); + return label.location * 4; + } + void logAppend(const char* fmt, ...) LUAU_PRINTF_ATTR(2, 3); uint32_t getCodeSize() const; + unsigned getInstructionCount() const; + // Resulting data and code that need to be copied over one after the other // The *end* of 'data' has to be aligned to 16 bytes, this will also align 'code' std::vector data; @@ -105,26 +195,61 @@ class AssemblyBuilderA64 std::string text; const bool logText = false; + const unsigned int features = 0; + + // Maximum immediate argument to functions like add/sub/cmp + static constexpr size_t kMaxImmediate = (1 << 12) - 1; + + // Check if immediate mode mask is supported for bitwise operations (and/or/xor) + static bool isMaskSupported(uint32_t mask); + + // Check if fmov can be used to synthesize a constant + static bool isFmovSupported(double value); private: // Instruction archetypes void place0(const char* name, uint32_t word); - void placeSR3(const char* name, RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, uint8_t op, int shift = 0); + void placeSR3(const char* name, RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, uint8_t op, int shift = 0, int N = 0); void placeSR2(const char* name, RegisterA64 dst, RegisterA64 src, uint8_t op, uint8_t op2 = 0); void placeR3(const char* name, RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, uint8_t op, uint8_t op2); void placeR1(const char* name, RegisterA64 dst, RegisterA64 src, uint32_t op); void placeI12(const char* name, RegisterA64 dst, RegisterA64 src1, int src2, uint8_t op); void placeI16(const char* name, RegisterA64 dst, int src, uint8_t op, int shift = 0); - void placeA(const char* name, RegisterA64 dst, AddressA64 src, uint8_t op, uint8_t size); + void placeA(const char* name, RegisterA64 dst, AddressA64 src, uint16_t opsize, int sizelog); + void placeB(const char* name, Label& label, uint8_t op); void placeBC(const char* name, Label& label, uint8_t op, uint8_t cond); void placeBCR(const char* name, Label& label, uint8_t op, RegisterA64 cond); void placeBR(const char* name, RegisterA64 src, uint32_t op); + void placeBTR(const char* name, Label& label, uint8_t op, RegisterA64 cond, uint8_t bit); void placeADR(const char* name, RegisterA64 src, uint8_t op); + void placeADR(const char* name, RegisterA64 src, uint8_t op, Label& label); + void placeP(const char* name, RegisterA64 dst1, RegisterA64 dst2, AddressA64 src, uint8_t op, uint8_t opc, int sizelog); + void placeCS(const char* name, RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, ConditionA64 cond, uint8_t op, uint8_t opc, int invert = 0); + void placeFCMP(const char* name, RegisterA64 src1, RegisterA64 src2, uint8_t op, uint8_t opc); + void placeFMOV(const char* name, RegisterA64 dst, double src, uint32_t op); + void placeBM(const char* name, RegisterA64 dst, RegisterA64 src1, uint32_t src2, uint8_t op); + void placeBFM(const char* name, RegisterA64 dst, RegisterA64 src1, int src2, uint8_t op, int immr, int imms); + void placeER(const char* name, RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, uint8_t op, int shift); + void placeVR(const char* name, RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, uint16_t op, uint8_t op2); void place(uint32_t word); - void patchLabel(Label& label); - void patchImm19(uint32_t location, int value); + struct Patch + { + enum Kind + { + Imm26, + Imm19, + Imm14, + }; + + Kind kind : 2; + uint32_t label : 30; + uint32_t location; + }; + + void patchLabel(Label& label, Patch::Kind kind); + void patchOffset(uint32_t location, int value, Patch::Kind kind); void commit(); LUAU_NOINLINE void extend(); @@ -138,16 +263,19 @@ class AssemblyBuilderA64 LUAU_NOINLINE void log(const char* opcode, RegisterA64 dst, RegisterA64 src1, int src2); LUAU_NOINLINE void log(const char* opcode, RegisterA64 dst, RegisterA64 src); LUAU_NOINLINE void log(const char* opcode, RegisterA64 dst, int src, int shift = 0); + LUAU_NOINLINE void log(const char* opcode, RegisterA64 dst, double src); LUAU_NOINLINE void log(const char* opcode, RegisterA64 dst, AddressA64 src); - LUAU_NOINLINE void log(const char* opcode, RegisterA64 src, Label label); + LUAU_NOINLINE void log(const char* opcode, RegisterA64 dst1, RegisterA64 dst2, AddressA64 src); + LUAU_NOINLINE void log(const char* opcode, RegisterA64 src, Label label, int imm = -1); LUAU_NOINLINE void log(const char* opcode, RegisterA64 src); LUAU_NOINLINE void log(const char* opcode, Label label); + LUAU_NOINLINE void log(const char* opcode, RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, ConditionA64 cond); LUAU_NOINLINE void log(Label label); LUAU_NOINLINE void log(RegisterA64 reg); LUAU_NOINLINE void log(AddressA64 addr); uint32_t nextLabel = 1; - std::vector(A) -> A` (take a look at [generics](#generics)), whereas in nonstrict we infer `(any) -> any`. We know this is true because `f` can take anything and then return that. If we used `x` with another concrete type, then we would end up inferring that. - -Similarly, we can infer the types of the parameters with ease. By passing a parameter into *anything* that also has a type, we are saying "this and that has the same type." - -```lua -local function greetingsHelper(name: string) - return "Hello, " .. name -end - -local function greetings(name) - return greetingsHelper(name) -end - -print(greetings("Alexander")) -- ok -print(greetings({name = "Alexander"})) -- not ok -``` - -## Table types - -From the type checker perspective, each table can be in one of three states. They are: `unsealed table`, `sealed table`, and `generic table`. This is intended to represent how the table's type is allowed to change. - -### Unsealed tables - -An unsealed table is a table which supports adding new properties, which updates the tables type. Unsealed tables are created using table literals. This is one way to accumulate knowledge of the shape of this table. - -```lua -local t = {x = 1} -- {x: number} -t.y = 2 -- {x: number, y: number} -t.z = 3 -- {x: number, y: number, z: number} -``` - -However, if this local were written as `local t: { x: number } = { x = 1 }`, it ends up sealing the table, so the two assignments henceforth will not be ok. - -Furthermore, once we exit the scope where this unsealed table was created in, we seal it. - -```lua -local function vec2(x, y) - local t = {} - t.x = x - t.y = y - return t -end - -local v2 = vec2(1, 2) -v2.z = 3 -- not ok -``` - -Unsealed tables are *exact* in that any property of the table must be named by the type. Since Luau treats missing properties as having value `nil`, this means that we can treat an unsealed table which does not mention a property as if it mentioned the property, as long as that property is optional. - -```lua -local t = {x = 1} -local u : { x : number, y : number? } = t -- ok because y is optional -local v : { x : number, z : number } = t -- not ok because z is not optional -``` - -### Sealed tables - -A sealed table is a table that is now locked down. This occurs when the table type is spelled out explicitly via a type annotation, or if it is returned from a function. - -```lua -local t : { x: number } = {x = 1} -t.y = 2 -- not ok -``` - -Sealed tables are *inexact* in that the table may have properties which are not mentioned in the type. -As a result, sealed tables support *width subtyping*, which allows a table with more properties to be used as a table with fewer - -```lua -type Point1D = { x : number } -type Point2D = { x : number, y : number } -local p : Point2D = { x = 5, y = 37 } -local q : Point1D = p -- ok because Point2D has more properties than Point1D -``` - -### Generic tables - -This typically occurs when the symbol does not have any annotated types or were not inferred anything concrete. In this case, when you index on a parameter, you're requesting that there is a table with a matching interface. - -```lua -local function f(t) - return t.x + t.y - --^ --^ {x: _, y: _} -end - -f({x = 1, y = 2}) -- ok -f({x = 1, y = 2, z = 3}) -- ok -f({x = 1}) -- not ok -``` - -## Table indexers - -These are particularly useful for when your table is used similarly to an array. - -```lua -local t = {"Hello", "world!"} -- {[number]: string} -print(table.concat(t, ", ")) -``` - -Luau supports a concise declaration for array-like tables, `{T}` (for example, `{string}` is equivalent to `{[number]: string}`); the more explicit definition of an indexer is still useful when the key isn't a number, or when the table has other fields like `{ [number]: string, n: number }`. - -## Generics - -The type inference engine was built from the ground up to recognize generics. A generic is simply a type parameter in which another type could be slotted in. It's extremely useful because it allows the type inference engine to remember what the type actually is, unlike `any`. - -```lua -type Pair = {first: T, second: T} - -local strings: Pair = {first="Hello", second="World"} -local numbers: Pair = {first=1, second=2} -``` - -## Generic functions - -As well as generic type aliases like `Pair`, Luau supports generic functions. These are functions that, as well as their regular data parameters, take type parameters. For example, a function which reverses an array is: -```lua -function reverse(a) - local result = {} - for i = #a, 1, -1 do - table.insert(result, a[i]) - end - return result -end -``` -The type of this function is that it can reverse an array, and return an array of the same type. Luau can infer this type, but if you want to be explicit, you can declare the type parameter `T`, for example: -```lua -function reverse(a: {T}): {T} - local result: {T} = {} - for i = #a, 1, -1 do - table.insert(result, a[i]) - end - return result -end -``` -When a generic function is called, Luau infers type arguments, for example -```lua -local x: {number} = reverse({1, 2, 3}) -local y: {string} = reverse({"a", "b", "c"}) -``` -Generic types are used for built-in functions as well as user functions, -for example the type of two-argument `table.insert` is: -```lua -({T}, T) -> () -``` - -## Union types - -A union type represents *one of* the types in this set. If you try to pass a union onto another thing that expects a *more specific* type, it will fail. - -For example, what if this `string | number` was passed into something that expects `number`, but the passed in value was actually a `string`? - -```lua -local stringOrNumber: string | number = "foo" - -local onlyString: string = stringOrNumber -- not ok -local onlyNumber: number = stringOrNumber -- not ok -``` - -Note: it's impossible to be able to call a function if there are two or more function types in this union. - -## Intersection types - -An intersection type represents *all of* the types in this set. It's useful for two main things: to join multiple tables together, or to specify overloadable functions. - -```lua -type XCoord = {x: number} -type YCoord = {y: number} -type ZCoord = {z: number} - -type Vector2 = XCoord & YCoord -type Vector3 = XCoord & YCoord & ZCoord - -local vec2: Vector2 = {x = 1, y = 2} -- ok -local vec3: Vector3 = {x = 1, y = 2, z = 3} -- ok -``` - -```lua -type SimpleOverloadedFunction = ((string) -> number) & ((number) -> string) - -local f: SimpleOverloadedFunction - -local r1: number = f("foo") -- ok -local r2: number = f(12345) -- not ok -local r3: string = f("foo") -- not ok -local r4: string = f(12345) -- ok -``` - -Note: it's impossible to create an intersection type of some primitive types, e.g. `string & number`, or `string & boolean`, or other variations thereof. - -Note: Luau still does not support user-defined overloaded functions. Some of Roblox and Lua 5.1 functions have different function signature, so inherently requires overloaded functions. - -## Singleton types (aka literal types) - -Luau's type system also supports singleton types, which means it's a type that represents one single value at runtime. At this time, both string and booleans are representable in types. - -> We do not currently support numbers as types. For now, this is intentional. - -```lua -local foo: "Foo" = "Foo" -- ok -local bar: "Bar" = foo -- not ok -local baz: string = foo -- ok - -local t: true = true -- ok -local f: false = false -- ok -``` - -This happens all the time, especially through [type refinements](#type-refinements) and is also incredibly useful when you want to enforce program invariants in the type system! See [tagged unions](#tagged-unions) for more information. - -## Variadic types - -Luau permits assigning a type to the `...` variadic symbol like any other parameter: - -```lua -local function f(...: number) -end - -f(1, 2, 3) -- ok -f(1, "string") -- not ok -``` - -`f` accepts any number of `number` values. - -In type annotations, this is written as `...T`: - -```lua -type F = (...number) -> ...string -``` - -## Type packs - -Multiple function return values as well as the function variadic parameter use a type pack to represent a list of types. - -When a type alias is defined, generic type pack parameters can be used after the type parameters: - -```lua -type Signal = { f: (T, U...) -> (), data: T } -``` - -> Keep in mind that `...T` is a variadic type pack (many elements of the same type `T`), while `U...` is a generic type pack that can contain zero or more types and they don't have to be the same. - -It is also possible for a generic function to reference a generic type pack from the generics list: - -```lua -local function call(s: Signal, ...: U...) - s.f(s.data, ...) -end -``` - -Generic types with type packs can be instantiated by providing a type pack: - -```lua -local signal: Signal = -- - -call(signal, 1, 2, false) -``` - -There are also other ways to instantiate types with generic type pack parameters: - -```lua -type A = (T) -> U... - -type B = A -- with a variadic type pack -type C = A -- with a generic type pack -type D = A -- with an empty type pack -``` - -Trailing type pack argument can also be provided without parentheses by specifying variadic type arguments: - -```lua -type List = (Head, Rest...) -> () - -type B = List -- Rest... is () -type C = List -- Rest is (string, boolean) - -type Returns = () -> T... - --- When there are no type parameters, the list can be left empty -type D = Returns<> -- T... is () -``` - -Type pack parameters are not limited to a single one, as many as required can be specified: - -```lua -type Callback = { f: (Args...) -> Rets... } - -type A = Callback<(number, string), ...number> -``` - -## Typing idiomatic OOP - -One common pattern we see throughout Roblox is this OOP idiom. A downside with this pattern is that it does not automatically create a type binding for an instance of that class, so one has to write `type Account = typeof(Account.new("", 0))`. - -```lua -local Account = {} -Account.__index = Account - -function Account.new(name, balance) - local self = {} - self.name = name - self.balance = balance - - return setmetatable(self, Account) -end - -function Account:deposit(credit) - self.balance += credit -end - -function Account:withdraw(debit) - self.balance -= debit -end - -local account: Account = Account.new("Alexander", 500) - --^^^^^^^ not ok, 'Account' does not exist -``` - -## Tagged unions - -Tagged unions are just union types! In particular, they're union types of tables where they have at least _some_ common properties but the structure of the tables are different enough. Here's one example: - -```lua -type Ok = { type: "ok", value: T } -type Err = { type: "err", error: E } -type Result = Ok | Err -``` - -This `Result` type can be discriminated by using type refinements on the property `type`, like so: - -```lua -if result.type == "ok" then - -- result is known to be Ok - -- and attempting to index for error here will fail - print(result.value) -elseif result.type == "err" then - -- result is known to be Err - -- and attempting to index for value here will fail - print(result.error) -end -``` - -Which works out because `value: T` exists only when `type` is in actual fact `"ok"`, and `error: E` exists only when `type` is in actual fact `"err"`. - -## Type refinements - -When we check the type of any lvalue (a global, a local, or a property), what we're doing is we're refining the type, hence "type refinement." The support for this is arbitrarily complex, so go crazy! - -Here are all the ways you can refine: -1. Truthy test: `if x then` will refine `x` to be truthy. -2. Type guards: `if type(x) == "number" then` will refine `x` to be `number`. -3. Equality: `x == "hello"` will refine `x` to be a singleton type `"hello"`. - -And they can be composed with many of `and`/`or`/`not`. `not`, just like `~=`, will flip the resulting refinements, that is `not x` will refine `x` to be falsy. - -Using truthy test: -```lua -local maybeString: string? = nil - -if maybeString then - local onlyString: string = maybeString -- ok - local onlyNil: nil = maybeString -- not ok -end - -if not maybeString then - local onlyString: string = maybeString -- not ok - local onlyNil: nil = maybeString -- ok -end -``` - -Using `type` test: -```lua -local stringOrNumber: string | number = "foo" - -if type(stringOrNumber) == "string" then - local onlyString: string = stringOrNumber -- ok - local onlyNumber: number = stringOrNumber -- not ok -end - -if type(stringOrNumber) ~= "string" then - local onlyString: string = stringOrNumber -- not ok - local onlyNumber: number = stringOrNumber -- ok -end -``` - -Using equality test: -```lua -local myString: string = f() - -if myString == "hello" then - local hello: "hello" = myString -- ok because it is absolutely "hello"! - local copy: string = myString -- ok -end -``` - -And as said earlier, we can compose as many of `and`/`or`/`not` as we wish with these refinements: -```lua -local function f(x: any, y: any) - if (x == "hello" or x == "bye") and type(y) == "string" then - -- x is of type "hello" | "bye" - -- y is of type string - end - - if not (x ~= "hi") then - -- x is of type "hi" - end -end -``` - -`assert` can also be used to refine in all the same ways: -```lua -local stringOrNumber: string | number = "foo" - -assert(type(stringOrNumber) == "string") - -local onlyString: string = stringOrNumber -- ok -local onlyNumber: number = stringOrNumber -- not ok -``` - -## Type casts - -Expressions may be typecast using `::`. Typecasting is useful for specifying the type of an expression when the automatically inferred type is too generic. - -For example, consider the following table constructor where the intent is to store a table of names: -```lua -local myTable = {names = {}} -table.insert(myTable.names, 42) -- Inserting a number ought to cause a type error, but doesn't -``` - -In order to specify the type of the `names` table a typecast may be used: - -```lua -local myTable = {names = {} :: {string}} -table.insert(myTable.names, 42) -- not ok, invalid 'number' to 'string' conversion -``` - -A typecast itself is also type checked to ensure the conversion is made to a subtype of the expression's type or `any`: -```lua -local numericValue = 1 -local value = numericValue :: any -- ok, all expressions may be cast to 'any' -local flag = numericValue :: boolean -- not ok, invalid 'number' to 'boolean' conversion -``` - -## Roblox types - -Roblox supports a rich set of classes and data types, [documented here](https://developer.roblox.com/en-us/api-reference). All of them are readily available for the type checker to use by their name (e.g. `Part` or `RaycastResult`). - -When one type inherits from another type, the type checker models this relationship and allows to cast a subclass to the parent class implicitly, so you can pass a `Part` to a function that expects an `Instance`. - -All enums are also available to use by their name as part of the `Enum` type library, e.g. `local m: Enum.Material = part.Material`. - -Finally, we can automatically deduce what calls like `Instance.new` and `game:GetService` are supposed to return: - -```lua -local part = Instance.new("Part") -local basePart: BasePart = part -``` - -Note that many of these types provide some properties and methods in both lowerCase and UpperCase; the lowerCase variants are deprecated, and the type system will ask you to use the UpperCase variants instead. - -## Module interactions - -Let's say that we have two modules, `Foo` and `Bar`. Luau will try to resolve the paths if it can find any `require` in any scripts. In this case, when you say `script.Parent.Bar`, Luau will resolve it as: relative to this script, go to my parent and get that script named Bar. - -```lua --- Module Foo -local Bar = require(script.Parent.Bar) - -local baz1: Bar.Baz = 1 -- not ok -local baz2: Bar.Baz = "foo" -- ok - -print(Bar.Quux) -- ok -print(Bar.FakeProperty) -- not ok - -Bar.NewProperty = true -- not ok -``` - -```lua --- Module Bar -export type Baz = string - -local module = {} - -module.Quux = "Hello, world!" - -return module -``` - -There are some caveats here though. For instance, the require path must be resolvable statically, otherwise Luau cannot accurately type check it. - -### Cyclic module dependencies - -Cyclic module dependencies can cause problems for the type checker. In order to break a module dependency cycle a typecast of the module to `any` may be used: -```lua -local myModule = require(MyModule) :: any -``` diff --git a/docs/_pages/why.md b/docs/_pages/why.md deleted file mode 100644 index 4b097a56b..000000000 --- a/docs/_pages/why.md +++ /dev/null @@ -1,26 +0,0 @@ ---- -permalink: /why -title: Why Luau? ---- - -Around 2006, [Roblox](https://www.roblox.com) started using Lua 5.1 as a scripting language for games. Over the years the runtime had to be tweaked to provide a safe, secure sandboxed environment; we gradually started accumulating small library changes and tweaks. - -Over the course of the last few years, instead of using Web-based stack for our player-facing application, Lua-based in-game UI and Qt-based editor UI, we've started consolidating a lot of the efforts and developing all of these using Roblox engine and Lua as a scripting language. - -Having grown a substantial internal codebase that needed to be correct and performant, and with the focus shifting a bit from novice game developers to professional studios building games on Roblox and our own teams of engineers building applications, there was a need to improve performance and quality of the code we were writing. - -Unlike mainline Lua, we also could not afford to do major breaking changes to the language (hence the 5.1 language baseline that remained unchanged for more than a decade). While faster implementations of Lua 5.1 like LuaJIT were available, they didn't meet our needs in terms of portability, ease of change and they didn't address the problem of developing robust code at scale. - -All of these motivated us to start reshaping Lua 5.1 that we started from into a new, derivative language that we call Luau. Our focus is on making the language more performant and feature-rich, and make it easier to write robust code through a combination of linting and type checking using a gradual type system. - -## Complete rewrite? - -A very large part of Luau codebase is written from scratch. We needed a set of tools to be able to write language analysis tools; Lua has a parser that is integrated with the bytecode compiler, which makes it unsuitable for complex semantic analysis. For bytecode compilation, while a single pass compiler can deliver better compilation throughput and be simpler than a full frontend/backend, it significantly limits the optimizations that can be done at the bytecode level. - -Luau compiler and analysis tools are thus written from scratch, closely following the syntax and semantics of Lua. Our compiler is not single-pass, and instead relies on a set of analysis passes that run over the AST to produce efficient bytecode, followed by some post-process optimizations. - -As for the runtime, we had to rewrite the interpreter from scratch to get substantially faster performance; using a combination of techniques pioneered by LuaJIT and custom optimizations that are able to improve performance by taking control over the entire stack (language, compiler, interpreter, virtual machine), we're able to get close to LuaJIT interpreter performance while using C as an implementation language. - -The garbage collector and the core libraries represent more of an incremental change, where we used Lua 5.1 as a baseline but we're continuing to rewrite these as well with performance in mind. - -While Luau doesn't currently implement JIT/AOT, this is likely to happen at some point; beyond the usual implementation challenges and security concerns, one significant limitation is that we don't have access to JIT on many platforms so for us maintaining excellent interpreted performance for gameplay and application code is more important than reaching peak FLOPS on numerical code. diff --git a/docs/_posts/2019-11-11-luau-recap-november-2019.md b/docs/_posts/2019-11-11-luau-recap-november-2019.md deleted file mode 100644 index 253d59c55..000000000 --- a/docs/_posts/2019-11-11-luau-recap-november-2019.md +++ /dev/null @@ -1,108 +0,0 @@ ---- -layout: single -title: "Luau Recap: November 2019" ---- - -A few months ago, we’ve released our new Lua implementation, Luau ([Faster Lua VM Released](https://devforum.roblox.com/t/faster-lua-vm-released/339587)) and made it the default for most platforms and configurations. Since then we’ve shipped many smaller changes that improved performance and expanded the usability of the VM. Many of them have been noted in release notes but some haven’t, so here’s a recap of everything that has happened in the Lua land since September! - -[Originally posted to the [Roblox Developer Forum](https://devforum.roblox.com/t/luau-recap-november-2019/).] - -## Debugger beta - -When we launched the new VM, we did it without the full debugger support. The reason for this is that the new VM is substantially different and the old implementation of the debugger (that relied on line hooks) just doesn’t work. - -We had to rebuild the low level implementation of the debugger from scratch - this is a tricky problem and it took time! We are excited to share a beta preview of this with you today. - -To use this, simply make sure that you’re enrolled in the new Lua VM beta: - -![Enable New Lua VM]({{ site.url }}{{ site.baseurl }}/assets/images/luau-recap-november-2019-option.png) - -After this you can use the debugger as usual. If you see any bugs, please feel free to report them! - -## Performance improvements - - * The for loop optimization that specializes `pairs/ipairs` now works for localized versions of these globals as well, as well as `next, table` expressions - * a^k expressions are now faster for some trivial values of k such as 2 and 0.5 - * Calling methods and accessing properties on deeply nested Roblox objects is now significantly faster than it used to be (~2x faster for objects that have an 8-deep nesting) - the cost is now independent of the hierarchy depth. - * Accessing .X/.Y/.Z properties on Vector types is now ~30% faster - * On Windows and Xbox, we’ve tuned our interpreter to be ~5-6% faster on Lua-intensive code - * For a set of builtin functions, we now support very quickly calling them from VM via a new fastcall mechanism. - -Fastcall requires the function call to be present in source as a global or localized global access (e.g. either `math.max(x, 1)` or `max(x, 1) where local max = math.max`). This can be substantially faster than normal calls, e.g. this makes SHA256 benchmark ~1.7x faster. We are currently optimizing calls to `bit32`, `math` libraries and additionally `assert` and `type`. Also, just like other global-based optimizations, this one is disabled if you use `getfenv`/`setfenv`. - -## Lua library extensions - -We’ve implemented most library features available in later versions of upstream Lua, including: - - * `table.pack` and `table.unpack` from Lua 5.2 (the latter is same as global `unpack`, the former helps by storing the true argument count in `.n` field) - * `table.move` from Lua 5.3 (useful for copying data between arrays) - * `coroutine.isyieldable` from Lua 5.3 - * `math.log` now accepts a second optional argument (as seen in Lua 5.2) for the logarithm base - -We’ve also introduced two new functions in the table library: - - * `table.create(count, value)` can create an array-like table quickly - * `table.find(table, value [, init])` can quickly find the numeric index of the element in the table - -Autocomplete support for `table.create`/`table.find` will ship next week - -## Lua syntax extensions - -We’ve started taking a look at improving the Lua syntax. To that end, we’ve incorporated a few changes from later versions of Lua into the literal syntax: - - * String literals now support `\z` (skip whitespace), `\x` (hexadecimal byte) and `\u` (Unicode codepoint) escape sequences - -and implemented a few extra changes: - - * Number literals now support binary literals, e.g. `0b010101` - * Number literals now support underscores anywhere in the literal for easier digit grouping, e.g. `1_000_000` - -Note that the literal extensions aren’t currently supported in syntax highlighter in Studio but this will be amended soon. - -## Error messages - -Error messages are slowly getting a bit of love. We’ve improved some runtime errors to be nicer, in particular: - - * When indexing operation fails, we now specify the key name or type, e.g. “attempt to index foo with ‘Health’†- * When arithmetic operations fails, we now specify the type of arithmetic operation, e.g. “attempt to perform arithmetic (add) on table and number†- -We’ve also improved some parse errors to look nicer by providing extra context - for example, if you forget parentheses after function name in a function declaration, we will now say `Expected '(' when parsing function, got 'local'`. - -We are looking into some reports of misplaced line numbers on errors in multi-line expressions but this will only ship later. - -## Correctness fixes - -There are always a few corner cases that we miss - a new Lua implementation is by necessity subtly different in a few places. Our goal is to find and correct as many of these issues as possible. In particular, we’ve: - - * Fixed some cases where we wouldn’t preserve negative zero (`-0`) - * Fixed cases where `getfenv(0)` wouldn’t properly deoptimize access to builtin globals - * Fixed cases where calling a function with >255 parameters would overflow the stack - * Fixed errors with very very very long scripts and control flow around large blocks (thousands of lines of code in a single if/for statement) - * Fixed cases where in Studio on Windows, constant-time comparisons with `NaNs` didn’t behave properly (`0/0==1`) - -Also, the upvalue limit in the new VM has been raised to 200 from 60; the limit in Lua 5.2 is 255 but we decided for now to match the local limit. - -## Script analysis - -Along with the compiler and virtual machine, we’ve implemented a new linting framework on top of Luau which is similar to our old script analysis code but is richer. In particular, we support a few more checks that are enabled by default: - - * Unreachable code warning, for cases where function provably doesn’t reach a specific point, such as redundant return after a set of if/else statements where every branch returns or errors. - * Unknown type warning, which was emitted before for `Instance.new/GetService/IsA` calls, is now also emitted when the result of `type/typeof` is compared to a string literal - * We now recognize and flag mistaken attempts to iterate downwards with a for loop (such as `for i=9,1` or `for i=#t,1` as well as cases where numeric for loop doesn’t reach the stated target (`for i=1,4.5`) - * We now detect and flag cases where in assignment expressions variables are implicitly initialized with nil or values are dropped during assignment - * “Statement spans multiple lines†warning now does not trigger on idiomatic constructs involving setting up locals in a block (`local name do ... name = value ... end`) - -We also have implemented a few more warnings for common style/correctness issues but they aren’t enabled yet - we’re looking into ways for us to enable them without too much user impact: - - * Local variables that shadow other local variables / global variables - * Local variables that are assigned but never used - * Implicit returns, where functions that explicitly return values in some codepaths can reach the end of the function and implicitly return no values (which is error-prone) - -## Future plans - -There’s a fair bit of performance improvements that we haven’t gotten to yet that are on the roadmap - this includes general VM optimizations (faster function calls, faster conditional checks, faster error handling including `pcall`) and some library optimizations (in particular, Vector3 math performance improvements). And we’re starting to look into some exciting ways for us to make performance even better in the future. - -Also we’re still working on the type system! It’s starting to take shape and we should have something ready for you by the end of the year, but you’ll learn about it in a separate post :smiley: - -As always don’t hesitate to reach out if you have any issues or have any suggestions for improvements. - diff --git a/docs/_posts/2020-01-16-luau-type-checking-beta.md b/docs/_posts/2020-01-16-luau-type-checking-beta.md deleted file mode 100644 index eb6a8b791..000000000 --- a/docs/_posts/2020-01-16-luau-type-checking-beta.md +++ /dev/null @@ -1,160 +0,0 @@ ---- -layout: single -title: "Luau Type Checking Beta" ---- - -Hello! - -We’ve been quietly working on building a type checker for Lua for quite some time now. It is now far enough along that we’d really like to hear what you think about it. - -I am very happy to offer a beta test into the second half of the Luau effort. - -[Originally posted on the [Roblox Developer Forum](https://devforum.roblox.com/t/luau-type-checking-beta/).] - -## Beta Test - -First, a word of caution: In this test, we are changing the syntax of Lua. We are pretty sure that we’ve mostly gotten things right, but part of the reason we’re calling this a beta is that, if we learn that we’ve made a mistake, we’re going to go back and fix it even if it breaks compatibility. - -Please try it out and tell us what you think, but be aware that this is not necessarily our final form. 🙂 - -Beta testers can try it out by enabling the “Enable New Lua Script Analysis†beta feature in Roblox Studio. - -## Overview - -Luau is an ahead-of-time typechecking system that sits atop ordinary Lua code. It does not (yet) feed into the runtime system; it behaves like a super powerful lint tool to help you find bugs in your code quickly. - -It is also what we call a gradual type system. This means that you can choose to add type annotations in some parts of your code but not others. - -## Two Modes - -Luau runs in one of two modes: strict, and nonstrict. - -### Nonstrict Mode - -Nonstrict mode is intended to be as helpful as possible for programs that are written without type annotations. We want to report whatever we can without reporting an error in reasonable Lua code. - - * If a local variable does not have a type annotation and it is not initially assigned a table, its type is any - * Unannotated function parameters have type any - * We do not check the number of values returned by a function - * Passing too few or too many arguments to a function is ok - -### Strict Mode - -Strict mode is expected to be more useful for more complex programs, but as a side effect, programs may need a bit of adjustment to pass without any errors. - - * The types of local variables, function parameters, and return types are deduced from how they are used - * Errors are produced if a function returns an inconsistent number of parameters, or if it is passed the wrong number of arguments - -Strict mode is not enabled by default. To turn it on, you need to add a special comment to the top of your source file. -``` ---!strict -``` - -## New syntax - -You can write type annotations in 5 places: - - * After a local variable - * After a function parameter - * After a function declaration (to declare the function’s return type) - * In a type alias, and - * After an expression using the new as keyword. - -``` -local foo: number = 55 - -function is_empty(param: string) => boolean - return 0 == param:len() -end - -type Point = {x: number, y: number} - -local baz = quux as number -``` - -## Type syntax -### Primitive types - -`nil`, `number`, `string`, and `boolean` - -### any -The special type any signifies that Luau shouldn’t try to track the type at all. You can do anything with an any. - -### Tables -Table types are surrounded by curly braces. Within the braces, you write a list of name: type pairs: -``` -type Point = {x: number, y: number} -``` -Table types can also have indexers. This is how you describe a table that is used like a hash table or an array. -``` -type StringArray = {[number]: string} - -type StringNumberMap = {[string]: number} -``` - -### Functions - -Function types use a `=>` to separate the argument types from the return types. -``` -type Callback = (string) => number -``` -If a function returns more than one value, put parens around them all. -``` -type MyFunction = (string) => (boolean, number) -``` - -### Unions - -You can use a `|` symbol to indicate an “or†combination between two types. Use this when a value can have different types as the program runs. -``` -function ordinals(limit) - local i = 0 - return function() => number | nil - if i < limit then - local t = i - i = i + 1 - return t - else - return nil - end - end -end -``` - -### Options - -It’s pretty commonplace to have optional data, so there is extra syntax for describing a union between a type and `nil`. Just put a `?` on the end. Function arguments that can be `nil` are understood to be optional. -``` -function foo(x: number, y: string?) end - -foo(5, 'five') -- ok -foo(5) -- ok -foo(5, 4) -- not ok -``` - -### Type Inference - -If you don’t write a type annotation, Luau will try to figure out what it is. -``` ---!strict -local Counter = {count=0} - -function Counter:incr() - self.count = 1 - return self.count -end - -print(Counter:incr()) -- ok -print(Counter.incr()) -- Error! -print(Counter.amount) -- Error! -``` - -## Future Plans - -This is just the first step! - -We’re excited about a whole bunch of stuff: - - * Nonstrict mode is way more permissive than we’d like - * Generics! - * Editor integration \ No newline at end of file diff --git a/docs/_posts/2020-02-25-luau-recap-february-2020.md b/docs/_posts/2020-02-25-luau-recap-february-2020.md deleted file mode 100644 index fc8b6e6df..000000000 --- a/docs/_posts/2020-02-25-luau-recap-february-2020.md +++ /dev/null @@ -1,93 +0,0 @@ ---- -layout: single -title: "Luau Recap: February 2020" ---- - -We continue to iterate on our language stack, working on many features for type checking, performance, and quality of life. Some of them come with announcements, some come with release notes, and some just ship - here we will talk about all things that happened since November last year. - -A lot of people work on these improvements; thanks @Apakovtac, @EthicalRobot, @fun_enthusiast, @xyzzyismagic, @zeuxcg! - -[Originally posted to the [Roblox Developer Forum](https://devforum.roblox.com/t/luau-recap-february-2020/).] - -We were originally intending to ship the beta last year but had to delay it due to last minute bugs. However, it’s now live as a beta option on production! Go here to learn more: - -EDIT: Please DO NOT publish places with type annotations just yet as they will not work on production! This is why it’s a beta 🙂 However, please continue to experiment in Studio and give us feedback. We are reading everything and will be fixing reported bugs and discussing syntax / semantics issues some people brought up. Hello! We’ve been quietly working on building a type checker for Lua for quite some time now. It is now far enough along that we’d really like to hear what… - -We’re continuing to iterate on the feedback we have received here. Something that will happen next is that we will enable type annotations on live server/clients - which will mean that you will be able to publish source code with type annotations without breaking your games. We still have work to do on the non-strict and strict mode type checking before the feature can move out of beta though, in particular we’ve implemented support for require statement and that should ship next week 🤞 - -We also fixed a few bugs in the type definitions for built-in functions/API and the type checker itself: - - * `table.concat` was accidentally treating the arguments as required - * `string.byte` and `string.find` now have a correct precise type - * `typeof` comparisons in if condition incorrectly propagated the inferred type into `elseif` branches - -We are also making the type checker more ergonomic and more correct. Two changes I want to call out are: - - * Type aliases declared with `type X = Y` are now co-recursive, meaning that they can refer to each other, e.g. - -``` -type array = { [number]: T } - -type Wheel = { radius: number, car: Car } -type Car = { wheels: array } -``` - -* We now support type intersections `(A & B)` in addition to type unions `(A | B)`. Intersections are critical to modeling overloaded functions correctly - while Lua as a language doesn’t support function overloads, we have various APIs that have complex overloaded semantics - one of them that people who tried the beta had problems with was UDim2.new - and it turns out that to correctly specify the type, we had to add support for intersections. This isn’t really intended as a feature that is used often in scripts developers write, but it’s important for internal use. - -## Debugger (beta) - -When we shipped the original version of the VM last year, we didn’t have the debugger fully working. Debugger relies on low-level implementation of the old VM that we decided to remove from the new VM - as such we had to make a new low-level debugging engine. - -This is now live under the Luau VM beta feature, see [this post](https://devforum.roblox.com/t/luau-in-studio-beta/456529) for details. - -If you use the debugger at all, please enable the beta feature and try it out - we want to fix all the bugs we can find, and this is blocking us enabling the new VM everywhere. - -(a quick aside: today the new VM is enabled on all servers and all clients, and it’s enabled in Studio “edit†mode for plugins - but not in Studio play modes, and full debugger support is what prevents us from doing so) - -## Language - -This section is short and sweet this time around: - -* You can now use continue statement in for/while/repeat loops. :tada: - -Please note that we only support this in the new VM, so you have to be enrolled in Luau VM beta to be able to use it in Studio. It will work in game regardless of the beta setting as per above. - -## Performance - -While we have some really forward looking ideas around multi-threading and native code compilation that we’re starting to explore, we also continue to improve performance across the board based on our existing performance backlog and your feedback. - -In particular, there are several memory and performance optimizations that shipped in the last few months: - - * Checking for truth (`if foo or foo and bar`) is now a bit faster, giving 2-3% performance improvements on some benchmarks - * `table.create` (with value argument) and `table.pack` have been reimplemented and are ~1.5x faster than before - * Internal mechanism for filling arrays has been made faster as well, which makes `Terrain:ReadVoxels` ~10% faster - * Catching engine-generated errors with pcall/xpcall is now ~1.5x faster (this only affects performance of calls that generated errors) - * String objects now take 8 bytes less memory per object (and in an upcoming change we’ll save a further 4 bytes) - * Capturing local variables that are never assigned to in closures is now much faster, takes much less memory and generates much less GC pressure. This can make closure creation up to 2x faster, and improves some Roact benchmarks by 10%. This is live in Studio and will ship everywhere else shortly. - * The performance of various for loops (numeric & ipairs) on Windows regressed after a VS2017 upgrade; this regression has been fixed, making all types of loops perform roughly equally. VS2017 upgrade also improved Luau performance on Windows by ~10% across the board. - * Lua function calls have been optimized a bit more, gaining an extra 10% of performance in call-heavy benchmarks on Windows. - * Variadic table constructors weren’t compiled very efficiently, resulting in surprisingly low performance of constructs like `{...}`. Fixing that made `{...}` ~3x faster for a typical number of variadic arguments. - -## Diagnostics - -We spent some time to improve error messages in various layers of the stack based on the reports from community. Specifically: - - * The static analysis warning about incorrect bounds for numeric for loops is now putting squigglies in the right place. - * Fixed false positive static analysis warnings about unreachable code inside repeat…until loops in certain cases. - * Multiline table construction expressions have a more precise line information now which helps in debugging since callstacks are now easier to understand - * Incomplete statements (e.g. foo) now produce a more easily understandable parsing error - * In some cases when calling the method with a `.` instead of `:`, we emitted a confusing error message at runtime (e.g. humanoid.LoadAnimation(animation)). We now properly emit the error message asking the user if `:` was intended. - * The legacy global `ypcall` is now flagged as deprecated by script analysis - * If you use a Unicode symbol in your source program outside of comments or string literals, we now produce a much more clear message, for example: -``` -local pi = 3․13 -- spoiler alert: this is not a dot! -``` -produces `Unexpected Unicode character: U+2024. Did you mean '.'?` - -## LoadLibrary removal - -Last but not least, let’s all press [F for LoadLibrary](https://devforum.roblox.com/t/loadlibrary-is-going-to-be-removed-on-february-3rd/382516). - -It was fun while it lasted, but supporting it caused us a lot of pain over the years and prevented some forward-looking changes to the VM. We don’t like removing APIs from the platform, but in this case it was necessary. Thanks to the passionate feedback from the community we adjusted our initial rollout plans to be less aggressive and batch-processed a lot of gear items that used this function to stop using this function. The update is in effect and LoadLibrary is no more. - -As usual, if you have any feedback about any of these updates, suggestions, bug reports, etc., post them in this thread or (preferably for bugs) as separate posts in the bug report category. diff --git a/docs/_posts/2020-05-18-luau-recap-may-2020.md b/docs/_posts/2020-05-18-luau-recap-may-2020.md deleted file mode 100644 index a2b4cdbca..000000000 --- a/docs/_posts/2020-05-18-luau-recap-may-2020.md +++ /dev/null @@ -1,100 +0,0 @@ ---- -layout: single -title: "Luau Recap: May 2020" ---- - -Luau (lowercase u, “l-wowâ€) is an umbrella initiative to improve our language stack - the syntax, compiler, virtual machine, builtin Lua libraries, type checker, linter (known as Script Analysis in Studio), and more related components. We continuously develop the language and runtime to improve performance, robustness and quality of life. Here we will talk about all things that happened since the update in March! - -[Originally posted on the [Roblox Developer Forum](https://devforum.roblox.com/t/luau-recap-may-2020/).] - -## New function type annotation syntax - -As noted in the previous update, the function type annotation syntax now uses `:` on function definitions and `->` on standalone function types: -``` -type FooFunction = (number, number) -> number - -function foo(a: number, b: number): number - return a + b -end -``` -This was done to make our syntax more consistent with other modern languages, and is easier to read in type context compared to our old `=>`. - -This change is now live; the old syntax is still accepted but it will start producing warnings at some point and will be removed eventually. - -## Number of locals in each function is now limited to 200 -As detailed in [Upcoming change to (correctly) limit the local count to 200](https://devforum.roblox.com/t/upcoming-change-to-correctly-limit-the-local-count-to-200/528417) (which is now live), when we first shipped Luau we accidentally set the local limit to 255 instead of 200. This resulted in confusing error messages and code that was using close to 250 locals was very fragile as it could easily break due to minor codegen changes in our compiler. - -This was fixed, and now we’re correctly applying limits of 200 locals, 200 upvalues and 255 registers (per function) - and emit proper error messages pointing to the right place in the code when either limit is exceeded. - -This is technically a breaking change but scripts with >200 locals didn’t work in our old VM and we felt like we had to make this change to ensure long-term stability. - -## Require handling improvements in type checker + export type - -We’re continuing to flesh out the type checker support for modules. As part of this, we overhauled the require path tracing - type checker is now much better at correctly recognizing (statically) which module you’re trying to require, including support for `game:GetService`. - -Additionally, up until now we have been automatically exporting all type aliases declared in the module (via `type X = Y`); requiring the module via `local Foo = require(path)` made these types available under `Foo.` namespace. - -This is different from the explicit handling of module entries, that must be added to the table returned from the `ModuleScript`. This was highlighted as a concern, and to fix this we’ve introduced `export type` syntax. - -Now the only types that are available after require are types that are declared with `export type X = Y`. If you declare a type without exporting it, it’s available inside the module, but the type alias can’t be used outside of the module. That allows to cleanly separate the public API (types and functions exposed through the module interface) from implementation details (local functions etc.). - -## Improve type checker robustness - -As we’re moving closer to enabling type checking for everyone to use (no ETA at the moment), we’re making sure that the type checker is as robust as possible. - -This includes never crashing and always computing the type information in a reasonable time frame, even on obscure scripts like this one: -``` -type ( ... ) ( ) ; -( ... ) ( - - ... ) ( - ... ) -type = ( ... ) ; -( ... ) ( ) ( ... ) ; -( ... ) "" -``` -To that end we’ve implemented a few changes, most of them being live, that fix crashes and unbounded recursion/iteration issues. This work is ongoing, as we’re fixing issues we encounter in the testing process. - -## Better types for Lua and Roblox builtin APIs - -In addition to improving the internals of the type checker, we’re still working on making sure that the builtin APIs have correct type information exposed to the type checker. - -In the last few weeks we’ve done a major audit and overhaul of that type information. We used to have many builtin methods “stubbed†to have a very generic type like `any` or `(...) -> any`, and while we still have a few omissions we’re much closer to full type coverage. - -One notable exception here is the `coroutine.` library which we didn’t get to fully covering, so the types for many of the functions there are imprecise. - -If you find cases where builtin Roblox APIs have omitted or imprecise type information, please let us know by commenting on this thread or filing a bug report. - -The full set of types we expose as of today is listed here for inquisitive minds: [https://gist.github.com/zeux/d169c1416c0c65bb88d3a3248582cd13](https://gist.github.com/zeux/d169c1416c0c65bb88d3a3248582cd13) - -## Removal of __gc from the VM -A bug with `continue` and local variables was reported to us a few weeks ago; the bug was initially believed to be benign but it was possible to turn this bug into a security vulnerability by getting access to `__gc` implementation for builtin Roblox objects. After fixing the bug itself (the turnaround time on the bug fix was about 20 hours from the bug report), we decided to make sure that future bugs like this don’t compromise the security of the VM by removing `__gc`. - -`__gc` is a metamethod that Lua 5.1 supports on userdata, and future versions of Lua extend to all tables; it runs when the object is ready to be garbage collected, and the primary use of that is to let the userdata objects implemented in C to do memory cleanup. This mechanism has several problems: - - * `__gc` is invoked by the garbage collector without context of the original thread. Because of how our sandboxing works this means that this code runs at highest permission level, which is why `__gc` for newproxy-created userdata was disabled in Roblox a long time ago (10 years?) - * `__gc` for builtin userdata objects puts the object into non-determinate state; due to how Lua handles `__gc` in weak keys (see [https://www.lua.org/manual/5.2/manual.html#2.5.2](https://www.lua.org/manual/5.2/manual.html#2.5.2)), these objects can be observed by external code. This has caused crashes in some Roblox code in the past; we changed this behavior at some point last year. - * Because `__gc` for builtin objects puts the object into non-determinate state, calling it on the same object again, or calling any other methods on the object can result in crashes or vulnerabilities where the attacker gains access to arbitrarily mutating the process memory from a Lua script. We normally don’t expose `__gc` because the metatables of builtin objects are locked but if it accidentally gets exposed the results are pretty catastrophic. - * Because `__gc` can result in object resurrection (if a custom Lua method adds the object back to the reachable set), during garbage collection the collector has to traverse the set of userdatas twice - once, to run `__gc` and a second time to mark the survivors. - -For all these reasons, we decided that the `__gc` mechanism just doesn’t pull its weight, and completely removed it from the VM - builtin userdata objects don’t use it for memory reclamation anymore, and naturally declaring `__gc` on custom userdata objects still does nothing. - -Aside from making sure we’re protected against these kinds of vulnerabilities in the future, this makes garbage collection ~25% faster. - -## Memory and performance improvements - -It’s probably not a surprise at this point but we’re never fully satisfied with the level of performance we get. From a language implementation point of view, any performance improvements we can make without changing the semantics are great, since they automatically result in Lua code running faster. To that end, here’s a few changes we’ve implemented recently: - - * ~~A few string. methods, notably string.byte and string.char, were optimized to make it easier to write performant deserialization code. string.byte is now ~4x faster than before for small numbers of returned characters. For optimization to be effective, it’s important to call the function directly (`string.byte(foo, 5)`) instead of using method calls (`foo:byte(5)`).~~ This had to be disabled due to a rare bug in some cases, this optimization will come back in a couple of weeks. - * `table.unpack` was carefully tuned for a few common cases, making it ~15% faster; `unpack` and `table.unpack` now share implementations (and the function objects are equal to each other). - * While we already had a very efficient parser, one long standing bottleneck in identifier parsing was fixed, making script compilation ~5% faster across the board, which can slightly benefit server startup times. - * Some builtin APIs that use floating point numbers as arguments, such as various `Vector3` constructors and operators, are now a tiny bit faster. - * All string objects are now 8 bytes smaller on 64-bit platforms, which isn’t a huge deal but can save a few megabytes of Lua heap in some games. - * Debug information is using a special compact format that results in ~3.2x smaller line tables, which ends up making function bytecode up to ~1.5x smaller overall. This can be important for games with a lot of scripts. - * Garbage collector heap size accounting was cleaned up and made more accurate, which in some cases makes Lua heap ~10% smaller; the gains highly depend on the workload. - -## Library changes - -The standard library doesn’t see a lot of changes at this point, but we did have a couple of small fixes here: - - * `coroutine.wrap` and `coroutine.create` now support C functions. This was the only API that treated Lua and C functions differently, and now it doesn’t. - * `require` silently skipped errors in module scripts that occurred after the module scripts yielding at least once; this was a regression from earlier work on yieldable pcall and has been fixed. - -As usual, if you have questions, comments, or any other feedback on these changes, feel free to share it in this thread or create separate posts for bug reports. \ No newline at end of file diff --git a/docs/_posts/2020-06-20-luau-recap-june-2020.md b/docs/_posts/2020-06-20-luau-recap-june-2020.md deleted file mode 100644 index 182762190..000000000 --- a/docs/_posts/2020-06-20-luau-recap-june-2020.md +++ /dev/null @@ -1,115 +0,0 @@ ---- -layout: single -title: "Luau Recap: June 2020" ---- - -… otherwise known as “This Month in Luau†I guess? You know the drill by now. We’ll talk about exciting things that happened to Luau - our new language stack. - -anxiously glances at FIB3 thread that casts a huge shadow on this announcement, but hopefully somebody will read this - -Many people work on these improvements; thanks @Apakovtac, @EthicalRobot, @fun_enthusiast, @zeuxcg! - -[Originally posted on the [Roblox Developer Forum](https://devforum.roblox.com/t/luau-recap-june-2020/).] - -## We have a website! - -Many developers told us on many occasions that as much as they love the recaps, it’s hard to know what the state of the language or libraries is if the only way to find out is to read through all updates. What’s the syntax extensions that Luau supports now? How do I use type checking? What’s the status of from Lua 5.x? - -Well, you can find all of this out here now: [https://roblox.github.io/luau/](https://roblox.github.io/luau/) - -Please let us know if this documentation can be improved - what are you missing, what could be improved. For now to maximize change velocity this documentation is separate from DevHub; it’s also meant as an external resource for people who don’t really use the language but are curious about the differences and novelties. - -Also, `_VERSION` now returns "Luau" because we definitely aren’t using Lua 5.1 anymore. - -## Compound assignments - -A long-standing feature request for Lua is compound assignments. Somehow Lua never got this feature, but Luau now implements `+=`, `-=`, `*=`, `/=`, `%=`, `^=` and `..=` operators. We decided to implement them because they are absolutely ubiquitous among most frequently used programming languages, both those with C descent and those with different lineage (Ruby, Python). They result in code that’s easier to read and harder to make mistakes in. - -We do not implement `++` and `--`. These aren’t universally adopted, `--` conflicts with comment syntax and they are arguably not as intuitively obvious. We trust everyone to type a few extra characters for `+= 1` without too much trouble. - -Two important semantical notes are that the expressions on the left hand side are only evaluated once, so for example `table[makeIndex()] += 1` only runs `makeIndex` once, and that compound assignments still call all the usual metamethod (`__add` et al, and `__index`/`__newindex`) when necessary - you don’t need to change any data structures to work with these. - -There’s no noticeable performance improvement from these operators (nor does using them carry a cost) - use them when they make sense for readability. - -## Nicer error messages - -Good errors are critical to be able to use Luau easily. We’ve spent some time to improve the quality of error messages during parsing and runtime execution: - - * In runtime type errors, we now often use the “Roblox†type name instead of plain userdata, e.g. `math.abs(v)` now says `number` expected, got `Vector3` - * When arguments are just missing, we now explicitly say that they are missing in libraries like math/table; the old message was slightly more confusing - * `string.format` in some cases produced error messages that confused missing arguments for incorrect types, which has been fixed - * When a builtin function such as `math.abs` fails, we now add the function name to the error message. This is something that used to happen in Lua, then we lost this in Luau because Luau removes a very fragile mechanism that supported that, but we now have a new, robust way to report this so you can have the function name back! The message looks like this now: `invalid argument #1 to 'abs' (number expected, got nil)` - * In compile-time type errors, we now can identify the case when the field was mistyped with a wrong case (ha), and tell you to use the correct case instead. - * When you forget an `end` statement, we now try to be more helpful and point you to the problematic statement instead of telling you that the end is missing at the very end of the program. This one is using indentation as a heuristic so it doesn’t always work perfectly. - * We now have slightly more helpful messages for cases when you forget parentheses after a function call - * We now have slightly more helpful messages for some cases when you accidentally use `( ... )` instead of `{ ... }` to create a table literal -Additionally two places had very lax error checking that made the code more fragile, and we fixed those: - - * `xpcall` now fails immediately when the error function argument is not a function; it used to work up until you get an error, and failed at that point, which made it hard to find these bugs - * `tostring` now enforces the return type of the result to be a string - previously `__tostring` could return a non-string result, which worked fine up until you tried to do something like passing the resulting value to `string.format` for `%s`. Now `tostring` will fail early. -Our next focus here is better error messages during type checking - please let us know if there are other errors you find confusing and we could improve! - -## Type checker improvements - -We’re getting closer and closer to be able to move out of beta. A big focus this month was on fixing all critical bugs in the type checker - it now should never hang or crash Studio during type checking, which took a bit of work to iron out all the problems. - -Notably, typing function string.length no longer crashes Studio (although why you’d do that is unclear), and Very Large Scripts With Tons Of Nested Statements And Expressions should be stable as well. - -We’ve also cleaned up the type information for builtin libraries to make it even more precise, including a few small fixes to `string/math` functions, and a much more precise coroutine library type information. For the latter we’ve introduced a primitive type `thread`, which is what `coroutine` library works with. - -## Linter improvements - -Linter is the component that produces warnings about scripts; it’s otherwise known as “Static Analysis†in Studio, although that is now serving as a place where we show type errors as well. - -Most of the changes here this month are internal as they concern warnings that aren’t yet enabled in Studio (the web site linked above documents all warnings including ones that aren’t active yet but may become active), but once notable feature is that you can now opt out of individual warnings on a script-by-script basis by adding a --!nolint comment to the top of the script. For example, if you really REALLY *REALLY* like the `Game` global, you can add this to the top of the script: - -``` ---!nolint DeprecatedGlobal -``` -Or, if you basically just want us to not issue any warnings ever, I guess you can add this: -``` ---!nocheck ---!nolint -``` -and live happily ignorant of all possible errors up until you run your code. (please don’t do that) - -## os. enhancements - -Our overall goal is to try to be reasonably compatible with Lua 5.x in terms of library functions we expose. This doesn’t always work - in some cases we have to remove library features for sandboxing reasons, and in others the library functions don’t make sense in context of Roblox. However, some of these decisions can be revised later. In particular, when we re-added `os.` library to Roblox, we limited it to `os.date`, `os.time` and `os.difftime` (although why `difftime` is a thing isn’t clear), omitting `os.clock` and restricting inputs to `os.date` to return a table with date components, whereas Lua 5.x supports format strings. - -Well, this changes today. `os.clock` is now available if you need a high-precision time for benchmarking, and `os.date` can now return formatted date using Lua 5.x format string that you can read about here [https://www.lua.org/pil/22.1.html](https://www.lua.org/pil/22.1.html) (we support all these specifiers: aAbBcdHIjmMpSUwWxXyYzZ). - -While `os.date()` is hopefully welcome, `os.clock` may raise some eyebrows - aren’t there enough timing functions in Roblox already? Well, this is nice if you are trying to port code from Lua 5.x to Luau, and there’s this - -![Oblig. xkcd]({{ site.url }}{{ site.baseurl }}/assets/images/luau-recap-june-2020-xkcd.png) - -But really, most existing Roblox timing functions are… problematic. - - * `time()` returns the total amount of time the game has been running simulation for, it’s monotonic and has reasonable precision. It’s fine - you can use it to update internal gameplay systems without too much trouble. It should’ve been called “tick†perhaps but that ship has sailed. - * `elapsedTime` and its close cousin `ElapsedTime`, are telling you “how much time has elapsed since the current instance of Roblox was started.â€. While technically true, this isn’t actually useful because on mobile the “start†time here can be days in the past. It’s also inadequate for performance measurements as on Windows, it has a 1ms resolution which isn’t really enough for anything interesting. We’re going to deprecate this in the future. - * `tick()` sounds perfect - it has a high resolution (usually around 1 microsecond), and a well-defined baseline - it counts since UNIX epoch! Or, well, it actually doesn’t. On Windows, it returns you a variant of the UNIX timestamp in local time zone. In addition, it can be off by 1 second from the actual, real UNIX timestamp, and might have other idiosyncrasies on non-Windows platforms. We’re going to deprecate this in the future - -So, if you need a UNIX timestamp, you should use `os.time()`. You get a stable baseline (from 1970’s) and 1s resolution. If you need to measure performance, you should use `os.clock()`. You don’t get a stable baseline, but you get ~1us resolution. If you need to do anything else, you should probably use `time()`. - -## Performance optimizations - -As you can never have too much performance, we’re continuing to work on performance! We’re starting to look into making Vector3 faster and improving the garbage collector, with some small changes already shipping, but overall it’s a long way out so here are the things that did get visibly better: - - * A few `string.` methods, notably `string.byte` and `string.char`, were optimized to make it easier to write performant deserialization code. string.byte is now ~4x faster than before for small numbers of returned characters. For optimization to be effective, it’s important to call the function directly ( `string.byte(foo, 5)` ) instead of using method calls ( `foo:byte(5)` ) - * Optimize coroutine resumption, making some code that is heavily reliant on `coroutine`. library ~10% faster. We have plans to improve this further, watch this space. - * Optimize `typeof()` to run ~6x faster. It used to be that `type()` was much faster than `typeof()` but they now should be more or less comparable. - * Some secret internal optimizations make some scripts a few percent faster - * The memory allocator used in Luau was rewritten using a new, more efficient, implementation. There might be more changes here in the future to save some memory, but for now this makes some allocation-intensive benchmarks ~15% faster. - * Using tables with keys that are not strings or numbers is a fair bit more efficient now (most commonly comes up when Instance is used as a key in a hash table), on par with using strings. - -Also we found a bug with some of our optimizations (which delayed the string. performance improvement above, but also could affect some math. calls) where in some complex functions you would see valid calls to math. etc. breaking with non-sensical errors such as “expected number, got table†- this has been fixed! - -## Memory optimizations - -As with performance, our goal here is simple - the more efficient internal Luau structures can become, the less memory will Lua heap take. This is great for both memory consumption, and for garbage collection performance as the collector needs to traverse less data. There’s a few exciting changes in this area this month: - - * Non-array-like tables now take 20% less space. This doesn’t affect arrays but can be observed on object-like tables, both big and small. This is great because some of you are using a lot of large tables apparently, since this resulted in very visible reduction in overall Lua heap sizes across all games. - * Function objects now take up to 30% less space. This isn’t as impactful since typically function objects are not created very frequently and/or don’t live for very long, but it’s nice nonetheless. - * New allocator mentioned in the previous section can save up to 5-6% of Lua heap memory as well, although these gains are highly dependent on the workload, and we usually see savings in the 1-2% range. - -And that’s it! Till next time. As usual let us know if you have questions, suggestions or bug reports. diff --git a/docs/_posts/2020-08-11-luau-recap-august-2020.md b/docs/_posts/2020-08-11-luau-recap-august-2020.md deleted file mode 100644 index 5d149c18c..000000000 --- a/docs/_posts/2020-08-11-luau-recap-august-2020.md +++ /dev/null @@ -1,93 +0,0 @@ ---- -layout: single -title: "Luau Recap August 2020" ---- - -As everyone knows by now, Luau is our new language stack that you can read more about at [https://roblox.github.io/luau](https://roblox.github.io/luau) and the month following June is August so let’s talk about changes, big and small, that happened since June! - -Many people work on these improvements, with the team slowly growing - thanks @Apakovtac, @EthicalRobot, @fun_enthusiast, @mrow_pizza and @zeuxcg! - -[Originally posted on the [Roblox Developer Forum](https://devforum.roblox.com/t/luau-recap-august-2020/).] - -## Type annotations are safe to use in production! - -When we started the Luau type checking beta, we’ve had a big warning sign in the post saying to not publish the type-annotated scripts to your production games which some of you did anyway. This was because we didn’t want to commit to specific syntax for types, and were afraid that changing the syntax would break your games. - -This restriction is lifted now. All scripts with type annotations that parse & execute will continue to parse & execute forever. Crucially, for this to be true you must not be using old fat arrow syntax for functions, which we warned you about for about a month now: - -![Fat arrow deprecated]({{ site.url }}{{ site.baseurl }}/assets/images/luau-recap-august-2020-arrow.png) - -… and must not be using the `__meta` property which no longer holds special meaning and we now warn you about that: - -![meta deprecated]({{ site.url }}{{ site.baseurl }}/assets/images/luau-recap-august-2020-meta.png) - -Part of the syntax finalization also involved changing the precedence on some type annotations and adding support for parentheses; notably, you can now mix unions and intersections if you know what that means (`(A & B) | C` is valid type syntax). Some complex type annotations changed their structure because of this - previously `(number) -> string & (string) -> string` was a correct way to declare an intersection of two function types, but now to keep it parsing the same way you need to put each function type in parentheses: `((number) -> string) & ((string) -> string)`. - -Type checking is not out of beta yet - we still have some work to do on the type checker itself. The items on our list before going out of beta right now include: - - * Better type checking for unary/binary operators - * Improving error messages to make type errors more clear - * Fixing a few remaining crashes for complex scripts - * Fixing conflation of warnings/errors between different scripts with the same path in the tree - * Improving type checking of globals in nonstrict mode (strict mode will continue to frown upon globals) - -Of course this doesn’t mark the end of work on the feature - after type checking goes out of beta we plan to continue working on both syntax and semantics, but that list currently represents the work we believe we have left to do in the first phase - please let us know if there are other significant issues you are seeing with beta beyond future feature requests! - -## Format string analysis - -A few standard functions in Luau are using format strings to dictate the behavior of the code. There’s `string.format` for building strings, `string.gmatch` for pattern matching, `string.gsub`'s replacement string, `string.pack` binary format specification and `os.date` date formatting. - -In all of these cases, it’s important to get the format strings right - typos in the format string can result in unpredictable behavior at runtime including errors. To help with that, we now have a new lint rule that parses the format strings and validates them according to the expected format. - -![String format]({{ site.url }}{{ site.baseurl }}/assets/images/luau-recap-august-2020-format.png) - -Right now this support is limited to direct library calls (`string.format("%.2f", ...)` and literal strings used in these calls - we may lift some of these limitations later to include e.g. support for constant locals. - -Additionally, if you have type checking beta enabled, string.format will now validate the argument types according to the format string to help you get your `%d`s and `%s`es right. - -![String format]({{ site.url }}{{ site.baseurl }}/assets/images/luau-recap-august-2020-format2.png) - -## Improvements to string. library - -We’ve upgraded the Luau string library to follow Lua 5.3 implementation; specifically: - - * `string.pack/string.packsize/string.unpack` are available for your byte packing needs - * `string.gmatch` and other pattern matching functions now support `%g` and `\0` in patterns - -This change also [inadvertently] makes `string.gsub` validation rules for replacement string stricter - previously `%` followed by a non-digit character was silently accepted in a replacement string, but now it generates an error. This accidentally broke our own localization script [Purchase Prompt broken in some games (% character in title)](https://devforum.roblox.com/t/purchase-prompt-broken-in-some-games-character-in-title/686237)), but we got no other reports, and this in retrospect is a good change as it makes future extensions to string replacement safe… It was impossible for us to roll the change back and due to a long release window because of an internal company holiday we decided to keep the change as is, although we’ll try to be more careful in the future. - -On a happier note, string.pack may seem daunting but is pretty easy to use to pack binary data to reduce your network traffic (note that binary strings aren’t safe to use in DataStores currently); I’ve posted an example in the release notes thread [Release Notes for 441](https://devforum.roblox.com/t/release-notes-for-441/686773) that allows you to pack a simple character state in 16 bytes like this: -``` -local characterStateFormat = "fffbbbB" - -local characterState = string.pack(characterStateFormat, - posx, posy, posz, dirx * 127, diry * 127, dirz * 127, health) -``` -And unpack it like this after network transmission: -``` -local posx, posy, posz, dirx, diry, dirz, health = - string.unpack(characterStateFormat, characterState) -dirx /= 127 -diry /= 127 -dirz /= 127 -``` - -## Assorted fixes - -As usual we fixed a few small problems discovered through testing. We now have an automated process that generates random Luau code in semi-intelligent ways to try to break different parts of our system, and a few fixes this time are a direct result of that. - - * Fix line debug information for multi-line function calls to make sure errors for code like `foo.Bar(...)` are generated in the appropriate location when foo is nil - * Fix debug information for constant upvalues; this fixes some bugs with watching local variables from the nested functions during debugging - * Fix an off-by-one range check in string.find for init argument that could result in reading uninitialized memory - * Fix type confusion for table.move target table argument that could result in reading or writing arbitrary memory - * Fix type confusion for `debug.getinfo` in some circumstances (we don’t currently expose getinfo but have plans to do so in the future) - * Improve out of memory behavior for large string allocations in string.rep and some other functions like `table.concat` to handle these conditions more gracefully - * Fix a regression with `os.time` from last update, where it erroneously reverted to Lua 5.x behavior of treating the time as a local time. Luau version (intentionally) deviates from this by treating the input table as UTC, which matches `os.time()` behavior with no arguments. - -## Performance improvements - -Only two changes in this category here this time around; some larger scale performance / memory improvements are still pending implementation. - - * Constant locals are now completely eliminated in cases when debugging is not available (so on server/client), making some scripts ~1-2% faster - * Make script compilation ~5% faster by tuning the compiler analysis and code generation more carefully -Oh, also `math.round` is now a thing which didn’t fit into any category above. \ No newline at end of file diff --git a/docs/_posts/2020-10-30-luau-recap-october-2020.md b/docs/_posts/2020-10-30-luau-recap-october-2020.md deleted file mode 100644 index 727fef70a..000000000 --- a/docs/_posts/2020-10-30-luau-recap-october-2020.md +++ /dev/null @@ -1,91 +0,0 @@ ---- -layout: single -title: "Luau Recap: October 2020" ---- - -Luau is our new language that you can read more about at [https://roblox.github.io/luau](https://roblox.github.io/luau); we’ve been so busy working on the current projects that we didn’t do an update in September, so let’s look at changes that happened since August! - -Many people work on these improvements, with the team slowly growing - thanks @Apakovtac, @EthicalRobot, @fun_enthusiast, @machinamentum, @mrow_pizza and @zeuxcg! - -[Originally posted on the [Roblox Developer Forum](https://devforum.roblox.com/t/luau-recap-october-2020/).] - -## Types are very close - -We’ve been in beta for a while now, but we’re steadily marching towards getting the first release of the type checker, what we call “types v0â€, out of the door. It turns out that we’ve substantially underestimated the effort required to make the type system robust, strike the balance between “correct†and “usable†and give quality diagnostics in the event we do find issues with your code 🙂 - -Because of this, we’re changing the original plans for the release a bit. We’re actively working on a host of changes that we consider to be part of the “v0†effort, and when they are all finished - which should happen next month, fingers crossed - we’re going to be out of beta! - -However, by default, on scripts with no annotations, we won’t actually activate type checking. You would have to opt into the type checking by using `--!nonstrict` or `--!strict`, at the top of each script. We are also going to open the second beta, “All scripts use non-strict mode by default†or something along these lines. - -This is important because we found that our non-strict mode still needs some more work to be more tolerant to some code that occurs commonly in Roblox and is correct, but doesn’t type-check. We’re going to evaluate what changes specifically are required to make this happen, but we didn’t want the extra risk of a flood of reports about issues reported in existing code to shift the release date in an unpredictable fashion. - -To that end, we’ve been working on Lots and Lots and Lots and Lots and Lots of changes to finish the first stage. Some of these changes are already live and some are rolling out; the amount of changes is so large that I can’t possibly list the up-to-date status on each one as these recaps are synthesized by the human who is writing this on a Friday night, so here’s just a raw list of changes that may or may not have been enabled: - - * Strict mode is now picky about passing extra arguments to functions, even though they are discarded silently at runtime, as this can hide bugs - * The error message about using a : vs . during type checking is now much more precise - * Recursive type annotations shouldn’t crash the type checker now, and we limit the recursion and iteration depth during type checking in a few cases in general in an effort to make sure type checker always returns in finite time - * Binary relational operators (`<` et al) are now stricter about the argument types and infer the argument types better - * Function argument and return types are now correctly contra- and co-variant; if this looks like gibberish to you, trust me - it’s for the best! - * Fixed a few problems with indexing unions of tables with matching key types - * Fixed issues with tracing types across modules (via require) in non-strict mode - * Error messages for long table types are now trimmed to make the output look nicer - * Improve the interaction between table types of unknown shape (`{ [string]: X }`) and table types of known shape. - * Fix some issues with type checking table assignments - * Fix some issues with variance of table fields - * Improve the legibility of type errors during function calls - errors now point at specific arguments that are incorrect, and mismatch in argument count should clearly highlight the problem - * Fix types for many builtins including `ipairs`, `table.create`, `Color3.fromHSV`, and a few others - * Fix missing callbacks for some instance types like `OnInvoke` for bindables (I think this one is currently disabled while we’re fixing a semi-related bug, but should be enabled soon!) - * Rework the rules under which globals are okay to use in non-strict mode to mostly permit valid scripts to type-check; strict mode will continue to frown upon the use of global variables - * Fix a problem with the beta where two scripts with identical names would share the set of errors/warnings, resulting in confusing error highlights for code that doesn’t exist - * Improve the legibility of type errors when indexing a table without a given key - * Improve the parsing error when trying to return a tuple; `function f(): string, number` is invalid since the type list should be parenthesized because of how our type grammar is currently structured - * Type checker now clearly reports cases where it finds a cyclic dependency between two modules - * Type errors should now be correctly sorted in the Script Analysis widget - * Error messages on mismatches between numbers of values in return statements should now be cleaner, as well as the associated type mismatch errors - * Improve error messages for comparison operators - * Flag attempts to require a non-module script during type checking - * Fix some cases where a type/typeof guard could be misled into inferring a non-sensible type - * Increase the strictness of return type checks in strict mode - functions now must conform to the specified type signature, whereas before we’d allow a function to return no values even in strict mode - * Improve the duplicate definition errors to specify the line of the first definition - * Increase the strictness of binary operators in strict mode to enforce the presence of the given operator as a built-in or as part of the metatable, to make sure that strict mode doesn’t infer types when it can’t guarantee correctness - * Improve the type errors for cyclic types to make them more readable - * Make type checker more friendly by rewording a lot of error messages - * Fix a few crashes in the type checker (although a couple more remain - working on them!) - * … I think that’s it? - * …edit ah, of course I forgot one thing - different enums that are part of the Roblox API now have distinct types and you can refer to the types by name e.g. `Enum.Material`; this should go live next week though. -If you want to pretend that you’ve read and understood the entire list above, just know that we’ve worked on making sure strict mode is more reliably reporting type errors and doesn’t infer types incorrectly, on making sure non-strict mode is more forgiving for code that is probably valid, and on making the type errors more specific, easier to understand, and correct. - -## Type syntax changes - -There’s only two small changes here this time around - the type syntax is now completely stable at this point, and any existing type annotation will continue parsing indefinitely. We of course reserve the right to add new syntax that’s backwards compatible :slight_smile: - -On that note, one of the small changes is that we’ve finally removed support for fat arrows (`=>`); we’ve previously announced that this would happen and that thin arrows (`->`) are the future, and had warnings issued on the legacy syntax for a while. Now it’s gone. - -On a positive note, we’ve added a shorter syntax for array-like table types. Whereas before you had to use a longer `{ [number]: string }` syntax to declare an array-like table that holds strings, or had to define an `Array` type in every. single. module. you. ever. write. ever., now you can simply say `{string}`! This syntax is clean, in line with the value syntax for Lua table literals, and also was chosen by other research projects to add type annotations to Lua. - -(if you’re a monster that uses mixed tables, you’ll have to continue using the longer syntax e.g. `{ [number]: string, n: number }`) - -## Library changes - -There’s only a few small tweaks here this time around on the functionality front: - - * `utf8.charpattern` is now exactly equal to the version from Lua 5.3; this is now possible because we support `\0` in patterns, and was suggested by a user on devforum. We do listen! - * `string.pack` now errors out early when the format specifier is Way Too Large. This was reported on dev forum and subsequently fixed. Note that trying to generate a Moderately Large String (like, 100 MB instead of 100 GB) will still succeed but may take longer than we’d like - we have a plan to accelerate operations on large strings substantially in the coming months. - -## Performance improvements - -We were super focused on other things so this is very short this time around. We have a lot of ideas here but they are waiting for us to finish some other large projects! - - * Method calls on strings via `:` are now ~10% faster than before. We still recommend using fully-qualified calls from string library such as `string.foo(str)`, but extra performance never hurts! - * Speaking of string methods, string.sub is now ~20% faster than before with the help of voodoo magic. - -## Miscellaneous fixes - -There were a few small fixes that didn’t land into any specific category that I wanted to highlight: - - * In some rare cases, debug information on conditions inside loops have been fixed to stop debugger from incorrectly suggesting that the current line is inside a branch that wasn’t taken. As usual, if you ever see debugger misbehaving, please file bugs on this! - * Code following `assert(false)` is now treated as an unreachable destination from the linting and type checking point of view, similarly to error calls. - * Linting support for various format strings has been greatly improved based on fantastic feedback from @Halalaluyafail3 (thanks!). - -Ok, phew, that’s what I get for skipping a month again. Please don’t hesitate to report bugs or suggestions, here or via separate posts. Due to our usual end-of-year code freeze there’s going to be one more recap at the end of the year where we will look back at 2020 and take a small peek into the future. - diff --git a/docs/_posts/2020-11-19-luau-type-checking-release.md b/docs/_posts/2020-11-19-luau-type-checking-release.md deleted file mode 100644 index 35879573d..000000000 --- a/docs/_posts/2020-11-19-luau-type-checking-release.md +++ /dev/null @@ -1,109 +0,0 @@ ---- -layout: single -title: "Luau Type Checking Release" ---- - -10 months ago, we’ve started upon the journey of helping Roblox scripters write robust code by introducing [an early beta of type checking](https://devforum.roblox.com/t/luau-type-checking-release). We’ve received a lot of enthusiastic feedback and worked with the community on trying to make sure critical issues are addressed, usability is improved and the type system is ready for prime time. - -Today I’m incredibly excited to announce that the first release of [Luau](https://roblox.github.io/luau/) type checking is officially released! Thanks a lot to @Apakovtac, @EthicalRobot, @fun_enthusiast, @machinamentum, @mrow_pizza and @zeuxcg! - -[Originally posted on the [Roblox Developer Forum](https://devforum.roblox.com/t/luau-type-checking-release/).] - -## What is type checking? - -When Luau code runs, every value has a certain type at runtime - a kind of value it stores. It could be a number, a string, a table, a Roblox Instance or one of many others. Thing is, some operations work on some types but don’t work on others! - -Consider this: -``` -local p = Instance.new("Part") -p.Positio = Vector3.new(1,2,3) -``` -Is this code correct? No - there’s a typo. The way you get to find this typo is by running your code and eventually seeing an error message. Type checker tries to analyze your code before running, by assigning a type to each value based on what we know about how that value was produced, or based on the type you’ve explicitly told us using a new syntax extension, and can produce an error ahead of time: - -!["Positio not found in class Part"]({{ site.url }}{{ site.baseurl }}/assets/images/luau-type-checking-release-screenshot.png) - -This can require some effort up front, especially if you use strict mode, but it can save you valuable time in the future. It can be especially valuable if you have a large complex code base you need to maintain for years, as is the case with many top Roblox games. - -## How do I use type checking? - -A very important feature of Luau type checking you need to know about is that it has three modes: - - * `nocheck`, where we don’t type check the script in question. - * `nonstrict`, where we type check the script but try to be lenient to allow commonly seen patterns even if they may violate type safety - * `strict`, where we try to make sure that every single line of code you write is correct, and every value has a known type. - -The modes can be selected per script by writing a comment at the top of the script that starts with `--!`, e.g. `--!strict`. - -As of this release, the default mode is nocheck. This means by default you actually won’t see the type checking produce feedback on your code! We had to use nocheck by default because we aren’t fully ready to unleash nonstrict mode on unsuspecting users - we need to do a bit more work to make sure that most cases where we tell you that something is wrong are cases where yes, something is actually wrong. - -However we highly encourage trying at least non-strict mode on your codebase. You can do this by opting into a different default via a Studio beta: - -!["Studio option"]({{ site.url }}{{ site.baseurl }}/assets/images/luau-type-checking-release-studio-option.png) - -This beta only changes the default mode. Another way to change the mode is to prepend a `--!` comment to the script - you can do this manually for now, but if anyone in the community wants to release a plugin that does it automatically on selected scripts (+ descendants), that would be swell! - -If you really want your code to be rock solid, we recommend trying out strict mode. Strict mode will require you to use type annotations. - -## What are type annotations and how do I use them? - -Glad you asked! (please pretend you did) Type annotations are a way to tell the type checker what the type of a variable is. Consider this code in strict mode: -``` -function add(x, y) - return x + y -end -``` -Is this code correct? Well, that depends. `add(2, 3)` will work just fine. `add(Vector3.new(1, 2, 3), Vector3.new(4, 5, 6))` will work as well. But `add({}, nil)` probably isn’t a good idea. - -In strict mode, we will insist that the type checker knows the type of all variables, and you’ll need to help the type checker occasionally - by adding types after variable names separated by `:`: -``` -function add(x: number, y: number) - return x + y -end -``` -If you want to tell the type checker “assume this value can be anything and I will take responsibilityâ€, you can use `any` type which will permit any value of any type. - -If you want to learn more about the type annotation syntax, you should read this [documentation on syntax](https://roblox.github.io/luau/syntax.html#type-annotations). We also have a somewhat more complete guide to type checking than this post can provide, that goes into more details on table types, OOP, Roblox classes and enums, interaction with require and other topics - [read it if you’re curious!](https://roblox.github.io/luau/typecheck.html). - -## What happens when I get a type error? - -One concept that’s very important to understand is that right now type errors do not influence whether the code will run or not. - -If you have a type error, this means that our type checker thinks your code has a bug, or doesn’t have enough information to prove the code works fine. But if you really want to forge ahead and run the code - you should feel free to do so! - -This means that you can gradually convert your code to strict mode by adding type annotations and have the code runnable at all times even if it has type errors. - -This also means that it’s safe to publish scripts even if type checker is not fully happy with them - type issues won’t affect script behavior on server/client, they are only displayed in Studio. - -## Do I have to re-learn Lua now?!? - -This is a question we get often! The answer is “noâ€. - -The way the type system is designed is that it’s completely optional, and you can use as many or as few types as you’d like in your code. - -In non-strict mode, types are meant as a lightweight helper - if your code is likely wrong, we’re going to tell you about it, and it’s up to you on whether to fix the issue, or even disable the type checker on a given problematic file if you really don’t feel like dealing with this. - -In strict mode, types are meant as a power user tool - they will require more time to develop your code, but they will give you a safety net, where changing code will be much less likely to trigger errors at runtime. - -## Is there a performance difference? - -Right now type annotations are ignored by our bytecode compiler; this means that performance of the code you write doesn’t actually depend on whether you use strict, nonstrict or nocheck modes or if you have type annotations. - -This is likely going to change! We have plans for using the type information to generate better bytecode in certain cases, and types are going to be instrumental to just-in-time compilation, something that we’re going to invest time into next year as well. - -Today, however, there’s no difference - type information is completely elided when the bytecode is built, so there is zero runtime impact one way or another. - -## What is next for types? - -This is the first full release of type checking, but it’s by far the last one. We have a lot more ground to cover. Here’s a few things that we’re excited about that will come next: - - * Making nonstrict mode better to the point where we can enable it as a default for all Roblox scripts - - * Adding several features to make strict mode more powerful/friendly, such as typed variadics, type ascription and better generics support - - * Improving type refinements for type/typeof and nil checks - - * Making it possible to view the type of a variable in Studio - - * Reworking autocomplete to use type information instead of the current system - -If you have any feedback on the type system, please don’t hesitate to share it here or in dedicated bug report threads. We’re always happy to fix corner cases that we’ve missed, fix stability issues if they are discovered, improve documentation when it’s not clear or improve error messages when they are hard to understand. \ No newline at end of file diff --git a/docs/_posts/2021-03-01-luau-recap-february-2021.md b/docs/_posts/2021-03-01-luau-recap-february-2021.md deleted file mode 100644 index 1179eaea5..000000000 --- a/docs/_posts/2021-03-01-luau-recap-february-2021.md +++ /dev/null @@ -1,83 +0,0 @@ ---- -layout: single -title: "Luau Recap: February 2021" ---- - -Luau is our new language that you can read more about at [https://roblox.github.io/luau](https://roblox.github.io/luau). It's been a busy few months in Luau! - -[Cross-posted to the [Roblox Developer Forum](https://devforum.roblox.com/t/luau-recap-february-2021/).] - -## Infallible parser - -Traditional compilers have focused on tasks that can be performed on complete programs, such as type-checking, static analysis and code generation. This is all good, but most programs under development are incomplete! They may have holes, statements that will be filled in later, and lines that are in the middle of being edited. If we'd like to provide support for developers while they are writing code, we need to provide tools for incomplete programs as well as complete ones. - -The first step in this is an *infallible* parser, that always returns an Abstract Syntax Tree, no matter what input it is given. If the program is syntactically incorrect, there will also be some syntax errors, but the parser keeps going and tries to recover from those errors, rather than just giving up. - -The Luau parser now recovers from errors, which means, for example, we can give hints about programs in an IDE. - -![A type error after a syntax error]({{ site.url }}{{ site.baseurl }}/assets/images/type-error-after-syntax-error.png) - -## Type assertions - -The Luau type checker can't know everything about your code, and sometimes it will produce type errors even when you know the code is correct. For example, sometimes the type checker can't work out the intended types, and gives a message such as "Unknown type used... consider adding a type annotation". - -!["Consider adding a type annotation"]({{ site.url }}{{ site.baseurl }}/assets/images/type-annotation-needed.png) - -Previously the only way to add an annotation was to put it on the *declaration* of the variable, but now you can put it on the *use* too. A use of variable `x` at type `T` can be written `x :: T`. For example the type `any` can be used almost anywhere, so a common usage of type assertions is to switch off the type system by writing `x :: any`. - -!["A type assertion y:any"]({{ site.url }}{{ site.baseurl }}/assets/images/type-annotation-provided.png) - -## Typechecking improvements - -We've made various improvements to the Luau typechecker: - - * We allow duplicate function definitions in non-strict mode. - * Better typechecking of `and`, `(f or g)()`, arrays with properties, and `string:format()`. - * Improved typechecking of infinite loops. - * Better error reporting for function type mismatch, type aliases and cyclic types. - -## Performance improvements - -We are continuing to work on optimizing our VM and libraries to make sure idiomatic code keeps improving in performance. Most of these changes are motivated by our benchmark suite; while some improvements may seem small and insignificant, over time these compound and allow us to reach excellent performance. - - * Table key assignments as well as global assignments have been optimized to play nicer with modern CPUs, yielding ~2% speedup in some benchmarks - * Luau function calls are now ~3% faster in most cases; we also have more call optimizations coming up next month! - * Modulo operation (%) is now a bit faster on Windows, resulting in ~2% performance improvement on some benchmarks - -!["Benchmark vs Lua 5.3"]({{ site.url }}{{ site.baseurl }}/assets/images/luau-recap-february-2021-benchmark.png) - -## Debugger improvements - -Our Luau VM implementation is focused on performance and provides a different API for implementation of debugger tools. But it does have its caveats and one of them was inability to debug coroutines (breakpoints/stepping). - -The good news is that we have lifted that limitation and coroutines can now be debugged just like any regular function. This can especially help people who use Promise libraries that rely on coroutines internally. - -![Debugging a coroutine]({{ site.url }}{{ site.baseurl }}/assets/images/luau-recap-february-2021-debugger.png) - -## Library changes - -`table` library now has a new method, `clear`, that removes all keys from the table but keeps the internal table capacity. When working with large arrays, this can be more efficient than assigning a table to `{}` - the performance gains are similar to that of using `table.create` instead of `{}` *when you expect the number of elements to stay more or less the same*. Note that large empty tables still take memory and are a bit slower for garbage collector to process, so use this with caution. - -In addition to that we found a small bug in `string.char` implementation that allowed creating strings from out-of-range character codes (e.g. `string.char(2000)`); the problem has been fixed and these calls now correctly generate an error. - -## Coming soon... - -* _Generic function types_ will soon be allowed! -``` -function id(x: a): a - return x -end -``` - -* _Typed variadics_ will soon allow types to be given to functions with varying numbers of arguments! -``` -function sum(...: number): number - local result = 0 - for i,v in ipairs({...}) do - result += v - end - return result -end -``` - -And there will be more! diff --git a/docs/_posts/2021-03-29-luau-recap-march-2021.md b/docs/_posts/2021-03-29-luau-recap-march-2021.md deleted file mode 100644 index f91c93e8e..000000000 --- a/docs/_posts/2021-03-29-luau-recap-march-2021.md +++ /dev/null @@ -1,93 +0,0 @@ ---- -layout: single -title: "Luau Recap: March 2021" ---- - -Luau is our new language that you can read more about at [https://roblox.github.io/luau](https://roblox.github.io/luau). It's been a busy month in Luau! - -[Cross-posted to the [Roblox Developer Forum](https://devforum.roblox.com/t/luau-recap-march-2021/).] - -## Typed variadics - -Luau supports *variadic* functions, meaning ones which can take a variable number of arguments (varargs!) but previously there was no way to specify their type. Now you can! -``` -function f(x: string, ...: number) - print(x) - print(...) -end -f("hi") -f("lo", 5, 27) -``` -This function takes a string, plus as many numbers as you like, but if you try calling it with anything else, you'll get a type error, for example `f("oh", true)` gives an error "Type `boolean` could not be converted into `number`" - -Variadics can be used in function declarations, and function types, for example -``` -type T = { - sum: (...number) -> number -} -function f(x: T) - print(x.sum(1, 2, 3)) -end -``` - -## Generic functions - -**WARNING** Generic functions are currently disabled as we're fixing some critical bugs. - -## Typechecking improvements - -We've made various improvements to the Luau typechecker: - -* Check bodies of methods whose `self` has type `any` -* More precise types for `debug.*` methods -* Mutually dependent type aliases are now handled correctly - -## Performance improvements - -We are continuing to squeeze the performance out of all sorts of possible code; this is an ongoing process and we have many improvements in the pipeline, big and small. These are the changes that are already live: - -* Significantly optimized non-variadic function calls, improving performance by up to 10% on call-heavy benchmarks -* Improve performance of `math.clamp`, `math.sign` and `math.round` by 2.3x, 2x and 1.6x respectively -* Optimized `coroutine.resume` with ~10% gains on coroutine-heavy benchmarks -* Equality comparisons are now a bit faster when comparing to constants, including `nil`; this makes some benchmarks 2-3% faster -* Calls to builtin functions like `math.abs` or `bit32.rrotate` are now significantly faster in some cases, e.g. this makes SHA256 benchmark 25% faster -* `rawset`, `rawget`, `rawequal` and 2-argument `table.insert` are now 40-50% faster; notably, `table.insert(t, v)` is now faster than `t[#t+1]=v` - -Note that we work off a set of benchmarks that we consider representative of the wide gamut of code that runs on Luau. If you have code that you think should be running faster, never hesitate to open a feature request / bug report on Roblox Developer Forum! - -## Debugger improvements - -We continue to improve our Luau debugger and we have added a new feature to help with coroutine call debugging. -The call stack that is being displayed while stopped inside a coroutine frame will display the chain of threads that have called it. - -Before: - -!["Old debugger"]({{ site.url }}{{ site.baseurl }}/assets/images/luau-recap-march-2021-debug-before.png) - -After: - -!["New debugger"]({{ site.url }}{{ site.baseurl }}/assets/images/luau-recap-march-2021-debug-after.png) - -We have restored the ability to break on all errors inside the scripts. -This is useful in cases where you need to track the location and state of an error that is triggered inside 'pcall'. -For example, when the error that's triggered is not the one you expected. - -!["Break on all exceptions"]({{ site.url }}{{ site.baseurl }}/assets/images/luau-recap-march-2021-debug-dialog.png) - -## Library changes - -* Added the `debug.info` function which allows retrieving information about stack frames or functions; similarly to `debug.getinfo` from Lua, this accepts an options string that must consist of characters `slnfa`; unlike Lua that returns a table, the function returns all requested values one after another to improve performance. - -## New logo - -Luau now has a shiny new logo! - -!["New logo!"]({{ site.url }}{{ site.baseurl }}/assets/images/luau.png) - -## Coming soon... - -* Generic variadics! -* Native Vector3 math with dramatic performance improvements! -* Better tools for memory analysis! -* Better treatment of cyclic requires during type checking! -* Better type refinements including nil-ability checks, `and`/`or` and `IsA`! diff --git a/docs/_posts/2021-04-30-luau-recap-april-2021.md b/docs/_posts/2021-04-30-luau-recap-april-2021.md deleted file mode 100644 index dc0e284f4..000000000 --- a/docs/_posts/2021-04-30-luau-recap-april-2021.md +++ /dev/null @@ -1,46 +0,0 @@ ---- -layout: single -title: "Luau Recap: April 2021" ---- - -Luau is our new language that you can read more about at [https://roblox.github.io/luau](https://roblox.github.io/luau). Another busy month in Luau with many performance improvements. - -[Cross-posted to the [Roblox Developer Forum](https://devforum.roblox.com/t/luau-recap-april-2021/).] - -## Editor features - -Luau implementation now provides an internal API for type-aware autocomplete suggestions. - -Roblox Studio will be the first user of this API and we plan for a new beta feature to come soon in addition to existing Luau-powered beta features like Go To Declaration, Type Hovers and Script Function Filter (you should check those out!) - -## Performance improvements - -Performance is a very important part of Luau implementation and we continue bringing in new performance optimizations: - -* We've finished the work on internal `vector` value type that will be used by `Vector3` type in Roblox. Improvements of up to 10x can be seen for primitive operations and some of our heavy `Vector3` benchmarks have seen 2-3x improvement. You can read more about this feature [on Roblox Developer forums](https://devforum.roblox.com/t/native-luau-vector3-beta/) -* By optimizing the way string buffers are handled internally, we bring improvements to string operations including `string.lower`, `string.upper`, `string.reverse`, `string.rep`, `table.concat` and string concatenation operator `..`. Biggest improvements can be seen on large strings -* Improved performance of `table.insert` and `table.remove`. Operations in the middle of large arrays can be multiple times faster with this change -* Improved performance of internal table resize which brings additional 30% speedup for `table.insert` -* Improved performance of checks for missing table fields - -## Generic functions - -We had to temporarily disable generic function definitions last month after finding critical issues in the implementation. - -While they are still not available, we are making steady progress on fixing those issues and making additional typechecking improvements to bring them back in. - -## Debugger improvements - -Debugging is now supported for parallel Luau Actors in Roblox Studio. - -Read more about the feature [on Roblox Developer forums](https://devforum.roblox.com/t/parallel-lua-beta/) and try it out yourself. - -## Behavior changes - -Backwards compatibility is important for Luau, but sometimes a change is required to fix corner cases in the language / libraries or to improve performance. Even still, we try to keep impact of these changes to a minimum: - -* __eq tag method will always get called for table comparisons even when a table is compared to itself - -## Coming soon... - -* Better type refinements for statements under a condition using a new constraint resolver. Luau will now understand complex conditions combining `and`/`not` and type guards with more improvements to come diff --git a/docs/_posts/2021-05-31-luau-recap-may-2021.md b/docs/_posts/2021-05-31-luau-recap-may-2021.md deleted file mode 100644 index 1965d2b89..000000000 --- a/docs/_posts/2021-05-31-luau-recap-may-2021.md +++ /dev/null @@ -1,68 +0,0 @@ ---- -layout: single -title: "Luau Recap: May 2021" ---- - -Luau is our new language that you can read more about at [https://roblox.github.io/luau](https://roblox.github.io/luau). This month we have added a new small feature to the language and spent a lot of time improving our typechecker. - -[Cross-posted to the [Roblox Developer Forum](https://devforum.roblox.com/t/luau-recap-may-2021/).] - -## Named function type arguments - -We've updated Luau syntax to support optional names of arguments inside function types. -The syntax follows the same format as regular function argument declarations: `(a: number, b: string)` - -Names can be provided in any place where function type is used, for example: - -* in type aliases: -``` -type MyCallbackType = (cost: number, name: string) -> string -``` - -* for variables: -``` -local cb: (amount: number) -> number -local function foo(cb: (name: string) -> ()) -``` - -Variadic arguments cannot have an extra name, they are already written as ...: number. - -These names are used for documentation purposes and we also plan to display them in Roblox Studio auto-complete and type hovers. -They do not affect how the typechecking of Luau scripts is performed. - -## Typechecking improvements - -Speaking of typechecking, we've implemented many improvements this month: -* Typechecker will now visit bodies of all member functions, previously it didn't check methods if the self type was unknown -* Made improvements to cyclic module import detection and error reporting -* Fixed incorrect error on modification of table intersection type fields -* When using an 'or' between a nillable type and a value, the resulting type is now inferred to be non-nil -* We have improved error messages that suggest to use ':' for a method call -* Fixed order of types in type mismatch error that was sometimes reversed -* Fixed an issue with `table.insert` function signature -* Fixed a bug which caused spurious unknown global errors - -We've also added new checks to our linter: -* A new check will report uses of deprecated Roblox APIs -* Linter will now suggest replacing globals with locals in more cases -* New warning is generated if array loop starts or ends on index '0', but the array is indexed from '1' -* FormatString lint will now check string patterns for `find`/`match` calls via `:` when object type is known to be a string - -We also fixed one of the sources for "Free types leaked into this module's public interface" error message and we are working to fix the remaining ones. - -As usual, typechecking improvements will not break execution of your games even if new errors get reported. - -## Editor features - -We continue to improve our built-in support for auto-complete that will be used in future Roblox Studio updates and will make it easier to implement custom extensions for applications that support Language Server Protocol. - -As part of this work we will improve the type information provided by Roblox APIs to match actual arguments and results. - -## Behavior changes - -When a relational comparison fails at runtime, the error message now specifies the comparison direction (e.g. `attempt to compare nil <= number`) - -## Performance improvements - -* Improved performance of table lookup with an index operator and a literal string: `t["name"]` -* Bytecode compilation is now ~5% faster which can improve server startup time for games with lots of scripts diff --git a/docs/_posts/2021-06-30-luau-recap-june-2021.md b/docs/_posts/2021-06-30-luau-recap-june-2021.md deleted file mode 100644 index 51dcb20a5..000000000 --- a/docs/_posts/2021-06-30-luau-recap-june-2021.md +++ /dev/null @@ -1,83 +0,0 @@ ---- -layout: single -title: "Luau Recap: June 2021" ---- - -Luau is our new language that you can read more about at [https://roblox.github.io/luau](https://roblox.github.io/luau). Most of our team was busy working on improving Luau interaction with Roblox Studio for an upcoming feature this month, but we were able to add typechecking and performance improvements as well! - -[Cross-posted to the [Roblox Developer Forum](https://devforum.roblox.com/t/luau-recap-june-2021/).] - -## Constraint Resolver - -To improve type inference under conditional expressions and other dynamic type changes (like assignments) we have introduced a new constraint resolver framework into Luau type checker. - -This framework allows us to handle more complex expressions that combine `and`/`not` operators and type guards. - -Type guards support include expressions like: - -* `if instance:IsA("ClassName") then` -* `if enum:IsA("EnumName") then` -* `if type(v) == "string" then` - -This framework is extensible and we have plans for future improvements with `a == b`/`a ~= b` equality constraints and handling of table field assignments. - -It is now also possible to get better type information inside `else` blocks of an `if` statement. - -A few examples to see the constraint resolver in action: -```lua -function say_hello(name: string?) - -- extra parentheses were enough to trip the old typechecker - if (name) then - print("Hello " .. name .. "!") - else - print("Hello mysterious stranger!") - end -end -``` -```lua -function say_hello(name: string?, surname: string?) - -- but now we handle that and more complex expressions as well - if not (name and surname) then - print("Hello mysterious stranger!") - else - print("Hello " .. name .. " " .. surname .. "!") - end -end -``` - -Please note that constraints are currently placed only on local and global variables. -One of our goals is to include support for table members in the future. - -## Typechecking improvements - -We have improved the way we handled module `require` calls. Previously, we had a simple pattern match on the `local m = require(...)` statement, but now we have replaced it with a general handling of the function call in any context. - -Handling of union types in equality operators was fixed to remove incorrect error reports. - -A new `IsA` method was introduced to EnumItem to check the type of a Roblox Enum. -This is intended to replace the `enumItem.EnumType == Enum.NormalId` pattern in the code for a construct that allows our constraint resolver to infer better types. - -Additional fixes include: -* `table.pack` return type was fixed -* A limit was added for deeply nested code blocks to avoid a crash -* We have improved the type names that are presented in error messages and Roblox Studio -* Error recovery was added to field access of a `table?` type. While you add a check for `nil`, typechecking can continue with better type information in other expressions. -* We handled a few internal compiler errors and rare crashes - -## Editor features - -If you have Luau-Powered Type Hover beta feature enabled in Roblox Studio, you will see more function argument names inside function type hovers. - -## Behavior changes - -We no longer allow referencing a function by name inside argument list of that function: - -`local function f(a: number, b: typeof(f)) -- 'f' is no longer visible here` - -## Performance improvements - -As always, we look for ways to improve performance of your scripts: -* We have fixed memory use of Roblox Actor scripts in Parallel Luau beta feature -* Performance of table clone through `table.move` has been greatly improved -* Table length lookup has been optimized, which also brings improvement to table element insertion speed -* Built-in Vector3 type support that we mentioned in [April](https://devforum.roblox.com/t/native-luau-vector3-beta/) is now enabled for everyone diff --git a/docs/_posts/2021-07-30-luau-recap-july-2021.md b/docs/_posts/2021-07-30-luau-recap-july-2021.md deleted file mode 100644 index 622733421..000000000 --- a/docs/_posts/2021-07-30-luau-recap-july-2021.md +++ /dev/null @@ -1,117 +0,0 @@ ---- -layout: single -title: "Luau Recap: July 2021" ---- - -Luau is our new language that you can read more about at [https://luau-lang.org](https://luau-lang.org). Our team was still busy working on upcoming Studio Beta feature for script editor, but we did fit in multiple typechecking improvements. - -[Cross-posted to the [Roblox Developer Forum](https://devforum.roblox.com/t/luau-recap-july-2021/).] - -## Typechecking improvements - -A common complaint that we've received was a false-positive error when table with an optional or union element type is defined: -```lua ---!strict -type Foo = {x: number | string} -local foos: {Foo} = { - {x = 1234567}, - {x = "hello"} -- Type 'string' could not be converted into 'number' -} -``` -This case is now handled and skipping optional fields is allowed as well: -```lua ---!strict -type Foo = { - a: number, - b: number? -} -local foos: {Foo} = { - { a = 1 }, - { a = 2, b = 3 } -- now ok -} -``` -Current fix only handles table element type in assignments, but we plan to extend that to function call arguments and individual table fields. - -Like we've mentioned last time, we will continue working on our new type constraint resolver and this month it learned to handle more complex expressions (including type guards) inside `assert` conditions: -```lua ---!strict -local part = script.Parent:WaitForChild("Part") -assert(part:IsA("BasePart")) -local basepart: BasePart = part -- no longer an error -``` - -And speaking of assertions, we applied a minor fix so that the type of the `assert` function correctly defines a second optional `string?` parameter. - -We have also fixed the type of `string.gmatch` function reported by one of the community members. -We know about issues in a few additional library functions and we'll work to fix them as well. - -Hopefully, you didn't see 'free type leak' errors that underline your whole script, but some of you did and reported them to us. -We read those reports and two additional cases have been fixed this month. -We now track only a single one that should be fixed next month. - -Another false positive error that was fixed involves tables with __call metatable function. -We no longer report a type error when this method is invoked and we'll also make sure that given arguments match the function definition: -```lua ---!strict -local t = { x = 2 } - -local x = setmetatable(t, { - __call = function(self, a: number) - return a * self.x - end -}) -local a = x(2) -- no longer an error -``` -Please note that while call operator on a table is now handled, function types in Luau are distinct from table types and you'll still get an error if you try to assign this table to a variable of a function type. - -## Linter improvements - -A new 'TableOperations' lint check was added that will detect common correctness or performance issues with `table.insert` and `table.remove`: -```lua --- table.insert will insert the value before the last element, which is likely a bug; consider removing the second argument or wrap it in parentheses to silence -table.insert(t, #t, 42) - --- table.insert will append the value to the table; consider removing the second argument for efficiency -table.insert(t, #t + 1, 42) - --- table.insert uses index 0 but arrays are 1-based; did you mean 1 instead? -table.insert(t, 0, 42) - --- table.remove uses index 0 but arrays are 1-based; did you mean 1 instead? -table.remove(t, 0) - --- table.remove will remove the value before the last element, which is likely a bug; consider removing the second argument or wrap it in parentheses to silence -table.remove(t, #t - 1) - --- table.insert may change behavior if the call returns more than one result; consider adding parentheses around second argument -table.insert(t, string.find("hello", "h")) -``` - -Another new check is 'DuplicateConditions'. The name speaks for itself, `if` statement chains with duplicate conditions and expressions containing `and`/`or` operations with redundant parts will now be detected: -```lua -if x then - -- ... -elseif not x then - -- ... -elseif x̳ then -- Condition has already been checked on line 1 - -- ... -end - -local success = a and a̳ -- Condition has already been checked on column 17 - -local good = (a or b) or a̳ -- Condition has already been checked on column 15 -``` - -We've also fixed an incorrect lint warning when `typeof` is used to check for `EnumItem`. - -## Editor features - -An issue was fixed that prevented the debugger from displaying values inside Roblox callback functions when an error was reported inside of it. - -## Behavior changes - -`table.insert` will no longer move elements forward 1 spot when index is negative or 0. - -This change also fixed a performance issue when `table.insert` was called with a large negative index. - -The 'TableOperations' lint mentioned earlier will flag cases where insertion at index 0 is performed. diff --git a/docs/_posts/2021-08-31-luau-recap-august-2021.md b/docs/_posts/2021-08-31-luau-recap-august-2021.md deleted file mode 100644 index 1b1213280..000000000 --- a/docs/_posts/2021-08-31-luau-recap-august-2021.md +++ /dev/null @@ -1,116 +0,0 @@ ---- -layout: single -title: "Luau Recap: August 2021" ---- - -Luau is our new language that you can read more about at [https://luau-lang.org](https://luau-lang.org). - -[Cross-posted to the [Roblox Developer Forum](https://devforum.roblox.com/t/luau-recap-august-2021/).] - -## Editor features - -The Roblox Studio [Luau-Powered Autocomplete & Language Features Beta](https://devforum.roblox.com/t/script-editor-luau-powered-autocomplete-language-features-beta) that our team has been working on has finally been released! -Be sure to check that out and leave your feedback for things we can improve. - -To support that feature, a lot of work went into: -* Improving fault-tolerant parser recovery scenarios -* Storing additional information in the AST, including comments, better location information and partial syntax data -* Tracking additional information about types and their fields, including tracking definition locations, function argument names, deprecation state and custom Roblox-specific tags -* Updating reflection information to provide more specific `Instance` types and correct previously missing or wrong type annotations -* Hybrid typechecking mode which tries to infer types even in scripts with no typechecking enabled -* Support for types that are attached to the `DataModel` tree elements to provide instance member information -* Placing limits to finish typechecking in a finite space/time -* Adding Autocomplete API for the Roblox Studio to get location-based entity information and appropriate suggestions -* Additional type inference engine improvements and fixes - -While our work continues to respond to the feedback we receive, our team members are shifting focus to add generic functions, improve type refinements in conditionals, extend Parallel Luau, improve Lua VM performance and provide documentation. - -## Typechecking improvements - -Type constraint resolver now remembers constraints placed on individual table fields. - -This should fix false-positive errors reported after making sure the optional table field is present: -```lua ---!strict -local t: {value: number?} = {value = 2} - -if t.value then - local v: number = t.value -- ok -end -``` - -And it can also refine field type to a more specific one: -```lua ---!strict -local t: {value: string|number} = {value = 2} - -if type(t.value) == "number" then - return t.value * 2 -- ok -end -``` - -Like before, combining multiple conditions using 'and' and 'not' is also supported. - ---- - -Constructing arrays with different values for optional/union types are now also supported for individual table fields and in functions call arguments: -```lua ---!strict -type Foo = {x: number | string, b: number?} - -local function foo(l: {Foo}) end - -foo({ - {x = 1234567}, - {x = "hello"}, -- now ok -}) - -type Bar = {a: {Foo}} - -local foos: Bar = {a = { - {x = 1234567}, - {x = "hello", b = 2}, -- now ok -}} -``` - ---- - -Finally, we have fixed an issue with Roblox class field access using indexing like `part["Anchored"] = true`. - -## Linter improvements - -We have added a new linter check for duplicate local variable definitions. - -It is created to find duplicate names in cases like these: -```lua -local function foo(a1, a2, a2) -- Function argument 'a2' already defined on column 24 -local a1, a2, a2 = f() -- Variable 'a2' already defined on column 11 - -local bar = {} -function bar:test(self) -- Function argument 'self' already defined implicitly -``` - -Our UnknownType linter warning was extended to check for correct class names passed into `FindFirstChildOfClass`, `FindFirstChildWhichIsA`, `FindFirstAncestorOfClass` and `FindFirstAncestorWhichIsA` functions. - -## Performance improvements - -We have added an optimization to 'table.unpack' for 2x performance improvement. - -We've also implemented an extra optimization for tables to predict required table capacity based on fields that are assigned to it in the code after construction. This can reduce the need to reallocate tables. - -Variadic call performance was fine-tuned and is now ~10% faster. - -Construction of array literals was optimized for a ~7% improvement. - -Another optimization this month changes the location and rate of garbage collection invocations. -We now try to avoid calling GC during the script execution and perform all the work in the GcJob part of the frame (it could be seen in the performance profiler). When possible, we can now skip that job in the frame completely, if we have some memory budget available. - -## Other improvements - -For general stability improvements we fixed a crash when strange types like '`nil??`' are used and when users have their own global functions named '`require`'. - -Indexing a table with an incompatible type will now show squiggly error lines under the index instead of the whole expression, which was a bit misleading. - -An issue with debug information that caused `repeat ... until` condition line to be skipped when stepping was fixed. - -Type output was improved to replace display of types like '`{(g405) -> g405}`' with '`{(a) -> a}`'. diff --git a/docs/_posts/2021-09-30-luau-recap-september-2021.md b/docs/_posts/2021-09-30-luau-recap-september-2021.md deleted file mode 100644 index 87793ebe9..000000000 --- a/docs/_posts/2021-09-30-luau-recap-september-2021.md +++ /dev/null @@ -1,126 +0,0 @@ ---- -layout: single -title: "Luau Recap: September 2021" ---- - -Luau is our new language that you can read more about at [https://luau-lang.org](https://luau-lang.org). - -[Cross-posted to the [Roblox Developer Forum](https://devforum.roblox.com/t/luau-recap-september-2021/).] - -## Generic functions - -The big news this month is that generic functions are back! - -Luau has always supported type inference for generic functions, for example: -```lua -type Point = { x: X, y: Y } -function swap(p) - return { x = p.y, y = p.x } -end -local p : Point = swap({ x = "hi", y = 37 }) -local q : Point = swap({ x = "hi", y = true }) -``` -but up until now, there's been no way to write the type of `swap`, since Luau didn't have type parameters to functions (just regular old data parameters). Well, now you can: -```lua -function swap(p : Point): Point - return { x = p.y, y = p.x } -end -``` -Generic functions can be used in function declarations, and function types too, for example -```lua -type Swapper = { swap : (Point) -> Point } -``` - -People may remember that back in -[April](https://devforum.roblox.com/t/luau-recap-april-2021/) we -announced generic functions, but then had to disable them. That was -because [DataBrain](https://devforum.roblox.com/u/databrain) discovered a [nasty -interaction](https://devforum.roblox.com/t/recent-type-system-regressions-for-generic-parametered-functions/) -between `typeof` and generics, which meant that it was possible to -write code that needed nested generic functions, which weren't -supported back then. - -Well, now we do support nested generic functions, so you can write code like -```lua -function mkPoint(x) - return function(y) - return { x = x, y = y } - end -end -``` -and have Luau infer a type where a generic function returns a generic function -```lua -function mkPoint(x : X) : (Y) -> Point - return function(y : Y) : Point - return { x = x, y = y } - end -end -``` -For people who like jargon, Luau now supports *Rank N Types*, where -previously it only supported Rank 1 Types. - -## Bidirectional typechecking - -Up until now, Luau has used *bottom-up* typechecking. For example, for -a function call `f(x)` we first find the type of `f` (say it's -`(T)->U`) and the type for `x` (say it's `V`), make sure that `V` is -a subtype of `T`, so the type of `f(x)` is `U`. - -This works in many cases, but has problems with examples like registering -callback event handlers. In code like -```lua -part.Touched:Connect(function (other) ... end) -``` -if we try to typecheck this bottom-up, we have a problem because -we don't know the type of `other` when we typecheck the body of the function. - -What we want in this case is a mix of bottom-up and *top-down* typechecking. -In this case, from the type of `part.Touched:Connect` we know that `other` must -have type `BasePart`. - -This mix of top-down and bottom-up typechecking is called -*bidirectional typechecking*, and means that tools like type-directed -autocomplete can provide better suggestions. - -## Editor features - -We have made some improvements to the Luau-powered autocomplete beta feature in Roblox Studio: - - * We no longer give autocomplete suggestions for client-only APIs in server-side scripts, - or vice versa. - * For table literals with known shape, we provide autocomplete suggestions for properties. - * We provide autocomplete suggestions for `Player.PlayerGui`. - * Keywords such as `then` and `else` are autocompleted better. - * Autocompletion is disabled inside a comment span (a comment starting `--[[`). - -## Typechecking improvements - -In other typechecking news: - - * The Luau constraint resolver can now refine the operands of equality expressions. - * Luau type guard refinements now support more arbitrary cases, for instance `typeof(foo) ~= "Instance"` eliminates anything not a subclass of `Instance`. - * We fixed some crashes caused by use-after-free during type inference. - * We do a better job of tracking updates when script is moved inside the data model. - * We fixed one of the ways that [recursive types could cause free types to leak](https://devforum.roblox.com/t/free-types-leaked-into-this-modules-public-interface/1459070). - * We improved the way that `return` statements interact with mutually recursive - function declarations. - * We improved parser recovery from code which looks like a function call (but isn't) such as -```lua -local x = y -(expr)[smth] = z -``` - * We consistently report parse errors before type errors. - * We display more types as `*unknown*` rather than as an internal type name like `error####`. - * Luau now infers the result of `Instance:Clone()` much more accurately. - -## Performance improvements - - * `Vector3.new` constructor has been optimized and is now ~2x faster - * A previously implemented optimization for table size prediction has been enhanced to predict final table size when `setmetatable` is used, such as `local self = setmetatable({}, Klass)` - * Method calls for user-specified objects have been optimized and are now 2-4% faster - * `debug.traceback` is now 1.5x faster, although `debug.info` is likely still superior for performance-conscious code - * Creating table literals with explicit numeric indices, such as `{ [1] = 42 }`, is now noticeably faster, although list-style construction is still recommended. - -## Other improvements - - * The existing 'TableLiteral' lint now flags cases when table literals have duplicate numeric indices, such as `{ [1] = 1, [1] = 2 }` diff --git a/docs/_posts/2021-10-31-luau-recap-october-2021.md b/docs/_posts/2021-10-31-luau-recap-october-2021.md deleted file mode 100644 index 2e0564246..000000000 --- a/docs/_posts/2021-10-31-luau-recap-october-2021.md +++ /dev/null @@ -1,151 +0,0 @@ ---- -layout: single -title: "Luau Recap: October 2021" ---- - -Luau is our new language that you can read more about at [https://luau-lang.org](https://luau-lang.org). - -[Cross-posted to the [Roblox Developer Forum](https://devforum.roblox.com/t/luau-recap-october-2021/).] - -## if-then-else expression - -In addition to supporting standard if *statements*, Luau adds support for if *expressions*. -Syntactically, `if-then-else` expressions look very similar to if statements. -However instead of conditionally executing blocks of code, if expressions conditionally evaluate expressions and return the value produced as a result. -Also, unlike if statements, if expressions do not terminate with the `end` keyword. - -Here is a simple example of an `if-then-else` expression: -```lua -local maxValue = if a > b then a else b -``` - -`if-then-else` expressions may occur in any place a regular expression is used. -The `if-then-else` expression must match `if then else `; -it can also contain an arbitrary number of `elseif` clauses, like `if then elseif then else `. -Note that in either case, `else` is mandatory. - -Here's is an example demonstrating `elseif`: -```lua -local sign = if x < 0 then -1 elseif x > 0 then 1 else 0 -``` - -**Note:** In Luau, the `if-then-else` expression is preferred vs the standard Lua idiom of writing `a and b or c` (which roughly simulates a ternary operator). However, the Lua idiom may return an unexpected result if `b` evaluates to false. -The `if-then-else` expression will behave as expected in all situations. - -## Library improvements - -New additions to the `table` library have arrived: - -```lua -function table.freeze(t) -``` - -Given a non-frozen table, freezes it such that all subsequent attempts to modify the table or assign its metatable raise an error. -If the input table is already frozen or has a protected metatable, the function raises an error; otherwise it returns the input table. -Note that the table is frozen in-place and is not being copied. -Additionally, only `t` is frozen, and keys/values/metatable of `t` don't change their state and need to be frozen separately if desired. - -```lua -function table.isfrozen(t): boolean -``` - -Returns `true` if and only if the input table is frozen. - -## Typechecking improvements - -We continue work on our type constraint resolver and have multiple improvements this month. - -We now resolve constraints that are created by `or` expressions. -In the following example, by checking against multiple type alternatives, we learn that value is a union of those types: -```lua ---!strict -local function f(x: any) - if type(x) == "number" or type(x) == "string" then - local foo = x -- 'foo' type is known to be 'number | string' here - -- ... - end -end -``` - -Support for `or` constraints allowed us to handle additional scenarios with `and` and `not` expressions to reduce false positives after specific type guards. - -And speaking of type guards, we now correctly handle sub-class relationships in those checks: -```lua ---!strict -local function f(x: Part | Folder | string) - if typeof(x) == "Instance" then - local foo = x -- 'foo' type is known to be 'Part | Folder' here - else - local bar = x -- 'bar' type is known to be 'string' here - end -end -``` - -One more fix handles the `a and b or c` expression when 'b' depends on 'a': -```lua ---!strict -function f(t: {x: number}?) - local a = t and t.x or 5 -- 'a' is a 'number', no false positive errors here -end -``` - -Of course, our new if-then-else expressions handle this case as well. -```lua ---!strict -function f(t: {x: number}?) - local a = if t then t.x else 5 -- 'a' is a 'number', no false positive errors here -end -``` - ---- -We have extended bidirectional typechecking that was announced last month to propagate types in additional statements and expressions. -```lua ---!strict -function getSortFunction(): (number, number) -> boolean - return function(a, b) return a > b end -- a and b are now known to be 'number' here -end - -local comp = getSortFunction() - -comp = function(a, b) return a < b end -- a and b are now known to be 'number' here as well -``` - ---- -We've also improved some of our messages with union types and optional types (unions types with `nil`). - -When optional types are used incorrectly, you get better messages. For example: -```lua ---!strict -function f(a: {number}?) - return a[1] -- "Value of type '{number}?' could be nil" instead of "'{number}?' is not a table' -end -``` - -When a property of a union type is accessed, but is missing from some of the options, we will report which options are not valid: -```lua ---!strict -type A = { x: number, y: number } -type B = { x: number } -local a: A | B -local b = a.y -- Key 'y' is missing from 'B' in the type 'A | B' -``` - ---- -When we enabled generic functions last month, some users might have seen a strange error about generic functions not being compatible with regular ones. - -This was caused by undefined behaviour of recursive types. -We have now added a restriction on how generic type parameters can be used in recursive types: [RFC: Recursive type restriction](https://github.com/Roblox/luau/blob/master/rfcs/recursive-type-restriction.md) - -## Performance improvements - -An improvement to the Stop-The-World (atomic in Lua terms) stage of the garbage collector was made to reduce time taken by that step by 4x factor. -While this step only happens once during a GC cycle, it cannot be split into small parts and long times were visible as frame time spikes. - -Table construction and resize was optimized further; as a result, many instances of table construction see 10-20% improvements -for smaller tables on all platforms and 20%+ improvements on Windows. - -Bytecode compiler has been optimized for giant table literals, resulting in 3x higher compilation throughput for certain files on AMD Zen architecture. - -Coroutine resumption has been optimized and is now ~10% faster for coroutine-heavy code. - -Array reads and writes are also now a bit faster resulting in 1-3% lift in array-heavy benchmarks. diff --git a/docs/_posts/2021-11-03-luau-goes-open-source.md b/docs/_posts/2021-11-03-luau-goes-open-source.md deleted file mode 100644 index f85e071e4..000000000 --- a/docs/_posts/2021-11-03-luau-goes-open-source.md +++ /dev/null @@ -1,26 +0,0 @@ ---- -layout: single -title: "Luau Goes Open-Source" ---- - -When Roblox was created 15 years ago, we chose Lua as the scripting language. Lua was small, fast, easy to embed and learn and opened up enormous possibilities for our developers. - -A lot in Roblox was built on Lua including hundreds of thousands of lines of internally-developed code that powers Roblox App and Roblox Studio to this day, and the millions of experiences that developers have created. For many of them, it was the first programming language they’ve learned. - -A few years ago, we started looking into how we can evolve Lua to be even faster, have better ergonomics, make it easier to write robust code and to unlock an ecosystem of rich tooling—from better static analysis to IDE integrations. - -This is how Luau was born. - -Luau is a new language that started from Lua 5.1 and kept evolving while keeping backwards compatibility and preserving the original design goals: simplicity, performance, embeddability. - -We’re incredibly grateful for the foundation that Lua has been—it’s been a joy to build on top of! So now we want to give back to the community at large. - -Starting today, [Luau](https://luau-lang.org) is no longer an inseparable part of Roblox platform; it’s a separate, open-source language. - -Luau is available at [https://github.com/Roblox/luau](https://github.com/Roblox/luau) and comes with the source code for the language runtime and all associated tooling: compiler, type checker, linter. The code is available to anyone, free of charge, under the terms of MIT License. We’re happy to accept contributions to the language, whether that’s documentation or source code. - -The language evolution is driven by an RFC process that is also open to the public. - -We are committed to improving Luau going forward—it remains a central piece of technology at Roblox. The team that works on the language keeps growing, and we have lots of ideas! The language will become even faster, even nicer to work with, even more powerful. - -We can’t wait to see what we can build, together. diff --git a/docs/_posts/2021-11-29-luau-recap-november-2021.md b/docs/_posts/2021-11-29-luau-recap-november-2021.md deleted file mode 100644 index 381b2455f..000000000 --- a/docs/_posts/2021-11-29-luau-recap-november-2021.md +++ /dev/null @@ -1,77 +0,0 @@ ---- -layout: single -title: "Luau Recap: November 2021" ---- - -Luau is our new language that you can read more about at [https://luau-lang.org](https://luau-lang.org). - -[Cross-posted to the [Roblox Developer Forum](https://devforum.roblox.com/t/luau-recap-november-2021/).] - -## Type packs in type aliases - -Type packs are the construct Luau uses to represent a sequence of types. We've had syntax for generic type packs for a while now, and it sees use in generic functions, but it hasn't been available in type aliases. That has changed, and it is now syntactically legal to write the following type alias: -```lua -type X = () -> A... -type Y = X -``` - -We've also added support for explicit type packs. Previously, it was impossible to instantiate a generic with two or more type pack parameters, because it wasn't clear where the first pack ended and the second one began. We have introduced a new syntax for this use case: -``` -type Fn = (P...) -> R... -type X = Fn<(number, string), (string, number)> -``` - -For more information, check out [the documentation](https://luau-lang.org/typecheck#type-packs) or [the RFC](https://github.com/Roblox/luau/blob/f86d4c6995418e489a55be0100159009492778ff/rfcs/syntax-type-alias-type-packs.md) for this feature. - -## Luau is open-source! - -We announced this in early November but it deserves repeating: Luau is now an open-source project! You can use Luau outside of Roblox, subject to MIT License, and - importantly - we accept contributions. - -Many changes contributed by community, both Roblox and external, have been merged since we've made Luau open source. Of note are two visible changes that shipped on Roblox platform: - -- The type error "Expected to return X values, but Y values are returned here" actually had X and Y swapped! This is now fixed. -- Luau compiler dutifully computed the length of the string when using `#` operator on a string literal; this is now fixed and `#"foo"` compiles to 3. - -You might think that C++ is a scary language and you can't contribute to Luau. If so, you'd be happy to know that the contents of https://luau-lang.org, where we host our documentation, is also hosted on GitHub in the same repository (https://github.com/Roblox/luau/tree/master/docs) and that we'd love the community to contribute improvements to documentation among other changes! For example see [issues in this list](https://github.com/Roblox/luau/issues?q=is%3Aissue+is%3Aopen+label%3A%22pr+welcome%22) that start with "Documentation", but all other changes and additions to documentation are also welcome. - -## Library improvements - -```lua -function bit32.countlz(n: number): number -function bit32.countrz(n: number): number -``` -Given a number, returns the number of preceding left or trailing right-hand bits that are `0`. - -See [the RFC for these functions](https://github.com/Roblox/luau/blob/f86d4c6995418e489a55be0100159009492778ff/rfcs/function-bit32-countlz-countrz.md) for more information. - -## Type checking improvements - -We have enabled a rewrite of how Luau handles `require` tracing. This has two main effects: firstly, in strict mode, `require` statements that Luau can't resolve will trigger type errors; secondly, Luau now understands the `FindFirstAncestor` method in `require` expressions. - -Luau now warns when the index to `table.move` is 0, as this is non-idiomatic and performs poorly. If this behavior is intentional, wrap the index in parentheses to suppress the warning. - -Luau now provides additional context in table and class type mismatch errors. - -## Performance improvements - -We have enabled several changes that aim to avoid allocating a new closure object in cases where it's not necessary to. This is helpful in cases where many closures are being allocated; in our benchmark suite, the two benchmarks that allocate a large number of closures improved by 15% and 5%, respectively. - -When checking union types, we now try possibilities whose synthetic names match. This will speed up type checking unions in cases where synthetic names are populated. - -We have also enabled an optimization that shares state in a hot path on the type checker. This will improve type checking performance. - -The Luau VM now attempts to cache the length of tables' array portion. This change showed a small performance improvement in benchmarks, and should speed up `#` expressions. - -The Luau type checker now caches a specific category of table unification results. This can improve type checking performance significantly when the same set of types is used frequently. - -When Luau is not retaining type graphs, the type checker now discards more of a module's type surface after type checking it. This improves memory usage significantly. - -## Bug fixes - -We've fixed a bug where on ARM systems (mobile), packing negative numbers using unsigned formats in `string.pack` would produce the wrong result. - -We've fixed an issue with type aliases that reuse generic type names that caused them to be instantiated incorrectly. - -We've corrected a subtle bug that could cause free types to leak into a table type when a free table is bound to that table. - -We've fixed an issue that could cause Luau to report an infinitely recursive type error when the type was not infinitely recursive. diff --git a/docs/_posts/2022-01-27-luau-recap-january-2022.md b/docs/_posts/2022-01-27-luau-recap-january-2022.md deleted file mode 100644 index 5ee100226..000000000 --- a/docs/_posts/2022-01-27-luau-recap-january-2022.md +++ /dev/null @@ -1,122 +0,0 @@ ---- -layout: single -title: "Luau Recap: January 2022" ---- - -Luau is our programming language that you can read more about at [https://luau-lang.org](https://luau-lang.org). - -[Find us on GitHub](https://github.com/Roblox/luau)! - -[Cross-posted to the [Roblox Developer Forum](https://devforum.roblox.com/t/luau-recap-january-2022/).] - -## Performance improvements - -The implementation of `tostring` has been rewritten. This change replaces the default number->string conversion with a -new algorithm called Schubfach, which allows us to produce the shortest precise round-trippable representation of any -input number very quickly. - -While performance is not the main driving factor, this also happens to be significantly faster than our old -implementation (up to 10x depending on the number and the platform). - ---- - -Make `tonumber(x)` ~2x faster by avoiding reparsing string arguments. - ---- - -The Luau compiler now optimizes table literals where keys are constant variables the same way as if they were constants, eg - -```lua -local r, g, b = 1, 2, 3 -local col = { [r] = 255, [g] = 0, [b] = 255 } -``` - -## Improvements to type assertions - -The `::` type assertion operator can now be used to coerce a value between any two related types. Previously, it could -only be used for downcasts or casts to `any`. The following used to be invalid, but is now valid: - -```lua -local t = {x=0, y=0} -local a = t :: {x: number} -``` - -## Typechecking improvements - -An issue surrounding table literals and indexers has been fixed: - -```lua -type RecolorMap = {[string]: RecolorMap | Color3} - -local hatRecolorMap: RecolorMap = { - Brim = Color3.fromRGB(255, 0, 0), -- We used to report an error here - Top = Color3.fromRGB(255, 0, 0) -} -``` - ---- -Accessing a property whose base expression was previously refined will now return the correct result. - -## Linter improvements - -`table.create(N, {})` will now produce a static analysis warning since the element is going to be shared for all table entries. - -## Error reporting improvements - -When a type error involves a union (or an option), we now provide more context in the error message. - -For instance, given the following code: - -```lua ---!strict - -type T = {x: number} - -local x: T? = {w=4} -``` - -We now report the following: - -``` -Type 'x' could not be converted into 'T?' -caused by: - None of the union options are compatible. For example: Table type 'x' not compatible with type 'T' because the former is missing field 'x' -``` - ---- -Luau now gives up and reports an `*unknown*` type in far fewer cases when typechecking programs that have type errors. - -## New APIs - -We have brought in the [`coroutine.close`](https://luau-lang.org/library#coroutine-library) function from Lua 5.4. It accepts a suspended coroutine and marks it as non-runnable. In Roblox, this can be useful in combination with `task.defer` to implement cancellation. - -## REPL improvements - -The `luau` REPL application can be compiled from source or downloaded from [releases page](https://github.com/Roblox/luau/releases). It has grown some new features: - -* Added `--interactive` option to run the REPL after running the last script file. -* Allowed the compiler optimization level to be specified. -* Allowed methods to be tab completed -* Allowed methods on string instances to be completed -* Improved Luau REPL argument parsing and error reporting -* Input history is now saved/loaded - -## Thanks - -A special thanks to all the fine folks who contributed PRs over the last few months! - -* [Halalaluyafail3](https://github.com/Halalaluyafail3) -* [JohnnyMorganz](https://github.com/JohnnyMorganz) -* [Kampfkarren](https://github.com/Kampfkarren) -* [kunitoki](https://github.com/kunitoki) -* [MathematicalDessert](https://github.com/MathematicalDessert) -* [metatablecat](https://github.com/metatablecat) -* [petrihakkinen](https://github.com/petrihakkinen) -* [rafa_br34](https://github.com/rafa_br34) -* [Rerumu](https://github.com/Rerumu) -* [Slappy826](https://github.com/Slappy826) -* [SnowyShiro](https://github.com/SnowyShiro) -* [vladmarica](https://github.com/vladmarica) -* [xgladius](https://github.com/xgladius) - -[Contribution guide](https://github.com/Roblox/luau/blob/2f989fc049772f36de1a4281834c375858507bda/CONTRIBUTING.md) diff --git a/docs/_posts/2022-02-28-luau-recap-february-2022.md b/docs/_posts/2022-02-28-luau-recap-february-2022.md deleted file mode 100644 index 412217dfb..000000000 --- a/docs/_posts/2022-02-28-luau-recap-february-2022.md +++ /dev/null @@ -1,164 +0,0 @@ ---- -layout: single -title: "Luau Recap: February 2022" ---- - -Luau is our new language that you can read more about at [https://luau-lang.org](https://luau-lang.org). - -[Cross-posted to the [Roblox Developer Forum](https://devforum.roblox.com/t/luau-recap-february-2022/).] - -## Default type alias type parameters - -We have introduced a syntax to provide default type arguments inside the type alias type parameter list. - -It is now possible to have type functions where the instantiation can omit some type arguments. - -You can provide concrete types: - -```lua ---!strict -type FieldResolver = (T, Data) -> number - -local a: FieldResolver = ... -local b: FieldResolver = ... -``` - -Or reference parameters defined earlier in the list: - -```lua ---!strict -type EqComp = (l: T, r: U) -> boolean - -local a: EqComp = ... -- (l: number, r: number) -> boolean -local b: EqComp = ... -- (l: number, r: string) -> boolean -``` - -Type pack parameters can also have a default type pack: - -```lua ---!strict -type Process = (T) -> U... - -local a: Process = ... -- (number) -> ...string -local b: Process = ... -- (number) -> (boolean, string) -``` - -If all type parameters have a default type, it is now possible to reference that without providing any type arguments: - -```lua ---!strict -type All = (T) -> U - -local a: All -- ok -local b: All<> -- ok as well -``` - -For more details, you can read the original [RFC proposal](https://github.com/Roblox/luau/blob/master/rfcs/syntax-default-type-alias-type-parameters.md). - -## Typechecking improvements - -This month we had many fixes to improve our type inference and reduce false positive errors. - -if-then-else expression can now have different types in each branch: - -```lua ---!strict -local a = if x then 5 else nil -- 'a' will have type 'number?' -local b = if x then 1 else '2' -- 'b' will have type 'number | string' -``` - -And if the expected result type is known, you will not get an error in cases like these: - -```lua ---!strict -type T = {number | string} --- different array element types don't give an error if that is expected -local c: T = if x then {1, "x", 2, "y"} else {0} -``` - ---- - -`assert` result is now known to not be 'falsy' (`false` or `nil`): - -```lua ---!strict -local function f(x: number?): number - return assert(x) -- no longer an error -end -``` - ---- - -We fixed cases where length operator `#` reported an error when used on a compatible type: - -```lua ---!strict -local union: {number} | {string} -local a = #union -- no longer an error -``` - ---- - -Functions with different variadic argument/return types are no longer compatible: - -```lua ---!strict -local function f(): (number, ...string) - return 2, "a", "b" -end - -local g: () -> (number, ...boolean) = f -- error -``` - ---- - -We have also fixed: - -* false positive errors caused by incorrect reuse of generic types across different function declarations -* issues with forward-declared intersection types -* wrong return type annotation for table.move -* various crashes reported by developers - -## Linter improvements - -A new static analysis warning was introduced to mark incorrect use of a '`a and b or c`' pattern. When 'b' is 'falsy' (`false` or `nil`), result will always be 'c', even if the expression 'a' was true: - -```lua -local function f(x: number) - -- The and-or expression always evaluates to the second alternative because the first alternative is false; consider using if-then-else expression instead - return x < 0.5 and false or 42 -end -``` - -Like we say in the warning, new if-then-else expression doesn't have this pitfall: - -```lua -local function g(x: number) - return if x < 0.5 then false else 42 -end -``` - ---- - -We have also introduced a check for misspelled comment directives: - -```lua ---!non-strict --- ^ Unknown comment directive 'non-strict'; did you mean 'nonstrict'? -``` - -## Performance improvements - -For performance, we have changed how our Garbage Collector collects unreachable memory. -This rework makes it possible to free memory 2.5x faster and also comes with a small change to how we store Luau objects in memory. For example, each table now uses 16 fewer bytes on 64-bit platforms. - -Another optimization was made for `select(_, ...)` call. -It is now using a special fast path that has constant-time complexity in number of arguments (~3x faster with 10 arguments). - -## Thanks - -A special thanks to all the fine folks who contributed PRs this month! - -* [mikejsavage](https://github.com/mikejsavage) -* [TheGreatSageEqualToHeaven](https://github.com/TheGreatSageEqualToHeaven) -* [petrihakkinen](https://github.com/petrihakkinen) diff --git a/docs/_posts/2022-03-31-luau-recap-march-2022.md b/docs/_posts/2022-03-31-luau-recap-march-2022.md deleted file mode 100644 index ff3a4d0fc..000000000 --- a/docs/_posts/2022-03-31-luau-recap-march-2022.md +++ /dev/null @@ -1,109 +0,0 @@ ---- -layout: single -title: "Luau Recap: March 2022" ---- - -Luau is our new language that you can read more about at [https://luau-lang.org](https://luau-lang.org). - -[Cross-posted to the [Roblox Developer Forum](https://devforum.roblox.com/t/luau-recap-march-2022/).] - -## Singleton types - -We added support for singleton types! These allow you to use string or -boolean literals in types. These types are only inhabited by the -literal, for example if a variable `x` has type `"foo"`, then `x == -"foo"` is guaranteed to be true. - -Singleton types are particularly useful when combined with union types, -for example: - -```lua -type Animals = "Dog" | "Cat" | "Bird" -``` - -or: - -```lua -type Falsey = false | nil -``` - -In particular, singleton types play well with unions of tables, -allowing tagged unions (also known as discriminated unions): - -```lua -type Ok = { type: "ok", value: T } -type Err = { type: "error", error: E } -type Result = Ok | Err - -local result: Result = ... -if result.type == "ok" then - -- result :: Ok - print(result.value) -elseif result.type == "error" then - -- result :: Err - error(result.error) -end -``` - -The RFC for singleton types is https://github.com/Roblox/luau/blob/master/rfcs/syntax-singleton-types.md - -## Width subtyping - -A common idiom for programming with tables is to provide a public interface type, but to keep some of the concrete implementation private, for example: - -```lua -type Interface = { - name: string, -} - -type Concrete = { - name: string, - id: number, -} -``` - -Within a module, a developer might use the concrete type, but export functions using the interface type: - -```lua -local x: Concrete = { - name = "foo", - id = 123, -} - -local function get(): Interface - return x -end -``` - -Previously examples like this did not typecheck but now they do! - -This language feature is called *width subtyping* (it allows tables to get *wider*, that is to have more properties). - -The RFC for width subtyping is https://github.com/Roblox/luau/blob/master/rfcs/sealed-table-subtyping.md - -## Typechecking improvements - - * Generic function type inference now works the same for generic types and generic type packs. - * We improved some error messages. - * There are now fewer crashes (hopefully none!) due to mutating types inside the Luau typechecker. - * We fixed a bug that could cause two incompatible copies of the same class to be created. - * Luau now copes better with cyclic metatable types (it gives a type error rather than hanging). - * Fixed a case where types are not properly bound to all of the subtype when the subtype is a union. - * We fixed a bug that confused union and intersection types of table properties. - * Functions declared as `function f(x : any)` can now be called as `f()` without a type error. - -## API improvements - - * Implement `table.clone` which takes a table and returns a new table that has the same keys/values/metatable. The cloning is shallow - if some keys refer to tables that need to be cloned, that can be done manually by modifying the resulting table. - -## Debugger improvements - - * Use the property name as the name of methods in the debugger. - -## Performance improvements - - * Optimize table rehashing (~15% faster dictionary table resize on average) - * Improve performance of freeing tables (~5% lift on some GC benchmarks) - * Improve gathering performance metrics for GC. - * Reduce stack memory reallocation. - diff --git a/docs/_posts/2022-05-02-luau-recap-april-2022.md b/docs/_posts/2022-05-02-luau-recap-april-2022.md deleted file mode 100644 index dd6b2c0c5..000000000 --- a/docs/_posts/2022-05-02-luau-recap-april-2022.md +++ /dev/null @@ -1,51 +0,0 @@ ---- -layout: single -title: "Luau Recap: April 2022" ---- - -Luau is our new language that you can read more about at [https://luau-lang.org](https://luau-lang.org). - -[Cross-posted to the [Roblox Developer Forum](https://devforum.roblox.com/t/luau-recap-april-2022/).] - -It's been a bit of a quiet month. We mostly have small optimizations and bugfixes for you. - -It is now allowed to define functions on sealed tables that have string indexers. These functions will be typechecked against the indexer type. For example, the following is now valid: - -```lua -local a : {[string]: () -> number} = {} - -function a.y() return 4 end -- OK -``` - -Autocomplete will now provide string literal suggestions for singleton types. eg - -```lua -local function f(x: "a" | "b") end -f("_") -- suggest "a" and "b" -``` - -Improve error recovery in the case where we encounter a type pack variable in a place where one is not allowed. eg `type Foo = { value: A... }` - -When code does not pass enough arguments to a variadic function, the error feedback is now better. - -For example, the following script now produces a much nicer error message: -```lua -type A = { [number]: number } -type B = { [number]: string } - -local a: A = { 1, 2, 3 } - --- ERROR: Type 'A' could not be converted into 'B' --- caused by: --- Property '[indexer value]' is not compatible. Type 'number' could not be converted into 'string' -local b: B = a -``` - -If the following code were to error because `Hello` was undefined, we would erroneously include the comment in the span of the error. This is now fixed. -```lua -type Foo = Hello -- some comment over here -``` - -Fix a crash that could occur when strict scripts have cyclic require() dependencies. - -Add an option to autocomplete to cause it to abort processing after a certain amount of time has elapsed. diff --git a/docs/_posts/2022-06-01-luau-recap-may-2022.md b/docs/_posts/2022-06-01-luau-recap-may-2022.md deleted file mode 100644 index 500e6e4af..000000000 --- a/docs/_posts/2022-06-01-luau-recap-may-2022.md +++ /dev/null @@ -1,97 +0,0 @@ ---- -layout: single -title: "Luau Recap: May 2022" ---- - -This month Luau team has worked to bring you a new language feature together with more typechecking improvements and bugfixes! - -[Cross-posted to the [Roblox Developer Forum](https://devforum.roblox.com/t/luau-recap-may-2022/).] - -## Generalized iteration - -We have extended the semantics of standard Lua syntax for iterating through containers, `for vars in values` with support for generalized iteration. -In Lua, to iterate over a table you need to use an iterator like `next` or a function that returns one like `pairs` or `ipairs`. In Luau, you can now simply iterate over a table: - -```lua -for k, v in {1, 4, 9} do - assert(k * k == v) -end -``` - -This works for tables but can also be customized for tables or userdata by implementing `__iter` metamethod. It is called before the iteration begins, and should return an iterator function like `next` (or a custom one): - -```lua -local obj = { items = {1, 4, 9} } -setmetatable(obj, { __iter = function(o) return next, o.items end }) - -for k, v in obj do - assert(k * k == v) -end -``` - -The default iteration order for tables is specified to be consecutive for elements `1..#t` and unordered after that, visiting every element. -Similar to iteration using `pairs`, modifying the table entries for keys other than the current one results in unspecified behavior. - -## Typechecking improvements - -We have added a missing check to compare implicit table keys against the key type of the table indexer: - -```lua --- error is correctly reported, implicit keys (1,2,3) are not compatible with [string] -local t: { [string]: boolean } = { true, true, false } -``` - -Rules for `==` and `~=` have been relaxed for union types, if any of the union parts can be compared, operation succeeds: - -```lua ---!strict -local function compare(v1: Vector3, v2: Vector3?) - return v1 == v2 -- no longer an error -end -``` - -Table value type propagation now correctly works with `[any]` key type: - -```lua ---!strict -type X = {[any]: string | boolean} -local x: X = { key = "str" } -- no longer gives an incorrect error -``` - -If a generic function doesn't provide type annotations for all arguments and the return value, additional generic type parameters might be added automatically: - -```lua --- previously it was foo, now it's foo, because second argument is also generic -function foo(x: T, y) end -``` - -We have also fixed various issues that have caused crashes, with many of them coming from your bug reports. - -## Linter improvements - -`GlobalUsedAsLocal` lint warning has been extended to notice when global variable writes always happen before their use in a local scope, suggesting that they can be replaced with a local variable: - -```lua -function bar() - foo = 6 -- Global 'foo' is never read before being written. Consider changing it to local - return foo -end -function baz() - foo = 10 - return foo -end -``` - -## Performance improvements - -Garbage collection CPU utilization has been tuned to further reduce frame time spikes of individual collection steps and to bring different GC stages to the same level of CPU utilization. - -Returning a type-cast local (`return a :: type`) as well as returning multiple local variables (`return a, b, c`) is now a little bit more efficient. - -### Function inlining and loop unrolling - -In the open-source release of Luau, when optimization level 2 is enabled, the compiler will now perform function inlining and loop unrolling. - -Only loops with loop bounds known at compile time, such as `for i=1,4 do`, can be unrolled. The loop body must be simple enough for the optimization to be profitable; compiler uses heuristics to estimate the performance benefit and automatically decide if unrolling should be performed. - -Only local functions (defined either as `local function foo` or `local foo = function`) can be inlined. The function body must be simple enough for the optimization to be profitable; compiler uses heuristics to estimate the performance benefit and automatically decide if each call to the function should be inlined instead. Additionally recursive invocations of a function can't be inlined at this time, and inlining is completely disabled for modules that use `getfenv`/`setfenv` functions. diff --git a/docs/_posts/2022-07-07-luau-recap-june-2022.md b/docs/_posts/2022-07-07-luau-recap-june-2022.md deleted file mode 100644 index 1f58d8920..000000000 --- a/docs/_posts/2022-07-07-luau-recap-june-2022.md +++ /dev/null @@ -1,88 +0,0 @@ ---- -layout: single -title: "Luau Recap: June 2022" ---- - -Luau is our new language that you can read more about at [https://luau-lang.org](https://luau-lang.org). - -[Cross-posted to the [Roblox Developer Forum](https://devforum.roblox.com/t/luau-recap-june-2022/).] - -# Lower bounds calculation - -A common problem that Luau has is that it primarily works by inspecting expressions in your program and narrowing the _upper bounds_ of the values that can inhabit particular variables. In other words, each time we see a variable used, we eliminate possible sets of values from that variable's domain. - -There are some important cases where this doesn't produce a helpful result. Take this function for instance: - -```lua -function find_first_if(vec, f) - for i, e in ipairs(vec) do - if f(e) then - return i - end - end - - return nil -end -``` - -Luau scans the function from top to bottom and first sees the line `return i`. It draws from this the inference that `find_first_if` must return the type of `i`, namely `number`. - -This is fine, but things go sour when we see the line `return nil`. Since we are always narrowing, we take from this line the judgement that the return type of the function is `nil`. Since we have already concluded that the function must return `number`, Luau reports an error. - -What we actually want to do in this case is to take these `return` statements as inferences about the _lower_ bound of the function's return type. Instead of saying "this function must return values of type `nil`," we should instead say "this function may _also_ return values of type `nil`." - -Lower bounds calculation does precisely this. Moving forward, Luau will instead infer the type `number?` for the above function. - -This does have one unfortunate consequence: If a function has no return type annotation, we will no longer ever report a type error on a `return` statement. We think this is the right balance but we'll be keeping an eye on things just to be sure. - -Lower-bounds calculation is larger and a little bit riskier than other things we've been working on so we've set up a beta feature in Roblox Studio to enable them. It is called "Experimental Luau language features." - -Please try it out and let us know what you think! - -## Known bug - -We have a known bug with certain kinds of cyclic types when lower-bounds calculation is enabled. The following, for instance, is known to be problematic. - -```lua -type T = {T?}? -- spuriously reduces to {nil}? -``` - -We hope to have this fixed soon. - -# All table literals now result in unsealed tables - -Previously, the only way to create a sealed table was by with a literal empty table. We have relaxed this somewhat: Any table created by a `{}` expression is considered to be unsealed within the scope where it was created: - -```lua -local T = {} -T.x = 5 -- OK - -local V = {x=5} -V.y = 2 -- previously disallowed. Now OK. - -function mkTable() - return {x = 5} -end - -local U = mkTable() -U.y = 2 -- Still disallowed: U is sealed -``` - -# Other fixes - -* Adjust indentation and whitespace when creating multiline string representations of types, resulting in types that are easier to read. -* Some small bugfixes to autocomplete -* Fix a case where accessing a nonexistent property of a table would not result in an error being reported. -* Improve parser recovery for the incorrect code `function foo() -> ReturnType` (the correct syntax is `function foo(): ReturnType`) -* Improve the parse error offered for code that improperly uses the `function` keyword to start a type eg `type T = function` -* Some small crash fixes and performance improvements - -# Thanks! - -A very special thanks to all of our open source contributors: - -* [Allan N Jeremy](https://github.com/AllanJeremy) -* [Daniel Nachun](https://github.com/danielnachun) -* [JohnnyMorganz](https://github.com/JohnnyMorganz/) -* [Petri Häkkinen](https://github.com/petrihakkinen) -* [Qualadore](https://github.com/Qualadore) diff --git a/docs/_posts/2022-08-29-luau-recap-august-2022.md b/docs/_posts/2022-08-29-luau-recap-august-2022.md deleted file mode 100644 index e43f0f34e..000000000 --- a/docs/_posts/2022-08-29-luau-recap-august-2022.md +++ /dev/null @@ -1,187 +0,0 @@ ---- -layout: single -title: "Luau Recap: July & August 2022" ---- - -Luau is our new language that you can read more about at [https://luau-lang.org](https://luau-lang.org). - -[Cross-posted to the [Roblox Developer Forum](https://devforum.roblox.com/t/luau-recap-july-august-2022/).] - -## Tables now support `__len` metamethod - -See the RFC [Support `__len` metamethod for tables and `rawlen` function](https://github.com/Roblox/luau/blob/master/rfcs/len-metamethod-rawlen.md) for more details. - -With generalized iteration released in May, custom containers are easier than ever to use. The only thing missing was the fact that tables didn't respect `__len`. - -Simply, tables now honor the `__len` metamethod, and `rawlen` is also added with similar semantics as `rawget` and `rawset`: - -```lua -local my_cool_container = setmetatable({ items = { 1, 2 } }, { - __len = function(self) return #self.items end -}) - -print(#my_cool_container) --> 2 -print(rawlen(my_cool_container)) --> 0 -``` - -## `never` and `unknown` types - -See the RFC [`never` and `unknown` types](https://github.com/Roblox/luau/blob/master/rfcs/never-and-unknown-types.md) for more details. - -We've added two new types, `never` and `unknown`. These two types are the opposites of each other by the fact that there's no value that inhabits the type `never`, and the dual of that is every value inhabits the type `unknown`. - -Type inference may infer a variable to have the type `never` if and only if the set of possible types becomes empty, for example through type refinements. - -```lua -function f(x: string | number) - if typeof(x) == "string" and typeof(x) == "number" then - -- x: never - end -end -``` - -This is useful because we still needed to ascribe a type to `x` here, but the type we used previously had unsound semantics. For example, it was possible to be able to _expand_ the domain of a variable once the user had proved it impossible. With `never`, narrowing a type from `never` yields `never`. - -Conversely, `unknown` can be used to enforce a stronger contract than `any`. That is, `unknown` and `any` are similar in terms of allowing every type to inhabit them, and other than `unknown` or `any`, `any` allows itself to inhabit into a different type, whereas `unknown` does not. - -```lua -function any(): any return 5 end -function unknown(): unknown return 5 end - --- no type error, but assigns a number to x which expects string -local x: string = any() - --- has type error, unknown cannot be converted into string -local y: string = unknown() -``` - -To be able to do this soundly, you must apply type refinements on a variable of type `unknown`. - -```lua -local u = unknown() - -if typeof(u) == "string" then - local y: string = u -- no type error -end -``` - -A use case of `unknown` is to enforce type safety at implementation sites for data that do not originate in code, but from over the wire. - -## Argument names in type packs when instantiating a type - -We had a bug in the parser which erroneously allowed argument names in type packs that didn't fold into a function type. That is, the below syntax did not generate a parse error when it should have. - -```lua -Foo<(a: number, b: string)> -``` - -## New IntegerParsing lint - -See [the announcement](https://devforum.roblox.com/t/improving-binary-and-hexadecimal-integer-literal-parsing-rules-in-luau/) for more details. We include this here for posterity. - -We've introduced a new lint called IntegerParsing. Right now, it lints three classes of errors: - -1. Truncation of binary literals that resolves to a value over 64 bits, -2. Truncation of hexadecimal literals that resolves to a value over 64 bits, and -3. Double hexadecimal prefix. - -For 1.) and 2.), they are currently not planned to become a parse error, so action is not strictly required here. - -For 3.), this will be a breaking change! See [the rollout plan](https://devforum.roblox.com/t/improving-binary-and-hexadecimal-integer-literal-parsing-rules-in-luau/#rollout-5) for details. - -## New ComparisonPrecedence lint - -We've also introduced a new lint called `ComparisonPrecedence`. It fires in two particular cases: - -1. `not X op Y` where `op` is `==` or `~=`, or -2. `X op Y op Z` where `op` is any of the comparison or equality operators. - -In languages that uses `!` to negate the boolean i.e. `!x == y` looks fine because `!x` _visually_ binds more tightly than Lua's equivalent, `not x`. Unfortunately, the precedences here are identical, that is `!x == y` is `(!x) == y` in the same way that `not x == y` is `(not x) == y`. We also apply this on other operators e.g. `x <= y == y`. - -```lua --- not X == Y is equivalent to (not X) == Y; consider using X ~= Y, or wrap one of the expressions in parentheses to silence -if not x == y then end - --- not X ~= Y is equivalent to (not X) ~= Y; consider using X == Y, or wrap one of the expressions in parentheses to silence -if not x ~= y then end - --- not X <= Y is equivalent to (not X) <= Y; wrap one of the expressions in parentheses to silence -if not x <= y then end - --- X <= Y == Z is equivalent to (X <= Y) == Z; wrap one of the expressions in parentheses to silence -if x <= y == 0 then end -``` - -As a special exception, this lint pass will not warn for cases like `x == not y` or `not x == not y`, which both looks intentional as it is written and interpreted. - -## Function calls returning singleton types incorrectly widened - -Fix a bug where widening was a little too happy to fire in the case of function calls returning singleton types or union thereof. This was an artifact of the logic that knows not to infer singleton types in cases that makes no sense to. - -```lua -function f(): "abc" | "def" - return if math.random() > 0.5 then "abc" else "def" -end - --- previously reported that 'string' could not be converted into '"abc" | "def"' -local x: "abc" | "def" = f() -``` - -## `string` can be a subtype of a table with a shape similar to `string` - -The function `my_cool_lower` is a function `(t: t1) -> a... where t1 = {+ lower: (t1) -> a... +}`. - -```lua -function my_cool_lower(t) - return t:lower() -end -``` - -Even though `t1` is a table type, we know `string` is a subtype of `t1` because `string` also has `lower` which is a subtype of `t1`'s `lower`, so this call site now type checks. - -```lua -local s: string = my_cool_lower("HI") -``` - -## Other analysis improvements - -* `string.gmatch`/`string.match`/`string.find` may now return more precise type depending on the patterns used -* Fix a bug where type arena ownership invariant could be violated, causing stability issues -* Fix a bug where internal type error could be presented to the user -* Fix a false positive with optionals & nested tables -* Fix a false positive in non-strict mode when using generalized iteration -* Improve autocomplete behavior in certain cases for `:` calls -* Fix minor inconsistencies in synthesized names for types with metatables -* Fix autocomplete not suggesting globals defined after the cursor -* Fix DeprecatedGlobal warning text in cases when the global is deprecated without a suggested alternative -* Fix an off-by-one error in type error text for incorrect use of `string.format` - -## Other runtime improvements - -* Comparisons with constants are now significantly faster when using clang as a compiler (10-50% gains on internal benchmarks) -* When calling non-existent methods on tables or strings, `foo:bar` now produces a more precise error message -* Improve performance for iteration of tables -* Fix a bug with negative zero in vector components when using vectors as table keys -* Compiler can now constant fold builtins under -O2, for example `string.byte("A")` is compiled to a constant -* Compiler can model the cost of builtins for the purpose of inlining/unrolling -* Local reassignment i.e. `local x = y :: T` is free iff neither `x` nor `y` is mutated/captured -* Improve `debug.traceback` performance by 1.15-1.75x depending on the platform -* Fix a corner case with table assignment semantics when key didn't exist in the table and `__newindex` was defined: we now use Lua 5.2 semantics and call `__newindex`, which results in less wasted space, support for NaN keys in `__newindex` path and correct support for frozen tables -* Reduce parser C stack consumption which fixes some stack overflow crashes on deeply nested sources -* Improve performance of `bit32.extract`/`replace` when width is implied (~3% faster chess) -* Improve performance of `bit32.extract` when field/width are constants (~10% faster base64) -* `string.format` now supports a new format specifier, `%*`, that accepts any value type and formats it using `tostring` rules - -## Thanks - -Thanks for all the contributions! - -* [natteko](https://github.com/natteko) -* [JohnnyMorganz](https://github.com/JohnnyMorganz) -* [khvzak](https://github.com/khvzak) -* [Anaminus](https://github.com/Anaminus) -* [memery-rbx](https://github.com/memery-rbx) -* [jaykru](https://github.com/jaykru) -* [Kampfkarren](https://github.com/Kampfkarren) -* [XmiliaH](https://github.com/XmiliaH) -* [Mactavsin](https://github.com/Mactavsin) diff --git a/docs/_posts/2022-10-31-luau-semantic-subtyping.md b/docs/_posts/2022-10-31-luau-semantic-subtyping.md deleted file mode 100644 index 68622a679..000000000 --- a/docs/_posts/2022-10-31-luau-semantic-subtyping.md +++ /dev/null @@ -1,292 +0,0 @@ ---- -layout: single -title: "Semantic Subtyping in Luau" -author: Alan Jeffrey ---- - -Luau is the first programming language to put the power of semantic subtyping in the hands of millions of creators. - -## Minimizing false positives - -One of the issues with type error reporting in tools like the Script Analysis widget in Roblox Studio is *false positives*. These are warnings that are artifacts of the analysis, and don’t correspond to errors which can occur at runtime. For example, the program -```lua - local x = CFrame.new() - local y - if (math.random()) then - y = CFrame.new() - else - y = Vector3.new() - end - local z = x * y -``` -reports a type error which cannot happen at runtime, since `CFrame` supports multiplication by both `Vector3` and `CFrame`. (Its type is `((CFrame, CFrame) -> CFrame) & ((CFrame, Vector3) -> Vector3)`.) - -False positives are especially poor for onboarding new users. If a type-curious creator switches on typechecking and is immediately faced with a wall of spurious red squiggles, there is a strong incentive to immediately switch it off again. - -Inaccuracies in type errors are inevitable, since it is impossible to decide ahead of time whether a runtime error will be triggered. Type system designers have to choose whether to live with false positives or false negatives. In Luau this is determined by the mode: `strict` mode errs on the side of false positives, and `nonstrict` mode errs on the side of false negatives. - -While inaccuracies are inevitable, we try to remove them whenever possible, since they result in spurious errors, and imprecision in type-driven tooling like autocomplete or API documentation. - -## Subtyping as a source of false positives - -One of the sources of false positives in Luau (and many other similar languages like TypeScript or Flow) is *subtyping*. Subtyping is used whenever a variable is initialized or assigned to, and whenever a function is called: the type system checks that the type of the expression is a subtype of the type of the variable. For example, if we add types to the above program -```lua - local x : CFrame = CFrame.new() - local y : Vector3 | CFrame - if (math.random()) then - y = CFrame.new() - else - y = Vector3.new() - end - local z : Vector3 | CFrame = x * y -``` -then the type system checks that the type of `CFrame` multiplication is a subtype of `(CFrame, Vector3 | CFrame) -> (Vector3 | CFrame)`. - -Subtyping is a very useful feature, and it supports rich type constructs like type union (`T | U`) and intersection (`T & U`). For example, `number?` is implemented as a union type `(number | nil)`, inhabited by values that are either numbers or `nil`. - -Unfortunately, the interaction of subtyping with intersection and union types can have odd results. A simple (but rather artificial) case in older Luau was: -```lua - local x : (number?) & (string?) = nil - local y : nil = nil - y = x -- Type '(number?) & (string?)' could not be converted into 'nil' - x = y -``` -This error is caused by a failure of subtyping, the old subtyping algorithm reports that `(number?) & (string?)` is not a subtype of `nil`. This is a false positive, since `number & string` is uninhabited, so the only possible inhabitant of `(number?) & (string?)` is `nil`. - -This is an artificial example, but there are real issues raised by creators caused by the problems, for example . Currently, these issues mostly affect creators making use of sophisticated type system features, but as we make type inference more accurate, union and intersection types will become more common, even in code with no type annotations. - -This class of false positives no longer occurs in Luau, as we have moved from our old approach of *syntactic subtyping* to an alternative called *semantic subtyping*. - -## Syntactic subtyping - -AKA “what we did before.†- -Syntactic subtyping is a syntax-directed recursive algorithm. The interesting cases to deal with intersection and union types are: - -* Reflexivity: `T` is a subtype of `T` -* Intersection L: `(Tâ‚ & … & Tâ±¼)` is a subtype of `U` whenever some of the `Táµ¢` are subtypes of `U` -* Union L: `(Tâ‚ | … | Tâ±¼)` is a subtype of `U` whenever all of the `Táµ¢` are subtypes of `U` -* Intersection R: `T` is a subtype of `(Uâ‚ & … & Uâ±¼)` whenever `T` is a subtype of all of the `Uáµ¢` -* Union R: `T` is a subtype of `(Uâ‚ | … | Uâ±¼)` whenever `T` is a subtype of some of the `Uáµ¢`. - -For example: - -* By Reflexivity: `nil` is a subtype of `nil` -* so by Union R: `nil` is a subtype of `number?` -* and: `nil` is a subtype of `string?` -* so by Intersection R: `nil` is a subtype of `(number?) & (string?)`. - -Yay! Unfortunately, using these rules: - -* `number` isn’t a subtype of `nil` -* so by Union L: `(number?)` isn’t a subtype of `nil` -* and: `string` isn’t a subtype of `nil` -* so by Union L: `(string?)` isn’t a subtype of `nil` -* so by Intersection L: `(number?) & (string?)` isn’t a subtype of `nil`. - -This is typical of syntactic subtyping: when it returns a “yes†result, it is correct, but when it returns a “no†result, it might be wrong. The algorithm is a *conservative approximation*, and since a “no†result can lead to type errors, this is a source of false positives. - -## Semantic subtyping - -AKA “what we do now.†- -Rather than thinking of subtyping as being syntax-directed, we first consider its semantics, and later return to how the semantics is implemented. For this, we adopt semantic subtyping: - - * The semantics of a type is a set of values. - * Intersection types are thought of as intersections of sets. - * Union types are thought of as unions of sets. - * Subtyping is thought of as set inclusion. - -For example: - -| Type | Semantics | -|------|-----------| -| `number` | { 1, 2, 3, … } | -| `string` | { “fooâ€, “barâ€, … } | -| `nil` | { nil } | -| `number?` | { nil, 1, 2, 3, … } | -| `string?` | { nil, “fooâ€, “barâ€, … } | -| `(number?) & (string?)` | { nil, 1, 2, 3, … } ∩ { nil, “fooâ€, “barâ€, … } = { nil } | - - -and since subtypes are interpreted as set inclusions: - -| Subtype | Supertype | Because | -|---------|-----------|---------| -| `nil` | `number?` | { nil } ⊆ { nil, 1, 2, 3, … } | -| `nil` | `string?`| { nil } ⊆ { nil, “fooâ€, “barâ€, … } | -| `nil` | `(number?) & (string?)` | { nil } ⊆ { nil } | -| `(number?) & (string?)` | `nil` | { nil } ⊆ { nil } | - - -So according to semantic subtyping, `(number?) & (string?)` is equivalent to `nil`, but syntactic subtyping only supports one direction. - -This is all fine and good, but if we want to use semantic subtyping in tools, we need an algorithm, and it turns out checking semantic subtyping is non-trivial. - -## Semantic subtyping is hard - -NP-hard to be precise. - -We can reduce graph coloring to semantic subtyping by coding up a graph as a Luau type such that checking subtyping on types has the same result as checking for the impossibility of coloring the graph - -For example, coloring a three-node, two color graph can be done using types: - -```lua -type Red = "red" -type Blue = "blue" -type Color = Red | Blue -type Coloring = (Color) -> (Color) -> (Color) -> boolean -type Uncolorable = (Color) -> (Color) -> (Color) -> false -``` - -Then a graph can be encoded as an overload function type with -subtype `Uncolorable` and supertype `Coloring`, as an overloaded -function which returns `false` when a constraint is violated. Each -overload encodes one constraint. For example a line has constraints -saying that adjacent nodes cannot have the same color: - -```lua -type Line = Coloring - & ((Red) -> (Red) -> (Color) -> false) - & ((Blue) -> (Blue) -> (Color) -> false) - & ((Color) -> (Red) -> (Red) -> false) - & ((Color) -> (Blue) -> (Blue) -> false) -``` - -A triangle is similar, but the end points also cannot have the same color: - -```lua -type Triangle = Line - & ((Red) -> (Color) -> (Red) -> false) - & ((Blue) -> (Color) -> (Blue) -> false) -``` - -Now, `Triangle` is a subtype of `Uncolorable`, but `Line` is not, since the line can be 2-colored. -This can be generalized to any finite graph with any finite number of colors, and so subtype checking is NP-hard. - -We deal with this in two ways: - -* we cache types to reduce memory footprint, and -* give up with a “Code Too Complex†error if the cache of types gets too large. - -Hopefully this doesn’t come up in practice much. There is good evidence that issues like this don’t arise in practice from experience with type systems like that of Standard ML, which is [EXPTIME-complete](https://dl.acm.org/doi/abs/10.1145/96709.96748), but in practice you have to go out of your way to code up Turing Machine tapes as types. - -## Type normalization - -The algorithm used to decide semantic subtyping is *type normalization*. -Rather than being directed by syntax, we first rewrite types to be normalized, then check subtyping on normalized types. - -A normalized type is a union of: - -* a normalized nil type (either `never` or `nil`) -* a normalized number type (either `never` or `number`) -* a normalized boolean type (either `never` or `true` or `false` or `boolean`) -* a normalized function type (either `never` or an intersection of function types) -etc - -Once types are normalized, it is straightforward to check semantic subtyping. - -Every type can be normalized (sigh, with some technical restrictions around generic type packs). The important steps are: - -* removing intersections of mismatched primitives, e.g. `number & bool` is replaced by `never`, and -* removing unions of functions, e.g. `((number?) -> number) | ((string?) -> string)` is replaced by `(nil) -> (number | string)`. - -For example, normalizing `(number?) & (string?)` removes `number & string`, so all that is left is `nil`. - -Our first attempt at implementing type normalization applied it liberally, but this resulted in dreadful performance (complex code went from typechecking in less than a minute to running overnight). The reason for this is annoyingly simple: there is an optimization in Luau’s subtyping algorithm to handle reflexivity (`T` is a subtype of `T`) that performs a cheap pointer equality check. Type normalization can convert pointer-identical types into semantically-equivalent (but not pointer-identical) types, which significantly degrades performance. - -Because of these performance issues, we still use syntactic subtyping as our first check for subtyping, and only perform type normalization if the syntactic algorithm fails. This is sound, because syntactic subtyping is a conservative approximation to semantic subtyping. - -## Pragmatic semantic subtyping - -Off-the-shelf semantic subtyping is slightly different from what is implemented in Luau, because it requires models to be *set-theoretic*, which requires that inhabitants of function types “act like functions.†There are two reasons why we drop this requirement. - -**Firstly**, we normalize function types to an intersection of functions, for example a horrible mess of unions and intersections of functions: -``` -((number?) -> number?) | (((number) -> number) & ((string?) -> string?)) -``` -normalizes to an overloaded function: -``` -((number) -> number?) & ((nil) -> (number | string)?) -``` -Set-theoretic semantic subtyping does not support this normalization, and instead normalizes functions to *disjunctive normal form* (unions of intersections of functions). We do not do this for ergonomic reasons: overloaded functions are idiomatic in Luau, but DNF is not, and we do not want to present users with such non-idiomatic types. - -Our normalization relies on rewriting away unions of function types: -``` -((A) -> B) | ((C) -> D) → (A & C) -> (B | D) -``` -This normalization is sound in our model, but not in set-theoretic models. - -**Secondly**, in Luau, the type of a function application `f(x)` is `B` if `f` has type `(A) -> B` and `x` has type `A`. Unexpectedly, this is not always true in set-theoretic models, due to uninhabited types. In set-theoretic models, if `x` has type `never` then `f(x)` has type `never`. We do not want to burden users with the idea that function application has a special corner case, especially since that corner case can only arise in dead code. - -In set-theoretic models, `(never) -> A` is a subtype of `(never) -> B`, no matter what `A` and `B` are. This is not true in Luau. - -For these two reasons (which are largely about ergonomics rather than anything technical) we drop the set-theoretic requirement, and use *pragmatic* semantic subtyping. - -## Negation types - -The other difference between Luau’s type system and off-the-shelf semantic subtyping is that Luau does not support all negated types. - -The common case for wanting negated types is in typechecking conditionals: -```lua --- initially x has type T -if (type(x) == "string") then - -- in this branch x has type T & string -else - -- in this branch x has type T & ~string -end -``` -This uses a negated type `~string` inhabited by values that are not strings. - -In Luau, we only allow this kind of typing refinement on *test types* like `string`, `function`, `Part` and so on, and *not* on structural types like `(A) -> B`, which avoids the common case of general negated types. - -## Prototyping and verification - -During the design of Luau’s semantic subtyping algorithm, there were changes made (for example initially we thought we were going to be able to use set-theoretic subtyping). During this time of rapid change, it was important to be able to iterate quickly, so we initially implemented a [prototype](https://github.com/luau-lang/agda-typeck) rather than jumping straight to a production implementation. - -Validating the prototype was important, since subtyping algorithms can have unexpected corner cases. For this reason, we adopted Agda as the prototyping language. As well as supporting unit testing, Agda supports mechanized verification, so we are confident in the design. - -The prototype does not implement all of Luau, just the functional subset, but this was enough to discover subtle feature interactions that would probably have surfaced as difficult-to-fix bugs in production. - -Prototyping is not perfect, for example the main issues that we hit in production were about performance and the C++ standard library, which are never going to be caught by a prototype. But the production implementation was otherwise fairly straightforward (or at least as straightforward as a 3kLOC change can be). - -## Next steps - -Semantic subtyping has removed one source of false positives, but we still have others to track down: - -* overloaded function applications and operators, -* property access on expressions of complex type, -* read-only properties of tables, -* variables that change type over time (aka typestates), -* … - -The quest to remove spurious red squiggles continues! - -## Acknowledgments - -Thanks to Giuseppe Castagna and Ben Greenman for helpful comments on drafts of this post. - -## Further reading - -If you want to find out more about Luau and semantic subtyping, you might want to check out… - -* Luau. -* Lily Brown, Andy Friesen and Alan Jeffrey, *Goals of the Luau Type System*, Human Aspects of Types and Reasoning Assistants (HATRA), 2021. -* Luau Typechecker Prototype. -* Agda. -* Andrew M. Kent. *Down and Dirty with Semantic Set-theoretic Types*, 2021. -* Giuseppe Castagna, *Covariance and Contravariance*, Logical Methods in Computer Science 16(1), 2022. -* Giuseppe Castagna and Alain Frisch, *A gentle introduction to semantic subtyping*, Proc. Principles and practice of declarative programming (PPDP), pp 198–208, 2005. -* Giuseppe Castagna, Mickaël Laurent, Kim Nguyá»…n, Matthew Lutze, *On Type-Cases, Union Elimination, and Occurrence Typing*, Principles of Programming Languages (POPL), 2022. -* Giuseppe Castagna, *Programming with union, intersection, and negation types*, 2022. -* Sam Tobin-Hochstadt and Matthias Felleisen, *Logical types for untyped languages*. International Conference on Functional Programming (ICFP), 2010. -* José Valim, *My Future with Elixir: set-theoretic types*, 2022. - -Some other languages which support semantic subtyping… - -* â„‚Duce -* Ballerina -* Elixir -* eqWAlizer - -And if you want to see the production code, it's in the C++ definitions of [tryUnifyNormalizedTypes](https://github.com/Roblox/luau/blob/d6aa35583e4be14304d2a17c7d11c8819756beb6/Analysis/src/Unifier.cpp#L868) and [NormalizedType](https://github.com/Roblox/luau/blob/d6aa35583e4be14304d2a17c7d11c8819756beb6/Analysis/include/Luau/Normalize.h#L134) in the [open source Luau repo](https://github.com/Roblox/luau). diff --git a/docs/_posts/2022-11-01-luau-recap-september-october-2022.md b/docs/_posts/2022-11-01-luau-recap-september-october-2022.md deleted file mode 100644 index 7d99babb8..000000000 --- a/docs/_posts/2022-11-01-luau-recap-september-october-2022.md +++ /dev/null @@ -1,82 +0,0 @@ ---- -layout: single -title: "Luau Recap: September & October 2022" ---- - -Luau is our new language that you can read more about at [https://luau-lang.org](https://luau-lang.org). - -[Cross-posted to the [Roblox Developer Forum](https://devforum.roblox.com/t/luau-recap-september-october-2022/).] - -## Semantic subtyping - -One of the most important goals for Luau is to avoid *false -positives*, that is cases where Script Analysis reports a type error, -but in fact the code is correct. This is very frustrating, especially -for beginners. Spending time chasing down a gnarly type error only to -discover that it was the type system that's wrong is nobody's idea of fun! - -We are pleased to announce that a major component of minimizing false -positives has landed, *semantic subtyping*, which removes a class of false positives caused -by failures of subtyping. For example, in the program - -```lua - local x : CFrame = CFrame.new() - local y : Vector3 | CFrame - if (math.random()) then - y = CFrame.new() - else - y = Vector3.new() - end - local z : Vector3 | CFrame = x * y -- Type Error! -``` - -an error is reported, even though there is no problem at runtime. This -is because `CFrame`'s multiplication has two overloads: - -```lua - ((CFrame, CFrame) -> CFrame) - & ((CFrame, Vector3) -> Vector3) -``` - -The current syntax-driven algorithm for subtyping is not sophisticated -enough to realize that this is a subtype of the desired type: - -```lua - (CFrame, Vector3 | CFrame) -> (Vector3 | CFrame) -``` - -Our new algorithm is driven by the semantics of subtyping, not the syntax of types, -and eliminates this class of false positives. - -If you want to know more about semantic subtyping in Luau, check out our -[technical blog post](https://luau-lang.org/2022/10/31/luau-semantic-subtyping.html) -on the subject. - -## Other analysis improvements - -* Improve stringification of function types. -* Improve parse error warnings in the case of missing tokens after a comma. -* Improve typechecking of expressions involving variadics such as `{ ... }`. -* Make sure modules don't return unbound generic types. -* Improve cycle detection in stringifying types. -* Improve type inference of combinations of intersections and generic functions. -* Improve typechecking when calling a function which returns a variadic e.g. `() -> (number...)`. -* Improve typechecking when passing a function expression as a parameter to a function. -* Improve error reporting locations. -* Remove some sources of memory corruption and crashes. - -## Other runtime and debugger improvements - -* Improve performance of accessing debug info. -* Improve performance of `getmetatable` and `setmetatable`. -* Remove a source of freezes in the debugger. -* Improve GC accuracy and performance. - -## Thanks - -Thanks for all the contributions! - -* [AllanJeremy](https://github.com/AllanJeremy) -* [JohnnyMorganz](https://github.com/JohnnyMorganz) -* [jujhar16](https://github.com/jujhar16) -* [petrihakkinen](https://github.com/petrihakkinen) diff --git a/docs/_posts/2022-11-04-luau-origins-and-evolution.md b/docs/_posts/2022-11-04-luau-origins-and-evolution.md deleted file mode 100644 index 707c5d41a..000000000 --- a/docs/_posts/2022-11-04-luau-origins-and-evolution.md +++ /dev/null @@ -1,104 +0,0 @@ ---- -layout: single -title: "Luau origins and evolution" -author: Arseny Kapoulkine ---- - -At the heart of Roblox technology lies [Luau](https://luau-lang.org), a scripting language derived from Lua 5.1 that is being [developed](https://github.com/Roblox/luau) by an internal team of programming language experts with the help of open source contributors. - -It powers all user-generated content on Roblox, providing access to a very rich set of APIs that allows manipulation of objects in the 3D world, backend API access, UI interaction and more. Hundreds of thousands of developers write code in Luau every month, with top experiences using hundreds of thousands of lines of code, adding up to hundreds of millions of lines of code across the platform. For many of them, it is the first programming language they learn, and one they spend the majority of their time programming in. Using a set of extended APIs developers also customize their workflows by writing plugins to Roblox Studio, where they work on their experiences, using an extended API surface to interact with all aspects of the editor. - -It also powers a lot of application code that Roblox engineers are writing: Universal App, the gateway to the worlds of Roblox that is used by tens of millions of people every day, has 95% of its functionality implemented in Luau, and Roblox Studio has a lot of builtin critical functionality such as part and terrain editors, marketplace browser, avatar and animation editors, material manager and more, implemented in Luau as a plugin, mostly using the same APIs that developers have access to. Every week, updates to this internal codebase that is now over 2 million lines large, are shipped to all Roblox users. - -In addition to Roblox use cases, Luau is also open-source and is seeing an increased adoption in other projects and applications. - -But why did we use Lua in the first place, and why did we decide to pursue building a new language on top of it? - -# Early beginnings - -Around 2006, when a very early version of the Roblox platform was developed, the question of user generated behaviors emerged. Before that, users were able to build non-interactive content on Roblox, and the only form of interaction was physics simulation. While this provided rich emergent behavior, it was hard to build gameplay on top of this: for example, to build a Capture The Flag game, you need to handle collision between players and flags spread throughout the map with a bit of logic that dictates how to adjust team points and when to remove or recreate the objects. - -After an early and brief misstep when we decided to add a few gameplay objects to the core definition of Roblox worlds (some developers may recognize FlagStand as a class name...), the Roblox co-founder Erik Cassel realized that an approach like this is fundamentally limiting the power of user generated content. It’s not enough to give creators the basic blocks on top of which to build their creations, it’s critical to expose the power of a full Turing-complete programming language. Without this, the expressive capability and the reach of the platform would have been restricted far too much. - -But which programming language to choose? This is where [Lua](https://lua.org/), which was, and still is, one of the dominant programming languages used in video games, comes in. - -In addition to its simplicity, which made the language easy to learn and get productive in, Lua was the fastest scripting language compared to popular alternatives like Python or JavaScript at the time[^1], designed to be embedded which meant an easy ability to expose APIs from the host application to the scripts as well as high degree of execution control from the host, and implemented coroutines, a very powerful concurrency primitive that allowed to easily and intuitively script behaviors for independent actors in game using linear control flow. - -Instead of having a large standard library, the expectation was that the embedding application would define a set of APIs that that application needed, as well as establish policies of running the code - which gave us a lot of freedom in how to structure the APIs and when the scripts would get triggered during the simulation of a single frame. - -# Power of simplicity - -Lua is a simple language. What does simplicity mean for us? - -Being a simple language means having a small set of features. Lua has all the fundamental features but doesn’t have a lot of syntax sugar - this means the language is easier to teach and learn, and you rarely run into code that’s difficult to understand syntactically because it uses an unfamiliar construct. Of course, this also means that some programs in Lua are longer than equivalent programs in languages that have more dedicated constructs to solve specific problems, such as list comprehensions in Python. - -Being a simple language means having a minimal set of rules for every feature. Lua does deviate from this in certain respects (which is to say, the language could have been even simpler!), but notably for a dynamic language the behavior of fundamental operators is generally easy to explain and unsurprising - for example, two values in Lua are equal iff they have the same type and the same value, as such `0 == “0â€` is `false`; as another example, `for` loops introduce unique variable bindings on every iteration, as such capturing the iteration variable in a closure produces unique values. These decisions lead to more concise and efficient implementation and eliminate a class of bugs in programs. - -Being a simple language means having a small implementation. This may be immaterial to people writing code in the language, but it leads to an implementation that can be of higher quality; simpler implementations can also be easier to optimize for memory or performance, and are easier to build upon. - -Developers on the Roblox platform have very diverse programming backgrounds. Some are writing their first line of code in Roblox Studio, while others have computer science degrees and experience working in multiple different programming languages. While it’s always possible to support two different programming languages that target different segments of the audience, that fragments the ecosystem and makes the programming story less consistent (impacting documentation, tutorials, code reuse, ability for community members to help each other write code, presents challenges with interaction between different languages in the same experience and more). A better outcome is one where a single language can serve both audiences - this requires a language that strikes a balance between simplicity and generality, and while Lua isn’t perfect here, it’s great as a foundation for a language like this[^2]. - -In many ways, Lua is simultaneously simple and pragmatic: many parts of the language are difficult to make much better without a lot of added complexity, but at the same time it requires little in the way of extra functionality to be able to solve problems efficiently. That said, no language is perfect, and within several areas of Lua we felt that the tradeoffs weren’t quite right for our use case. - -# Respectful evolution - -In 2019, we decided to build [Luau](https://luau-lang.org) - a language derived from Lua and compatible with Lua 5.1, which is the version we’ve been using all these years. At the time we evaluated other routes, but ultimately settled on this as the most optimal long-term. - -On one hand, we loved a lot of things about Lua - both design wise and implementation wise, while there were some decisions we felt were suboptimal, by and large it was an almost perfect foundation for what we’ve set out to achieve. - -On the other hand, we’ve been running into the limitations of Lua on large code bases in absence of type checking, performance was good but not great, and some missing features would have been great to have. - -Some of the things we’ve been missing have been added in later versions of Lua, yet we were still using Lua 5.1. While we would have loved to use a later version of the language standard, Lua 5.x releases are not backwards compatible, and some releases remove support for features that are in wide use at Roblox. For Roblox, backwards compatibility is an essential feature of the platform - while we don’t have a guarantee that content created 10 years ago still works, to the extent that we can achieve that without restricting the platform evolution too much, we try. - -What we’ve realized is that Lua is a great foundation for a perfect language that we can build for Roblox. - -We would maintain backwards compatibility with Lua 5.1 but evolve the language from there; sometimes this means taking later features from Lua that don’t conflict with the existing language or our design values, sometimes this means innovating beyond what Lua has done. Crucially, we must maintain the balance between simplicity and power - we still value simplicity, we still need to avoid a feature explosion to ensure that the features compose and are of high quality, and we still need the language to be a good fit for beginners. - -One of the largest limitations that we’ve seen is the lack of type checking making it easy to make mistakes in large code bases, as such [support for type checking](https://luau-lang.org/typecheck) was a requirement for Luau. However, it’s important that the type checker is mostly transparent to the developers who don’t want to invest the time to learn it - anything else would change the learning curve too much for the language to be suitable for beginners. As such, we’ve investing in gradual typing, and our type checker is learning to strike a balance between inferring useful types for completely untyped programs (which, among other things, greatly enhances editing experience through type-aware autocomplete), and the lack of false positive diagnostics that can be confusing and distracting. - -While we did need to introduce [extra syntax](https://luau-lang.org/syntax) to the language - most notably, to support optional type annotations - it was important for us to maintain the cohesion of the overall syntax. We aren’t seeking to make a new language with a syntax alien to Lua programmers - Luau programs are still recognizably Lua, and to the extent possible we try to avoid new syntactic features. In a sense, we still want the syntax, semantics, and the runtime to be simple and minimal - but at the same time we have important problems to solve with respect to ergonomics, robustness and performance of the language, and solving some of them requires having slightly more complex syntax, semantics, or implementation. - -So in finding ways to evolve Luau, we strive to design features that feel like they would be at home in Lua. At the same time, we’ve adopted a more open evolution process - the language development is driven [through RFCs](https://github.com/Roblox/luau/blob/master/rfcs/README.md) that are designs open to the public that anyone can contribute to - this is in contrast with Lua, which has a very closed development process, and is one of the reasons why it would have been difficult for us to keep using Lua as we wouldn’t get a say in its development. At the same time, to ensure the design criterias are met, it’s important that the Luau development team at Roblox maintains a final say over design and implementation of the language[^3], while taking the community's proposals and input into consideration. - -# Importance of co-design - -Luau language is developed in concert with the language compiler, runtime, type checker and other analysis tools, autocomplete engine and other tooling, and that development is guided by the vast volume of existing Luau code, both internal and external. - -This is one of the key principles behind our evolution philosophy - neither layer is developed in isolation, and instead concerns at every level inform all other aspects of the language design and implementation. - -This means that when designing language features, we make sure that they can be implemented efficiently, type checked properly, can be supported well in editing and analysis tools and have a positive impact on the code internal and external engineers write. When we find issues in any component, we can always ask, what changes to other components or even language design would make for a better overall solution. - -This avoids some classes of design problems, for example we won’t specify a language feature that has a prohibitively high implementation cost, as it violates our simplicity criteria, or that is impractical to implement efficiently, as that would create a performance hazard. This also means that when implementing various components of the language we cross-check the concerns and applicability of these across the entire stack - for example, we’ve reworked our auto-complete system to use the same type inference engine that the type checking / analysis tools use, which had immense benefits for the experience of editing code, but also applied significant back pressure on the type inference itself, forcing us to improve it substantially and fix a lot of corner cases that would otherwise have lingered unnoticed. - -Whenever we develop features, optimizations, improve our analysis engine or enhance the standard libraries, we also heavily rely on code written in Luau to validate our hypotheses. When working on new features we find motivation in the real problems that we see our developers face. For example, we implemented the [new ternary operator](https://luau-lang.org/syntax#if-then-else-expressions) after seeing a large set of cases where existing Lua’s `a and b or c` pattern was error-prone for boolean values, which made it easy to accidentally introduce a mistake that was hard to identify automatically. All optimizations and new analysis features are validated on our internal 2M LOC codebase before being added to Luau, which allows us to quickly get initial validation of ideas, or invalidate some approaches as infeasible / unhelpful. - -In addition to that, while we don’t have direct access to community-developed source code for privacy reasons, we can run experiments and collect telemetry[^4], which also helps us make decisions regarding backwards compatibility. Due to [Hyrum’s law](https://www.hyrumslaw.com/), technically any change in the language or libraries, no matter how small, would be backwards incompatible - instead we adopt the notion of pragmatic balance between strict backwards compatibility[^5] and pragmatic compatibility concerns. For example, later versions of Lua make some library functions like `table.insert`/`table.remove` more strict with how they handle out of range indices. We have evaluated this change for compatibility by collecting telemetry on the use of out of range indices in these functions on the Roblox platform and concluded that applying the stricter checking would break existing programs, and instead had to slightly adjust the rules for out of range behavior in ways that was benign for existing code but prevented catastrophic performance degradation for large out of range indices. Because we couldn’t afford to introduce new runtime errors in this case, we also added a set of linting rules to our analysis engine to flag potential misuse of `table.insert`/`table.remove` before the code ever gets to run - this diagnostics is informational and as such doesn’t affect backwards compatibility, but does help prevent mistakes. - -There are also cases where this co-design approach prevents introduction of features that can lead to easy misuse, which can be difficult to see in the design of the feature itself, but becomes more apparent when you consider features in context of the entire ecosystem. This is a good thing - it means co-design acts as a forcing function on the language simplicity and makes it easier to flag potential bad interactions between different language features, or language features and tooling, or language features and existing programming patterns that are in widespread use in real-world code. By making sure that all features are validated for their impact across the stack and on code written in Luau, we ultimately get a better, simpler and more cohesive language. - -# Efficient execution - -One of the critical goals in front of Luau is efficiency, both from the performance and memory perspective. There’s only so many milliseconds in a frame, and we simultaneously see the need to increase the scale and complexity of simulated experiences, which requires more memory and computation, as well as the need to fit more comfortably into smaller budgets of performance memory for better experience on smaller devices. In fact, one of the motivations for Luau in 2019 has been improved performance, as we saw many opportunities to go beyond Lua with a redesigned implementation. - -Crucially, our performance needs are somewhat unique and require somewhat unique solutions. - -We need Luau to run on many platforms where native code generation is either prohibited by the platform vendor or impractical due to tight memory constraints. As such, in terms of execution performance it’s critical that we have a very fast interpreter[^6]. However, we have freedom in terms of high level design of the entire stack - for example, clients never see the source code of the scripts as all compilation to bytecode happens on the server; this gives us an opportunity to perform more involved and expensive optimizations during that process as well as have the smallest possible startup time on the client without complex pre-parse steps. Notably, our bytecode compiler performs a series of high level optimizations including function inlining and loop unrolling that in other dynamic languages is often left to the just-in-time compiler. - -Another area where performance is critical is garbage collection. Garbage collection is crucial for the language’s simplicity as it makes memory management easier to reason about, but it does require a substantial amount of implementation effort to keep it efficient. For Roblox and for any other game engine or interactive simulation, latency is critical and so our collector is heavily optimized for that - to the extent possible collection is incremental and stop-the-world pauses are very brief. Another part of the performance story here however is the language and data structure design - by making sure that core data types are efficient in how they are laid out in memory we reduce the amount of work garbage collector takes to trace the heap, and, as another example of co-design, we try to make sure that language features are conscious of the impact they have on memory and garbage collection efficiency. - -However, from a whole-platform standpoint there’s a lot of performance aspects that go beyond single-threaded execution. This is an active area of research and development for the team, as to really leverage the hardware the code is running on we need to think about SIMD, hardware thread utilization as well as running code in a cluster of nodes. These considerations inform current and future development of the runtime and the language (for example, our runtime now supports efficient operations on short SIMD vectors even in interpreted mode, and the VM is fairly lightweight to instantiate which makes running many VMs per core practical, with message passing or access to shared Roblox data model used to make gameplay features feasible to implement), but we’re definitely in the early days here - our first implementation of parallel script execution in Roblox just [shipped earlier this year](https://devforum.roblox.com/t/full-release-of-parallel-luau-v1/1836187). This is likely the area where a lot of future innovations will happen as well. - -# Future - -We’re very happy with the success of Luau - in several years we’ve established consistent processes for evolving the language and so far we found a good balance between simplicity, ease of use, performance and robustness of the language, its implementation and the tooling surrounding it. The language keeps continuously evolving but at a pace that is easy to stay on top of - in 2022 we shipped a few syntactic extensions for type annotations but no changes to the syntax of the language outside of types, and only one major [semantic change to the for loop iteration](https://luau-lang.org/syntax#generalized-iteration) that actually made the language easier to use by avoiding the need to specify the table traversal style via `pairs`/`ipairs`. We try to make sure that the features are general and provide enough extensibility so that libraries can be built on top of the language to make it easier to write code, while also making it practical to use the language without complex supporting frameworks. - -There’s still a lot of ground to cover, and we’ll be working on Luau for years to come. We’re in the process of building the next version of our type inference / checking engine to make sure that all users of the language regardless of their expertise benefit from it, we’ve started investing in native code generation as we’re reaching the limits of interpreted performance (although some exciting opportunities for compiler optimization are still on the horizon), and there’s still a lot of hard design and implementation work ahead of us for some important language features and standard libraries. And as mentioned, our execution model will likely see a lot of innovation as we push the boundaries of hardware utilization across cores and nodes. - -Overall, Luau is like an iceberg - the surface is simple to learn and use, but it hides the tremendous amount of careful design, engineering and attention to detail, and we plan to continue to invest in it while trying to keep the outer surface comparatively small. We're excited to see how far we can take it! - -[^1]: High-performance JavaScript engines didn’t exist at the time! LuaJIT was around the corner and redefined the performance expectations of dynamic languages. -[^2]: In fact, scaling to large teams of expert programmers is one of the core motivations behind our creating Luau, while a requirement to still be suitable for beginner programmers guides our evolution direction. -[^3]: This would have been difficult to drive in any existing large established language like JavaScript or Python. -[^4]: This is limited to Roblox platform and doesn't exist in open-source releases. -[^5]: Which we do follow in some areas, such as syntactic compatibility - all existing programs that parse must continue to parse the same way as the language evolves. -[^6]: Some design decisions and implementation techniques are documented on our [performance page](https://luau-lang.org/performance). diff --git a/docs/_posts/2022-11-30-luau-recap-november-2022.md b/docs/_posts/2022-11-30-luau-recap-november-2022.md deleted file mode 100644 index 29b4c6dca..000000000 --- a/docs/_posts/2022-11-30-luau-recap-november-2022.md +++ /dev/null @@ -1,96 +0,0 @@ ---- -layout: single -title: "Luau Recap: November 2022" ---- - -While the team is busy to bring some bigger things in the future, we have made some small improvements this month. - -[Cross-posted to the [Roblox Developer Forum](https://devforum.roblox.com/t/luau-recap-november-2022/).] - -## Analysis improvements - -We have improved tagged union type refinements to only include unhandled type cases in the `else` branch of the `if` statement: - -```lua -type Ok = { tag: "ok", value: T } -type Err = { tag: "error", msg: string } -type Result = Ok | Err - -function unwrap(r: Result): T? - if r.tag == "ok" then - return r.value - else - -- Luau now understands that 'r' here can only be the 'Err' part - print(r.msg) - return nil - end -end -``` - -For better inference, we updated the definition of `Enum.SomeType:GetEnumItems()` to return `{Enum.SomeType}` instead of common `{EnumItem}` and the return type of `next` function now includes the possibility of key being `nil`. - -Finally, if you use `and` operator on non-boolean values, `boolean` type will no longer be added by the type inference: - -```lua -local function f1(a: number?) - -- 'x' is still a 'number?' and doesn't become 'boolean | number' - local x = a and 5 -end -``` - -## Error message improvements - -We now give an error when built-in types are being redefined: - -```lua -type string = number -- Now an error: Redefinition of type 'string' -``` - -We also had a parse error missing in case you forgot your default type pack parameter value. We accepted the following code silently without raising an issue: - -```lua -type Foo = nil -- Now an error: Expected type, got '>' -``` - -Error about function argument count mismatch no longer points at the last argument, but instead at the function in question. -So, instead of: - -```lua -function myfunction(a: number, b:number) end -myfunction(123) - ~~~ -``` - -We now highlight this: - -```lua -function myfunction(a: number, b:number) end -myfunction(123) -~~~~~~~~~~ -``` - -If you iterate over a table value that could also be `nil`, you get a better explanation in the error message: - -```lua -local function f(t: {number}?) - for i,v in t do -- Value of type {number}? could be nil - --... - end -end -``` -Previously it was `Cannot call non-function {number}?` which was confusing. - -And speaking of confusing, some of you might have seen an error like `Type 'string' could not be converted into 'string'`. - -This was caused by Luau having both a primitive type `string` and a table type coming from `string` library. Since the way you can get the type of the `string` library table is by using `typeof(string)`, the updated error message will mirror that and report `Type 'string' could not be converted into 'typeof(string)'`. - - -Parsing now recovers with a more precise error message if you forget a comma in table constructor spanning multiple lines: - -```lua -local t = { - a = 1 - b = 2 -- Expected ',' after table constructor element - c = 3 -- Expected ',' after table constructor element -} -``` diff --git a/docs/_posts/2023-02-02-luau-string-interpolation.md b/docs/_posts/2023-02-02-luau-string-interpolation.md deleted file mode 100644 index 0bfa33a23..000000000 --- a/docs/_posts/2023-02-02-luau-string-interpolation.md +++ /dev/null @@ -1,33 +0,0 @@ ---- -layout: single -title: "String Interpolation" ---- - -String interpolation is the new syntax introduced to Luau that allows you to create a string literal with expressions inside of that string literal. - -In short, it's a safer and more ergonomic alternative over `string.format`. - -Here's a quick example of a string interpolation: - -```lua -local combos = {2, 7, 1, 8, 5} -print(`The lock combination is {table.concat(combos)}. Again, {table.concat(combos, ", ")}.`) ---> The lock combination is 27185. Again, 2, 7, 1, 8, 5. -``` - -String interpolation also composes well with the `__tostring` metamethod. - -```lua -local balance = setmetatable({ value = 500 }, { - __tostring = function(self) - return "$" .. tostring(self.value) - end -}) - -print(`You have {balance}!`) ---> You have $500! -``` - -To find out more details about this feature, check out [Luau Syntax page](/syntax#string-interpolation). - -This is also the first major language feature implemented in a [contribution](https://github.com/Roblox/luau/pull/614) from the open-source community. Thanks [Kampfkarren](https://github.com/Kampfkarren)! diff --git a/docs/assets/images/chess-profile.svg b/docs/assets/images/chess-profile.svg deleted file mode 100644 index 742dc6f2d..000000000 --- a/docs/assets/images/chess-profile.svg +++ /dev/null @@ -1,1356 +0,0 @@ - - - - - - - - - - - - - -Flame Graph -Reset Zoom -Search -ic - - - - - - -
Function: [:0] (31,070 usec, 100.0%); self: 0 usec
- - - -
- - -chess.lua:3 -
Function: [chess.lua:3] (31,070 usec, 100.0%); self: 770 usec
- - - -
- -test -chess.lua:510 -
Function: test [chess.lua:510] (30,300 usec, 97.5%); self: 0 usec
- -test -test -
- -moveList -chess.lua:453 -
Function: moveList [chess.lua:453] (30,300 usec, 97.5%); self: 0 usec
- -moveList -moveList -
- -pmoves -chess.lua:310 -
Function: pmoves [chess.lua:310] (500 usec, 1.6%); self: 0 usec
- - -pmoves -
- -illegalyChecked -chess.lua:476 -
Function: illegalyChecked [chess.lua:476] (28,700 usec, 92.4%); self: 300 usec
- -illegalyChecked -illegalyChecked -
- -applyMove -chess.lua:490 -
Function: applyMove [chess.lua:490] (1,100 usec, 3.5%); self: 200 usec
- -app.. -applyMove -
- -generate -chess.lua:319 -
Function: generate [chess.lua:319] (500 usec, 1.6%); self: 0 usec
- - -generate -
- -pmoves -chess.lua:310 -
Function: pmoves [chess.lua:310] (27,400 usec, 88.2%); self: 100 usec
- -pmoves -pmoves -
- -band -chess.lua:125 -
Function: band [chess.lua:125] (200 usec, 0.6%); self: 0 usec
- - -band -
- -set -chess.lua:195 -
Function: set [chess.lua:195] (300 usec, 1.0%); self: 0 usec
- - -set -
- -ctz -chess.lua:141 -
Function: ctz [chess.lua:141] (400 usec, 1.3%); self: 400 usec
- - -ctz -
- -empty -chess.lua:137 -
Function: empty [chess.lua:137] (100 usec, 0.3%); self: 100 usec
- - -empty -
- -updateCache -chess.lua:283 -
Function: updateCache [chess.lua:283] (600 usec, 1.9%); self: 0 usec
- -u.. -updateCache -
- -set -chess.lua:195 -
Function: set [chess.lua:195] (100 usec, 0.3%); self: 0 usec
- - -set -
- -index -chess.lua:274 -
Function: index [chess.lua:274] (100 usec, 0.3%); self: 100 usec
- - -index -
- -new -chess.lua:228 -
Function: new [chess.lua:228] (100 usec, 0.3%); self: 0 usec
- - -new -
- -move -chess.lua:109 -
Function: move [chess.lua:109] (100 usec, 0.3%); self: 0 usec
- - -move -
- -isolate -chess.lua:304 -
Function: isolate [chess.lua:304] (400 usec, 1.3%); self: 0 usec
- - -isolate -
- -generate -chess.lua:319 -
Function: generate [chess.lua:319] (27,300 usec, 87.9%); self: 2,700 usec
- -generate -generate -
- -from -chess.lua:75 -
Function: from [chess.lua:75] (200 usec, 0.6%); self: 200 usec
- - -from -
- -from -chess.lua:75 -
Function: from [chess.lua:75] (300 usec, 1.0%); self: 200 usec
- - -from -
- -bor -chess.lua:129 -
Function: bor [chess.lua:129] (600 usec, 1.9%); self: 200 usec
- -bor -bor -
- -from -chess.lua:75 -
Function: from [chess.lua:75] (100 usec, 0.3%); self: 100 usec
- - -from -
- -GC -
Function: GC [GC:0] (100 usec, 0.3%); self: 100 usec
- - -GC -
- -right -chess.lua:101 -
Function: right [chess.lua:101] (100 usec, 0.3%); self: 0 usec
- - -right -
- -map -chess.lua:295 -
Function: map [chess.lua:295] (400 usec, 1.3%); self: 100 usec
- - -map -
- -move -chess.lua:109 -
Function: move [chess.lua:109] (4,400 usec, 14.2%); self: 800 usec
- -move -move -
- -band -chess.lua:125 -
Function: band [chess.lua:125] (1,900 usec, 6.1%); self: 400 usec
- -band -band -
- -isolate -chess.lua:304 -
Function: isolate [chess.lua:304] (13,000 usec, 41.8%); self: 0 usec
- -isolate -isolate -
- -index -chess.lua:274 -
Function: index [chess.lua:274] (700 usec, 2.3%); self: 200 usec
- -i.. -index -
- -down -chess.lua:97 -
Function: down [chess.lua:97] (100 usec, 0.3%); self: 0 usec
- - -down -
- -left -chess.lua:105 -
Function: left [chess.lua:105] (1,000 usec, 3.2%); self: 100 usec
- -left -left -
- -up -chess.lua:93 -
Function: up [chess.lua:93] (600 usec, 1.9%); self: 100 usec
- -up -up -
- -right -chess.lua:101 -
Function: right [chess.lua:101] (800 usec, 2.6%); self: 100 usec
- -ri.. -right -
- -GC -
Function: GC [GC:0] (500 usec, 1.6%); self: 500 usec
- - -GC -
- -bor -chess.lua:129 -
Function: bor [chess.lua:129] (900 usec, 2.9%); self: 400 usec
- -bor -bor -
- -empty -chess.lua:137 -
Function: empty [chess.lua:137] (400 usec, 1.3%); self: 400 usec
- - -empty -
- -some -chess.lua:207 -
Function: some [chess.lua:207] (300 usec, 1.0%); self: 200 usec
- - -some -
- -GC -
Function: GC [GC:0] (100 usec, 0.3%); self: 100 usec
- - -GC -
- -from -chess.lua:75 -
Function: from [chess.lua:75] (400 usec, 1.3%); self: 100 usec
- - -from -
- -lshift -chess.lua:164 -
Function: lshift [chess.lua:164] (100 usec, 0.3%); self: 0 usec
- - -lshift -
- -updateCache -chess.lua:283 -
Function: updateCache [chess.lua:283] (100 usec, 0.3%); self: 100 usec
- - -updateCache -
- - -chess.lua:305 -
Function: [chess.lua:305] (200 usec, 0.6%); self: 0 usec
- - - -
- -right -chess.lua:101 -
Function: right [chess.lua:101] (1,800 usec, 5.8%); self: 300 usec
- -right -right -
- -up -chess.lua:93 -
Function: up [chess.lua:93] (400 usec, 1.3%); self: 200 usec
- - -up -
- -left -chess.lua:105 -
Function: left [chess.lua:105] (1,100 usec, 3.5%); self: 100 usec
- -left -left -
- -down -chess.lua:97 -
Function: down [chess.lua:97] (300 usec, 1.0%); self: 200 usec
- - -down -
- -from -chess.lua:75 -
Function: from [chess.lua:75] (1,500 usec, 4.8%); self: 700 usec
- -from -from -
- -map -chess.lua:295 -
Function: map [chess.lua:295] (13,000 usec, 41.8%); self: 300 usec
- -map -map -
- -index -chess.lua:187 -
Function: index [chess.lua:187] (500 usec, 1.6%); self: 500 usec
- - -index -
- -rshift -chess.lua:176 -
Function: rshift [chess.lua:176] (100 usec, 0.3%); self: 100 usec
- - -rshift -
- -rshift -chess.lua:176 -
Function: rshift [chess.lua:176] (300 usec, 1.0%); self: 100 usec
- - -rshift -
- -inverse -chess.lua:133 -
Function: inverse [chess.lua:133] (400 usec, 1.3%); self: 100 usec
- - -inverse -
- -band -chess.lua:125 -
Function: band [chess.lua:125] (200 usec, 0.6%); self: 100 usec
- - -band -
- -lshift -chess.lua:164 -
Function: lshift [chess.lua:164] (500 usec, 1.6%); self: 200 usec
- - -lshift -
- -band -chess.lua:125 -
Function: band [chess.lua:125] (300 usec, 1.0%); self: 200 usec
- - -band -
- -inverse -chess.lua:133 -
Function: inverse [chess.lua:133] (300 usec, 1.0%); self: 0 usec
- - -inverse -
- -lshift -chess.lua:164 -
Function: lshift [chess.lua:164] (100 usec, 0.3%); self: 100 usec
- - -lshift -
- -from -chess.lua:75 -
Function: from [chess.lua:75] (500 usec, 1.6%); self: 300 usec
- - -from -
- -set -chess.lua:195 -
Function: set [chess.lua:195] (100 usec, 0.3%); self: 0 usec
- - -set -
- -GC -
Function: GC [GC:0] (300 usec, 1.0%); self: 300 usec
- - -GC -
- -from -chess.lua:75 -
Function: from [chess.lua:75] (100 usec, 0.3%); self: 0 usec
- - -from -
- -isolate -chess.lua:203 -
Function: isolate [chess.lua:203] (200 usec, 0.6%); self: 0 usec
- - -isolate -
- -inverse -chess.lua:133 -
Function: inverse [chess.lua:133] (200 usec, 0.6%); self: 200 usec
- - -inverse -
- -lshift -chess.lua:164 -
Function: lshift [chess.lua:164] (900 usec, 2.9%); self: 300 usec
- -ls.. -lshift -
- -band -chess.lua:125 -
Function: band [chess.lua:125] (400 usec, 1.3%); self: 300 usec
- - -band -
- -lshift -chess.lua:164 -
Function: lshift [chess.lua:164] (200 usec, 0.6%); self: 100 usec
- - -lshift -
- -rshift -chess.lua:176 -
Function: rshift [chess.lua:176] (400 usec, 1.3%); self: 100 usec
- - -rshift -
- -inverse -chess.lua:133 -
Function: inverse [chess.lua:133] (500 usec, 1.6%); self: 0 usec
- - -inverse -
- -band -chess.lua:125 -
Function: band [chess.lua:125] (100 usec, 0.3%); self: 0 usec
- - -band -
- -rshift -chess.lua:176 -
Function: rshift [chess.lua:176] (100 usec, 0.3%); self: 0 usec
- - -rshift -
- -GC -
Function: GC [GC:0] (800 usec, 2.6%); self: 800 usec
- -GC -GC -
- -updateCache -chess.lua:283 -
Function: updateCache [chess.lua:283] (3,800 usec, 12.2%); self: 500 usec
- -updateCache -updateCache -
- - -chess.lua:305 -
Function: [chess.lua:305] (7,700 usec, 24.8%); self: 300 usec
- - - -
- -new -chess.lua:228 -
Function: new [chess.lua:228] (1,200 usec, 3.9%); self: 1,000 usec
- -new -new -
- -from -chess.lua:75 -
Function: from [chess.lua:75] (200 usec, 0.6%); self: 100 usec
- - -from -
- -from -chess.lua:75 -
Function: from [chess.lua:75] (300 usec, 1.0%); self: 200 usec
- - -from -
- -from -chess.lua:75 -
Function: from [chess.lua:75] (100 usec, 0.3%); self: 0 usec
- - -from -
- -from -chess.lua:75 -
Function: from [chess.lua:75] (300 usec, 1.0%); self: 200 usec
- - -from -
- -from -chess.lua:75 -
Function: from [chess.lua:75] (100 usec, 0.3%); self: 0 usec
- - -from -
- -from -chess.lua:75 -
Function: from [chess.lua:75] (300 usec, 1.0%); self: 200 usec
- - -from -
- -GC -
Function: GC [GC:0] (200 usec, 0.6%); self: 200 usec
- - -GC -
- -from -chess.lua:75 -
Function: from [chess.lua:75] (100 usec, 0.3%); self: 0 usec
- - -from -
- -GC -
Function: GC [GC:0] (100 usec, 0.3%); self: 100 usec
- - -GC -
- -some -chess.lua:207 -
Function: some [chess.lua:207] (100 usec, 0.3%); self: 0 usec
- - -some -
- -band -chess.lua:125 -
Function: band [chess.lua:125] (100 usec, 0.3%); self: 0 usec
- - -band -
- -from -chess.lua:75 -
Function: from [chess.lua:75] (600 usec, 1.9%); self: 200 usec
- -f.. -from -
- -from -chess.lua:75 -
Function: from [chess.lua:75] (100 usec, 0.3%); self: 0 usec
- - -from -
- -from -chess.lua:75 -
Function: from [chess.lua:75] (100 usec, 0.3%); self: 100 usec
- - -from -
- -from -chess.lua:75 -
Function: from [chess.lua:75] (300 usec, 1.0%); self: 100 usec
- - -from -
- -from -chess.lua:75 -
Function: from [chess.lua:75] (500 usec, 1.6%); self: 100 usec
- - -from -
- -from -chess.lua:75 -
Function: from [chess.lua:75] (100 usec, 0.3%); self: 0 usec
- - -from -
- -from -chess.lua:75 -
Function: from [chess.lua:75] (100 usec, 0.3%); self: 0 usec
- - -from -
- -bor -chess.lua:129 -
Function: bor [chess.lua:129] (2,900 usec, 9.3%); self: 900 usec
- -bor -bor -
- -inverse -chess.lua:133 -
Function: inverse [chess.lua:133] (400 usec, 1.3%); self: 100 usec
- - -inverse -
- -isolate -chess.lua:203 -
Function: isolate [chess.lua:203] (7,400 usec, 23.8%); self: 700 usec
- -isolate -isolate -
- -GC -
Function: GC [GC:0] (200 usec, 0.6%); self: 200 usec
- - -GC -
- -GC -
Function: GC [GC:0] (100 usec, 0.3%); self: 100 usec
- - -GC -
- -GC -
Function: GC [GC:0] (100 usec, 0.3%); self: 100 usec
- - -GC -
- -GC -
Function: GC [GC:0] (100 usec, 0.3%); self: 100 usec
- - -GC -
- -GC -
Function: GC [GC:0] (100 usec, 0.3%); self: 100 usec
- - -GC -
- -GC -
Function: GC [GC:0] (100 usec, 0.3%); self: 100 usec
- - -GC -
- -GC -
Function: GC [GC:0] (100 usec, 0.3%); self: 100 usec
- - -GC -
- -GC -
Function: GC [GC:0] (100 usec, 0.3%); self: 100 usec
- - -GC -
- -set -chess.lua:195 -
Function: set [chess.lua:195] (100 usec, 0.3%); self: 0 usec
- - -set -
- -from -chess.lua:75 -
Function: from [chess.lua:75] (100 usec, 0.3%); self: 0 usec
- - -from -
- -GC -
Function: GC [GC:0] (400 usec, 1.3%); self: 400 usec
- - -GC -
- -GC -
Function: GC [GC:0] (100 usec, 0.3%); self: 100 usec
- - -GC -
- -GC -
Function: GC [GC:0] (200 usec, 0.6%); self: 200 usec
- - -GC -
- -GC -
Function: GC [GC:0] (400 usec, 1.3%); self: 400 usec
- - -GC -
- -GC -
Function: GC [GC:0] (100 usec, 0.3%); self: 100 usec
- - -GC -
- -GC -
Function: GC [GC:0] (100 usec, 0.3%); self: 100 usec
- - -GC -
- -from -chess.lua:75 -
Function: from [chess.lua:75] (2,000 usec, 6.4%); self: 1,100 usec
- -from -from -
- -from -chess.lua:75 -
Function: from [chess.lua:75] (300 usec, 1.0%); self: 0 usec
- - -from -
- -some -chess.lua:207 -
Function: some [chess.lua:207] (4,100 usec, 13.2%); self: 1,000 usec
- -some -some -
- -band -chess.lua:125 -
Function: band [chess.lua:125] (2,600 usec, 8.4%); self: 900 usec
- -band -band -
- -from -chess.lua:75 -
Function: from [chess.lua:75] (100 usec, 0.3%); self: 100 usec
- - -from -
- -GC -
Function: GC [GC:0] (100 usec, 0.3%); self: 100 usec
- - -GC -
- -GC -
Function: GC [GC:0] (900 usec, 2.9%); self: 900 usec
- -GC -GC -
- -GC -
Function: GC [GC:0] (300 usec, 1.0%); self: 300 usec
- - -GC -
- -set -chess.lua:195 -
Function: set [chess.lua:195] (3,100 usec, 10.0%); self: 1,500 usec
- -set -set -
- -from -chess.lua:75 -
Function: from [chess.lua:75] (1,700 usec, 5.5%); self: 800 usec
- -from -from -
- -from -chess.lua:75 -
Function: from [chess.lua:75] (1,600 usec, 5.1%); self: 1,000 usec
- -from -from -
- -GC -
Function: GC [GC:0] (900 usec, 2.9%); self: 900 usec
- -GC -GC -
- -GC -
Function: GC [GC:0] (600 usec, 1.9%); self: 600 usec
- -GC -GC -
-
-
- diff --git a/docs/assets/images/create-new-place.png b/docs/assets/images/create-new-place.png deleted file mode 100644 index 63a29242d..000000000 Binary files a/docs/assets/images/create-new-place.png and /dev/null differ diff --git a/docs/assets/images/create-script.png b/docs/assets/images/create-script.png deleted file mode 100644 index a0481ef0c..000000000 Binary files a/docs/assets/images/create-script.png and /dev/null differ diff --git a/docs/assets/images/error-isfoo.png b/docs/assets/images/error-isfoo.png deleted file mode 100644 index aa33ca6f2..000000000 Binary files a/docs/assets/images/error-isfoo.png and /dev/null differ diff --git a/docs/assets/images/error-ispositive-boolean.png b/docs/assets/images/error-ispositive-boolean.png deleted file mode 100644 index 643c4da3d..000000000 Binary files a/docs/assets/images/error-ispositive-boolean.png and /dev/null differ diff --git a/docs/assets/images/error-ispositive-string.png b/docs/assets/images/error-ispositive-string.png deleted file mode 100644 index 32cce31bf..000000000 Binary files a/docs/assets/images/error-ispositive-string.png and /dev/null differ diff --git a/docs/assets/images/error-ispositive.png b/docs/assets/images/error-ispositive.png deleted file mode 100644 index 183856c3c..000000000 Binary files a/docs/assets/images/error-ispositive.png and /dev/null differ diff --git a/docs/assets/images/example.png b/docs/assets/images/example.png deleted file mode 100644 index 2f4eedeff..000000000 Binary files a/docs/assets/images/example.png and /dev/null differ diff --git a/docs/assets/images/luau-88.png b/docs/assets/images/luau-88.png deleted file mode 100644 index f9571b8a0..000000000 Binary files a/docs/assets/images/luau-88.png and /dev/null differ diff --git a/docs/assets/images/luau-header.png b/docs/assets/images/luau-header.png deleted file mode 100644 index 5e88da325..000000000 Binary files a/docs/assets/images/luau-header.png and /dev/null differ diff --git a/docs/assets/images/luau-recap-august-2020-arrow.png b/docs/assets/images/luau-recap-august-2020-arrow.png deleted file mode 100644 index 282bc5a0f..000000000 Binary files a/docs/assets/images/luau-recap-august-2020-arrow.png and /dev/null differ diff --git a/docs/assets/images/luau-recap-august-2020-format.png b/docs/assets/images/luau-recap-august-2020-format.png deleted file mode 100644 index d180b442a..000000000 Binary files a/docs/assets/images/luau-recap-august-2020-format.png and /dev/null differ diff --git a/docs/assets/images/luau-recap-august-2020-format2.png b/docs/assets/images/luau-recap-august-2020-format2.png deleted file mode 100644 index c98fbf3d9..000000000 Binary files a/docs/assets/images/luau-recap-august-2020-format2.png and /dev/null differ diff --git a/docs/assets/images/luau-recap-august-2020-meta.png b/docs/assets/images/luau-recap-august-2020-meta.png deleted file mode 100644 index c7396f92f..000000000 Binary files a/docs/assets/images/luau-recap-august-2020-meta.png and /dev/null differ diff --git a/docs/assets/images/luau-recap-february-2021-benchmark.png b/docs/assets/images/luau-recap-february-2021-benchmark.png deleted file mode 100644 index a68332523..000000000 Binary files a/docs/assets/images/luau-recap-february-2021-benchmark.png and /dev/null differ diff --git a/docs/assets/images/luau-recap-february-2021-debugger.png b/docs/assets/images/luau-recap-february-2021-debugger.png deleted file mode 100644 index aca246450..000000000 Binary files a/docs/assets/images/luau-recap-february-2021-debugger.png and /dev/null differ diff --git a/docs/assets/images/luau-recap-june-2020-xkcd.png b/docs/assets/images/luau-recap-june-2020-xkcd.png deleted file mode 100644 index b7e2ed8ef..000000000 Binary files a/docs/assets/images/luau-recap-june-2020-xkcd.png and /dev/null differ diff --git a/docs/assets/images/luau-recap-march-2021-debug-after.png b/docs/assets/images/luau-recap-march-2021-debug-after.png deleted file mode 100644 index 4af00ba70..000000000 Binary files a/docs/assets/images/luau-recap-march-2021-debug-after.png and /dev/null differ diff --git a/docs/assets/images/luau-recap-march-2021-debug-before.png b/docs/assets/images/luau-recap-march-2021-debug-before.png deleted file mode 100644 index b5b6e0145..000000000 Binary files a/docs/assets/images/luau-recap-march-2021-debug-before.png and /dev/null differ diff --git a/docs/assets/images/luau-recap-march-2021-debug-dialog.png b/docs/assets/images/luau-recap-march-2021-debug-dialog.png deleted file mode 100644 index 0049b0bd6..000000000 Binary files a/docs/assets/images/luau-recap-march-2021-debug-dialog.png and /dev/null differ diff --git a/docs/assets/images/luau-recap-november-2019-option.png b/docs/assets/images/luau-recap-november-2019-option.png deleted file mode 100644 index 3b2b93c14..000000000 Binary files a/docs/assets/images/luau-recap-november-2019-option.png and /dev/null differ diff --git a/docs/assets/images/luau-type-checking-release-screenshot.png b/docs/assets/images/luau-type-checking-release-screenshot.png deleted file mode 100644 index 5afc78ef3..000000000 Binary files a/docs/assets/images/luau-type-checking-release-screenshot.png and /dev/null differ diff --git a/docs/assets/images/luau-type-checking-release-studio-option.png b/docs/assets/images/luau-type-checking-release-studio-option.png deleted file mode 100644 index e79aa8311..000000000 Binary files a/docs/assets/images/luau-type-checking-release-studio-option.png and /dev/null differ diff --git a/docs/assets/images/luau.png b/docs/assets/images/luau.png deleted file mode 100644 index c0ecbbb86..000000000 Binary files a/docs/assets/images/luau.png and /dev/null differ diff --git a/docs/assets/images/type-annotation-needed.png b/docs/assets/images/type-annotation-needed.png deleted file mode 100644 index 25307327b..000000000 Binary files a/docs/assets/images/type-annotation-needed.png and /dev/null differ diff --git a/docs/assets/images/type-annotation-provided.png b/docs/assets/images/type-annotation-provided.png deleted file mode 100644 index 7bac0340f..000000000 Binary files a/docs/assets/images/type-annotation-provided.png and /dev/null differ diff --git a/docs/assets/images/type-error-after-syntax-error.png b/docs/assets/images/type-error-after-syntax-error.png deleted file mode 100644 index 6d8b23a92..000000000 Binary files a/docs/assets/images/type-error-after-syntax-error.png and /dev/null differ diff --git a/docs/assets/images/type-refinement-in-action.png b/docs/assets/images/type-refinement-in-action.png deleted file mode 100644 index e3acf8eda..000000000 Binary files a/docs/assets/images/type-refinement-in-action.png and /dev/null differ diff --git a/docs/assets/js/luau_mode.js b/docs/assets/js/luau_mode.js deleted file mode 100644 index 0a0b933a6..000000000 --- a/docs/assets/js/luau_mode.js +++ /dev/null @@ -1,167 +0,0 @@ -// CodeMirror, copyright (c) by Marijn Haverbeke and others -// Distributed under an MIT license: https://codemirror.net/LICENSE - -// Luau mode. Based on Lua mode from CodeMirror and Franciszek Wawrzak (https://codemirror.net/mode/lua/lua.js) - -(function(mod) { - if (typeof exports == "object" && typeof module == "object") // CommonJS - mod(require("../../lib/codemirror")); - else if (typeof define == "function" && define.amd) // AMD - define(["../../lib/codemirror"], mod); - else // Plain browser env - mod(CodeMirror); -})(function(CodeMirror) { - "use strict"; - - CodeMirror.defineMode("luau", function(_, parserConfig) { - var indentUnit = 4; - - function prefixRE(words) { - return new RegExp("^(?:" + words.join("|") + ")", "i"); - } - function wordRE(words) { - return new RegExp("^(?:" + words.join("|") + ")$", "i"); - } - var specials = wordRE(parserConfig.specials || ["type"]); - - // long list of standard functions from lua manual - var builtins = wordRE([ - "_G","_VERSION","assert","error","getfenv","getmetatable","ipairs","load", "loadstring","next","pairs","pcall", - "print","rawequal","rawget","rawset","require","select","setfenv","setmetatable","tonumber","tostring","type", - "unpack","xpcall", - - "coroutine.create","coroutine.resume","coroutine.running","coroutine.status","coroutine.wrap","coroutine.yield", - - "debug.debug","debug.getfenv","debug.gethook","debug.getinfo","debug.getlocal","debug.getmetatable", - "debug.getregistry","debug.getupvalue","debug.setfenv","debug.sethook","debug.setlocal","debug.setmetatable", - "debug.setupvalue","debug.traceback", - - "math.abs","math.acos","math.asin","math.atan","math.atan2","math.ceil","math.cos","math.cosh","math.deg", - "math.exp","math.floor","math.fmod","math.frexp","math.huge","math.ldexp","math.log","math.log10","math.max", - "math.min","math.modf","math.pi","math.pow","math.rad","math.random","math.randomseed","math.sin","math.sinh", - "math.sqrt","math.tan","math.tanh", - - "os.clock","os.date","os.difftime","os.execute","os.exit","os.getenv","os.remove","os.rename","os.setlocale", - "os.time","os.tmpname", - - "string.byte","string.char","string.dump","string.find","string.format","string.gmatch","string.gsub", - "string.len","string.lower","string.match","string.rep","string.reverse","string.sub","string.upper", - - "table.concat","table.insert","table.maxn","table.remove","table.sort" - ]); - var keywords = wordRE(["and","break","elseif","false","nil","not","or","return", - "true","function", "end", "if", "then", "else", "do", - "while", "repeat", "until", "for", "in", "local", "continue" ]); - - var indentTokens = wordRE(["function", "if","repeat","do", "\\(", "{"]); - var dedentTokens = wordRE(["end", "until", "\\)", "}"]); - var dedentPartial = prefixRE(["end", "until", "\\)", "}", "else", "elseif"]); - - function readBracket(stream) { - var level = 0; - while (stream.eat("=")) ++level; - stream.eat("["); - return level; - } - - function normal(stream, state) { - var ch = stream.next(); - if (ch == "-" && stream.eat("-")) { - if (stream.eat("[") && stream.eat("[")) - return (state.cur = bracketed(readBracket(stream), "comment"))(stream, state); - stream.skipToEnd(); - return "comment"; - } - if (ch == "\"" || ch == "'") - return (state.cur = string(ch))(stream, state); - if (ch == "[" && /[\[=]/.test(stream.peek())) - return (state.cur = bracketed(readBracket(stream), "string"))(stream, state); - if (/\d/.test(ch)) { - stream.eatWhile(/[\w.%]/); - return "number"; - } - if (/[\w_]/.test(ch)) { - stream.eatWhile(/[\w\\\-_.]/); - return "variable"; - } - return null; - } - - function bracketed(level, style) { - return function(stream, state) { - var curlev = null, ch; - while ((ch = stream.next()) != null) { - if (curlev == null) { - if (ch == "]") curlev = 0; - } else if (ch == "=") { - ++curlev; - } else if (ch == "]" && curlev == level) { - state.cur = normal; - break; - } else { - curlev = null; - } - } - return style; - }; - } - - function string(quote) { - return function(stream, state) { - var escaped = false, ch; - while ((ch = stream.next()) != null) { - if (ch == quote && !escaped) { - break; - } - escaped = !escaped && ch == "\\"; - } - - if (!escaped) { - state.cur = normal; - } - return "string"; - }; - } - - return { - startState: function(basecol) { - return {basecol: basecol || 0, indentDepth: 0, cur: normal}; - }, - - token: function(stream, state) { - if (stream.eatSpace()) { - return null; - } - var style = state.cur(stream, state); - var word = stream.current(); - if (style == "variable") { - if (keywords.test(word)) { - style = "keyword"; - } else if (builtins.test(word)) { - style = "builtin"; - } else if (specials.test(word)) { - style = "variable-2"; - } - } - if ((style != "comment") && (style != "string")) { - if (indentTokens.test(word)) { - ++state.indentDepth; - } else if (dedentTokens.test(word)) { - --state.indentDepth; - } - } - return style; - }, - - indent: function(state, textAfter) { - var closing = dedentPartial.test(textAfter); - return state.basecol + indentUnit * (state.indentDepth - (closing ? 1 : 0)); - }, - - electricInput: /^\s*(?:end|until|else|\)|\})$/, - lineComment: "--", - blockCommentStart: "--[[", - blockCommentEnd: "]]" - }}); - CodeMirror.defineMIME("text/x-luau", "luau"); -}); \ No newline at end of file diff --git a/docs/favicon.ico b/docs/favicon.ico deleted file mode 100644 index 5b0660965..000000000 Binary files a/docs/favicon.ico and /dev/null differ diff --git a/docs/index.md b/docs/index.md deleted file mode 100644 index 3bea3b153..000000000 --- a/docs/index.md +++ /dev/null @@ -1,59 +0,0 @@ ---- -title: Lua*u* -layout: splash -permalink: / - -header: - overlay_color: #000 - overlay_filter: 0.8 - overlay_image: /assets/images/luau-header.png - -excerpt: > - Lua*u* (lowercase *u*, /ˈlu.aʊ/) is a fast, small, safe, gradually typed embeddable scripting language derived from Lua. - -feature_row1: - - - title: Motivation - excerpt: > - Around 2006, [Roblox](https://www.roblox.com) started using Lua 5.1 as a scripting language for games. Over the years we ended up substantially evolving the implementation and the language; to support growing sophistication of games on the Roblox platform, growing team sizes and large internal teams writing a lot of code for application/editor (1+MLOC as of 2020), we had to invest in performance, ease of use and language tooling, and introduce a gradual type system to the language. [More...](/why) - - - - title: Sandboxing - excerpt: > - Luau limits the set of standard libraries exposed to the users and implements extra sandboxing features to be able to run unprivileged code (written by our game developers) side by side with privileged code (written by us). This results in an execution environment that is different from what is commonplace in Lua. [More...](/sandbox) - - - - title: Compatibility - excerpt: > - Whenever possible, Luau aims to be backwards-compatible with Lua 5.1 and at the same time to incorporate features from later revisions of Lua. However, Luau is not a full superset of later versions of Lua - we do not always agree with Lua design decisions, and have different use cases and constraints. All post-5.1 Lua features, along with their support status in Luau, [are documented here](compatibility). - -feature_row2: - - - title: Syntax - image_path: /assets/images/example.png - excerpt: > - Luau is syntactically backwards-compatible with Lua 5.1 (code that is valid Lua 5.1 is also valid Luau); however, we have extended the language with a set of syntactical features that make the language more familiar and ergonomic. The syntax [is described here](syntax). - -feature_row3: - - - title: Analysis - excerpt: > - To make it easier to write correct code, Luau comes with a set of analysis tools that can surface common mistakes. These consist of a linter and a type checker, colloquially known as script analysis, and are integrated into `luau-analyze` command line executable. The linting passes are [described here](lint), and the type checking user guide can [be found here](typecheck). - - - - title: Performance - excerpt: > - In addition to a completely custom front end that implements parsing, linting and type checking, Luau runtime features new bytecode, interpreter and compiler that are heavily tuned for performance. Luau currently does not implement Just-In-Time compilation, but its interpreter can be competitive with LuaJIT interpreter depending on the program. We continue to optimize the runtime and rewrite portions of it to be even more efficient. While our overall goal is to minimize the amount of time programmers spend tuning performance, some details about the performance characteristics are [provided for inquisitive minds](performance). - - - - title: Libraries - excerpt: > - As a language, Luau is a full superset of Lua 5.1. As far as standard library is concerned, some functions had to be removed from the builtin libraries, and some functions had to be added; refer to [full documentation](/library) for details. When Luau is embedded into an application, the scripts normally get access to extra library features that are application-specific. - ---- - -{% include feature_row id="feature_row1" %} - -{% include feature_row id="feature_row2" type="left" %} - -{% include feature_row id="feature_row3" %} diff --git a/docs/logo.svg b/docs/logo.svg deleted file mode 100644 index 55253947a..000000000 --- a/docs/logo.svg +++ /dev/null @@ -1,5 +0,0 @@ - - - - - diff --git a/extern/doctest.h b/extern/doctest.h index aa2724c73..42d93c7f3 100644 --- a/extern/doctest.h +++ b/extern/doctest.h @@ -3139,7 +3139,11 @@ DOCTEST_MAKE_STD_HEADERS_CLEAN_FROM_WARNINGS_ON_WALL_BEGIN #include #include #include + +#if !defined(DOCTEST_CONFIG_NO_POSIX_SIGNALS) #include +#endif + #include #include #include @@ -5667,6 +5671,8 @@ namespace { std::tm timeInfo; #ifdef DOCTEST_PLATFORM_WINDOWS gmtime_s(&timeInfo, &rawtime); +#elif defined(DOCTEST_CONFIG_USE_GMTIME_S) + gmtime_s(&rawtime, &timeInfo); #else // DOCTEST_PLATFORM_WINDOWS gmtime_r(&rawtime, &timeInfo); #endif // DOCTEST_PLATFORM_WINDOWS diff --git a/fuzz/CMakeLists.txt b/fuzz/CMakeLists.txt new file mode 100644 index 000000000..be40b811e --- /dev/null +++ b/fuzz/CMakeLists.txt @@ -0,0 +1,108 @@ +# This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +if(${CMAKE_VERSION} VERSION_LESS "3.26") + message(WARNING "Building the Luau fuzzer requires CMake version 3.26 or higher.") + return() +endif() + +include(FetchContent) + +cmake_policy(SET CMP0054 NEW) +cmake_policy(SET CMP0058 NEW) +cmake_policy(SET CMP0074 NEW) +cmake_policy(SET CMP0077 NEW) +cmake_policy(SET CMP0091 NEW) + +if(NOT CMAKE_CXX_COMPILER_ID STREQUAL "Clang") + message(WARNING "Building the Luau fuzzer requires Clang to be used. AppleClang is not sufficient.") + return() +endif() + +if(NOT CMAKE_OSX_ARCHITECTURES STREQUAL "x86_64") + message(WARNING "Building the Luau fuzzer for ARM64 is currently unsupported.") + return() +endif() + +# protobuf / std integer types vary based on platform; disable sign-compare +# warnings for portability. +set(FUZZ_COMPILE_OPTIONS ${LUAU_OPTIONS} -fsanitize=address,fuzzer -g2 -Wno-sign-compare) +set(FUZZ_LINK_OPTIONS ${LUAU_OPTIONS} -fsanitize=address,fuzzer) + +FetchContent_Declare( + ProtobufMutator + GIT_REPOSITORY https://github.com/google/libprotobuf-mutator + GIT_TAG 212a7be1eb08e7f9c79732d2aab9b2097085d936 + # libprotobuf-mutator unconditionally configures its examples, but this + # doesn't actually work with how we're building Protobuf from source. This + # patch disables configuration of the examples. + PATCH_COMMAND + git apply + --reverse + --check + --ignore-space-change + --ignore-whitespace + "${CMAKE_CURRENT_SOURCE_DIR}/libprotobuf-mutator-patch.patch" + || + git apply + --ignore-space-change + --ignore-whitespace + "${CMAKE_CURRENT_SOURCE_DIR}/libprotobuf-mutator-patch.patch" +) + +FetchContent_Declare( + Protobuf + GIT_REPOSITORY https://github.com/protocolbuffers/protobuf.git + # Needs to match the Protobuf version that libprotobuf-mutator is written for, roughly. + GIT_TAG v22.3 + GIT_SHALLOW ON + + # libprotobuf-mutator will need to be able to find this at configuration + # time. + OVERRIDE_FIND_PACKAGE +) + +set(protobuf_BUILD_TESTS OFF) +set(protobuf_BUILD_SHARED_LIBS OFF) +# libprotobuf-mutator relies on older module support. +set(protobuf_MODULE_COMPATIBLE ON) + +find_package(Protobuf CONFIG REQUIRED) + +# libprotobuf-mutator happily ignores CMP0077 because of its minimum version +# requirement. To override that, we set the policy default here. +set(CMAKE_POLICY_DEFAULT_CMP0077 NEW) +set(LIB_PROTO_MUTATOR_TESTING OFF) + +FetchContent_MakeAvailable(ProtobufMutator) + +# This patches around the fact that find_package isn't going to set the right +# values for libprotobuf-mutator to link against protobuf libraries. +target_link_libraries(protobuf-mutator-libfuzzer protobuf::libprotobuf) +target_link_libraries(protobuf-mutator protobuf::libprotobuf) + +set(LUAU_PB_DIR ${CMAKE_CURRENT_BINARY_DIR}/protobuf) +set(LUAU_PB_SOURCES ${LUAU_PB_DIR}/luau.pb.cc ${LUAU_PB_DIR}/luau.pb.h) + +add_custom_command( + OUTPUT ${LUAU_PB_SOURCES} + COMMAND ${CMAKE_COMMAND} -E make_directory ${LUAU_PB_DIR} + COMMAND $ ${CMAKE_CURRENT_SOURCE_DIR}/luau.proto --proto_path=${CMAKE_CURRENT_SOURCE_DIR} --cpp_out=${LUAU_PB_DIR} + DEPENDS protobuf::protoc ${CMAKE_CURRENT_SOURCE_DIR}/luau.proto +) + +add_executable(Luau.Fuzz.Proto) +target_compile_options(Luau.Fuzz.Proto PRIVATE ${FUZZ_COMPILE_OPTIONS}) +target_link_options(Luau.Fuzz.Proto PRIVATE ${FUZZ_LINK_OPTIONS}) +target_compile_features(Luau.Fuzz.Proto PRIVATE cxx_std_17) +target_include_directories(Luau.Fuzz.Proto PRIVATE ${LUAU_PB_DIR} ${protobufmutator_SOURCE_DIR}) +target_sources(Luau.Fuzz.Proto PRIVATE ${LUAU_PB_SOURCES} proto.cpp protoprint.cpp) +target_link_libraries(Luau.Fuzz.Proto PRIVATE protobuf::libprotobuf protobuf-mutator-libfuzzer protobuf-mutator Luau.Analysis Luau.Compiler Luau.Ast Luau.Config Luau.VM Luau.CodeGen) +set_target_properties(Luau.Fuzz.Proto PROPERTIES CXX_STANDARD_REQUIRED ON CXX_EXTENSIONS OFF OUTPUT_NAME fuzz-proto) + +add_executable(Luau.Fuzz.ProtoTest) +target_compile_options(Luau.Fuzz.ProtoTest PRIVATE ${FUZZ_COMPILE_OPTIONS}) +target_link_options(Luau.Fuzz.ProtoTest PRIVATE ${FUZZ_LINK_OPTIONS}) +target_compile_features(Luau.Fuzz.ProtoTest PRIVATE cxx_std_17) +target_include_directories(Luau.Fuzz.ProtoTest PRIVATE ${LUAU_PB_DIR} ${protobufmutator_SOURCE_DIR}) +target_sources(Luau.Fuzz.ProtoTest PRIVATE ${LUAU_PB_SOURCES} prototest.cpp protoprint.cpp) +target_link_libraries(Luau.Fuzz.ProtoTest PRIVATE protobuf::libprotobuf protobuf-mutator-libfuzzer protobuf-mutator) +set_target_properties(Luau.Fuzz.ProtoTest PROPERTIES CXX_STANDARD_REQUIRED ON CXX_EXTENSIONS OFF OUTPUT_NAME fuzz-prototest) diff --git a/fuzz/format.cpp b/fuzz/format.cpp index 3ad3912f3..4b943bf1b 100644 --- a/fuzz/format.cpp +++ b/fuzz/format.cpp @@ -1,6 +1,7 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/Common.h" +#include #include #include diff --git a/fuzz/libprotobuf-mutator-patch.patch b/fuzz/libprotobuf-mutator-patch.patch new file mode 100644 index 000000000..ee41a5403 --- /dev/null +++ b/fuzz/libprotobuf-mutator-patch.patch @@ -0,0 +1,12 @@ +diff --git a/CMakeLists.txt b/CMakeLists.txt +index 4805c82..9f0df5c 100644 +--- a/CMakeLists.txt ++++ b/CMakeLists.txt +@@ -149,7 +149,6 @@ add_subdirectory(src) + + if (NOT "${LIB_PROTO_MUTATOR_FUZZER_LIBRARIES}" STREQUAL "" OR + NOT "${FUZZING_FLAGS}" STREQUAL "") +- add_subdirectory(examples EXCLUDE_FROM_ALL) + endif() + + install(EXPORT libprotobuf-mutatorTargets FILE libprotobuf-mutatorTargets.cmake diff --git a/fuzz/linter.cpp b/fuzz/linter.cpp index 66ca5bb14..8efd42469 100644 --- a/fuzz/linter.cpp +++ b/fuzz/linter.cpp @@ -3,10 +3,10 @@ #include "Luau/BuiltinDefinitions.h" #include "Luau/Common.h" +#include "Luau/Frontend.h" #include "Luau/Linter.h" #include "Luau/ModuleResolver.h" #include "Luau/Parser.h" -#include "Luau/TypeInfer.h" extern "C" int LLVMFuzzerTestOneInput(const uint8_t* Data, size_t Size) { @@ -18,18 +18,17 @@ extern "C" int LLVMFuzzerTestOneInput(const uint8_t* Data, size_t Size) Luau::ParseResult parseResult = Luau::Parser::parse(reinterpret_cast(Data), Size, names, allocator, options); // "static" here is to accelerate fuzzing process by only creating and populating the type environment once - static Luau::NullModuleResolver moduleResolver; - static Luau::InternalErrorReporter iceHandler; - static Luau::TypeChecker sharedEnv(&moduleResolver, &iceHandler); - static int once = (Luau::registerBuiltinGlobals(sharedEnv), 1); + static Luau::NullFileResolver fileResolver; + static Luau::NullConfigResolver configResolver; + static Luau::Frontend frontend{&fileResolver, &configResolver}; + static int once = (Luau::registerBuiltinGlobals(frontend, frontend.globals, false), 1); (void)once; - static int once2 = (Luau::freeze(sharedEnv.globalTypes), 1); + static int once2 = (Luau::freeze(frontend.globals.globalTypes), 1); (void)once2; if (parseResult.errors.empty()) { - Luau::TypeChecker typeck(&moduleResolver, &iceHandler); - typeck.globalScope = sharedEnv.globalScope; + Luau::TypeChecker typeck(frontend.globals.globalScope, &frontend.moduleResolver, frontend.builtinTypes, &frontend.iceHandler); Luau::LintOptions lintOptions; lintOptions.warningMask = ~0ull; diff --git a/fuzz/luau.proto b/fuzz/luau.proto index e51d687bd..5fed9ddc9 100644 --- a/fuzz/luau.proto +++ b/fuzz/luau.proto @@ -135,17 +135,18 @@ message ExprBinary { Sub = 1; Mul = 2; Div = 3; - Mod = 4; - Pow = 5; - Concat = 6; - CompareNe = 7; - CompareEq = 8; - CompareLt = 9; - CompareLe = 10; - CompareGt = 11; - CompareGe = 12; - And = 13; - Or = 14; + FloorDiv = 4; + Mod = 5; + Pow = 6; + Concat = 7; + CompareNe = 8; + CompareEq = 9; + CompareLt = 10; + CompareLe = 11; + CompareGt = 12; + CompareGe = 13; + And = 14; + Or = 15; } required Op op = 1; diff --git a/fuzz/proto.cpp b/fuzz/proto.cpp index 5ca165765..9c2ab35c2 100644 --- a/fuzz/proto.cpp +++ b/fuzz/proto.cpp @@ -7,6 +7,7 @@ #include "Luau/CodeGen.h" #include "Luau/Common.h" #include "Luau/Compiler.h" +#include "Luau/Config.h" #include "Luau/Frontend.h" #include "Luau/Linter.h" #include "Luau/ModuleResolver.h" @@ -19,20 +20,31 @@ #include "lualib.h" #include +#include + +static bool getEnvParam(const char* name, bool def) +{ + char* val = getenv(name); + if (val == nullptr) + return def; + else + return strcmp(val, "0") != 0; +} // Select components to fuzz -const bool kFuzzCompiler = true; -const bool kFuzzLinter = true; -const bool kFuzzTypeck = true; -const bool kFuzzVM = true; -const bool kFuzzTranspile = true; -const bool kFuzzCodegen = true; +const bool kFuzzCompiler = getEnvParam("LUAU_FUZZ_COMPILER", true); +const bool kFuzzLinter = getEnvParam("LUAU_FUZZ_LINTER", true); +const bool kFuzzTypeck = getEnvParam("LUAU_FUZZ_TYPE_CHECK", true); +const bool kFuzzVM = getEnvParam("LUAU_FUZZ_VM", true); +const bool kFuzzTranspile = getEnvParam("LUAU_FUZZ_TRANSPILE", true); +const bool kFuzzCodegenVM = getEnvParam("LUAU_FUZZ_CODEGEN_VM", true); +const bool kFuzzCodegenAssembly = getEnvParam("LUAU_FUZZ_CODEGEN_ASM", true); +const bool kFuzzUseNewSolver = getEnvParam("LUAU_FUZZ_NEW_SOLVER", false); // Should we generate type annotations? -const bool kFuzzTypes = true; +const bool kFuzzTypes = getEnvParam("LUAU_FUZZ_GEN_TYPES", true); -static_assert(!(kFuzzVM && !kFuzzCompiler), "VM requires the compiler!"); -static_assert(!(kFuzzCodegen && !kFuzzVM), "Codegen requires the VM!"); +const Luau::CodeGen::AssemblyOptions::Target kFuzzCodegenTarget = Luau::CodeGen::AssemblyOptions::A64; std::vector protoprint(const luau::ModuleSet& stat, bool types); @@ -43,6 +55,8 @@ LUAU_FASTINT(LuauTableTypeMaximumStringifierLength) LUAU_FASTINT(LuauTypeInferIterationLimit) LUAU_FASTINT(LuauTarjanChildLimit) LUAU_FASTFLAG(DebugLuauFreezeArena) +LUAU_FASTFLAG(DebugLuauAbortingChecks) +LUAU_FASTFLAG(LuauSolverV2) std::chrono::milliseconds kInterruptTimeout(10); std::chrono::time_point interruptDeadline; @@ -86,7 +100,7 @@ lua_State* createGlobalState() { lua_State* L = lua_newstate(allocate, NULL); - if (kFuzzCodegen && Luau::CodeGen::isSupported()) + if (kFuzzCodegenVM && Luau::CodeGen::isSupported()) Luau::CodeGen::create(L); lua_callbacks(L)->interrupt = interrupt; @@ -97,48 +111,49 @@ lua_State* createGlobalState() return L; } -int registerTypes(Luau::TypeChecker& env) +int registerTypes(Luau::Frontend& frontend, Luau::GlobalTypes& globals, bool forAutocomplete) { using namespace Luau; using std::nullopt; - Luau::registerBuiltinGlobals(env); + Luau::registerBuiltinGlobals(frontend, globals, forAutocomplete); - TypeArena& arena = env.globalTypes; + TypeArena& arena = globals.globalTypes; + BuiltinTypes& builtinTypes = *globals.builtinTypes; // Vector3 stub TypeId vector3MetaType = arena.addType(TableType{}); - TypeId vector3InstanceType = arena.addType(ClassType{"Vector3", {}, nullopt, vector3MetaType, {}, {}, "Test"}); + TypeId vector3InstanceType = arena.addType(ClassType{"Vector3", {}, nullopt, vector3MetaType, {}, {}, "Test", {}}); getMutable(vector3InstanceType)->props = { - {"X", {env.numberType}}, - {"Y", {env.numberType}}, - {"Z", {env.numberType}}, + {"X", {builtinTypes.numberType}}, + {"Y", {builtinTypes.numberType}}, + {"Z", {builtinTypes.numberType}}, }; getMutable(vector3MetaType)->props = { {"__add", {makeFunction(arena, nullopt, {vector3InstanceType, vector3InstanceType}, {vector3InstanceType})}}, }; - env.globalScope->exportedTypeBindings["Vector3"] = TypeFun{{}, vector3InstanceType}; + globals.globalScope->exportedTypeBindings["Vector3"] = TypeFun{{}, vector3InstanceType}; // Instance stub - TypeId instanceType = arena.addType(ClassType{"Instance", {}, nullopt, nullopt, {}, {}, "Test"}); + TypeId instanceType = arena.addType(ClassType{"Instance", {}, nullopt, nullopt, {}, {}, "Test", {}}); getMutable(instanceType)->props = { - {"Name", {env.stringType}}, + {"Name", {builtinTypes.stringType}}, }; - env.globalScope->exportedTypeBindings["Instance"] = TypeFun{{}, instanceType}; + globals.globalScope->exportedTypeBindings["Instance"] = TypeFun{{}, instanceType}; // Part stub - TypeId partType = arena.addType(ClassType{"Part", {}, instanceType, nullopt, {}, {}, "Test"}); + TypeId partType = arena.addType(ClassType{"Part", {}, instanceType, nullopt, {}, {}, "Test", {}}); getMutable(partType)->props = { {"Position", {vector3InstanceType}}, }; - env.globalScope->exportedTypeBindings["Part"] = TypeFun{{}, partType}; + globals.globalScope->exportedTypeBindings["Part"] = TypeFun{{}, partType}; - for (const auto& [_, fun] : env.globalScope->exportedTypeBindings) + for (const auto& [_, fun] : globals.globalScope->exportedTypeBindings) persist(fun.type); return 0; @@ -146,13 +161,14 @@ int registerTypes(Luau::TypeChecker& env) static void setupFrontend(Luau::Frontend& frontend) { - registerTypes(frontend.typeChecker); - Luau::freeze(frontend.typeChecker.globalTypes); + registerTypes(frontend, frontend.globals, false); + Luau::freeze(frontend.globals.globalTypes); - registerTypes(frontend.typeCheckerForAutocomplete); - Luau::freeze(frontend.typeCheckerForAutocomplete.globalTypes); + registerTypes(frontend, frontend.globalsForAutocomplete, true); + Luau::freeze(frontend.globalsForAutocomplete.globalTypes); - frontend.iceHandler.onInternalError = [](const char* error) { + frontend.iceHandler.onInternalError = [](const char* error) + { printf("ICE: %s\n", error); LUAU_ASSERT(!"ICE"); }; @@ -211,6 +227,13 @@ static std::vector debugsources; DEFINE_PROTO_FUZZER(const luau::ModuleSet& message) { + if (!kFuzzCompiler && (kFuzzCodegenAssembly || kFuzzCodegenVM || kFuzzVM)) + { + printf("Compiler is required in order to fuzz codegen or the VM\n"); + LUAU_ASSERT(false); + return; + } + FInt::LuauTypeInferRecursionLimit.value = 100; FInt::LuauTypeInferTypePackLoopLimit.value = 100; FInt::LuauCheckRecursionLimit.value = 100; @@ -223,6 +246,8 @@ DEFINE_PROTO_FUZZER(const luau::ModuleSet& message) flag->value = true; FFlag::DebugLuauFreezeArena.value = true; + FFlag::DebugLuauAbortingChecks.value = true; + FFlag::LuauSolverV2.value = kFuzzUseNewSolver; std::vector sources = protoprint(message, kFuzzTypes); @@ -260,10 +285,11 @@ DEFINE_PROTO_FUZZER(const luau::ModuleSet& message) { static FuzzFileResolver fileResolver; static FuzzConfigResolver configResolver; - static Luau::FrontendOptions options{true, true}; - static Luau::Frontend frontend(&fileResolver, &configResolver, options); + static Luau::FrontendOptions defaultOptions{/*retainFullTypeGraphs*/ true, /*forAutocomplete*/ false, /*runLintChecks*/ kFuzzLinter}; + static Luau::Frontend frontend(&fileResolver, &configResolver, defaultOptions); static int once = (setupFrontend(frontend), 0); + (void)once; // restart frontend.clear(); @@ -283,16 +309,12 @@ DEFINE_PROTO_FUZZER(const luau::ModuleSet& message) try { - Luau::CheckResult result = frontend.check(name, std::nullopt); - - // lint (note that we need access to types so we need to do this with typeck in scope) - if (kFuzzLinter && result.errors.empty()) - frontend.lint(name, std::nullopt); + frontend.check(name); // Second pass in strict mode (forced by auto-complete) - Luau::FrontendOptions opts; - opts.forAutocomplete = true; - frontend.check(name, opts); + Luau::FrontendOptions options = defaultOptions; + options.forAutocomplete = true; + frontend.check(name, options); } catch (std::exception&) { @@ -302,7 +324,7 @@ DEFINE_PROTO_FUZZER(const luau::ModuleSet& message) // validate sharedEnv post-typecheck; valuable for debugging some typeck crashes but slows fuzzing down // note: it's important for typeck to be destroyed at this point! - for (auto& p : frontend.typeChecker.globalScope->bindings) + for (auto& p : frontend.globals.globalScope->bindings) { Luau::ToStringOptions opts; opts.exhaustive = true; @@ -350,19 +372,38 @@ DEFINE_PROTO_FUZZER(const luau::ModuleSet& message) } } + // run codegen on resulting bytecode (in separate state) + if (kFuzzCodegenAssembly && bytecode.size()) + { + static lua_State* globalState = luaL_newstate(); + + if (luau_load(globalState, "=fuzz", bytecode.data(), bytecode.size(), 0) == 0) + { + Luau::CodeGen::AssemblyOptions options; + options.compilationOptions.flags = Luau::CodeGen::CodeGen_ColdFunctions; + options.outputBinary = true; + options.target = kFuzzCodegenTarget; + Luau::CodeGen::getAssembly(globalState, -1, options); + } + + lua_pop(globalState, 1); + lua_gc(globalState, LUA_GCCOLLECT, 0); + } + // run resulting bytecode (from last successfully compiler module) - if (kFuzzVM && bytecode.size()) + if ((kFuzzVM || kFuzzCodegenVM) && bytecode.size()) { static lua_State* globalState = createGlobalState(); - auto runCode = [](const std::string& bytecode, bool useCodegen) { + auto runCode = [](const std::string& bytecode, bool useCodegen) + { lua_State* L = lua_newthread(globalState); luaL_sandboxthread(L); if (luau_load(L, "=fuzz", bytecode.data(), bytecode.size(), 0) == 0) { if (useCodegen) - Luau::CodeGen::compile(L, -1); + Luau::CodeGen::compile(L, -1, Luau::CodeGen::CodeGen_ColdFunctions); interruptDeadline = std::chrono::system_clock::now() + kInterruptTimeout; @@ -376,9 +417,10 @@ DEFINE_PROTO_FUZZER(const luau::ModuleSet& message) LUAU_ASSERT(heapSize < 256 * 1024); }; - runCode(bytecode, false); + if (kFuzzVM) + runCode(bytecode, false); - if (kFuzzCodegen && Luau::CodeGen::isSupported()) + if (kFuzzCodegenVM && Luau::CodeGen::isSupported()) runCode(bytecode, true); } } diff --git a/fuzz/protoprint.cpp b/fuzz/protoprint.cpp index 5c7c5bf60..0adc09681 100644 --- a/fuzz/protoprint.cpp +++ b/fuzz/protoprint.cpp @@ -2,172 +2,38 @@ #include "luau.pb.h" static const std::string kNames[] = { - "_G", - "_LOADED", - "_VERSION", - "__add", - "__call", - "__concat", - "__div", - "__eq", - "__index", - "__iter", - "__le", - "__len", - "__lt", - "__mod", - "__mode", - "__mul", - "__namecall", - "__newindex", - "__pow", - "__sub", - "__type", - "__unm", - "abs", - "acos", - "arshift", - "asin", - "assert", - "atan", - "atan2", - "band", - "bit32", - "bnot", - "boolean", - "bor", - "btest", - "bxor", - "byte", - "ceil", - "char", - "charpattern", - "clamp", - "clock", - "clone", - "close", - "codepoint", - "codes", - "concat", - "coroutine", - "cos", - "cosh", - "countlz", - "countrz", - "create", - "date", - "debug", - "deg", - "difftime", - "error", - "exp", - "extract", - "find", - "floor", - "fmod", - "foreach", - "foreachi", - "format", - "frexp", - "freeze", - "function", - "gcinfo", - "getfenv", - "getinfo", - "getmetatable", - "getn", - "gmatch", - "gsub", - "huge", - "info", - "insert", - "ipairs", - "isfrozen", - "isyieldable", - "ldexp", - "len", - "loadstring", - "log", - "log10", - "lower", - "lrotate", - "lshift", - "match", - "math", - "max", - "maxn", - "min", - "modf", - "move", - "newproxy", - "next", - "nil", - "noise", - "number", - "offset", - "os", - "pack", - "packsize", - "pairs", - "pcall", - "pi", - "pow", - "print", - "rad", - "random", - "randomseed", - "rawequal", - "rawget", - "rawset", - "remove", - "rep", - "replace", - "require", - "resume", - "reverse", - "rrotate", - "rshift", - "running", - "select", - "setfenv", - "setmetatable", - "sign", - "sin", - "sinh", - "sort", - "split", - "sqrt", - "status", - "stdin", - "string", - "sub", - "table", - "tan", - "tanh", - "thread", - "time", - "tonumber", - "tostring", - "traceback", - "type", - "typeof", - "unpack", - "upper", - "userdata", - "utf8", - "vector", - "wrap", - "xpcall", - "yield", + "_G", "_VERSION", "__add", "__call", "__concat", "__div", "__eq", "__idiv", "__index", + "__iter", "__le", "__len", "__lt", "__mod", "__mode", "__mul", "__namecall", "__newindex", + "__pow", "__sub", "__type", "__unm", "abs", "acos", "arshift", "asin", "assert", + "atan", "atan2", "band", "bit32", "bnot", "boolean", "bor", "btest", "buffer", + "bxor", "byte", "ceil", "char", "charpattern", "clamp", "clear", "clock", "clone", + "close", "codepoint", "codes", "collectgarbage", "concat", "copy", "coroutine", "cos", "cosh", + "countlz", "countrz", "create", "date", "debug", "deg", "difftime", "error", "exp", + "extract", "fill", "find", "floor", "fmod", "foreach", "foreachi", "format", "freeze", + "frexp", "fromstring", "function", "gcinfo", "getfenv", "getmetatable", "getn", "gmatch", "gsub", + "huge", "info", "insert", "ipairs", "isfrozen", "isyieldable", "ldexp", "len", "loadstring", + "log", "log10", "lower", "lrotate", "lshift", "match", "math", "max", "maxn", + "min", "modf", "move", "newproxy", "next", "nil", "noise", "number", "offset", + "os", "pack", "packsize", "pairs", "pcall", "pi", "pow", "print", "rad", + "random", "randomseed", "rawequal", "rawget", "rawlen", "rawset", "readf32", "readf64", "readi16", + "readi32", "readi8", "readstring", "readu16", "readu32", "readu8", "remove", "rep", "replace", + "require", "resume", "reverse", "round", "rrotate", "rshift", "running", "select", "setfenv", + "setmetatable", "sign", "sin", "sinh", "sort", "split", "sqrt", "status", "string", + "sub", "table", "tan", "tanh", "thread", "time", "tonumber", "tostring", "tostring", + "traceback", "type", "typeof", "unpack", "upper", "userdata", "utf8", "vector", "wrap", + "writef32", "writef64", "writei16", "writei32", "writei8", "writestring", "writeu16", "writeu32", "writeu8", + "xpcall", "yield", }; static const std::string kTypes[] = { "any", + "boolean", + "buffer", "nil", "number", "string", - "boolean", "thread", + "vector", }; static const std::string kClasses[] = { @@ -495,6 +361,8 @@ struct ProtoToLuau source += " * "; else if (expr.op() == luau::ExprBinary::Div) source += " / "; + else if (expr.op() == luau::ExprBinary::FloorDiv) + source += " // "; else if (expr.op() == luau::ExprBinary::Mod) source += " % "; else if (expr.op() == luau::ExprBinary::Pow) diff --git a/fuzz/typeck.cpp b/fuzz/typeck.cpp index a6c9ae284..87a882717 100644 --- a/fuzz/typeck.cpp +++ b/fuzz/typeck.cpp @@ -3,9 +3,9 @@ #include "Luau/BuiltinDefinitions.h" #include "Luau/Common.h" +#include "Luau/Frontend.h" #include "Luau/ModuleResolver.h" #include "Luau/Parser.h" -#include "Luau/TypeInfer.h" LUAU_FASTINT(LuauTypeInferRecursionLimit) LUAU_FASTINT(LuauTypeInferTypePackLoopLimit) @@ -23,23 +23,22 @@ extern "C" int LLVMFuzzerTestOneInput(const uint8_t* Data, size_t Size) Luau::ParseResult parseResult = Luau::Parser::parse(reinterpret_cast(Data), Size, names, allocator, options); // "static" here is to accelerate fuzzing process by only creating and populating the type environment once - static Luau::NullModuleResolver moduleResolver; - static Luau::InternalErrorReporter iceHandler; - static Luau::TypeChecker sharedEnv(&moduleResolver, &iceHandler); - static int once = (Luau::registerBuiltinGlobals(sharedEnv), 1); + static Luau::NullFileResolver fileResolver; + static Luau::NullConfigResolver configResolver; + static Luau::Frontend frontend{&fileResolver, &configResolver}; + static int once = (Luau::registerBuiltinGlobals(frontend, frontend.globals, false), 1); (void)once; - static int once2 = (Luau::freeze(sharedEnv.globalTypes), 1); + static int once2 = (Luau::freeze(frontend.globals.globalTypes), 1); (void)once2; if (parseResult.errors.empty()) { + Luau::TypeChecker typeck(frontend.globals.globalScope, &frontend.moduleResolver, frontend.builtinTypes, &frontend.iceHandler); + Luau::SourceModule module; module.root = parseResult.root; module.mode = Luau::Mode::Nonstrict; - Luau::TypeChecker typeck(&moduleResolver, &iceHandler); - typeck.globalScope = sharedEnv.globalScope; - try { typeck.check(module, Luau::Mode::Nonstrict); diff --git a/jni/Android.mk b/jni/Android.mk index ad23483e6..0fbd2d001 100644 --- a/jni/Android.mk +++ b/jni/Android.mk @@ -17,9 +17,9 @@ LOCAL_CFLAGS := -O2 -DLUA_USE_MKSTEMP -std=c++17 LOCAL_C_INCLUDES += $(addprefix $(LOCAL_PATH)/../,VM/src VM/include Common/Include Compiler/include Ast/include ../libgvfs) LOCAL_SRC_FILES += $(addsuffix .cpp, \ - $(addprefix ../VM/src/,lapi laux lbaselib lbitlib lbuiltins lcorolib ldblib ldebug ldo lfunc lgc lgcdebug linit lint64lib liolib lmathlib lmem lnumprint lobject loslib lperf lstate lstring lstrlib \ + $(addprefix ../VM/src/,lapi laux lbaselib lbitlib lbuffer lbuflib lbuiltins lcorolib ldblib ldebug ldo lfunc lgc lgcdebug linit lint64lib liolib lmathlib lmem lnumprint lobject loslib lperf lstate lstring lstrlib \ ltable ltablib ltm ludata lutf8lib lvmexecute lvmload lvmutils) \ - $(addprefix ../Compiler/src/,Builtins BuiltinFolding BytecodeBuilder ConstantFolding Compiler CostModel lcode PseudoCode TableShape ValueTracking) \ + $(addprefix ../Compiler/src/,Builtins BuiltinFolding BytecodeBuilder ConstantFolding Compiler CostModel lcode PseudoCode TableShape Types ValueTracking) \ $(addprefix ../Ast/src/,Ast Confusables Lexer Location Parser StringUtils TimeTrace)) LOCAL_LDLIBS := -ldl diff --git a/luau.pro b/luau.pro index a8adf48aa..72461dca2 100644 --- a/luau.pro +++ b/luau.pro @@ -38,9 +38,9 @@ INCLUDEPATH += \ ../libgvfs SOURCES += \ - $$expand(lapi laux lbaselib lbitlib lbuiltins lcorolib ldblib ldebug ldo lfunc lgc lgcdebug linit lint64lib liolib lmathlib lmem lnumprint lobject loslib lperf lstate lstring lstrlib \ + $$expand(lapi laux lbaselib lbitlib lbuffer lbuflib lbuiltins lcorolib ldblib ldebug ldo lfunc lgc lgcdebug linit lint64lib liolib lmathlib lmem lnumprint lobject loslib lperf lstate lstring lstrlib \ ltable ltablib ltm ludata lutf8lib lvmexecute lvmload lvmutils,VM/src/,.cpp) \ - $$expand(Builtins BuiltinFolding BytecodeBuilder ConstantFolding Compiler CostModel lcode PseudoCode TableShape ValueTracking,Compiler/src/,.cpp) \ + $$expand(Builtins BuiltinFolding BytecodeBuilder ConstantFolding Compiler CostModel lcode PseudoCode TableShape Types ValueTracking,Compiler/src/,.cpp) \ $$expand(Ast Confusables Lexer Location Parser StringUtils TimeTrace,Ast/src/,.cpp) win32 { diff --git a/luawinrt/luawinrt/luawinrt.Shared/luawinrt.Shared.vcxitems b/luawinrt/luawinrt/luawinrt.Shared/luawinrt.Shared.vcxitems index 63f281329..3854c6f6d 100644 --- a/luawinrt/luawinrt/luawinrt.Shared/luawinrt.Shared.vcxitems +++ b/luawinrt/luawinrt/luawinrt.Shared/luawinrt.Shared.vcxitems @@ -33,11 +33,14 @@ + + + @@ -69,6 +72,7 @@ + diff --git a/papers/.gitignore b/papers/.gitignore deleted file mode 100644 index 86338e143..000000000 --- a/papers/.gitignore +++ /dev/null @@ -1,12 +0,0 @@ -*.aux -*.bbl -*.blg -*.dvi -*.fdb_latexmk -*.fls -*.log -*.out -*.xcp -*.nav -*.snm -*.toc diff --git a/papers/hatra21/Logo-Roblox-Black-Full.png b/papers/hatra21/Logo-Roblox-Black-Full.png deleted file mode 100644 index a792fd0b6..000000000 Binary files a/papers/hatra21/Logo-Roblox-Black-Full.png and /dev/null differ diff --git a/papers/hatra21/README.md b/papers/hatra21/README.md deleted file mode 100644 index 3105bb5bb..000000000 --- a/papers/hatra21/README.md +++ /dev/null @@ -1,45 +0,0 @@ -# HATRA 21 position paper - -A position paper on Luau for [Human Aspects of Types and Reasoning Assistants](https://2021.splashcon.org/home/hatra-2021) (HATRA) 2021. - -## Installing latexmk - -First install basictex -``` -sudo brew install basictex -``` - -Then install the dependencies for the paper (sigh, by hand): - -``` -sudo tlmgr update --all -sudo tlmgr install acmart -sudo tlmgr install iftex -sudo tlmgr install xstring -sudo tlmgr install environ -sudo tlmgr install totpages -sudo tlmgr install trimspaces -sudo tlmgr install manyfoot -sudo tlmgr install ncctools -sudo tlmgr install comment -sudo tlmgr install balance -sudo tlmgr install preprint -sudo tlmgr install libertine -sudo tlmgr install inconsolata -sudo tlmgr install newtx -sudo tlmgr install latexmk -sudo tlmgr install montserrat -sudo tlmgr install ly1 -``` - -## Building the paper - -To build the paper: -``` -latexmk --pdf hatra21 -``` - -To run latexmk in watching mode (where it rebuilds the PDF on each change): -``` -latexmk --pdf --pvc hatra21 -``` diff --git a/papers/hatra21/bibliography.bib b/papers/hatra21/bibliography.bib deleted file mode 100644 index 3c1936186..000000000 --- a/papers/hatra21/bibliography.bib +++ /dev/null @@ -1,177 +0,0 @@ -@Misc{Roblox, - author = {Roblox}, - title = {What is {Roblox}}, - year = 2021, - url = {https://corp.roblox.com}, -} - -@Misc{Luau, - author = {Roblox}, - title = {The {Luau} Programming Language}, - year = 2021, - url = {https://luau-lang.org}, -} - -@Misc{Lua, - author = {Lua.org and {PUC}-Rio}, - title = {The {Lua} Programming Language}, - year = 2021, - url = {https://lua.org}, -} - -@Misc{AllEducators, - author = {Roblox}, - title = {Roblox Education: All Educators}, - year = {2021}, - url = {https://education.roblox.com/en-us/educators}, -} - -@Misc{RobloxDevelopers, - author = {Roblox}, - title = {Roblox Developers Expected to Earn Over \$250 Million in 2020; Platform Now Has Over 150 Million Monthly Active Users -}, - year = {2020}, - url = {https://corp.roblox.com/2020/07/roblox-developers-expected-earn-250-million-2020-platform-now-150-million-monthly-active-users/}, -} - -@Book{TAPL, - author = {Benjamin C. Pierce}, - title = {Types and Programming Languages}, - publisher = {{MIT} Press}, - year = {2002}, - isbn = {0-262-16209-1}, -} - -@Book{TDDIdris, - author = {Edwin Brady}, - title = {Type-Driven Development with {Idris}}, - publisher = {Manning}, - year = {2017}, - isbn = {9781617293023}, -} - -@PhdThesis{TopQuality, - author = {Bastiaan J. Heeren}, - title = {Top Quality Type Error Messages}, - school = {U. Utrecht}, - year = {2005}, -} - -@PhdThesis{RepairingTypeErrors, - author = {Bruce J. McAdam}, - title = {Repairing Type Errors in Functional Programs}, - school = {U. Edinburgh}, - year = {2002}, -} - -@InProceedings{GradualTyping, - author = {Jeremy G. Siek and Walid Taha}, - title = {Gradual Typing for Functional Languages}, - booktitle = {Proc. Scheme and Functional Programming Workshop}, - year = {2006}, - pages = {81-92}, -} - -@InProceedings{WellTyped, - author = {Philip Wadler and Robert B. Findler}, - title = {Well-typed Programs Can’t be Blamed}, - booktitle = {Proc. European Symp. Programming}, - year = {2009}, - pages = {1-16}, -} - -@InProceedings{Contracts, - author = {Robert B. Findler and Matthias Felleisen}, - title = {Contracts for Higher-order Functions}, - booktitle = {Proc. Int. Conf. Functional Programming}, - year = {2002}, - pages = {48-59}, -} - -@inproceedings{SuccessTyping, - author = {Lindahl, Tobias and Sagonas, Konstantinos}, - title = {Practical Type Inference Based on Success Typings}, - year = {2006}, - booktitle = {Proc. Int. Conf. Principles and Practice of Declarative Programming}, - pages = {167–178}, -} - -@InProceedings{IncorrectnessLogic, - author = {O'Hearn, Peter W.}, - title = {Incorrectness Logic}, - year = {2020}, - booktitle = {Proc. Symp. Principles of Programming Languages}, - articleno = {10}, - pages = {1-32}, -} - -@Misc{HowToDrawAnOwl, - author = {Know Your Meme}, - title = {How To Draw An Owl}, - year = {2010}, - url = {https://knowyourmeme.com/memes/how-to-draw-an-owl}, -} - -@Misc{RustBook, - author = {Klabnik, Steve and Nichols, Carol and the Rust Community}, - title = {The Rust Programming Language}, - year = {2021}, - url = {https://doc.rust-lang.org/book/}, -} - -@article{TypeClasses, -author = {Hall, Cordelia V. and Hammond, Kevin and Peyton Jones, Simon L. and Wadler, Philip L.}, -title = {Type Classes in Haskell}, -year = {1996}, -volume = {18}, -number = {2}, -journal = {ACM Trans. Program. Lang. Syst.}, -pages = {109–138}, -} - -@InProceedings{Hazel, - author = {Cyrus Omar and Ian Voysey and Ravi Chugh and Matthew Hammer}, - title = {Live Functional Programming with Typed Holes}, - booktitle = {Proc. Symp. Principles of Programming Languages}, - year = {2019}, - pages = {14:1-14:28}, -} - -@InProceedings{MigratoryTyping, - author = {Sam Tobin-Hochstadt and Matthias Felleisen and Robert Bruce Findler and Matthew Flatt and Ben Greenman and Andrew M. Kent and Vincent St-Amour and T. Stephen Strickland and Asumu Takikawa}, - title = {Migratory Typing: Ten Years Later}, - booktitle = {Proc. Summit on Advances in Programming Languages}, - year = {2017}, -} - -@InProceedings{LinkingTypes, - author = {Daniel Patterson and Amal Ahmed}, - title = {Linking Types for Multi-Language Software: Have Your Cake and Eat It Too}, - booktitle = {Proc. Summit on Advances in Programming Languages}, - year = {2017}, -} - -@InProceedings{QuickLook, -author = {Serrano, Alejandro and Hage, Jurriaan and Peyton Jones, Simon and Vytiniotis, Dimitrios}, -title = {A quick look at impredicativity}, -booktitle = {Proc. Int. Conf. Functional Programming}, -year = {2020}, -} - -@InProceedings{Boehm85, - author = {Partial polymorphic type inference is undecidable}, - title = {Hans-J. Boehm}, - booktitle = {Proc. Symp. Foundations of Computer Science}, - year = {1985}, - pages = {339-345}, -} - -@article{LocalTypeInference, -author = {Pierce, Benjamin C. and Turner, David N.}, -title = {Local Type Inference}, -year = {2000}, -volume = {22}, -number = {1}, -journal = {ACM Trans. Program. Lang. Syst.}, -pages = {1–44}, -} \ No newline at end of file diff --git a/papers/hatra21/cc-by.png b/papers/hatra21/cc-by.png deleted file mode 100644 index c8473a247..000000000 Binary files a/papers/hatra21/cc-by.png and /dev/null differ diff --git a/papers/hatra21/hatra21.pdf b/papers/hatra21/hatra21.pdf deleted file mode 100644 index 05ab916ec..000000000 Binary files a/papers/hatra21/hatra21.pdf and /dev/null differ diff --git a/papers/hatra21/hatra21.tex b/papers/hatra21/hatra21.tex deleted file mode 100644 index bd4a5d583..000000000 --- a/papers/hatra21/hatra21.tex +++ /dev/null @@ -1,446 +0,0 @@ -\documentclass[acmsmall]{acmart} - -\setcopyright{rightsretained} -\copyrightyear{2021} -\acmYear{2021} -\acmConference[HATRA '21]{Human Aspects of Types and Reasoning Assistants}{October 2021}{Chicago, IL} -\acmBooktitle{HATRA '21: Human Aspects of Types and Reasoning Assistants} -\acmDOI{} -\acmISBN{} -\expandafter\def\csname @copyrightpermission\endcsname{\raisebox{-1ex}{\includegraphics[height=3.5ex]{cc-by}} This work is licensed under a Creative Commons Attribution 4.0 International License.} -\expandafter\def\csname @copyrightowner\endcsname{Roblox.} - -\newcommand{\squnder}[1]{\color{red}\underline{{\color{black}#1}}\color{black}} -\newcommand{\infer}[2]{\frac{\textstyle#1}{\textstyle#2}} -\newcommand{\erase}{\mathrm{erase}} -\newcommand{\evCtx}{\mathcal{E}} -\newcommand{\NIL}{\mathsf{nil}} -\newcommand{\ANY}{\mathsf{any}} -\newcommand{\TRUE}{\mathsf{true}} -\newcommand{\FALSE}{\mathsf{false}} -\newcommand{\NUMBER}{\mathsf{number}} -\newcommand{\STRING}{\mathsf{string}} -\newcommand{\ERROR}{\mathsf{error}} -\newcommand{\IF}{\mathsf{if}\,} -\newcommand{\LOCAL}{\mathsf{local}\,} -\newcommand{\THEN}{\,\mathsf{then}\,} -\newcommand{\ELSE}{\,\mathsf{else}\,} -\newcommand{\END}{\,\mathsf{end}} -\newcommand{\FUNCTION}{\mathsf{function}\,} -\newcommand{\RETURN}{\mathsf{return}\,} -\newcommand{\FIND}{\mathsf{find}} -\newcommand{\PRINT}{\mathsf{print}} -\newcommand{\strlit}[1]{\mbox{``#1''}} - -\begin{document} - -\title{Position Paper: Goals of the Luau Type System} - -\author{Lily Brown} -\author{Andy Friesen} -\author{Alan Jeffrey} -\affiliation{ - \institution{Roblox} - \city{San Mateo} - \state{CA} - \country{USA} -} - -\begin{abstract} - Luau is the scripting language that powers user-generated experiences on the - Roblox platform. It is a statically-typed language, based on the - dynamically-typed Lua language, with type inference. These types are used for providing - editor assistance in Roblox Studio, the IDE for authoring Roblox experiences. - Due to Roblox's uniquely heterogeneous developer community, Luau must operate - in a somewhat different fashion than a traditional statically-typed language. - In this paper, we describe some of the goals of the Luau type system, - focusing on where the goals differ from those of other type systems. -\end{abstract} - -\maketitle - -\section{Introduction} - -The Roblox platform allows anyone to create shared, -immersive, 3D experiences. As of July 2021, there are -approximately 20~million experiences available on Roblox, created -by 8~million developers~\cite{Roblox}. Roblox creators are often young: there are -over 200~Roblox kids' coding camps in 65~countries -listed by the company as education resources~\cite{AllEducators}. -The Luau programming language~\cite{Luau} is the scripting language -used by creators of Roblox experiences. Luau is derived from the Lua -programming language~\cite{Lua}, with additional capabilities, -including a type inference engine. - -This paper will discuss some of the goals of the Luau type system, such -as supporting goal-driven learning, non-strict typing semantics, and -mixing strict and non-strict types. Particular focus is placed on how -these goals differ from traditional type systems' goals. - -\section{Needs of the Roblox platform} -\subsection{Heterogeneous developer community} - -Need: \emph{a language that is powerful enough to support professional users, yet accessible to beginners} - -Quoting a Roblox 2020 report \cite{RobloxDevelopers}: -\begin{itemize} -\item \emph{Adopt Me!} now has over 10 billion plays and surpassed 1.6 million concurrent users earlier this year. -\item \emph{Piggy}, launched in January 2020, has close to 5 billion visits in just over six months. -\item There are now 345,000 developers on the platform who are monetizing their games. -\end{itemize} -This demonstrates the heterogeneity of the Roblox developer community: -developers of experiences with billions of plays are on the same -platform as children first learning to code. Both of these groups are important to -support: the professional development studios bring high-quality experiences to the -platform, and the beginning creators contribute to the energetic creative community, -forming the next generation of developers. - -\subsection{Goal-driven learning} - -Need: \emph{organic learning for achieving specific goals} - -All developers are goal-driven, but this is especially true for -learners. A learner will download Roblox Studio -(the creation environment for the Roblox platform) with an -experience in mind, such as designing an obstacle course -to play in with their friends. - -The user experience of developing a Roblox experience is primarily a -3D interactive one, seen in Fig.~\ref{fig:studio}(a). The user designs -and deploys 3D assets such as terrain, parts and joints, providing -them with physics attributes such as mass and orientation. The user -can interact with the experience in Studio, and deploy it to a Roblox -server so anyone with the Roblox app can play it. Physics, rendering -and multiplayer are all immediately accessible to creators. - -\begin{figure} -\includegraphics[width=0.48\textwidth]{studio-mow.png} -\includegraphics[width=0.48\textwidth]{studio-script-editor.png} -\caption{Roblox Studio's 3D environment editor (a), and script editor (b)} -\label{fig:studio} -\end{figure} - -At some point during experience design, the experience creator has a need -that can't be met by the game engine alone, such as ``the stairs should -light up when a player walks on them'' or ``a firework is set off -every few seconds''. At this point, they will discover the script -editor, seen in Fig.~\ref{fig:studio}(b). - -This onboarding experience is different from many initial exposures to -programming, in that by the time the user first opens the script -editor, they have already built much of their creation, and have a -very specific concrete aim. As such, Luau must allow users to perform a -specific task with as much help as possible from tools. - -A common workflow for getting started is to Google the task, then -cut-and-paste the resulting code, adapting it as needed. Since this -is so common, backward compatibility of Luau with existing code is -important, even for learners who do not have an existing code base to -maintain. - -Type-driven tools are useful to all creators, in as much as they help -them achieve their current goals. For example type-driven -autocomplete, or type-driven API documentation, are of immediate -benefit. Traditional typechecking can be useful, for example for -catching spelling mistakes, but for most goal-driven developers, the -type system should help or get out of the way. - -\subsection{Type-driven development} - -Need: \emph{a language that supports large-scale codebases and defect detection} - -Professional development studios are also goal-directed (though the -goals may be more abstract, such as ``decrease user churn'' or -``improve frame rate'') but have additional needs: -\begin{itemize} - -\item \emph{Code planning}: - code spends much of its time in an incomplete state, with holes - that will be filled in later. - -\item \emph{Code refactoring}: - code evolves over time, and it is easy for changes to - break previously-held invariants. - -\item \emph{Defect detection}: - code has errors, and detecting these at runtime (for example by crash telemetry) - can be expensive and recovery can be time-consuming. - -\end{itemize} -Detecting defects ahead-of-time is a traditional goal of type systems, -resulting in an array of techniques for establishing safety results, -surveyed for example in~\cite{TAPL}. Supporting code planning and -refactoring are some of the goals of \emph{type-driven -development}~\cite{TDDIdris} under the slogan ``type, define, -refine''. A common use of type-driven development is renaming a -property, which is achieved by changing the name in one place, -and then fixing the resulting type errors---once the type system stops -reporting errors, the refactoring is complete. - -To help support the transition from novice to experienced developer, -types are introduced gradually, through API documentation and type discovery. -Type inference provides many of the benefits of type-driven development -even to creators who are not explicitly providing types. - -\section{Goals of the type system} -\subsection{Infallible types} - -Goal: \emph{provide type information even for ill-typed or syntactically invalid programs.} - -Programs spend much of their time under development in an ill-typed or incomplete state, even if the -final artifact is well-typed. If tools such as autocomplete and API documentation are type-driven, -this means that tooling needs to rely on type information even for ill-typed -or syntactically invalid programs. An analogy is infallible parsers, which perform error recovery and -provide an AST for all input texts, even if they don't adhere to the parser's syntax. - -Program analysis can still flag type errors, which may be presented -to the user with red squiggly underlining. Formalizing this, rather -than a judgment -$\Gamma\vdash M:T$, for an input term $M$, there is a judgment -$\Gamma \vdash M \Rightarrow N : T$ where $N$ is an output term -where some subterms are \emph{flagged} as having type errors, written $\squnder{N}$. Write $\erase(N)$ -for the result of erasing flaggings: $\erase(\squnder{N}) = \erase(N)$. - -For example, in Lua, the $\STRING.\FIND$ function expects two strings, and returns the -offsets for that string: -\[ - \STRING.\FIND(\strlit{hello}, \strlit{ell}) \rightarrow (2, 4) -\qquad - \STRING.\FIND(\strlit{world}, \strlit{ell}) \rightarrow (\NIL, \NIL) -\] -and in Luau it has the type: -\[ - \STRING.\FIND : (\STRING, \STRING) \rightarrow (\NUMBER?, \NUMBER?) -\] -In a conventional type system, there is no judgment for ill-typed terms -such as $\STRING.\FIND(\strlit{hello}, 37)$ but in an infallible system we flag the error -and approximate the type, for example: -\[ - {} \vdash - \STRING.\FIND(\strlit{hello}, 37) - \Rightarrow - \squnder{\STRING.\FIND(\strlit{hello}, 37)} - : - (\NUMBER?, \NUMBER?) -\] -The goal of infallible types is that every term has a typing judgment -given by flagging ill-typed subterms: -\begin{itemize} -\item \emph{Typability}: for every $M$ and $\Gamma$, - there are $N$ and $T$ such that $\Gamma \vdash M \Rightarrow N : T$. -\item \emph{Erasure}: if $\Gamma \vdash M \Rightarrow N : T$ - then $\erase(M) = \erase(N)$ -\end{itemize} -Some issues raised by infallible types: -\begin{itemize} -\item Which heuristics should be used to provide types for flagged programs? For example, could one - use minimal edit distance to correct for spelling mistakes in field names? -\item How can we avoid cascading type errors, where a developer is - faced with type errors that are artifacts of the heuristics, rather - than genuine errors? -\item How can the goals of an infallible type system be formalized? -\end{itemize} -\emph{Related work}: -there is a large body of work on type error reporting -(see, for example, the survey in~\cite[Ch.~3]{TopQuality}) -and on type-directed program repair -(see, for example, the survey in~\cite[Ch.~3]{RepairingTypeErrors}), -but less on type repair. -The closest work is Hazel's~\cite{Hazel} \emph{typed holes} -where $\squnder{N}$ is treated as a partially-filled hole in the program, -though in that work partially-filled holes are not erased at run-time. -Many compilers perform -error recovery during typechecking, but do not provide a semantics -for programs with type errors. - -\subsection{Strict types} - -Goal: \emph{no false negatives.} - -For developers who are interested in defect detection, Luau provides a \emph{strict mode}, -which acts much like a traditional, sound, type system. This has the goal of ``no false negatives'' -where any possible run-time error is flagged. This is formalized using: -\begin{itemize} -\item \emph{Operational semantics}: a reduction judgment $M \rightarrow N$ on terms. -\item \emph{Values}: a subset of terms representing a successfully completed evaluation. -\end{itemize} -Error states at runtime are represented as stuck states (terms that are not -values but cannot reduce), and showing that no well-typed program is -stuck. This is not true if typing is infallible, but can fairly -straightforwardly be adapted. We extend the operational semantics to flagged terms, -where $M \rightarrow M'$ implies $\squnder{M} \rightarrow \squnder{M'}$, and -for any value $V$ we have $\squnder{V} \rightarrow V$, then show: -\begin{itemize} -\item \emph{Progress}: if ${} \vdash M \Rightarrow N : T$, then either $N \rightarrow N'$ or $N$ is a value or $N$ has a flagged subterm. -\item \emph{Preservation}: if ${} \vdash M \Rightarrow N : T$ and $N \rightarrow N'$ then $M \rightarrow^*M'$ and ${} \vdash M' \Rightarrow N' : T$. -\end{itemize} -For example in typechecking the program: -\[ - \LOCAL (i,j) = \STRING.\FIND(x, y); - \IF i \THEN \PRINT(j-i) \END -\] -the interesting case is $i-j$ in a context where $i$ has type -$\NUMBER$ (since it is guarded by the $\IF$) but $j$ has type -$\NUMBER?$. Since subtraction has type $(\NUMBER, \NUMBER) \rightarrow \NUMBER$, -this is a type error, so the relevant typing judgment is: -\[\begin{array}{r@{}l} - x: \STRING, y: \STRING \vdash {}& - (\LOCAL (i,j) = \STRING.\FIND(x, y); - \IF i \THEN \PRINT(j-i) \END) \\ - \Rightarrow {}& - (\LOCAL (i,j) = \STRING.\FIND(x, y); - \IF i \THEN \PRINT(\squnder{j-i}) \END) -\end{array}\] -Some issues raised by soundness for infallible types: -\begin{itemize} -\item How should the judgments and their metatheory be set up? -\item How should type inference and generic functions be handled? -\item Is the operational semantics of flagged values - ($\squnder{V} \rightarrow V$) the right one? -\end{itemize} -\emph{Related work}: gradual typing and blame analysis, e.g.~\cite{GradualTyping,WellTyped,Contracts}. -The main difference between this approach and that of migratory typing~\cite{MigratoryTyping} -is that (due to backward compatibility with existing Lua) we cannot introduce -extra code during migration. - -\subsection{Nonstrict types} - -Goal: \emph{no false positives.} - -For developers who are not interested in defect detection, type-driven -tools and techniques such as autocomplete, API documentation -and type-driven refactoring are still useful. -For such developers, Luau provides a -\emph{nonstrict mode}, which we hope will eventually be useful for all -developers. This non-strict typing mode is particularly useful when -adopting Luau types in pre-existing code that was not authored with -the type system in mind. Non-strict mode does \emph{not} aim for -soundness, but instead has the goal of ``no false positives``, in the -sense that any flagged code is guaranteed to produce a runtime error -when executed. - -Our previous example was, in fact, a false positive since a programmer -can make use of the fact that $\STRING.\FIND(x, y)$ is either $\NIL$ -in both results or neither, so if $i$ is non-$\NIL$ then so is $j$. -This is discussed in the English-language documentation but not reflected -in the type. So flagging $(i - j)$ is a false positive. - -On the face of it, detecting all errors without false positives is undecidable, since a program such as -$(\IF f() \THEN \ERROR \END)$ will produce a runtime error when $f()$ is -$\TRUE$. Instead we can aim for a weaker property: that all flagged code -is either dead code or will produce an error. Either of these is a -defect, so deserves flagging, even if the tool does not know -which reason applies. - -We can formalize this by defining an \emph{evaluation context} -$\evCtx[\bullet]$, and saying $M$ is \emph{incorrectly flagged} -if it is of the form $\evCtx[\squnder{V}]$. We can then define: -\begin{itemize} -\item \emph{Correct flagging}: if ${} \vdash M \Rightarrow N : T$ - then $N$ is correctly flagged. -\end{itemize} -Some issues raised by nonstrict types: -\begin{itemize} - -\item Will nonstrict types result in errors being flagged in function call sites - rather than definitions? - -\item In Luau, ill-typed property update of most tables succeeds - (the property is inserted if it did not exist), and so functions which - update properties cannot be flagged. Can we still provide meaningful - error messages in such cases? - -\item Does nonstrict typing require whole program analysis, - to find all the possible types a property might be updated with? - -\item The natural formulation of function types in a nonstrict setting - is that of~\cite{SuccessTyping}: if $f: T \rightarrow U$ and $f(V) \rightarrow^* W$ - then $V:T$ and $W:U$. This formulation is \emph{covariant} in $T$, - not \emph{contravariant}; what impact does this have? - -\end{itemize} -\emph{Related work}: success types~\cite{SuccessTyping} and incorrectness logic~\cite{IncorrectnessLogic}. - -\subsection{Mixing types} - -Goal: \emph{support mixed strict/nonstrict development}. - -Like every active software community, Roblox developers share code -with one another constantly. First- and third-party developers alike -frequently share entire software packages written in Luau. To add to -this, many Roblox experiences are authored by a team. It is therefore -crucial that we offer first-class support for mixing code written in -strict and nonstrict modes. - -Some questions raised by mixed-mode types: -\begin{itemize} - -\item How much feedback can we offer for a nonstrict script that is - importing strict-mode code? - -\item In strict mode, how do we talk about values and types that are - drawn from nonstrict code? - -\item How can we combine the goals of strict and nonstrict types? - -\item Can we have strict and non-strict mode infer the same types, - only with different flagging? - -\item Is strict-mode code sound when it relies on non-strict code, - which has weaker invariants? - -\item How can we avoid introducing function wrappers in higher-order code - at the strict/nonstrict boundary? - -\end{itemize} -\emph{Related work}: there has been work on interoperability between different type systems, -notably~\cite{LinkingTypes}, but there the overall goals of the systems were similar safety properties. -In our case, the two type systems have different goals. - -\subsection{Type inference} - -Goal: \emph{infer types to allow gradual adoption of type annotations.} - -Since backward compatibility with existing code is important, we have -to provide types for code without explicit annotations. Moreover, we -want to make use of type-directed tools such as autocomplete, so we -cannot adopt the common strategy of treating all untyped variables as -having type $\ANY$. This leads us to type inference. - -To make use of type-driven technologies for programs -without explicit type annotations, we use a type inference algorithm. -Since Luau includes System~F, type inference is undecidable~\cite{Boehm85}, -but we can still make use of heuristics such as local type inference~\cite{LocalTypeInference}. - -It remains to be seen if type inference can satisfy the goals of -strict and non-strict types. The current Luau system -infers different types in the two modes, which is unsatisfactory as it -makes changing mode a non-local breaking change. In addition, -non-strict inference is currently too imprecise to support -type-directed tools such as autocomplete. - -Some questions raised by type inference: -\begin{itemize} - -\item How many cases in strict mode cannot be inferred by the type inference system? Minimizing - this kind of error is desirable, to make the type system as unobtrusive as possible. -\item Can something like the Rust traits system~\cite{RustBook} or Haskell classes~\cite{TypeClasses} be used to provide types for overloaded operators, without hopelessly confusing learners? -\item Type inference currently infers monotypes for unannotated - functions, in contrast to QuickLook~\cite{QuickLook}, which can infer generic types. - Will this be good enough for idiomatic Luau scripts? -\item Can type inference be used to infer the same types in strict and nonstrict mode, to ease migrating between modes, with the only difference being error reporting? -\end{itemize} -\emph{Related work}: there is a large body of work on type inference, largely summarized in~\cite{TAPL}. - -\section{Conclusions} - -In this paper, we have presented some of the goals of the Luau type -system, and how they map to the needs of the Roblox creator -community. We have also explored how these goals differ from traditional -type systems, where it is necessary to accommodate the unique needs of -the Roblox platform. We have sketched what a solution might look like; -all that remains is to draw the owl~\cite{HowToDrawAnOwl}. - -\bibliographystyle{ACM-Reference-Format} \bibliography{bibliography} - -\end{document} diff --git a/papers/hatra21/studio-mow.png b/papers/hatra21/studio-mow.png deleted file mode 100644 index 71a10a07b..000000000 Binary files a/papers/hatra21/studio-mow.png and /dev/null differ diff --git a/papers/hatra21/studio-script-editor.png b/papers/hatra21/studio-script-editor.png deleted file mode 100644 index 92a83f75d..000000000 Binary files a/papers/hatra21/studio-script-editor.png and /dev/null differ diff --git a/papers/hatra21/talk.pdf b/papers/hatra21/talk.pdf deleted file mode 100644 index 042a7a7e9..000000000 Binary files a/papers/hatra21/talk.pdf and /dev/null differ diff --git a/papers/hatra21/talk.tex b/papers/hatra21/talk.tex deleted file mode 100644 index 1c8627fb1..000000000 --- a/papers/hatra21/talk.tex +++ /dev/null @@ -1,203 +0,0 @@ -\documentclass[aspectratio=169]{beamer} - -\usecolortheme{whale} -\setbeamertemplate{navigation symbols}{} -\definecolor{background}{rgb}{0.945,0.941,0.96} -\definecolor{bluish}{rgb}{0.188,0.455,0.863} -\usepackage{montserrat} -\setbeamerfont{frametitle}{size=\Large,series=\bfseries} -\setbeamerfont{title}{size=\Huge,series=\bfseries} -\setbeamerfont{date}{shape=\itshape} -\setbeamercolor{title}{bg=bluish} -\setbeamercolor{frametitle}{bg=bluish} -\setbeamercolor{background canvas}{bg=background} -\setbeamercolor{itemize item}{fg=bluish} -\setbeamercolor{part name}{fg=background} -\setbeamercolor{part title}{bg=bluish} -\setbeamertemplate{footline}[text line]{\hfill\raisebox{5ex}{\insertshorttitle~~~~\insertframenumber/\inserttotalframenumber~~~~\includegraphics[width=5em]{Logo-Roblox-Black-Full.png}}} -\AtBeginPart{{\setbeamertemplate{footline}{}\frame{\partpage}}} - -\newcommand{\erase}{\mathsf{erase}} - -\title{Goals of the Luau~Type~System} -\author{Lily Brown \and Andy Friesen \and Alan Jeffrey} -\institute[Roblox]{\includegraphics[width=15em]{Logo-Roblox-Black-Full.png}} -\date[HATRA '21]{\textit{Human Aspects of Types and Reasoning Assistants} 2021} - -\begin{document} - -{\setbeamertemplate{footline}{}\frame{\titlepage}} - -\part{Creator Goals} - -\begin{frame} - -\frametitle{Roblox} - -A platform for creating shared immersive 3D experiences: -\begin{itemize} - \item \textbf{Many}: 20 million experiences, 8 million creators. - \item \textbf{At scale}: e.g.~\emph{Adopt Me!} has 10 billion plays. - \item \textbf{Learners}: e.g.~200+ kids' coding camps in 65+ countries. - \item \textbf{Professional}: 345k creators monetizing experiences. -\end{itemize} -A very heterogeneous community. - -\end{frame} - -\begin{frame} - -\frametitle{Roblox developer community} - -All developers are important: -\begin{itemize} - \item \textbf{Learners}: energetic creative community. - \item \textbf{Professionals}: high-quality experiences. - \item \textbf{Everyone inbetween}: some learners become professionals! -\end{itemize} -Satisfying everyone is sometimes challenging. - -\end{frame} - -\begin{frame} - -\frametitle{Roblox Studio} - -Demo time! - -\end{frame} - -\begin{frame} - -\frametitle{Learners have immediate goals} - -E.g. ``when a player steps on the button, advance the slide''. -\begin{itemize} - \item \textbf{3D scene editor} meets most goals, e.g.~model parts. - \item \textbf{Programming} is needed for reacting to events, e.g.~collisions. - \item \textbf{Onboarding} is very different from ``let's learn to program''. - \item \textbf{Google Stack Overflow} is a common workflow. - \item \textbf{Type-driven tools} are useful, e.g.~autocomplete or API help. - \item \textbf{Type errors} may be useful (e.g.~catching typos) but some are not. -\end{itemize} -Type systems should help or get out of the way. - -\end{frame} - -\begin{frame} - -\frametitle{Professionals have long-term goals} - -E.g. ``decrease user churn'' or ``improve frame rate''. -\begin{itemize} -\item \textbf{Code planning}: programs are incomplete. -\item \textbf{Code refactoring}: programs change. -\item \textbf{Defect detection}: programs have bugs. -\end{itemize} -Type-driven development is a useful technique! - -\end{frame} - -\part{Luau Type System} - -\begin{frame} - -\frametitle{Infallible types} - -Goal: \emph{support type-driven tools (e.g. autocomplete) for all programs.} -\begin{itemize} -\item \textbf{Traditional typing judgment} says nothing about ill-typed terms. -\item \textbf{Infallible judgment}: every term gets a type. -\item \textbf{Flag type errors}: elaboration introduces \emph{flagged} subterms. -\end{itemize} - -\emph{Related work}: -\begin{itemize} -\item Type error reporting, program repair. -\item Typed holes (e.g. in Hazel). -\end{itemize} - -\end{frame} - -\begin{frame} - -\frametitle{Strict types} - -Goal: \emph{no false negatives}. - -\begin{itemize} -\item \textbf{Strict mode} enabled by developers who want defect detection. -\item \textbf{Business as usual} soundness via progress + preservation. -\item \textbf{Gradual types} for programs with flagged type errors. -\end{itemize} - -\emph{Related work}: -\begin{itemize} -\item Lots and lots for type safety. -\item Gradual typing, blame analysis, migratory types\dots -\end{itemize} - -\end{frame} - -\begin{frame} - -\frametitle{Nonstrict types} - -Goal: \emph{no false positives}. - -\begin{itemize} -\item \textbf{Nonstrict mode} enabled by developers who want type-drive tools. -\item \textbf{Victory condition} does not have an obvious definition! -\item \textbf{A shot at it}: a program is \emph{incorrectly flagged} if it contains - a flagged value (i.e.~a flagged program has successfully terminated). -\item \textbf{Progress + correct flagging} is what we want??? -\end{itemize} - -\emph{Related work}: -\begin{itemize} -\item Success types (e.g. Erlang Dialyzer). -\item Incorrectness Logic. -\end{itemize} - -\end{frame} - -\begin{frame} - -\frametitle{Mixing types} - -Goal: \emph{support mixed strict/nonstrict development}. - -\begin{itemize} -\item \textbf{Per-module} strict/nonstrict mode. -\item \textbf{Combined} progress + preservation with progress + correct flagging? -\end{itemize} - -\emph{Related work}: -\begin{itemize} -\item Some on mixed languages, but with shared safety properties. -\end{itemize} - -\end{frame} - -\begin{frame} - -\frametitle{Type inference} - -Goal: \emph{provide benefits of type-directed tools to everyone}. - -\begin{itemize} -\item \textbf{Infer types} for all variables. Resist the urge to give up and ascribe a top type when an error is encountered. -\item \textbf{System F} is in Luau, so everything is undecidable. Yay heuristics! -\item \textbf{Different modes} currently infer different types. Boo! -\end{itemize} - -\emph{Related work}: -\begin{itemize} -\item Lots, though not on mixed modes. -\end{itemize} - -\end{frame} - -\part{Thank you!\\Roblox is hiring!} - -\end{document} diff --git a/rfcs/README.md b/rfcs/README.md deleted file mode 100644 index 4b5e7b04f..000000000 --- a/rfcs/README.md +++ /dev/null @@ -1,60 +0,0 @@ -Background -=== - -Whenever Luau language changes its syntax or semantics (including behavior of builtin libraries), we need to consider many implications of the changes. - -Whenever new syntax is introduced, we need to ask: - -- Is it backwards compatible? -- Is it easy for machines and humans to parse? -- Does it create grammar ambiguities for current and future syntax? -- Is it stylistically coherent with the rest of the language? -- Does it present challenges with editor integration like autocomplete? - -For changes in semantics, we should be asking: - -- Is behavior easy to understand and non-surprising? -- Can it be implemented performantly today? -- Can it be sandboxed assuming malicious usage? -- Is it compatible with type checking and other forms of static analysis? - -For new standard library functions, we should be asking: - -- Is the new functionality used/useful often enough in existing code? -- Does the standard library implementation carry important performance benefits that can't be achieved in user code? -- Is the behavior general and unambiguous, as opposed to solving a problem / providing an interface that's too specific? -- Is the function interface amenable to type checking / linting? - -In addition to these questions, we also need to consider that every addition carries a cost, and too many features will result in a language that is harder to learn, harder to implement and ensure consistent implementation quality throughout, slower, etc. In addition, any language is greater than the sum of its parts and features often have non-intuitive interactions with each other. - -Since reversing these decisions is incredibly costly and can be impossible due to backwards compatibility implications, all user facing changes to Luau language and core libraries must go through an RFC process. - -Process -=== - -To open an RFC, a Pull Request must be opened which creates a new Markdown file in `rfcs/` folder. The RFCs should follow the template `rfcs/TEMPLATE.md`, and should have a file name that is a short human readable description of the feature (using lowercase alphanumeric characters and dashes only). Try using the general area of the RFC as a prefix, e.g. `syntax-generic-functions.md` or `function-debug-info.md`. - -**Please make sure to add `rfc` label to PRs *before* creating them!** This makes sure that our automatic notifications work correctly. - -Every open RFC will be open for at least two calendar weeks. This is to make sure that there is sufficient time to review the proposal and raise concerns or suggest improvements. The discussion points should be reflected on the PR comments; when discussion happens outside of the comment stream, the points salient to the RFC should be summarized as a followup. - -When the initial comment period expires, the RFC can be merged if there's consensus that the change is important and that the details of the syntax/semantics presented are workable. The decision to merge the RFC is made by the Luau team. - -When revisions on the RFC text that affect syntax/semantics are suggested, they need to be incorporated before a RFC is merged; a merged RFC represents a maximally accurate version of the language change that is going to be implemented. - -In some cases RFCs may contain conditional compatibility clauses. E.g. there are cases where a change is potentially not backwards compatible, but is believed to be substantially beneficial that it can be implemented if, in practice, the backwards compatibility implications are minimal. As a strawman example, if we wanted to introduce a non-context-specific keyword `globallycoherent`, we would be able to do so if our analysis of Luau code (based on the Roblox platform at the moment) informs us that no script in existence uses this keyword. In cases like this an RFC may need to be revised after the initial implementation attempt based on the data that we gather. - -In general, RFCs can also be updated after merging to make the language of the RFC more clear, but should not change their meaning. When a new feature is built on top of an existing feature that has an RFC, a new RFC should be created instead of editing an existing RFC. - -When there's no consensus that the feature is broadly beneficial and can be implemented, an RFC will be closed. The decision to close the RFC is made by the Luau team. - -Note that in some cases an RFC may be closed because we don't have sufficient data or believe that at this point in time, the stars do not line up sufficiently for this change to be worthwhile, but this doesn't mean that it may never be considered again; an RFC PR may be reopened if new data is available since the original discussion, or if the PR has changed substantially to address the core problems raised in the prior round. - -Implementation -=== - -When an RFC gets merged, the feature *can* be implemented; however, there's no set timeline for that implementation. In some cases implementation may land in a matter of days after an RFC is merged, in some it may take months. - -To avoid having permanently stale RFCs, in rare cases Luau team can *remove* a previously merged RFC when the landscape is believed to change enough for a feature like this to warrant further discussion. - -When an RFC is implemented and the implementation is enabled via feature flags, RFC should be updated to include "**Status**: Implemented" at the top level (before *Summary* section). diff --git a/rfcs/STATUS.md b/rfcs/STATUS.md deleted file mode 100644 index 41d69f015..000000000 --- a/rfcs/STATUS.md +++ /dev/null @@ -1,28 +0,0 @@ -This document tracks unimplemented RFCs. - -## Deprecate getfenv/setfenv - -[RFC: Deprecate getfenv/setfenv](https://github.com/Roblox/luau/blob/master/rfcs/deprecate-getfenv-setfenv.md) - -**Status**: Needs implementation. - -**Notes**: Implementing this RFC triggers warnings across the board in the apps ecosystem, in particular in testing libraries. Pending code changes / decisions. - -## Deprecate table.getn/foreach/foreachi - -[RFC: Deprecate table.getn/foreach/foreachi](https://github.com/Roblox/luau/blob/master/rfcs/deprecate-table-getn-foreach.md) - -**Status**: Needs implementation. - -## Read-only and write-only properties - -[RFC: Read-only properties](https://github.com/Roblox/luau/blob/master/rfcs/property-readonly.md) | -[RFC: Write-only properties](https://github.com/Roblox/luau/blob/master/rfcs/property-writeonly.md) - -**Status**: Needs implementation - -## Expanded Subtyping for Generic Function Types - -[RFC: Expanded Subtyping for Generic Function Types](https://github.com/Roblox/luau/blob/master/rfcs/generic-function-subtyping.md) - -**Status**: Implemented but not fully rolled out yet. diff --git a/rfcs/TEMPLATE.md b/rfcs/TEMPLATE.md deleted file mode 100644 index 266922b27..000000000 --- a/rfcs/TEMPLATE.md +++ /dev/null @@ -1,21 +0,0 @@ -# Feature name - -## Summary - -One paragraph explanation of the feature. - -## Motivation - -Why are we doing this? What use cases does it support? What is the expected outcome? - -## Design - -This is the bulk of the proposal. Explain the design in enough detail for somebody familiar with the language to understand, and include examples of how the feature is used. - -## Drawbacks - -Why should we *not* do this? - -## Alternatives - -What other designs have been considered? What is the impact of not doing this? diff --git a/rfcs/behavior-eq-metamethod.md b/rfcs/behavior-eq-metamethod.md deleted file mode 100644 index eeb768f02..000000000 --- a/rfcs/behavior-eq-metamethod.md +++ /dev/null @@ -1,59 +0,0 @@ -# Always call `__eq` when comparing for equality - -> Note: this RFC was adapted from an internal proposal that predates RFC process - -**Status**: Implemented - -## Summary - -`__eq` metamethod will always be called during `==`/`~=` comparison, even for objects that are rawequal. - -## Motivation - -Lua 5.x has the following algorithm it uses for comparing userdatas and tables: - -- If two objects are not of the same type (userdata vs number), they aren't equal -- If two objects are referentially equal, they are equal (!) -- If no object has a metatable with `__eq` metamethod, they are equal iff they are referentially equal -- Otherwise, pick one of the `__eq` metamethods, call it with both objects as arguments and return the result. - -In mid-2019, we've released Luau which implements a fast path for userdata comparison. This fast path accidentally omitted step 2 for userdatas with C `__eq` implementations (!), and thus comparing a userdata object vs itself would actually run `__eq` metamethod. This is significant as it allowed users to use `v == v` as a NaN check for vectors, coordinate frames, and other objects that have floating point contents. - -Since this was a bug, we're in a rather inconsistent state: - -- `==` and `~=` in the code always call `__eq` for userdata with C `__eq` -- `==` and `~=` don't call `__eq` for tables and custom newproxy-like userdatas with Lua `__eq` when objects are ref. equal -- `table.find` *doesn't* call `__eq` when objects are ref. equal - -## Design - -Since developers started relying on `==` behavior for NaN checks in the last two years since Luau release, the bug has become a feature. Additionally, it's sort of a good feature since it allows to implement NaN semantics for custom types - userdatas, tables, etc. - -Thus the proposal suggests changing the rules so that when `__eq` metamethod is present, `__eq` is always called even when comparing the object to itself. - -This would effectively make the current ruleset for userdata objects official, and change the behavior for `table.find` (which is probably not significant) and, more significantly, start calling user-provided `__eq` even when the object is the same. It's expected that any reasonable `__eq` implementation can handle comparing the object to itself so this is not expected to result in breakage. - -## Drawbacks - -This represents a difference in a rather core behavior from all upstream versions of Lua. - -## Alternatives - -We could instead equalize (ha!) the behavior between Luau and Lua. In fact, this is what we tried to do initially as the userdata behavior was considered a bug, but encountered the issue with games already depending on the new behavior. - -We could work with developers to change their games to stop relying on this. However, this is more complicated to deploy and - upon reflection - makes `==` less intuitive than the main proposal when comparing objects with NaN, since e.g. it means that these two functions have a different behavior: - -``` -function compare1(a: Vector3, b: Vector3) - return a == b -end - -function compare2(a: Vector3, b: Vector3) - return a.X == b.X and a.Y == b.Y and a.Z == b.Z -end -``` - -## References - -https://devforum.roblox.com/t/call-eq-even-when-tables-are-rawequal/1088886 -https://devforum.roblox.com/t/nan-vector3-comparison-broken-cframe-too/1130778 diff --git a/rfcs/change-global-version.md b/rfcs/change-global-version.md deleted file mode 100644 index cdb6c1e72..000000000 --- a/rfcs/change-global-version.md +++ /dev/null @@ -1,21 +0,0 @@ -# Change \_VERSION global to "Luau" - -> Note: this RFC was adapted from an internal proposal that predates RFC process - -**Status**: Implemented - -## Summary - -Change \_VERSION global to "Luau" to differentiate Luau from Lua - -## Motivation - -Provide an official way to distinguish Luau from Lua implementation. - -## Design - -We inherit the global string \_VERSION from Lua (this is distinct from Roblox `version()` function that returns a full version number such as 0.432.43589). - -The string is set to "Lua 5.1" for us (and "Lua 5.2" etc for newer versions of Lua. - -Since our implementation is sufficiently divergent from upstream, this proposal suggests setting \_VERSION to "Luau". diff --git a/rfcs/config-luaurc.md b/rfcs/config-luaurc.md deleted file mode 100644 index b321b94b3..000000000 --- a/rfcs/config-luaurc.md +++ /dev/null @@ -1,68 +0,0 @@ -# Configure analysis via .luaurc - -**Status**: Implemented - -## Summary - -Introduces a way to configure type checker and linter using JSON-like .luaurc files - -## Motivation - -While Luau analysis tools try to provide sensible defaults, it's difficult to establish the rules that work for all code. -For example, some packages may decide that unused variables aren't interesting; other packages may decide that all files should be using strict typechecking mode. - -While it's possible to configure some aspects of analysis behavior using --! comments, it can be cumbersome to replicate this in all files. - -## Design - -To solve this problem, we are going to introduce support for `.luaurc` files for users of command-line Luau tools. -For a given .lua file, Luau will search for .luaurc files starting from the folder that the .lua file is in; all files in the ancestry chain will be parsed and their configuration -applied. When multiple files are used, the file closer to the .lua file overrides the settings. - -.luaurc is a JSON file that can also contain comments and trailing commas. The file can have the following keys: - -- `"languageMode"`: type checking mode, can be one of "nocheck", "nonstrict", "strict" -- `"lint"`: lints to enable; points to an object that maps string literals that correspond to the names of linting rules (see https://luau-lang.org/lint), or `"*"` that means "all rules", to a boolean (to enable/disable the lint) -- `"lintErrors"`: a boolean that controls whether lint issues are reported as errors or warnings (off by default) -- `"typeErrors"`: a boolean that controls whether type issues are reported as errors or warnings (on by default) -- `"globals"`: extra global values; points to an array of strings where each string names a global that the type checker and linter must assume is valid and of type `any` - -Example of a valid .luaurc file: - -```json5 -{ - "languageMode": "nonstrict", - "lint": { "*": true, "LocalUnused": false }, - "lintErrors": true, - "globals": ["expect"] // TestEZ -} -``` - -Note that in absence of a configuration file, we will use default settings: languageMode will be set to nonstrict, a set of lint warnings is going to be enabled by default (this proposal doesn't detail that set - that will be subject to a different proposal), type checking issues are going to be treated as errors, lint issues are going to be treated as warnings. - -## Design -- compatibility - -Today we support .robloxrc files; this proposal will keep parsing legacy specification of configuration for compatibility: - -- Top-level `"language"` key can refer to an object that has `"languageMode"` key that also defines language mode -- Top-level `"lint"` object values can refer to a string `"disabled"`/`"enabled"`/`"fatal"` instead of a boolean as a value. - -These keys are only going to be supported for compatibility and only when the file name is .robloxrc (which is only going to be parsed by internal Roblox command line tools but this proposal mentions it for completeness). - -## Drawbacks - -The introduction of configuration files means that it's now impossible to type check or lint sources in isolation, which complicates the code setup. - -File-based JSON configuration may or may not map cleanly to environments that don't support files, such as Roblox Studio. - -Using JSON5 instead of vanilla JSON limits the interoperability. - -There's no way to force specific lints to be fatal, although this can be solved in the future by promoting the "compatibility" feature where one can specify a string to a non-compatibility feature. - -## Alternatives - -It's possible to consider forcing users to specify the source settings via `--!` comments exclusively. This is problematic as it may require excessive amounts of annotation though, which this proposal aims to simplify. - -The format of the configuration file does not have to be JSON; for example, it can be a valid Luau source file which is the approach luacheck takes. This makes it more difficult to repurpose the .luaurc file to use third-party processing tools though, e.g. a package manager would need to learn how to parse Luau syntax to store configuration in .luaurc. - -It's possible to use the old style of lint rule specification with "enabled"/"fatal"/etc., but it's more verbose and is more difficult to use in common scenarios, such as "all enabled lints are fatal and these are the ones we need to enable in addition to the default set" is impossible to specify. diff --git a/rfcs/deprecate-getfenv-setfenv.md b/rfcs/deprecate-getfenv-setfenv.md deleted file mode 100644 index e5af17306..000000000 --- a/rfcs/deprecate-getfenv-setfenv.md +++ /dev/null @@ -1,36 +0,0 @@ -# Deprecate getfenv/setfenv - -## Summary - -Mark getfenv/setfenv as deprecated - -## Motivation - -getfenv and setfenv are problematic for a host of reasons: - -- They allow uncontrolled mutation of global environment, which results in deoptimization; various important performance features -like builtin calls or imports are disabled when these functions are used. -- Because of the uncontrolled mutation code that uses getfenv/setfenv can't be typechecked correctly; in particular, injecting new -globals is going to produce "unknown globals" warnings, and modifying existing globals can trivially violate soundness wrt type -checking -- While these functions can be used for good (once you ignore the issues above), such as custom module systems, statistically speaking -they are mostly used to obfuscate code to hide malicious intent. - -## Design - -We will mark getfenv and setfenv as deprecated. The only consequence of this change is that the linter will start emitting warnings when they are used. - -Removing support for getfenv/setfenv, while tempting, is not planned in the foreseeable future because it will cause significant backwards compatibility issues. - -## Drawbacks - -There are valid uses for getfenv/setfenv, that include extra logging (in Roblox code this manifests as `getfenv(1).script`), monkey patching for mocks in unit tests, and custom -module systems that inject globals into the calling environment. We do have a replacement for logging use cases, `debug.info`, and we do have an officially recommended replacement -for custom module systems, which is to use `require` that doesn't result in issues that fenv modification carries and can be understood by the type checker, we do not have an -alternative for mocks. As such, testing frameworks that implement mocking via setfenv/getfenv will need to use `--!nolint DeprecatedGlobal` to avoid this warning. - -## Alternatives - -Besides the obvious alternative "do nothing", we could also consider implementing Lua 5.2 support for _ENV. However, since we do not have a way to load script files other than -via `require` that doesn't support _ENV, and `loadstring` is supported but discouraged, we do not currently plan to implement `_ENV` although it's possible that this will happen -in the future. diff --git a/rfcs/deprecate-table-getn-foreach.md b/rfcs/deprecate-table-getn-foreach.md deleted file mode 100644 index c6c889dc3..000000000 --- a/rfcs/deprecate-table-getn-foreach.md +++ /dev/null @@ -1,29 +0,0 @@ -# Deprecate table.getn/foreach/foreachi - -## Summary - -Mark table.getn/foreach/foreachi as deprecated - -## Motivation - -`table.getn`, `table.foreach` and `table.foreachi` were deprecated in Lua 5.1 that Luau is based on, and removed in Lua 5.2. - -`table.getn(x)` is equivalent to `rawlen(x)` when `x` is a table; when `x` is not a table, `table.getn` produces an error. It's difficult to imagine code where `table.getn(x)` is better than either `#x` (idiomatic) or `rawlen(x)` (fully compatible replacement). However, `table.getn` is slower and provides yet another way to perform an operation, leading new users of the language to use it unknowingly. - -`table.foreach` is equivalent to a `for .. pairs` loop; `table.foreachi` is equivalent to a `for .. ipairs` loop; both may also be replaced by generalized iteration. Both functions are significantly slower than equivalent `for` loop replacements, are more restrictive because the function can't yield, and result in new users (particularly coming from JS background) unknowingly using these thus producing non-idiomatic non-performant code. - -In both cases, the functions bring no value over other library or language alternatives, and thus just serve as a distraction. - -## Design - -We will mark all three functions as deprecated. The only consequence of this change is that the linter will start emitting warnings when they are used. - -Removing support for these functions doesn't provide any measurable value and as such is not planned in the foreseeable future because it may cause backwards compatibility issues. - -## Drawbacks - -None - -## Alternatives - -If we consider table.getn/etc as supported, we'd want to start optimizing their usage which gets particularly tricky with foreach and requires more compiler machinery than this is probably worth. diff --git a/rfcs/disallow-proposals-leading-to-ambiguity-in-grammar.md b/rfcs/disallow-proposals-leading-to-ambiguity-in-grammar.md deleted file mode 100644 index d9c5c7d73..000000000 --- a/rfcs/disallow-proposals-leading-to-ambiguity-in-grammar.md +++ /dev/null @@ -1,129 +0,0 @@ -# Disallow `name T` and `name(T)` in future syntactic extensions for type annotations - -## Summary - -We propose to disallow the syntax `` `('`` as well as ` ` in future syntax extensions for type annotations to ensure that all existing programs continue to parse correctly. This still keeps the door open for future syntax extensions of different forms such as `` `<' `>'``. - -## Motivation - -Lua and by extension Luau's syntax is very free form, which means that when the parser finishes parsing a node, it doesn't try to look for a semi-colon or any termination token e.g. a `{` to start a block, or `;` to end a statement, or a newline, etc. It just immediately invokes the next parser to figure out how to parse the next node based on the remainder's starting token. - -That feature is sometimes quite troublesome when we want to add new syntax. - -We have had cases where we talked about using syntax like `setmetatable(T, MT)` and `keyof T`. They all look innocent, but when you look beyond that, and try to apply it onto Luau's grammar, things break down really fast. - -### `F(T)`? - -An example that _will_ cause a change in semantics: - -``` -local t: F -(u):m() -``` - -where today, `local t: F` is one statement, and `(u):m()` is another. If we had the syntax for `F(T)` here, it becomes invalid input because it gets parsed as - -``` -local t: F(u) -:m() -``` - -This is important because of the `setmetatable(T, MT)` case: - -``` -type Foo = setmetatable({ x: number }, { ... }) -``` - -For `setmetatable`, the parser isn't sure whether `{}` is actually a type or an expression, because _today_ `setmetatable` is parsed as a type reference, and `({}, {})` is the remainder that we'll attempt to parse as a statement. This means `{ x: number }` is invalid table _literal_. Recovery by backtracking is technically possible here, but this means performance loss on invalid input + may introduce false positives wrt how things are parsed. We'd much rather take a very strict stance about how things get parsed. - -### `F T`? - -An example that _will_ cause a change in semantics: - -``` -local function f(t): F T - (t or u):m() -end -``` - -where today, the return type annotation `F T` is simply parsed as just `F`, followed by a ambiguous parse error from the statement `T(t or u)` because its `(` is on the next line. If at some point in the future we were to allow `T` followed by `(` on the next line, then there's yet another semantic change. `F T` could be parsed as a type annotation and the first statement is `(t or u):m()` instead of `F` followed by `T(t or u):m()`. - -For `keyof`, here's a practical example of the above issue: - -``` -type Vec2 = {x: number, y: number} - -local function f(t, u): keyof Vec2 - (t or u):m() -end -``` - -There's three possible outcomes: - 1. Return type of `f` is `keyof`, statement throws a parse error because `(` is on the next line after `Vec2`, - 2. Return type of `f` is `keyof Vec2` and next statement is `(t or u):m()`, or - 3. Return type of `f` is `keyof` and next statement is `Vec2(t or u):m()` (if we allow `(` on the next line to be part of previous line). - -This particular case is even worse when we keep going: - -``` -local function f(t): F - T(t or u):m() -end -``` - -``` -local function f(t): F T - {1, 2, 3} -end -``` - -where today, `F` is the return type annotation of `f`, and `T(t or u):m()`/`T{1, 2, 3}` is the first statement, respectively. - -Adding some syntax for `F T` **will** cause the parser to change the semantics of the above three examples. - -### But what about `typeof(...)`? - -This syntax is grandfathered in because the parser supported `typeof(...)` before we stabilized our syntax, and especially before type annotations were released to the public, so we didn't need to worry about compatibility here. We are very glad that we used parentheses in this case, because it's natural for expressions to belong within parentheses `()`, and types to belong within angles `<>`. - -## The One Exception with a caveat - -This is a strict requirement! - -`function() -> ()` has been talked about in the past, and this one is different despite falling under the same category as `` `('``. The token `function` is in actual fact a "hard keyword," meaning that it cannot be parsed as a type annotation because it is not an identifier, just a keyword. - -Likewise, we also have talked about adding standalone `function` as a type annotation (semantics of it is irrelevant for this RFC) - -It's possible that we may end up adding both, but the requirements are as such: - 1. `function() -> ()` must be added first before standalone `function`, OR - 2. `function` can be added first, but with a future-proofing parse error if `<` or `(` follows after it - -If #1 is what ends up happening, there's not much to worry about because the type annotation parser will parse greedily already, so any new valid input will remain valid and have same semantics, except it also allows omitting of `(` and `<`. - -If #2 is what ends up happening, there could be a problem if we didn't future-proof against `<` and `(` to follow `function`: - -``` - return f :: function(T) -> U -``` - -which would be a parse error because at the point of `(` we expect one of `until`, `end`, or `EOF`, and - -``` - return f :: function
(a) -> a -``` - -which would also be a parse error by the time we reach `->`, that is the production of the above is semantically equivalent to `(f < a) > (a)` which would compare whether the value of `f` is less than the value of `a`, then whether the result of that value is greater than `a`. - -## Alternatives - -Only allow these syntax when used inside parentheses e.g. `(F T)` or `(F(T))`. This makes it inconsistent with the existing `typeof(...)` type annotation, and changing that over is also breaking change. - -Support backtracking in the parser, so if `: MyType(t or u):m()` is invalid syntax, revert and parse `MyType` as a type, and `(t or u):m()` as an expression statement. Even so, this option is terrible for: - 1. parsing performance (backtracking means losing progress on invalid input), - 2. user experience (why was this annotation parsed as `X(...)` instead of `X` followed by a statement `(...)`), - 3. has false positives (`foo(bar)(baz)` may be parsed as `foo(bar)` as the type annotation and `(baz)` is the remainder to parse) - -## Drawbacks - -To be able to expose some kind of type-level operations using `F` syntax, means one of the following must be chosen: - 1. introduce the concept of "magic type functions" into type inference, or - 2. introduce them into the prelude as `export type F = ...` (where `...` is to be read as "we haven't decided") diff --git a/rfcs/function-bit32-countlz-countrz.md b/rfcs/function-bit32-countlz-countrz.md deleted file mode 100644 index b4ccb1973..000000000 --- a/rfcs/function-bit32-countlz-countrz.md +++ /dev/null @@ -1,52 +0,0 @@ -# bit32.countlz/countrz - -**Status**: Implemented - -## Summary - -Add bit32.countlz (count left zeroes) and bit32.countrz (count right zeroes) to accelerate bit scanning - -## Motivation - -All CPUs have instructions to determine the position of first/last set bit in an integer. These instructions have a variety of uses, the popular ones being: - -- Fast implementation of integer logarithm (essentially allowing to compute `floor(log2(value))` quickly) -- Scanning set bits in an integer, which allows efficient traversal of compact representation of bitmaps -- Allocating bits out of a bitmap quickly - -Today it's possible to approximate `countlz` using `floor` and `log` but this approximation is relatively slow; approximating `countrz` is difficult without iterating through each bit. - -## Design - -`bit32` library will gain two new functions, `countlz` and `countrz`: - -``` -function bit32.countlz(n: number): number -function bit32.countrz(n: number): number -``` - -`countlz` takes an integer number (converting the input number to a 32-bit unsigned integer as all other `bit32` functions do), and returns the number of consecutive left-most zero bits - that is, the number of most significant zero bits in a 32-bit number until the first 1. The result is in `[0, 32]` range. - -For example, when the input number is `0`, it's `32`. When the input number is `2^k`, the result is `31-k`. - -`countrz` takes an integer number (converting the input number to a 32-bit unsigned integer as all other `bit32` functions do), and returns the number of consecutive right-most zero bits - that is, -the number of least significant zero bits in a 32-bit number until the first 1. The result is in `[0, 32]` range. - -For example, when the input number is `0`, it's `32`. When the input number is `2^k`, the result is `k`. - -> Non-normative: a proof of concept implementation shows that a polyfill for `countlz` takes ~34 ns per loop iteration when computing `countlz` for an increasing number sequence, whereas -> a builtin implementation takes ~4 ns. - -## Drawbacks - -None known. - -## Alternatives - -These functions can be alternatively specified as "find the position of the most/least significant bit set" (e.g. "ffs"/"fls" for "find first set"/"find last set"). This formulation -can be more immediately useful since the bit position is usually more important than the number of bits. However, the bit position is undefined when the input number is zero, -returning a sentinel such as -1 seems non-idiomatic, and returning `nil` seems awkward for calling code. Counting functions don't have this problem. - -An early version of this proposal suggested `clz`/`ctz` (leading/trailing) as names; however, using a full verb is more consistent with other operations like shift/rotate, and left/right may be easier to understand intuitively compared to leading/trailing. left/right are used by C++20. - -Of the two functions, `countlz` is vastly more useful than `countrz`; we could implement just `countlz`, but having both is nice for symmetry. diff --git a/rfcs/function-coroutine-close.md b/rfcs/function-coroutine-close.md deleted file mode 100644 index b9ffbf6f6..000000000 --- a/rfcs/function-coroutine-close.md +++ /dev/null @@ -1,36 +0,0 @@ -# coroutine.close - -**Status**: Implemented - -## Summary - -Add `coroutine.close` function from Lua 5.4 that takes a suspended coroutine and makes it "dead" (non-runnable). - -## Motivation - -When implementing various higher level objects on top of coroutines, such as promises, it can be useful to cancel the coroutine execution externally - when the caller is not -interested in getting the results anymore, execution can be aborted. Since coroutines don't provide a way to do that externally, this requires the framework to implement -cancellation on top of coroutines by keeping extra status/token and checking that token in all places where the coroutine is resumed. - -Since coroutine execution can be aborted with an error at any point, coroutines already implement support for "dead" status. If it were possible to externally transition a coroutine -to that status, it would be easier to implement cancellable promises on top of coroutines. - -## Design - -We implement Lua 5.4 behavior exactly with the exception of to-be-closed variables that we don't support. Quoting Lua 5.4 manual: - -> coroutine.close (co) -> Closes coroutine co, that is, puts the coroutine in a dead state. The given coroutine must be dead or suspended. In case of error (either the original error that stopped the coroutine or errors in closing methods), returns false plus the error object; otherwise returns true. - -The `co` argument must be a coroutine object (of type `thread`). - -After closing the coroutine, it gets transitioned to dead state which means that `coroutine.status` will return `"dead"` and attempts to resume the coroutine will fail. In addition, the coroutine stack (which can be accessed via `debug.traceback` or `debug.info`) will become empty. Calling `coroutine.close` on a closed coroutine will return `true` - after closing, the coroutine transitions into a "dead" state with no error information. - -## Drawbacks - -None known, as this function doesn't introduce any existing states to coroutines, and is similar to running the coroutine to completion/error. - -## Alternatives - -Lua's name for this function is likely in part motivated by to-be-closed variables that we don't support. As such, a more appropriate name could be `coroutine.cancel` which also -aligns with use cases better. However, since the semantics is otherwise the same, using the same name as Lua 5.4 reduces library fragmentation. diff --git a/rfcs/function-debug-info.md b/rfcs/function-debug-info.md deleted file mode 100644 index 5f486db4b..000000000 --- a/rfcs/function-debug-info.md +++ /dev/null @@ -1,109 +0,0 @@ -# debug.info - -> Note: this RFC was adapted from an internal proposal that predates RFC process - -**Status**: Implemented - -## Summary - -Add `debug.info` as programmatic debug info access API, similarly to Lua's `debug.getinfo` - -## Motivation - -Today Luau provides only one method to get the callstack, `debug.traceback`. This method traverses the entire stack and returns a string containing the call stack details - with no guarantees about the format of the call stack. As a result, the string doesn't present a formal API and can't be parsed programmatically. - -There are a few cases where this can be inconvenient: - -- Sometimes it is useful to pass the resulting call stack to some system expecting a structured input, e.g. for crash aggregation -- Sometimes it is useful to use the information about the caller for logging or filtering purposes; in these cases using just the script name can be useful, and getting script name out of the traceback is slow and imprecise - -Additionally, in some cases instead of getting the information (such as script or function name) from the callstack, it can be useful to get it from a function object for diagnostic purposes. For example, maybe you want to call a callback and if it doesn't return expected results, display a user-friendly error message that contains the function name & script location - these aren't possible today at all. - -## Design - -The proposal is to expose a function from Lua standard library, `debug.getinfo`, to fix this problem - but change the function's signature for efficiency: - -> debug.info([thread], [function | level], options) -> any... - -(note that the function has been renamed to make it more obvious that the behavior differs from that of Lua) - -The parameters of the function match that of Lua's variant - the first argument is either a function object or a stack level (which is a number starting from 1, where 1 means "my caller"), or a thread (followed by the stack level), followed by a string that contains a list of things the result needs to contain: - - * s - function source identifier, in Roblox environment this is equal to the full name of the script the function is defined in - * l - line number that the function is defined on (when examining a function) or line number of the stack frame (when examining a stack frame) - * n - function name if present; this can be absent for anonymous functions or some C functions that don't have an assigned debug name - * a - function arity information, which refers to the parameter count and whether the function is variadic or not - * f - function object - -Unlike Lua version, which would use the options given to fill a resulting table (e.g. "l" would map to a "currentline" and "linedefined" fields of the output table), our version will return the requested information in the order that it was requested in in the string - all letters specified above map to one extra returned value, "a" maps to a pair of a parameter number and a boolean indicating variadic status. - -For example, here's how you implement a stack trace function: - -``` - for i=1,100 do -- limit at 100 entries for very deep stacks - local source, name, line = debug.info(i, "snl") - if not source then break end - if line >= 0 then - print(string.format("%s(%d): %s", source, line, name or "anonymous")) - else - print(string.format("%s: %s", source, name or "anonymous")) - end - end -``` - -output: - -``` - cs.lua(3): stacktrace - cs.lua(17): bar - cs.lua(13): foo - [C]: pcall - cs.lua(20): anonymous -``` - -When the first argument is a number and the input level is out of bounds, the function returns no values. - -### Why the difference from Lua? - -Lua's variant of this function has the same string as an input and the same thread/function/level combo as arguments before that, but returns a table with the requested data - or nil, when stack is exhausted. - -The problem with this solution is performance. It results in generating excessive garbage by wrapping results in a table, which slows down the function call itself and generates extra garbage that needs to be collected later. This is not a problem for error handling scenarios, but can be an issue when logging is required; for example, `debug.info` with options containing a single result, "s" (mapping to source identifier aka script name), runs 3-4x slower when using a table variant with the current implementation of both functions in our VM. - -While the difference in behavior is unfortunate, note that Lua has a long-standing precedent of using characters in strings to define the set of inputs or outputs for functions; of particular note is string.unpack which closely tracks this proposal where input string characters tell the implementation what data to return. - -### Why not hardcode the options? - -One possibility is that we could return all data associated with the function or a stack frame as a tuple. - -This would work but has issues: - -1. Because of the tuple-like API, the code becomes more error prone and less self-descriptive. -2. Some data is more expensive to access than other data - by forcing all callers to process all possible data we regress in performance; this is also why the original Lua API has an options string - -To make sure we appropriately address 1, unlike Lua API in our API options string is mandatory to specify. - -### Sandboxing risk? - -Compared to information that you can already parse from traceback, the only extra data we expose is the function object. This is valuable when collecting stacks because retrieving the function object is faster than retrieving the associated source/name data - for example a very performant stack tracing implementation could collect data using "fl" (function and line number), and later when it comes the time to display the results, use `debug.info` again with "sn" to get script & name data from the object. - -This technically wasn't possible to get before - this means in particular that if your function is ever called by another function, a malicious script could grab that function object again and call it with different arguments. However given that it's already possible to mutate global environment of any function on the callstack using getfenv/setfenv, the extra risk presented here seems minimal. - -### Options delta from Lua - -Lua presents the following options in getinfo: - -* `n´ selects fields name and namewhat -* `f´ selects field func -* `S´ selects fields source, short_src, what, and linedefined -* `l´ selects field currentline -* `u´ selects field nup - -We chose to omit `namewhat` as it's not meaningful in our implementation, omit `what` as it's redundant wrt source/short_src for C functions, replace source/short_src with only a single option (`s`) to avoid leaking script source via callstack API, remove `u` because there are no use cases for knowing the number of upvalues without debug.getupvalue API, and add `a` which has been requested by Roact team before for complex backwards compatibility workarounds wrt passed callbacks. - -## Drawbacks - -Having a different way to query debug information from Lua requires language-specific dispatch for code that wants to work on Lua and Luau. - -## Alternatives - -We could expose `debug.getinfo` from Lua as is; the problem is that in addition to performance issues highlighted above, Luau implementation doesn't track the same data and as such can't provide a fully compatible implementation short of implementing a shim for the sake of compatibility - an option this proposal keeps open. diff --git a/rfcs/function-string-pack-unpack.md b/rfcs/function-string-pack-unpack.md deleted file mode 100644 index 5315f4c3c..000000000 --- a/rfcs/function-string-pack-unpack.md +++ /dev/null @@ -1,71 +0,0 @@ -# string.pack/unpack/packsize from Lua 5.3 - -> Note: this RFC was adapted from an internal proposal that predates RFC process - -**Status**: Implemented - -## Summary - -Add string pack/unpack from Lua 5.3 for binary interop, with small tweaks to format specification to make format strings portable. - -## Motivation - -While the dominant usecase for Luau is a game programming language, for backend work it's sometimes the case that developers need to work with formats defined outside of Roblox. When these are structured as JSON, it's easy, but if they are binary, it's not. Additionally for the game programming, often developers end up optimizing their data transmission using custom binary codecs where they know the range of the data (e.g. it's much more efficient to send a number using 1 byte if you know the number is between 0 and 1 and 8 bits is enough, but RemoteEvent/etc won't do it for you because it guarantees lossless roundtrip). For both working with external data and optimizing data transfer, it would be nice to have a way to work with binary data. - -This is doable in Luau using `string.byte`/`string.char`/`bit32` library/etc. but tends to be a bit cumbersome. Lua 5.3 provides functions `string.pack`/`string.unpack`/`string.packsize` that, while not solving 100% of the problems, often make working with binary much easier and much faster. This proposal suggests adding them to Luau - this will both further our goal to be reasonably compatible with latest Lua versions, and make it easier for developers to write some types of code. - -## Design - -Concretely, this proposal suggests adding the following functions: - -``` -string.pack (fmt, v1, v2, ···) -``` - -Returns a binary string containing the values v1, v2, etc. packed (that is, serialized in binary form) according to the format string fmt. - -``` -string.packsize (fmt) -``` - -Returns the size of a string resulting from string.pack with the given format. The format string cannot have the variable-length options 's' or 'z'. - -``` -string.unpack (fmt, s [, pos]) -``` - -Returns the values packed in string s (see string.pack) according to the format string fmt. An optional pos marks where to start reading in s (default is 1). After the read values, this function also returns the index of the first unread byte in s. - -The format string is a sequence of characters that define the data layout that is described here in full: https://www.lua.org/manual/5.3/manual.html#6.4.2. We will adopt this wholesale, but we will guarantee that the resulting code is cross-platform by: - -a) Ensuring native endian is little endian (de-facto true for all our platforms) -b) Fixing sizes of native formats to 2b short, 4b int, 8b long -c) Treating `size_t` in context of `T` and `s` formats as a 32-bit integer - -Of course, the functions are memory-safe; if the input string is too short to provide all relevant data they will fail with "data string is too short" error. - -This may seem slightly unconventional but it's very powerful and expressive, in much the same way format strings and regular expressions are :) Here's a basic example of how you might transmit a 3-component vector with this: - -``` --- returns a 24-byte string with 64-bit double encoded three times, similar to how we'd replicate 3 raw numbers -string.pack("ddd", x, y, z) - --- returns a 12-byte string with 32-bit float encoded three times, similar to how we'd replicate Vector3 -string.pack("fff", x, y, z) - --- returns a 3-byte string with each value stored in 8 bits --- assumes -1..1 range; this code doesn't round the right way because I'm too lazy -string.pack("bbb", x * 127, y * 127, z * 127) -``` - -The unpacking of the data is symmetrical - using the same format string and `string.unpack` you get the encoded data back. - -## Drawbacks - -The format specification is somewhat arbitrary and is likely to be unfamiliar to people who come with prior experience in other languages (having said that, this feature closely follows equivalent functionality from Ruby). - -The implementation of string pack/unpack requires yet another format string matcher, which increases complexity of the builtin libraries and static analysis (since we need to provide linting for another format string syntax). - -## Alternatives - -We could force developers to rely on existing functionality for string packing; it is possible to replicate this proposal in a library, although at a much reduced performance. diff --git a/rfcs/function-table-clear.md b/rfcs/function-table-clear.md deleted file mode 100644 index 92279928b..000000000 --- a/rfcs/function-table-clear.md +++ /dev/null @@ -1,21 +0,0 @@ -# table.clear - -> Note: this RFC was adapted from an internal proposal that predates RFC process and as such doesn't follow the template precisely - -**Status**: Implemented - -## Summary - -Add `table.clear` function that removes all elements from the table but keeps internal capacity allocated. - -## Design - -`table.clear` adds a fast way to clear a Lua table. This is effectively a sister function to `table.create()`, only for reclaiming an existing table's memory rather than pre-allocating a new one. Use cases: - -* Often you want to recalculate a set or map data structure based on a table. Currently there is no good way to do this, the fastest way is simply to throw away the old table and construct a new empty one to work with. This is wasteful since often the new structure will take a similar amount of memory to the old one. - -* Sometimes you have a shared table which multiple scripts access. In order to clear this kind of table, you have no other option than to use a slow for loop setting each index to nil. - -These use cases can technically be accomplished via `table.move` moving from an empty table to the table which is to be edited, but I feel that they are frequent enough to warrant a clearer more understandable method which has an opportunity to be more efficient. - -Like `table.move`, does not invoke any metamethods. Not that it would anyways, given that assigning nil to an index never invokes a metamethod. diff --git a/rfcs/function-table-clone.md b/rfcs/function-table-clone.md deleted file mode 100644 index 8cb979845..000000000 --- a/rfcs/function-table-clone.md +++ /dev/null @@ -1,66 +0,0 @@ -# table.clone - -**Status**: Implemented - -## Summary - -Add `table.clone` function that, given a table, produces a copy of that table with the same keys/values/metatable. - -## Motivation - -There are multiple cases today when cloning tables is a useful operation. - -- When working with tables as data containers, some algorithms may require modifying the table that can't be done in place for some reason. -- When working with tables as objects, it can be useful to obtain an identical copy of the object for further modification, preserving the metatable. -- When working with immutable data structures, any modification needs to clone some parts of the data structure to produce a new version of the object. - -While it's possible to implement this function in user code today, it's impossible to implement it with maximum efficiency; furthermore, cloning is a reasonably fundamental -operation so from the ergonomics perspective it can be expected to be provided by the standard library. - -## Design - -`table.clone(t)` takes a table, `t`, and returns a new table that: - -- has the same metatable -- has the same keys and values -- is not frozen, even if `t` was - -The copy is shallow: implementing a deep recursive copy automatically is challenging (for similar reasons why we decided to avoid this in `table.freeze`), and often only certain keys need to be cloned recursively which can be done after the initial clone. - -The table can be modified after cloning; as such, functions that compute a slightly modified copy of the table can be easily built on top of `table.clone`. - -`table.clone(t)` is functionally equivalent to the following code, but it's more ergonomic (on the account of being built-in) and significantly faster: - -```lua -assert(type(t) == "table") -local nt = {} -for k,v in pairs(t) do - nt[k] = v -end -if type(getmetatable(t)) == "table" then - setmetatable(nt, getmetatable(t)) -end -``` - -The reason why `table.clone` can be dramatically more efficient is that it can directly copy the internal structure, preserving capacity and exact key order, and is thus -limited purely by memory bandwidth. In comparison, the code above can't predict the table size ahead of time, has to recreate the internal table structure one key at a time, -and bears the interpreter overhead (which can be avoided for numeric keys with `table.move` but that doesn't work for the general case of dictionaries). - -Out of the abundance of caution, `table.clone` will fail to clone the table if it has a protected metatable. This is motivated by the fact that you can't do this today, so -there are no new potential vectors to escape various sandboxes. Superficially it seems like it's probably reasonable to allow cloning tables with protected metatables, but -there may be cases where code manufactures tables with unique protected metatables expecting 1-1 relationship and cloning would break that, so for now this RFC proposes a more -conservative route. We are likely to relax this restriction in the future. - -## Drawbacks - -Adding a new function to `table` library theoretically increases complexity. In practice though, we already effectively implement `table.clone` internally for some VM optimizations, so exposing this to the users bears no cost. - -Assigning a type to this function is a little difficult if we want to enforce the "argument must be a table" constraint. It's likely that we'll need to type this as `table.clone(T): T` for the time being, which is less precise. - -## Alternatives - -We can implement something similar to `Object.assign` from JavaScript instead, that simultaneously assigns extra keys. However, this won't be fundamentally more efficient than -assigning the keys afterwards, and can be implemented in user space. Additionally, we can later extend `clone` with an extra argument if we so choose, so this proposal is the -minimal viable one. - -We can immediately remove the rule wrt protected metatables, as it's not clear that it's actually problematic to be able to clone tables with protected metatables. diff --git a/rfcs/function-table-create-find.md b/rfcs/function-table-create-find.md deleted file mode 100644 index 671e16afd..000000000 --- a/rfcs/function-table-create-find.md +++ /dev/null @@ -1,28 +0,0 @@ -# table.create and table.find - -> Note: this RFC was adapted from an internal proposal that predates RFC process and as such doesn't follow the template precisely - -**Status**: Implemented - -## Design - -This proposal suggests adding two new builtin table functions: - -`table.create(count, value)`: Creates an array with count values, initialized to value. This can be useful to preallocate large tables - repeatedly appending an element to the table repeatedly reallocates it. count is converted to an integer using standard conversion/coercion rules (strings are converted to doubles, doubles are converted to integers using truncation). Negative counts result in the function failing. Positive counts that are too large and would cause a heap allocation error also result in function failing. When value is nil or omitted, table is preallocated without storing anything in it - this is roughly equivalent to creating a large table literal filled with `nil`, or preallocating a table by assigning a sufficiently large numeric index to a value and then erasing it by reassigning it to nil. - -`table.find(table, value [, init])`: Looks for value in the array part of the table; returns index of first occurrence or nil if value is not found. Comparison is performed using standard equality (non-raw) to make sure that objects like Vector3 etc. can be found. The first nil value in the array part of the table terminates the traversal. init is an optional numeric index where the search starts and it defaults to 1; this can be useful to go through repeat occurrences. - -`table.create` can not be replicated efficiently in Lua at all; `table.find` is provided as a faster and more convenient option compared to the code above. - -`table.find` is roughly equivalent to the following code modulo semantical oddities with #t and performance: - -``` -function find(table, value, init) - for i=init or 1, #table do - if rawget(table, i) == value then - return i - end - end - return nil -end -``` diff --git a/rfcs/function-table-freeze.md b/rfcs/function-table-freeze.md deleted file mode 100644 index ca819882e..000000000 --- a/rfcs/function-table-freeze.md +++ /dev/null @@ -1,55 +0,0 @@ -# table.freeze - -> Note: this RFC was adapted from an internal proposal that predates RFC process - -**Status**: Implemented - -## Summary - -Add `table.freeze` which allows to make a table read-only in a shallow way. - -## Motivation - -Lua tables by default are freely modifiable in every possible way: you can add new fields, change values for existing fields, or set or unset the metatable. - -Today it is possible to customize the behavior for *adding* new fields by setting a metatable that overrides `__newindex` (including setting `__newindex` to a function that always errors to prohibit additions of new fields). - -Today it is also possible to customize the behavior of setmetatable by "locking" the metatable - this can be achieved by setting a meta-index `__metatable` to something, which would block setmetatable from functioning and force metatable to return the provided value. With this it's possible to prohibit customizations of a table's behavior, but existing fields can still be assigned to. - -To make an existing table read-only, one needs to combine these mechanisms, by creating a new table with a locked metatable, which has an `__index` function pointing to the old table. However, this results in iteration and length operator not working on the resulting table, and carries a performance cost - both for creating the table, and for repeated property access. - -## Design - -This proposal proposes formalizing the notion of "read-only" tables by providing two new table functions: - -- `table.freeze(t)`: given a non-frozen table t, freezes it; fails when t is not a table or is already frozen. Returns t. -- `table.isfrozen(t)`: given a table t, returns a boolean indicating the frozen status; fails when t is not a table. - -When a table is frozen, the following is true: - -- Attempts to modify the existing keys of the table fail (regardless of how they are performed - via table assignments, rawset, or any other methods like table.sort) -- Attempts to add new keys to the table fail, unless `__newindex` is defined on the metatable (in which case the assignment is routed through `__newindex` as usual) -- Attempts to change the metatable of the table fail -- Reading the table fields or iterating through the table proceeds as usual - -This feature is useful for two reasons: - -a) It allows an easier way to expose sandboxed objects that aren't possible to monkey-patch for security reasons. We actually already have support for freezing and use it internally on various builtin tables like `math`, we just don't expose it to Lua. - -b) It allows an easier way to expose immutable objects for consistency/correctness reasons. For example, Cryo library provides an implementation of immutable data structures; with this functionality, it's possible to implement a lighter-weight library by, for example, extending a table with methods to return mutated versions of the table, but retaining the usual table interface - -To limit the use of `table.freeze` to cases when table contents can be freely manipulated, `table.freeze` shall fail when the table has a locked metatable (but will succeed if the metatable isn't locked). - -## Drawbacks - -Exposing the internal "readonly" feature may have an impact on interoperability between scripts - for example, it becomes possible to freeze some tables that scripts may be expecting to have write access to from other scripts. Since we don't provide a way to unfreeze tables and freezing a table with a locked metatable fails, in theory the impact should not be any worse than allowing to change a metatable, but the full extents are unclear. - -There may be existing code in the VM that allows changing frozen tables in ways that are benign to the current sandboxing code, but expose a "gap" in the implementation that becomes significant with this feature; thus we would need to audit all table writes when implementing this. - -## Alternatives - -We've considered exposing a recursive freeze. The correct generic implementation is challenging since it requires supporting infinitely nested tables when working on the C stack (or a stackless implementation that requires heap allocation); also, to handle self-recursive tables requires a separate temporary tracking table since stopping the traversal at frozen sub-tables is insufficient as their children may not have been frozen. As such, we leave recursive implementation to user code. - -We've considered exposing thawing. The problem with this is that freezing is required for sandboxing, and as such we'd need to support "permafrozen" status that is separate from "frozen". This complicates implementation and we didn't find compelling use cases for thawing - if it becomes necessary we can always expose it separately. - -We've considered calling this "locking", but the term has connotations coming from multithreading that aren't applicable here, and in absence of unlocking, "locking" makes a bit less sense. diff --git a/rfcs/generalized-iteration.md b/rfcs/generalized-iteration.md deleted file mode 100644 index c28156ff9..000000000 --- a/rfcs/generalized-iteration.md +++ /dev/null @@ -1,126 +0,0 @@ -# Generalized iteration - -**Status**: Implemented - -## Summary - -Introduce support for iterating over tables without using `pairs`/`ipairs` as well as a generic customization point for iteration via `__iter` metamethod. - -## Motivation - -Today there are many different ways to iterate through various containers that are syntactically incompatible. - -To iterate over arrays, you need to use `ipairs`: `for i, v in ipairs(t) do`. The traversal goes over a sequence `1..k` of numeric keys until `t[k] == nil`, preserving order. - -To iterate over dictionaries, you need to use `pairs`: `for k, v in pairs(t) do`. The traversal goes over all keys, numeric and otherwise, but doesn't guarantee an order; when iterating over arrays this may happen to work but is not guaranteed to work, as it depends on how keys are distributed between array and hash portion. - -To iterate over custom objects, whether they are represented as tables (user-specified) or userdata (host-specified), you need to expose special iteration methods, for example `for k, v in obj:Iterator() do`. - -All of these rely on the standard Lua iteration protocol, but it's impossible to trigger them in a generic fashion. Additionally, you *must* use one of `pairs`/`ipairs`/`next` to iterate over tables, which is easy to forget - a naive `for k, v in tab do` doesn't work and produces a hard-to-understand error `attempt to call a table value`. - -This proposal solves all of these by providing a way to implement uniform iteration with self-iterating objects by allowing to iterate over objects and tables directly via convenient `for k, v in obj do` syntax, and specifies the default iteration behavior for tables, thus mostly rendering `pairs`/`ipairs` obsolete - making Luau easier to use and teach. - -## Design - -In Lua, `for vars in iter do` has the following semantics (otherwise known as the iteration protocol): `iter` is expanded into three variables, `gen`, `state` and `index` (using `nil` if `iter` evaluates to fewer than 3 results); after this the loop is converted to the following pseudocode: - -```lua -while true do - vars... = gen(state, index) - index = vars... -- copy the first variable into the index - if index == nil then break end - - -- loop body goes here -end -``` - -This is a general mechanism that can support iteration through many containers, especially if `gen` is allowed to mutate state. Importantly, the *first* returned variable (which is exposed to the user) is used to continue the process on the next iteration - this can be limiting because it may require `gen` or `state` to carry extra internal iteration data for efficiency. To work around this for table iteration to avoid repeated calls to `next`, Luau compiler produces a special instruction sequence that recognizes `pairs`/`ipairs` iterators and stores the iteration index separately. - -Thus, today the loop `for k, v in tab do` effectively executes `k, v = tab()` on the first iteration, which is why it yields `attempt to call a table value`. If the object defines `__call` metamethod then it can act as a self-iterating method, but this is not idiomatic, not efficient and not pure/clean. - -This proposal comes in two parts: general support for `__iter` metamethod and default implementation for tables without one. With both of these in place, there's going to be a single, idiomatic, general and performant way to iterate through the object of any type: - -```lua -for k, v in obj do -... -end -``` - -### __iter - -To support self-iterating objects, we modify the iteration protocol as follows: instead of simply expanding the result of expression `iter` into three variables (`gen`, `state` and `index`), we check if the first result has an `__iter` metamethod (which can be the case if it's a table, userdata or another composite object (e.g. a record in the future). If it does, the metamethod is called with `gen` as the first argument, and the returned three values replace `gen`/`state`/`index`. This happens *before* the loop: - -```lua -local genmt = rawgetmetatable(gen) -- pseudo code for getmetatable that bypasses __metatable -local iterf = genmt and rawget(genmt, "__iter") -if iterf then - gen, state, index = iterf(gen) -end -``` - -This check is comparatively trivial: usually `gen` is a function, and functions don't have metatables; as such we can simply check the type of `gen` and if it's a table/userdata, we can check if it has a metamethod `__iter`. Due to tag-method cache, this check is also very cheap if the metamethod is absent. - -This allows objects to provide a custom function that guides the iteration. Since the function is called once, it is easy to reuse other functions in the implementation, for example here's a node object that exposes iteration through its children: - -```lua -local Node = {} -Node.__index = Node - -function Node.new(children) - return setmetatable({ children = children }, Node) -end - -function Node:__iter() - return next, self.children -end -``` - -Luau compiler already emits a bytecode instruction, FORGPREP*, to perform initial loop setup - this is where we can evaluate `__iter` as well. - -Naturally, this means that if the table has `__iter` metamethod and you need to iterate through the table fields instead of using the provided metamethod, you can't rely on the general iteration scheme and need to use `pairs`. This is similar to other parts of the language, like `t[k]` vs `rawget(t, 'k')`, where the default behavior is overrideable but a library function can help peek behind the curtain. - -### Default table iteration - -If the argument is a table and it does not implement `__iter` metamethod, we treat this as an attempt to iterate through the table using the builtin iteration order. - -> Note: we also check if the table implements `__call`; if it does, we fall back to the default handling. We may be able to remove this check in the future, but we will need this initially to preserve backwards compatibility with custom table-driven iterator objects that implement `__call`. In either case, we will be able to collect detailed analytics about the use of `__call` in iteration, and if neither is present we can emit a specialized error message such as `object X is not iteratable`. - -To have a single, unified, iteration scheme over tables regardless of whether they are arrays or dictionaries, we establish the following semantics: - -- First, the traversal goes over numeric keys in range `1..k` up until reaching the first `k` such that `t[k] == nil` -- Then, the traversal goes over the remaining keys (with non-nil values), numeric and otherwise, in unspecified order. - -For arrays with gaps, this iterates until the first gap in order, and the remaining order is not specified. - -> Note: This behavior is similar to what `pairs` happens to provide today, but `pairs` doesn't give any guarantees, and it doesn't always provide this behavior in practice. - -To ensure that this traversal is performant, the actual implementation of the traversal involves going over the array part (in index order) and then over the hash part (in hash order). For that implementation to satisfy the criteria above, we need to make two additional changes to table insertion/rehash: - -- When inserting key `k` in the table when `k == t->sizearray + 1`, we force the table to rehash (resize its array portion). Today this is only performed if the hash portion is full, as such sometimes numeric keys can end up in the hash part. -- When rehashing the table, we ensure that the hash part doesn't contain the key `newsizearray + 1`. This requires checking if the table has this key, which may require an additional hash lookup but we only need to do this in rare cases based on the analysis of power-of-two key buckets that we already collect during rehash. - -These changes guarantee that the order observed via standard traversal with `next`/`pairs` matches the guarantee above, which is nice because it means we can minimize the complexity cost of this change by reusing the traversal code, including VM optimizations. They also mean that the array boundary (aka `#t`) can *always* be computed from just the array portion, which simplifies the table length computation and may slightly speed it up. - -## Drawbacks - -This makes `for` desugaring and implementation a little more complicated; it's not a large complexity factor in Luau because we already have special handling for `for` loops in the VM, but it's something to keep in mind. - -While the proposed iteration scheme should be a superset to both `pairs` and `ipairs` for tables, for arrays `ipairs` may in some cases be faster because it stops at the first `nil`, whereas the proposed new scheme (like `pairs`) needs to iterate through the rest of the table's array storage. This may be fixable in the future, if we replace our cached table length (`aboundary`) with Lua 5.4's `alimit`, which maintains the invariant that all values after `alimit` in the array are `nil`. This would make default table iteration maximally performant as well as help us accelerate GC in some cases, but will require extra checks during table assignments which is a cost we may not be willing to pay. Thus it is theoretically possible that we will end up with `ipairs` being a slightly faster equivalent for array iteration forever. - -The resulting iteration behavior, while powerful, increases the divergence between Luau and Lua, making more programs that are written for Luau not runnable in Lua. Luau language in general does not consider this type of compatibility essential, but this is noted for posterity. - -The changes in insertion behavior that facilitate single iteration order may have a small cost; that said, they are currently understood to belong to paths that are already slow and the added cost is minimal. - -The extra semantics will make inferring the types of the variables in a for loop more difficult - if we know the type of the expression that is being iterated through it probably is not a problem though. - -## Alternatives - -Other major designs have been considered. - -A minor variation of the proposal involves having `__iter` be called on every iteration instead of at loop startup, effectively having `__iter` work as an alternative to `__call`. The issue with this variant is that while it's a little simpler to specify and implement, it restricts the options when implementing custom iteratable objects, because it would be difficult for iteratable objects to store custom iteration state elsewhere since `__iter` method would effectively need to be pure, as it can't modify the object itself as more than one concurrent iteration needs to be supported. - -A major variation of the proposal involves instead supporting `__pairs` from Lua 5.2. The issue with this variant is that it still requires the use of a library method, `pairs`, to work, which doesn't make the language simpler as far as table iteration, which is the 95% case, is concerned. Additionally, with some rare exceptions metamethods today extend the *language* behavior, not the *library* behavior, and extending extra library functions with metamethods does not seem true to the core of the language. Finally, this only works if the user uses `pairs` to iterate and doesn't work with `ipairs`/`next`. - -Another variation involves using a new pseudo-keyword, `foreach`, instead of overloading existing `for`, and only using the new `__iter` semantics there. This can more cleanly separate behavior, requiring the object to have an `__iter` metamethod (or be a table) in `foreach` - which also avoids having to deal with `__call` - but it also requires teaching the users a new keyword which fragments the iteration space a little bit more. Compared to that, the main proposal doesn't introduce new divergent syntax, and merely tweaks existing behavior to be more general, thus making an existing construct easier to use. - -Finally, the author also considered and rejected extending the iteration protocol as part of this change. One problem with the current protocol is that the iterator requires an allocation (per loop execution) to keep extra state that isn't exposed to the user. The builtin iterators like `pairs`/`ipairs` work around this by feeding the user-visible index back to the search function, but that's not always practical. That said, having a different iteration protocol in effect only when `__iter` is used makes the language more complicated for unclear efficiency gains, thus this design doesn't suggest a new core protocol to favor simplicity. diff --git a/rfcs/generic-function-subtyping.md b/rfcs/generic-function-subtyping.md deleted file mode 100644 index b9c0c430b..000000000 --- a/rfcs/generic-function-subtyping.md +++ /dev/null @@ -1,211 +0,0 @@ -# Expanded Subtyping for Generic Function Types - -## Summary - -Extend the subtyping relation for function types to relate generic function -types with compatible instantiated function types. - -## Motivation - -As Luau does not have an explicit syntax for instantiation, there are a number -of places where the typechecker will automatically perform instantiation with -the goal of permitting more programs. These instances of instantiation are -ad-hoc and strategic, but useful in practice for permitting programs such as: - -```lua -function id(x: T): T - return x -end - -local idNum : (number) -> number -idNum = id -- ok -``` - -However, they have also been a source of some typechecking bugs because of how -they actually make a determination as to whether the instantation should happen, -and they currently open up some potential soundness holes when instantiating -functions in table types since properties of tables are mutable and thus need to -be invariant (which the automatic-instantiation potentially masks). - -## Design - -The goal then is to rework subtyping to support the relationship we want in the -first place: allowing polymorphic functions to be used where instantiated -functions are expected. In particular, this means adding instantiation itself to -the subtyping relation. Formally, that'd look something like: - -``` -instantiate((T1) -> T2) = (T1') -> T2' -(T1') -> T2' <: (T3) -> T4 --------------------------------------------- -(T1) -> T2 <: (T3) -> T4) -``` - -Or informally, we'd say that a generic function type is a subtype of another -function type if we can instantiate it and show that instantiated function type -to be a subtype of the original function type. Implementation-wise, this loose -formal rule suggests a strategy of when we'll want to apply instantiation. -Namely, whenever the subtype and supertype are both functions with the potential -subtype having some generic parameters and the supertype having none. So, if we -look once again at our simple example from motivation, we can walk through how -we expect it to type check: - -```lua -function id(x: T): T - return x -end - -local idNum : (number) -> number -idNum = id -- ok -``` - -First, `id` is given the type `(T) -> T` and `idNum` is given the type -`(number) -> number`. When we actually perform the assignment, we must show that -the type of the right-hand side is compatible with the type of the left-hand -side according to subtyping. That is, we'll ask if `(T) -> T` is a subtype of -`(number) -> number` which matches the rule to apply instantiation since the -would-be subtype has a generic parameter while the would-be supertype has no -generic parameters. This contrasts with the current implementation which, before -asking the subtyping question, checks if the type of the right-hand side -contains any generics at any point and if the type of the left-hand side cannot -_possibly_ contain generics and instantiates the right-hand side if so. - -Adding instantiation to subtyping does pose some additional questions still -about when exactly to instantiate. Namely, we need to consider cases like -function application. We can see why by looking at some examples: - -```lua -function rank2(f: (a) -> a): (number) -> number - return f -end -``` - -In this case, we expect to allow the instantiation of `f` from `(a) -> a` to -`(number) -> number`. After all, we can consider other cases like where the body -instead applies `f` to some particular value, e.g. `f(42)`, and we'd want the -instantiation to be allowed there. However, this means we'd potentially run into -issues if we allowed call sites to `rank2` to pass in non-polymorphic functions. -A naive approach to implementing this proposal would do exactly that because we -currently treat contravariant subtyping positions (i.e. for the arguments of -functions) as being the same as our normal (i.e. covariant) subtyping relation -but with the arguments reversed. So, to type check an application like -`rank2(function(str: string) return str + "s" end)` (where the function argument -is of type `(string) -> string`), we would ask if `(a) -> a` is a subtype of -`(string) -> string`. This is precisely the question we asked in the original -example, but in the contravariant context, this is actually unsound since -`rank2` would then function as a general coercion from, e.g., -`(string) -> string` to `(number) -> number`. - -This sort of behavior does come up in other languages that mix polymorphism and -subtyping. If we consider the same example in F#, we can compare its behavior: - -```fsharp -let ranktwo (f : 'a -> 'a) : int -> int = f -let pluralize (s : string) : string = s + "s" -let x = ranktwo pluralize -``` - -For this example, F# produces one warning and one error. The warning is applied -to the function definition of `ranktwo` itself (coded `FS0064`), and says "This -construct causes code to be less generic than indicated by the type annotations. -The type variable 'a has been constrained to be type 'int'." This warning -highlights the actual difference between our example in Luau and the F# -translation. In F#, `'a` is really a free type variable, rather than a generic -type parameter of the function `ranktwo`, as such, this code actually -constrains the type of `ranktwo` to be `(int -> int) -> (int -> int)`. As such, -the application on line 3 errors because our `(string -> string)` function is -simply not compatible with that type. With higher-rank polymorphic function -parameters, it doesn't make sense to warn on their instantiation (as illustrated -by the example of actually applying `f` to some particular data in the -definition of `rank2`), but it's still just as problematic if we were to accept -instantiated functions at polymorphic types. Thus, it's important that we -actually ensure that we only instantiate in covariant contexts. So, we must -ensure that subtyping only instantiates in covariant contexts. - -It may also be helpful to consider an example of rank-1 polymorphism to -understand the full scope of the behavior. So, we can look at what happens if we -simply move the type parameter out in our working example: - -```lua -function rank1(f: (a) -> a): (number) -> number - return f -end -``` - -In this case, we expect an error to occur because the type of `f` depends on -what we instantiate `rank1` with. If we allowed this, it would naturally be -unsound because we could again provide a `(string) -> string` argument (by -instantiating `a` with `string`). This reinforces the idea that the presence of -the generic type parameter is likely to be a good option for determining -instantiation (at least when compared to the presence of free type variables). - -## Drawbacks - -One of the aims of this proposal is to provide a clear and predictable mental -model of when instantiation will take place in Luau. The author feels this -proposal is step forward compared to the existing ad-hoc usage of instantiation -in the typechecker, but it's possible that programmers are already comfortable -with the mental model they have built for the existing implementation. -Hopefully, this is mitigated by the fact that the new setup should allow all of -the _sound_ uses of instantiation permitted by the existing system. Notably, -however, programmers may be surprised by the added restriction when it comes to -properties in tables. In particular, we can consider a small variation of our -original example with identity functions: - -```lua -function id(x: T): T - return x -end - -local poly : { id : (a) -> a } = { id = id } - -local mono : { id : (number) -> number } -mono = poly -- error! -mono.id = id -- also an error! -``` - -In this case, the fact that we're dealing with a _property_ of a table type -means that we're in a context that needs to be invariant (i.e. not allow -subtyping) to avoid unsoundness caused by interactions between mutable -references and polymorphism (see things like the [value -restriction in OCaml][value-restriction] to understand why). In most cases, we -believe programmers will be using functions in tables as an implementation of -methods for objects, so we don't anticipate that they'll actually _want_ to do -the unsound thing here. The accepted RFC for [read-only -properties][read-only-props] gives us a technically-precise solution since -read-only properties would be free to be typechecked as a covariant context -(since they disallow mutation), and thus if the property `id` was marked -read-only, we'd be able to do both of the assignments in the above example. - -## Alternatives - -The main alternatives would likely be keeping the existing solution (and -likely having to tactically fix future bugs where instantiation either happens -too much or not enough), or removing automatic instantiation altogether in favor -of manual instantiation syntax. The former solution (changing nothing) is cheap -now (both in terms of runtime performance and also development cost), but the -existing implementation involves extra walks of both types to make a decision -about whether or not to perform instantiation. To minimize the performance -impact, the functions that perform these questions (`isGeneric` and -`maybeGeneric`) actually do not perform a full walk, and instead try to -strategically look at only enough to make the decision. We already found and -fixed one bug that was caused by these functions being too imprecise against -their spec, but fleshing them out entirely could potentially be a noticeable -performance regression since the decision to potentially instantiate is one that -comes up often. - -Removing automatic instantiation altogether, by contrast, will definitely be -"correct" in that we'll never instantiate in the wrong spot and programmers will -always have the ability to instantiate, but it would be a marked regression on -developer experience since it would increase the annotation burden considerably -and generally runs counter to the overall design strategy of Luau (which focuses -heavily on type inference). It would also require us to actually pick a syntax -for manual instantiation (which we are still open to do in the future if we -maintain an automatic instantiation solution) which is frought with parser -ambiguity issues or requires the introduction of a sigil like Rust's turbofish -for instantiation. Discussion of that syntax is present in the [generic -functions][generic-functions] RFC. - -[value-restriction]: https://stackoverflow.com/questions/22507448/the-value-restriction#22507665 -[read-only-props]: https://github.com/Roblox/luau/blob/master/rfcs/property-readonly.md -[generic-functions]: https://github.com/Roblox/luau/blob/master/rfcs/generic-functions.md diff --git a/rfcs/generic-functions.md b/rfcs/generic-functions.md deleted file mode 100644 index 3ac1bbba3..000000000 --- a/rfcs/generic-functions.md +++ /dev/null @@ -1,155 +0,0 @@ -# Generic functions - -**Status**: Implemented - -## Summary - -Extend the syntax and semantics of functions to support explicit generic functions, which can bind type parameters as well as data parameters. - -## Motivation - -Currently Luau allows generic functions to be inferred but not given explicit type annotations. For example - -```lua -function id(x) return x end -local x: string = id("hi") -local y: number = id(37) -``` - -is fine, but there is no way for a user to write the type of `id`. - -## Design - -Allow functions to take type parameters as well as function parameters, similar to Java/Typescript/... - -```lua -function id(x : a) : a return x end -``` - -Functions may also take generic type pack arguments for varargs, for instance: - -```lua -function compose(... : a...) -> (a...) return ... end -``` - -Generic type and type pack parameters can also be used in function types, for instance: - -```lua -local id: (a)->a = function(x) return x end -``` - -This change is *not* only syntax, as explicit type parameters need to be part of the semantics of types. For example, we can define a generic identity function - -```lua -local function id(x) return x end -local x: string = id("hi") -local y: number = id(37) -type Id = typeof(id) -``` - -and two functions - -```lua -function f() - return id -end -function g() - local y - function oh(x) - if not(y) then y = x end - return y - end - return oh -end -``` - -The types of these functions are - -```lua - f : () -> (a) -> a - g : () -> (a) -> a -``` - -so this is okay: - -```lua - local i: Id = f() - local x: string = i("hi") - local y: number = i(37) -``` - -but this is not: - -```lua - -- This assignment shouldn't typecheck! - local i: Id = g() - local x: string = i("hi") - -- This is unsound, since it assigns a string to a variable of type number - local y: number = i(37) -``` - -Currently, Luau does not have explicit type binders, so `f` and `g` have the same type. We propose making type binders part of the semantics of types as well as their syntax (so `f` and `g` have different types, and the unsound example does not typecheck). - -We propose supporting type parameters which can be instantiated with any type (jargon: Rank-N Types) but not type functions (jargon: Higher Kinded Types) or types with constraints (jargon: F-bounded polymorphism). - -## Turbofish - -Note that this RFC proposes a syntax for adding generic parameters to functions, but it does *not* propose syntax for adding generic arguments to function call site. For example, for `id` function you *can* write: - -```lua - -- generic type gets inferred as a number in all these cases -local x = id(4) -local x = id(y) :: number -local x: number = id(y) -``` - -but you can *not* write `id(y)`. - -This syntax is difficult to parse as it's ambiguous wrt grammar for comparison, and disambiguating it requires being able to parse types in expression context which makes parsing slow and complicated. It's also worth noting that today there are programs with this syntax that are grammatically correct (eg `id('4')` parses as "compare variable `id` to variable `string`, and compare the result to string '4'"). The specific example with a single argument will always fail at runtime because booleans can't be compared with relational operators, but multi-argument cases such as `print(foo(4))` can execute without errors in certain cases. - -Note that in many cases the types can be inferred, whether through function arguments (`id(4)`) or through expected return type (`id(y) :: number`). It's also often possible to cast the function object to a given type, even though that can be unwieldy (`(id :: (number)->number)(y)`). Some languages don't have a way to specify the types at call site either, Swift being a prominent example. Thus it's not a given we need this feature in Luau. - -If we ever want to implement this though, we can use a solution inspired by Rust's turbofish and require an extra token before `<`. Rust uses `::<` but that doesn't work in Luau because as part of this RFC, `id::(a)->a` is a valid, if redundant, type ascription, so we need to choose a different prefix. - -The following two variants are grammatically unambiguous in expression context in Luau, and are a better parallel for Rust's turbofish (in Rust, `::` is more similar to Luau's `:` or `.` than `::`, which in Rust is called `as`): - -```lua -foo:() -- require : before <; this is only valid in Luau in variable declaration context, so it's safe to use in expression context -foo.() -- require . before <; this is currently never valid in Luau -``` - -This RFC doesn't propose using either of these options, but notes that either one of these options is possible to specify & implement in the future if we so desire. - -## Drawbacks - -This is a breaking change, in that examples like the unsound program above will no longer typecheck. - -Types become more complex, so harder for programmers to reason about, and adding to their space usage. This is particularly noticeable anywhere the typechecker has exponential blowup, since small increases in type size can result in large increases in space or time usage. - -Not having higher-kinded types stops some examples which are parameterized on container types, for example: - -```lua - function g(f : (a) -> c) : (b) -> c> - return function(x) return f(f(x)) end - end -``` - -Not having bounded types stops some examples like giving a type to the function that sums an non-empty array: - -```lua - function sum(xs) - local result = x[0] - for i=1,#xs - result += x[i] - end - return result - end -``` - -## Alternatives - -We did originally consider Rank-1 types, but the problem is that's not backward-compatible, as DataBrain pointed out in the [Dev Forum](https://devforum.roblox.com/t/luau-recap-march-2021/1141387/29), since `typeof` allows users to construct generic types even without syntax for them. Rank-1 types give a false positive type error in this case, which comes from deployed code. - -We could introduce syntax for generic types without changing the semantics, but then there'd be a gap between the syntax (where the types `() -> (a) -> a` and `() -> (a) -> a` are different) and the semantics (where they are not). As noted above, this isn't sound. - -Rather than using Rank-N types, we could use SML-style polymorphism, but this would need something like the [value restriction](http://users.cis.fiu.edu/~smithg/cop4555/valrestr.html) to be sound. diff --git a/rfcs/len-metamethod-rawlen.md b/rfcs/len-metamethod-rawlen.md deleted file mode 100644 index 60278dda0..000000000 --- a/rfcs/len-metamethod-rawlen.md +++ /dev/null @@ -1,45 +0,0 @@ -# Support `__len` metamethod for tables and `rawlen` function - -**Status**: Implemented - -## Summary - -`__len` metamethod will be called by `#` operator on tables, matching Lua 5.2 - -## Motivation - -Lua 5.1 invokes `__len` only on userdata objects, whereas Lua 5.2 extends this to tables. In addition to making `__len` metamethod more uniform and making Luau -more compatible with later versions of Lua, this has the important advantage which is that it makes it possible to implement an index based container. - -Before `__iter` and `__len` it was possible to implement a custom container using `__index`/`__newindex`, but to iterate through the container a custom function was -necessary, because Luau didn't support generalized iteration, `__pairs`/`__ipairs` from Lua 5.2, or `#` override. - -With generalized iteration, a custom container can implement its own iteration behavior so as long as code uses `for k,v in obj` iteration style, the container can -be interfaced with the same way as a table. However, when the container uses integer indices, manual iteration via `#` would still not work - which is required for some -more complicated algorithms, or even to simply iterate through the container backwards. - -Supporting `__len` would make it possible to implement a custom integer based container that exposes the same interface as a table does. - -## Design - -`#v` will call `__len` metamethod if the object is a table and the metamethod exists; the result of the metamethod will be returned if it's a number (an error will be raised otherwise). - -`table.` functions that implicitly compute table length, such as `table.getn`, `table.insert`, will continue using the actual table length. This is consistent with the -general policy that Luau doesn't support metamethods in `table.` functions. - -A new function, `rawlen(v)`, will be added to the standard library; given a string or a table, it will return the length of the object without calling any metamethods. -The new function has the previous behavior of `#` operator with the exception of not supporting userdata inputs, as userdata doesn't have an inherent definition of length. - -## Drawbacks - -`#` is an operator that is used frequently and as such an extra metatable check here may impact performance. However, `#` is usually called on tables without metatables, -and even when it is, using the existing metamethod-absence-caching approach we use for many other metamethods a test version of the change to support `__len` shows no -statistically significant difference on existing benchmark suite. This does complicate the `#` computation a little more which may affect JIT as well, but even if the -table doesn't have a metatable the process of computing `#` involves a series of condition checks and as such will likely require slow paths anyway. - -This is technically changing semantics of `#` when called on tables with an existing `__len` metamethod, and as such has a potential to change behavior of an existing valid program. -That said, it's unlikely that any table would have a metatable with `__len` metamethod as outside of userdata it would not anything, and this drawback is not feasible to resolve with any alternate version of the proposal. - -## Alternatives - -Do not implement `__len`. diff --git a/rfcs/lower-bounds-calculation.md b/rfcs/lower-bounds-calculation.md deleted file mode 100644 index 7208bf1a2..000000000 --- a/rfcs/lower-bounds-calculation.md +++ /dev/null @@ -1,219 +0,0 @@ -# Lower Bounds Calculation - -**Status**: Abandoned in favor of a future design for full local inference - -## Summary - -We propose adapting lower bounds calculation from Pierce's Local Type Inference paper into the Luau type inference algorithm. - -https://www.cis.upenn.edu/~bcpierce/papers/lti-toplas.pdf - -## Motivation - -There are a number of important scenarios that occur where Luau cannot infer a sensible type without annotations. - -Many of these revolve around type variables that occur in contravariant positions. - -### Function Return Types - -A very common thing to write in Luau is a function to try to find something in some data structure. These functions habitually return the relevant datum when it is successfully found, or `nil` in the case that it cannot. For instance: - -```lua --- A.lua -function find_first_if(vec, f) - for i, e in ipairs(vec) do - if f(e) then - return i - end - end - - return nil -end -``` - -This function has two `return` statements: One returns `number` and the other `nil`. Today, Luau flags this as an error. We ask authors to add a return annotation to make this error go away. - -We would like to automatically infer `find_first_if : ({T}, (T) -> boolean) -> number?`. - -Higher order functions also present a similar problem. - -```lua --- B.lua -function foo(f) - f(5) - f("string") -end -``` - -There is nothing wrong with the implementation of `foo` here, but Luau fails to typecheck it all the same because `f` is used in an inconsistent way. This too can be worked around by introducing a type annotation for `f`. - -The fact that the return type of `f` is never used confounds things a little, but for now it would be a big improvement if we inferred `f : ((number | string) -> T...) -> ()`. - -## Design - -We introduce a new kind of TypeVar, `ConstrainedTypeVar` to represent a TypeVar whose lower bounds are known. We will never expose syntax for a user to write these types: They only temporarily exist as type inference is being performed. - -When unifying some type with a `ConstrainedTypeVar` we _broaden_ the set of constraints that can be placed upon it. - -It may help to realize that what we have been doing up until now has been _upper bounds calculation_. - -When we `quantify` a function, we will _normalize_ each type and convert each `ConstrainedTypeVar` into a `UnionTypeVar`. - -### Normalization - -When computing lower bounds, we need to have some process by which we reduce types down to a minimal shape and canonicalize them, if only to have a clean way to flush out degenerate unions like `A | A`. Normalization is about reducing union and intersection types to a minimal, canonicalizable shape. - -A normalized union is one where there do not exist two branches on the union where one is a subtype of the other. It is quite straightforward to implement. - -A normalized intersection is a little bit more complicated: - -1. The tables of an intersection are always combined into a single table. Coincident properties are merged into intersections of their own. - * eg `normalize({x: number, y: string} & {y: number, z: number}) == {x: number, y: string & number, z: number}` - * This is recursive. eg `normalize({x: {y: number}} & {x: {y: string}}) == {x: {y: number & string}}` -1. If two functions in the intersection have a subtyping relationship, the normalization results only in the super-type-most function. (more on function subtyping later) - -### Function subtyping relationships - -If we are going to infer intersections of functions, then we need to be very careful about keeping combinatorics under control. We therefore need to be very deliberate about what subtyping rules we have for functions of differing arity. We have some important requirements: - -* We'd like some way to canonicalize intersections of functions, and yet -* optional function arguments are a great feature that we don't want to break - -A very important use case for us is the case where the user is providing a callback to some higher-order function, and that function will be invoked with extra arguments that the original customer doesn't actually care about. For example: - -```lua --- C.lua -function map_array(arr, f) - local result = {} - for i, e in ipairs(arr) do - table.insert(result, f(e, i, arr)) - end - return result -end - -local example = {1, 2, 3, 4} -local example_result = map_array(example, function(i) return i * 2 end) -``` - -This function mirrors the actual `Array.map` function in JavaScript. It is very frequent for users of this function to provide a lambda that only accepts one argument. It would be annoying for callers to be forced to provide a lambda that accepts two unused arguments. This obviously becomes even worse if the function later changes to provide yet more optional information to the callback. - -This use case is very important for Roblox, as we have many APIs that accept callbacks. Implementors of those callbacks frequently omit arguments that they don't care about. - -Here is an example straight out of the Roblox developer documentation. ([full example here](https://developer.roblox.com/en-us/api-reference/event/BasePart/Touched)) - -```lua --- D.lua -local part = script.Parent - -local function blink() - -- ... -end - -part.Touched:Connect(blink) -``` - -The `Touched` event actually passes a single argument: the part that touched the `Instance` in question. In this example, it is omitted from the callback handler. - -We therefore want _oversaturation_ of a function to be allowed, but this combines with optional function arguments to create a problem with soundness. Consider the following: - -```lua --- E.lua -type Callback = (Instance) -> () - -local cb: Callback -function register_callback(c: Callback) - cb = c -end - -function invoke_callback(i: Instance) - cb(i) -end - ---- - -function bad_callback(x: number?) -end - -local obscured: () -> () = bad_callback - -register_callback(obscured) - -function good_callback() -end - -register_callback(good_callback) -``` - -The problem we run into is, if we allow the subtyping rule `(T?) -> () <: () -> ()` and also allow oversaturation of a function, it becomes easy to obscure an argument type and pass the wrong type of value to it. - -Next, consider the following type alias - -```lua --- F.lua -type OldFunctionType = (any, any) -> any -type NewFunctionType = (any) -> any -type FunctionType = OldFunctionType & NewFunctionType -``` - -If we have a subtyping rule `(T0..TN) <: (T0..TN-1)` to permit the function subtyping relationship `(T0..TN-1) -> R <: (T0..TN) -> R`, then the above type alias normalizes to `(any) -> any`. In order to call the two-argument variation, we would need to permit oversaturation, which runs afoul of the soundness hole from the previous example. - -We need a solution here. - -To resolve this, let's reframe things in simpler terms: - -If there is never a subtyping relationship between packs of different length, then we don't have any soundness issues, but we find ourselves unable to register `good_callback`. - -To resolve _that_, consider that we are in truth being a bit hasty when we say `good_callback : () -> ()`. We can pass any number of arguments to this function safely. We could choose to type `good_callback : () -> () & (any) -> () & (any, any) -> () & ...`. Luau already has syntax for this particular sort of infinite intersection: `good_callback : (any...) -> ()`. - -So, we propose some different inference rules for functions: - -1. The AST fragment `function(arg0..argN) ... end` is typed `(T0..TN, any...) -> R` where `arg0..argN : T0..TN` and `R` is the inferred return type of the function body. Function statements are inferred the same way. -1. Type annotations are unchanged. `() -> ()` is still a nullary function. - -For reference, the subtyping rules for unions and functions are unchanged. We include them here for clarity. - -1. `A <: A | B` -1. `B <: A | B` -1. `A | B <: T` if `A <: T` or `B <: T` -1. `T -> R <: U -> S` if `U <: T` and `R <: S` - -We propose new subtyping rules for type packs: - -1. `(T0..TN) <: (U0..UN)` if, for each `T` and `U`, `T <: U` -1. `(U...)` is the same as `() | (U) | (U, U) | (U, U, U) | ...`, therefore -1. `(T0..TN) <: (U...)` if for each `T`, `T <: U`, therefore -1. `(U...) -> R <: (T0..TN) -> R` if for each `T`, `T <: U` - -The important difference is that we remove all subtyping rules that mention options. Functions of different arities are no longer considered subtypes of one another. Optional function arguments are still allowed, but function as a feature of function calls. - -Under these rules, functions of different arities can never be converted to one another, but actual functions are known to be safe to oversaturate with anything, and so gain a type that says so. - -Under these subtyping rules, snippets `C.lua` and `D.lua`, check the way we want: literal functions are implicitly safe to oversaturate, so it is fine to cast them as the necessary callback function type. - -`E.lua` also typechecks the way we need it to: `(Instance) -> () ()` and so `obscured` cannot receive the value `bad_callback`, which prevents it from being passed to `register_callback`. However, `good_callback : (any...) -> ()` and `(any...) -> () <: (Instance) -> ()` and so it is safe to register `good_callback`. - -Snippet `F.lua` is also fixed with this ruleset: There is no subtyping relationship between `(any) -> ()` and `(any, any) -> ()`, so the intersection is not combined under normalization. - -This works, but itself creates some small problems that we need to resolve: - -First, the `...` symbol still needs to be unavailable for functions that have been given this implicit `...any` type. This is actually taken care of in the Luau parser, so no code change is required. - -Secondly, we do not want to silently allow oversaturation of direct calls to a function if we know that the arguments will be ignored. We need to treat these variadic packs differently when unifying for function calls. - -Thirdly, we don't want to display this variadic in the signature if the author doesn't expect to see it. - -We solve these issues by adding a property `bool VariadicTypePack::hidden` to the implementation and switching on it in the above scenarios. The implementation is relatively straightforward for all 3 cases. - -## Drawbacks - -There is a potential cause for concern that we will be inferring unions of functions in cases where we previously did not. Unions are known to be potential sources of performance issues. One possibility is to allow Luau to be less intelligent and have it "give up" and produce less precise types. This would come at the cost of accuracy and soundness. - -If we allow functions to be oversaturated, we are going to miss out on opportunities to warn the user about legitimate problems with their program. I think we will have to work out some kind of special logic to detect when we are oversaturating a function whose exact definition is known and warn on that. - -Allowing indirect function calls to be oversaturated with `nil` values only should be safe, but a little bit unfortunate. As long as we statically know for certain that `nil` is actually a permissible value for that argument position, it should be safe. - -## Alternatives - -If we are willing to sacrifice soundness, we could adopt success typing and come up with an inference algorithm that produces less precise type information. - -We could also technically choose to do nothing, but this has some unpalatable consequences: Something I would like to do in the near future is to have the inference algorithm assume the same `self` type for all methods of a table. This will make inference of common OO patterns dramatically more intuitive and ergonomic, but inference of polymorphic methods requires some kind of lower bounds calculation to work correctly. diff --git a/rfcs/never-and-unknown-types.md b/rfcs/never-and-unknown-types.md deleted file mode 100644 index 5ad216ef0..000000000 --- a/rfcs/never-and-unknown-types.md +++ /dev/null @@ -1,146 +0,0 @@ -# never and unknown types - -**Status**: Implemented - -## Summary - -Add `unknown` and `never` types that are inhabited by everything and nothing respectively. - -## Motivation - -There are lots of cases in local type inference, semantic subtyping, -and type normalization, where it would be useful to have top and -bottom types. Currently, `any` is filling that role, but it has -special "switch off the type system" superpowers. - -Any use of `unknown` must be narrowed by type refinements unless another `unknown` or `any` is expected. For -example a function which can return any value is: - -```lua - function anything() : unknown ... end -``` - -and can be used as: - -```lua - local x = anything() - if type(x) == "number" then - print(x + 1) - end -``` - -The type of this function cannot be given concisely in current -Luau. The nearest equivalent is `any`, but this switches off the type system, for example -if the type of `anything` is `() -> any` then the following code typechecks: - -```lua - local x = anything() - print(x + 1) -``` - -This is fine in nonstrict mode, but strict mode should flag this as an error. - -The `never` type comes up whenever type inference infers incompatible types for a variable, for example - -```lua - function oops(x) - print("hi " .. x) -- constrains x must be a string - print(math.abs(x)) -- constrains x must be a number - end -``` - -The most general type of `x` is `string & number`, so this code gives -a type error, but we still need to provide a type for `oops`. With a -`never` type, we can infer the type `oops : (never) -> ()`. - -or when exhaustive type casing is achieved: - -```lua - function f(x: string | number) - if type(x) == "string" then - -- x : string - elseif type(x) == "number" then - -- x : number - else - -- x : never - end - end -``` - -or even when the type casing is simply nonsensical: - -```lua - function f(x: string | number) - if type(x) == "string" and type(x) == "number" then - -- x : string & number which is never - end - end -``` - -The `never` type is also useful in cases such as tagged unions where -some of the cases are impossible. For example: - -```lua - type Result = { err: false, val: T } | { err: true, err: E } -``` - -For code which we know is successful, we would like to be able to -indicate that the error case is impossible. With a `never` type, we -can do this with `Result`. Similarly, code which cannot succeed -has type `Result`. - -These types can _almost_ be defined in current Luau, but only quite verbosely: - -```lua - type never = number & string - type unknown = nil | number | boolean | string | {} | (...never) -> (...unknown) -``` - -But even for `unknown` it is impossible to include every single data types, e.g. every root class. - -Providing `never` and `unknown` as built-in types makes the code for -type inference simpler, for example we have a way to present a union -type with no options (as `never`). Otherwise we have to contend with ad hoc -corner cases. - -## Design - -Add: - -* a type `never`, inhabited by nothing, and -* a type `unknown`, inhabited by everything. - -And under success types (nonstrict mode), `unknown` is exactly equivalent to `any` because `unknown` -encompasses everything as does `any`. - -The interesting thing is that `() -> (never, string)` is equivalent to `() -> never` because all -values in a pack must be inhabitable in order for the pack itself to also be inhabitable. In fact, -the type `() -> never` is not completely accurate, it should be `() -> (never, ...never)` to avoid -cascading type errors. Ditto for when an expression list `f(), g()` where the resulting type pack is -`(never, string, number)` is still the same as `(never, ...never)`. - -```lua - function f(): never error() end - function g(): string return "" end - - -- no cascading type error where count mismatches, because the expression list f(), g() - -- was made to return (never, ...never) due to the presence of a never type in the pack - local x, y, z = f(), g() - -- x : never - -- y : never - -- z : never -``` - -## Drawbacks - -Another bit of complexity budget spent. - -These types will be visible to creators, so yay bikeshedding! - -Replacing `any` with `unknown` is a breaking change: code in strict mode may now produce errors. - -## Alternatives - -Stick with the current use of `any` for these cases. - -Make `never` and `unknown` type aliases rather than built-ins. diff --git a/rfcs/property-readonly.md b/rfcs/property-readonly.md deleted file mode 100644 index 6d09212d1..000000000 --- a/rfcs/property-readonly.md +++ /dev/null @@ -1,148 +0,0 @@ -# Read-only properties - -## Summary - -Allow properties of classes and tables to be inferred as read-only. - -## Motivation - -Currently, Roblox APIs have read-only properties of classes, but our -type system does not track this. As a result, users can write (and -indeed due to autocomplete, an encouraged to write) programs with -run-time errors. - -In addition, user code may have properties (such as methods) -that are expected to be used without modification. Currently there is -no way for user code to indicate this, even if it has explicit type -annotations. - -It is very common for functions to only require read access to a parameter, -and this can be inferred during type inference. - -## Design - -### Properties - -Add a modifier to table properties indicating that they are read-only. - -This proposal is not about syntax, but it will be useful for examples to have some. Write: - -* `get p: T` for a read-only property of type `T`. - -For example: -```lua -function f(t) - t.p = 1 + t.p + t.q -end -``` -has inferred type: -``` -f: (t: { p: number, get q: number }) -> () -``` -indicating that `p` is used read-write but `q` is used read-only. - -### Subtyping - -Read-only properties are covariant: - -* If `T` is a subtype of `U` then `{ get p: T }` is a subtype of `{ get p: U }`. - -Read-write properties are a subtype of read-only properties: - -* If `T` is a subtype of `U` then `{ p: T }` is a subtype of `{ get p: U }`. - -### Indexers - -Indexers can be marked read-only just like properties. In -particular, this means there are read-only arrays `{get T}`, that are -covariant, so we have a solution to the "covariant array problem": - -```lua -local dogs: {Dog} -function f(a: {get Animal}) ... end -f(dogs) -``` - -It is sound to allow this program, since `f` only needs read access to -the array, and `{Dog}` is a subtype of `{get Dog}`, which is a subtype -of `{get Animal}`. This would not be sound if `f` had write access, -for example `function f(a: {Animal}) a[1] = Cat.new() end`. - -### Functions - -Functions are not normally mutated after they are initialized, so -```lua -local t = {} -function t.f() ... end -function t:m() ... end -``` - -should have type -``` -t : { - get f : () -> (), - get m : (self) -> () -} -``` - -If developers want a mutable function, -they can use the anonymous function version -```lua -t.g = function() ... end -``` - -For example, if we define: -```lua - type RWFactory = { build : () -> A } -``` - -then we do *not* have that `RWFactory` is a subtype of `RWFactory` -since the build method is read-write, so users can update it: -```lua - local mkdog : RWFactory = { build = Dog.new } - local mkanimal : RWFactory = mkdog -- Does not typecheck - mkanimal.build = Cat.new -- Assigning to methods is OK for RWFactory - local fido : Dog = mkdog.build() -- Oh dear, fido is a Cat at runtime -``` - -but if we define: -```lua - type ROFactory = { get build : () -> A } -``` - -then we do have that `ROFactory` is a subtype of `ROFactory` -since the build method is read-write, so users can update it: -```lua - local mkdog : ROFactory = { build = Dog.new } - local mkanimal : ROFactory = mkdog -- Typechecks now! - mkanimal.build = Cat.new -- Fails to typecheck, since build is read-only -``` - -Since most idiomatic Lua does not update methods after they are -initialized, it seems sensible for the default access for methods should -be read-only. - -*This is a possibly breaking change.* - -### Classes - -Classes can also have read-only properties and accessors. - -Methods in classes should be read-only by default. - -Many of the Roblox APIs an be marked as having getters but not -setters, which will improve accuracy of type checking for Roblox APIs. - -## Drawbacks - -This is adding to the complexity budget for users, -who will be faced with inferred get modifiers on many properties. - -## Alternatives - -Rather than making read-write access the default, we could make read-only the -default and add a new modifier for read-write. This is not backwards compatible. - -We could continue with read-write access to methods, -which means no breaking changes, but means that users may be faced with type -errors such as "`Factory` is not a subtype of `Factory`". diff --git a/rfcs/property-writeonly.md b/rfcs/property-writeonly.md deleted file mode 100644 index 1a49c26b5..000000000 --- a/rfcs/property-writeonly.md +++ /dev/null @@ -1,179 +0,0 @@ -# Write-only properties - -## Summary - -Allow properties of classes and tables to be inferred as write-only. - -## Motivation - -This RFC is a follow-on to supporting read-only properties. - -Read-only properties have many obvious use-cases, but write-only properties -are more technical. - -The reason for wanting write-only properties is that it means -that we can infer a most specific type for functions, which we can't do if -we only have read-write and read-only properties. - -For example, consider the function -```lua - function f(t) t.p = Dog.new() end -``` - -The obvious type for this is -```lua - f : ({ p: Dog }) -> () -``` - -but this is not the most specific type, since read-write properties -are invariant, We could have inferred `f : ({ p: Animal }) -> ()`. -These types are incomparable (neither is a subtype of the other) -and there are uses of `f` that fail to typecheck depending which one choose. - -If `f : ({ p: Dog }) -> ()` then -```lua - local x : { p : Animal } = { p = Cat.new() } - f(x) -- Fails to typecheck -``` - -If `f : ({ p: Animal }) -> ()` then -```lua - local x : { p : Dog } = { p = Dog.new() } - f(x) -- Fails to typecheck -``` - -The reason for these failures is that neither of these is the most -specific type. It is one which includes that `t.p` is written to, and -not read from. -```lua - f : ({ set p: Dog }) -> () -``` - -This allows both example uses of `f` to typecheck. To see that it is more specific than `({ p: Animal }) -> ()`: - -* `Dog` is a subtype of `Animal` -* so (since write-only properties are contravariant) `{ set p: Dog }` is a supertype of `{ set p: Animal }` -* and (since read-write properties are a subtype of write-only properties) `{ set p: Animal }` is a supertype of `{ p: Animal }` -* so (by transitivity) `{ set p: Dog }` is a supertype of `{ set p: Animal }` is a supertype of `{ p: Animal }` -* so (since function arguments are contravariant `({ set p: Dog }) -> ()` is a subtype of `({ p: Animal }) -> ()` - -and similarly `({ set p: Dog }) -> ()` is a subtype of `({ p: Dog }) -> ()`. - -Local type inference depends on the existence of most specific (and most general) types, -so if we want to use it "off the shelf" we will need write-only properties. - -There are also some security reasons why properties should be -write-only. If `t` is a shared table, and any security domain can -write to `t.p`, then it may be possible to use this as a back-channel -if `t.p` is readable. If there is a dynamic check that a property is -write-only then we may wish to present a script analysis error if a -user tries reading it. - -## Design - -### Properties - -Add a modifier to table properties indicating that they are write-only. - -This proposal is not about syntax, but it will be useful for examples to have some. Write: - -* `set p: T` for a write-only property of type `T`. - -For example: -```lua -function f(t) - t.p = 1 + t.q -end -``` -has inferred type: -``` -f: (t: { set p: number, get q: number }) -> () -``` -indicating that `p` is used write-only but `q` is used read-only. - -### Adding read-only and write-only properties - -There are various points where type inference adds properties to types, we now have to consider how to treat each of these. - -When reading a property from a free table, we should add a read-only -property if there is no such property already. If there is already a -write-only property, we should make it read-write. - -When writing a property to a free table, we should add a write-only -property if there is no such property already. If there is already a -read-only property, we should make it read-write. - -When writing a property to an unsealed table, we should add a read-write -property if there is no such property already. - -When declaring a method in a table or class, we should add a read-only property for the method. - -### Subtyping - -Write-only properties are contravariant: - -* If `T` is a subtype of `U` then `{ set p: U }` is a subtype of `{ set p: T }`. - -Read-write properties are a subtype of write-only properties: - -* If `T` is a subtype of `U` then `{ p: U }` is a subtype of `{ set p: T }`. - -### Indexers - -Indexers can be marked write-only just like properties. In -particular, this means there are write-only arrays `{set T}`, that are -contravariant. These are sometimes useful, for example: - -```lua -function move(src, tgt) - for i,v in ipairs(src) do - tgt[i] = src[i] - src[i] = nil - end -end -``` - -we can give this function the type -``` - move: ({a},{set a}) -> () -``` - -and since write-only arrays are contravariant, we can call this with differently-typed -arrays: -```lua - local dogs : {Dog} = {fido,rover} - local animals : {Animal} = {tweety,sylvester} - move (dogs,animals) -``` - -This program does not type-check with read-write arrays. - -### Classes - -Classes can also have write-only properties and indexers. - -Some Roblox APIs which manipulate callbacks are write-only for security reasons. - -### Separate read and write types - -Once we have read-only properties and write-only properties, type intersection -gives read-write properties with different types. - -```lua - { get p: T } & { set p : U } -``` - -If we infer such types, we may wish to present them differently, for -example TypeScript allows both a getter and a setter. - -## Drawbacks - -This is adding to the complexity budget for users, who will be faced -with inferred set modifiers on many properties. There is a trade-off -here about how to spend the user's complexity budget: on understanding -inferred types with write-only properties, or debugging false positive -type errors caused by variance issues). - -## Alternatives - -Just stick with read-only and read-write accesses. diff --git a/rfcs/recursive-type-restriction.md b/rfcs/recursive-type-restriction.md deleted file mode 100644 index 6f69d43a0..000000000 --- a/rfcs/recursive-type-restriction.md +++ /dev/null @@ -1,65 +0,0 @@ -# Recursive type restriction - -**Status**: Implemented - -## Summary - -Restrict generic type aliases to only be able to refer to the exact same instantiation of the generic that's being declared. - -## Motivation - -Luau supports recursive type aliases, but with an important restriction: -users can declare functions of recursive types, such as: -```lua - type Tree = { data: a, children: {Tree} } -``` -but *not* recursive type functions, such as: -```lua - type Weird = { data: a, children: Weird<{a}> } -``` -If types such as `Weird` were allowed, they would have infinite unfoldings for example: -```lua - Weird = { data: number, children: Weird<{number}> }` - Weird<{number}> = { data: {number}, children: Weird<{{number}}> } - Weird<{{number}}> = { data: {{number}}, children: Weird<{{{number}}}> } - ... -``` - -Currently Luau has this restriction, but does not enforce it, and instead -produces unexpected types, which can result in free types leaking into -the module exports. - -## Design - -To enforce the restriction that recursive types aliases produce functions of -recursive types, we require that in any recursive type alias defining `T`, -in any recursive use of `T`, we have that `gs` and `Us` are equal. - -This allows types such as: -```lua - type Tree = { data: a, children: {Tree} } -``` -but *not*: -```lua - type Weird = { data: a, children: Weird<{a}> } -``` -since in the recursive use `a` is not equal to `{a}`. - -This restriction applies to mutually recursive types too. - -## Drawbacks - -This restriction bans some type declarations which do not produce infinite unfoldings, -such as: -```lua - type WeirdButFinite = { data: a, children: WeirdButFinite } -``` -This restriction is stricter than TypeScript, which allows programs such as: -```typescript -interface Foo { x: Foo[]; y: a; } -let x: Foo = { x: [], y: 37 } -``` - -## Alternatives - -We could adopt a solution more like TypeScript's, which is to lazily rather than eagerly instantiate types. diff --git a/rfcs/sealed-table-subtyping.md b/rfcs/sealed-table-subtyping.md deleted file mode 100644 index 73714909b..000000000 --- a/rfcs/sealed-table-subtyping.md +++ /dev/null @@ -1,106 +0,0 @@ -# Sealed table subtyping - -**Status**: Implemented - -## Summary - -In Luau, tables have a state, which can, among others, be "sealed". A sealed table is one that we know the full shape of and cannot have new properties added to it. We would like to introduce subtyping for sealed tables, to allow users to express some subtyping relationships that they currently cannot. - -## Motivation - -We would like this code to type check: -```lua -type Interface = { - name: string, -} - -type Concrete = { - name: string, - id: number, -} - -local x: Concrete = { - name = "foo", - id = 123, -} - -local function getImplementation(): Interface - return x -end -``` -Right now this code fails to type check, because `x` contains an extra property, `id`. Allowing sealed tables to be subtypes of other sealed tables would permit this code to type check successfully. - -## Design - -In order to do this, we will make sealed tables act as a subtype of other sealed tables if they contain all the properties of the supertype. - -``` -type A = { - name: string, -} - -type B = { - name: string, - id: number, -} - -type C = { - id: number, -} - -local b: B = { - name = "foo", - id = 123, -} - --- works: B is a subtype of A -local a: A = b - --- works: B is a subtype of C -local c: C = b - --- fails: A is not a subtype of C -local a2: A = c -``` - -This change affects existing code, but it should be a strictly more permissive change - it won't break any existing code, but it will allow code that was previously denied before. - -## Drawbacks - -This change will mean that sealed tables that don't exactly match may be permitted. In the past, this was an error; users may be relying on the type checker to perform these checks. We think the risk of this is minimal, as the presence of extra properties is unlikely to break user code. This is an example of code that would have raised a type error before: - -```lua -type A = { - name: string, -} - -local a: A = { - name = "foo", - -- Before, we would have raised a type error here for the presence of the - -- extra property `id`. - id = 123, -} -``` - -## Alternatives - -In order to avoid any chance of breaking backwards-compatibility, we could introduce a new state for tables, "interface" or something similar, that can only be produced via new syntax. This state would act like a sealed table, except with the addition of the subtyping rule described in this RFC. An example syntax for this: - -```lua --- `interface` context-sensitive keyword denotes an interface table -type A = interface { - name: string, -} - -type B = { - name: string, - id: number, -} - -local b: B = { - name = "foo", - id = 123, -} - -local a: A = b -``` diff --git a/rfcs/syntax-array-like-table-types.md b/rfcs/syntax-array-like-table-types.md deleted file mode 100644 index 486f7debb..000000000 --- a/rfcs/syntax-array-like-table-types.md +++ /dev/null @@ -1,65 +0,0 @@ -# Array-like table types - -> Note: this RFC was adapted from an internal proposal that predates RFC process - -**Status**: Implemented - -## Summary - -Add special syntax for array-like table types, `{ T }` - -## Motivation - -Luau supports annotating table types. Tables are quite complex beasts, acting as essentially an associative container mapping any value to any other value, and to make it possible to reason about them at type level we have a more constrained definition of what a table is: - -- A table can contain a set of string keys with a specific type for each key -- A table can additionally have an "indexer" for a given key/value type, meaning that it acts as an associative container mapping keys of type K to values of type V - -The syntax for this right now looks like this: - -``` -{ key1: Type1, key2: Type2, [KeyType]: ValueType } -``` - -This is an example of a hybrid table that has both an indexer and a list of specific key/value pairs. - -While Luau technically doesn't support arrays, canonically tables with integer keys are called arrays, or, more precisely, array-like tables. Luau way to specify these is to use an indexer with a number key: - -``` -{ [number]: ValueType } -``` - -(note that this permits use of non-integer keys, so it's technically richer than an array). - -As the use of arrays is very common - for example, many library functions such as `table.insert`, `table.find`, `ipairs`, work on array-like tables - Luau users who want to type-annotate their code have to use array-like table annotations a lot. - -`{ [number]: Type }` is verbose, and the only alternative is to provide a slightly shorter generic syntax: - -``` -type Array = { [number]: T } -``` - -... but this is necessary to specify in every single script, as we don't support preludes. - -## Design - -This proposal suggests adding syntactic sugar to make this less cumbersome: - -``` -{T} -``` - -This will be exactly equivalent to `{ [number]: T }`. `T` must be a type definition immediately followed by `}` (ignoring whitespace characters of course) - -Conveniently, `{T}` syntax matches the syntax for arrays in Typed Lua (a research project from 2014) and Teal (a recent initiative for a TypeScript-like Lua extension language from 2020). - -## Drawbacks - -This introduces a potential ambiguity wrt a tuple-like table syntax; to represent a table with two values, number and a string, it's natural to use syntax `{ number, string }`; however, how would you represent a table with just one value of type number? This may seem concerning but can be resolved by requiring a trailing comma for one-tuple table type in the future, so `{ number, }` would mean "a table with one number", vs `{ number }` which means "an array-like table of numbers". - -## Alternatives - -A different syntax along the lines of `[T]` or `T[]` was considered and rejected in favor of the current syntax: - -a) This allows us to, in the future - if we find a good workaround for b - introduce "real" arrays with a distinct runtime representation, maybe even starting at 0! (whether we do this or not is uncertain and outside of scope of this proposal) -b) Square brackets don't nest nicely due to Lua lexing rules, where [[foo]] is a string literal "foo", so with either syntax with square brackets array-of-arrays is not easy to specify diff --git a/rfcs/syntax-compound-assignment.md b/rfcs/syntax-compound-assignment.md deleted file mode 100644 index 6ab97f6a0..000000000 --- a/rfcs/syntax-compound-assignment.md +++ /dev/null @@ -1,49 +0,0 @@ -# Compound assignment using `op=` syntax - -> Note: this RFC was adapted from an internal proposal that predates RFC process and as such doesn't follow the template precisely - -**Status**: Implemented - -## Design - -A feature present in many many programming languages is assignment operators that perform operations on the left hand side, for example - -``` -a += b -``` - -Lua doesn't provide this right now, so it requires code that's more verbose, for example - -``` -data[index].cost = data[index].cost + 1 -``` - -This proposal suggests adding `+=`, `-=`, `*=`, `/=`, `%=`, `^=` and `..=` operators to remedy this. This improves the ergonomics of writing code, and occasionally results in code that is easier to read to also be faster to execute. - -The semantics of the operators is going to be as follows: - -- Only one value can be on the left and right hand side -- The left hand side is evaluated once as an l-value, similarly to the left hand side of an assignment operator -- The right hand side is evaluated as an r-value (which results in a single Lua value) -- The assignment-modification is performed, which can involve table access if the left hand side is a table dereference -- Unlike C++, these are *assignment statements*, not expressions - code like this `a = (b += 1)` is invalid. - -Crucially, this proposal does *not* introduce new metamethods, and instead uses the existing metamethods and table access semantics, for example - -``` -data[index].cost += 1 -``` - -translates to - -``` -local table = data[index] -local key = "cost" -table[key] = table[key] + 1 -``` - -Which can invoke `__index` and `__newindex` on table as necessary, as well as `__add` on the element. In this specific example, this is *faster* than `data[index].cost = data[index].cost + 1` because `data[index]` is only evaluated once, but in general the compound assignment is expected to have the same performance and the goal of this proposal is to make code easier and more pleasant to write. - -The proposed new operators are currently invalid in Lua source, and as such this is a backwards compatible change. - -From the implementation perspective, this requires adding new code/structure to AST but doesn't involve adding new opcodes, metatables, or any extra cost at runtime. diff --git a/rfcs/syntax-continue-statement.md b/rfcs/syntax-continue-statement.md deleted file mode 100644 index 94e2009a4..000000000 --- a/rfcs/syntax-continue-statement.md +++ /dev/null @@ -1,98 +0,0 @@ -# continue statement - -> Note: this RFC was adapted from an internal proposal that predates RFC process - -**Status**: Implemented - -## Summary - -Add `continue` statement to `for`, `while` and `repeat` loops using a context-sensitive keyword to preserve compatibility. - -## Motivation - -`continue` statement is a feature present in basically all modern programming languages. It's great for ergonomics - often you want the loop to only process items of a specific kind, so you can say `if item.kind ~= "blah" then continue end` in the beginning of the loop. - -`continue` never makes code that was previously impossible to write possible, but it makes some code easier to write. - -We'd like to add this to Luau but we need to keep backwards compatibility - all existing scripts that parse correctly must parse as they do now. The rest of the proposal outlines the exact syntax and semantics that makes it possible. - -## Design - -`continue` statement shall be the statement that *starts* with "continue" identifier (*NOT* keyword - effectively it will be a context-sensitive keyword), and such that the *next* token is none of (`.`, `[`, `:`, `{`, `(`, `=`, string literal or ','). - -These rules effectively say that continue statement is the statement that *does not* parse as a function call or the beginning of an assignment statement. - -This is a continue statement: - -``` -do -continue -end -``` - -This is not a continue statement: - -``` -do -continue = 5 -end -``` - -This is not a continue statement: - -``` -do -continue(5) -end -``` - -This is not a continue statement either, why do you ask? - -``` -do -continue, foo = table.unpack(...) -end -``` - -These rules are simple to implement. In any Lua parser there is already a point where you have to disambiguate an identifier that starts an assignment statement (`foo = 5`) from an identifier that starts a function call (`foo(5)`). It's one of the few, if not the only, place in the Lua grammar where single token lookahead is not sufficient to parse Lua, because you could have `foo.bar(5)` or `foo.bar=5` or `foo.bar(5)[6] = 7`. - -Because of this, we need to parse the entire left hand side of an assignment statement (primaryexp in Lua's BNF) and then check if it was a function call; if not, we'd expect it to be an assignment statement. - -Alternatively in this specific case we could parse "continue", parse the next token, and if it's one of the exclusion list above, roll the parser state back and re-parse the non-continue statement. Our lexer currently doesn't support rollbacks but it's also an easy strategy that other implementations might employ for `continue` specifically. - -The rules make it so that the only time we interpret `continue` as a continuation statement is when in the old Lua the program would not have compiled correctly - because this is not valid Lua 5.x: - -``` -do -continue -end -``` - -There is one case where this can create new confusion in the newly written code - code like this: - -``` -do -continue -(foo())(5) -end -``` - -could be interpreted both as a function call to `continue` (which it is!) and as a continuation statement followed by a function call (which it is not!). Programmers writing this code might expect the second treatment which is wrong. - -We have an existing linter rule to prevent this, however *for now* we will solve this in a stronger way: - -Once we parse `continue`, we will treat this as a block terminator - similarly to `break`/`return`, we will expect the block to end and the next statement will have to be `end`. This will make sure there's no ambiguity. We may relax this later and rely on the linter to tell people when the code is wrong. - -Semantically, continue will work as you would expect - it would skip the rest of the loop body, evaluate the condition for loop continuation (e.g. check the counter value for numeric loops, call the loop iterator for generic loops, evaluate while/repeat condition for while/repeat loops) and proceed accordingly. Locals declared in the loop body would be closed as well. - -One special case is the `until` expression: since it has access to the entire scope of `repeat` statement, using `continue` is invalid when it would result in `until` expression accessing local variables that are declared after `continue`. - -## Drawbacks - -Adding `continue` requires a context-sensitive keyword; this makes editor integration such as syntax highlighting more challenging, as you can't simply assume any occurrence of the word `continue` is referring to the statement - this is different from `break`. - -Implementing `continue` requires special care for `until` statement as highlighted in the design, which may make compiler slower and more complicated. - -## Alternatives - -In later versions of Lua, instead of `continue` you can use `goto`. However, that changes control flow to be unstructured and requires more complex implementation and syntactic changes. diff --git a/rfcs/syntax-default-type-alias-type-parameters.md b/rfcs/syntax-default-type-alias-type-parameters.md deleted file mode 100644 index 443bbac3b..000000000 --- a/rfcs/syntax-default-type-alias-type-parameters.md +++ /dev/null @@ -1,97 +0,0 @@ -# Default type alias type parameters - -**Status**: Implemented - -## Summary - -Introduce syntax to provide default type values inside the type alias type parameter list. - -## Motivation - -Luau has support for type parameters for type aliases and functions. -In languages with similar features like C++, Rust, Flow and TypeScript, it is possible to specify default values for looser coupling and easier composability, and users with experience in those languages would like to have these design capabilities in Luau. - -Here is an example that is coming up frequently during development of GraphQL Luau library: -```lua -export type GraphQLFieldResolver< - TSource, - TContext, - TArgs = { [string]: any } -> = (TSource, TArgs, TContext, GraphQLResolveInfo) -> any -``` -If we could specify defaults like that, we won't have to write long type names when type alias is used unless specific customization is required. -Some engineers already skip these extra arguments and use `'any'` to save time, which gives worse typechecking quality. - -Without default parameter values it's also harder to refactor the code as each type alias reference that uses 'common' type arguments has to be updated. - -While previous example uses a concrete type for default type value, it should also be possible to reference generic types from the same list: -```lua -type Eq = (l: T, r: U) -> boolean - -local a: Eq = ... -local b: Eq = ... -``` - -Generic functions in Luau also have a type parameter list, but it's not possible to specify type arguments at the call site and because of that, default type parameter values for generic functions are not proposed. - -## Design - -If a default type parameter value is assigned, following type parameters (on the right) must also have default type parameter values. -```lua -type A = ... -- not allowed -``` - -Default type parameter values can reference type parameters which were defined earlier (to the left): -```lua -type A = ...-- ok - -type A = ... -- not allowed -``` - -Default type parameter values are also allowed for type packs: -```lua -type A -- ok, variadic type pack -type B -- ok, type pack with no elements -type C -- ok, type pack with one element -type D -- ok, type pack with two elements -type E -- ok, variadic type pack with a different first element -type F -- ok, same type pack as T... -``` - ---- - -Syntax for type alias type parameter is extended as follows: - -```typeparameter ::= Name [`...'] [`=' typeannotation]``` - -Instead of storing a simple array of names in AstStatTypeAlias, we will store an array of structs containing the name and an optional default type value. - -When type alias is referenced, missing type parameters are replaced with default type values, if they are available. - -If all type parameters have a default type value, it is now possible to reference that without providing a type parameter list: -```lua -type All = ... - -local a: All -- ok -local b: All<> -- ok as well -``` - -If type is exported from a module, default type parameter values will still be available when module is imported. - ---- -Type annotations in Luau are placed after `':'`, but we use `'='` here to assign a type value, not to imply that the type parameter on the left has a certain type. - -Type annotation with `':'` could be used in the future for bounded quantification which is orthogonal to the default type value. - -## Drawbacks - -Other languages might allow references to the type alias without arguments inside the scope of that type alias to resolve into a recursive reference to the type alias with the same arguments. - -While that is not allowed in Luau right now, if we decide to change that in the future, we will have an ambiguity when all type alias parameters have default values: -```lua --- ok if we allow Type to mean Type -type Type = { x: number, b: Type? } - --- ambiguity, Type could mean Type or Type -type Type = { x: number, b: Type? } -``` diff --git a/rfcs/syntax-if-expression.md b/rfcs/syntax-if-expression.md deleted file mode 100644 index 76f76cf14..000000000 --- a/rfcs/syntax-if-expression.md +++ /dev/null @@ -1,108 +0,0 @@ -# if-then-else expression - -> Note: this RFC was adapted from an internal proposal that predates RFC process - -**Status**: Implemented - -## Summary - -Introduce a form of ternary conditional using `if cond then value else alternative` syntax. - -## Motivation - -Luau does not have a first-class ternary operator; when a ternary operator is needed, it is usually emulated with `and/or` expression, such as `cond and value or alternative`. - -This expression evaluates to `value` if `cond` and `value` are truthy, and `alternative` otherwise. In particular it means that when `value` is `false` or `nil`, the result of the entire expression is `alternative` even when `cond` is truthy - which doesn't match the expected ternary logic and is a frequent source of subtle errors. - -Instead of `and/or`, `if/else` statement can be used but since that requires a separate mutable variable, this option isn't ergonomic. An immediately invoked function expression is also unergonomic and results in performance issues at runtime. - -## Design - -To solve these problems, we propose introducing a first-class ternary conditional. Instead of `? :` common in C-like languages, we propose an `if-then-else` expression form that is syntactically similar to `if-then-else` statement, but lacks terminating `end`. - -Concretely, the `if-then-else` expression must match `if then else `; it can also contain an arbitrary number of `elseif` clauses, like `if then elseif then else `. Unlike if statements, `else` is mandatory. - -The result of the expression is the then-expression when condition is truthy (not `nil` or `false`) and else-expression otherwise. Only one of the two possible resulting expressions is evaluated. - -Example: - -```lua -local x = if FFlagFoo then A else B - -MyComponent.validateProps = t.strictInterface({ - layoutOrder = t.optional(t.number), - newThing = if FFlagUseNewThing then t.whatever() else nil, -}) -``` - -Note that `else` is mandatory because it's always better to be explicit. If it weren't mandatory, it opens the possiblity that someone might be writing a chain of if-then-else and forgot to add in the final `else` that _doesn't_ return a `nil` value! Enforcing this syntactically ensures the program does not run. Also, with it being mandatory, it solves many cases where parsing the expression is ambiguous due to the infamous [dangling else](https://en.wikipedia.org/wiki/Dangling_else). - -This example will not do what it looks like it's supposed to do! The if expression will _successfully_ parse and be interpreted as to return `h()` if `g()` evaluates to some falsy value, when in actual fact the clear intention is to evaluate `h()` only if `f()` is falsy. - -```lua -if f() then - ... - local foo = if g() then x -else - h() - ... -end -``` - -The only way to solve this had we chose optional `else` branch would be to wrap the if expression in parentheses or to place a semi-colon. - -## Drawbacks - -Studio's script editor autocomplete currently adds an indented block followed by `end` whenever a line ends that includes a `then` token. This can make use of the if expression unpleasant as developers have to keep fixing the code by removing auto-inserted `end`. We can work around this on the editor side by (short-term) differentiating between whether `if` token is the first on its line, and (long-term) by refactoring completion engine to use infallible parser for the block completer. - -Parser recovery can also be more fragile due to leading `if` keyword - when `if` was encountered previously, it always meant an unfinished expression, but now it may start an `if-expr` that, when confused with `if-end` statement can lead to a substantially incorrect parse that is difficult to recover from. However, similar issues occur frequently due to function call statements and as such it's not clear that this makes the recovery materially worse. - -While this is not a problem today, in the past we've contemplated adding support for mid-block `return` statements; these would create an odd grammatical quirk where an `if..then` statement following an empty `return` would parse as an `if` expression. This would happen even without `if` expressions though for function calls (e.g. `return` followed by `print(1)`), and is more of a problem with the potential `return` statement changes and less of a problem with this proposal. - -## Alternatives - -We've evaluated many alternatives for the proposed syntax. - -### Python syntax -``` -b if a else c -``` -Undesirable because expression evaluation order is not left-to-right which is a departure from all other Lua expressions. Additionally, since `b` may be ending a statement (followed by `if` statement), resolving this ambiguity requires parsing `a` as expression and backtracking if `else` is not found, which is expensive and likely to introduce further ambiguities. - -### C-style ternary operator -``` -a ? b : c -``` -Problematic because `:` is used for method calls. In Julia `? :` and `:` are both operators which are disambiguated by _requiring_ spaces in the first case and _prohibiting_ them in the second case; this breaks backwards compatibility and doesn't match the rest of the language where whitespace in the syntax is not significant. - -### Function syntax -``` -iff(a, b, c) -``` -If implemented as a regular function call, this would break short-circuit behavior. If implemented as a special builtin, it would look like a regular function call but have magical behavior -- something likely to confuse developers. - -### Perl 6 syntax -``` -a ?? b !! c -``` -Syntax deemed too unconventional to use in Luau. - -### Smaller variations -``` -(if a then b else c) -``` -Ada uses this syntax (with parentheses required for clarity). Similar solutions were discussed for `as` previously and rejected to make it easier for humans and machines to understand the language syntax. - -``` -a then b else c -``` -This is ambiguous in some cases (like within if condition) so not feasible from a grammar perspective. - -``` -if a then b else c end -``` -The `end` here is unnecessary since `c` is not a block of statements -- it is simply an expression. Thus, use of `end` here would be inconsistent with its other uses in the language. It also makes the syntax more cumbersome to use and could lead to developers sticking with the error-prone `a and b or c` alternative. - -### `elseif` support - -We discussed a simpler version of this proposal without `elseif` support. Unlike if statements, here `elseif` is purely syntactic sugar as it's fully equivalent to `else if`. However, supporting `elseif` makes if expression more consistent with if statement - it is likely that developers familiar with Luau are going to try using `elseif` out of habit. Since supporting `elseif` here is trivial we decided to keep it for consistency. diff --git a/rfcs/syntax-named-function-type-args.md b/rfcs/syntax-named-function-type-args.md deleted file mode 100644 index 536e5606d..000000000 --- a/rfcs/syntax-named-function-type-args.md +++ /dev/null @@ -1,58 +0,0 @@ -# Named function type arguments - -**Status**: Implemented - -## Summary - -Introduce syntax for optional names of function type arguments. - -## Motivation - -This feature will be useful to improve code documentation and provide additional information to LSP clients. - -## Design - -This proposal uses the same syntax that functions use to name the arguments: `(a: number, b: string) -> string` - -Names can be provided in any place where function type is used, for example: - -* in type aliases: -``` -type MyFunc = (cost: number, name: string) -> string -``` - -* in definition files for table types: -``` -declare string: { - rep: (pattern: string, repeats: number) -> string, - sub: (string, start: number, end: number?) -> string -- names are optional, here the first argument doesn't use a name -} -``` - -* for variables: -``` -local cb: (amount: number) -> number -local function foo(cb: (name: string) -> ()) -``` - -Variadic arguments cannot have a name, they are already written as ...: number. - -This feature can be found in other languages: - -* TypeScript (names are required): `let func: (p: type) => any` -* C++: `void (*f)(int cost, std::string name) = nullptr;` - -Implementation will store the names inside the function type description. - -Parsing the argument list will require a single-token lookahead that we already support. -Argument list parser will check if current token is an identifier and if the lookahead token is a colon, in which case it will consume both tokens. - -Function type comparisons will ignore the argument names, this proposal doesn't change the semantics of the language and how typechecking is performed. - -## Drawbacks - -Argument names require that we create unique function types even when these types are 'identical', so we can't compare types using pointer identity. - -This is already the case in current Luau implementation, but it might reduce the optimization opportunities in the future. - -There might also be cases of pointer identity checks that are currently hidden and named arguments might expose places where correct unification is required in the type checker. diff --git a/rfcs/syntax-number-literals.md b/rfcs/syntax-number-literals.md deleted file mode 100644 index 2ad6a6faf..000000000 --- a/rfcs/syntax-number-literals.md +++ /dev/null @@ -1,14 +0,0 @@ -# Extended numeric literal syntax - -> Note: this RFC was adapted from an internal proposal that predates RFC process and as such doesn't follow the template precisely - -**Status**: Implemented - -## Design - -This proposal suggests extending Lua number syntax with: - -1. Binary literals: `0b10101010101`. The prefix is either '0b' or '0B' (to match Lua's '0x' and '0X'). Followed by at least one 1 or 0. -2. Number literal separators: `1_034_123`. We will allow an arbitrary number and arrangement of underscores in all numeric literals, including hexadecimal and binary. This helps with readability of long numbers. - -Both of these features are standard in all modern languages, and can help write readable code. diff --git a/rfcs/syntax-singleton-types.md b/rfcs/syntax-singleton-types.md deleted file mode 100644 index 2c1f54425..000000000 --- a/rfcs/syntax-singleton-types.md +++ /dev/null @@ -1,91 +0,0 @@ -# Singleton types - -> Note: this RFC was adapted from an internal proposal that predates RFC process - -**Status**: Implemented - -## Summary - -Introduce a new kind of type variable, called singleton types. They are just like normal types but has the capability to represent a constant runtime value as a type. - -## Motivation - -There are two primary drivers to add two kinds of singleton types: `string` and `boolean`. - -### `string` singleton types - -Luau type checker can get by mostly fine without constant string types, but it can shine at its best in user code. - -One popular pattern are the abstract data types, which could be supported: - -``` -type Ok = { type: "ok", value: T } -type Err = { type: "error", error: E } -type Result = Ok | Err - -local result: Result = ... -if result.type == "ok" then - -- result :: Ok - print(result.value) -else - -- result :: Err - error(result.error) -end -``` - -### `boolean` singleton types - -At the moment, Luau type checker is completely unable to discern the state of a boolean whatsoever, which makes it impossible to determine all the possible types of the expression from any variations of `a and b`, `a and b or c`, or `a or b`. - -## Design - -Both design components of singleton types should be intuitive for everyone by default. - -### Syntax - -A constant string token as well as a constant boolean token is now allowed to show up in type annotation context. - -``` -type Animals = "Dog" | "Cat" | "Bird" -type TrueOrNil = true? -``` - -Adding constant strings as type means that it is now legal to write -`{["foo"]:T}` as a table type. This should be parsed as a property, -not an indexer. For example: -```lua - type T = { - ["foo"]: number, - ["$$bar"]: string, - baz: boolean, - } -``` -The table type `T` is a table with three properties and no indexer. - -### Semantics - -You are allowed to provide a constant value to the generic primitive type. - -```lua -local foo: "Hello world" = "Hello world" -local bar: string = foo -- allowed - -local foo: true = true -local bar: boolean = foo -- also allowed -``` - -The inverse is not true, because you're trying to narrow any values to a specific value. - -```lua -local foo: string = "Hello world" -local bar: "Hello world" = foo -- not allowed - -local foo: boolean = true -local bar: true = foo -- not allowed -``` - -## Drawbacks - -This may increase the cost of type checking - since some types now need to carry a string literal value, it may need to be copied and compared. The cost can be mitigated through interning although this is not very trivial due to cross-module type checking and the need to be able to typecheck a module graph incrementally. - -This may make the type system a bit more complex to understand, as many programmers have a mental model of types that doesn't include being able to use literal values as a type, and having that be a subtype of a more general value type. diff --git a/rfcs/syntax-string-interpolation.md b/rfcs/syntax-string-interpolation.md deleted file mode 100644 index 2fbb04b0b..000000000 --- a/rfcs/syntax-string-interpolation.md +++ /dev/null @@ -1,169 +0,0 @@ -# String interpolation - -## Summary - -New string interpolation syntax. - -## Motivation - -The problems with `string.format` are many. - -1. Must be exact about the types and its corresponding value. -2. Using `%d` is the idiomatic default for most people, but this loses precision. - * `%d` casts the number into `long long`, which has a lower max value than `double` and does not support decimals. - * `%f` by default will format to the millionths, e.g. `5.5` is `5.500000`. - * `%g` by default will format up to the hundred thousandths, e.g. `5.5` is `5.5` and `5.5312389` is `5.53123`. It will also convert the number to scientific notation when it encounters a number equal to or greater than 10^6. - * To not lose too much precision, you need to use `%s`, but even so the type checker assumes you actually wanted strings. -3. No support for `boolean`. You must use `%s` **and** call `tostring`. -4. No support for values implementing the `__tostring` metamethod. -5. Using `%` is in itself a dangerous operation within `string.format`. - * `"Your health is %d% so you need to heal up."` causes a runtime error because `% so` is actually parsed as `(%s)o` and now requires a corresponding string. -6. Having to use parentheses around string literals just to call a method of it. - -## Design - -To fix all of those issues, we need to do a few things. - -1. A new string interpolation expression (fixes #5, #6) -2. Extend `string.format` to accept values of arbitrary types (fixes #1, #2, #3, #4) - -Because we care about backward compatibility, we need some new syntax in order to not change the meaning of existing strings. There are a few components of this new expression: - -1. A string chunk (`` `...{ ``, `}...{`, and `` }...` ``) where `...` is a range of 0 to many characters. - * `\` escapes `` ` ``, `{`, and itself `\`. - * The pairs must be on the same line (unless a `\` escapes the newline) but expressions needn't be on the same line. -2. An expression between the braces. This is the value that will be interpolated into the string. - * Restriction: we explicitly reject `{{` as it is considered an attempt to escape and get a single `{` character at runtime. -3. Formatting specification may follow after the expression, delimited by an unambiguous character. - * Restriction: the formatting specification must be constant at parse time. - * In the absence of an explicit formatting specification, the `%*` token will be used. - * For now, we explicitly reject any formatting specification syntax. A future extension may be introduced to extend the syntax with an optional specification. - -To put the above into formal EBNF grammar: - -``` -stringinterp ::= exp { exp} -``` - -Which, in actual Luau code, will look like the following: - -``` -local world = "world" -print(`Hello {world}!`) ---> Hello world! - -local combo = {5, 2, 8, 9} -print(`The lock combinations are: {table.concat(combo, ", ")}`) ---> The lock combinations are: 5, 2, 8, 9 - -local set1 = Set.new({0, 1, 3}) -local set2 = Set.new({0, 5, 4}) -print(`{set1} ∪ {set2} = {Set.union(set1, set2)}`) ---> {0, 1, 3} ∪ {0, 5, 4} = {0, 1, 3, 4, 5} - -print(`Some example escaping the braces \{like so}`) -print(`backslash \ that escapes the space is not a part of the string...`) -print(`backslash \\ will escape the second backslash...`) -print(`Some text that also includes \`...`) ---> Some example escaping the braces {like so} ---> backslash that escapes the space is not a part of the string... ---> backslash \ will escape the second backslash... ---> Some text that also includes `... -``` - -As for how newlines are handled, they are handled the same as other string literals. Any text between the `{}` delimiters are not considered part of the string, hence newlines are OK. The main thing is that one opening pair will scan until either a closing pair is encountered, or an unescaped newline. - -``` -local name = "Luau" - -print(`Welcome to { - name -}!`) ---> Welcome to Luau! - -print(`Welcome to \ -{name}!`) ---> Welcome to --- Luau! -``` - -We currently *prohibit* using interpolated strings in function calls without parentheses, this is illegal: - -``` -local name = "world" -print`Hello {name}` -``` - -> Note: This restriction is likely temporary while we work through string interpolation DSLs, an ability to pass individual components of interpolated strings to a function. - -The restriction on `{{` exists solely for the people coming from languages e.g. C#, Rust, or Python which uses `{{` to escape and get the character `{` at runtime. We're also rejecting this at parse time too, since the proper way to escape it is `\{`, so: - -```lua -print(`{{1, 2, 3}} = {myCoolSet}`) -- parse error -``` - -If we did not apply this as a parse error, then the above would wind up printing as the following, which is obviously a gotcha we can and should avoid. - -``` ---> table: 0xSOMEADDRESS = {1, 2, 3} -``` - -Since the string interpolation expression is going to be lowered into a `string.format` call, we'll also need to extend `string.format`. The bare minimum to support the lowering is to add a new token whose definition is to perform a `tostring` call. `%*` is currently an invalid token, so this is a backward compatible extension. This RFC shall define `%*` to have the same behavior as if `tostring` was called. - -```lua -print(string.format("%* %*", 1, 2)) ---> 1 2 -``` - -The offset must always be within bound of the numbers of values passed to `string.format`. - -```lua -local function return_one_thing() return "hi" end -local function return_two_nils() return nil, nil end - -print(string.format("%*", return_one_thing())) ---> "hi" - -print(string.format("%*", Set.new({1, 2, 3}))) ---> {1, 2, 3} - -print(string.format("%* %*", return_two_nils())) ---> nil nil - -print(string.format("%* %* %*", return_two_nils())) ---> error: value #3 is missing, got 2 -``` - -It must be said that we are not allowing this style of string literals in type annotations at this time, regardless of zero or many interpolating expressions, so the following two type annotations below are illegal syntax: - -```lua -local foo: `foo` -local bar: `bar{baz}` -``` - -String interpolation syntax will also support escape sequences. Except `\u{...}`, there is no ambiguity with other escape sequences. If `\u{...}` occurs within a string interpolation literal, it takes priority. - -```lua -local foo = `foo\tbar` -- "foo bar" -local bar = `\u{0041} \u{42}` -- "A B" -``` - -## Drawbacks - -If we want to use backticks for other purposes, it may introduce some potential ambiguity. One option to solve that is to only ever produce string interpolation tokens from the context of an expression. This is messy but doable because the parser and the lexer are already implemented to work in tandem. The other option is to pick a different delimiter syntax to keep backticks available for use in the future. - -If we were to naively compile the expression into a `string.format` call, then implementation details would be observable if you write `` `Your health is {hp}% so you need to heal up.` ``. When lowering the expression, we would need to implicitly insert a `%` character anytime one shows up in a string interpolation token. Otherwise attempting to run this will produce a runtime error where the `%s` token is missing its corresponding string value. - -## Alternatives - -Rather than coming up with a new syntax (which doesn't help issue #5 and #6) and extending `string.format` to accept an extra token, we could just make `%s` call `tostring` and be done. However, doing so would cause programs to be more lenient and the type checker would have no way to infer strings from a `string.format` call. To preserve that, we would need a different token anyway. - -Language | Syntax | Conclusion -----------:|:----------------------|:----------- -Python | `f'Hello {name}'` | Rejected because it's ambiguous with function call syntax. -Swift | `"Hello \(name)"` | Rejected because it changes the meaning of existing strings. -Ruby | `"Hello #{name}"` | Rejected because it changes the meaning of existing strings. -JavaScript | `` `Hello ${name}` `` | Viable option as long as we don't intend to use backticks for other purposes. -C# | `$"Hello {name}"` | Viable option and guarantees no ambiguities with future syntax. - -This leaves us with only two syntax that already exists in other programming languages. The current proposal are for backticks, so the only backward compatible alternative are `$""` literals. We don't necessarily need to use `$` symbol here, but if we were to choose a different symbol, `#` cannot be used. I picked backticks because it doesn't require us to add a stack of closing delimiters in the lexer to make sure each nested string interpolation literals are correctly closed with its opening pair. You only have to count them. diff --git a/rfcs/syntax-type-alias-type-packs.md b/rfcs/syntax-type-alias-type-packs.md deleted file mode 100644 index d5bb60654..000000000 --- a/rfcs/syntax-type-alias-type-packs.md +++ /dev/null @@ -1,218 +0,0 @@ -# Type alias type packs - -**Status**: Implemented - -## Summary - -Provide semantics for referencing type packs inside the body of a type alias declaration - -## Motivation - -We now have an ability to declare a placeholder for a type pack in type alias declaration, but there is no support to reference this pack inside the body of the alias: -```lua -type X = () -> A... -- cannot reference A... as the return value pack - -type Y = X -- invalid number of arguments -``` - -Additionally, while a simple introduction of these generic type packs into the scope will provide an ability to reference them in function declarations, we want to be able to use them to instantiate other type aliases as well. - -Declaration syntax also supports multiple type packs, but we don't have defined semantics on instantiation of such type alias. - -## Design - -We currently support type packs at these locations: -```lua --- for variadic function parameter when type pack is generic -local function f(...: a...) - --- for multiple return values -local function f(): a... - --- as the tail item of function return value pack -local function f(): (number, a...) -``` - -We want to be able to use type packs for type alias instantiation: -```lua -type X = -- - -type A = X -- T... = (S...) -``` - -Similar to function calls, we want to be able to assign zero or more regular types to a single type pack: -```lua -type A = X<> -- T... = () -type B = X -- T... = (number) -type C = X -- T... = (number, string) -``` - -Definition of `A` doesn't parse right now, we would like to make it legal going forward. - -Variadic types can also be assigned to type alias type pack: -```lua -type D = X<...number> -- T... = (...number) -``` - -### Multiple type pack parameters - -We have to keep in mind that it is also possible to declare a type alias that takes multiple type pack parameters. - -Again, type parameters that haven't been matched with type arguments are combined together into the first type pack. -After the first type pack parameter was assigned, following type parameters are not allowed. -Type pack parameters after the first one have to be type packs: -```lua -type Y = -- - -type A = Y -- T... = S..., U... = S... -type B = Y<...string, S...> -- T... = (...string), U... = S... -type C = Y -- T... = (number, string), U... = S... -type D = Y<...number> -- error, T = (...number), but U... = undefined, not (...number) even though one infinite set is enough to fill two, we may have '...number' inside a type pack argument and we'll be unable to see its content -type E = Y -- error, type parameters are not allowed after a type pack - -type Z = -- - -type F = Z -- T = number, U... = S... -type G = Z -- error, not enough regular type arguments, can't split the front of S... into T - -type W = -- - -type H = W -- U... = S..., V... = R... -type I = W -- U... = (string), V... = S... -``` - -### Explicit type pack syntax - -To enable additional control for the content of a type pack, especially in cases where multiple type pack parameters are expected, we introduce an explicit type pack syntax for use in type alias instantiation. - -Similar to variadic types `...a` and generic type packs `T...`, explicit type packs can only be used at type pack positions: -```lua -type Y = (T...) -> (U...) - -type F1 = Y<(number, string), (boolean)> -- T... = (number, string), U... = (boolean) -type F2 = Y<(), ()> -- T... = (), U... = () -type F3 = Y -- T... = (string, number), U... = (number, S...) -``` - -In type parameter list, types inside the parentheses always produce a type pack. -This is in contrast to function return type pack annotation, where `() -> number` is the same as `() -> (number)`. - -However, to preserve backwards-compatibility with optional parenthesis around regular types, type alias instantiation is allowed to assign a non-variadic type pack parameter with a single element to a type argument: -```lua -type X = (T) -> U? -type A = X<(number), (string)> -- T = number, U = string -type A = X<(number), string> -- same - -type Y = (T...) -> () -type B = Y<(number), (string)> -- error: too many type pack parameters -``` - -Explicit type pack syntax is not available in other type pack annotation contexts. - -## Drawbacks - -### Type pack element extraction - -Because our type alias instantiations are not lazy, it's impossible to split of a single type from a type pack: -```lua -type Car = T - -type X = Car -- number -type Y = Car -- error, not enough regular type arguments -type Z = Y -- error, Y doesn't have a valid definition -``` - -With our immediate instantiation, at the point of `Car`, we only know that `S...` is a type pack, but contents are not known. - -Splitting off a single type is is a common pattern with variadic templates in C++, but we don't allow type alias overloads, so use cases are more limited. - -### Type alias can't result in a type pack - -We don't propose type aliases to generate type packs, which could have looked as: -```lua -type Car = T -type Cdr = U... -type Cons = (T, U...) - ---[[ - using type functions to operate on type packs as a list of types -]] -``` - -We wouldn't be able to differentiate if an instantiation results in a type or a type pack and our type system only allows variadic types as the type pack tail element. - -Support for variadic types in the middle of a type pack can be found in TypeScript's tuples. - -## Alternatives - -### Function return type syntax for explicit type packs - -Another option that was considered is to parse `(T)` as `T`, like we do for return type annotation. - -This option complicates the match ruleset since the typechecker will never know if the user has written `T` or `(T)` so each regular type could be a single element type pack and vice versa. -```lua -type X -type C = X -- T... = (number, number) -type D = X<(number), (number)> -- T... = (number, number) - -type Y - ---- two items that were enough to satisfy only a single T... in X are enough to satisfy two T..., U... in Y -type E = Y -- T... = (number), U... = (number) -``` - -### Special mark for single type type packs - -In the Rust language, there is a special disambiguation syntax for single element tuples and single element type packs using a trailing comma: -```rust -(Type,) -``` - -In Python, the same idea is used for single element tuple values: -```python -value = (1, ) -``` - -Since our current ruleset no longer has a problem with single element type tuples, I don't think we need syntax-directed disambiguation option like this one. - -### Only type pack arguments for type pack parameters - -One option that we have is to remove implicit pack assignment from a set of types and always require new explicit type pack syntax: - -```lua -type X = -- - -type B = X<> -- invalid -type C = X -- invalid -type D = X -- invalid - -type B = X<()> -- T... = () -type C = X<(number)> -- T... = (number) -type D = X<(number, string)> -- T... = (number, string) -``` - -But this doesn't allow users to define type aliases where they only care about a few types and use the rest as a 'tail': - -```lua -type X = (T, U, Rest...) -> Rest... - -type A = X -- forced to use a type pack when there are no tail elements -``` - -It also makes it harder to change the type parameter count without fixing up the instantiations. - -### Combining types together with the following type pack into a single argument - -Earlier version of the proposal allowed types to be combined together with a type pack as a tail: -```lua -type X = -- - -type A = X --- T... = (number, S...) -``` - -But this syntax resulted in some confusing behavior when multiple type pack arguments are expected: -```lua -type Y = -- - -type B = Y -- not enough type pack parameters -``` diff --git a/rfcs/syntax-type-ascription-bidi.md b/rfcs/syntax-type-ascription-bidi.md deleted file mode 100644 index 0831aba51..000000000 --- a/rfcs/syntax-type-ascription-bidi.md +++ /dev/null @@ -1,36 +0,0 @@ -# Relaxing type assertions - -**Status**: Implemented - -## Summary - -The way `::` works today is really strange. The best solution we can come up with is to allow `::` to convert between any two related types. - -## Motivation - -Due to an accident of the implementation, the Luau `::` operator can only be used for downcasts and casts to `any`. - -Because of this property, `::` works as users expect in a great many cases, but doesn't actually make a whole lot of sense when scrutinized. - -```lua -local t = {x=0, y=0} - -local a = t :: {x: number, y: number, z: number} -- OK -local a = t :: {x: number} -- Error: This is an upcast! -``` - -Originally, we intended for type assertions to only be useful for upcasts. This would make it consistent with the way annotations work in OCaml and Haskell and would never break soundness. However, users have yet to report this oddity! It is working correctly for them! - -From this, we conclude that users are actually much more interested in having a convenient way to write a downcast. We should bless this use and clean up the rules so they make more sense. - -## Design - -I propose that we change the meaning of the `::` operator to permit conversions between any two types for which either is a subtype of the other. - -## Drawbacks - -`::` was originally envisioned to be a way for users to make the type inference engine work smarter and better for them. The fact of the matter is, though, that downcasts are useful to our users. We should be responsive to that. - -## Alternatives - -We initially discussed allowing `::` to coerce anything to anything else, acting as a full bypass of the type system. We are not doing this because it is really just not that hard to implement: All we need to do is to succeed if unification works between the two types in either direction. Additionally, requiring one type to be subtype of another catches mistakes when two types are completely unrelated, e.g. casting a `string` to a table will still produce an error when this proposal is in effect - this will make sure that `::` is as safe of a bypass as it can be in practice. diff --git a/rfcs/syntax-type-ascription.md b/rfcs/syntax-type-ascription.md deleted file mode 100644 index e48b723a7..000000000 --- a/rfcs/syntax-type-ascription.md +++ /dev/null @@ -1,68 +0,0 @@ -# Type ascriptions - -> Note: this RFC was adapted from an internal proposal that predates RFC process - -**Status**: Implemented - -## Summary - -Implement syntax for type ascriptions using `::` - -## Motivation - -Luau would like to provide a mechanism for requiring a value to be of a specific type: - -``` --- Asserts that the result of a + b is a number. --- Emits a type error if it isn't. -local foo = (a + b) as number -``` - -This syntax was proposed in the original Luau syntax proposal. Unfortunately, we discovered that there is a syntactical ambiguity with `as`: - -``` --- Two function calls or a type assertion? -foo() as (bar) -``` - -## Design - -To provide this functionality without introducing syntactical confusion, we want to change this syntax to use the `::` symbol instead of `as`: - -``` -local foo = (a + b) :: number -``` - -This syntax is borrowed from Haskell, where it performs the same function. - -The `::` operator will bind very tightly, like `as`: - -``` --- type assertion applies to c, not (b + c). -local a = b + c :: number -``` - -Note that `::` can only cast a *single* value to a type - not a type pack (multiple values). This means that in the following context, `::` changes runtime behavior: - -``` -foo(1, bar()) -- passes all values returned by bar() to foo() -foo(1, bar() :: any) -- passes just the first value returned by bar() to foo() -``` - -## Drawbacks - -It's somewhat unusual for Lua to use symbols as operators, with the exception of arithmetics (and `..`). Also a lot of Luau users may be familiar with TypeScript, where the equivalent concept uses `as`. - -`::` may make it more difficult for us to use Turbofish (`::<>`) in the future. - -## Alternatives - -We considered requiring `as` to be wrapped in parentheses, and then relaxing this restriction where there's no chance of syntactical ambiguity: - -``` -local foo: SomeType = (fn() as SomeType) --- Parentheses not needed: unambiguous! -bar(foo as number) -``` - -We decided to not go with this due to concerns about the complexity of the grammar - it requires users to internalize knowledge of our parser to know when they need to surround an `as` expression with parentheses. The rules for when you can leave the parentheses out are somewhat nonintuitive. diff --git a/rfcs/syntax-typed-variadics.md b/rfcs/syntax-typed-variadics.md deleted file mode 100644 index 2988787e7..000000000 --- a/rfcs/syntax-typed-variadics.md +++ /dev/null @@ -1,45 +0,0 @@ -# Typed variadics - -> Note: this RFC was adapted from an internal proposal that predates RFC process - -**Status**: Implemented - -## Summary - -Add syntax for ascribing a type to variadic pack (`...`). - -## Motivation - -Luau's type checker internally can represent a typed variadic: any number of values of the same type. Developers should be able to describe this construct in their own code, for cases where they have a function that accepts an arbitrary number of `string`s, for example. - -## Design - -We think that the postfix `...: T` syntax is the best balance of readability and simplicity. In function type annotations, we will use `...T`: - -``` -function math.max(...: number): number -end - -type fn = (...string) -> string - -type fn2 = () -> ...string -``` - -This doesn't introduce syntactical ambiguity and should cover all cases where we need to represent this construct. Like `...` itself, this syntax is only legal as the last parameter to a function. - -Like all type annotations, the `...: T` syntax has no effect on runtime behavior versus an unannotated `...`. - -There are currently no plans to introduce named variadics, but this proposal leaves room to adopt them with the form `...name: Type` in function declarations in the future. - -## Drawbacks - -The mismatch between the type of `...` in function declaration (`number`) and type declaration (`...number`) is a bit awkward. This also gets more complicated when we introduce generic variadic packs. - -## Alternatives - -We considered several other syntaxes for this construct: - -* `...T`: leaves no room to introduce named variadics -* `...: T...`: redundant `...` -* `... : ...T`: feels redundant, same as above -* `...: T*`: potentially confusing for users with C knowledge, where `T*` is a pointer type diff --git a/rfcs/unsealed-table-assign-optional-property.md b/rfcs/unsealed-table-assign-optional-property.md deleted file mode 100644 index 477399c28..000000000 --- a/rfcs/unsealed-table-assign-optional-property.md +++ /dev/null @@ -1,60 +0,0 @@ -# Unsealed table assignment creates an optional property - -**Status**: Implemented - -## Summary - -In Luau, tables have a state, which can, among others, be "unsealed". -An unsealed table is one that we are still constructing. Currently -assigning a table literal to an unsealed table does not introduce new -properties, so it is a type error if they are read. -We would like to change this so that assigning a table -literal to an unsealed table creates an optional property. - -## Motivation - -In lua-apps, there is testing code which (simplified) looks like: - -```lua -local t = { u = {} } -t = { u = { p = 37 } } -t = { u = { q = "hi" } } -local x: number? = t.u.p -local y: string? = t.u.q -``` - -Currently, this code doesn't typecheck, due to `p` and `q` being unknown properties of `t.u`. - -## Design - -In order to support this idiom, we propose that assigning a table -to an unsealed table should add an optional property. - -For example, before this change the type of `t` is `{ u: {} }`, -and after this change is `{ u: { p: number?, q: number? } }`. - -This is implemented by adding a case to unification where the supertype -is an unsealed table, and the subtype is a table with extra properties. -Currently the extra properties are ignored, but with this change we would -add the property to the unsealed table (making it optional if necessary). - -Since tables with optional properties of the same type are subtypes of -tables with indexers, this allows table literals to be used as dictionaries, -for example the type of `t` is a subtype of `{ u: { [string]: number } }`. - -Note that we need to add an optional property, otherwise the example above will not typecheck. -```lua -local t = { u = {} } -t = { u = { p = 37 } } -t = { u = { q = "hi" } } -- fails because there's no u.p -``` - -## Drawbacks - -The implementation of this proposal introduces optional types during unification, -and so needs access to an allocator. - -## Alternatives - -Rather than introducing optional properties, we could introduce an indexer. For example we could infer the type of -`t` as `{ u: { [string]: number } }`. diff --git a/rfcs/unsealed-table-literals.md b/rfcs/unsealed-table-literals.md deleted file mode 100644 index 669b67d4e..000000000 --- a/rfcs/unsealed-table-literals.md +++ /dev/null @@ -1,78 +0,0 @@ -# Unsealed table literals - -**Status**: Implemented - -## Summary - -Currently the only way to create an unsealed table is as an empty table literal `{}`. -This RFC proposes making all table literals unsealed. - -## Motivation - -Table types can be *sealed* or *unsealed*. These are different in that: - -* Unsealed table types are *precise*: if a table has unsealed type `{ p: number, q: string }` - then it is guaranteed to have only properties `p` and `q`. - -* Sealed tables support *width subtyping*: if a table has sealed type `{ p: number }` - then it is guaranteed to have at least property `p`, so we allow `{ p: number, q: string }` - to be treated as a subtype of `{ p: number }` - -* Unsealed tables can have properties added to them: if `t` has unsealed type - `{ p: number }` then after the assignment `t.q = "hi"`, `t`'s type is updated to be - `{ p: number, q: string }`. - -* Unsealed tables are subtypes of sealed tables. - -Currently the only way to create an unsealed table is using an empty table literal, so -```lua - local t = {} - t.p = 5 - t.q = "hi" -``` -typechecks, but -```lua - local t = { p = 5 } - t.q = "hi" -``` -does not. - -This causes problems in examples, in particular developers -may initialize properties but not methods: -```lua - local t = { p = 5 } - function t.f() return t.p end -``` - -## Design - -The proposed change is straightforward: make all table literals unsealed. - -## Drawbacks - -Making all table literals unsealed is a conservative change, it only removes type errors. - -It does encourage developers to add new properties to tables during initialization, which -may be considered poor style. - -It does mean that some spelling mistakes will not be caught, for example -```lua -local t = {x = 1, y = 2} -if foo then - t.z = 3 -- is z a typo or intentional 2-vs-3 choice? -end -``` - -In particular, we no longer warn about adding properties to array-like tables. -```lua -local a = {1,2,3} -a.p = 5 -``` - -## Alternatives - -We could introduce a new table state for unsealed-but-precise -tables. The trade-off is that that would be more precise, at the cost -of adding user-visible complexity to the type system. - -We could continue to treat array-like tables as sealed. diff --git a/rfcs/unsealed-table-subtyping-strips-optional-properties.md b/rfcs/unsealed-table-subtyping-strips-optional-properties.md deleted file mode 100644 index d99c1f818..000000000 --- a/rfcs/unsealed-table-subtyping-strips-optional-properties.md +++ /dev/null @@ -1,68 +0,0 @@ -# Only strip optional properties from unsealed tables during subtyping - -**Status**: Implemented - -## Summary - -Currently subtyping allows optional properties to be stripped from table types during subtyping. -This RFC proposes only allowing that when the subtype is unsealed and the supertype is sealed. - -## Motivation - -Table types can be *sealed* or *unsealed*. These are different in that: - -* Unsealed table types are *precise*: if a table has unsealed type `{ p: number, q: string }` - then it is guaranteed to have only properties `p` and `q`. - -* Sealed tables support *width subtyping*: if a table has sealed type `{ p: number }` - then it is guaranteed to have at least property `p`, so we allow `{ p: number, q: string }` - to be treated as a subtype of `{ p: number }` - -* Unsealed tables can have properties added to them: if `t` has unsealed type - `{ p: number }` then after the assignment `t.q = "hi"`, `t`'s type is updated to be - `{ p: number, q: string }`. - -* Unsealed tables are subtypes of sealed tables. - -Currently we allow subtyping to strip away optional fields -as long as the supertype is sealed. -This is necessary for examples, for instance: -```lua - local t : { p: number, q: string? } = { p = 5, q = "hi" } - t = { p = 7 } -``` -typechecks because `{ p : number }` is a subtype of -`{ p : number, q : string? }`. Unfortunately this is not sound, -since sealed tables support width subtyping: -```lua - local t : { p: number, q: string? } = { p = 5, q = "hi" } - local u : { p: number } = { p = 5, q = false } - t = u -``` - -## Design - -The fix for this source of unsoundness is twofold: - -1. make all table literals unsealed, and -2. only allow stripping optional properties from when the - supertype is sealed and the subtype is unsealed. - -This RFC is for (2). There is a [separate RFC](unsealed-table-literals.md) for (1). - -## Drawbacks - -This introduces new type errors (it has to, since it is fixing a source of -unsoundness). This means that there are now false positives such as: -```lua - local t : { p: number, q: string? } = { p = 5, q = "hi" } - local u : { p: number } = { p = 5, q = "lo" } - t = u -``` -These false positives are so similar to sources of unsoundness -that it is difficult to see how to allow them soundly. - -## Alternatives - -We could just live with unsoundness. - diff --git a/tests/AnyTypeSummary.test.cpp b/tests/AnyTypeSummary.test.cpp new file mode 100644 index 000000000..5c3b4aa39 --- /dev/null +++ b/tests/AnyTypeSummary.test.cpp @@ -0,0 +1,989 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/AstQuery.h" +#include "Luau/BuiltinDefinitions.h" +#include "Luau/RequireTracer.h" + +#include "Fixture.h" + +#include "ScopedFlags.h" +#include "doctest.h" + +#include + +using namespace Luau; + +using Pattern = AnyTypeSummary::Pattern; + +LUAU_FASTFLAG(LuauSolverV2) +LUAU_FASTFLAG(DebugLuauFreezeArena) +LUAU_FASTFLAG(DebugLuauMagicTypes) +LUAU_FASTFLAG(StudioReportLuauAny2) + + +struct ATSFixture : BuiltinsFixture +{ + + ATSFixture() + { + addGlobalBinding(frontend.globals, "game", builtinTypes->anyType, "@test"); + addGlobalBinding(frontend.globals, "script", builtinTypes->anyType, "@test"); + } +}; + +TEST_SUITE_BEGIN("AnyTypeSummaryTest"); + +TEST_CASE_FIXTURE(ATSFixture, "var_typepack_any") +{ + ScopedFastFlag sff[] = { + {FFlag::LuauSolverV2, true}, + {FFlag::StudioReportLuauAny2, true}, + }; + + fileResolver.source["game/Gui/Modules/A"] = R"( +type A = (number, string) -> ...any +)"; + + CheckResult result1 = frontend.check("game/Gui/Modules/A"); + LUAU_REQUIRE_NO_ERRORS(result1); + + ModulePtr module = frontend.moduleResolver.getModule("game/Gui/Modules/A"); + + LUAU_ASSERT(module->ats.typeInfo.size() == 1); + LUAU_ASSERT(module->ats.typeInfo[0].code == Pattern::Alias); + LUAU_ASSERT(module->ats.typeInfo[0].node == "type A = (number, string)->( ...any)"); +} + +TEST_CASE_FIXTURE(ATSFixture, "export_alias") +{ + ScopedFastFlag sff[] = { + {FFlag::LuauSolverV2, true}, + {FFlag::StudioReportLuauAny2, true}, + }; + + fileResolver.source["game/Gui/Modules/A"] = R"( +export type t8 = t0 &((true | any)->('')) +)"; + + CheckResult result1 = frontend.check("game/Gui/Modules/A"); + LUAU_REQUIRE_ERROR_COUNT(1, result1); + + ModulePtr module = frontend.moduleResolver.getModule("game/Gui/Modules/A"); + + LUAU_ASSERT(module->ats.typeInfo.size() == 1); + LUAU_ASSERT(module->ats.typeInfo[0].code == Pattern::Alias); + LUAU_ASSERT(module->ats.typeInfo[0].node == "export type t8 = t0 &((true | any)->(''))"); +} + +TEST_CASE_FIXTURE(ATSFixture, "typepacks") +{ + ScopedFastFlag sff[] = { + {FFlag::LuauSolverV2, true}, + {FFlag::StudioReportLuauAny2, true}, + }; + + fileResolver.source["game/Gui/Modules/A"] = R"( +local function fallible(t: number): ...any + if t > 0 then + return true, t -- should catch this + end + return false, "must be positive" -- should catch this +end +)"; + + CheckResult result1 = frontend.check("game/Gui/Modules/A"); + LUAU_REQUIRE_NO_ERRORS(result1); + + ModulePtr module = frontend.moduleResolver.getModule("game/Gui/Modules/A"); + + LUAU_ASSERT(module->ats.typeInfo.size() == 3); + LUAU_ASSERT(module->ats.typeInfo[1].code == Pattern::TypePk); + LUAU_ASSERT(module->ats.typeInfo[0].node == "local function fallible(t: number): ...any\n if t > 0 then\n return true, t\n end\n return false, 'must be positive'\nend"); +} + +TEST_CASE_FIXTURE(ATSFixture, "typepacks_no_ret") +{ + ScopedFastFlag sff[] = { + {FFlag::LuauSolverV2, true}, + {FFlag::StudioReportLuauAny2, true}, + }; + + fileResolver.source["game/Gui/Modules/A"] = R"( +-- TODO: if partially typed, we'd want to know too +local function fallible(t: number) + if t > 0 then + return true, t + end + return false, "must be positive" +end +)"; + + CheckResult result1 = frontend.check("game/Gui/Modules/A"); + LUAU_REQUIRE_ERROR_COUNT(1, result1); + + ModulePtr module = frontend.moduleResolver.getModule("game/Gui/Modules/A"); + + LUAU_ASSERT(module->ats.typeInfo.size() == 0); +} + +TEST_CASE_FIXTURE(ATSFixture, "var_typepack_any_gen_table") +{ + ScopedFastFlag sff[] = { + {FFlag::LuauSolverV2, true}, + {FFlag::StudioReportLuauAny2, true}, + }; + + fileResolver.source["game/Gui/Modules/A"] = R"( +type Pair = {first: T, second: any} +)"; + + CheckResult result1 = frontend.check("game/Gui/Modules/A"); + LUAU_REQUIRE_NO_ERRORS(result1); + + ModulePtr module = frontend.moduleResolver.getModule("game/Gui/Modules/A"); + + LUAU_ASSERT(module->ats.typeInfo.size() == 1); + LUAU_ASSERT(module->ats.typeInfo[0].code == Pattern::Alias); + LUAU_ASSERT(module->ats.typeInfo[0].node == "type Pair = {first: T, second: any}"); +} + +TEST_CASE_FIXTURE(ATSFixture, "assign_uneq") +{ + ScopedFastFlag sff[] = { + {FFlag::LuauSolverV2, true}, + {FFlag::StudioReportLuauAny2, true}, + }; + + fileResolver.source["game/Gui/Modules/B"] = R"( +local function greetings(name: string) + return "Hello, " .. name, nil +end + +local x, y = greetings("Dibri") +local x, y = greetings("Dibri"), nil +local x, y, z = greetings("Dibri") -- mismatch +)"; + + CheckResult result1 = frontend.check("game/Gui/Modules/B"); + LUAU_REQUIRE_ERROR_COUNT(1, result1); + + ModulePtr module = frontend.moduleResolver.getModule("game/Gui/Modules/B"); + LUAU_ASSERT(module->ats.typeInfo.size() == 0); +} + +TEST_CASE_FIXTURE(ATSFixture, "var_typepack_any_gen") +{ + ScopedFastFlag sff[] = { + {FFlag::LuauSolverV2, true}, + {FFlag::StudioReportLuauAny2, true}, + }; + + fileResolver.source["game/Gui/Modules/A"] = R"( +-- type Pair = (boolean, string, ...any) -> {T} -- type aliases with generics/pack do not seem to be processed? +type Pair = (boolean, T) -> ...any +)"; + + CheckResult result1 = frontend.check("game/Gui/Modules/A"); + LUAU_REQUIRE_NO_ERRORS(result1); + + ModulePtr module = frontend.moduleResolver.getModule("game/Gui/Modules/A"); + + LUAU_ASSERT(module->ats.typeInfo.size() == 1); + LUAU_ASSERT(module->ats.typeInfo[0].code == Pattern::Alias); + LUAU_ASSERT(module->ats.typeInfo[0].node == "type Pair = (boolean, T)->( ...any)"); +} + +TEST_CASE_FIXTURE(ATSFixture, "typeof_any_in_func") +{ + ScopedFastFlag sff[] = { + {FFlag::LuauSolverV2, true}, + {FFlag::StudioReportLuauAny2, true}, + }; + + fileResolver.source["game/Gui/Modules/A"] = R"( + local function f() + local a: any = 1 + local b: typeof(a) = 1 + end +)"; + + CheckResult result1 = frontend.check("game/Gui/Modules/A"); + LUAU_REQUIRE_NO_ERRORS(result1); + + ModulePtr module = frontend.moduleResolver.getModule("game/Gui/Modules/A"); + + LUAU_ASSERT(module->ats.typeInfo.size() == 2); + LUAU_ASSERT(module->ats.typeInfo[0].code == Pattern::VarAnnot); + LUAU_ASSERT(module->ats.typeInfo[0].node == "local function f()\n local a: any = 1\n local b: typeof(a) = 1\n end"); +} + +TEST_CASE_FIXTURE(ATSFixture, "generic_types") +{ + ScopedFastFlag sff[] = { + {FFlag::LuauSolverV2, true}, + {FFlag::StudioReportLuauAny2, true}, + }; + + fileResolver.source["game/Gui/Modules/A"] = R"( +local function foo(a: (...A) -> any, ...: A) + return a(...) +end + +local function addNumbers(num1, num2) + local result = num1 + num2 + return result +end + +foo(addNumbers) + )"; + + CheckResult result1 = frontend.check("game/Gui/Modules/A"); + LUAU_REQUIRE_NO_ERRORS(result1); + + ModulePtr module = frontend.moduleResolver.getModule("game/Gui/Modules/A"); + + LUAU_ASSERT(module->ats.typeInfo.size() == 3); + LUAU_ASSERT(module->ats.typeInfo[1].code == Pattern::FuncApp); + LUAU_ASSERT(module->ats.typeInfo[0].node == "local function foo(a: (...A)->( any),...: A)\n return a(...)\nend"); +} + +TEST_CASE_FIXTURE(ATSFixture, "no_annot") +{ + ScopedFastFlag sff[] = { + {FFlag::LuauSolverV2, true}, + {FFlag::StudioReportLuauAny2, true}, + }; + + fileResolver.source["game/Gui/Modules/A"] = R"( +local character = script.Parent +)"; + + CheckResult result1 = frontend.check("game/Gui/Modules/A"); + LUAU_REQUIRE_NO_ERRORS(result1); + + ModulePtr module = frontend.moduleResolver.getModule("game/Gui/Modules/A"); + + LUAU_ASSERT(module->ats.typeInfo.size() == 0); +} + +TEST_CASE_FIXTURE(ATSFixture, "if_any") +{ + ScopedFastFlag sff[] = { + {FFlag::LuauSolverV2, true}, + {FFlag::StudioReportLuauAny2, true}, + }; + + fileResolver.source["game/Gui/Modules/A"] = R"( +function f(x: any) +if not x then +x = { + y = math.random(0, 2^31-1), + left = nil, + right = nil +} +else + local expected = x * 5 +end +end +)"; + + CheckResult result1 = frontend.check("game/Gui/Modules/A"); + LUAU_REQUIRE_NO_ERRORS(result1); + + ModulePtr module = frontend.moduleResolver.getModule("game/Gui/Modules/A"); + + LUAU_ASSERT(module->ats.typeInfo.size() == 1); + LUAU_ASSERT(module->ats.typeInfo[0].code == Pattern::FuncArg); + LUAU_ASSERT( + module->ats.typeInfo[0].node == "function f(x: any)\nif not x then\nx = {\n y = math.random(0, 2^31-1),\n left = nil,\n right = " + "nil\n}\nelse\n local expected = x * 5\nend\nend" + ); +} + +TEST_CASE_FIXTURE(ATSFixture, "variadic_any") +{ + ScopedFastFlag sff[] = { + {FFlag::LuauSolverV2, true}, + {FFlag::StudioReportLuauAny2, true}, + }; + + fileResolver.source["game/Gui/Modules/A"] = R"( + local function f(): (number, ...any) + return 1, 5 --catching this + end + + local x, y, z = f() -- not catching this any because no annot +)"; + + CheckResult result1 = frontend.check("game/Gui/Modules/A"); + LUAU_REQUIRE_NO_ERRORS(result1); + + ModulePtr module = frontend.moduleResolver.getModule("game/Gui/Modules/A"); + + LUAU_ASSERT(module->ats.typeInfo.size() == 2); + LUAU_ASSERT(module->ats.typeInfo[0].code == Pattern::FuncRet); + LUAU_ASSERT(module->ats.typeInfo[0].node == "local function f(): (number, ...any)\n return 1, 5\n end"); +} + +TEST_CASE_FIXTURE(ATSFixture, "type_alias_intersection") +{ + ScopedFastFlag sff[] = { + {FFlag::LuauSolverV2, true}, + {FFlag::StudioReportLuauAny2, true}, + }; + + fileResolver.source["game/Gui/Modules/A"] = R"( + type XCoord = {x: number} + type YCoord = {y: any} + type Vector2 = XCoord & YCoord -- table type intersections do not get normalized + local vec2: Vector2 = {x = 1, y = 2} +)"; + + CheckResult result1 = frontend.check("game/Gui/Modules/A"); + LUAU_REQUIRE_NO_ERRORS(result1); + + ModulePtr module = frontend.moduleResolver.getModule("game/Gui/Modules/A"); + + LUAU_ASSERT(module->ats.typeInfo.size() == 3); + LUAU_ASSERT(module->ats.typeInfo[2].code == Pattern::VarAnnot); + LUAU_ASSERT(module->ats.typeInfo[2].node == "local vec2: Vector2 = {x = 1, y = 2}"); +} + +TEST_CASE_FIXTURE(ATSFixture, "var_func_arg") +{ + ScopedFastFlag sff[] = { + {FFlag::LuauSolverV2, true}, + {FFlag::StudioReportLuauAny2, true}, + }; + + fileResolver.source["game/Gui/Modules/A"] = R"( + local function f(...: any) + end + local function f(x: number?, y, z: any) + end + function f(x: number?, y, z: any) + end + function f(...: any) + end +)"; + + CheckResult result1 = frontend.check("game/Gui/Modules/A"); + LUAU_REQUIRE_NO_ERRORS(result1); + + ModulePtr module = frontend.moduleResolver.getModule("game/Gui/Modules/A"); + + LUAU_ASSERT(module->ats.typeInfo.size() == 4); + LUAU_ASSERT(module->ats.typeInfo[0].code == Pattern::VarAny); + LUAU_ASSERT(module->ats.typeInfo[0].node == "local function f(...: any)\n end"); +} + +TEST_CASE_FIXTURE(ATSFixture, "var_func_apps") +{ + ScopedFastFlag sff[] = { + {FFlag::LuauSolverV2, true}, + {FFlag::StudioReportLuauAny2, true}, + }; + + fileResolver.source["game/Gui/Modules/A"] = R"( + local function f(...: any) + end + f("string", 123) + f("string") +)"; + + CheckResult result1 = frontend.check("game/Gui/Modules/A"); + LUAU_REQUIRE_NO_ERRORS(result1); + + ModulePtr module = frontend.moduleResolver.getModule("game/Gui/Modules/A"); + + LUAU_ASSERT(module->ats.typeInfo.size() == 3); + LUAU_ASSERT(module->ats.typeInfo[0].code == Pattern::VarAny); + LUAU_ASSERT(module->ats.typeInfo[0].node == "local function f(...: any)\n end"); +} + + +TEST_CASE_FIXTURE(ATSFixture, "CannotExtendTable") +{ + ScopedFastFlag sff[] = { + {FFlag::LuauSolverV2, true}, + {FFlag::StudioReportLuauAny2, true}, + }; + + fileResolver.source["game/Gui/Modules/A"] = R"( +local CAR_COLLISION_GROUP = "Car" + +-- Set the car collision group +for _, descendant in carTemplate:GetDescendants() do + if descendant:IsA("BasePart") then + descendant.CollisionGroup = CAR_COLLISION_GROUP + end +end + +)"; + + CheckResult result1 = frontend.check("game/Gui/Modules/A"); + LUAU_REQUIRE_ERROR_COUNT(3, result1); + + ModulePtr module = frontend.moduleResolver.getModule("game/Gui/Modules/A"); + + LUAU_ASSERT(module->ats.typeInfo.size() == 0); +} + +TEST_CASE_FIXTURE(ATSFixture, "unknown_symbol") +{ + ScopedFastFlag sff[] = { + {FFlag::LuauSolverV2, true}, + {FFlag::StudioReportLuauAny2, true}, + }; + + fileResolver.source["game/Gui/Modules/A"] = R"( +local function manageRace(raceContainer: Model) + RaceManager.new(raceContainer) +end + +)"; + + CheckResult result1 = frontend.check("game/Gui/Modules/A"); + LUAU_REQUIRE_ERROR_COUNT(2, result1); + + ModulePtr module = frontend.moduleResolver.getModule("game/Gui/Modules/A"); + + LUAU_ASSERT(module->ats.typeInfo.size() == 2); + LUAU_ASSERT(module->ats.typeInfo[0].code == Pattern::FuncArg); + LUAU_ASSERT(module->ats.typeInfo[0].node == "local function manageRace(raceContainer: Model)\n RaceManager.new(raceContainer)\nend"); +} + +TEST_CASE_FIXTURE(ATSFixture, "racing_3_short") +{ + ScopedFastFlag sff[] = { + {FFlag::LuauSolverV2, true}, + {FFlag::StudioReportLuauAny2, true}, + }; + + fileResolver.source["game/Gui/Modules/A"] = R"( + +local CollectionService = game:GetService("CollectionService") + +local RaceManager = require(script.RaceManager) + +local RACE_TAG = "Race" + +local function manageRace(raceContainer: Model) + RaceManager.new(raceContainer) +end + +local function initialize() + CollectionService:GetInstanceAddedSignal(RACE_TAG):Connect(manageRace) + + for _, raceContainer in CollectionService:GetTagged(RACE_TAG) do + manageRace(raceContainer) + end +end + +initialize() + +)"; + + CheckResult result1 = frontend.check("game/Gui/Modules/A"); + LUAU_REQUIRE_ERROR_COUNT(2, result1); + + ModulePtr module = frontend.moduleResolver.getModule("game/Gui/Modules/A"); + + LUAU_ASSERT(module->ats.typeInfo.size() == 5); + LUAU_ASSERT(module->ats.typeInfo[0].code == Pattern::FuncArg); + LUAU_ASSERT(module->ats.typeInfo[0].node == "local function manageRace(raceContainer: Model)\n RaceManager.new(raceContainer)\nend"); +} + +TEST_CASE_FIXTURE(ATSFixture, "racing_collision_2") +{ + ScopedFastFlag sff[] = { + {FFlag::LuauSolverV2, true}, + {FFlag::StudioReportLuauAny2, true}, + }; + + fileResolver.source["game/Gui/Modules/A"] = R"( +local PhysicsService = game:GetService("PhysicsService") +local ReplicatedStorage = game:GetService("ReplicatedStorage") + +local safePlayerAdded = require(script.safePlayerAdded) + +local CAR_COLLISION_GROUP = "Car" +local CHARACTER_COLLISION_GROUP = "Character" + +local carTemplate = ReplicatedStorage.Car + +local function onCharacterAdded(character: Model) + -- Set the collision group for any parts that are added to the character + character.DescendantAdded:Connect(function(descendant) + if descendant:IsA("BasePart") then + descendant.CollisionGroup = CHARACTER_COLLISION_GROUP + end + end) + + -- Set the collision group for any parts currently in the character + for _, descendant in character:GetDescendants() do + if descendant:IsA("BasePart") then + descendant.CollisionGroup = CHARACTER_COLLISION_GROUP + end + end +end + +local function onPlayerAdded(player: Player) + player.CharacterAdded:Connect(onCharacterAdded) + + if player.Character then + onCharacterAdded(player.Character) + end +end + +local function initialize() + -- Setup collision groups + PhysicsService:RegisterCollisionGroup(CAR_COLLISION_GROUP) + PhysicsService:RegisterCollisionGroup(CHARACTER_COLLISION_GROUP) + + -- Stop the collision groups from colliding with each other + PhysicsService:CollisionGroupSetCollidable(CAR_COLLISION_GROUP, CAR_COLLISION_GROUP, false) + PhysicsService:CollisionGroupSetCollidable(CHARACTER_COLLISION_GROUP, CHARACTER_COLLISION_GROUP, false) + PhysicsService:CollisionGroupSetCollidable(CAR_COLLISION_GROUP, CHARACTER_COLLISION_GROUP, false) + + -- Set the car collision group + for _, descendant in carTemplate:GetDescendants() do + if descendant:IsA("BasePart") then + descendant.CollisionGroup = CAR_COLLISION_GROUP + end + end + + -- Set character collision groups for all players + safePlayerAdded(onPlayerAdded) +end + +initialize() + +)"; + + CheckResult result1 = frontend.check("game/Gui/Modules/A"); + LUAU_REQUIRE_ERROR_COUNT(5, result1); + + ModulePtr module = frontend.moduleResolver.getModule("game/Gui/Modules/A"); + + LUAU_ASSERT(module->ats.typeInfo.size() == 11); + LUAU_ASSERT(module->ats.typeInfo[0].code == Pattern::FuncArg); + LUAU_ASSERT( + module->ats.typeInfo[0].node == + "local function onCharacterAdded(character: Model)\n\n character.DescendantAdded:Connect(function(descendant)\n if " + "descendant:IsA('BasePart')then\n descendant.CollisionGroup = CHARACTER_COLLISION_GROUP\n end\n end)\n\n\n for _, descendant in " + "character:GetDescendants()do\n if descendant:IsA('BasePart')then\n descendant.CollisionGroup = CHARACTER_COLLISION_GROUP\n end\n " + "end\nend" + ); +} + +TEST_CASE_FIXTURE(ATSFixture, "racing_spawning_1") +{ + ScopedFastFlag sff[] = { + {FFlag::LuauSolverV2, true}, + {FFlag::StudioReportLuauAny2, true}, + }; + + fileResolver.source["game/Gui/Modules/A"] = R"( +local CollectionService = game:GetService("CollectionService") +local Players = game:GetService("Players") + +local spawnCar = require(script.spawnCar) +local destroyPlayerCars = require(script.destroyPlayerCars) + +local spawnPromptTemplate = script.SpawnPrompt + +local KIOSK_TAG = "CarSpawnKiosk" + +local function setupKiosk(kiosk: Model) + local spawnLocation = kiosk:FindFirstChild("SpawnLocation") + assert(spawnLocation, `{kiosk:GetFullName()} has no SpawnLocation part`) + local promptPart = kiosk:FindFirstChild("Prompt") + assert(promptPart, `{kiosk:GetFullName()} has no Prompt part`) + + -- Hide the car spawn location + spawnLocation.Transparency = 1 + + -- Create a new prompt to spawn the car + local spawnPrompt = spawnPromptTemplate:Clone() + spawnPrompt.Parent = promptPart + + spawnPrompt.Triggered:Connect(function(player: Player) + -- Remove any existing cars the player has spawned + destroyPlayerCars(player) + -- Spawn a new car at the spawnLocation, owned by the player + spawnCar(spawnLocation.CFrame, player) + end) +end + +local function initialize() + -- Remove cars owned by players whenever they leave + Players.PlayerRemoving:Connect(destroyPlayerCars) + + -- Setup all car spawning kiosks + CollectionService:GetInstanceAddedSignal(KIOSK_TAG):Connect(setupKiosk) + + for _, kiosk in CollectionService:GetTagged(KIOSK_TAG) do + setupKiosk(kiosk) + end +end + +initialize() + +)"; + + CheckResult result1 = frontend.check("game/Gui/Modules/A"); + LUAU_REQUIRE_ERROR_COUNT(5, result1); + + ModulePtr module = frontend.moduleResolver.getModule("game/Gui/Modules/A"); + + LUAU_ASSERT(module->ats.typeInfo.size() == 7); + LUAU_ASSERT(module->ats.typeInfo[0].code == Pattern::FuncArg); + LUAU_ASSERT( + module->ats.typeInfo[0].node == + "local function setupKiosk(kiosk: Model)\n local spawnLocation = kiosk:FindFirstChild('SpawnLocation')\n assert(spawnLocation, " + "`{kiosk:GetFullName()} has no SpawnLocation part`)\n local promptPart = kiosk:FindFirstChild('Prompt')\n assert(promptPart, " + "`{kiosk:GetFullName()} has no Prompt part`)\n\n\n spawnLocation.Transparency = 1\n\n\n local spawnPrompt = " + "spawnPromptTemplate:Clone()\n spawnPrompt.Parent = promptPart\n\n spawnPrompt.Triggered:Connect(function(player: Player)\n\n " + "destroyPlayerCars(player)\n\n spawnCar(spawnLocation.CFrame, player)\n end)\nend" + ); +} + +TEST_CASE_FIXTURE(ATSFixture, "mutually_recursive_generic") +{ + ScopedFastFlag sff[] = { + {FFlag::LuauSolverV2, true}, + {FFlag::StudioReportLuauAny2, true}, + }; + + fileResolver.source["game/Gui/Modules/A"] = R"( + --!strict + type T = { f: a, g: U } + type U = { h: a, i: T? } + local x: T = { f = 37, g = { h = 5, i = nil } } + x.g.i = x + local y: T = { f = "hi", g = { h = "lo", i = nil } } + y.g.i = y + )"; + + CheckResult result1 = frontend.check("game/Gui/Modules/A"); + LUAU_REQUIRE_ERROR_COUNT(2, result1); + + ModulePtr module = frontend.moduleResolver.getModule("game/Gui/Modules/A"); + + LUAU_ASSERT(module->ats.typeInfo.size() == 0); +} + +TEST_CASE_FIXTURE(ATSFixture, "explicit_pack") +{ + ScopedFastFlag sff[] = { + {FFlag::LuauSolverV2, true}, + {FFlag::StudioReportLuauAny2, true}, + }; + + fileResolver.source["game/Gui/Modules/A"] = R"( +type Foo = (T...) -> () -- also want to see how these are used. +type Bar = Foo<(number, any)> +)"; + + CheckResult result1 = frontend.check("game/Gui/Modules/A"); + LUAU_REQUIRE_NO_ERRORS(result1); + + ModulePtr module = frontend.moduleResolver.getModule("game/Gui/Modules/A"); + + LUAU_ASSERT(module->ats.typeInfo.size() == 1); + LUAU_ASSERT(module->ats.typeInfo[0].code == Pattern::Alias); + LUAU_ASSERT(module->ats.typeInfo[0].node == "type Bar = Foo<(number, any)>"); +} + +TEST_CASE_FIXTURE(ATSFixture, "local_val") +{ + ScopedFastFlag sff[] = { + {FFlag::LuauSolverV2, true}, + {FFlag::StudioReportLuauAny2, true}, + }; + + fileResolver.source["game/Gui/Modules/A"] = R"( +local a, b, c = 1 :: any +)"; + + CheckResult result1 = frontend.check("game/Gui/Modules/A"); + LUAU_REQUIRE_NO_ERRORS(result1); + + ModulePtr module = frontend.moduleResolver.getModule("game/Gui/Modules/A"); + + LUAU_ASSERT(module->ats.typeInfo.size() == 1); + LUAU_ASSERT(module->ats.typeInfo[0].code == Pattern::Casts); + LUAU_ASSERT(module->ats.typeInfo[0].node == "local a, b, c = 1 :: any"); +} + +TEST_CASE_FIXTURE(ATSFixture, "var_any_local") +{ + ScopedFastFlag sff[] = { + {FFlag::LuauSolverV2, true}, + {FFlag::StudioReportLuauAny2, true}, + }; + + fileResolver.source["game/Gui/Modules/A"] = R"( +local x = 2 +local x: any = 2, 3 +local x: any, y = 1, 2 +local x: number, y: any, z, h: nil = 1, nil +)"; + + CheckResult result1 = frontend.check("game/Gui/Modules/A"); + LUAU_REQUIRE_NO_ERRORS(result1); + + ModulePtr module = frontend.moduleResolver.getModule("game/Gui/Modules/A"); + + LUAU_ASSERT(module->ats.typeInfo.size() == 3); + LUAU_ASSERT(module->ats.typeInfo[0].code == Pattern::VarAnnot); + LUAU_ASSERT(module->ats.typeInfo[0].node == "local x: any = 2, 3"); +} + +TEST_CASE_FIXTURE(ATSFixture, "table_uses_any") +{ + ScopedFastFlag sff[] = { + {FFlag::LuauSolverV2, true}, + {FFlag::StudioReportLuauAny2, true}, + }; + + fileResolver.source["game/Gui/Modules/A"] = R"( + local x: any = 0 + local y: number + local z = {x=x, y=y} -- not catching this +)"; + + CheckResult result1 = frontend.check("game/Gui/Modules/A"); + LUAU_REQUIRE_NO_ERRORS(result1); + + ModulePtr module = frontend.moduleResolver.getModule("game/Gui/Modules/A"); + + LUAU_ASSERT(module->ats.typeInfo.size() == 1); + LUAU_ASSERT(module->ats.typeInfo[0].code == Pattern::VarAnnot); + LUAU_ASSERT(module->ats.typeInfo[0].node == "local x: any = 0"); +} + +TEST_CASE_FIXTURE(ATSFixture, "typeof_any") +{ + ScopedFastFlag sff[] = { + {FFlag::LuauSolverV2, true}, + {FFlag::StudioReportLuauAny2, true}, + }; + + fileResolver.source["game/Gui/Modules/A"] = R"( + local x: any = 0 + function some1(x: typeof(x)) + end +)"; + + CheckResult result1 = frontend.check("game/Gui/Modules/A"); + LUAU_REQUIRE_NO_ERRORS(result1); + + ModulePtr module = frontend.moduleResolver.getModule("game/Gui/Modules/A"); + + LUAU_ASSERT(module->ats.typeInfo.size() == 2); + LUAU_ASSERT(module->ats.typeInfo[1].code == Pattern::FuncArg); + LUAU_ASSERT(module->ats.typeInfo[0].node == "function some1(x: typeof(x))\n end"); +} + +TEST_CASE_FIXTURE(ATSFixture, "table_type_assigned") +{ + ScopedFastFlag sff[] = { + {FFlag::LuauSolverV2, true}, + {FFlag::StudioReportLuauAny2, true}, + }; + + fileResolver.source["game/Gui/Modules/A"] = R"( + local x: { x: any?} = {x = 1} + local z: { x : any, y : number? } -- not catching this + z.x = "bigfatlongstring" + z.y = nil +)"; + + CheckResult result1 = frontend.check("game/Gui/Modules/A"); + LUAU_REQUIRE_NO_ERRORS(result1); + + ModulePtr module = frontend.moduleResolver.getModule("game/Gui/Modules/A"); + + LUAU_ASSERT(module->ats.typeInfo.size() == 2); + LUAU_ASSERT(module->ats.typeInfo[1].code == Pattern::Assign); + LUAU_ASSERT(module->ats.typeInfo[0].node == "local x: { x: any?} = {x = 1}"); +} + +TEST_CASE_FIXTURE(ATSFixture, "simple_func_wo_ret") +{ + ScopedFastFlag sff[] = { + {FFlag::LuauSolverV2, true}, + {FFlag::StudioReportLuauAny2, true}, + }; + + fileResolver.source["game/Gui/Modules/A"] = R"( + function some(x: any) + end +)"; + + CheckResult result1 = frontend.check("game/Gui/Modules/A"); + LUAU_REQUIRE_NO_ERRORS(result1); + + ModulePtr module = frontend.moduleResolver.getModule("game/Gui/Modules/A"); + + LUAU_ASSERT(module->ats.typeInfo.size() == 1); + LUAU_ASSERT(module->ats.typeInfo[0].code == Pattern::FuncArg); + LUAU_ASSERT(module->ats.typeInfo[0].node == "function some(x: any)\n end"); +} + +TEST_CASE_FIXTURE(ATSFixture, "simple_func_w_ret") +{ + ScopedFastFlag sff[] = { + {FFlag::LuauSolverV2, true}, + {FFlag::StudioReportLuauAny2, true}, + }; + + fileResolver.source["game/Gui/Modules/A"] = R"( + function other(y: number): any + return "gotcha!" + end +)"; + + CheckResult result1 = frontend.check("game/Gui/Modules/A"); + LUAU_REQUIRE_NO_ERRORS(result1); + + ModulePtr module = frontend.moduleResolver.getModule("game/Gui/Modules/A"); + + LUAU_ASSERT(module->ats.typeInfo.size() == 1); + LUAU_ASSERT(module->ats.typeInfo[0].code == Pattern::FuncRet); + LUAU_ASSERT(module->ats.typeInfo[0].node == "function other(y: number): any\n return 'gotcha!'\n end"); +} + +TEST_CASE_FIXTURE(ATSFixture, "nested_local") +{ + ScopedFastFlag sff[] = { + {FFlag::LuauSolverV2, true}, + {FFlag::StudioReportLuauAny2, true}, + }; + + fileResolver.source["game/Gui/Modules/A"] = R"( + function cool(y: number): number + local g: any = "gratatataaa" + return y + end +)"; + + CheckResult result1 = frontend.check("game/Gui/Modules/A"); + LUAU_REQUIRE_NO_ERRORS(result1); + + ModulePtr module = frontend.moduleResolver.getModule("game/Gui/Modules/A"); + + LUAU_ASSERT(module->ats.typeInfo.size() == 1); + LUAU_ASSERT(module->ats.typeInfo[0].code == Pattern::VarAnnot); + LUAU_ASSERT(module->ats.typeInfo[0].node == "function cool(y: number): number\n local g: any = 'gratatataaa'\n return y\n end"); +} + +TEST_CASE_FIXTURE(ATSFixture, "generic_func") +{ + ScopedFastFlag sff[] = { + {FFlag::LuauSolverV2, true}, + {FFlag::StudioReportLuauAny2, true}, + }; + + fileResolver.source["game/Gui/Modules/A"] = R"( + function reverse(a: {T}, b: any): {T} + return a + end +)"; + + CheckResult result1 = frontend.check("game/Gui/Modules/A"); + LUAU_REQUIRE_NO_ERRORS(result1); + + ModulePtr module = frontend.moduleResolver.getModule("game/Gui/Modules/A"); + + LUAU_ASSERT(module->ats.typeInfo.size() == 1); + LUAU_ASSERT(module->ats.typeInfo[0].code == Pattern::FuncArg); + LUAU_ASSERT(module->ats.typeInfo[0].node == "function reverse(a: {T}, b: any): {T}\n return a\n end"); +} + +TEST_CASE_FIXTURE(ATSFixture, "type_alias_any") +{ + ScopedFastFlag sff[] = { + {FFlag::LuauSolverV2, true}, + {FFlag::StudioReportLuauAny2, true}, + }; + + fileResolver.source["game/Gui/Modules/A"] = R"( + type Clear = any + local z: Clear = "zip" +)"; + + CheckResult result1 = frontend.check("game/Gui/Modules/A"); + LUAU_REQUIRE_NO_ERRORS(result1); + + ModulePtr module = frontend.moduleResolver.getModule("game/Gui/Modules/A"); + + LUAU_ASSERT(module->ats.typeInfo.size() == 2); + LUAU_ASSERT(module->ats.typeInfo[0].code == Pattern::Alias); + LUAU_ASSERT(module->ats.typeInfo[0].node == "type Clear = any"); +} + +TEST_CASE_FIXTURE(ATSFixture, "multi_module_any") +{ + ScopedFastFlag sff[] = { + {FFlag::LuauSolverV2, true}, + {FFlag::StudioReportLuauAny2, true}, + }; + + fileResolver.source["game/A"] = R"( + export type MyFunction = (number, string) -> (any) +)"; + + fileResolver.source["game/B"] = R"( + local MyFunc = require(script.Parent.A) + type Clear = any + local z: Clear = "zip" +)"; + + fileResolver.source["game/Gui/Modules/A"] = R"( + local Modules = game:GetService('Gui').Modules + local B = require(Modules.B) + return {hello = B.hello} +)"; + + CheckResult result = frontend.check("game/B"); + LUAU_REQUIRE_ERROR_COUNT(1, result); + + ModulePtr module = frontend.moduleResolver.getModule("game/B"); + + LUAU_ASSERT(module->ats.typeInfo.size() == 2); + LUAU_ASSERT(module->ats.typeInfo[0].code == Pattern::Alias); + LUAU_ASSERT(module->ats.typeInfo[0].node == "type Clear = any"); +} + +TEST_CASE_FIXTURE(ATSFixture, "cast_on_cyclic_req") +{ + ScopedFastFlag sff[] = { + {FFlag::LuauSolverV2, true}, + {FFlag::StudioReportLuauAny2, true}, + }; + + fileResolver.source["game/A"] = R"( + local a = require(script.Parent.B) -- not resolving this module + export type MyFunction = (number, string) -> (any) +)"; + + fileResolver.source["game/B"] = R"( + local MyFunc = require(script.Parent.A) :: any + type Clear = any + local z: Clear = "zip" +)"; + + CheckResult result = frontend.check("game/B"); + LUAU_REQUIRE_ERROR_COUNT(0, result); + + ModulePtr module = frontend.moduleResolver.getModule("game/B"); + + LUAU_ASSERT(module->ats.typeInfo.size() == 3); + LUAU_ASSERT(module->ats.typeInfo[1].code == Pattern::Alias); + LUAU_ASSERT(module->ats.typeInfo[1].node == "type Clear = any"); +} + + +TEST_SUITE_END(); diff --git a/tests/AssemblyBuilderA64.test.cpp b/tests/AssemblyBuilderA64.test.cpp index e23b965bc..2cd821b58 100644 --- a/tests/AssemblyBuilderA64.test.cpp +++ b/tests/AssemblyBuilderA64.test.cpp @@ -1,6 +1,7 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/AssemblyBuilderA64.h" #include "Luau/StringUtils.h" +#include "ScopedFlags.h" #include "doctest.h" @@ -32,9 +33,9 @@ static std::string bytecodeAsArray(const std::vector& code) class AssemblyBuilderA64Fixture { public: - bool check(void (*f)(AssemblyBuilderA64& build), std::vector code, std::vector data = {}) + bool check(void (*f)(AssemblyBuilderA64& build), std::vector code, std::vector data = {}, unsigned int features = 0) { - AssemblyBuilderA64 build(/* logText= */ false); + AssemblyBuilderA64 build(/* logText= */ false, features); f(build); @@ -61,21 +62,25 @@ TEST_SUITE_BEGIN("A64Assembly"); #define SINGLE_COMPARE(inst, ...) \ CHECK(check( \ - [](AssemblyBuilderA64& build) { \ + [](AssemblyBuilderA64& build) \ + { \ build.inst; \ }, \ - {__VA_ARGS__})) + {__VA_ARGS__} \ + )) TEST_CASE_FIXTURE(AssemblyBuilderA64Fixture, "Unary") { SINGLE_COMPARE(neg(x0, x1), 0xCB0103E0); SINGLE_COMPARE(neg(w0, w1), 0x4B0103E0); - SINGLE_COMPARE(mvn(x0, x1), 0xAA2103E0); + SINGLE_COMPARE(mvn_(x0, x1), 0xAA2103E0); SINGLE_COMPARE(clz(x0, x1), 0xDAC01020); SINGLE_COMPARE(clz(w0, w1), 0x5AC01020); SINGLE_COMPARE(rbit(x0, x1), 0xDAC00020); SINGLE_COMPARE(rbit(w0, w1), 0x5AC00020); + SINGLE_COMPARE(rev(w0, w1), 0x5AC00820); + SINGLE_COMPARE(rev(x0, x1), 0xDAC00C20); } TEST_CASE_FIXTURE(AssemblyBuilderA64Fixture, "Binary") @@ -84,8 +89,12 @@ TEST_CASE_FIXTURE(AssemblyBuilderA64Fixture, "Binary") SINGLE_COMPARE(add(x0, x1, x2), 0x8B020020); SINGLE_COMPARE(add(w0, w1, w2), 0x0B020020); SINGLE_COMPARE(add(x0, x1, x2, 7), 0x8B021C20); + SINGLE_COMPARE(add(x0, x1, x2, -7), 0x8B421C20); SINGLE_COMPARE(sub(x0, x1, x2), 0xCB020020); SINGLE_COMPARE(and_(x0, x1, x2), 0x8A020020); + SINGLE_COMPARE(and_(x0, x1, x2, 7), 0x8A021C20); + SINGLE_COMPARE(and_(x0, x1, x2, -7), 0x8A421C20); + SINGLE_COMPARE(bic(x0, x1, x2), 0x8A220020); SINGLE_COMPARE(orr(x0, x1, x2), 0xAA020020); SINGLE_COMPARE(eor(x0, x1, x2), 0xCA020020); SINGLE_COMPARE(lsl(x0, x1, x2), 0x9AC22020); @@ -94,6 +103,7 @@ TEST_CASE_FIXTURE(AssemblyBuilderA64Fixture, "Binary") SINGLE_COMPARE(asr(x0, x1, x2), 0x9AC22820); SINGLE_COMPARE(ror(x0, x1, x2), 0x9AC22C20); SINGLE_COMPARE(cmp(x0, x1), 0xEB01001F); + SINGLE_COMPARE(tst(x0, x1), 0xEA01001F); // reg, imm SINGLE_COMPARE(add(x3, x7, 78), 0x910138E3); @@ -102,6 +112,54 @@ TEST_CASE_FIXTURE(AssemblyBuilderA64Fixture, "Binary") SINGLE_COMPARE(cmp(w0, 42), 0x7100A81F); } +TEST_CASE_FIXTURE(AssemblyBuilderA64Fixture, "BinaryExtended") +{ + // reg, reg + SINGLE_COMPARE(add(x0, x1, w2, 3), 0x8B224C20); + SINGLE_COMPARE(sub(x0, x1, w2, 3), 0xCB224C20); +} + +TEST_CASE_FIXTURE(AssemblyBuilderA64Fixture, "BinaryImm") +{ + // instructions + SINGLE_COMPARE(and_(w1, w2, 1), 0x12000041); + SINGLE_COMPARE(orr(w1, w2, 1), 0x32000041); + SINGLE_COMPARE(eor(w1, w2, 1), 0x52000041); + SINGLE_COMPARE(tst(w1, 1), 0x7200003f); + + // various mask forms + SINGLE_COMPARE(and_(w0, w0, 1), 0x12000000); + SINGLE_COMPARE(and_(w0, w0, 3), 0x12000400); + SINGLE_COMPARE(and_(w0, w0, 7), 0x12000800); + SINGLE_COMPARE(and_(w0, w0, 2147483647), 0x12007800); + SINGLE_COMPARE(and_(w0, w0, 6), 0x121F0400); + SINGLE_COMPARE(and_(w0, w0, 12), 0x121E0400); + SINGLE_COMPARE(and_(w0, w0, 2147483648), 0x12010000); + + // shifts + SINGLE_COMPARE(lsl(w1, w2, 1), 0x531F7841); + SINGLE_COMPARE(lsl(x1, x2, 1), 0xD37FF841); + SINGLE_COMPARE(lsr(w1, w2, 1), 0x53017C41); + SINGLE_COMPARE(lsr(x1, x2, 1), 0xD341FC41); + SINGLE_COMPARE(asr(w1, w2, 1), 0x13017C41); + SINGLE_COMPARE(asr(x1, x2, 1), 0x9341FC41); + SINGLE_COMPARE(ror(w1, w2, 1), 0x13820441); + SINGLE_COMPARE(ror(x1, x2, 1), 0x93C20441); +} + +TEST_CASE_FIXTURE(AssemblyBuilderA64Fixture, "Bitfield") +{ + SINGLE_COMPARE(ubfiz(x1, x2, 37, 5), 0xD35B1041); + SINGLE_COMPARE(ubfx(x1, x2, 37, 5), 0xD365A441); + SINGLE_COMPARE(sbfiz(x1, x2, 37, 5), 0x935B1041); + SINGLE_COMPARE(sbfx(x1, x2, 37, 5), 0x9365A441); + + SINGLE_COMPARE(ubfiz(w1, w2, 17, 5), 0x530F1041); + SINGLE_COMPARE(ubfx(w1, w2, 17, 5), 0x53115441); + SINGLE_COMPARE(sbfiz(w1, w2, 17, 5), 0x130F1041); + SINGLE_COMPARE(sbfx(w1, w2, 17, 5), 0x13115441); +} + TEST_CASE_FIXTURE(AssemblyBuilderA64Fixture, "Loads") { // address forms @@ -120,6 +178,18 @@ TEST_CASE_FIXTURE(AssemblyBuilderA64Fixture, "Loads") SINGLE_COMPARE(ldrsh(x0, x1), 0x79800020); SINGLE_COMPARE(ldrsh(w0, x1), 0x79C00020); SINGLE_COMPARE(ldrsw(x0, x1), 0xB9800020); + + // load sizes x offset scaling + SINGLE_COMPARE(ldr(x0, mem(x1, 8)), 0xF9400420); + SINGLE_COMPARE(ldr(w0, mem(x1, 8)), 0xB9400820); + SINGLE_COMPARE(ldrb(w0, mem(x1, 8)), 0x39402020); + SINGLE_COMPARE(ldrh(w0, mem(x1, 8)), 0x79401020); + SINGLE_COMPARE(ldrsb(w0, mem(x1, 8)), 0x39C02020); + SINGLE_COMPARE(ldrsh(w0, mem(x1, 8)), 0x79C01020); + + // paired loads + SINGLE_COMPARE(ldp(x0, x1, mem(x2, 8)), 0xA9408440); + SINGLE_COMPARE(ldp(w0, w1, mem(x2, -8)), 0x297F0440); } TEST_CASE_FIXTURE(AssemblyBuilderA64Fixture, "Stores") @@ -135,49 +205,120 @@ TEST_CASE_FIXTURE(AssemblyBuilderA64Fixture, "Stores") SINGLE_COMPARE(str(w0, x1), 0xB9000020); SINGLE_COMPARE(strb(w0, x1), 0x39000020); SINGLE_COMPARE(strh(w0, x1), 0x79000020); + + // store sizes x offset scaling + SINGLE_COMPARE(str(x0, mem(x1, 8)), 0xF9000420); + SINGLE_COMPARE(str(w0, mem(x1, 8)), 0xB9000820); + SINGLE_COMPARE(strb(w0, mem(x1, 8)), 0x39002020); + SINGLE_COMPARE(strh(w0, mem(x1, 8)), 0x79001020); + + // paired stores + SINGLE_COMPARE(stp(x0, x1, mem(x2, 8)), 0xA9008440); + SINGLE_COMPARE(stp(w0, w1, mem(x2, -8)), 0x293F0440); } TEST_CASE_FIXTURE(AssemblyBuilderA64Fixture, "Moves") { SINGLE_COMPARE(mov(x0, x1), 0xAA0103E0); SINGLE_COMPARE(mov(w0, w1), 0x2A0103E0); - SINGLE_COMPARE(mov(x0, 42), 0xD2800540); - SINGLE_COMPARE(mov(w0, 42), 0x52800540); + SINGLE_COMPARE(mov(q0, q1), 0x4EA11C20); + + SINGLE_COMPARE(movz(x0, 42), 0xD2800540); + SINGLE_COMPARE(movz(w0, 42), 0x52800540); + SINGLE_COMPARE(movn(x0, 42), 0x92800540); + SINGLE_COMPARE(movn(w0, 42), 0x12800540); SINGLE_COMPARE(movk(x0, 42, 16), 0xF2A00540); + + CHECK(check( + [](AssemblyBuilderA64& build) + { + build.mov(x0, 42); + }, + {0xD2800540} + )); + + CHECK(check( + [](AssemblyBuilderA64& build) + { + build.mov(x0, 424242); + }, + {0xD28F2640, 0xF2A000C0} + )); + + CHECK(check( + [](AssemblyBuilderA64& build) + { + build.mov(x0, -42); + }, + {0x92800520} + )); + + CHECK(check( + [](AssemblyBuilderA64& build) + { + build.mov(x0, -424242); + }, + {0x928F2620, 0xF2BFFF20} + )); + + CHECK(check( + [](AssemblyBuilderA64& build) + { + build.mov(x0, -65536); + }, + {0x929FFFE0} + )); + + CHECK(check( + [](AssemblyBuilderA64& build) + { + build.mov(x0, -65537); + }, + {0x92800000, 0xF2BFFFC0} + )); } TEST_CASE_FIXTURE(AssemblyBuilderA64Fixture, "ControlFlow") { // Jump back CHECK(check( - [](AssemblyBuilderA64& build) { + [](AssemblyBuilderA64& build) + { Label start = build.setLabel(); build.mov(x0, x1); build.b(ConditionA64::Equal, start); }, - {0xAA0103E0, 0x54FFFFE0})); + {0xAA0103E0, 0x54FFFFE0} + )); // Jump forward CHECK(check( - [](AssemblyBuilderA64& build) { + [](AssemblyBuilderA64& build) + { Label skip; build.b(ConditionA64::Equal, skip); build.mov(x0, x1); build.setLabel(skip); }, - {0x54000040, 0xAA0103E0})); + {0x54000040, 0xAA0103E0} + )); // Jumps CHECK(check( - [](AssemblyBuilderA64& build) { + [](AssemblyBuilderA64& build) + { Label skip; build.b(ConditionA64::Equal, skip); build.cbz(x0, skip); build.cbnz(x0, skip); + build.tbz(x0, 5, skip); + build.tbnz(x0, 5, skip); build.setLabel(skip); build.b(skip); + build.bl(skip); }, - {0x54000060, 0xB4000040, 0xB5000020, 0x5400000E})); + {0x540000A0, 0xB4000080, 0xB5000060, 0x36280040, 0x37280020, 0x14000000, 0x97ffffff} + )); // Basic control flow SINGLE_COMPARE(br(x0), 0xD61F0000); @@ -222,6 +363,185 @@ TEST_CASE_FIXTURE(AssemblyBuilderA64Fixture, "Constants") // clang-format on } +TEST_CASE_FIXTURE(AssemblyBuilderA64Fixture, "AddressOfLabel") +{ + // clang-format off + CHECK(check( + [](AssemblyBuilderA64& build) { + Label label; + build.adr(x0, label); + build.add(x0, x0, x0); + build.setLabel(label); + }, + { + 0x10000040, 0x8b000000, + })); + // clang-format on +} + +TEST_CASE_FIXTURE(AssemblyBuilderA64Fixture, "FPBasic") +{ + SINGLE_COMPARE(fmov(d0, d1), 0x1E604020); + SINGLE_COMPARE(fmov(d0, x1), 0x9E670020); +} + +TEST_CASE_FIXTURE(AssemblyBuilderA64Fixture, "FPMath") +{ + SINGLE_COMPARE(fabs(d1, d2), 0x1E60C041); + SINGLE_COMPARE(fadd(d1, d2, d3), 0x1E632841); + SINGLE_COMPARE(fadd(s29, s29, s28), 0x1E3C2BBD); + SINGLE_COMPARE(fdiv(d1, d2, d3), 0x1E631841); + SINGLE_COMPARE(fdiv(s29, s29, s28), 0x1E3C1BBD); + SINGLE_COMPARE(fmul(d1, d2, d3), 0x1E630841); + SINGLE_COMPARE(fmul(s29, s29, s28), 0x1E3C0BBD); + SINGLE_COMPARE(fneg(d1, d2), 0x1E614041); + SINGLE_COMPARE(fneg(s30, s30), 0x1E2143DE); + SINGLE_COMPARE(fsqrt(d1, d2), 0x1E61C041); + SINGLE_COMPARE(fsub(d1, d2, d3), 0x1E633841); + SINGLE_COMPARE(fsub(s29, s29, s28), 0x1E3C3BBD); + + SINGLE_COMPARE(frinta(d1, d2), 0x1E664041); + SINGLE_COMPARE(frintm(d1, d2), 0x1E654041); + SINGLE_COMPARE(frintp(d1, d2), 0x1E64C041); + + SINGLE_COMPARE(fcvt(s1, d2), 0x1E624041); + SINGLE_COMPARE(fcvt(d1, s2), 0x1E22C041); + + SINGLE_COMPARE(fcvtzs(w1, d2), 0x1E780041); + SINGLE_COMPARE(fcvtzs(x1, d2), 0x9E780041); + SINGLE_COMPARE(fcvtzu(w1, d2), 0x1E790041); + SINGLE_COMPARE(fcvtzu(x1, d2), 0x9E790041); + + SINGLE_COMPARE(scvtf(d1, w2), 0x1E620041); + SINGLE_COMPARE(scvtf(d1, x2), 0x9E620041); + SINGLE_COMPARE(ucvtf(d1, w2), 0x1E630041); + SINGLE_COMPARE(ucvtf(d1, x2), 0x9E630041); + + CHECK(check( + [](AssemblyBuilderA64& build) + { + build.fjcvtzs(w1, d2); + }, + {0x1E7E0041}, + {}, + A64::Feature_JSCVT + )); +} + +TEST_CASE_FIXTURE(AssemblyBuilderA64Fixture, "FPLoadStore") +{ + // address forms + SINGLE_COMPARE(ldr(d0, x1), 0xFD400020); + SINGLE_COMPARE(ldr(d0, mem(x1, 8)), 0xFD400420); + SINGLE_COMPARE(ldr(d0, mem(x1, x7)), 0xFC676820); + SINGLE_COMPARE(ldr(d0, mem(x1, -7)), 0xFC5F9020); + SINGLE_COMPARE(str(d0, x1), 0xFD000020); + SINGLE_COMPARE(str(d0, mem(x1, 8)), 0xFD000420); + SINGLE_COMPARE(str(d0, mem(x1, x7)), 0xFC276820); + SINGLE_COMPARE(str(d0, mem(x1, -7)), 0xFC1F9020); + + // load/store sizes + SINGLE_COMPARE(ldr(s0, x1), 0xBD400020); + SINGLE_COMPARE(ldr(d0, x1), 0xFD400020); + SINGLE_COMPARE(ldr(q0, x1), 0x3DC00020); + SINGLE_COMPARE(str(s0, x1), 0xBD000020); + SINGLE_COMPARE(str(d0, x1), 0xFD000020); + SINGLE_COMPARE(str(q0, x1), 0x3D800020); + + // load/store sizes x offset scaling + SINGLE_COMPARE(ldr(q0, mem(x1, 16)), 0x3DC00420); + SINGLE_COMPARE(ldr(d0, mem(x1, 16)), 0xFD400820); + SINGLE_COMPARE(ldr(s0, mem(x1, 16)), 0xBD401020); + SINGLE_COMPARE(str(q0, mem(x1, 16)), 0x3D800420); + SINGLE_COMPARE(str(d0, mem(x1, 16)), 0xFD000820); + SINGLE_COMPARE(str(s0, mem(x1, 16)), 0xBD001020); +} + +TEST_CASE_FIXTURE(AssemblyBuilderA64Fixture, "FPInsertExtract") +{ + SINGLE_COMPARE(ins_4s(q29, w17, 3), 0x4E1C1E3D); + SINGLE_COMPARE(ins_4s(q31, 0, q29, 0), 0x6E0407BF); + SINGLE_COMPARE(dup_4s(s29, q31, 2), 0x5E1407FD); + SINGLE_COMPARE(dup_4s(q29, q30, 0), 0x4E0407DD); +} + +TEST_CASE_FIXTURE(AssemblyBuilderA64Fixture, "FPCompare") +{ + SINGLE_COMPARE(fcmp(d0, d1), 0x1E612000); + SINGLE_COMPARE(fcmpz(d1), 0x1E602028); +} + +TEST_CASE_FIXTURE(AssemblyBuilderA64Fixture, "FPImm") +{ + SINGLE_COMPARE(fmov(d0, 0), 0x2F00E400); + SINGLE_COMPARE(fmov(d0, 0.125), 0x1E681000); + SINGLE_COMPARE(fmov(d0, -0.125), 0x1E781000); + SINGLE_COMPARE(fmov(d0, 1.9375), 0x1E6FF000); + + SINGLE_COMPARE(fmov(q0, 0), 0x4F000400); + SINGLE_COMPARE(fmov(q0, 0.125), 0x4F02F400); + SINGLE_COMPARE(fmov(q0, -0.125), 0x4F06F400); + SINGLE_COMPARE(fmov(q0, 1.9375), 0x4F03F7E0); + + CHECK(!AssemblyBuilderA64::isFmovSupported(-0.0)); + CHECK(!AssemblyBuilderA64::isFmovSupported(0.12389)); +} + +TEST_CASE_FIXTURE(AssemblyBuilderA64Fixture, "AddressOffsetSize") +{ + SINGLE_COMPARE(ldr(w0, mem(x1, 16)), 0xB9401020); + SINGLE_COMPARE(ldr(x0, mem(x1, 16)), 0xF9400820); + SINGLE_COMPARE(ldr(d0, mem(x1, 16)), 0xFD400820); + SINGLE_COMPARE(ldr(q0, mem(x1, 16)), 0x3DC00420); + + SINGLE_COMPARE(str(w0, mem(x1, 16)), 0xB9001020); + SINGLE_COMPARE(str(x0, mem(x1, 16)), 0xF9000820); + SINGLE_COMPARE(str(d0, mem(x1, 16)), 0xFD000820); + SINGLE_COMPARE(str(q0, mem(x1, 16)), 0x3D800420); +} + +TEST_CASE_FIXTURE(AssemblyBuilderA64Fixture, "Conditionals") +{ + SINGLE_COMPARE(csel(x0, x1, x2, ConditionA64::Equal), 0x9A820020); + SINGLE_COMPARE(csel(w0, w1, w2, ConditionA64::Equal), 0x1A820020); + SINGLE_COMPARE(fcsel(d0, d1, d2, ConditionA64::Equal), 0x1E620C20); + + SINGLE_COMPARE(cset(x1, ConditionA64::Less), 0x9A9FA7E1); +} + +TEST_CASE_FIXTURE(AssemblyBuilderA64Fixture, "Undefined") +{ + SINGLE_COMPARE(udf(), 0x00000000); +} + +TEST_CASE_FIXTURE(AssemblyBuilderA64Fixture, "PrePostIndexing") +{ + SINGLE_COMPARE(ldr(x0, mem(x1, 1)), 0xF8401020); + SINGLE_COMPARE(ldr(x0, mem(x1, 1, AddressKindA64::pre)), 0xF8401C20); + SINGLE_COMPARE(ldr(x0, mem(x1, 1, AddressKindA64::post)), 0xF8401420); + + SINGLE_COMPARE(ldr(q0, mem(x1, 1)), 0x3CC01020); + SINGLE_COMPARE(ldr(q0, mem(x1, 1, AddressKindA64::pre)), 0x3CC01C20); + SINGLE_COMPARE(ldr(q0, mem(x1, 1, AddressKindA64::post)), 0x3CC01420); + + SINGLE_COMPARE(str(x0, mem(x1, 1)), 0xF8001020); + SINGLE_COMPARE(str(x0, mem(x1, 1, AddressKindA64::pre)), 0xF8001C20); + SINGLE_COMPARE(str(x0, mem(x1, 1, AddressKindA64::post)), 0xF8001420); + + SINGLE_COMPARE(str(q0, mem(x1, 1)), 0x3C801020); + SINGLE_COMPARE(str(q0, mem(x1, 1, AddressKindA64::pre)), 0x3C801C20); + SINGLE_COMPARE(str(q0, mem(x1, 1, AddressKindA64::post)), 0x3C801420); +} + +TEST_CASE_FIXTURE(AssemblyBuilderA64Fixture, "SIMDMath") +{ + SINGLE_COMPARE(fadd(q0, q1, q2), 0x4E22D420); + SINGLE_COMPARE(fsub(q0, q1, q2), 0x4EA2D420); + SINGLE_COMPARE(fmul(q0, q1, q2), 0x6E22DC20); + SINGLE_COMPARE(fdiv(q0, q1, q2), 0x6E22FC20); + SINGLE_COMPARE(fneg(q0, q1), 0x6EA0F820); +} + TEST_CASE("LogTest") { AssemblyBuilderA64 build(/* logText= */ true); @@ -229,6 +549,7 @@ TEST_CASE("LogTest") build.add(sp, sp, 4); build.add(w0, w1, w2); build.add(x0, x1, x2, 2); + build.add(x0, x1, x2, -2); build.add(w7, w8, 5); build.add(x7, x8, 5); build.ldr(x7, x8); @@ -243,6 +564,37 @@ TEST_CASE("LogTest") build.b(ConditionA64::Plus, l); build.cbz(x7, l); + build.ldp(x0, x1, mem(x8, 8)); + build.adr(x0, l); + + build.fabs(d1, d2); + build.ldr(q1, x2); + + build.csel(x0, x1, x2, ConditionA64::Equal); + build.cset(x0, ConditionA64::Equal); + + build.fcmp(d0, d1); + build.fcmpz(d0); + + build.fmov(d0, 0.25); + build.tbz(x0, 5, l); + + build.fcvt(s1, d2); + + build.ubfx(x1, x2, 37, 5); + + build.ldr(x0, mem(x1, 1)); + build.ldr(x0, mem(x1, 1, AddressKindA64::pre)); + build.ldr(x0, mem(x1, 1, AddressKindA64::post)); + + build.add(x1, x2, w3, 3); + + build.ins_4s(q29, w17, 3); + build.ins_4s(q31, 1, q29, 2); + build.dup_4s(s29, q31, 2); + build.dup_4s(q29, q30, 0); + build.fmul(q0, q1, q2); + build.setLabel(l); build.ret(); @@ -252,6 +604,7 @@ TEST_CASE("LogTest") add sp,sp,#4 add w0,w1,w2 add x0,x1,x2 LSL #2 + add x0,x1,x2 LSR #2 add w7,w8,#5 add x7,x8,#5 ldr x7,[x8] @@ -263,6 +616,27 @@ TEST_CASE("LogTest") blr x0 b.pl .L1 cbz x7,.L1 + ldp x0,x1,[x8,#8] + adr x0,.L1 + fabs d1,d2 + ldr q1,[x2] + csel x0,x1,x2,eq + cset x0,eq + fcmp d0,d1 + fcmp d0,#0 + fmov d0,#0.25 + tbz x0,#5,.L1 + fcvt s1,d2 + ubfx x1,x2,#3705 + ldr x0,[x1,#1] + ldr x0,[x1,#1]! + ldr x0,[x1]!,#1 + add x1,x2,w3 UXTW #3 + ins v29.s[3],w17 + ins v31.s[1],v29.s[2] + dup s29,v31.s[2] + dup v29.4s,v30.s[0] + fmul v0.4s,v1.4s,v2.4s .L1: ret )"; diff --git a/tests/AssemblyBuilderX64.test.cpp b/tests/AssemblyBuilderX64.test.cpp index 6aa7aa561..655fa8f19 100644 --- a/tests/AssemblyBuilderX64.test.cpp +++ b/tests/AssemblyBuilderX64.test.cpp @@ -3,6 +3,7 @@ #include "Luau/StringUtils.h" #include "doctest.h" +#include "ScopedFlags.h" #include @@ -50,10 +51,12 @@ TEST_SUITE_BEGIN("x64Assembly"); #define SINGLE_COMPARE(inst, ...) \ CHECK(check( \ - [](AssemblyBuilderX64& build) { \ + [](AssemblyBuilderX64& build) \ + { \ build.inst; \ }, \ - {__VA_ARGS__})) + {__VA_ARGS__} \ + )) TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "BaseBinaryInstructionForms") { @@ -67,6 +70,9 @@ TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "BaseBinaryInstructionForms") SINGLE_COMPARE(add(rax, 0x7f), 0x48, 0x83, 0xc0, 0x7f); SINGLE_COMPARE(add(rax, 0x80), 0x48, 0x81, 0xc0, 0x80, 0x00, 0x00, 0x00); SINGLE_COMPARE(add(r10, 0x7fffffff), 0x49, 0x81, 0xc2, 0xff, 0xff, 0xff, 0x7f); + SINGLE_COMPARE(add(al, 3), 0x80, 0xc0, 0x03); + SINGLE_COMPARE(add(sil, 3), 0x48, 0x80, 0xc6, 0x03); + SINGLE_COMPARE(add(r11b, 3), 0x49, 0x80, 0xc3, 0x03); // reg, [reg] SINGLE_COMPARE(add(rax, qword[rax]), 0x48, 0x03, 0x00); @@ -191,6 +197,8 @@ TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "FormsOfMov") SINGLE_COMPARE(mov64(rcx, 0x1234567812345678ll), 0x48, 0xb9, 0x78, 0x56, 0x34, 0x12, 0x78, 0x56, 0x34, 0x12); SINGLE_COMPARE(mov(ecx, 2), 0xb9, 0x02, 0x00, 0x00, 0x00); SINGLE_COMPARE(mov(cl, 2), 0xb1, 0x02); + SINGLE_COMPARE(mov(sil, 2), 0x48, 0xb6, 0x02); + SINGLE_COMPARE(mov(r9b, 2), 0x49, 0xb1, 0x02); SINGLE_COMPARE(mov(rcx, qword[rdi]), 0x48, 0x8b, 0x0f); SINGLE_COMPARE(mov(dword[rax], 0xabcd), 0xc7, 0x00, 0xcd, 0xab, 0x00, 0x00); SINGLE_COMPARE(mov(r13, 1), 0x49, 0xbd, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00); @@ -201,6 +209,13 @@ TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "FormsOfMov") SINGLE_COMPARE(mov(qword[rdx], r9), 0x4c, 0x89, 0x0a); SINGLE_COMPARE(mov(byte[rsi], 0x3), 0xc6, 0x06, 0x03); SINGLE_COMPARE(mov(byte[rsi], al), 0x88, 0x06); + SINGLE_COMPARE(mov(byte[rsi], dil), 0x48, 0x88, 0x3e); + SINGLE_COMPARE(mov(byte[rsi], r10b), 0x4c, 0x88, 0x16); + SINGLE_COMPARE(mov(wordReg(ebx), 0x3a3d), 0x66, 0xbb, 0x3d, 0x3a); + SINGLE_COMPARE(mov(word[rsi], 0x3a3d), 0x66, 0xc7, 0x06, 0x3d, 0x3a); + SINGLE_COMPARE(mov(word[rsi], wordReg(eax)), 0x66, 0x89, 0x06); + SINGLE_COMPARE(mov(word[rsi], wordReg(edi)), 0x66, 0x89, 0x3e); + SINGLE_COMPARE(mov(word[rsi], wordReg(r10)), 0x66, 0x44, 0x89, 0x16); } TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "FormsOfMovExtended") @@ -229,12 +244,18 @@ TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "FormsOfShift") { SINGLE_COMPARE(shl(al, 1), 0xd0, 0xe0); SINGLE_COMPARE(shl(al, cl), 0xd2, 0xe0); + SINGLE_COMPARE(shl(sil, cl), 0x48, 0xd2, 0xe6); + SINGLE_COMPARE(shl(r10b, cl), 0x49, 0xd2, 0xe2); SINGLE_COMPARE(shr(al, 4), 0xc0, 0xe8, 0x04); SINGLE_COMPARE(shr(eax, 1), 0xd1, 0xe8); SINGLE_COMPARE(sal(eax, cl), 0xd3, 0xe0); SINGLE_COMPARE(sal(eax, 4), 0xc1, 0xe0, 0x04); SINGLE_COMPARE(sar(rax, 4), 0x48, 0xc1, 0xf8, 0x04); SINGLE_COMPARE(sar(r11, 1), 0x49, 0xd1, 0xfb); + SINGLE_COMPARE(rol(eax, 1), 0xd1, 0xc0); + SINGLE_COMPARE(rol(eax, cl), 0xd3, 0xc0); + SINGLE_COMPARE(ror(eax, 1), 0xd1, 0xc8); + SINGLE_COMPARE(ror(eax, cl), 0xd3, 0xc8); } TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "FormsOfLea") @@ -247,9 +268,18 @@ TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "FormsOfLea") TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "FormsOfSetcc") { SINGLE_COMPARE(setcc(ConditionX64::NotEqual, bl), 0x0f, 0x95, 0xc3); + SINGLE_COMPARE(setcc(ConditionX64::NotEqual, dil), 0x48, 0x0f, 0x95, 0xc7); SINGLE_COMPARE(setcc(ConditionX64::BelowEqual, byte[rcx]), 0x0f, 0x96, 0x01); } +TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "FormsOfCmov") +{ + SINGLE_COMPARE(cmov(ConditionX64::LessEqual, ebx, eax), 0x0f, 0x4e, 0xd8); + SINGLE_COMPARE(cmov(ConditionX64::NotZero, rbx, qword[rax]), 0x48, 0x0f, 0x45, 0x18); + SINGLE_COMPARE(cmov(ConditionX64::Zero, rbx, qword[rax + rcx]), 0x48, 0x0f, 0x44, 0x1c, 0x08); + SINGLE_COMPARE(cmov(ConditionX64::BelowEqual, r14d, r15d), 0x45, 0x0f, 0x46, 0xf7); +} + TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "FormsOfAbsoluteJumps") { SINGLE_COMPARE(jmp(rax), 0xff, 0xe0); @@ -290,33 +320,41 @@ TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "NopForms") TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "AlignmentForms") { CHECK(check( - [](AssemblyBuilderX64& build) { + [](AssemblyBuilderX64& build) + { build.ret(); build.align(8, AlignmentDataX64::Nop); }, - {0xc3, 0x0f, 0x1f, 0x80, 0x00, 0x00, 0x00, 0x00})); + {0xc3, 0x0f, 0x1f, 0x80, 0x00, 0x00, 0x00, 0x00} + )); CHECK(check( - [](AssemblyBuilderX64& build) { + [](AssemblyBuilderX64& build) + { build.ret(); build.align(32, AlignmentDataX64::Nop); }, - {0xc3, 0x66, 0x0f, 0x1f, 0x84, 0x00, 0x00, 0x00, 0x00, 0x00, 0x66, 0x0f, 0x1f, 0x84, 0x00, 0x00, 0x00, 0x00, 0x00, 0x66, 0x0f, 0x1f, 0x84, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x0f, 0x1f, 0x40, 0x00})); + {0xc3, 0x66, 0x0f, 0x1f, 0x84, 0x00, 0x00, 0x00, 0x00, 0x00, 0x66, 0x0f, 0x1f, 0x84, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x66, 0x0f, 0x1f, 0x84, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0f, 0x1f, 0x40, 0x00} + )); CHECK(check( - [](AssemblyBuilderX64& build) { + [](AssemblyBuilderX64& build) + { build.ret(); build.align(8, AlignmentDataX64::Int3); }, - {0xc3, 0xcc, 0xcc, 0xcc, 0xcc, 0xcc, 0xcc, 0xcc})); + {0xc3, 0xcc, 0xcc, 0xcc, 0xcc, 0xcc, 0xcc, 0xcc} + )); CHECK(check( - [](AssemblyBuilderX64& build) { + [](AssemblyBuilderX64& build) + { build.ret(); build.align(8, AlignmentDataX64::Ud2); }, - {0xc3, 0x0f, 0x0b, 0x0f, 0x0b, 0x0f, 0x0b, 0xcc})); + {0xc3, 0x0f, 0x0b, 0x0f, 0x0b, 0x0f, 0x0b, 0xcc} + )); } TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "AlignmentOverflow") @@ -359,28 +397,33 @@ TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "ControlFlow") { // Jump back CHECK(check( - [](AssemblyBuilderX64& build) { + [](AssemblyBuilderX64& build) + { Label start = build.setLabel(); build.add(rsi, 1); build.cmp(rsi, rdi); build.jcc(ConditionX64::Equal, start); }, - {0x48, 0x83, 0xc6, 0x01, 0x48, 0x3b, 0xf7, 0x0f, 0x84, 0xf3, 0xff, 0xff, 0xff})); + {0x48, 0x83, 0xc6, 0x01, 0x48, 0x3b, 0xf7, 0x0f, 0x84, 0xf3, 0xff, 0xff, 0xff} + )); // Jump back, but the label is set before use CHECK(check( - [](AssemblyBuilderX64& build) { + [](AssemblyBuilderX64& build) + { Label start; build.add(rsi, 1); build.setLabel(start); build.cmp(rsi, rdi); build.jcc(ConditionX64::Equal, start); }, - {0x48, 0x83, 0xc6, 0x01, 0x48, 0x3b, 0xf7, 0x0f, 0x84, 0xf7, 0xff, 0xff, 0xff})); + {0x48, 0x83, 0xc6, 0x01, 0x48, 0x3b, 0xf7, 0x0f, 0x84, 0xf7, 0xff, 0xff, 0xff} + )); // Jump forward CHECK(check( - [](AssemblyBuilderX64& build) { + [](AssemblyBuilderX64& build) + { Label skip; build.cmp(rsi, rdi); @@ -388,24 +431,28 @@ TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "ControlFlow") build.or_(rdi, 0x3e); build.setLabel(skip); }, - {0x48, 0x3b, 0xf7, 0x0f, 0x8f, 0x04, 0x00, 0x00, 0x00, 0x48, 0x83, 0xcf, 0x3e})); + {0x48, 0x3b, 0xf7, 0x0f, 0x8f, 0x04, 0x00, 0x00, 0x00, 0x48, 0x83, 0xcf, 0x3e} + )); // Regular jump CHECK(check( - [](AssemblyBuilderX64& build) { + [](AssemblyBuilderX64& build) + { Label skip; build.jmp(skip); build.and_(rdi, 0x3e); build.setLabel(skip); }, - {0xe9, 0x04, 0x00, 0x00, 0x00, 0x48, 0x83, 0xe7, 0x3e})); + {0xe9, 0x04, 0x00, 0x00, 0x00, 0x48, 0x83, 0xe7, 0x3e} + )); } TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "LabelCall") { CHECK(check( - [](AssemblyBuilderX64& build) { + [](AssemblyBuilderX64& build) + { Label fnB; build.and_(rcx, 0x3e); @@ -416,7 +463,8 @@ TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "LabelCall") build.lea(rax, addr[rcx + 0x1f]); build.ret(); }, - {0x48, 0x83, 0xe1, 0x3e, 0xe8, 0x01, 0x00, 0x00, 0x00, 0xc3, 0x48, 0x8d, 0x41, 0x1f, 0xc3})); + {0x48, 0x83, 0xe1, 0x3e, 0xe8, 0x01, 0x00, 0x00, 0x00, 0xc3, 0x48, 0x8d, 0x41, 0x1f, 0xc3} + )); } TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "AVXBinaryInstructionForms") @@ -442,8 +490,13 @@ TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "AVXBinaryInstructionForms") SINGLE_COMPARE(vmulsd(xmm8, xmm10, xmm14), 0xc4, 0x41, 0x2b, 0x59, 0xc6); SINGLE_COMPARE(vdivsd(xmm8, xmm10, xmm14), 0xc4, 0x41, 0x2b, 0x5e, 0xc6); + SINGLE_COMPARE(vsubps(xmm8, xmm10, xmm14), 0xc4, 0x41, 0x28, 0x5c, 0xc6); + SINGLE_COMPARE(vmulps(xmm8, xmm10, xmm14), 0xc4, 0x41, 0x28, 0x59, 0xc6); + SINGLE_COMPARE(vdivps(xmm8, xmm10, xmm14), 0xc4, 0x41, 0x28, 0x5e, 0xc6); + SINGLE_COMPARE(vorpd(xmm8, xmm10, xmm14), 0xc4, 0x41, 0x29, 0x56, 0xc6); SINGLE_COMPARE(vxorpd(xmm8, xmm10, xmm14), 0xc4, 0x41, 0x29, 0x57, 0xc6); + SINGLE_COMPARE(vorps(xmm8, xmm10, xmm14), 0xc4, 0x41, 0x28, 0x56, 0xc6); SINGLE_COMPARE(vandpd(xmm8, xmm10, xmm14), 0xc4, 0x41, 0x29, 0x54, 0xc6); SINGLE_COMPARE(vandnpd(xmm8, xmm10, xmm14), 0xc4, 0x41, 0x29, 0x55, 0xc6); @@ -507,20 +560,51 @@ TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "AVXConversionInstructionForms") SINGLE_COMPARE(vcvtsi2sd(xmm6, xmm11, dword[rcx + rdx]), 0xc4, 0xe1, 0x23, 0x2a, 0x34, 0x11); SINGLE_COMPARE(vcvtsi2sd(xmm5, xmm10, r13), 0xc4, 0xc1, 0xab, 0x2a, 0xed); SINGLE_COMPARE(vcvtsi2sd(xmm6, xmm11, qword[rcx + rdx]), 0xc4, 0xe1, 0xa3, 0x2a, 0x34, 0x11); + SINGLE_COMPARE(vcvtsd2ss(xmm5, xmm10, xmm11), 0xc4, 0xc1, 0x2b, 0x5a, 0xeb); + SINGLE_COMPARE(vcvtsd2ss(xmm6, xmm11, qword[rcx + rdx]), 0xc4, 0xe1, 0xa3, 0x5a, 0x34, 0x11); + SINGLE_COMPARE(vcvtss2sd(xmm3, xmm8, xmm12), 0xc4, 0xc1, 0x3a, 0x5a, 0xdc); + SINGLE_COMPARE(vcvtss2sd(xmm4, xmm9, dword[rcx + rsi]), 0xc4, 0xe1, 0x32, 0x5a, 0x24, 0x31); } TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "AVXTernaryInstructionForms") { SINGLE_COMPARE(vroundsd(xmm7, xmm12, xmm3, RoundingModeX64::RoundToNegativeInfinity), 0xc4, 0xe3, 0x19, 0x0b, 0xfb, 0x09); SINGLE_COMPARE( - vroundsd(xmm8, xmm13, xmmword[r13 + rdx], RoundingModeX64::RoundToPositiveInfinity), 0xc4, 0x43, 0x11, 0x0b, 0x44, 0x15, 0x00, 0x0a); + vroundsd(xmm8, xmm13, xmmword[r13 + rdx], RoundingModeX64::RoundToPositiveInfinity), 0xc4, 0x43, 0x11, 0x0b, 0x44, 0x15, 0x00, 0x0a + ); SINGLE_COMPARE(vroundsd(xmm9, xmm14, xmmword[rcx + r10], RoundingModeX64::RoundToZero), 0xc4, 0x23, 0x09, 0x0b, 0x0c, 0x11, 0x0b); SINGLE_COMPARE(vblendvpd(xmm7, xmm12, xmmword[rcx + r10], xmm5), 0xc4, 0xa3, 0x19, 0x4b, 0x3c, 0x11, 0x50); + + SINGLE_COMPARE(vpshufps(xmm7, xmm12, xmmword[rcx + r10], 0b11010100), 0xc4, 0xa1, 0x18, 0xc6, 0x3c, 0x11, 0xd4); + SINGLE_COMPARE(vpinsrd(xmm7, xmm12, xmmword[rcx + r10], 2), 0xc4, 0xa3, 0x19, 0x22, 0x3c, 0x11, 0x02); } TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "MiscInstructions") { SINGLE_COMPARE(int3(), 0xcc); + SINGLE_COMPARE(ud2(), 0x0f, 0x0b); + SINGLE_COMPARE(bsr(eax, edx), 0x0f, 0xbd, 0xc2); + SINGLE_COMPARE(bsf(eax, edx), 0x0f, 0xbc, 0xc2); + SINGLE_COMPARE(bswap(eax), 0x0f, 0xc8); + SINGLE_COMPARE(bswap(r12d), 0x41, 0x0f, 0xcc); + SINGLE_COMPARE(bswap(rax), 0x48, 0x0f, 0xc8); + SINGLE_COMPARE(bswap(r12), 0x49, 0x0f, 0xcc); +} + +TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "LabelLea") +{ + CHECK(check( + [](AssemblyBuilderX64& build) + { + Label fn; + build.lea(rax, fn); + build.ret(); + + build.setLabel(fn); + build.ret(); + }, + {0x48, 0x8d, 0x05, 0x01, 0x00, 0x00, 0x00, 0xc3, 0xc3} + )); } TEST_CASE("LogTest") @@ -542,6 +626,8 @@ TEST_CASE("LogTest") Label start = build.setLabel(); build.cmp(rsi, rdi); build.jcc(ConditionX64::Equal, start); + build.lea(rcx, start); + build.lea(rcx, addr[rdx]); build.jmp(qword[rdx]); build.vaddps(ymm9, ymm12, ymmword[rbp + 0xc]); @@ -556,6 +642,7 @@ TEST_CASE("LogTest") build.vroundsd(xmm1, xmm2, xmm3, RoundingModeX64::RoundToNearestEven); build.add(rdx, qword[rcx - 12]); build.pop(r12); + build.cmov(ConditionX64::AboveEqual, rax, rbx); build.ret(); build.int3(); @@ -586,6 +673,8 @@ TEST_CASE("LogTest") .L1: cmp rsi,rdi je .L1 + lea rcx,.L1 + lea rcx,[rdx] jmp qword ptr [rdx] vaddps ymm9,ymm12,ymmword ptr [rbp+0Ch] vaddpd ymm2,ymm7,qword ptr [.start-8] @@ -599,6 +688,7 @@ TEST_CASE("LogTest") vroundsd xmm1,xmm2,xmm3,8 add rdx,qword ptr [rcx-0Ch] pop r12 + cmovae rax,rbx ret int3 nop @@ -662,19 +752,51 @@ TEST_CASE("ConstantStorage") AssemblyBuilderX64 build(/* logText= */ false); for (int i = 0; i <= 3000; i++) - build.vaddss(xmm0, xmm0, build.f32(1.0f)); + build.vaddss(xmm0, xmm0, build.i32(i)); build.finalize(); - LUAU_ASSERT(build.data.size() == 12004); + CHECK(build.data.size() == 12004); for (int i = 0; i <= 3000; i++) { - LUAU_ASSERT(build.data[i * 4 + 0] == 0x00); - LUAU_ASSERT(build.data[i * 4 + 1] == 0x00); - LUAU_ASSERT(build.data[i * 4 + 2] == 0x80); - LUAU_ASSERT(build.data[i * 4 + 3] == 0x3f); + CHECK(build.data[i * 4 + 0] == ((3000 - i) & 0xff)); + CHECK(build.data[i * 4 + 1] == ((3000 - i) >> 8)); + CHECK(build.data[i * 4 + 2] == 0x00); + CHECK(build.data[i * 4 + 3] == 0x00); } } +TEST_CASE("ConstantStorageDedup") +{ + AssemblyBuilderX64 build(/* logText= */ false); + + for (int i = 0; i <= 3000; i++) + build.vaddss(xmm0, xmm0, build.f32(1.0f)); + + build.finalize(); + + CHECK(build.data.size() == 4); + + CHECK(build.data[0] == 0x00); + CHECK(build.data[1] == 0x00); + CHECK(build.data[2] == 0x80); + CHECK(build.data[3] == 0x3f); +} + +TEST_CASE("ConstantCaching") +{ + AssemblyBuilderX64 build(/* logText= */ false); + + OperandX64 two = build.f64(2); + + // Force data relocation + for (int i = 0; i < 4096; i++) + build.f64(i); + + CHECK(build.f64(2).imm == two.imm); + + build.finalize(); +} + TEST_SUITE_END(); diff --git a/tests/AstJsonEncoder.test.cpp b/tests/AstJsonEncoder.test.cpp index a0127eef7..e170e9bc4 100644 --- a/tests/AstJsonEncoder.test.cpp +++ b/tests/AstJsonEncoder.test.cpp @@ -103,9 +103,11 @@ TEST_CASE("encode_AstStatBlock") AstStatBlock block{Location(), bodyArray}; - CHECK_EQ( - (R"({"type":"AstStatBlock","location":"0,0 - 0,0","body":[{"type":"AstStatLocal","location":"0,0 - 0,0","vars":[{"luauType":null,"name":"a_local","type":"AstLocal","location":"0,0 - 0,0"}],"values":[]}]})"), - toJson(&block)); + CHECK( + toJson(&block) == + (R"({"type":"AstStatBlock","location":"0,0 - 0,0","hasEnd":true,"body":[{"type":"AstStatLocal","location":"0,0 - 0,0","vars":[{"luauType":null,"name":"a_local","type":"AstLocal","location":"0,0 - 0,0"}],"values":[]}]})" + ) + ); } TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_tables") @@ -123,7 +125,8 @@ TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_tables") CHECK( json == - R"({"type":"AstStatBlock","location":"0,0 - 6,4","body":[{"type":"AstStatLocal","location":"1,8 - 5,9","vars":[{"luauType":{"type":"AstTypeTable","location":"1,17 - 3,9","props":[{"name":"foo","type":"AstTableProp","location":"2,12 - 2,15","propType":{"type":"AstTypeReference","location":"2,17 - 2,23","name":"number","parameters":[]}}],"indexer":null},"name":"x","type":"AstLocal","location":"1,14 - 1,15"}],"values":[{"type":"AstExprTable","location":"3,12 - 5,9","items":[{"type":"AstExprTableItem","kind":"record","key":{"type":"AstExprConstantString","location":"4,12 - 4,15","value":"foo"},"value":{"type":"AstExprConstantNumber","location":"4,18 - 4,21","value":123}}]}]}]})"); + R"({"type":"AstStatBlock","location":"0,0 - 6,4","hasEnd":true,"body":[{"type":"AstStatLocal","location":"1,8 - 5,9","vars":[{"luauType":{"type":"AstTypeTable","location":"1,17 - 3,9","props":[{"name":"foo","type":"AstTableProp","location":"2,12 - 2,15","propType":{"type":"AstTypeReference","location":"2,17 - 2,23","name":"number","nameLocation":"2,17 - 2,23","parameters":[]}}],"indexer":null},"name":"x","type":"AstLocal","location":"1,14 - 1,15"}],"values":[{"type":"AstExprTable","location":"3,12 - 5,9","items":[{"type":"AstExprTableItem","kind":"record","key":{"type":"AstExprConstantString","location":"4,12 - 4,15","value":"foo"},"value":{"type":"AstExprConstantNumber","location":"4,18 - 4,21","value":123}}]}]}]})" + ); } TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_table_array") @@ -135,7 +138,8 @@ TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_table_array") CHECK( json == - R"({"type":"AstStatBlock","location":"0,0 - 0,17","body":[{"type":"AstStatTypeAlias","location":"0,0 - 0,17","name":"X","generics":[],"genericPacks":[],"type":{"type":"AstTypeTable","location":"0,9 - 0,17","props":[],"indexer":{"location":"0,10 - 0,16","indexType":{"type":"AstTypeReference","location":"0,10 - 0,16","name":"number","parameters":[]},"resultType":{"type":"AstTypeReference","location":"0,10 - 0,16","name":"string","parameters":[]}}},"exported":false}]})"); + R"({"type":"AstStatBlock","location":"0,0 - 0,17","hasEnd":true,"body":[{"type":"AstStatTypeAlias","location":"0,0 - 0,17","name":"X","generics":[],"genericPacks":[],"value":{"type":"AstTypeTable","location":"0,9 - 0,17","props":[],"indexer":{"location":"0,10 - 0,16","indexType":{"type":"AstTypeReference","location":"0,10 - 0,16","name":"number","nameLocation":"0,10 - 0,16","parameters":[]},"resultType":{"type":"AstTypeReference","location":"0,10 - 0,16","name":"string","nameLocation":"0,10 - 0,16","parameters":[]}}},"exported":false}]})" + ); } TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_table_indexer") @@ -147,7 +151,8 @@ TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_table_indexer") CHECK( json == - R"({"type":"AstStatBlock","location":"0,0 - 0,17","body":[{"type":"AstStatTypeAlias","location":"0,0 - 0,17","name":"X","generics":[],"genericPacks":[],"type":{"type":"AstTypeTable","location":"0,9 - 0,17","props":[],"indexer":{"location":"0,10 - 0,16","indexType":{"type":"AstTypeReference","location":"0,10 - 0,16","name":"number","parameters":[]},"resultType":{"type":"AstTypeReference","location":"0,10 - 0,16","name":"string","parameters":[]}}},"exported":false}]})"); + R"({"type":"AstStatBlock","location":"0,0 - 0,17","hasEnd":true,"body":[{"type":"AstStatTypeAlias","location":"0,0 - 0,17","name":"X","generics":[],"genericPacks":[],"value":{"type":"AstTypeTable","location":"0,9 - 0,17","props":[],"indexer":{"location":"0,10 - 0,16","indexType":{"type":"AstTypeReference","location":"0,10 - 0,16","name":"number","nameLocation":"0,10 - 0,16","parameters":[]},"resultType":{"type":"AstTypeReference","location":"0,10 - 0,16","name":"string","nameLocation":"0,10 - 0,16","parameters":[]}}},"exported":false}]})" + ); } TEST_CASE("encode_AstExprGroup") @@ -198,8 +203,10 @@ TEST_CASE("encode_AstExprLocal") AstLocal local{AstName{"foo"}, Location{}, nullptr, 0, 0, nullptr}; AstExprLocal exprLocal{Location{}, &local, false}; - CHECK(toJson(&exprLocal) == - R"({"type":"AstExprLocal","location":"0,0 - 0,0","local":{"luauType":null,"name":"foo","type":"AstLocal","location":"0,0 - 0,0"}})"); + CHECK( + toJson(&exprLocal) == + R"({"type":"AstExprLocal","location":"0,0 - 0,0","local":{"luauType":null,"name":"foo","type":"AstLocal","location":"0,0 - 0,0"}})" + ); } TEST_CASE("encode_AstExprVarargs") @@ -243,7 +250,7 @@ TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_AstExprFunction") AstExpr* expr = expectParseExpr("function (a) return a end"); std::string_view expected = - R"({"type":"AstExprFunction","location":"0,4 - 0,29","generics":[],"genericPacks":[],"args":[{"luauType":null,"name":"a","type":"AstLocal","location":"0,14 - 0,15"}],"vararg":false,"varargLocation":"0,0 - 0,0","body":{"type":"AstStatBlock","location":"0,16 - 0,26","body":[{"type":"AstStatReturn","location":"0,17 - 0,25","list":[{"type":"AstExprLocal","location":"0,24 - 0,25","local":{"luauType":null,"name":"a","type":"AstLocal","location":"0,14 - 0,15"}}]}]},"functionDepth":1,"debugname":"","hasEnd":true})"; + R"({"type":"AstExprFunction","location":"0,4 - 0,29","attributes":[],"generics":[],"genericPacks":[],"args":[{"luauType":null,"name":"a","type":"AstLocal","location":"0,14 - 0,15"}],"vararg":false,"varargLocation":"0,0 - 0,0","body":{"type":"AstStatBlock","location":"0,16 - 0,26","hasEnd":true,"body":[{"type":"AstStatReturn","location":"0,17 - 0,25","list":[{"type":"AstExprLocal","location":"0,24 - 0,25","local":{"luauType":null,"name":"a","type":"AstLocal","location":"0,14 - 0,15"}}]}]},"functionDepth":1,"debugname":""})"; CHECK(toJson(expr) == expected); } @@ -283,7 +290,7 @@ TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_AstExprTypeAssertion") AstExpr* expr = expectParseExpr("b :: any"); std::string_view expected = - R"({"type":"AstExprTypeAssertion","location":"0,4 - 0,12","expr":{"type":"AstExprGlobal","location":"0,4 - 0,5","global":"b"},"annotation":{"type":"AstTypeReference","location":"0,9 - 0,12","name":"any","parameters":[]}})"; + R"({"type":"AstExprTypeAssertion","location":"0,4 - 0,12","expr":{"type":"AstExprGlobal","location":"0,4 - 0,5","global":"b"},"annotation":{"type":"AstTypeReference","location":"0,9 - 0,12","name":"any","nameLocation":"0,9 - 0,12","parameters":[]}})"; CHECK(toJson(expr) == expected); } @@ -311,7 +318,7 @@ TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_AstStatIf") AstStat* statement = expectParseStatement("if true then else end"); std::string_view expected = - R"({"type":"AstStatIf","location":"0,0 - 0,21","condition":{"type":"AstExprConstantBool","location":"0,3 - 0,7","value":true},"thenbody":{"type":"AstStatBlock","location":"0,12 - 0,13","body":[]},"elsebody":{"type":"AstStatBlock","location":"0,17 - 0,18","body":[]},"hasThen":true,"hasEnd":true})"; + R"({"type":"AstStatIf","location":"0,0 - 0,21","condition":{"type":"AstExprConstantBool","location":"0,3 - 0,7","value":true},"thenbody":{"type":"AstStatBlock","location":"0,12 - 0,13","hasEnd":true,"body":[]},"elsebody":{"type":"AstStatBlock","location":"0,17 - 0,18","hasEnd":true,"body":[]},"hasThen":true})"; CHECK(toJson(statement) == expected); } @@ -321,7 +328,7 @@ TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_AstStatWhile") AstStat* statement = expectParseStatement("while true do end"); std::string_view expected = - R"({"type":"AstStatWhile","location":"0,0 - 0,17","condition":{"type":"AstExprConstantBool","location":"0,6 - 0,10","value":true},"body":{"type":"AstStatBlock","location":"0,13 - 0,14","body":[]},"hasDo":true,"hasEnd":true})"; + R"({"type":"AstStatWhile","location":"0,0 - 0,17","condition":{"type":"AstExprConstantBool","location":"0,6 - 0,10","value":true},"body":{"type":"AstStatBlock","location":"0,13 - 0,14","hasEnd":true,"body":[]},"hasDo":true})"; CHECK(toJson(statement) == expected); } @@ -331,7 +338,7 @@ TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_AstStatRepeat") AstStat* statement = expectParseStatement("repeat until true"); std::string_view expected = - R"({"type":"AstStatRepeat","location":"0,0 - 0,17","condition":{"type":"AstExprConstantBool","location":"0,13 - 0,17","value":true},"body":{"type":"AstStatBlock","location":"0,6 - 0,7","body":[]},"hasUntil":true})"; + R"({"type":"AstStatRepeat","location":"0,0 - 0,17","condition":{"type":"AstExprConstantBool","location":"0,13 - 0,17","value":true},"body":{"type":"AstStatBlock","location":"0,6 - 0,7","hasEnd":true,"body":[]}})"; CHECK(toJson(statement) == expected); } @@ -341,7 +348,7 @@ TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_AstStatBreak") AstStat* statement = expectParseStatement("while true do break end"); std::string_view expected = - R"({"type":"AstStatWhile","location":"0,0 - 0,23","condition":{"type":"AstExprConstantBool","location":"0,6 - 0,10","value":true},"body":{"type":"AstStatBlock","location":"0,13 - 0,20","body":[{"type":"AstStatBreak","location":"0,14 - 0,19"}]},"hasDo":true,"hasEnd":true})"; + R"({"type":"AstStatWhile","location":"0,0 - 0,23","condition":{"type":"AstExprConstantBool","location":"0,6 - 0,10","value":true},"body":{"type":"AstStatBlock","location":"0,13 - 0,20","hasEnd":true,"body":[{"type":"AstStatBreak","location":"0,14 - 0,19"}]},"hasDo":true})"; CHECK(toJson(statement) == expected); } @@ -351,7 +358,7 @@ TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_AstStatContinue") AstStat* statement = expectParseStatement("while true do continue end"); std::string_view expected = - R"({"type":"AstStatWhile","location":"0,0 - 0,26","condition":{"type":"AstExprConstantBool","location":"0,6 - 0,10","value":true},"body":{"type":"AstStatBlock","location":"0,13 - 0,23","body":[{"type":"AstStatContinue","location":"0,14 - 0,22"}]},"hasDo":true,"hasEnd":true})"; + R"({"type":"AstStatWhile","location":"0,0 - 0,26","condition":{"type":"AstExprConstantBool","location":"0,6 - 0,10","value":true},"body":{"type":"AstStatBlock","location":"0,13 - 0,23","hasEnd":true,"body":[{"type":"AstStatContinue","location":"0,14 - 0,22"}]},"hasDo":true})"; CHECK(toJson(statement) == expected); } @@ -361,7 +368,7 @@ TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_AstStatFor") AstStat* statement = expectParseStatement("for a=0,1 do end"); std::string_view expected = - R"({"type":"AstStatFor","location":"0,0 - 0,16","var":{"luauType":null,"name":"a","type":"AstLocal","location":"0,4 - 0,5"},"from":{"type":"AstExprConstantNumber","location":"0,6 - 0,7","value":0},"to":{"type":"AstExprConstantNumber","location":"0,8 - 0,9","value":1},"body":{"type":"AstStatBlock","location":"0,12 - 0,13","body":[]},"hasDo":true,"hasEnd":true})"; + R"({"type":"AstStatFor","location":"0,0 - 0,16","var":{"luauType":null,"name":"a","type":"AstLocal","location":"0,4 - 0,5"},"from":{"type":"AstExprConstantNumber","location":"0,6 - 0,7","value":0},"to":{"type":"AstExprConstantNumber","location":"0,8 - 0,9","value":1},"body":{"type":"AstStatBlock","location":"0,12 - 0,13","hasEnd":true,"body":[]},"hasDo":true})"; CHECK(toJson(statement) == expected); } @@ -371,7 +378,7 @@ TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_AstStatForIn") AstStat* statement = expectParseStatement("for a in b do end"); std::string_view expected = - R"({"type":"AstStatForIn","location":"0,0 - 0,17","vars":[{"luauType":null,"name":"a","type":"AstLocal","location":"0,4 - 0,5"}],"values":[{"type":"AstExprGlobal","location":"0,9 - 0,10","global":"b"}],"body":{"type":"AstStatBlock","location":"0,13 - 0,14","body":[]},"hasIn":true,"hasDo":true,"hasEnd":true})"; + R"({"type":"AstStatForIn","location":"0,0 - 0,17","vars":[{"luauType":null,"name":"a","type":"AstLocal","location":"0,4 - 0,5"}],"values":[{"type":"AstExprGlobal","location":"0,9 - 0,10","global":"b"}],"body":{"type":"AstStatBlock","location":"0,13 - 0,14","hasEnd":true,"body":[]},"hasIn":true,"hasDo":true})"; CHECK(toJson(statement) == expected); } @@ -391,7 +398,7 @@ TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_AstStatLocalFunction") AstStat* statement = expectParseStatement("local function a(b) return end"); std::string_view expected = - R"({"type":"AstStatLocalFunction","location":"0,0 - 0,30","name":{"luauType":null,"name":"a","type":"AstLocal","location":"0,15 - 0,16"},"func":{"type":"AstExprFunction","location":"0,0 - 0,30","generics":[],"genericPacks":[],"args":[{"luauType":null,"name":"b","type":"AstLocal","location":"0,17 - 0,18"}],"vararg":false,"varargLocation":"0,0 - 0,0","body":{"type":"AstStatBlock","location":"0,19 - 0,27","body":[{"type":"AstStatReturn","location":"0,20 - 0,26","list":[]}]},"functionDepth":1,"debugname":"a","hasEnd":true}})"; + R"({"type":"AstStatLocalFunction","location":"0,0 - 0,30","name":{"luauType":null,"name":"a","type":"AstLocal","location":"0,15 - 0,16"},"func":{"type":"AstExprFunction","location":"0,0 - 0,30","attributes":[],"generics":[],"genericPacks":[],"args":[{"luauType":null,"name":"b","type":"AstLocal","location":"0,17 - 0,18"}],"vararg":false,"varargLocation":"0,0 - 0,0","body":{"type":"AstStatBlock","location":"0,19 - 0,27","hasEnd":true,"body":[{"type":"AstStatReturn","location":"0,20 - 0,26","list":[]}]},"functionDepth":1,"debugname":"a"}})"; CHECK(toJson(statement) == expected); } @@ -401,7 +408,7 @@ TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_AstStatTypeAlias") AstStat* statement = expectParseStatement("type A = B"); std::string_view expected = - R"({"type":"AstStatTypeAlias","location":"0,0 - 0,10","name":"A","generics":[],"genericPacks":[],"type":{"type":"AstTypeReference","location":"0,9 - 0,10","name":"B","parameters":[]},"exported":false})"; + R"({"type":"AstStatTypeAlias","location":"0,0 - 0,10","name":"A","generics":[],"genericPacks":[],"value":{"type":"AstTypeReference","location":"0,9 - 0,10","name":"B","nameLocation":"0,9 - 0,10","parameters":[]},"exported":false})"; CHECK(toJson(statement) == expected); } @@ -411,11 +418,31 @@ TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_AstStatDeclareFunction") AstStat* statement = expectParseStatement("declare function foo(x: number): string"); std::string_view expected = - R"({"type":"AstStatDeclareFunction","location":"0,0 - 0,39","name":"foo","params":{"type":"AstTypeList","types":[{"type":"AstTypeReference","location":"0,24 - 0,30","name":"number","parameters":[]}]},"retTypes":{"type":"AstTypeList","types":[{"type":"AstTypeReference","location":"0,33 - 0,39","name":"string","parameters":[]}]},"generics":[],"genericPacks":[]})"; + R"({"type":"AstStatDeclareFunction","location":"0,0 - 0,39","attributes":[],"name":"foo","nameLocation":"0,17 - 0,20","params":{"type":"AstTypeList","types":[{"type":"AstTypeReference","location":"0,24 - 0,30","name":"number","nameLocation":"0,24 - 0,30","parameters":[]}]},"paramNames":[{"type":"AstArgumentName","name":"x","location":"0,21 - 0,22"}],"vararg":false,"varargLocation":"0,0 - 0,0","retTypes":{"type":"AstTypeList","types":[{"type":"AstTypeReference","location":"0,33 - 0,39","name":"string","nameLocation":"0,33 - 0,39","parameters":[]}]},"generics":[],"genericPacks":[]})"; + + CHECK(toJson(statement) == expected); +} + +TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_AstStatDeclareFunction2") +{ + AstStat* statement = expectParseStatement("declare function foo(x: number, ...: string): string"); + + std::string_view expected = + R"({"type":"AstStatDeclareFunction","location":"0,0 - 0,52","attributes":[],"name":"foo","nameLocation":"0,17 - 0,20","params":{"type":"AstTypeList","types":[{"type":"AstTypeReference","location":"0,24 - 0,30","name":"number","nameLocation":"0,24 - 0,30","parameters":[]}],"tailType":{"type":"AstTypePackVariadic","location":"0,37 - 0,43","variadicType":{"type":"AstTypeReference","location":"0,37 - 0,43","name":"string","nameLocation":"0,37 - 0,43","parameters":[]}}},"paramNames":[{"type":"AstArgumentName","name":"x","location":"0,21 - 0,22"}],"vararg":true,"varargLocation":"0,32 - 0,35","retTypes":{"type":"AstTypeList","types":[{"type":"AstTypeReference","location":"0,46 - 0,52","name":"string","nameLocation":"0,46 - 0,52","parameters":[]}]},"generics":[],"genericPacks":[]})"; CHECK(toJson(statement) == expected); } +TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_AstAttr") +{ + AstStat* expr = expectParseStatement("@checked function a(b) return c end"); + + std::string_view expected = + R"({"type":"AstStatFunction","location":"0,9 - 0,35","name":{"type":"AstExprGlobal","location":"0,18 - 0,19","global":"a"},"func":{"type":"AstExprFunction","location":"0,9 - 0,35","attributes":[{"type":"AstAttr","location":"0,0 - 0,8","name":"checked"}],"generics":[],"genericPacks":[],"args":[{"luauType":null,"name":"b","type":"AstLocal","location":"0,20 - 0,21"}],"vararg":false,"varargLocation":"0,0 - 0,0","body":{"type":"AstStatBlock","location":"0,22 - 0,32","hasEnd":true,"body":[{"type":"AstStatReturn","location":"0,23 - 0,31","list":[{"type":"AstExprGlobal","location":"0,30 - 0,31","global":"c"}]}]},"functionDepth":1,"debugname":"a"}})"; + + CHECK(toJson(expr) == expected); +} + TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_AstStatDeclareClass") { AstStatBlock* root = expectParse(R"( @@ -432,11 +459,11 @@ TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_AstStatDeclareClass") REQUIRE(2 == root->body.size); std::string_view expected1 = - R"({"type":"AstStatDeclareClass","location":"1,22 - 4,11","name":"Foo","props":[{"name":"prop","type":"AstDeclaredClassProp","luauType":{"type":"AstTypeReference","location":"2,18 - 2,24","name":"number","parameters":[]}},{"name":"method","type":"AstDeclaredClassProp","luauType":{"type":"AstTypeFunction","location":"3,21 - 4,11","generics":[],"genericPacks":[],"argTypes":{"type":"AstTypeList","types":[{"type":"AstTypeReference","location":"3,39 - 3,45","name":"number","parameters":[]}]},"returnTypes":{"type":"AstTypeList","types":[{"type":"AstTypeReference","location":"3,48 - 3,54","name":"string","parameters":[]}]}}}]})"; + R"({"type":"AstStatDeclareClass","location":"1,22 - 4,11","name":"Foo","props":[{"name":"prop","nameLocation":"2,12 - 2,16","type":"AstDeclaredClassProp","luauType":{"type":"AstTypeReference","location":"2,18 - 2,24","name":"number","nameLocation":"2,18 - 2,24","parameters":[]},"location":"2,12 - 2,24"},{"name":"method","nameLocation":"3,21 - 3,27","type":"AstDeclaredClassProp","luauType":{"type":"AstTypeFunction","location":"3,12 - 3,54","attributes":[],"generics":[],"genericPacks":[],"argTypes":{"type":"AstTypeList","types":[{"type":"AstTypeReference","location":"3,39 - 3,45","name":"number","nameLocation":"3,39 - 3,45","parameters":[]}]},"argNames":[{"type":"AstArgumentName","name":"foo","location":"3,34 - 3,37"}],"returnTypes":{"type":"AstTypeList","types":[{"type":"AstTypeReference","location":"3,48 - 3,54","name":"string","nameLocation":"3,48 - 3,54","parameters":[]}]}},"location":"3,12 - 3,54"}],"indexer":null})"; CHECK(toJson(root->body.data[0]) == expected1); std::string_view expected2 = - R"({"type":"AstStatDeclareClass","location":"6,22 - 8,11","name":"Bar","superName":"Foo","props":[{"name":"prop2","type":"AstDeclaredClassProp","luauType":{"type":"AstTypeReference","location":"7,19 - 7,25","name":"string","parameters":[]}}]})"; + R"({"type":"AstStatDeclareClass","location":"6,22 - 8,11","name":"Bar","superName":"Foo","props":[{"name":"prop2","nameLocation":"7,12 - 7,17","type":"AstDeclaredClassProp","luauType":{"type":"AstTypeReference","location":"7,19 - 7,25","name":"string","nameLocation":"7,19 - 7,25","parameters":[]},"location":"7,12 - 7,25"}],"indexer":null})"; CHECK(toJson(root->body.data[1]) == expected2); } @@ -445,7 +472,39 @@ TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_annotation") AstStat* statement = expectParseStatement("type T = ((number) -> (string | nil)) & ((string) -> ())"); std::string_view expected = - R"({"type":"AstStatTypeAlias","location":"0,0 - 0,55","name":"T","generics":[],"genericPacks":[],"type":{"type":"AstTypeIntersection","location":"0,9 - 0,55","types":[{"type":"AstTypeFunction","location":"0,10 - 0,36","generics":[],"genericPacks":[],"argTypes":{"type":"AstTypeList","types":[{"type":"AstTypeReference","location":"0,11 - 0,17","name":"number","parameters":[]}]},"returnTypes":{"type":"AstTypeList","types":[{"type":"AstTypeUnion","location":"0,23 - 0,35","types":[{"type":"AstTypeReference","location":"0,23 - 0,29","name":"string","parameters":[]},{"type":"AstTypeReference","location":"0,32 - 0,35","name":"nil","parameters":[]}]}]}},{"type":"AstTypeFunction","location":"0,41 - 0,55","generics":[],"genericPacks":[],"argTypes":{"type":"AstTypeList","types":[{"type":"AstTypeReference","location":"0,42 - 0,48","name":"string","parameters":[]}]},"returnTypes":{"type":"AstTypeList","types":[]}}]},"exported":false})"; + R"({"type":"AstStatTypeAlias","location":"0,0 - 0,55","name":"T","generics":[],"genericPacks":[],"value":{"type":"AstTypeIntersection","location":"0,9 - 0,55","types":[{"type":"AstTypeFunction","location":"0,10 - 0,36","attributes":[],"generics":[],"genericPacks":[],"argTypes":{"type":"AstTypeList","types":[{"type":"AstTypeReference","location":"0,11 - 0,17","name":"number","nameLocation":"0,11 - 0,17","parameters":[]}]},"argNames":[],"returnTypes":{"type":"AstTypeList","types":[{"type":"AstTypeUnion","location":"0,23 - 0,35","types":[{"type":"AstTypeReference","location":"0,23 - 0,29","name":"string","nameLocation":"0,23 - 0,29","parameters":[]},{"type":"AstTypeReference","location":"0,32 - 0,35","name":"nil","nameLocation":"0,32 - 0,35","parameters":[]}]}]}},{"type":"AstTypeFunction","location":"0,41 - 0,55","attributes":[],"generics":[],"genericPacks":[],"argTypes":{"type":"AstTypeList","types":[{"type":"AstTypeReference","location":"0,42 - 0,48","name":"string","nameLocation":"0,42 - 0,48","parameters":[]}]},"argNames":[],"returnTypes":{"type":"AstTypeList","types":[]}}]},"exported":false})"; + + CHECK(toJson(statement) == expected); +} + +TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_type_literal") +{ + AstStat* statement = expectParseStatement(R"(type Action = { strings: "A" | "B" | "C", mixed: "This" | "That" | true })"); + + auto json = toJson(statement); + + std::string_view expected = + R"({"type":"AstStatTypeAlias","location":"0,0 - 0,73","name":"Action","generics":[],"genericPacks":[],"value":{"type":"AstTypeTable","location":"0,14 - 0,73","props":[{"name":"strings","type":"AstTableProp","location":"0,16 - 0,23","propType":{"type":"AstTypeUnion","location":"0,25 - 0,40","types":[{"type":"AstTypeSingletonString","location":"0,25 - 0,28","value":"A"},{"type":"AstTypeSingletonString","location":"0,31 - 0,34","value":"B"},{"type":"AstTypeSingletonString","location":"0,37 - 0,40","value":"C"}]}},{"name":"mixed","type":"AstTableProp","location":"0,42 - 0,47","propType":{"type":"AstTypeUnion","location":"0,49 - 0,71","types":[{"type":"AstTypeSingletonString","location":"0,49 - 0,55","value":"This"},{"type":"AstTypeSingletonString","location":"0,58 - 0,64","value":"That"},{"type":"AstTypeSingletonBool","location":"0,67 - 0,71","value":true}]}}],"indexer":null},"exported":false})"; + + CHECK(toJson(statement) == expected); +} + +TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_indexed_type_literal") +{ + AstStat* statement = expectParseStatement(R"(type StringSet = { [string]: true })"); + + std::string_view expected = + R"({"type":"AstStatTypeAlias","location":"0,0 - 0,35","name":"StringSet","generics":[],"genericPacks":[],"value":{"type":"AstTypeTable","location":"0,17 - 0,35","props":[],"indexer":{"location":"0,19 - 0,33","indexType":{"type":"AstTypeReference","location":"0,20 - 0,26","name":"string","nameLocation":"0,20 - 0,26","parameters":[]},"resultType":{"type":"AstTypeSingletonBool","location":"0,29 - 0,33","value":true}}},"exported":false})"; + + CHECK(toJson(statement) == expected); +} + +TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_AstTypeFunction") +{ + AstStat* statement = expectParseStatement(R"(type fun = (string, bool, named: number) -> ())"); + + std::string_view expected = + R"({"type":"AstStatTypeAlias","location":"0,0 - 0,46","name":"fun","generics":[],"genericPacks":[],"value":{"type":"AstTypeFunction","location":"0,11 - 0,46","attributes":[],"generics":[],"genericPacks":[],"argTypes":{"type":"AstTypeList","types":[{"type":"AstTypeReference","location":"0,12 - 0,18","name":"string","nameLocation":"0,12 - 0,18","parameters":[]},{"type":"AstTypeReference","location":"0,20 - 0,24","name":"bool","nameLocation":"0,20 - 0,24","parameters":[]},{"type":"AstTypeReference","location":"0,33 - 0,39","name":"number","nameLocation":"0,33 - 0,39","parameters":[]}]},"argNames":[null,null,{"type":"AstArgumentName","name":"named","location":"0,26 - 0,31"}],"returnTypes":{"type":"AstTypeList","types":[]}},"exported":false})"; CHECK(toJson(statement) == expected); } @@ -458,7 +517,7 @@ TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_AstTypeError") AstStat* statement = parseResult.root->body.data[0]; std::string_view expected = - R"({"type":"AstStatTypeAlias","location":"0,0 - 0,9","name":"T","generics":[],"genericPacks":[],"type":{"type":"AstTypeError","location":"0,8 - 0,9","types":[],"messageIndex":0},"exported":false})"; + R"({"type":"AstStatTypeAlias","location":"0,0 - 0,9","name":"T","generics":[],"genericPacks":[],"value":{"type":"AstTypeError","location":"0,8 - 0,9","types":[],"messageIndex":0},"exported":false})"; CHECK(toJson(statement) == expected); } @@ -473,7 +532,7 @@ TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_AstTypePackExplicit") CHECK(2 == root->body.size); std::string_view expected = - R"({"type":"AstStatLocal","location":"2,8 - 2,36","vars":[{"luauType":{"type":"AstTypeReference","location":"2,17 - 2,36","name":"A","parameters":[{"type":"AstTypePackExplicit","location":"2,19 - 2,20","typeList":{"type":"AstTypeList","types":[{"type":"AstTypeReference","location":"2,20 - 2,26","name":"number","parameters":[]},{"type":"AstTypeReference","location":"2,28 - 2,34","name":"string","parameters":[]}]}}]},"name":"a","type":"AstLocal","location":"2,14 - 2,15"}],"values":[]})"; + R"({"type":"AstStatLocal","location":"2,8 - 2,36","vars":[{"luauType":{"type":"AstTypeReference","location":"2,17 - 2,36","name":"A","nameLocation":"2,17 - 2,18","parameters":[{"type":"AstTypePackExplicit","location":"2,19 - 2,20","typeList":{"type":"AstTypeList","types":[{"type":"AstTypeReference","location":"2,20 - 2,26","name":"number","nameLocation":"2,20 - 2,26","parameters":[]},{"type":"AstTypeReference","location":"2,28 - 2,34","name":"string","nameLocation":"2,28 - 2,34","parameters":[]}]}}]},"name":"a","type":"AstLocal","location":"2,14 - 2,15"}],"values":[]})"; CHECK(toJson(root->body.data[1]) == expected); } diff --git a/tests/AstQuery.test.cpp b/tests/AstQuery.test.cpp index a642334af..5fc51b39a 100644 --- a/tests/AstQuery.test.cpp +++ b/tests/AstQuery.test.cpp @@ -8,6 +8,8 @@ using namespace Luau; +LUAU_FASTFLAG(LuauDocumentationAtPosition) + struct DocumentationSymbolFixture : BuiltinsFixture { std::optional getDocSymbol(const std::string& source, Position position) @@ -25,20 +27,24 @@ TEST_SUITE_BEGIN("AstQuery::getDocumentationSymbolAtPosition"); TEST_CASE_FIXTURE(DocumentationSymbolFixture, "binding") { - std::optional global = getDocSymbol(R"( + std::optional global = getDocSymbol( + R"( local a = string.sub() )", - Position(1, 21)); + Position(1, 21) + ); CHECK_EQ(global, "@luau/global/string"); } TEST_CASE_FIXTURE(DocumentationSymbolFixture, "prop") { - std::optional substring = getDocSymbol(R"( + std::optional substring = getDocSymbol( + R"( local a = string.sub() )", - Position(1, 27)); + Position(1, 27) + ); CHECK_EQ(substring, "@luau/global/string.sub"); } @@ -49,11 +55,13 @@ TEST_CASE_FIXTURE(DocumentationSymbolFixture, "event_callback_arg") declare function Connect(fn: (string) -> ()) )"); - std::optional substring = getDocSymbol(R"( + std::optional substring = getDocSymbol( + R"( Connect(function(abc) end) )", - Position(1, 27)); + Position(1, 27) + ); CHECK_EQ(substring, "@test/global/Connect/param/0/param/0"); } @@ -64,10 +72,12 @@ TEST_CASE_FIXTURE(DocumentationSymbolFixture, "overloaded_fn") declare foo: ((string) -> number) & ((number) -> string) )"); - std::optional symbol = getDocSymbol(R"( + std::optional symbol = getDocSymbol( + R"( foo("asdf") )", - Position(1, 10)); + Position(1, 10) + ); CHECK_EQ(symbol, "@test/global/foo/overload/(string) -> number"); } @@ -78,13 +88,19 @@ TEST_CASE_FIXTURE(DocumentationSymbolFixture, "class_method") declare class Foo function bar(self, x: string): number end + + declare Foo: { + new: () -> Foo + } )"); - std::optional symbol = getDocSymbol(R"( - local x: Foo + std::optional symbol = getDocSymbol( + R"( + local x: Foo = Foo.new() x:bar("asdf") )", - Position(2, 11)); + Position(2, 11) + ); CHECK_EQ(symbol, "@test/globaltype/Foo.bar"); } @@ -96,13 +112,19 @@ TEST_CASE_FIXTURE(DocumentationSymbolFixture, "overloaded_class_method") function bar(self, x: string): number function bar(self, x: number): string end + + declare Foo: { + new: () -> Foo + } )"); - std::optional symbol = getDocSymbol(R"( - local x: Foo + std::optional symbol = getDocSymbol( + R"( + local x: Foo = Foo.new() x:bar("asdf") )", - Position(2, 11)); + Position(2, 11) + ); CHECK_EQ(symbol, "@test/globaltype/Foo.bar/overload/(Foo, string) -> number"); } @@ -115,10 +137,12 @@ TEST_CASE_FIXTURE(DocumentationSymbolFixture, "table_function_prop") } )"); - std::optional symbol = getDocSymbol(R"( + std::optional symbol = getDocSymbol( + R"( Foo.new("asdf") )", - Position(1, 13)); + Position(1, 13) + ); CHECK_EQ(symbol, "@test/global/Foo.new"); } @@ -131,20 +155,62 @@ TEST_CASE_FIXTURE(DocumentationSymbolFixture, "table_overloaded_function_prop") } )"); - std::optional symbol = getDocSymbol(R"( + std::optional symbol = getDocSymbol( + R"( Foo.new("asdf") )", - Position(1, 13)); + Position(1, 13) + ); CHECK_EQ(symbol, "@test/global/Foo.new/overload/(string) -> number"); } +TEST_CASE_FIXTURE(DocumentationSymbolFixture, "string_metatable_method") +{ + ScopedFastFlag sff{FFlag::LuauDocumentationAtPosition, true}; + std::optional symbol = getDocSymbol( + R"( + local x: string = "Foo" + x:rep(2) + )", + Position(2, 12) + ); + + CHECK_EQ(symbol, "@luau/global/string.rep"); +} + +TEST_CASE_FIXTURE(DocumentationSymbolFixture, "parent_class_method") +{ + ScopedFastFlag sff{FFlag::LuauDocumentationAtPosition, true}; + loadDefinition(R"( + declare class Foo + function bar(self, x: string): number + end + + declare class Bar extends Foo + function notbar(self, x: string): number + end + )"); + + std::optional symbol = getDocSymbol( + R"( + local x: Bar = Bar.new() + x:bar("asdf") + )", + Position(2, 11) + ); + + CHECK_EQ(symbol, "@test/globaltype/Foo.bar"); +} + TEST_SUITE_END(); TEST_SUITE_BEGIN("AstQuery"); TEST_CASE_FIXTURE(Fixture, "last_argument_function_call_type") { + ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + check(R"( local function foo() return 2 end local function bar(a: number) return -a end @@ -282,4 +348,87 @@ TEST_CASE_FIXTURE(Fixture, "Luau_selectively_query_for_a_different_boolean_2") REQUIRE(snd->value == true); } +TEST_CASE_FIXTURE(Fixture, "include_types_ancestry") +{ + check("local x: number = 4;"); + const Position pos(0, 10); + + std::vector ancestryNoTypes = findAstAncestryOfPosition(*getMainSourceModule(), pos); + std::vector ancestryTypes = findAstAncestryOfPosition(*getMainSourceModule(), pos, true); + + CHECK(ancestryTypes.size() > ancestryNoTypes.size()); + CHECK(!ancestryNoTypes.back()->asType()); + CHECK(ancestryTypes.back()->asType()); +} + +TEST_CASE_FIXTURE(Fixture, "find_name_ancestry") +{ + check(R"( + local tbl = {} + function tbl:abc() end + )"); + const Position pos(2, 18); + + std::vector ancestry = findAstAncestryOfPosition(*getMainSourceModule(), pos); + + REQUIRE(!ancestry.empty()); + CHECK(ancestry.back()->is()); +} + +TEST_CASE_FIXTURE(Fixture, "find_expr_ancestry") +{ + check(R"( + local tbl = {} + function tbl:abc() end + )"); + const Position pos(2, 29); + + std::vector ancestry = findAstAncestryOfPosition(*getMainSourceModule(), pos); + + REQUIRE(!ancestry.empty()); + CHECK(ancestry.back()->is()); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "find_binding_at_position_global_start_of_file") +{ + + check("local x = string.char(1)"); + const Position pos(0, 12); + + std::optional binding = findBindingAtPosition(*getMainModule(), *getMainSourceModule(), pos); + + REQUIRE(binding); + CHECK_EQ(binding->location, Location{Position{0, 0}, Position{0, 0}}); +} + +TEST_CASE_FIXTURE(Fixture, "interior_binding_location_is_consistent_with_exterior_binding") +{ + CheckResult result = check(R"( + local function abcd(arg) + abcd(arg) + end + + abcd(0) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + // FIXME CLI-114385: findBindingByPosition does not properly handle AstStatLocalFunction. + + // std::optional declBinding = findBindingAtPosition(*getMainModule(), *getMainSourceModule(), {1, 26}); + // REQUIRE(declBinding); + + // CHECK(declBinding->location == Location{{1, 25}, {1, 28}}); + + std::optional innerCallBinding = findBindingAtPosition(*getMainModule(), *getMainSourceModule(), {2, 15}); + REQUIRE(innerCallBinding); + + CHECK(innerCallBinding->location == Location{{1, 23}, {1, 27}}); + + std::optional outerCallBinding = findBindingAtPosition(*getMainModule(), *getMainSourceModule(), {5, 8}); + REQUIRE(outerCallBinding); + + CHECK(outerCallBinding->location == Location{{1, 23}, {1, 27}}); +} + TEST_SUITE_END(); diff --git a/tests/Autocomplete.test.cpp b/tests/Autocomplete.test.cpp index 85bd55077..de4049a97 100644 --- a/tests/Autocomplete.test.cpp +++ b/tests/Autocomplete.test.cpp @@ -15,8 +15,7 @@ LUAU_FASTFLAG(LuauTraceTypesInNonstrictMode2) LUAU_FASTFLAG(LuauSetMetatableDoesNotTimeTravel) -LUAU_FASTFLAG(LuauFixAutocompleteInWhile) -LUAU_FASTFLAG(LuauFixAutocompleteInFor) +LUAU_FASTINT(LuauTypeInferRecursionLimit) using namespace Luau; @@ -29,20 +28,40 @@ template struct ACFixtureImpl : BaseType { ACFixtureImpl() - : BaseType(true, true) + : BaseType(true) { } AutocompleteResult autocomplete(unsigned row, unsigned column) { + FrontendOptions opts; + opts.forAutocomplete = true; + opts.retainFullTypeGraphs = true; + this->frontend.check("MainModule", opts); + return Luau::autocomplete(this->frontend, "MainModule", Position{row, column}, nullCallback); } AutocompleteResult autocomplete(char marker, StringCompletionCallback callback = nullCallback) { + FrontendOptions opts; + opts.forAutocomplete = true; + opts.retainFullTypeGraphs = true; + this->frontend.check("MainModule", opts); + return Luau::autocomplete(this->frontend, "MainModule", getPosition(marker), callback); } + AutocompleteResult autocomplete(const ModuleName& name, Position pos, StringCompletionCallback callback = nullCallback) + { + FrontendOptions opts; + opts.forAutocomplete = true; + opts.retainFullTypeGraphs = true; + this->frontend.check(name, opts); + + return Luau::autocomplete(this->frontend, name, pos, callback); + } + CheckResult check(const std::string& source) { markerPosition.clear(); @@ -55,7 +74,7 @@ struct ACFixtureImpl : BaseType { if (prevChar == '@') { - LUAU_ASSERT("Illegal marker character" && c >= '0' && c <= '9'); + LUAU_ASSERT("Illegal marker character" && ((c >= '0' && c <= '9') || (c >= 'A' && c <= 'Z'))); LUAU_ASSERT("Duplicate marker found" && markerPosition.count(c) == 0); markerPosition.insert(std::pair{c, curPos}); } @@ -85,10 +104,22 @@ struct ACFixtureImpl : BaseType LoadDefinitionFileResult loadDefinition(const std::string& source) { - TypeChecker& typeChecker = this->frontend.typeCheckerForAutocomplete; - unfreeze(typeChecker.globalTypes); - LoadDefinitionFileResult result = loadDefinitionFile(typeChecker, typeChecker.globalScope, source, "@test"); - freeze(typeChecker.globalTypes); + GlobalTypes& globals = this->frontend.globalsForAutocomplete; + unfreeze(globals.globalTypes); + LoadDefinitionFileResult result = this->frontend.loadDefinitionFile( + globals, globals.globalScope, source, "@test", /* captureComments */ false, /* typeCheckForAutocomplete */ true + ); + freeze(globals.globalTypes); + + if (FFlag::LuauSolverV2) + { + GlobalTypes& globals = this->frontend.globals; + unfreeze(globals.globalTypes); + LoadDefinitionFileResult result = this->frontend.loadDefinitionFile( + globals, globals.globalScope, source, "@test", /* captureComments */ false, /* typeCheckForAutocomplete */ true + ); + freeze(globals.globalTypes); + } REQUIRE_MESSAGE(result.success, "loadDefinition: unable to load definition file"); return result; @@ -100,7 +131,6 @@ struct ACFixtureImpl : BaseType LUAU_ASSERT(i != markerPosition.end()); return i->second; } - // Maps a marker character (0-9 inclusive) to a position in the source code. std::map markerPosition; }; @@ -110,10 +140,10 @@ struct ACFixture : ACFixtureImpl ACFixture() : ACFixtureImpl() { - addGlobalBinding(frontend, "table", Binding{typeChecker.anyType}); - addGlobalBinding(frontend, "math", Binding{typeChecker.anyType}); - addGlobalBinding(frontend.typeCheckerForAutocomplete, "table", Binding{typeChecker.anyType}); - addGlobalBinding(frontend.typeCheckerForAutocomplete, "math", Binding{typeChecker.anyType}); + addGlobalBinding(frontend.globals, "table", Binding{builtinTypes->anyType}); + addGlobalBinding(frontend.globals, "math", Binding{builtinTypes->anyType}); + addGlobalBinding(frontend.globalsForAutocomplete, "table", Binding{builtinTypes->anyType}); + addGlobalBinding(frontend.globalsForAutocomplete, "math", Binding{builtinTypes->anyType}); } }; @@ -121,6 +151,40 @@ struct ACBuiltinsFixture : ACFixtureImpl { }; +#define LUAU_CHECK_HAS_KEY(map, key) \ + do \ + { \ + auto&& _m = (map); \ + auto&& _k = (key); \ + const size_t count = _m.count(_k); \ + CHECK_MESSAGE(count, "Map should have key \"" << _k << "\""); \ + if (!count) \ + { \ + MESSAGE("Keys: (count " << _m.size() << ")"); \ + for (const auto& [k, v] : _m) \ + { \ + MESSAGE("\tkey: " << k); \ + } \ + } \ + } while (false) + +#define LUAU_CHECK_HAS_NO_KEY(map, key) \ + do \ + { \ + auto&& _m = (map); \ + auto&& _k = (key); \ + const size_t count = _m.count(_k); \ + CHECK_MESSAGE(!count, "Map should not have key \"" << _k << "\""); \ + if (count) \ + { \ + MESSAGE("Keys: (count " << _m.size() << ")"); \ + for (const auto& [k, v] : _m) \ + { \ + MESSAGE("\tkey: " << k); \ + } \ + } \ + } while (false) + TEST_SUITE_BEGIN("AutocompleteTest"); TEST_CASE_FIXTURE(ACFixture, "empty_program") @@ -178,7 +242,7 @@ TEST_CASE_FIXTURE(ACFixture, "dont_suggest_local_before_its_definition") auto ac = autocomplete('1'); CHECK(ac.entryMap.count("myLocal")); - CHECK(!ac.entryMap.count("myInnerLocal")); + LUAU_CHECK_HAS_NO_KEY(ac.entryMap, "myInnerLocal"); ac = autocomplete('2'); CHECK(ac.entryMap.count("myLocal")); @@ -186,7 +250,7 @@ TEST_CASE_FIXTURE(ACFixture, "dont_suggest_local_before_its_definition") ac = autocomplete('3'); CHECK(ac.entryMap.count("myLocal")); - CHECK(!ac.entryMap.count("myInnerLocal")); + LUAU_CHECK_HAS_NO_KEY(ac.entryMap, "myInnerLocal"); } TEST_CASE_FIXTURE(ACFixture, "recursive_function") @@ -273,7 +337,7 @@ TEST_CASE_FIXTURE(ACFixture, "local_functions_fall_out_of_scope") auto ac = autocomplete('1'); CHECK_NE(0, ac.entryMap.size()); - CHECK(!ac.entryMap.count("abc")); + LUAU_CHECK_HAS_NO_KEY(ac.entryMap, "abc"); } TEST_CASE_FIXTURE(ACFixture, "function_parameters") @@ -300,7 +364,7 @@ TEST_CASE_FIXTURE(ACBuiltinsFixture, "get_member_completions") CHECK_EQ(17, ac.entryMap.size()); CHECK(ac.entryMap.count("find")); CHECK(ac.entryMap.count("pack")); - CHECK(!ac.entryMap.count("math")); + LUAU_CHECK_HAS_NO_KEY(ac.entryMap, "math"); CHECK_EQ(ac.context, AutocompleteContext::Property); } @@ -446,7 +510,7 @@ TEST_CASE_FIXTURE(ACFixture, "method_call_inside_function_body") CHECK_NE(0, ac.entryMap.size()); - CHECK(!ac.entryMap.count("math")); + LUAU_CHECK_HAS_NO_KEY(ac.entryMap, "math"); CHECK_EQ(ac.context, AutocompleteContext::Property); } @@ -460,7 +524,7 @@ TEST_CASE_FIXTURE(ACBuiltinsFixture, "method_call_inside_if_conditional") CHECK_NE(0, ac.entryMap.size()); CHECK(ac.entryMap.count("concat")); - CHECK(!ac.entryMap.count("math")); + LUAU_CHECK_HAS_NO_KEY(ac.entryMap, "math"); CHECK_EQ(ac.context, AutocompleteContext::Property); } @@ -630,19 +694,10 @@ TEST_CASE_FIXTURE(ACFixture, "autocomplete_for_middle_keywords") )"); auto ac5 = autocomplete('1'); - if (FFlag::LuauFixAutocompleteInFor) - { - CHECK_EQ(ac5.entryMap.count("math"), 1); - CHECK_EQ(ac5.entryMap.count("do"), 0); - CHECK_EQ(ac5.entryMap.count("end"), 0); - CHECK_EQ(ac5.context, AutocompleteContext::Expression); - } - else - { - CHECK_EQ(ac5.entryMap.count("do"), 1); - CHECK_EQ(ac5.entryMap.count("end"), 0); - CHECK_EQ(ac5.context, AutocompleteContext::Keyword); - } + CHECK_EQ(ac5.entryMap.count("math"), 1); + CHECK_EQ(ac5.entryMap.count("do"), 0); + CHECK_EQ(ac5.entryMap.count("end"), 0); + CHECK_EQ(ac5.context, AutocompleteContext::Expression); check(R"( for x = 1, 2, 5 f@1 @@ -661,29 +716,26 @@ TEST_CASE_FIXTURE(ACFixture, "autocomplete_for_middle_keywords") CHECK_EQ(ac7.entryMap.count("end"), 1); CHECK_EQ(ac7.context, AutocompleteContext::Statement); - if (FFlag::LuauFixAutocompleteInFor) - { - check(R"(local Foo = 1 - for x = @11, @22, @35 - )"); + check(R"(local Foo = 1 + for x = @11, @22, @35 + )"); - for (int i = 0; i < 3; ++i) - { - auto ac8 = autocomplete('1' + i); - CHECK_EQ(ac8.entryMap.count("Foo"), 1); - CHECK_EQ(ac8.entryMap.count("do"), 0); - } + for (int i = 0; i < 3; ++i) + { + auto ac8 = autocomplete('1' + i); + CHECK_EQ(ac8.entryMap.count("Foo"), 1); + CHECK_EQ(ac8.entryMap.count("do"), 0); + } - check(R"(local Foo = 1 - for x = @11, @22 - )"); + check(R"(local Foo = 1 + for x = @11, @22 + )"); - for (int i = 0; i < 2; ++i) - { - auto ac9 = autocomplete('1' + i); - CHECK_EQ(ac9.entryMap.count("Foo"), 1); - CHECK_EQ(ac9.entryMap.count("do"), 0); - } + for (int i = 0; i < 2; ++i) + { + auto ac9 = autocomplete('1' + i); + CHECK_EQ(ac9.entryMap.count("Foo"), 1); + CHECK_EQ(ac9.entryMap.count("do"), 0); } } @@ -776,18 +828,10 @@ TEST_CASE_FIXTURE(ACFixture, "autocomplete_while_middle_keywords") )"); auto ac2 = autocomplete('1'); - if (FFlag::LuauFixAutocompleteInWhile) - { - CHECK_EQ(3, ac2.entryMap.size()); - CHECK_EQ(ac2.entryMap.count("do"), 1); - CHECK_EQ(ac2.entryMap.count("and"), 1); - CHECK_EQ(ac2.entryMap.count("or"), 1); - } - else - { - CHECK_EQ(1, ac2.entryMap.size()); - CHECK_EQ(ac2.entryMap.count("do"), 1); - } + CHECK_EQ(3, ac2.entryMap.size()); + CHECK_EQ(ac2.entryMap.count("do"), 1); + CHECK_EQ(ac2.entryMap.count("and"), 1); + CHECK_EQ(ac2.entryMap.count("or"), 1); CHECK_EQ(ac2.context, AutocompleteContext::Keyword); check(R"( @@ -803,31 +847,20 @@ TEST_CASE_FIXTURE(ACFixture, "autocomplete_while_middle_keywords") )"); auto ac4 = autocomplete('1'); - if (FFlag::LuauFixAutocompleteInWhile) - { - CHECK_EQ(3, ac4.entryMap.size()); - CHECK_EQ(ac4.entryMap.count("do"), 1); - CHECK_EQ(ac4.entryMap.count("and"), 1); - CHECK_EQ(ac4.entryMap.count("or"), 1); - } - else - { - CHECK_EQ(1, ac4.entryMap.size()); - CHECK_EQ(ac4.entryMap.count("do"), 1); - } + CHECK_EQ(3, ac4.entryMap.size()); + CHECK_EQ(ac4.entryMap.count("do"), 1); + CHECK_EQ(ac4.entryMap.count("and"), 1); + CHECK_EQ(ac4.entryMap.count("or"), 1); CHECK_EQ(ac4.context, AutocompleteContext::Keyword); - if (FFlag::LuauFixAutocompleteInWhile) - { - check(R"( - while t@1 - )"); + check(R"( + while t@1 + )"); - auto ac5 = autocomplete('1'); - CHECK_EQ(ac5.entryMap.count("do"), 0); - CHECK_EQ(ac5.entryMap.count("true"), 1); - CHECK_EQ(ac5.entryMap.count("false"), 1); - } + auto ac5 = autocomplete('1'); + CHECK_EQ(ac5.entryMap.count("do"), 0); + CHECK_EQ(ac5.entryMap.count("true"), 1); + CHECK_EQ(ac5.entryMap.count("false"), 1); } TEST_CASE_FIXTURE(ACFixture, "autocomplete_if_middle_keywords") @@ -838,8 +871,10 @@ TEST_CASE_FIXTURE(ACFixture, "autocomplete_if_middle_keywords") auto ac1 = autocomplete('1'); CHECK_EQ(ac1.entryMap.count("then"), 0); - CHECK_EQ(ac1.entryMap.count("function"), - 1); // FIXME: This is kind of dumb. It is technically syntactically valid but you can never do anything interesting with this. + CHECK_EQ( + ac1.entryMap.count("function"), + 1 + ); // FIXME: This is kind of dumb. It is technically syntactically valid but you can never do anything interesting with this. CHECK_EQ(ac1.entryMap.count("table"), 1); CHECK_EQ(ac1.entryMap.count("else"), 0); CHECK_EQ(ac1.entryMap.count("elseif"), 0); @@ -992,6 +1027,31 @@ TEST_CASE_FIXTURE(ACFixture, "autocomplete_end_with_lambda") CHECK_EQ(ac.context, AutocompleteContext::Statement); } +TEST_CASE_FIXTURE(ACFixture, "autocomplete_end_of_do_block") +{ + check("do @1"); + + auto ac = autocomplete('1'); + + CHECK(ac.entryMap.count("end")); + + check(R"( + function f() + do + @1 + end + @2 + )"); + + ac = autocomplete('1'); + + CHECK(ac.entryMap.count("end")); + + ac = autocomplete('2'); + + CHECK(ac.entryMap.count("end")); +} + TEST_CASE_FIXTURE(ACFixture, "stop_at_first_stat_when_recommending_keywords") { check(R"( @@ -1311,7 +1371,7 @@ local a: nu@3 ac = autocomplete('3'); - CHECK(!ac.entryMap.count("num")); + LUAU_CHECK_HAS_NO_KEY(ac.entryMap, "num"); CHECK(ac.entryMap.count("number")); } @@ -1351,7 +1411,7 @@ local a: aa frontend.check("Module/B"); - auto ac = Luau::autocomplete(frontend, "Module/B", Position{2, 11}, nullCallback); + auto ac = autocomplete("Module/B", Position{2, 11}); CHECK(ac.entryMap.count("aaa")); CHECK_EQ(ac.context, AutocompleteContext::Type); @@ -1374,7 +1434,7 @@ local a: aaa. frontend.check("Module/B"); - auto ac = Luau::autocomplete(frontend, "Module/B", Position{2, 13}, nullCallback); + auto ac = autocomplete("Module/B", Position{2, 13}); CHECK_EQ(2, ac.entryMap.size()); CHECK(ac.entryMap.count("A")); @@ -1548,6 +1608,9 @@ return target(a.@1 TEST_CASE_FIXTURE(ACFixture, "type_correct_suggestion_in_table") { + if (FFlag::LuauSolverV2) // CLI-116815 Autocomplete cannot suggest keys while autocompleting inside of a table + return; + check(R"( type Foo = { a: number, b: string } local a = { one = 4, two = "hello" } @@ -2031,9 +2094,9 @@ ex.a(function(x: frontend.check("Module/B"); - auto ac = Luau::autocomplete(frontend, "Module/B", Position{2, 16}, nullCallback); + auto ac = autocomplete("Module/B", Position{2, 16}); - CHECK(!ac.entryMap.count("done")); + LUAU_CHECK_HAS_NO_KEY(ac.entryMap, "done"); fileResolver.source["Module/C"] = R"( local ex = require(script.Parent.A) @@ -2042,9 +2105,9 @@ ex.b(function(x: frontend.check("Module/C"); - ac = Luau::autocomplete(frontend, "Module/C", Position{2, 16}, nullCallback); + ac = autocomplete("Module/C", Position{2, 16}); - CHECK(!ac.entryMap.count("(done) -> number")); + LUAU_CHECK_HAS_NO_KEY(ac.entryMap, "(done) -> number"); } TEST_CASE_FIXTURE(ACBuiltinsFixture, "suggest_external_module_type") @@ -2065,9 +2128,9 @@ ex.a(function(x: frontend.check("Module/B"); - auto ac = Luau::autocomplete(frontend, "Module/B", Position{2, 16}, nullCallback); + auto ac = autocomplete("Module/B", Position{2, 16}); - CHECK(!ac.entryMap.count("done")); + LUAU_CHECK_HAS_NO_KEY(ac.entryMap, "done"); CHECK(ac.entryMap.count("ex.done")); CHECK(ac.entryMap["ex.done"].typeCorrect == TypeCorrectKind::Correct); @@ -2078,9 +2141,9 @@ ex.b(function(x: frontend.check("Module/C"); - ac = Luau::autocomplete(frontend, "Module/C", Position{2, 16}, nullCallback); + ac = autocomplete("Module/C", Position{2, 16}); - CHECK(!ac.entryMap.count("(done) -> number")); + LUAU_CHECK_HAS_NO_KEY(ac.entryMap, "(done) -> number"); CHECK(ac.entryMap.count("(ex.done) -> number")); CHECK(ac.entryMap["(ex.done) -> number"].typeCorrect == TypeCorrectKind::Correct); } @@ -2094,7 +2157,7 @@ local bar: @1= foo auto ac = autocomplete('1'); - CHECK(!ac.entryMap.count("foo")); + LUAU_CHECK_HAS_NO_KEY(ac.entryMap, "foo"); } TEST_CASE_FIXTURE(ACFixture, "type_correct_function_no_parenthesis") @@ -2151,7 +2214,10 @@ local fp: @1= f auto ac = autocomplete('1'); - REQUIRE_EQ("({| x: number, y: number |}) -> number", toString(requireType("f"))); + if (FFlag::LuauSolverV2) + REQUIRE_EQ("({ x: number, y: number }) -> number", toString(requireType("f"))); + else + REQUIRE_EQ("({| x: number, y: number |}) -> number", toString(requireType("f"))); CHECK(ac.entryMap.count("({ x: number, y: number }) -> number")); } @@ -2199,6 +2265,9 @@ local ec = e(f@5) TEST_CASE_FIXTURE(ACFixture, "type_correct_suggestion_for_overloads") { + if (FFlag::LuauSolverV2) // CLI-116814 Autocomplete needs to populate expected types for function arguments correctly + // (overloads and singletons) + return; check(R"( local target: ((number) -> string) & ((string) -> number)) @@ -2316,7 +2385,7 @@ local name = na@1 auto ac = autocomplete('1'); - CHECK(!ac.entryMap.count("name")); + LUAU_CHECK_HAS_NO_KEY(ac.entryMap, "name"); CHECK(ac.entryMap.count("other")); check(R"( @@ -2326,8 +2395,8 @@ local name, test = na@1 ac = autocomplete('1'); - CHECK(!ac.entryMap.count("name")); - CHECK(!ac.entryMap.count("test")); + LUAU_CHECK_HAS_NO_KEY(ac.entryMap, "name"); + LUAU_CHECK_HAS_NO_KEY(ac.entryMap, "test"); CHECK(ac.entryMap.count("other")); } @@ -2392,7 +2461,7 @@ local a: aaa.do frontend.check("Module/B"); - auto ac = Luau::autocomplete(frontend, "Module/B", Position{2, 15}, nullCallback); + auto ac = autocomplete("Module/B", Position{2, 15}); CHECK_EQ(2, ac.entryMap.size()); CHECK(ac.entryMap.count("done")); @@ -2404,7 +2473,7 @@ TEST_CASE_FIXTURE(ACFixture, "comments") { fileResolver.source["Comments"] = "--!str"; - auto ac = Luau::autocomplete(frontend, "Comments", Position{0, 6}, nullCallback); + auto ac = autocomplete("Comments", Position{0, 6}); CHECK_EQ(0, ac.entryMap.size()); } @@ -2423,7 +2492,7 @@ TEST_CASE_FIXTURE(ACBuiltinsFixture, "autocompleteProp_index_function_metamethod -- | Column 20 )"; - auto ac = Luau::autocomplete(frontend, "Module/A", Position{9, 20}, nullCallback); + auto ac = autocomplete("Module/A", Position{9, 20}); REQUIRE_EQ(1, ac.entryMap.size()); CHECK(ac.entryMap.count("x")); } @@ -2516,8 +2585,8 @@ TEST_CASE_FIXTURE(ACFixture, "not_the_var_we_are_defining") { fileResolver.source["Module/A"] = "abc,de"; - auto ac = Luau::autocomplete(frontend, "Module/A", Position{0, 6}, nullCallback); - CHECK(!ac.entryMap.count("de")); + auto ac = autocomplete("Module/A", Position{0, 6}); + LUAU_CHECK_HAS_NO_KEY(ac.entryMap, "de"); } TEST_CASE_FIXTURE(ACFixture, "recursive_function_global") @@ -2527,7 +2596,7 @@ TEST_CASE_FIXTURE(ACFixture, "recursive_function_global") end )"; - auto ac = Luau::autocomplete(frontend, "global", Position{1, 0}, nullCallback); + auto ac = autocomplete("global", Position{1, 0}); CHECK(ac.entryMap.count("abc")); } @@ -2540,12 +2609,16 @@ TEST_CASE_FIXTURE(ACFixture, "recursive_function_local") end )"; - auto ac = Luau::autocomplete(frontend, "local", Position{1, 0}, nullCallback); + auto ac = autocomplete("local", Position{1, 0}); CHECK(ac.entryMap.count("abc")); } TEST_CASE_FIXTURE(ACFixture, "suggest_table_keys") { + if (FFlag::LuauSolverV2) // CLI-116812 AutocompleteTest.suggest_table_keys needs to populate expected types for nested + // tables without an annotation + return; + check(R"( type Test = { first: number, second: number } local t: Test = { f@1 } @@ -2575,8 +2648,8 @@ local t: Test = { s@1 } ac = autocomplete('1'); CHECK(ac.entryMap.count("second")); - CHECK(!ac.entryMap.count("first")); - CHECK(!ac.entryMap.count("third")); + LUAU_CHECK_HAS_NO_KEY(ac.entryMap, "first"); + LUAU_CHECK_HAS_NO_KEY(ac.entryMap, "third"); CHECK_EQ(ac.context, AutocompleteContext::Property); // No parenthesis suggestion @@ -2619,8 +2692,8 @@ local t: Test = { "f@1" } )"); ac = autocomplete('1'); - CHECK(!ac.entryMap.count("first")); - CHECK(!ac.entryMap.count("second")); + LUAU_CHECK_HAS_NO_KEY(ac.entryMap, "first"); + LUAU_CHECK_HAS_NO_KEY(ac.entryMap, "second"); CHECK_EQ(ac.context, AutocompleteContext::String); // Skip keys that are already defined @@ -2630,7 +2703,7 @@ local t: Test = { first = 2, s@1 } )"); ac = autocomplete('1'); - CHECK(!ac.entryMap.count("first")); + LUAU_CHECK_HAS_NO_KEY(ac.entryMap, "first"); CHECK(ac.entryMap.count("second")); CHECK_EQ(ac.context, AutocompleteContext::Property); @@ -2671,6 +2744,48 @@ local t = { CHECK_EQ(ac.context, AutocompleteContext::Property); } +TEST_CASE_FIXTURE(ACFixture, "suggest_table_keys_no_initial_character") +{ + check(R"( +type Test = { first: number, second: number } +local t: Test = { @1 } + )"); + + auto ac = autocomplete('1'); + CHECK(ac.entryMap.count("first")); + CHECK(ac.entryMap.count("second")); + CHECK_EQ(ac.context, AutocompleteContext::Property); +} + +TEST_CASE_FIXTURE(ACFixture, "suggest_table_keys_no_initial_character_2") +{ + check(R"( +type Test = { first: number, second: number } +local t: Test = { first = 1, @1 } + )"); + + auto ac = autocomplete('1'); + CHECK_EQ(ac.entryMap.count("first"), 0); + CHECK(ac.entryMap.count("second")); + CHECK_EQ(ac.context, AutocompleteContext::Property); +} + +TEST_CASE_FIXTURE(ACFixture, "suggest_table_keys_no_initial_character_3") +{ + check(R"( +type Properties = { TextScaled: boolean, Text: string } +local function create(props: Properties) end + +create({ @1 }) + )"); + + auto ac = autocomplete('1'); + CHECK(ac.entryMap.size() > 0); + CHECK(ac.entryMap.count("TextScaled")); + CHECK(ac.entryMap.count("Text")); + CHECK_EQ(ac.context, AutocompleteContext::Property); +} + TEST_CASE_FIXTURE(ACFixture, "autocomplete_documentation_symbols") { loadDefinition(R"( @@ -2987,6 +3102,10 @@ TEST_CASE_FIXTURE(ACBuiltinsFixture, "autocomplete_on_string_singletons") TEST_CASE_FIXTURE(ACFixture, "autocomplete_string_singletons") { + if (FFlag::LuauSolverV2) // CLI-116814 Autocomplete needs to populate expected types for function arguments correctly + // (overloads and singletons) + return; + check(R"( type tag = "cat" | "dog" local function f(a: tag) end @@ -3027,8 +3146,6 @@ TEST_CASE_FIXTURE(ACFixture, "autocomplete_string_singletons") TEST_CASE_FIXTURE(ACFixture, "string_singleton_as_table_key") { - ScopedFastFlag sff{"LuauCompleteTableKeysBetter", true}; - check(R"( type Direction = "up" | "down" @@ -3060,8 +3177,8 @@ TEST_CASE_FIXTURE(ACFixture, "string_singleton_as_table_key") ac = autocomplete('4'); - CHECK(!ac.entryMap.count("up")); - CHECK(!ac.entryMap.count("down")); + LUAU_CHECK_HAS_NO_KEY(ac.entryMap, "up"); + LUAU_CHECK_HAS_NO_KEY(ac.entryMap, "down"); CHECK(ac.entryMap.count("\"up\"")); CHECK(ac.entryMap.count("\"down\"")); @@ -3083,13 +3200,183 @@ TEST_CASE_FIXTURE(ACFixture, "string_singleton_as_table_key") ac = autocomplete('8'); - CHECK(!ac.entryMap.count("up")); - CHECK(!ac.entryMap.count("down")); + LUAU_CHECK_HAS_NO_KEY(ac.entryMap, "up"); + LUAU_CHECK_HAS_NO_KEY(ac.entryMap, "down"); CHECK(ac.entryMap.count("\"up\"")); CHECK(ac.entryMap.count("\"down\"")); } +// https://github.com/Roblox/luau/issues/858 +TEST_CASE_FIXTURE(ACFixture, "string_singleton_in_if_statement") +{ + ScopedFastFlag sff[]{ + {FFlag::LuauSolverV2, true}, + }; + + check(R"( + --!strict + + type Direction = "left" | "right" + + local dir: Direction = "left" + + if dir == @1"@2"@3 then end + local a: {[Direction]: boolean} = {[@4"@5"@6]} + + if dir == @7`@8`@9 then end + local a: {[Direction]: boolean} = {[@A`@B`@C]} + )"); + + Luau::AutocompleteResult ac; + + ac = autocomplete('1'); + + LUAU_CHECK_HAS_NO_KEY(ac.entryMap, "left"); + LUAU_CHECK_HAS_NO_KEY(ac.entryMap, "right"); + + ac = autocomplete('2'); + + LUAU_CHECK_HAS_KEY(ac.entryMap, "left"); + LUAU_CHECK_HAS_KEY(ac.entryMap, "right"); + + ac = autocomplete('3'); + + LUAU_CHECK_HAS_NO_KEY(ac.entryMap, "left"); + LUAU_CHECK_HAS_NO_KEY(ac.entryMap, "right"); + + ac = autocomplete('4'); + + LUAU_CHECK_HAS_NO_KEY(ac.entryMap, "left"); + LUAU_CHECK_HAS_NO_KEY(ac.entryMap, "right"); + + ac = autocomplete('5'); + + LUAU_CHECK_HAS_KEY(ac.entryMap, "left"); + LUAU_CHECK_HAS_KEY(ac.entryMap, "right"); + + ac = autocomplete('6'); + + LUAU_CHECK_HAS_NO_KEY(ac.entryMap, "left"); + LUAU_CHECK_HAS_NO_KEY(ac.entryMap, "right"); + + ac = autocomplete('7'); + + LUAU_CHECK_HAS_NO_KEY(ac.entryMap, "left"); + LUAU_CHECK_HAS_NO_KEY(ac.entryMap, "right"); + + ac = autocomplete('8'); + + LUAU_CHECK_HAS_KEY(ac.entryMap, "left"); + LUAU_CHECK_HAS_KEY(ac.entryMap, "right"); + + ac = autocomplete('9'); + + LUAU_CHECK_HAS_NO_KEY(ac.entryMap, "left"); + LUAU_CHECK_HAS_NO_KEY(ac.entryMap, "right"); + + ac = autocomplete('A'); + + LUAU_CHECK_HAS_NO_KEY(ac.entryMap, "left"); + LUAU_CHECK_HAS_NO_KEY(ac.entryMap, "right"); + + ac = autocomplete('B'); + + LUAU_CHECK_HAS_KEY(ac.entryMap, "left"); + LUAU_CHECK_HAS_KEY(ac.entryMap, "right"); + + ac = autocomplete('C'); + + LUAU_CHECK_HAS_NO_KEY(ac.entryMap, "left"); + LUAU_CHECK_HAS_NO_KEY(ac.entryMap, "right"); +} + +// https://github.com/Roblox/luau/issues/858 +TEST_CASE_FIXTURE(ACFixture, "string_singleton_in_if_statement2") +{ + // don't run this when the DCR flag isn't set + if (!FFlag::LuauSolverV2) + return; + + check(R"( + --!strict + + type Direction = "left" | "right" + + local dir: Direction + -- typestate here means dir is actually typed as `"left"` + dir = "left" + + if dir == @1"@2"@3 then end + local a: {[Direction]: boolean} = {[@4"@5"@6]} + + if dir == @7`@8`@9 then end + local a: {[Direction]: boolean} = {[@A`@B`@C]} + )"); + + Luau::AutocompleteResult ac; + + ac = autocomplete('1'); + + LUAU_CHECK_HAS_NO_KEY(ac.entryMap, "left"); + LUAU_CHECK_HAS_NO_KEY(ac.entryMap, "right"); + + ac = autocomplete('2'); + + LUAU_CHECK_HAS_KEY(ac.entryMap, "left"); + LUAU_CHECK_HAS_NO_KEY(ac.entryMap, "right"); + + ac = autocomplete('3'); + + LUAU_CHECK_HAS_NO_KEY(ac.entryMap, "left"); + LUAU_CHECK_HAS_NO_KEY(ac.entryMap, "right"); + + ac = autocomplete('4'); + + LUAU_CHECK_HAS_NO_KEY(ac.entryMap, "left"); + LUAU_CHECK_HAS_NO_KEY(ac.entryMap, "right"); + + ac = autocomplete('5'); + + LUAU_CHECK_HAS_KEY(ac.entryMap, "left"); + LUAU_CHECK_HAS_KEY(ac.entryMap, "right"); + + ac = autocomplete('6'); + + LUAU_CHECK_HAS_NO_KEY(ac.entryMap, "left"); + LUAU_CHECK_HAS_NO_KEY(ac.entryMap, "right"); + + ac = autocomplete('7'); + + LUAU_CHECK_HAS_NO_KEY(ac.entryMap, "left"); + LUAU_CHECK_HAS_NO_KEY(ac.entryMap, "right"); + + ac = autocomplete('8'); + + LUAU_CHECK_HAS_KEY(ac.entryMap, "left"); + LUAU_CHECK_HAS_NO_KEY(ac.entryMap, "right"); + + ac = autocomplete('9'); + + LUAU_CHECK_HAS_NO_KEY(ac.entryMap, "left"); + LUAU_CHECK_HAS_NO_KEY(ac.entryMap, "right"); + + ac = autocomplete('A'); + + LUAU_CHECK_HAS_NO_KEY(ac.entryMap, "left"); + LUAU_CHECK_HAS_NO_KEY(ac.entryMap, "right"); + + ac = autocomplete('B'); + + LUAU_CHECK_HAS_KEY(ac.entryMap, "left"); + LUAU_CHECK_HAS_KEY(ac.entryMap, "right"); + + ac = autocomplete('C'); + + LUAU_CHECK_HAS_NO_KEY(ac.entryMap, "left"); + LUAU_CHECK_HAS_NO_KEY(ac.entryMap, "right"); +} + TEST_CASE_FIXTURE(ACFixture, "autocomplete_string_singleton_equality") { check(R"( @@ -3171,8 +3458,9 @@ end { check(R"( -local t: Foo -t:@1 +local function f(t: Foo) + t:@1 +end )"); auto ac = autocomplete('1'); @@ -3181,12 +3469,15 @@ t:@1 REQUIRE(ac.entryMap.count("two")); CHECK(!ac.entryMap["one"].wrongIndexType); CHECK(ac.entryMap["two"].wrongIndexType); + CHECK(ac.entryMap["one"].indexedWithSelf); + CHECK(ac.entryMap["two"].indexedWithSelf); } { check(R"( -local t: Foo -t.@1 +local function f(t: Foo) + t.@1 +end )"); auto ac = autocomplete('1'); @@ -3195,6 +3486,8 @@ t.@1 REQUIRE(ac.entryMap.count("two")); CHECK(ac.entryMap["one"].wrongIndexType); CHECK(!ac.entryMap["two"].wrongIndexType); + CHECK(!ac.entryMap["one"].indexedWithSelf); + CHECK(!ac.entryMap["two"].indexedWithSelf); } } @@ -3224,6 +3517,7 @@ t:@1 REQUIRE(ac.entryMap.count("m")); CHECK(!ac.entryMap["m"].wrongIndexType); + CHECK(ac.entryMap["m"].indexedWithSelf); } TEST_CASE_FIXTURE(ACFixture, "no_incompatible_self_calls") @@ -3238,6 +3532,7 @@ t:@1 REQUIRE(ac.entryMap.count("m")); CHECK(ac.entryMap["m"].wrongIndexType); + CHECK(ac.entryMap["m"].indexedWithSelf); } TEST_CASE_FIXTURE(ACFixture, "no_incompatible_self_calls_2") @@ -3253,6 +3548,7 @@ t:@1 REQUIRE(ac.entryMap.count("f")); CHECK(ac.entryMap["f"].wrongIndexType); + CHECK(ac.entryMap["f"].indexedWithSelf); } TEST_CASE_FIXTURE(ACFixture, "do_wrong_compatible_self_calls") @@ -3268,6 +3564,26 @@ t:@1 REQUIRE(ac.entryMap.count("m")); // We can make changes to mark this as a wrong way to call even though it's compatible CHECK(!ac.entryMap["m"].wrongIndexType); + CHECK(ac.entryMap["m"].indexedWithSelf); +} + +TEST_CASE_FIXTURE(ACFixture, "do_wrong_compatible_nonself_calls") +{ + check(R"( +local t = {} +function t:m(x: string) end +t.@1 + )"); + + auto ac = autocomplete('1'); + + REQUIRE(ac.entryMap.count("m")); + + if (FFlag::LuauSolverV2) + CHECK(ac.entryMap["m"].wrongIndexType); + else + CHECK(!ac.entryMap["m"].wrongIndexType); + CHECK(!ac.entryMap["m"].indexedWithSelf); } TEST_CASE_FIXTURE(ACFixture, "no_wrong_compatible_self_calls_with_generics") @@ -3283,6 +3599,7 @@ t:@1 REQUIRE(ac.entryMap.count("m")); // While this call is compatible with the type, this requires instantiation of a generic type which we don't perform CHECK(ac.entryMap["m"].wrongIndexType); + CHECK(ac.entryMap["m"].indexedWithSelf); } TEST_CASE_FIXTURE(ACFixture, "string_prim_self_calls_are_fine") @@ -3296,10 +3613,13 @@ s:@1 REQUIRE(ac.entryMap.count("byte")); CHECK(ac.entryMap["byte"].wrongIndexType == false); + CHECK(ac.entryMap["byte"].indexedWithSelf); REQUIRE(ac.entryMap.count("char")); CHECK(ac.entryMap["char"].wrongIndexType == true); + CHECK(ac.entryMap["char"].indexedWithSelf); REQUIRE(ac.entryMap.count("sub")); CHECK(ac.entryMap["sub"].wrongIndexType == false); + CHECK(ac.entryMap["sub"].indexedWithSelf); } TEST_CASE_FIXTURE(ACFixture, "string_prim_non_self_calls_are_avoided") @@ -3313,8 +3633,10 @@ s.@1 REQUIRE(ac.entryMap.count("char")); CHECK(ac.entryMap["char"].wrongIndexType == false); + CHECK(!ac.entryMap["char"].indexedWithSelf); REQUIRE(ac.entryMap.count("sub")); CHECK(ac.entryMap["sub"].wrongIndexType == true); + CHECK(!ac.entryMap["sub"].indexedWithSelf); } TEST_CASE_FIXTURE(ACBuiltinsFixture, "library_non_self_calls_are_fine") @@ -3327,10 +3649,13 @@ string.@1 REQUIRE(ac.entryMap.count("byte")); CHECK(ac.entryMap["byte"].wrongIndexType == false); + CHECK(!ac.entryMap["byte"].indexedWithSelf); REQUIRE(ac.entryMap.count("char")); CHECK(ac.entryMap["char"].wrongIndexType == false); + CHECK(!ac.entryMap["char"].indexedWithSelf); REQUIRE(ac.entryMap.count("sub")); CHECK(ac.entryMap["sub"].wrongIndexType == false); + CHECK(!ac.entryMap["sub"].indexedWithSelf); check(R"( table.@1 @@ -3340,10 +3665,13 @@ table.@1 REQUIRE(ac.entryMap.count("remove")); CHECK(ac.entryMap["remove"].wrongIndexType == false); + CHECK(!ac.entryMap["remove"].indexedWithSelf); REQUIRE(ac.entryMap.count("getn")); CHECK(ac.entryMap["getn"].wrongIndexType == false); + CHECK(!ac.entryMap["getn"].indexedWithSelf); REQUIRE(ac.entryMap.count("insert")); CHECK(ac.entryMap["insert"].wrongIndexType == false); + CHECK(!ac.entryMap["insert"].indexedWithSelf); } TEST_CASE_FIXTURE(ACBuiltinsFixture, "library_self_calls_are_invalid") @@ -3356,13 +3684,16 @@ string:@1 REQUIRE(ac.entryMap.count("byte")); CHECK(ac.entryMap["byte"].wrongIndexType == true); + CHECK(ac.entryMap["byte"].indexedWithSelf); REQUIRE(ac.entryMap.count("char")); CHECK(ac.entryMap["char"].wrongIndexType == true); + CHECK(ac.entryMap["char"].indexedWithSelf); // We want the next test to evaluate to 'true', but we have to allow function defined with 'self' to be callable with ':' // We may change the definition of the string metatable to not use 'self' types in the future (like byte/char/pack/unpack) REQUIRE(ac.entryMap.count("sub")); CHECK(ac.entryMap["sub"].wrongIndexType == false); + CHECK(ac.entryMap["sub"].indexedWithSelf); } TEST_CASE_FIXTURE(ACFixture, "source_module_preservation_and_invalidation") @@ -3376,6 +3707,8 @@ a.@1 auto ac = autocomplete('1'); + CHECK(2 == ac.entryMap.size()); + CHECK(ac.entryMap.count("x")); CHECK(ac.entryMap.count("y")); @@ -3422,49 +3755,19 @@ TEST_CASE_FIXTURE(ACFixture, "globals_are_order_independent") CHECK(ac.entryMap.count("abc1")); } -TEST_CASE_FIXTURE(ACFixture, "type_reduction_is_hooked_up_to_autocomplete") -{ - ScopedFastFlag sff{"DebugLuauDeferredConstraintResolution", true}; - - check(R"( - type T = { x: (number & string)? } - - function f(thingamabob: T) - thingamabob.@1 - end - - function g(thingamabob: T) - thingama@2 - end - )"); - - ToStringOptions opts; - opts.exhaustive = true; - - auto ac1 = autocomplete('1'); - REQUIRE(ac1.entryMap.count("x")); - std::optional ty1 = ac1.entryMap.at("x").type; - REQUIRE(ty1); - CHECK("nil" == toString(*ty1, opts)); - - auto ac2 = autocomplete('2'); - REQUIRE(ac2.entryMap.count("thingamabob")); - std::optional ty2 = ac2.entryMap.at("thingamabob").type; - REQUIRE(ty2); - CHECK("{| x: nil |}" == toString(*ty2, opts)); -} - TEST_CASE_FIXTURE(ACFixture, "string_contents_is_available_to_callback") { loadDefinition(R"( declare function require(path: string): any )"); - std::optional require = frontend.typeCheckerForAutocomplete.globalScope->linearSearchForBinding("require"); + GlobalTypes& globals = FFlag::LuauSolverV2 ? frontend.globals : frontend.globalsForAutocomplete; + + std::optional require = globals.globalScope->linearSearchForBinding("require"); REQUIRE(require); - Luau::unfreeze(frontend.typeCheckerForAutocomplete.globalTypes); + Luau::unfreeze(globals.globalTypes); attachTag(require->typeId, "RequireCall"); - Luau::freeze(frontend.typeCheckerForAutocomplete.globalTypes); + Luau::freeze(globals.globalTypes); check(R"( local x = require("testing/@1") @@ -3472,17 +3775,21 @@ TEST_CASE_FIXTURE(ACFixture, "string_contents_is_available_to_callback") bool isCorrect = false; auto ac1 = autocomplete( - '1', [&isCorrect](std::string, std::optional, std::optional contents) -> std::optional { + '1', + [&isCorrect](std::string, std::optional, std::optional contents) -> std::optional + { isCorrect = contents && *contents == "testing/"; return std::nullopt; - }); + } + ); CHECK(isCorrect); } TEST_CASE_FIXTURE(ACFixture, "autocomplete_response_perf1" * doctest::timeout(0.5)) { - ScopedFastFlag luauAutocompleteSkipNormalization{"LuauAutocompleteSkipNormalization", true}; + if (FFlag::LuauSolverV2) + return; // FIXME: This test is just barely at the threshhold which makes it very flaky under the new solver // Build a function type with a large overload set const int parts = 100; @@ -3509,4 +3816,527 @@ TEST_CASE_FIXTURE(ACFixture, "autocomplete_response_perf1" * doctest::timeout(0. CHECK(ac.entryMap.count("Instance")); } +TEST_CASE_FIXTURE(ACFixture, "autocomplete_subtyping_recursion_limit") +{ + // TODO: in old solver, type resolve can't handle the type in this test without a stack overflow + if (!FFlag::LuauSolverV2) + return; + + ScopedFastInt luauTypeInferRecursionLimit{FInt::LuauTypeInferRecursionLimit, 10}; + + const int parts = 100; + std::string source; + + source += "function f()\n"; + + std::string prefix; + for (int i = 0; i < parts; i++) + formatAppend(prefix, "(nil|({a%d:number}&", i); + formatAppend(prefix, "(nil|{a%d:number})", parts); + for (int i = 0; i < parts; i++) + formatAppend(prefix, "))"); + + source += "local x1 : " + prefix + "\n"; + source += "local y : {a1:number} = x@1\n"; + + source += "end\n"; + + check(source); + + auto ac = autocomplete('1'); + + CHECK(ac.entryMap.count("true")); + CHECK(ac.entryMap.count("x1")); +} + +TEST_CASE_FIXTURE(ACFixture, "strict_mode_force") +{ + check(R"( +--!nonstrict +local a: {x: number} = {x=1} +local b = a +local c = b.@1 + )"); + + auto ac = autocomplete('1'); + + CHECK_EQ(1, ac.entryMap.size()); + CHECK(ac.entryMap.count("x")); +} + +TEST_CASE_FIXTURE(ACFixture, "suggest_exported_types") +{ + check(R"( +export type Type = {a: number} +local a: T@1 + )"); + + auto ac = autocomplete('1'); + + CHECK(ac.entryMap.count("Type")); + CHECK_EQ(ac.context, AutocompleteContext::Type); +} + +TEST_CASE_FIXTURE(ACFixture, "frontend_use_correct_global_scope") +{ + loadDefinition(R"( + declare class Instance + Name: string + end + )"); + + CheckResult result = check(R"( + local a: unknown = nil + if typeof(a) == "Instance" then + local b = a.@1 + end + )"); + auto ac = autocomplete('1'); + + CHECK_EQ(1, ac.entryMap.size()); + CHECK(ac.entryMap.count("Name")); +} + +TEST_CASE_FIXTURE(ACFixture, "string_completion_outside_quotes") +{ + loadDefinition(R"( + declare function require(path: string): any + )"); + + GlobalTypes& globals = FFlag::LuauSolverV2 ? frontend.globals : frontend.globalsForAutocomplete; + + std::optional require = globals.globalScope->linearSearchForBinding("require"); + REQUIRE(require); + Luau::unfreeze(globals.globalTypes); + attachTag(require->typeId, "RequireCall"); + Luau::freeze(globals.globalTypes); + + check(R"( + local x = require(@1"@2"@3) + )"); + + StringCompletionCallback callback = [](std::string, std::optional, std::optional contents + ) -> std::optional + { + Luau::AutocompleteEntryMap results = {{"test", Luau::AutocompleteEntry{Luau::AutocompleteEntryKind::String, std::nullopt, false, false}}}; + return results; + }; + + auto ac = autocomplete('2', callback); + + CHECK_EQ(1, ac.entryMap.size()); + CHECK(ac.entryMap.count("test")); + + ac = autocomplete('1', callback); + + CHECK_EQ(0, ac.entryMap.size()); + + ac = autocomplete('3', callback); + + CHECK_EQ(0, ac.entryMap.size()); +} + +TEST_CASE_FIXTURE(ACFixture, "anonymous_autofilled_empty") +{ + check(R"( +local function foo(a: () -> ()) + a() +end + +foo(@1) + )"); + + const std::optional EXPECTED_INSERT = "function() end"; + + auto ac = autocomplete('1'); + + REQUIRE(ac.entryMap.count(kGeneratedAnonymousFunctionEntryName) == 1); + CHECK(ac.entryMap[kGeneratedAnonymousFunctionEntryName].kind == Luau::AutocompleteEntryKind::GeneratedFunction); + CHECK(ac.entryMap[kGeneratedAnonymousFunctionEntryName].typeCorrect == Luau::TypeCorrectKind::Correct); + REQUIRE(ac.entryMap[kGeneratedAnonymousFunctionEntryName].insertText); + CHECK_EQ(EXPECTED_INSERT, *ac.entryMap[kGeneratedAnonymousFunctionEntryName].insertText); +} + +TEST_CASE_FIXTURE(ACFixture, "anonymous_autofilled_args") +{ + check(R"( +local function foo(a: (number, string) -> ()) + a() +end + +foo(@1) + )"); + + const std::optional EXPECTED_INSERT = "function(a0: number, a1: string) end"; + + auto ac = autocomplete('1'); + + REQUIRE(ac.entryMap.count(kGeneratedAnonymousFunctionEntryName) == 1); + CHECK(ac.entryMap[kGeneratedAnonymousFunctionEntryName].kind == Luau::AutocompleteEntryKind::GeneratedFunction); + CHECK(ac.entryMap[kGeneratedAnonymousFunctionEntryName].typeCorrect == Luau::TypeCorrectKind::Correct); + REQUIRE(ac.entryMap[kGeneratedAnonymousFunctionEntryName].insertText); + CHECK_EQ(EXPECTED_INSERT, *ac.entryMap[kGeneratedAnonymousFunctionEntryName].insertText); +} + +TEST_CASE_FIXTURE(ACFixture, "anonymous_autofilled_args_single_return") +{ + check(R"( +local function foo(a: (number, string) -> (string)) + a() +end + +foo(@1) + )"); + + const std::optional EXPECTED_INSERT = "function(a0: number, a1: string): string end"; + + auto ac = autocomplete('1'); + + REQUIRE(ac.entryMap.count(kGeneratedAnonymousFunctionEntryName) == 1); + CHECK(ac.entryMap[kGeneratedAnonymousFunctionEntryName].kind == Luau::AutocompleteEntryKind::GeneratedFunction); + CHECK(ac.entryMap[kGeneratedAnonymousFunctionEntryName].typeCorrect == Luau::TypeCorrectKind::Correct); + REQUIRE(ac.entryMap[kGeneratedAnonymousFunctionEntryName].insertText); + CHECK_EQ(EXPECTED_INSERT, *ac.entryMap[kGeneratedAnonymousFunctionEntryName].insertText); +} + +TEST_CASE_FIXTURE(ACFixture, "anonymous_autofilled_args_multi_return") +{ + check(R"( +local function foo(a: (number, string) -> (string, number)) + a() +end + +foo(@1) + )"); + + const std::optional EXPECTED_INSERT = "function(a0: number, a1: string): (string, number) end"; + + auto ac = autocomplete('1'); + + REQUIRE(ac.entryMap.count(kGeneratedAnonymousFunctionEntryName) == 1); + CHECK(ac.entryMap[kGeneratedAnonymousFunctionEntryName].kind == Luau::AutocompleteEntryKind::GeneratedFunction); + CHECK(ac.entryMap[kGeneratedAnonymousFunctionEntryName].typeCorrect == Luau::TypeCorrectKind::Correct); + REQUIRE(ac.entryMap[kGeneratedAnonymousFunctionEntryName].insertText); + CHECK_EQ(EXPECTED_INSERT, *ac.entryMap[kGeneratedAnonymousFunctionEntryName].insertText); +} + +TEST_CASE_FIXTURE(ACFixture, "anonymous_autofilled__noargs_multi_return") +{ + check(R"( +local function foo(a: () -> (string, number)) + a() +end + +foo(@1) + )"); + + const std::optional EXPECTED_INSERT = "function(): (string, number) end"; + + auto ac = autocomplete('1'); + + REQUIRE(ac.entryMap.count(kGeneratedAnonymousFunctionEntryName) == 1); + CHECK(ac.entryMap[kGeneratedAnonymousFunctionEntryName].kind == Luau::AutocompleteEntryKind::GeneratedFunction); + CHECK(ac.entryMap[kGeneratedAnonymousFunctionEntryName].typeCorrect == Luau::TypeCorrectKind::Correct); + REQUIRE(ac.entryMap[kGeneratedAnonymousFunctionEntryName].insertText); + CHECK_EQ(EXPECTED_INSERT, *ac.entryMap[kGeneratedAnonymousFunctionEntryName].insertText); +} + +TEST_CASE_FIXTURE(ACFixture, "anonymous_autofilled__varargs_multi_return") +{ + check(R"( +local function foo(a: (...number) -> (string, number)) + a() +end + +foo(@1) + )"); + + const std::optional EXPECTED_INSERT = "function(...: number): (string, number) end"; + + auto ac = autocomplete('1'); + + REQUIRE(ac.entryMap.count(kGeneratedAnonymousFunctionEntryName) == 1); + CHECK(ac.entryMap[kGeneratedAnonymousFunctionEntryName].kind == Luau::AutocompleteEntryKind::GeneratedFunction); + CHECK(ac.entryMap[kGeneratedAnonymousFunctionEntryName].typeCorrect == Luau::TypeCorrectKind::Correct); + REQUIRE(ac.entryMap[kGeneratedAnonymousFunctionEntryName].insertText); + CHECK_EQ(EXPECTED_INSERT, *ac.entryMap[kGeneratedAnonymousFunctionEntryName].insertText); +} + +TEST_CASE_FIXTURE(ACFixture, "anonymous_autofilled_multi_varargs_multi_return") +{ + check(R"( +local function foo(a: (string, ...number) -> (string, number)) + a() +end + +foo(@1) + )"); + + const std::optional EXPECTED_INSERT = "function(a0: string, ...: number): (string, number) end"; + + auto ac = autocomplete('1'); + + REQUIRE(ac.entryMap.count(kGeneratedAnonymousFunctionEntryName) == 1); + CHECK(ac.entryMap[kGeneratedAnonymousFunctionEntryName].kind == Luau::AutocompleteEntryKind::GeneratedFunction); + CHECK(ac.entryMap[kGeneratedAnonymousFunctionEntryName].typeCorrect == Luau::TypeCorrectKind::Correct); + REQUIRE(ac.entryMap[kGeneratedAnonymousFunctionEntryName].insertText); + CHECK_EQ(EXPECTED_INSERT, *ac.entryMap[kGeneratedAnonymousFunctionEntryName].insertText); +} + +TEST_CASE_FIXTURE(ACFixture, "anonymous_autofilled_multi_varargs_varargs_return") +{ + check(R"( +local function foo(a: (string, ...number) -> ...number) + a() +end + +foo(@1) + )"); + + const std::optional EXPECTED_INSERT = "function(a0: string, ...: number): ...number end"; + + auto ac = autocomplete('1'); + + REQUIRE(ac.entryMap.count(kGeneratedAnonymousFunctionEntryName) == 1); + CHECK(ac.entryMap[kGeneratedAnonymousFunctionEntryName].kind == Luau::AutocompleteEntryKind::GeneratedFunction); + CHECK(ac.entryMap[kGeneratedAnonymousFunctionEntryName].typeCorrect == Luau::TypeCorrectKind::Correct); + REQUIRE(ac.entryMap[kGeneratedAnonymousFunctionEntryName].insertText); + CHECK_EQ(EXPECTED_INSERT, *ac.entryMap[kGeneratedAnonymousFunctionEntryName].insertText); +} + +TEST_CASE_FIXTURE(ACFixture, "anonymous_autofilled_multi_varargs_multi_varargs_return") +{ + check(R"( +local function foo(a: (string, ...number) -> (boolean, ...number)) + a() +end + +foo(@1) + )"); + + const std::optional EXPECTED_INSERT = "function(a0: string, ...: number): (boolean, ...number) end"; + + auto ac = autocomplete('1'); + + REQUIRE(ac.entryMap.count(kGeneratedAnonymousFunctionEntryName) == 1); + CHECK(ac.entryMap[kGeneratedAnonymousFunctionEntryName].kind == Luau::AutocompleteEntryKind::GeneratedFunction); + CHECK(ac.entryMap[kGeneratedAnonymousFunctionEntryName].typeCorrect == Luau::TypeCorrectKind::Correct); + REQUIRE(ac.entryMap[kGeneratedAnonymousFunctionEntryName].insertText); + CHECK_EQ(EXPECTED_INSERT, *ac.entryMap[kGeneratedAnonymousFunctionEntryName].insertText); +} + +TEST_CASE_FIXTURE(ACFixture, "anonymous_autofilled_named_args") +{ + check(R"( +local function foo(a: (foo: number, bar: string) -> (string, number)) + a() +end + +foo(@1) + )"); + + const std::optional EXPECTED_INSERT = "function(foo: number, bar: string): (string, number) end"; + + auto ac = autocomplete('1'); + + REQUIRE(ac.entryMap.count(kGeneratedAnonymousFunctionEntryName) == 1); + CHECK(ac.entryMap[kGeneratedAnonymousFunctionEntryName].kind == Luau::AutocompleteEntryKind::GeneratedFunction); + CHECK(ac.entryMap[kGeneratedAnonymousFunctionEntryName].typeCorrect == Luau::TypeCorrectKind::Correct); + REQUIRE(ac.entryMap[kGeneratedAnonymousFunctionEntryName].insertText); + CHECK_EQ(EXPECTED_INSERT, *ac.entryMap[kGeneratedAnonymousFunctionEntryName].insertText); +} + +TEST_CASE_FIXTURE(ACFixture, "anonymous_autofilled_partially_args") +{ + check(R"( +local function foo(a: (number, bar: string) -> (string, number)) + a() +end + +foo(@1) + )"); + + const std::optional EXPECTED_INSERT = "function(a0: number, bar: string): (string, number) end"; + + auto ac = autocomplete('1'); + + REQUIRE(ac.entryMap.count(kGeneratedAnonymousFunctionEntryName) == 1); + CHECK(ac.entryMap[kGeneratedAnonymousFunctionEntryName].kind == Luau::AutocompleteEntryKind::GeneratedFunction); + CHECK(ac.entryMap[kGeneratedAnonymousFunctionEntryName].typeCorrect == Luau::TypeCorrectKind::Correct); + REQUIRE(ac.entryMap[kGeneratedAnonymousFunctionEntryName].insertText); + CHECK_EQ(EXPECTED_INSERT, *ac.entryMap[kGeneratedAnonymousFunctionEntryName].insertText); +} + +TEST_CASE_FIXTURE(ACFixture, "anonymous_autofilled_partially_args_last") +{ + check(R"( +local function foo(a: (foo: number, string) -> (string, number)) + a() +end + +foo(@1) + )"); + + const std::optional EXPECTED_INSERT = "function(foo: number, a1: string): (string, number) end"; + + auto ac = autocomplete('1'); + + REQUIRE(ac.entryMap.count(kGeneratedAnonymousFunctionEntryName) == 1); + CHECK(ac.entryMap[kGeneratedAnonymousFunctionEntryName].kind == Luau::AutocompleteEntryKind::GeneratedFunction); + CHECK(ac.entryMap[kGeneratedAnonymousFunctionEntryName].typeCorrect == Luau::TypeCorrectKind::Correct); + REQUIRE(ac.entryMap[kGeneratedAnonymousFunctionEntryName].insertText); + CHECK_EQ(EXPECTED_INSERT, *ac.entryMap[kGeneratedAnonymousFunctionEntryName].insertText); +} + +TEST_CASE_FIXTURE(ACFixture, "anonymous_autofilled_typeof_args") +{ + check(R"( +local t = { a = 1, b = 2 } + +local function foo(a: (foo: typeof(t)) -> ()) + a() +end + +foo(@1) + )"); + + const std::optional EXPECTED_INSERT = "function(foo) end"; // Cannot utter this type. + + auto ac = autocomplete('1'); + + REQUIRE(ac.entryMap.count(kGeneratedAnonymousFunctionEntryName) == 1); + CHECK(ac.entryMap[kGeneratedAnonymousFunctionEntryName].kind == Luau::AutocompleteEntryKind::GeneratedFunction); + CHECK(ac.entryMap[kGeneratedAnonymousFunctionEntryName].typeCorrect == Luau::TypeCorrectKind::Correct); + REQUIRE(ac.entryMap[kGeneratedAnonymousFunctionEntryName].insertText); + CHECK_EQ(EXPECTED_INSERT, *ac.entryMap[kGeneratedAnonymousFunctionEntryName].insertText); +} + +TEST_CASE_FIXTURE(ACFixture, "anonymous_autofilled_table_literal_args") +{ + check(R"( +local function foo(a: (tbl: { x: number, y: number }) -> number) return a({x=2, y = 3}) end +foo(@1) + )"); + + const std::optional EXPECTED_INSERT = "function(tbl: { x: number, y: number }): number end"; + + auto ac = autocomplete('1'); + + REQUIRE(ac.entryMap.count(kGeneratedAnonymousFunctionEntryName) == 1); + CHECK(ac.entryMap[kGeneratedAnonymousFunctionEntryName].kind == Luau::AutocompleteEntryKind::GeneratedFunction); + CHECK(ac.entryMap[kGeneratedAnonymousFunctionEntryName].typeCorrect == Luau::TypeCorrectKind::Correct); + REQUIRE(ac.entryMap[kGeneratedAnonymousFunctionEntryName].insertText); + CHECK_EQ(EXPECTED_INSERT, *ac.entryMap[kGeneratedAnonymousFunctionEntryName].insertText); +} + +TEST_CASE_FIXTURE(ACFixture, "anonymous_autofilled_typeof_returns") +{ + check(R"( +local t = { a = 1, b = 2 } + +local function foo(a: () -> typeof(t)) + a() +end + +foo(@1) + )"); + + const std::optional EXPECTED_INSERT = "function() end"; // Cannot utter this type. + + auto ac = autocomplete('1'); + + REQUIRE(ac.entryMap.count(kGeneratedAnonymousFunctionEntryName) == 1); + CHECK(ac.entryMap[kGeneratedAnonymousFunctionEntryName].kind == Luau::AutocompleteEntryKind::GeneratedFunction); + CHECK(ac.entryMap[kGeneratedAnonymousFunctionEntryName].typeCorrect == Luau::TypeCorrectKind::Correct); + REQUIRE(ac.entryMap[kGeneratedAnonymousFunctionEntryName].insertText); + CHECK_EQ(EXPECTED_INSERT, *ac.entryMap[kGeneratedAnonymousFunctionEntryName].insertText); +} + +TEST_CASE_FIXTURE(ACFixture, "anonymous_autofilled_table_literal_args") +{ + check(R"( +local function foo(a: () -> { x: number, y: number }) return {x=2, y = 3} end +foo(@1) + )"); + + const std::optional EXPECTED_INSERT = "function(): { x: number, y: number } end"; + + auto ac = autocomplete('1'); + + REQUIRE(ac.entryMap.count(kGeneratedAnonymousFunctionEntryName) == 1); + CHECK(ac.entryMap[kGeneratedAnonymousFunctionEntryName].kind == Luau::AutocompleteEntryKind::GeneratedFunction); + CHECK(ac.entryMap[kGeneratedAnonymousFunctionEntryName].typeCorrect == Luau::TypeCorrectKind::Correct); + REQUIRE(ac.entryMap[kGeneratedAnonymousFunctionEntryName].insertText); + CHECK_EQ(EXPECTED_INSERT, *ac.entryMap[kGeneratedAnonymousFunctionEntryName].insertText); +} + +TEST_CASE_FIXTURE(ACFixture, "anonymous_autofilled_typeof_vararg") +{ + check(R"( +local t = { a = 1, b = 2 } + +local function foo(a: (...typeof(t)) -> ()) + a() +end + +foo(@1) + )"); + + const std::optional EXPECTED_INSERT = "function(...) end"; // Cannot utter this type. + + auto ac = autocomplete('1'); + + REQUIRE(ac.entryMap.count(kGeneratedAnonymousFunctionEntryName) == 1); + CHECK(ac.entryMap[kGeneratedAnonymousFunctionEntryName].kind == Luau::AutocompleteEntryKind::GeneratedFunction); + CHECK(ac.entryMap[kGeneratedAnonymousFunctionEntryName].typeCorrect == Luau::TypeCorrectKind::Correct); + REQUIRE(ac.entryMap[kGeneratedAnonymousFunctionEntryName].insertText); + CHECK_EQ(EXPECTED_INSERT, *ac.entryMap[kGeneratedAnonymousFunctionEntryName].insertText); +} + +TEST_CASE_FIXTURE(ACFixture, "anonymous_autofilled_generic_type_pack_vararg") +{ + // CLI-116932 - Autocomplete on a anonymous function in a function argument should not recommend a function with a generic parameter. + if (FFlag::LuauSolverV2) + return; + check(R"( +local function foo(a: (...A) -> number, ...: A) + return a(...) +end + +foo(@1) + )"); + + const std::optional EXPECTED_INSERT = "function(...): number end"; + + auto ac = autocomplete('1'); + + REQUIRE(ac.entryMap.count(kGeneratedAnonymousFunctionEntryName) == 1); + CHECK(ac.entryMap[kGeneratedAnonymousFunctionEntryName].kind == Luau::AutocompleteEntryKind::GeneratedFunction); + CHECK(ac.entryMap[kGeneratedAnonymousFunctionEntryName].typeCorrect == Luau::TypeCorrectKind::Correct); + REQUIRE(ac.entryMap[kGeneratedAnonymousFunctionEntryName].insertText); + CHECK_EQ(EXPECTED_INSERT, *ac.entryMap[kGeneratedAnonymousFunctionEntryName].insertText); +} + +TEST_CASE_FIXTURE(ACFixture, "anonymous_autofilled_generic_on_argument_type_pack_vararg") +{ + check(R"( +local function foo(a: (...: T...) -> number) + return a(4, 5, 6) +end + +foo(@1) + )"); + + const std::optional EXPECTED_INSERT = + FFlag::LuauSolverV2 ? "function(...: number): number end" : "function(...): number end"; + + auto ac = autocomplete('1'); + + REQUIRE(ac.entryMap.count(kGeneratedAnonymousFunctionEntryName) == 1); + CHECK(ac.entryMap[kGeneratedAnonymousFunctionEntryName].kind == Luau::AutocompleteEntryKind::GeneratedFunction); + CHECK(ac.entryMap[kGeneratedAnonymousFunctionEntryName].typeCorrect == Luau::TypeCorrectKind::Correct); + REQUIRE(ac.entryMap[kGeneratedAnonymousFunctionEntryName].insertText); + CHECK_EQ(EXPECTED_INSERT, *ac.entryMap[kGeneratedAnonymousFunctionEntryName].insertText); +} + TEST_SUITE_END(); diff --git a/tests/BuiltinDefinitions.test.cpp b/tests/BuiltinDefinitions.test.cpp index 188f2190f..08505c8de 100644 --- a/tests/BuiltinDefinitions.test.cpp +++ b/tests/BuiltinDefinitions.test.cpp @@ -12,15 +12,16 @@ TEST_SUITE_BEGIN("BuiltinDefinitionsTest"); TEST_CASE_FIXTURE(BuiltinsFixture, "lib_documentation_symbols") { - CHECK(!typeChecker.globalScope->bindings.empty()); + CHECK(!frontend.globals.globalScope->bindings.empty()); - for (const auto& [name, binding] : typeChecker.globalScope->bindings) + for (const auto& [name, binding] : frontend.globals.globalScope->bindings) { std::string nameString(name.c_str()); std::string expectedRootSymbol = "@luau/global/" + nameString; std::optional actualRootSymbol = binding.documentationSymbol; CHECK_MESSAGE( - actualRootSymbol == expectedRootSymbol, "expected symbol ", expectedRootSymbol, " for global ", nameString, ", got ", actualRootSymbol); + actualRootSymbol == expectedRootSymbol, "expected symbol ", expectedRootSymbol, " for global ", nameString, ", got ", actualRootSymbol + ); const TableType::Props* props = nullptr; if (const TableType* ttv = get(binding.typeId)) @@ -39,8 +40,9 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "lib_documentation_symbols") std::string fullPropName = nameString + "." + propName; std::string expectedPropSymbol = expectedRootSymbol + "." + propName; std::optional actualPropSymbol = prop.documentationSymbol; - CHECK_MESSAGE(actualPropSymbol == expectedPropSymbol, "expected symbol ", expectedPropSymbol, " for ", fullPropName, ", got ", - actualPropSymbol); + CHECK_MESSAGE( + actualPropSymbol == expectedPropSymbol, "expected symbol ", expectedPropSymbol, " for ", fullPropName, ", got ", actualPropSymbol + ); } } } diff --git a/tests/ClassFixture.cpp b/tests/ClassFixture.cpp index 087b88d53..a9bf95963 100644 --- a/tests/ClassFixture.cpp +++ b/tests/ClassFixture.cpp @@ -11,100 +11,128 @@ namespace Luau ClassFixture::ClassFixture() { - TypeArena& arena = typeChecker.globalTypes; - TypeId numberType = typeChecker.numberType; + GlobalTypes& globals = frontend.globals; + TypeArena& arena = globals.globalTypes; + TypeId numberType = builtinTypes->numberType; + TypeId stringType = builtinTypes->stringType; unfreeze(arena); - TypeId baseClassInstanceType = arena.addType(ClassType{"BaseClass", {}, nullopt, nullopt, {}, {}, "Test"}); + TypeId connectionType = arena.addType(ClassType{"Connection", {}, nullopt, nullopt, {}, {}, "Connection", {}}); + + TypeId baseClassInstanceType = arena.addType(ClassType{"BaseClass", {}, nullopt, nullopt, {}, {}, "Test", {}}); getMutable(baseClassInstanceType)->props = { - {"BaseMethod", {makeFunction(arena, baseClassInstanceType, {numberType}, {})}}, + {"BaseMethod", Property::readonly(makeFunction(arena, baseClassInstanceType, {numberType}, {}))}, {"BaseField", {numberType}}, + + {"Touched", Property::readonly(connectionType)}, + }; + + getMutable(connectionType)->props = { + {"Connect", {makeFunction(arena, connectionType, {makeFunction(arena, nullopt, {baseClassInstanceType}, {})}, {})}} }; - TypeId baseClassType = arena.addType(ClassType{"BaseClass", {}, nullopt, nullopt, {}, {}, "Test"}); + TypeId baseClassType = arena.addType(ClassType{"BaseClass", {}, nullopt, nullopt, {}, {}, "Test", {}}); getMutable(baseClassType)->props = { {"StaticMethod", {makeFunction(arena, nullopt, {}, {numberType})}}, {"Clone", {makeFunction(arena, nullopt, {baseClassInstanceType}, {baseClassInstanceType})}}, {"New", {makeFunction(arena, nullopt, {}, {baseClassInstanceType})}}, }; - typeChecker.globalScope->exportedTypeBindings["BaseClass"] = TypeFun{{}, baseClassInstanceType}; - addGlobalBinding(frontend, "BaseClass", baseClassType, "@test"); + globals.globalScope->exportedTypeBindings["BaseClass"] = TypeFun{{}, baseClassInstanceType}; + addGlobalBinding(globals, "BaseClass", baseClassType, "@test"); - TypeId childClassInstanceType = arena.addType(ClassType{"ChildClass", {}, baseClassInstanceType, nullopt, {}, {}, "Test"}); + TypeId childClassInstanceType = arena.addType(ClassType{"ChildClass", {}, baseClassInstanceType, nullopt, {}, {}, "Test", {}}); getMutable(childClassInstanceType)->props = { - {"Method", {makeFunction(arena, childClassInstanceType, {}, {typeChecker.stringType})}}, + {"Method", {makeFunction(arena, childClassInstanceType, {}, {stringType})}}, }; - TypeId childClassType = arena.addType(ClassType{"ChildClass", {}, baseClassType, nullopt, {}, {}, "Test"}); + TypeId childClassType = arena.addType(ClassType{"ChildClass", {}, baseClassType, nullopt, {}, {}, "Test", {}}); getMutable(childClassType)->props = { {"New", {makeFunction(arena, nullopt, {}, {childClassInstanceType})}}, }; - typeChecker.globalScope->exportedTypeBindings["ChildClass"] = TypeFun{{}, childClassInstanceType}; - addGlobalBinding(frontend, "ChildClass", childClassType, "@test"); + globals.globalScope->exportedTypeBindings["ChildClass"] = TypeFun{{}, childClassInstanceType}; + addGlobalBinding(globals, "ChildClass", childClassType, "@test"); - TypeId grandChildInstanceType = arena.addType(ClassType{"GrandChild", {}, childClassInstanceType, nullopt, {}, {}, "Test"}); + TypeId grandChildInstanceType = arena.addType(ClassType{"GrandChild", {}, childClassInstanceType, nullopt, {}, {}, "Test", {}}); getMutable(grandChildInstanceType)->props = { - {"Method", {makeFunction(arena, grandChildInstanceType, {}, {typeChecker.stringType})}}, + {"Method", {makeFunction(arena, grandChildInstanceType, {}, {stringType})}}, }; - TypeId grandChildType = arena.addType(ClassType{"GrandChild", {}, baseClassType, nullopt, {}, {}, "Test"}); + TypeId grandChildType = arena.addType(ClassType{"GrandChild", {}, baseClassType, nullopt, {}, {}, "Test", {}}); getMutable(grandChildType)->props = { {"New", {makeFunction(arena, nullopt, {}, {grandChildInstanceType})}}, }; - typeChecker.globalScope->exportedTypeBindings["GrandChild"] = TypeFun{{}, grandChildInstanceType}; - addGlobalBinding(frontend, "GrandChild", childClassType, "@test"); + globals.globalScope->exportedTypeBindings["GrandChild"] = TypeFun{{}, grandChildInstanceType}; + addGlobalBinding(globals, "GrandChild", childClassType, "@test"); - TypeId anotherChildInstanceType = arena.addType(ClassType{"AnotherChild", {}, baseClassInstanceType, nullopt, {}, {}, "Test"}); + TypeId anotherChildInstanceType = arena.addType(ClassType{"AnotherChild", {}, baseClassInstanceType, nullopt, {}, {}, "Test", {}}); getMutable(anotherChildInstanceType)->props = { - {"Method", {makeFunction(arena, anotherChildInstanceType, {}, {typeChecker.stringType})}}, + {"Method", {makeFunction(arena, anotherChildInstanceType, {}, {stringType})}}, }; - TypeId anotherChildType = arena.addType(ClassType{"AnotherChild", {}, baseClassType, nullopt, {}, {}, "Test"}); + TypeId anotherChildType = arena.addType(ClassType{"AnotherChild", {}, baseClassType, nullopt, {}, {}, "Test", {}}); getMutable(anotherChildType)->props = { {"New", {makeFunction(arena, nullopt, {}, {anotherChildInstanceType})}}, }; - typeChecker.globalScope->exportedTypeBindings["AnotherChild"] = TypeFun{{}, anotherChildInstanceType}; - addGlobalBinding(frontend, "AnotherChild", childClassType, "@test"); + globals.globalScope->exportedTypeBindings["AnotherChild"] = TypeFun{{}, anotherChildInstanceType}; + addGlobalBinding(globals, "AnotherChild", childClassType, "@test"); - TypeId unrelatedClassInstanceType = arena.addType(ClassType{"UnrelatedClass", {}, nullopt, nullopt, {}, {}, "Test"}); + TypeId unrelatedClassInstanceType = arena.addType(ClassType{"UnrelatedClass", {}, nullopt, nullopt, {}, {}, "Test", {}}); - TypeId unrelatedClassType = arena.addType(ClassType{"UnrelatedClass", {}, nullopt, nullopt, {}, {}, "Test"}); + TypeId unrelatedClassType = arena.addType(ClassType{"UnrelatedClass", {}, nullopt, nullopt, {}, {}, "Test", {}}); getMutable(unrelatedClassType)->props = { {"New", {makeFunction(arena, nullopt, {}, {unrelatedClassInstanceType})}}, }; - typeChecker.globalScope->exportedTypeBindings["UnrelatedClass"] = TypeFun{{}, unrelatedClassInstanceType}; - addGlobalBinding(frontend, "UnrelatedClass", unrelatedClassType, "@test"); + globals.globalScope->exportedTypeBindings["UnrelatedClass"] = TypeFun{{}, unrelatedClassInstanceType}; + addGlobalBinding(globals, "UnrelatedClass", unrelatedClassType, "@test"); TypeId vector2MetaType = arena.addType(TableType{}); - TypeId vector2InstanceType = arena.addType(ClassType{"Vector2", {}, nullopt, vector2MetaType, {}, {}, "Test"}); + vector2InstanceType = arena.addType(ClassType{"Vector2", {}, nullopt, vector2MetaType, {}, {}, "Test", {}}); getMutable(vector2InstanceType)->props = { {"X", {numberType}}, {"Y", {numberType}}, }; - TypeId vector2Type = arena.addType(ClassType{"Vector2", {}, nullopt, nullopt, {}, {}, "Test"}); + vector2Type = arena.addType(ClassType{"Vector2", {}, nullopt, nullopt, {}, {}, "Test", {}}); getMutable(vector2Type)->props = { {"New", {makeFunction(arena, nullopt, {numberType, numberType}, {vector2InstanceType})}}, }; getMutable(vector2MetaType)->props = { {"__add", {makeFunction(arena, nullopt, {vector2InstanceType, vector2InstanceType}, {vector2InstanceType})}}, + {"__mul", + {arena.addType(IntersectionType{{ + makeFunction(arena, vector2InstanceType, {vector2InstanceType}, {vector2InstanceType}), + makeFunction(arena, vector2InstanceType, {builtinTypes->numberType}, {vector2InstanceType}), + }})}} }; - typeChecker.globalScope->exportedTypeBindings["Vector2"] = TypeFun{{}, vector2InstanceType}; - addGlobalBinding(frontend, "Vector2", vector2Type, "@test"); + globals.globalScope->exportedTypeBindings["Vector2"] = TypeFun{{}, vector2InstanceType}; + addGlobalBinding(globals, "Vector2", vector2Type, "@test"); TypeId callableClassMetaType = arena.addType(TableType{}); - TypeId callableClassType = arena.addType(ClassType{"CallableClass", {}, nullopt, callableClassMetaType, {}, {}, "Test"}); + TypeId callableClassType = arena.addType(ClassType{"CallableClass", {}, nullopt, callableClassMetaType, {}, {}, "Test", {}}); getMutable(callableClassMetaType)->props = { - {"__call", {makeFunction(arena, nullopt, {callableClassType, typeChecker.stringType}, {typeChecker.numberType})}}, + {"__call", {makeFunction(arena, nullopt, {callableClassType, stringType}, {numberType})}}, + }; + globals.globalScope->exportedTypeBindings["CallableClass"] = TypeFun{{}, callableClassType}; + + auto addIndexableClass = [&arena, &globals](const char* className, TypeId keyType, TypeId returnType) + { + TypeId indexableClassMetaType = arena.addType(TableType{}); + TypeId indexableClassType = + arena.addType(ClassType{className, {}, nullopt, indexableClassMetaType, {}, {}, "Test", {}, TableIndexer{keyType, returnType}}); + globals.globalScope->exportedTypeBindings[className] = TypeFun{{}, indexableClassType}; }; - typeChecker.globalScope->exportedTypeBindings["CallableClass"] = TypeFun{{}, callableClassType}; - for (const auto& [name, tf] : typeChecker.globalScope->exportedTypeBindings) + // IndexableClass has a table indexer with a key type of 'number | string' and a return type of 'number' + addIndexableClass("IndexableClass", arena.addType(Luau::UnionType{{stringType, numberType}}), numberType); + // IndexableNumericKeyClass has a table indexer with a key type of 'number' and a return type of 'number' + addIndexableClass("IndexableNumericKeyClass", numberType, numberType); + + for (const auto& [name, tf] : globals.globalScope->exportedTypeBindings) persist(tf.type); freeze(arena); diff --git a/tests/ClassFixture.h b/tests/ClassFixture.h index c46697a26..4d8275c10 100644 --- a/tests/ClassFixture.h +++ b/tests/ClassFixture.h @@ -9,6 +9,9 @@ namespace Luau struct ClassFixture : BuiltinsFixture { ClassFixture(); + + TypeId vector2Type; + TypeId vector2InstanceType; }; } // namespace Luau diff --git a/tests/CodeAllocator.test.cpp b/tests/CodeAllocator.test.cpp index a6ed96f02..058a1100f 100644 --- a/tests/CodeAllocator.test.cpp +++ b/tests/CodeAllocator.test.cpp @@ -47,6 +47,56 @@ TEST_CASE("CodeAllocation") CHECK(nativeEntry == nativeData + kCodeAlignment); } +TEST_CASE("CodeAllocationCallbacks") +{ + struct AllocationData + { + size_t bytesAllocated = 0; + size_t bytesFreed = 0; + }; + + AllocationData allocationData{}; + + const auto allocationCallback = [](void* context, void* oldPointer, size_t oldSize, void* newPointer, size_t newSize) + { + AllocationData& allocationData = *static_cast(context); + if (oldPointer != nullptr) + { + CHECK(oldSize != 0); + + allocationData.bytesFreed += oldSize; + } + + if (newPointer != nullptr) + { + CHECK(newSize != 0); + + allocationData.bytesAllocated += newSize; + } + }; + + const size_t blockSize = 1024 * 1024; + const size_t maxTotalSize = 1024 * 1024; + + { + CodeAllocator allocator(blockSize, maxTotalSize, allocationCallback, &allocationData); + + uint8_t* nativeData = nullptr; + size_t sizeNativeData = 0; + uint8_t* nativeEntry = nullptr; + + std::vector code; + code.resize(128); + + REQUIRE(allocator.allocate(nullptr, 0, code.data(), code.size(), nativeData, sizeNativeData, nativeEntry)); + CHECK(allocationData.bytesAllocated == blockSize); + CHECK(allocationData.bytesFreed == 0); + } + + CHECK(allocationData.bytesAllocated == blockSize); + CHECK(allocationData.bytesFreed == blockSize); +} + TEST_CASE("CodeAllocationFailure") { size_t blockSize = 3000; @@ -97,7 +147,8 @@ TEST_CASE("CodeAllocationWithUnwindCallbacks") data.resize(8); allocator.context = &info; - allocator.createBlockUnwindInfo = [](void* context, uint8_t* block, size_t blockSize, size_t& beginOffset) -> void* { + allocator.createBlockUnwindInfo = [](void* context, uint8_t* block, size_t blockSize, size_t& beginOffset) -> void* + { Info& info = *(Info*)context; CHECK(info.unwind.size() == 8); @@ -108,7 +159,8 @@ TEST_CASE("CodeAllocationWithUnwindCallbacks") return new int(7); }; - allocator.destroyBlockUnwindInfo = [](void* context, void* unwindData) { + allocator.destroyBlockUnwindInfo = [](void* context, void* unwindData) + { Info& info = *(Info*)context; info.destroyCalled = true; @@ -135,27 +187,18 @@ TEST_CASE("WindowsUnwindCodesX64") UnwindBuilderWin unwind; - unwind.start(); - unwind.spill(16, rdx); - unwind.spill(8, rcx); - unwind.save(rdi); - unwind.save(rsi); - unwind.save(rbx); - unwind.save(rbp); - unwind.save(r12); - unwind.save(r13); - unwind.save(r14); - unwind.save(r15); - unwind.allocStack(72); - unwind.setupFrameReg(rbp, 48); - unwind.finish(); + unwind.startInfo(UnwindBuilder::X64); + unwind.startFunction(); + unwind.prologueX64(/* prologueSize= */ 23, /* stackSize= */ 72, /* setupFrame= */ true, {rdi, rsi, rbx, r12, r13, r14, r15}, {}); + unwind.finishFunction(0x11223344, 0x55443322); + unwind.finishInfo(); std::vector data; - data.resize(unwind.getSize()); - unwind.finalize(data.data(), nullptr, 0); + data.resize(unwind.getUnwindInfoSize()); + unwind.finalize(data.data(), 0, nullptr, 0); - std::vector expected{0x01, 0x23, 0x0a, 0x35, 0x23, 0x33, 0x1e, 0x82, 0x1a, 0xf0, 0x18, 0xe0, 0x16, 0xd0, 0x14, 0xc0, 0x12, 0x50, 0x10, - 0x30, 0x0e, 0x60, 0x0c, 0x70}; + std::vector expected{0x44, 0x33, 0x22, 0x11, 0x22, 0x33, 0x44, 0x55, 0x0c, 0x00, 0x00, 0x00, 0x01, 0x17, 0x0a, 0x05, 0x17, 0x82, + 0x13, 0xf0, 0x11, 0xe0, 0x0f, 0xd0, 0x0d, 0xc0, 0x0b, 0x30, 0x09, 0x60, 0x07, 0x70, 0x05, 0x03, 0x02, 0x50}; REQUIRE(data.size() == expected.size()); CHECK(memcmp(data.data(), expected.data(), expected.size()) == 0); @@ -168,34 +211,53 @@ TEST_CASE("Dwarf2UnwindCodesX64") UnwindBuilderDwarf2 unwind; - unwind.start(); - unwind.save(rdi); - unwind.save(rsi); - unwind.save(rbx); - unwind.save(rbp); - unwind.save(r12); - unwind.save(r13); - unwind.save(r14); - unwind.save(r15); - unwind.allocStack(72); - unwind.setupFrameReg(rbp, 48); - unwind.finish(); + unwind.startInfo(UnwindBuilder::X64); + unwind.startFunction(); + unwind.prologueX64(/* prologueSize= */ 23, /* stackSize= */ 72, /* setupFrame= */ true, {rdi, rsi, rbx, r12, r13, r14, r15}, {}); + unwind.finishFunction(0, 0); + unwind.finishInfo(); + + std::vector data; + data.resize(unwind.getUnwindInfoSize()); + unwind.finalize(data.data(), 0, nullptr, 0); + + std::vector expected{0x14, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x01, 0x78, 0x10, 0x0c, 0x07, 0x08, 0x90, 0x01, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x4c, 0x00, 0x00, 0x00, 0x1c, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x02, 0x0e, 0x10, 0x86, 0x02, + 0x02, 0x03, 0x02, 0x02, 0x0e, 0x18, 0x85, 0x03, 0x02, 0x02, 0x0e, 0x20, 0x84, 0x04, 0x02, 0x02, 0x0e, 0x28, + 0x83, 0x05, 0x02, 0x02, 0x0e, 0x30, 0x8c, 0x06, 0x02, 0x02, 0x0e, 0x38, 0x8d, 0x07, 0x02, 0x02, 0x0e, 0x40, + 0x8e, 0x08, 0x02, 0x02, 0x0e, 0x48, 0x8f, 0x09, 0x02, 0x04, 0x0e, 0x90, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00}; + + REQUIRE(data.size() == expected.size()); + CHECK(memcmp(data.data(), expected.data(), expected.size()) == 0); +} + +TEST_CASE("Dwarf2UnwindCodesA64") +{ + using namespace A64; + + UnwindBuilderDwarf2 unwind; + + unwind.startInfo(UnwindBuilder::A64); + unwind.startFunction(); + unwind.prologueA64(/* prologueSize= */ 28, /* stackSize= */ 64, {x29, x30, x19, x20, x21, x22, x23, x24}); + unwind.finishFunction(0, 32); + unwind.finishInfo(); std::vector data; - data.resize(unwind.getSize()); - unwind.finalize(data.data(), nullptr, 0); + data.resize(unwind.getUnwindInfoSize()); + unwind.finalize(data.data(), 0, nullptr, 0); - std::vector expected{0x14, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x01, 0x78, 0x10, 0x0c, 0x07, 0x08, 0x05, 0x10, 0x01, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x4c, 0x00, 0x00, 0x00, 0x1c, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x02, 0x0e, 0x10, 0x85, 0x02, 0x02, 0x02, 0x0e, 0x18, 0x84, 0x03, 0x02, 0x02, 0x0e, 0x20, 0x83, - 0x04, 0x02, 0x02, 0x0e, 0x28, 0x86, 0x05, 0x02, 0x02, 0x0e, 0x30, 0x8c, 0x06, 0x02, 0x02, 0x0e, 0x38, 0x8d, 0x07, 0x02, 0x02, 0x0e, 0x40, - 0x8e, 0x08, 0x02, 0x02, 0x0e, 0x48, 0x8f, 0x09, 0x02, 0x04, 0x0e, 0x90, 0x01, 0x02, 0x05, 0x00, 0x00, 0x00, 0x00, 0x00}; + std::vector expected{0x0c, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x01, 0x78, 0x1e, 0x0c, 0x1f, 0x00, 0x2c, + 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x20, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x04, 0x0e, 0x40, 0x02, 0x18, 0x9d, 0x08, 0x9e, 0x07, 0x93, + 0x06, 0x94, 0x05, 0x95, 0x04, 0x96, 0x03, 0x97, 0x02, 0x98, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}; REQUIRE(data.size() == expected.size()); CHECK(memcmp(data.data(), expected.data(), expected.size()) == 0); } -#if defined(__x86_64__) || defined(_M_X64) +#if defined(CODEGEN_TARGET_X64) #if defined(_WIN32) // Windows x64 ABI @@ -211,9 +273,14 @@ constexpr X64::RegisterX64 rArg3 = X64::rdx; constexpr X64::RegisterX64 rNonVol1 = X64::r12; constexpr X64::RegisterX64 rNonVol2 = X64::rbx; +constexpr X64::RegisterX64 rNonVol3 = X64::r13; +constexpr X64::RegisterX64 rNonVol4 = X64::r14; TEST_CASE("GeneratedCodeExecutionX64") { + if (!Luau::CodeGen::isSupported()) + return; + using namespace X64; AssemblyBuilderX64 build(/* logText= */ false); @@ -241,15 +308,23 @@ TEST_CASE("GeneratedCodeExecutionX64") CHECK(result == 210); } -void throwing(int64_t arg) +static void throwing(int64_t arg) { CHECK(arg == 25); throw std::runtime_error("testing"); } +static void nonthrowing(int64_t arg) +{ + CHECK(arg == 25); +} + TEST_CASE("GeneratedCodeExecutionWithThrowX64") { + if (!Luau::CodeGen::isSupported()) + return; + using namespace X64; AssemblyBuilderX64 build(/* logText= */ false); @@ -260,26 +335,25 @@ TEST_CASE("GeneratedCodeExecutionWithThrowX64") std::unique_ptr unwind = std::make_unique(); #endif - unwind->start(); + unwind->startInfo(UnwindBuilder::X64); + + Label functionBegin = build.setLabel(); + unwind->startFunction(); // Prologue + build.push(rbp); + build.mov(rbp, rsp); build.push(rNonVol1); - unwind->save(rNonVol1); build.push(rNonVol2); - unwind->save(rNonVol2); - build.push(rbp); - unwind->save(rbp); int stackSize = 32; int localsSize = 16; build.sub(rsp, stackSize + localsSize); - unwind->allocStack(stackSize + localsSize); - build.lea(rbp, addr[rsp + stackSize]); - unwind->setupFrameReg(rbp, stackSize); + uint32_t prologueSize = build.setLabel().location; - unwind->finish(); + unwind->prologueX64(prologueSize, stackSize + localsSize, /* setupFrame= */ true, {rNonVol1, rNonVol2}, {}); // Body build.mov(rNonVol1, rArg1); @@ -290,14 +364,18 @@ TEST_CASE("GeneratedCodeExecutionWithThrowX64") build.call(rNonVol2); // Epilogue - build.lea(rsp, addr[rbp + localsSize]); - build.pop(rbp); + build.add(rsp, stackSize + localsSize); build.pop(rNonVol2); build.pop(rNonVol1); + build.pop(rbp); build.ret(); + unwind->finishFunction(build.getLabelOffset(functionBegin), ~0u); + build.finalize(); + unwind->finishInfo(); + size_t blockSize = 1024 * 1024; size_t maxTotalSize = 1024 * 1024; CodeAllocator allocator(blockSize, maxTotalSize); @@ -315,6 +393,8 @@ TEST_CASE("GeneratedCodeExecutionWithThrowX64") using FunctionType = int64_t(int64_t, void (*)(int64_t)); FunctionType* f = (FunctionType*)nativeEntry; + f(10, nonthrowing); + // To simplify debugging, CHECK_THROWS_WITH_AS is not used here try { @@ -326,8 +406,266 @@ TEST_CASE("GeneratedCodeExecutionWithThrowX64") } } +static void obscureThrowCase(int64_t (*f)(int64_t, void (*)(int64_t))) +{ + // To simplify debugging, CHECK_THROWS_WITH_AS is not used here + try + { + f(10, throwing); + } + catch (const std::runtime_error& error) + { + CHECK(strcmp(error.what(), "testing") == 0); + } +} + +TEST_CASE("GeneratedCodeExecutionWithThrowX64Simd") +{ + // This test requires AVX + if (!Luau::CodeGen::isSupported()) + return; + + using namespace X64; + + AssemblyBuilderX64 build(/* logText= */ false); + +#if defined(_WIN32) + std::unique_ptr unwind = std::make_unique(); +#else + std::unique_ptr unwind = std::make_unique(); +#endif + + unwind->startInfo(UnwindBuilder::X64); + + Label functionBegin = build.setLabel(); + unwind->startFunction(); + + int stackSize = 32 + 64; + int localsSize = 16; + + // Prologue + build.push(rNonVol1); + build.push(rNonVol2); + build.push(rbp); + build.sub(rsp, stackSize + localsSize); + + if (build.abi == ABIX64::Windows) + { + build.vmovaps(xmmword[rsp + ((stackSize + localsSize) - 0x40)], xmm6); + build.vmovaps(xmmword[rsp + ((stackSize + localsSize) - 0x30)], xmm7); + build.vmovaps(xmmword[rsp + ((stackSize + localsSize) - 0x20)], xmm8); + build.vmovaps(xmmword[rsp + ((stackSize + localsSize) - 0x10)], xmm9); + } + + uint32_t prologueSize = build.setLabel().location; + + if (build.abi == ABIX64::Windows) + unwind->prologueX64(prologueSize, stackSize + localsSize, /* setupFrame= */ false, {rNonVol1, rNonVol2, rbp}, {xmm6, xmm7, xmm8, xmm9}); + else + unwind->prologueX64(prologueSize, stackSize + localsSize, /* setupFrame= */ false, {rNonVol1, rNonVol2, rbp}, {}); + + // Body + build.vxorpd(xmm0, xmm0, xmm0); + build.vmovsd(xmm6, xmm0, xmm0); + build.vmovsd(xmm7, xmm0, xmm0); + build.vmovsd(xmm8, xmm0, xmm0); + build.vmovsd(xmm9, xmm0, xmm0); + + build.mov(rNonVol1, rArg1); + build.mov(rNonVol2, rArg2); + + build.add(rNonVol1, 15); + build.mov(rArg1, rNonVol1); + build.call(rNonVol2); + + // Epilogue + if (build.abi == ABIX64::Windows) + { + build.vmovaps(xmm6, xmmword[rsp + ((stackSize + localsSize) - 0x40)]); + build.vmovaps(xmm7, xmmword[rsp + ((stackSize + localsSize) - 0x30)]); + build.vmovaps(xmm8, xmmword[rsp + ((stackSize + localsSize) - 0x20)]); + build.vmovaps(xmm9, xmmword[rsp + ((stackSize + localsSize) - 0x10)]); + } + + build.add(rsp, stackSize + localsSize); + build.pop(rbp); + build.pop(rNonVol2); + build.pop(rNonVol1); + build.ret(); + + unwind->finishFunction(build.getLabelOffset(functionBegin), ~0u); + + build.finalize(); + + unwind->finishInfo(); + + size_t blockSize = 1024 * 1024; + size_t maxTotalSize = 1024 * 1024; + CodeAllocator allocator(blockSize, maxTotalSize); + + allocator.context = unwind.get(); + allocator.createBlockUnwindInfo = createBlockUnwindInfo; + allocator.destroyBlockUnwindInfo = destroyBlockUnwindInfo; + + uint8_t* nativeData; + size_t sizeNativeData; + uint8_t* nativeEntry; + REQUIRE(allocator.allocate(build.data.data(), build.data.size(), build.code.data(), build.code.size(), nativeData, sizeNativeData, nativeEntry)); + REQUIRE(nativeEntry); + + using FunctionType = int64_t(int64_t, void (*)(int64_t)); + FunctionType* f = (FunctionType*)nativeEntry; + + f(10, nonthrowing); + + obscureThrowCase(f); +} + +TEST_CASE("GeneratedCodeExecutionMultipleFunctionsWithThrowX64") +{ + if (!Luau::CodeGen::isSupported()) + return; + + using namespace X64; + + AssemblyBuilderX64 build(/* logText= */ false); + +#if defined(_WIN32) + std::unique_ptr unwind = std::make_unique(); +#else + std::unique_ptr unwind = std::make_unique(); +#endif + + unwind->startInfo(UnwindBuilder::X64); + + Label start1; + Label start2; + + // First function + { + build.setLabel(start1); + unwind->startFunction(); + + // Prologue + build.push(rbp); + build.mov(rbp, rsp); + build.push(rNonVol1); + build.push(rNonVol2); + + int stackSize = 32; + int localsSize = 16; + + build.sub(rsp, stackSize + localsSize); + + uint32_t prologueSize = build.setLabel().location - start1.location; + + unwind->prologueX64(prologueSize, stackSize + localsSize, /* setupFrame= */ true, {rNonVol1, rNonVol2}, {}); + + // Body + build.mov(rNonVol1, rArg1); + build.mov(rNonVol2, rArg2); + + build.add(rNonVol1, 15); + build.mov(rArg1, rNonVol1); + build.call(rNonVol2); + + // Epilogue + build.add(rsp, stackSize + localsSize); + build.pop(rNonVol2); + build.pop(rNonVol1); + build.pop(rbp); + build.ret(); + + Label end1 = build.setLabel(); + unwind->finishFunction(build.getLabelOffset(start1), build.getLabelOffset(end1)); + } + + // Second function with different layout and no frame + { + build.setLabel(start2); + unwind->startFunction(); + + // Prologue + build.push(rNonVol1); + build.push(rNonVol2); + build.push(rNonVol3); + build.push(rNonVol4); + + int stackSize = 32; + int localsSize = 24; + + build.sub(rsp, stackSize + localsSize); + + uint32_t prologueSize = build.setLabel().location - start2.location; + + unwind->prologueX64(prologueSize, stackSize + localsSize, /* setupFrame= */ false, {rNonVol1, rNonVol2, rNonVol3, rNonVol4}, {}); + + // Body + build.mov(rNonVol3, rArg1); + build.mov(rNonVol4, rArg2); + + build.add(rNonVol3, 15); + build.mov(rArg1, rNonVol3); + build.call(rNonVol4); + + // Epilogue + build.add(rsp, stackSize + localsSize); + build.pop(rNonVol4); + build.pop(rNonVol3); + build.pop(rNonVol2); + build.pop(rNonVol1); + build.ret(); + + unwind->finishFunction(build.getLabelOffset(start2), ~0u); + } + + build.finalize(); + + unwind->finishInfo(); + + size_t blockSize = 1024 * 1024; + size_t maxTotalSize = 1024 * 1024; + CodeAllocator allocator(blockSize, maxTotalSize); + + allocator.context = unwind.get(); + allocator.createBlockUnwindInfo = createBlockUnwindInfo; + allocator.destroyBlockUnwindInfo = destroyBlockUnwindInfo; + + uint8_t* nativeData; + size_t sizeNativeData; + uint8_t* nativeEntry; + REQUIRE(allocator.allocate(build.data.data(), build.data.size(), build.code.data(), build.code.size(), nativeData, sizeNativeData, nativeEntry)); + REQUIRE(nativeEntry); + + using FunctionType = int64_t(int64_t, void (*)(int64_t)); + FunctionType* f1 = (FunctionType*)(nativeEntry + start1.location); + FunctionType* f2 = (FunctionType*)(nativeEntry + start2.location); + + // To simplify debugging, CHECK_THROWS_WITH_AS is not used here + try + { + f1(10, throwing); + } + catch (const std::runtime_error& error) + { + CHECK(strcmp(error.what(), "testing") == 0); + } + + try + { + f2(10, throwing); + } + catch (const std::runtime_error& error) + { + CHECK(strcmp(error.what(), "testing") == 0); + } +} + TEST_CASE("GeneratedCodeExecutionWithThrowOutsideTheGateX64") { + if (!Luau::CodeGen::isSupported()) + return; + using namespace X64; AssemblyBuilderX64 build(/* logText= */ false); @@ -338,36 +676,29 @@ TEST_CASE("GeneratedCodeExecutionWithThrowOutsideTheGateX64") std::unique_ptr unwind = std::make_unique(); #endif - unwind->start(); + unwind->startInfo(UnwindBuilder::X64); + + Label functionBegin = build.setLabel(); + unwind->startFunction(); // Prologue (some of these registers don't have to be saved, but we want to have a big prologue) + build.push(rbp); + build.mov(rbp, rsp); build.push(r10); - unwind->save(r10); build.push(r11); - unwind->save(r11); build.push(r12); - unwind->save(r12); build.push(r13); - unwind->save(r13); build.push(r14); - unwind->save(r14); build.push(r15); - unwind->save(r15); - build.push(rbp); - unwind->save(rbp); int stackSize = 64; int localsSize = 16; build.sub(rsp, stackSize + localsSize); - unwind->allocStack(stackSize + localsSize); - - build.lea(rbp, addr[rsp + stackSize]); - unwind->setupFrameReg(rbp, stackSize); - unwind->finish(); + uint32_t prologueSize = build.setLabel().location; - size_t prologueSize = build.setLabel().location; + unwind->prologueX64(prologueSize, stackSize + localsSize, /* setupFrame= */ true, {r10, r11, r12, r13, r14, r15}, {}); // Body build.mov(rax, rArg1); @@ -377,18 +708,22 @@ TEST_CASE("GeneratedCodeExecutionWithThrowOutsideTheGateX64") Label returnOffset = build.setLabel(); // Epilogue - build.lea(rsp, addr[rbp + localsSize]); - build.pop(rbp); + build.add(rsp, stackSize + localsSize); build.pop(r15); build.pop(r14); build.pop(r13); build.pop(r12); build.pop(r11); build.pop(r10); + build.pop(rbp); build.ret(); + unwind->finishFunction(build.getLabelOffset(functionBegin), ~0u); + build.finalize(); + unwind->finishInfo(); + size_t blockSize = 4096; // Force allocate to create a new block each time size_t maxTotalSize = 1024 * 1024; CodeAllocator allocator(blockSize, maxTotalSize); @@ -400,8 +735,8 @@ TEST_CASE("GeneratedCodeExecutionWithThrowOutsideTheGateX64") uint8_t* nativeData1; size_t sizeNativeData1; uint8_t* nativeEntry1; - REQUIRE( - allocator.allocate(build.data.data(), build.data.size(), build.code.data(), build.code.size(), nativeData1, sizeNativeData1, nativeEntry1)); + REQUIRE(allocator.allocate(build.data.data(), build.data.size(), build.code.data(), build.code.size(), nativeData1, sizeNativeData1, nativeEntry1) + ); REQUIRE(nativeEntry1); // Now we set the offset at the begining so that functions in new blocks will not overlay the locations @@ -424,8 +759,9 @@ TEST_CASE("GeneratedCodeExecutionWithThrowOutsideTheGateX64") uint8_t* nativeData2; size_t sizeNativeData2; uint8_t* nativeEntry2; - REQUIRE(allocator.allocate( - build2.data.data(), build2.data.size(), build2.code.data(), build2.code.size(), nativeData2, sizeNativeData2, nativeEntry2)); + REQUIRE( + allocator.allocate(build2.data.data(), build2.data.size(), build2.code.data(), build2.code.size(), nativeData2, sizeNativeData2, nativeEntry2) + ); REQUIRE(nativeEntry2); // To simplify debugging, CHECK_THROWS_WITH_AS is not used here @@ -443,7 +779,7 @@ TEST_CASE("GeneratedCodeExecutionWithThrowOutsideTheGateX64") #endif -#if defined(__aarch64__) +#if defined(CODEGEN_TARGET_A64) TEST_CASE("GeneratedCodeExecutionA64") { @@ -465,6 +801,7 @@ TEST_CASE("GeneratedCodeExecutionA64") build.add(x1, x1, 2); build.add(x0, x0, x1, /* LSL */ 1); + build.ret(); build.finalize(); @@ -476,8 +813,15 @@ TEST_CASE("GeneratedCodeExecutionA64") uint8_t* nativeData; size_t sizeNativeData; uint8_t* nativeEntry; - REQUIRE(allocator.allocate(build.data.data(), build.data.size(), reinterpret_cast(build.code.data()), build.code.size() * 4, nativeData, - sizeNativeData, nativeEntry)); + REQUIRE(allocator.allocate( + build.data.data(), + build.data.size(), + reinterpret_cast(build.code.data()), + build.code.size() * 4, + nativeData, + sizeNativeData, + nativeEntry + )); REQUIRE(nativeEntry); using FunctionType = int64_t(int64_t, int*); @@ -487,6 +831,91 @@ TEST_CASE("GeneratedCodeExecutionA64") CHECK(result == 42); } +#if 0 +static void throwing(int64_t arg) +{ + CHECK(arg == 25); + + throw std::runtime_error("testing"); +} + +TEST_CASE("GeneratedCodeExecutionWithThrowA64") +{ + // macOS 12 doesn't support JIT frames without pointer authentication + if (!isUnwindSupported()) + return; + + using namespace A64; + + AssemblyBuilderA64 build(/* logText= */ false); + + std::unique_ptr unwind = std::make_unique(); + + unwind->startInfo(UnwindBuilder::A64); + + build.sub(sp, sp, 32); + build.stp(x29, x30, mem(sp)); + build.str(x28, mem(sp, 16)); + build.mov(x29, sp); + + Label prologueEnd = build.setLabel(); + + build.add(x0, x0, 15); + build.blr(x1); + + build.ldr(x28, mem(sp, 16)); + build.ldp(x29, x30, mem(sp)); + build.add(sp, sp, 32); + + build.ret(); + + Label functionEnd = build.setLabel(); + + unwind->startFunction(); + unwind->prologueA64(build.getLabelOffset(prologueEnd), 32, {x29, x30, x28}); + unwind->finishFunction(0, build.getLabelOffset(functionEnd)); + + build.finalize(); + + unwind->finishInfo(); + + size_t blockSize = 1024 * 1024; + size_t maxTotalSize = 1024 * 1024; + CodeAllocator allocator(blockSize, maxTotalSize); + + allocator.context = unwind.get(); + allocator.createBlockUnwindInfo = createBlockUnwindInfo; + allocator.destroyBlockUnwindInfo = destroyBlockUnwindInfo; + + uint8_t* nativeData; + size_t sizeNativeData; + uint8_t* nativeEntry; + REQUIRE(allocator.allocate( + build.data.data(), + build.data.size(), + reinterpret_cast(build.code.data()), + build.code.size() * 4, + nativeData, + sizeNativeData, + nativeEntry + )); + REQUIRE(nativeEntry); + + using FunctionType = int64_t(int64_t, void (*)(int64_t)); + FunctionType* f = (FunctionType*)nativeEntry; + + // To simplify debugging, CHECK_THROWS_WITH_AS is not used here + try + { + f(10, throwing); + } + catch (const std::runtime_error& error) + { + CHECK(strcmp(error.what(), "testing") == 0); + } +} +#endif + #endif TEST_SUITE_END(); diff --git a/tests/Compiler.test.cpp b/tests/Compiler.test.cpp index 135a555ab..73b8816f3 100644 --- a/tests/Compiler.test.cpp +++ b/tests/Compiler.test.cpp @@ -15,14 +15,27 @@ namespace Luau std::string rep(const std::string& s, size_t n); } +LUAU_FASTINT(LuauCompileInlineDepth) +LUAU_FASTINT(LuauCompileInlineThreshold) +LUAU_FASTINT(LuauCompileInlineThresholdMaxBoost) +LUAU_FASTINT(LuauCompileLoopUnrollThreshold) +LUAU_FASTINT(LuauCompileLoopUnrollThresholdMaxBoost) +LUAU_FASTINT(LuauRecursionLimit) +LUAU_FASTFLAG(LuauUserDefinedTypeFunctionsSyntax2) + using namespace Luau; -static std::string compileFunction(const char* source, uint32_t id, int optimizationLevel = 1) +static std::string compileFunction(const char* source, uint32_t id, int optimizationLevel = 1, bool enableVectors = false) { Luau::BytecodeBuilder bcb; bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Code); Luau::CompileOptions options; options.optimizationLevel = optimizationLevel; + if (enableVectors) + { + options.vectorLib = "Vector3"; + options.vectorCtor = "new"; + } Luau::compileOrThrow(bcb, source, options); return bcb.dumpFunction(id); @@ -49,8 +62,50 @@ static std::string compileFunction0Coverage(const char* source, int level) return bcb.dumpFunction(0); } +static std::string compileTypeTable(const char* source) +{ + Luau::BytecodeBuilder bcb; + bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Code); + + Luau::CompileOptions opts; + opts.vectorType = "Vector3"; + opts.typeInfoLevel = 1; + Luau::compileOrThrow(bcb, source, opts); + + return bcb.dumpTypeInfo(); +} + TEST_SUITE_BEGIN("Compiler"); +TEST_CASE("BytecodeIsStable") +{ + // As noted in Bytecode.h, all enums used for bytecode storage and serialization are order-sensitive + // Adding entries in the middle will typically pass the tests but break compatibility + // This test codifies this by validating that in each enum, the last (or close-to-last) entry has a fixed encoding + + // This test will need to get occasionally revised to "move" the checked enum entries forward as we ship newer versions + // When doing so, please add *new* checks for more recent bytecode versions and keep existing checks in place. + + // Bytecode ops (serialized & in-memory) + CHECK(LOP_FASTCALL2K == 75); // bytecode v1 + CHECK(LOP_JUMPXEQKS == 80); // bytecode v3 + + // Bytecode fastcall ids (serialized & in-memory) + // Note: these aren't strictly bound to specific bytecode versions, but must monotonically increase to keep backwards compat + CHECK(LBF_VECTOR == 54); + CHECK(LBF_TOSTRING == 63); + + // Bytecode capture type (serialized & in-memory) + CHECK(LCT_UPVAL == 2); // bytecode v1 + + // Bytecode constants (serialized) + CHECK(LBC_CONSTANT_CLOSURE == 6); // bytecode v1 + + // Bytecode type encoding (serialized & in-memory) + // Note: these *can* change retroactively *if* type version is bumped, but probably shouldn't + LUAU_ASSERT(LBC_TYPE_BUFFER == 9); // type version 1 +} + TEST_CASE("CompileToBytecode") { Luau::BytecodeBuilder bcb; @@ -156,7 +211,8 @@ RETURN R0 0 TEST_CASE("ReflectionBytecode") { - CHECK_EQ("\n" + compileFunction0(R"( + CHECK_EQ( + "\n" + compileFunction0(R"( local part = Instance.new('Part', workspace) part.Size = Vector3.new(1, 2, 3) return part.Size.Z * part:GetMass() @@ -178,7 +234,8 @@ NAMECALL R3 R0 K10 ['GetMass'] CALL R3 1 1 MUL R1 R2 R3 RETURN R1 1 -)"); +)" + ); } TEST_CASE("ImportCall") @@ -636,7 +693,8 @@ RETURN R0 1 TEST_CASE("TableLiteralsIndexConstant") { // validate that we use SETTTABLEKS for constant variable keys - CHECK_EQ("\n" + compileFunction0(R"( + CHECK_EQ( + "\n" + compileFunction0(R"( local a, b = "key", "value" return {[a] = 42, [b] = 0} )"), @@ -647,10 +705,12 @@ SETTABLEKS R1 R0 K0 ['key'] LOADN R1 0 SETTABLEKS R1 R0 K1 ['value'] RETURN R0 1 -)"); +)" + ); // validate that we use SETTABLEN for constant variable keys *and* that we predict array size - CHECK_EQ("\n" + compileFunction0(R"( + CHECK_EQ( + "\n" + compileFunction0(R"( local a, b = 1, 2 return {[a] = 42, [b] = 0} )"), @@ -661,12 +721,14 @@ SETTABLEN R1 R0 1 LOADN R1 0 SETTABLEN R1 R0 2 RETURN R0 1 -)"); +)" + ); } TEST_CASE("TableSizePredictionBasic") { - CHECK_EQ("\n" + compileFunction0(R"( + CHECK_EQ( + "\n" + compileFunction0(R"( local t = {} t.a = 1 t.b = 1 @@ -699,9 +761,11 @@ SETTABLEKS R1 R0 K7 ['h'] LOADN R1 1 SETTABLEKS R1 R0 K8 ['i'] RETURN R0 0 -)"); +)" + ); - CHECK_EQ("\n" + compileFunction0(R"( + CHECK_EQ( + "\n" + compileFunction0(R"( local t = {} t.x = 1 t.x = 2 @@ -734,9 +798,11 @@ SETTABLEKS R1 R0 K0 ['x'] LOADN R1 9 SETTABLEKS R1 R0 K0 ['x'] RETURN R0 0 -)"); +)" + ); - CHECK_EQ("\n" + compileFunction0(R"( + CHECK_EQ( + "\n" + compileFunction0(R"( local t = {} t[1] = 1 t[2] = 1 @@ -772,12 +838,15 @@ SETTABLEN R1 R0 9 LOADN R1 1 SETTABLEN R1 R0 10 RETURN R0 0 -)"); +)" + ); } TEST_CASE("TableSizePredictionObject") { - CHECK_EQ("\n" + compileFunction(R"( + CHECK_EQ( + "\n" + compileFunction( + R"( local t = {} t.field = 1 function t:getfield() @@ -785,7 +854,8 @@ function t:getfield() end return t )", - 1), + 1 + ), R"( NEWTABLE R0 2 0 LOADN R1 1 @@ -793,12 +863,14 @@ SETTABLEKS R1 R0 K0 ['field'] DUPCLOSURE R1 K1 ['getfield'] SETTABLEKS R1 R0 K2 ['getfield'] RETURN R0 1 -)"); +)" + ); } TEST_CASE("TableSizePredictionSetMetatable") { - CHECK_EQ("\n" + compileFunction0(R"( + CHECK_EQ( + "\n" + compileFunction0(R"( local t = setmetatable({}, nil) t.field1 = 1 t.field2 = 2 @@ -815,12 +887,14 @@ SETTABLEKS R1 R0 K3 ['field1'] LOADN R1 2 SETTABLEKS R1 R0 K4 ['field2'] RETURN R0 1 -)"); +)" + ); } TEST_CASE("TableSizePredictionLoop") { - CHECK_EQ("\n" + compileFunction0(R"( + CHECK_EQ( + "\n" + compileFunction0(R"( local t = {} for i=1,4 do t[i] = 0 @@ -837,7 +911,8 @@ L0: LOADN R4 0 SETTABLE R4 R0 R3 FORNLOOP R1 L0 L1: RETURN R0 1 -)"); +)" + ); } TEST_CASE("ReflectionEnums") @@ -1101,33 +1176,27 @@ L0: RETURN R1 1 TEST_CASE("AndOrFoldLeft") { // constant folding and/or expression is possible even if just the left hand is constant - CHECK_EQ("\n" + compileFunction0("local a = false if a and b then b() end"), R"( -RETURN R0 0 + CHECK_EQ("\n" + compileFunction0("local a = false return a and b"), R"( +LOADB R0 0 +RETURN R0 1 )"); - CHECK_EQ("\n" + compileFunction0("local a = true if a or b then b() end"), R"( -GETIMPORT R0 1 [b] -CALL R0 0 0 -RETURN R0 0 + CHECK_EQ("\n" + compileFunction0("local a = true return a or b"), R"( +LOADB R0 1 +RETURN R0 1 )"); - // however, if right hand side is constant we can't constant fold the entire expression - // (note that we don't need to evaluate the right hand side, but we do need a branch) - CHECK_EQ("\n" + compileFunction0("local a = false if b and a then b() end"), R"( -GETIMPORT R0 1 [b] -JUMPIFNOT R0 L0 -RETURN R0 0 -GETIMPORT R0 1 [b] -CALL R0 0 0 -L0: RETURN R0 0 + // if right hand side is constant we can't constant fold the entire expression + CHECK_EQ("\n" + compileFunction0("local a = false return b and a"), R"( +GETIMPORT R1 2 [b] +ANDK R0 R1 K0 [false] +RETURN R0 1 )"); - CHECK_EQ("\n" + compileFunction0("local a = true if b or a then b() end"), R"( -GETIMPORT R0 1 [b] -JUMPIF R0 L0 -L0: GETIMPORT R0 1 [b] -CALL R0 0 0 -RETURN R0 0 + CHECK_EQ("\n" + compileFunction0("local a = true return b or a"), R"( +GETIMPORT R1 2 [b] +ORK R0 R1 K0 [true] +RETURN R0 1 )"); } @@ -1141,23 +1210,22 @@ TEST_CASE("AndOrChainCodegen") )"; CHECK_EQ("\n" + compileFunction0(source), R"( -LOADN R2 1 -GETIMPORT R3 1 [verticalGradientTurbulence] -SUB R1 R2 R3 -GETIMPORT R3 4 [waterLevel] -ADDK R2 R3 K2 [0.014999999999999999] +GETIMPORT R2 2 [verticalGradientTurbulence] +SUBRK R1 K0 [1] R2 +GETIMPORT R3 5 [waterLevel] +ADDK R2 R3 K3 [0.014999999999999999] JUMPIFNOTLT R1 R2 L0 -GETIMPORT R0 8 [Enum.Material.Sand] +GETIMPORT R0 9 [Enum.Material.Sand] JUMPIF R0 L2 -L0: GETIMPORT R1 10 [sandbank] +L0: GETIMPORT R1 11 [sandbank] LOADN R2 0 JUMPIFNOTLT R2 R1 L1 -GETIMPORT R1 10 [sandbank] +GETIMPORT R1 11 [sandbank] LOADN R2 1 JUMPIFNOTLT R1 R2 L1 -GETIMPORT R0 8 [Enum.Material.Sand] +GETIMPORT R0 9 [Enum.Material.Sand] JUMPIF R0 L2 -L1: GETIMPORT R0 12 [Enum.Material.Sandstone] +L1: GETIMPORT R0 13 [Enum.Material.Sandstone] L2: RETURN R0 1 )"); } @@ -1280,7 +1348,8 @@ TEST_CASE("InterpStringWithNoExpressions") TEST_CASE("InterpStringZeroCost") { - CHECK_EQ("\n" + compileFunction0(R"(local _ = `hello, {"world"}!`)"), + CHECK_EQ( + "\n" + compileFunction0(R"(local _ = `hello, {"world"}!`)"), R"( LOADK R1 K0 ['hello, %*!'] LOADK R3 K1 ['world'] @@ -1288,12 +1357,14 @@ NAMECALL R1 R1 K2 ['format'] CALL R1 2 1 MOVE R0 R1 RETURN R0 0 -)"); +)" + ); } TEST_CASE("InterpStringRegisterCleanup") { - CHECK_EQ("\n" + compileFunction0(R"( + CHECK_EQ( + "\n" + compileFunction0(R"( local a, b, c = nil, "um", "uh oh" a = `foo{"bar"}` print(a) @@ -1312,7 +1383,8 @@ GETIMPORT R3 6 [print] MOVE R4 R0 CALL R3 1 0 RETURN R0 0 -)"); +)" + ); } TEST_CASE("InterpStringRegisterLimit") @@ -1677,8 +1749,6 @@ RETURN R0 0 TEST_CASE("LoopBreak") { - ScopedFastFlag sff("LuauCompileTerminateBC", true); - // default codegen: compile breaks as unconditional jumps CHECK_EQ("\n" + compileFunction0("while true do if math.random() < 0.5 then break else end end"), R"( L0: GETIMPORT R0 2 [math.random] @@ -1703,8 +1773,6 @@ L1: RETURN R0 0 TEST_CASE("LoopContinue") { - ScopedFastFlag sff("LuauCompileTerminateBC", true); - // default codegen: compile continue as unconditional jumps CHECK_EQ("\n" + compileFunction0("repeat if math.random() < 0.5 then continue else end break until false error()"), R"( L0: GETIMPORT R0 2 [math.random] @@ -1768,13 +1836,15 @@ until rr < 0.5 { CHECK_EQ(e.getLocation().begin.line + 1, 8); CHECK_EQ( - std::string(e.what()), "Local rr used in the repeat..until condition is undefined because continue statement on line 5 jumps over it"); + std::string(e.what()), "Local rr used in the repeat..until condition is undefined because continue statement on line 5 jumps over it" + ); } // but it's okay if continue is inside a non-repeat..until loop, or inside a loop that doesn't use the local (here `continue` just terminates // inner loop) - CHECK_EQ("\n" + compileFunction0( - "repeat local r = math.random() repeat if r > 0.5 then continue end r = r - 0.1 until true r = r + 0.3 until r < 0.5"), + CHECK_EQ( + "\n" + + compileFunction0("repeat local r = math.random() repeat if r > 0.5 then continue end r = r - 0.1 until true r = r + 0.3 until r < 0.5"), R"( L0: GETIMPORT R0 2 [math.random] CALL R0 0 1 @@ -1786,29 +1856,30 @@ LOADK R1 K3 [0.5] JUMPIFLT R0 R1 L2 JUMPBACK L0 L2: RETURN R0 0 -)"); +)" + ); // and it's also okay to use a local defined in the until expression as long as it's inside a function! CHECK_EQ( "\n" + compileFunction( - "repeat local r = math.random() if r > 0.5 then continue end r = r + 0.3 until (function() local a = r return a < 0.5 end)()", 1), + "repeat local r = math.random() if r > 0.5 then continue end r = r + 0.3 until (function() local a = r return a < 0.5 end)()", 1 + ), R"( L0: GETIMPORT R0 2 [math.random] CALL R0 0 1 LOADK R1 K3 [0.5] -JUMPIFNOTLT R1 R0 L1 -CLOSEUPVALS R0 -JUMP L2 -L1: ADDK R0 R0 K4 [0.29999999999999999] -L2: NEWCLOSURE R1 P0 +JUMPIFLT R1 R0 L1 +ADDK R0 R0 K4 [0.29999999999999999] +L1: NEWCLOSURE R1 P0 CAPTURE REF R0 CALL R1 0 1 -JUMPIF R1 L3 +JUMPIF R1 L2 CLOSEUPVALS R0 JUMPBACK L0 -L3: CLOSEUPVALS R0 +L2: CLOSEUPVALS R0 RETURN R0 0 -)"); +)" + ); // but not if the function just refers to an upvalue try @@ -1830,12 +1901,14 @@ until (function() return rr end)() < 0.5 { CHECK_EQ(e.getLocation().begin.line + 1, 8); CHECK_EQ( - std::string(e.what()), "Local rr used in the repeat..until condition is undefined because continue statement on line 5 jumps over it"); + std::string(e.what()), "Local rr used in the repeat..until condition is undefined because continue statement on line 5 jumps over it" + ); } // unless that upvalue is from an outer scope - CHECK_EQ("\n" + compileFunction0("local stop = false stop = true function test() repeat local r = math.random() if r > 0.5 then " - "continue end r = r + 0.3 until stop or r < 0.5 end"), + CHECK_EQ( + "\n" + compileFunction0("local stop = false stop = true function test() repeat local r = math.random() if r > 0.5 then " + "continue end r = r + 0.3 until stop or r < 0.5 end"), R"( L0: GETIMPORT R0 2 [math.random] CALL R0 0 1 @@ -1848,92 +1921,335 @@ LOADK R1 K3 [0.5] JUMPIFLT R0 R1 L2 JUMPBACK L0 L2: RETURN R0 0 -)"); +)" + ); // including upvalue references from a function expression - CHECK_EQ("\n" + compileFunction("local stop = false stop = true function test() repeat local r = math.random() if r > 0.5 then continue " - "end r = r + 0.3 until (function() return stop or r < 0.5 end)() end", - 1), + CHECK_EQ( + "\n" + compileFunction( + "local stop = false stop = true function test() repeat local r = math.random() if r > 0.5 then continue " + "end r = r + 0.3 until (function() return stop or r < 0.5 end)() end", + 1 + ), R"( L0: GETIMPORT R0 2 [math.random] CALL R0 0 1 LOADK R1 K3 [0.5] -JUMPIFNOTLT R1 R0 L1 -CLOSEUPVALS R0 -JUMP L2 -L1: ADDK R0 R0 K4 [0.29999999999999999] -L2: NEWCLOSURE R1 P0 +JUMPIFLT R1 R0 L1 +ADDK R0 R0 K4 [0.29999999999999999] +L1: NEWCLOSURE R1 P0 CAPTURE UPVAL U0 CAPTURE REF R0 CALL R1 0 1 -JUMPIF R1 L3 +JUMPIF R1 L2 CLOSEUPVALS R0 JUMPBACK L0 -L3: CLOSEUPVALS R0 +L2: CLOSEUPVALS R0 RETURN R0 0 -)"); +)" + ); } -TEST_CASE("LoopContinueUntilOops") +TEST_CASE("LoopContinueIgnoresImplicitConstant") { // this used to crash the compiler :( + CHECK_EQ( + "\n" + compileFunction0(R"( +local _ +repeat +continue +until not _ +)"), + R"( +RETURN R0 0 +RETURN R0 0 +)" + ); +} + +TEST_CASE("LoopContinueIgnoresExplicitConstant") +{ + // Constants do not allocate locals and 'continue' validation should skip them if their lifetime already started + CHECK_EQ( + "\n" + compileFunction0(R"( +local c = true +repeat + continue +until c +)"), + R"( +RETURN R0 0 +RETURN R0 0 +)" + ); +} + +TEST_CASE("LoopContinueRespectsExplicitConstant") +{ + // If local lifetime hasn't started, even if it's a constant that will not receive an allocation, it cannot be jumped over try { Luau::BytecodeBuilder bcb; Luau::compileOrThrow(bcb, R"( -local _ repeat -continue -until not _ + do continue end + + local c = true +until c )"); + + CHECK(!"Expected CompileError"); + } + catch (Luau::CompileError& e) + { + CHECK_EQ(e.getLocation().begin.line + 1, 6); + CHECK_EQ( + std::string(e.what()), "Local c used in the repeat..until condition is undefined because continue statement on line 3 jumps over it" + ); + } +} + +TEST_CASE("LoopContinueIgnoresImplicitConstantAfterInline") +{ + // Inlining might also replace some locals with constants instead of allocating them + CHECK_EQ( + "\n" + compileFunction( + R"( +local function inline(f) + repeat + continue + until f +end + +local function test(...) + inline(true) +end + +test() +)", + 1, + 2 + ), + R"( +RETURN R0 0 +RETURN R0 0 +)" + ); +} + +TEST_CASE("LoopContinueCorrectlyHandlesImplicitConstantAfterUnroll") +{ + ScopedFastInt sfi(FInt::LuauCompileLoopUnrollThreshold, 200); + + // access to implicit constant that depends on the unrolled loop constant is still invalid even though we can constant-propagate it + try + { + compileFunction( + R"( +for i = 1, 2 do + s() + repeat + if i == 2 then + continue + end + local x = i == 1 or a + until f(x) +end +)", + 0, + 2 + ); + + CHECK(!"Expected CompileError"); } catch (Luau::CompileError& e) { + CHECK_EQ(e.getLocation().begin.line + 1, 9); CHECK_EQ( - std::string(e.what()), "Local _ used in the repeat..until condition is undefined because continue statement on line 4 jumps over it"); + std::string(e.what()), "Local x used in the repeat..until condition is undefined because continue statement on line 6 jumps over it" + ); } } +TEST_CASE("LoopContinueUntilCapture") +{ + // validate continue upvalue closing behavior: continue must close locals defined in the nested scopes + // but can't close locals defined in the loop scope - these are visible to the condition and will be closed + // when evaluating the condition instead. + CHECK_EQ( + "\n" + compileFunction( + R"( +local a a = 0 +repeat + local b b = 0 + if a then + local c + print(function() c = 0 end) + if a then + continue -- must close c but not a/b + end + -- must close c + end + -- must close b but not a +until function() a = 0 b = 0 end +-- must close b on loop exit +-- must close a +)", + 2 + ), + R"( +LOADNIL R0 +LOADN R0 0 +L0: LOADNIL R1 +LOADN R1 0 +JUMPIFNOT R0 L2 +LOADNIL R2 +GETIMPORT R3 1 [print] +NEWCLOSURE R4 P0 +CAPTURE REF R2 +CALL R3 1 0 +JUMPIFNOT R0 L1 +CLOSEUPVALS R2 +JUMP L2 +L1: CLOSEUPVALS R2 +L2: NEWCLOSURE R2 P1 +CAPTURE REF R0 +CAPTURE REF R1 +JUMPIF R2 L3 +CLOSEUPVALS R1 +JUMPBACK L0 +L3: CLOSEUPVALS R1 +CLOSEUPVALS R0 +RETURN R0 0 +)" + ); + + // a simpler version of the above test doesn't need to close anything when evaluating continue + CHECK_EQ( + "\n" + compileFunction( + R"( +local a a = 0 +repeat + local b b = 0 + if a then + continue -- must not close a/b + end + -- must close b but not a +until function() a = 0 b = 0 end +-- must close b on loop exit +-- must close a +)", + 1 + ), + R"( +LOADNIL R0 +LOADN R0 0 +L0: LOADNIL R1 +LOADN R1 0 +JUMPIF R0 L1 +L1: NEWCLOSURE R2 P0 +CAPTURE REF R0 +CAPTURE REF R1 +JUMPIF R2 L2 +CLOSEUPVALS R1 +JUMPBACK L0 +L2: CLOSEUPVALS R1 +CLOSEUPVALS R0 +RETURN R0 0 +)" + ); +} + +TEST_CASE("LoopContinueEarlyCleanup") +{ + // locals after a potential 'continue' are not accessible inside the condition and can be closed at the end of a block + CHECK_EQ( + "\n" + compileFunction( + R"( +local y +repeat + local a, b + do continue end + local c, d + local function x() + return a + b + c + d + end + + c = 2 + a = 4 + + y = x +until a +)", + 1 + ), + R"( +LOADNIL R0 +L0: LOADNIL R1 +LOADNIL R2 +JUMP L1 +LOADNIL R3 +LOADNIL R4 +NEWCLOSURE R5 P0 +CAPTURE REF R1 +CAPTURE REF R3 +LOADN R3 2 +LOADN R1 4 +MOVE R0 R5 +CLOSEUPVALS R3 +L1: JUMPIF R1 L2 +CLOSEUPVALS R1 +JUMPBACK L0 +L2: CLOSEUPVALS R1 +RETURN R0 0 +)" + ); +} + TEST_CASE("AndOrOptimizations") { // the OR/ORK optimization triggers for cutoff since lhs is simple - CHECK_EQ("\n" + compileFunction(R"( + CHECK_EQ( + "\n" + compileFunction( + R"( local function advancedRidgedFilter(value, cutoff) local cutoff = cutoff or .5 value = value - cutoff return 1 - (value < 0 and -value or value) * 1 / (1 - cutoff) end )", - 0), + 0 + ), R"( ORK R2 R1 K0 [0.5] SUB R0 R0 R2 -LOADN R4 1 -LOADN R8 0 -JUMPIFNOTLT R0 R8 L0 -MINUS R7 R0 -JUMPIF R7 L1 -L0: MOVE R7 R0 -L1: MULK R6 R7 K1 [1] -LOADN R8 1 -SUB R7 R8 R2 -DIV R5 R6 R7 -SUB R3 R4 R5 +LOADN R7 0 +JUMPIFNOTLT R0 R7 L0 +MINUS R6 R0 +JUMPIF R6 L1 +L0: MOVE R6 R0 +L1: MULK R5 R6 K1 [1] +SUBRK R6 K1 [1] R2 +DIV R4 R5 R6 +SUBRK R3 K1 [1] R4 RETURN R3 1 -)"); +)" + ); // sometimes we need to compute a boolean; this uses LOADB with an offset - CHECK_EQ("\n" + compileFunction(R"( + CHECK_EQ( + "\n" + compileFunction( + R"( function thinSurface(surfaceGradient, surfaceThickness) return surfaceGradient > .5 - surfaceThickness*.4 and surfaceGradient < .5 + surfaceThickness*.4 end )", - 0), + 0 + ), R"( LOADB R2 0 -LOADK R4 K0 [0.5] -MULK R5 R1 K1 [0.40000000000000002] -SUB R3 R4 R5 +MULK R4 R1 K1 [0.40000000000000002] +SUBRK R3 K0 [0.5] R4 JUMPIFNOTLT R3 R0 L1 LOADK R4 K0 [0.5] MULK R5 R1 K1 [0.40000000000000002] @@ -1942,20 +2258,23 @@ JUMPIFLT R0 R3 L0 LOADB R2 0 +1 L0: LOADB R2 1 L1: RETURN R2 1 -)"); +)" + ); // sometimes we need to compute a boolean; this uses LOADB with an offset for the last op, note that first op is compiled better - CHECK_EQ("\n" + compileFunction(R"( + CHECK_EQ( + "\n" + compileFunction( + R"( function thickSurface(surfaceGradient, surfaceThickness) return surfaceGradient < .5 - surfaceThickness*.4 or surfaceGradient > .5 + surfaceThickness*.4 end )", - 0), + 0 + ), R"( LOADB R2 1 -LOADK R4 K0 [0.5] -MULK R5 R1 K1 [0.40000000000000002] -SUB R3 R4 R5 +MULK R4 R1 K1 [0.40000000000000002] +SUBRK R3 K0 [0.5] R4 JUMPIFLT R0 R3 L1 LOADK R4 K0 [0.5] MULK R5 R1 K1 [0.40000000000000002] @@ -1964,30 +2283,38 @@ JUMPIFLT R3 R0 L0 LOADB R2 0 +1 L0: LOADB R2 1 L1: RETURN R2 1 -)"); +)" + ); // trivial ternary if with constants - CHECK_EQ("\n" + compileFunction(R"( + CHECK_EQ( + "\n" + compileFunction( + R"( function testSurface(surface) return surface and 1 or 0 end )", - 0), + 0 + ), R"( JUMPIFNOT R0 L0 LOADN R1 1 RETURN R1 1 L0: LOADN R1 0 RETURN R1 1 -)"); +)" + ); // canonical saturate - CHECK_EQ("\n" + compileFunction(R"( + CHECK_EQ( + "\n" + compileFunction( + R"( function saturate(x) return x < 0 and 0 or x > 1 and 1 or x end )", - 0), + 0 + ), R"( LOADN R2 0 JUMPIFNOTLT R0 R2 L0 @@ -1999,7 +2326,8 @@ LOADN R1 1 RETURN R1 1 L1: MOVE R1 R0 RETURN R1 1 -)"); +)" + ); } TEST_CASE("JumpFold") @@ -2044,7 +2372,9 @@ L0: RETURN R0 0 // in this example, we do *not* have a JUMP after RETURN in the if branch // this is important since, even though this jump is never reached, jump folding needs to be able to analyze it - CHECK_EQ("\n" + compileFunction(R"( + CHECK_EQ( + "\n" + compileFunction( + R"( local function getPerlin(x, y, z, seed, scale, raw) local seed = seed or 0 local scale = scale or 1 @@ -2055,7 +2385,8 @@ return math.noise(x / scale + (seed * 17) + masterSeed, y / scale - masterSeed, end end )", - 0), + 0 + ), R"( ORK R6 R3 K0 [0] ORK R7 R4 K1 [1] @@ -2090,7 +2421,8 @@ MUL R13 R6 R6 SUB R11 R12 R13 CALL R8 3 -1 RETURN R8 -1 -)"); +)" + ); } TEST_CASE("RecursionParse") @@ -2098,9 +2430,9 @@ TEST_CASE("RecursionParse") // The test forcibly pushes the stack limit during compilation; in NoOpt, the stack consumption is much larger so we need to reduce the limit to // not overflow the C stack. When ASAN is enabled, stack consumption increases even more. #if defined(LUAU_ENABLE_ASAN) - ScopedFastInt flag("LuauRecursionLimit", 200); + ScopedFastInt flag(FInt::LuauRecursionLimit, 200); #elif defined(_NOOPT) || defined(_DEBUG) - ScopedFastInt flag("LuauRecursionLimit", 300); + ScopedFastInt flag(FInt::LuauRecursionLimit, 300); #endif Luau::BytecodeBuilder bcb; @@ -2214,19 +2546,59 @@ TEST_CASE("RecursionParse") { CHECK_EQ(std::string(e.what()), "Exceeded allowed recursion depth; simplify your block to make the code compile"); } -} -TEST_CASE("ArrayIndexLiteral") -{ - CHECK_EQ("\n" + compileFunction0("local arr = {} return arr[0], arr[1], arr[256], arr[257]"), R"( -NEWTABLE R0 0 0 -LOADN R2 0 -GETTABLE R1 R0 R2 -GETTABLEN R2 R0 1 -GETTABLEN R3 R0 256 -LOADN R5 257 -GETTABLE R4 R0 R5 -RETURN R1 4 + try + { + Luau::compileOrThrow(bcb, "local f: " + rep("(", 1500) + "nil" + rep(")", 1500)); + CHECK(!"Expected exception"); + } + catch (std::exception& e) + { + CHECK_EQ(std::string(e.what()), "Exceeded allowed recursion depth; simplify your type annotation to make the code compile"); + } + + try + { + Luau::compileOrThrow(bcb, "local f: () " + rep("-> ()", 1500)); + CHECK(!"Expected exception"); + } + catch (std::exception& e) + { + CHECK_EQ(std::string(e.what()), "Exceeded allowed recursion depth; simplify your type annotation to make the code compile"); + } + + try + { + Luau::compileOrThrow(bcb, "local f: " + rep("{x:", 1500) + "nil" + rep("}", 1500)); + CHECK(!"Expected exception"); + } + catch (std::exception& e) + { + CHECK_EQ(std::string(e.what()), "Exceeded allowed recursion depth; simplify your type annotation to make the code compile"); + } + + try + { + Luau::compileOrThrow(bcb, "local f: " + rep("(nil & ", 1500) + "nil" + rep(")", 1500)); + CHECK(!"Expected exception"); + } + catch (std::exception& e) + { + CHECK_EQ(std::string(e.what()), "Exceeded allowed recursion depth; simplify your type annotation to make the code compile"); + } +} + +TEST_CASE("ArrayIndexLiteral") +{ + CHECK_EQ("\n" + compileFunction0("local arr = {} return arr[0], arr[1], arr[256], arr[257]"), R"( +NEWTABLE R0 0 0 +LOADN R2 0 +GETTABLE R1 R0 R2 +GETTABLEN R2 R0 1 +GETTABLEN R3 R0 256 +LOADN R5 257 +GETTABLE R4 R0 R5 +RETURN R1 4 )"); CHECK_EQ("\n" + compileFunction0("local arr = {} local b = ... arr[0] = b arr[1] = b arr[256] = b arr[257] = b"), R"( @@ -2260,7 +2632,9 @@ L1: RETURN R3 -1 TEST_CASE("UpvaluesLoopsBytecode") { - CHECK_EQ("\n" + compileFunction(R"( + CHECK_EQ( + "\n" + compileFunction( + R"( function test() for i=1,10 do i = i @@ -2272,7 +2646,8 @@ function test() return 0 end )", - 1), + 1 + ), R"( LOADN R2 1 LOADN R0 10 @@ -2291,9 +2666,12 @@ L1: CLOSEUPVALS R3 FORNLOOP R0 L0 L2: LOADN R0 0 RETURN R0 1 -)"); +)" + ); - CHECK_EQ("\n" + compileFunction(R"( + CHECK_EQ( + "\n" + compileFunction( + R"( function test() for i in ipairs(data) do i = i @@ -2305,7 +2683,8 @@ function test() return 0 end )", - 1), + 1 + ), R"( GETIMPORT R0 1 [ipairs] GETIMPORT R1 3 [data] @@ -2323,9 +2702,12 @@ L1: CLOSEUPVALS R3 L2: FORGLOOP R0 L0 1 [inext] L3: LOADN R0 0 RETURN R0 1 -)"); +)" + ); - CHECK_EQ("\n" + compileFunction(R"( + CHECK_EQ( + "\n" + compileFunction( + R"( function test() local i = 0 while i < 5 do @@ -2340,7 +2722,8 @@ function test() return 0 end )", - 1), + 1 + ), R"( LOADN R0 0 L0: LOADN R1 5 @@ -2360,9 +2743,12 @@ L1: CLOSEUPVALS R1 JUMPBACK L0 L2: LOADN R1 0 RETURN R1 1 -)"); +)" + ); - CHECK_EQ("\n" + compileFunction(R"( + CHECK_EQ( + "\n" + compileFunction( + R"( function test() local i = 0 repeat @@ -2377,7 +2763,8 @@ function test() return 0 end )", - 1), + 1 + ), R"( LOADN R0 0 L0: LOADNIL R1 @@ -2398,7 +2785,8 @@ JUMPBACK L0 L2: CLOSEUPVALS R1 L3: LOADN R1 0 RETURN R1 1 -)"); +)" + ); } TEST_CASE("TypeAliasing") @@ -2409,6 +2797,16 @@ TEST_CASE("TypeAliasing") CHECK_NOTHROW(Luau::compileOrThrow(bcb, "type A = number local a: A = 1", options, parseOptions)); } +TEST_CASE("TypeFunction") +{ + ScopedFastFlag sff{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; + + Luau::BytecodeBuilder bcb; + Luau::CompileOptions options; + Luau::ParseOptions parseOptions; + CHECK_NOTHROW(Luau::compileOrThrow(bcb, "type function a() return types.any end", options, parseOptions)); +} + TEST_CASE("DebugLineInfo") { Luau::BytecodeBuilder bcb; @@ -2519,7 +2917,7 @@ end 6: GETIMPORT R1 2 [print] 6: LOADK R2 K3 ['done!'] 6: CALL R1 1 0 -10: RETURN R0 0 +7: RETURN R0 0 3: L1: JUMPBACK L0 10: RETURN R0 0 )"); @@ -2527,7 +2925,9 @@ end TEST_CASE("DebugLineInfoRepeatUntil") { - CHECK_EQ("\n" + compileFunction0Coverage(R"( + CHECK_EQ( + "\n" + compileFunction0Coverage( + R"( local f = 0 repeat f += 1 @@ -2538,7 +2938,8 @@ repeat end until f == 0 )", - 0), + 0 + ), R"( 2: LOADN R0 0 4: L0: ADDK R0 R0 K0 [1] @@ -2551,7 +2952,8 @@ until f == 0 10: L2: JUMPXEQKN R0 K3 L3 [0] 10: JUMPBACK L0 11: L3: RETURN R0 0 -)"); +)" + ); } TEST_CASE("DebugLineInfoSubTable") @@ -2842,6 +3244,71 @@ local 8: reg 3, start pc 35 line 21, end pc 35 line 21 )"); } +TEST_CASE("DebugLocals2") +{ + const char* source = R"( +function foo(x) + repeat + local a, b + until true +end +)"; + + Luau::BytecodeBuilder bcb; + bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Code | Luau::BytecodeBuilder::Dump_Lines | Luau::BytecodeBuilder::Dump_Locals); + bcb.setDumpSource(source); + + Luau::CompileOptions options; + options.debugLevel = 2; + + Luau::compileOrThrow(bcb, source, options); + + CHECK_EQ("\n" + bcb.dumpFunction(0), R"( +local 0: reg 1, start pc 2 line 6, no live range +local 1: reg 2, start pc 2 line 6, no live range +local 2: reg 0, start pc 0 line 4, end pc 2 line 6 +4: LOADNIL R1 +4: LOADNIL R2 +6: RETURN R0 0 +)"); +} + +TEST_CASE("DebugLocals3") +{ + const char* source = R"( +function foo(x) + repeat + local a, b + do continue end + local c, d = 2 + until true +end +)"; + + Luau::BytecodeBuilder bcb; + bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Code | Luau::BytecodeBuilder::Dump_Lines | Luau::BytecodeBuilder::Dump_Locals); + bcb.setDumpSource(source); + + Luau::CompileOptions options; + options.debugLevel = 2; + + Luau::compileOrThrow(bcb, source, options); + + CHECK_EQ("\n" + bcb.dumpFunction(0), R"( +local 0: reg 3, start pc 5 line 8, no live range +local 1: reg 4, start pc 5 line 8, no live range +local 2: reg 1, start pc 2 line 5, end pc 4 line 6 +local 3: reg 2, start pc 2 line 5, end pc 4 line 6 +local 4: reg 0, start pc 0 line 4, end pc 5 line 8 +4: LOADNIL R1 +4: LOADNIL R2 +5: RETURN R0 0 +6: LOADN R3 2 +6: LOADNIL R4 +8: RETURN R0 0 +)"); +} + TEST_CASE("DebugRemarks") { Luau::BytecodeBuilder bcb; @@ -2869,6 +3336,76 @@ RETURN R0 0 )"); } +TEST_CASE("DebugTypes") +{ + const char* source = R"( +local up: number = 2 + +function foo(e: vector, f: mat3, g: sequence) + local h = e * e + + for i=1,3 do + print(i) + end + + print(e * f) + print(g) + print(h) + + up += a + return a +end +)"; + + Luau::BytecodeBuilder bcb; + bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Code | Luau::BytecodeBuilder::Dump_Types); + bcb.setDumpSource(source); + + Luau::CompileOptions options; + options.vectorCtor = "vector"; + options.vectorType = "vector"; + + options.typeInfoLevel = 1; + + static const char* kUserdataCompileTypes[] = {"vec2", "color", "mat3", nullptr}; + options.userdataTypes = kUserdataCompileTypes; + + Luau::compileOrThrow(bcb, source, options); + + CHECK_EQ("\n" + bcb.dumpFunction(0), R"( +R0: vector [argument] +R1: mat3 [argument] +R2: userdata [argument] +U0: number +R6: any from 1 to 9 +R3: vector from 0 to 30 +MUL R3 R0 R0 +LOADN R6 1 +LOADN R4 3 +LOADN R5 1 +FORNPREP R4 L1 +L0: GETIMPORT R7 1 [print] +MOVE R8 R6 +CALL R7 1 0 +FORNLOOP R4 L0 +L1: GETIMPORT R4 1 [print] +MUL R5 R0 R1 +CALL R4 1 0 +GETIMPORT R4 1 [print] +MOVE R5 R2 +CALL R4 1 0 +GETIMPORT R4 1 [print] +MOVE R5 R3 +CALL R4 1 0 +GETUPVAL R4 0 +GETIMPORT R5 3 [a] +ADD R4 R4 R5 +SETUPVAL R4 0 +GETIMPORT R4 3 [a] +RETURN R4 1 +)"); +} + TEST_CASE("SourceRemarks") { const char* source = R"( @@ -3058,6 +3595,33 @@ RETURN R1 -1 )"); } +TEST_CASE("Fastcall3") +{ + CHECK_EQ( + "\n" + compileFunction0(R"( +local a, b, c = ... +return math.min(a, b, c) + math.clamp(a, b, c) +)"), + R"( +GETVARARGS R0 3 +FASTCALL3 19 R0 R1 R2 L0 +MOVE R5 R0 +MOVE R6 R1 +MOVE R7 R2 +GETIMPORT R4 2 [math.min] +CALL R4 3 1 +L0: FASTCALL3 46 R0 R1 R2 L1 +MOVE R6 R0 +MOVE R7 R1 +MOVE R8 R2 +GETIMPORT R5 4 [math.clamp] +CALL R5 3 1 +L1: ADD R3 R4 R5 +RETURN R3 1 +)" + ); +} + TEST_CASE("FastcallSelect") { // select(_, ...) compiles to a builtin call @@ -3071,7 +3635,8 @@ L0: RETURN R0 1 )"); // more complex example: select inside a for loop bound + select from a iterator - CHECK_EQ("\n" + compileFunction0(R"( + CHECK_EQ( + "\n" + compileFunction0(R"( local sum = 0 for i=1, select('#', ...) do sum += select(i, ...) @@ -3097,7 +3662,8 @@ CALL R4 -1 1 L2: ADD R0 R0 R4 FORNLOOP R1 L1 L3: RETURN R0 1 -)"); +)" + ); // currently we assume a single value return to avoid dealing with stack resizing CHECK_EQ("\n" + compileFunction0("return select('#', ...)"), R"( @@ -3265,19 +3831,23 @@ RETURN R0 1 // multi-level recursive capture where function isn't top-level // note: this should probably be optimized to DUPCLOSURE but doing that requires a different upval tracking flow in the compiler - CHECK_EQ("\n" + compileFunction(R"( + CHECK_EQ( + "\n" + compileFunction( + R"( local function foo() local function bar() return function() return bar() end end end )", - 1), + 1 + ), R"( NEWCLOSURE R0 P0 CAPTURE UPVAL U0 RETURN R0 1 -)"); +)" + ); } TEST_CASE("OutOfLocals") @@ -3625,7 +4195,8 @@ TEST_CASE("CompileBytecode") TEST_CASE("NestedNamecall") { - CHECK_EQ("\n" + compileFunction0(R"( + CHECK_EQ( + "\n" + compileFunction0(R"( local obj = ... return obj:Method(1):Method(2):Method(3) )"), @@ -3641,23 +4212,27 @@ LOADN R3 3 NAMECALL R1 R1 K0 ['Method'] CALL R1 2 -1 RETURN R1 -1 -)"); +)" + ); } TEST_CASE("ElideLocals") { // simple local elision: all locals are constant - CHECK_EQ("\n" + compileFunction0(R"( + CHECK_EQ( + "\n" + compileFunction0(R"( local a, b = 1, 2 return a + b )"), R"( LOADN R0 3 RETURN R0 1 -)"); +)" + ); // side effecting expressions block local elision - CHECK_EQ("\n" + compileFunction0(R"( + CHECK_EQ( + "\n" + compileFunction0(R"( local a = g() return a )"), @@ -3665,10 +4240,12 @@ return a GETIMPORT R0 1 [g] CALL R0 0 1 RETURN R0 1 -)"); +)" + ); // ... even if they are not used - CHECK_EQ("\n" + compileFunction0(R"( + CHECK_EQ( + "\n" + compileFunction0(R"( local a = 1, g() return a )"), @@ -3677,12 +4254,14 @@ LOADN R0 1 GETIMPORT R1 1 [g] CALL R1 0 1 RETURN R0 1 -)"); +)" + ); } TEST_CASE("ConstantJumpCompare") { - CHECK_EQ("\n" + compileFunction0(R"( + CHECK_EQ( + "\n" + compileFunction0(R"( local obj = ... local b = obj == 1 )"), @@ -3692,9 +4271,11 @@ JUMPXEQKN R0 K0 L0 [1] LOADB R1 0 +1 L0: LOADB R1 1 L1: RETURN R0 0 -)"); +)" + ); - CHECK_EQ("\n" + compileFunction0(R"( + CHECK_EQ( + "\n" + compileFunction0(R"( local obj = ... local b = 1 == obj )"), @@ -3704,9 +4285,11 @@ JUMPXEQKN R0 K0 L0 [1] LOADB R1 0 +1 L0: LOADB R1 1 L1: RETURN R0 0 -)"); +)" + ); - CHECK_EQ("\n" + compileFunction0(R"( + CHECK_EQ( + "\n" + compileFunction0(R"( local obj = ... local b = "Hello, Sailor!" == obj )"), @@ -3716,9 +4299,11 @@ JUMPXEQKS R0 K0 L0 ['Hello, Sailor!'] LOADB R1 0 +1 L0: LOADB R1 1 L1: RETURN R0 0 -)"); +)" + ); - CHECK_EQ("\n" + compileFunction0(R"( + CHECK_EQ( + "\n" + compileFunction0(R"( local obj = ... local b = nil == obj )"), @@ -3728,9 +4313,11 @@ JUMPXEQKNIL R0 L0 LOADB R1 0 +1 L0: LOADB R1 1 L1: RETURN R0 0 -)"); +)" + ); - CHECK_EQ("\n" + compileFunction0(R"( + CHECK_EQ( + "\n" + compileFunction0(R"( local obj = ... local b = true == obj )"), @@ -3740,9 +4327,11 @@ JUMPXEQKB R0 1 L0 LOADB R1 0 +1 L0: LOADB R1 1 L1: RETURN R0 0 -)"); +)" + ); - CHECK_EQ("\n" + compileFunction0(R"( + CHECK_EQ( + "\n" + compileFunction0(R"( local obj = ... local b = nil ~= obj )"), @@ -3752,10 +4341,12 @@ JUMPXEQKNIL R0 L0 NOT LOADB R1 0 +1 L0: LOADB R1 1 L1: RETURN R0 0 -)"); +)" + ); // table literals should not generate IFEQK variants - CHECK_EQ("\n" + compileFunction0(R"( + CHECK_EQ( + "\n" + compileFunction0(R"( local obj = ... local b = obj == {} )"), @@ -3766,12 +4357,14 @@ JUMPIFEQ R0 R2 L0 LOADB R1 0 +1 L0: LOADB R1 1 L1: RETURN R0 0 -)"); +)" + ); } TEST_CASE("TableConstantStringIndex") { - CHECK_EQ("\n" + compileFunction0(R"( + CHECK_EQ( + "\n" + compileFunction0(R"( local t = { a = 2 } return t['a'] )"), @@ -3781,9 +4374,11 @@ LOADN R1 2 SETTABLEKS R1 R0 K0 ['a'] GETTABLEKS R1 R0 K0 ['a'] RETURN R1 1 -)"); +)" + ); - CHECK_EQ("\n" + compileFunction0(R"( + CHECK_EQ( + "\n" + compileFunction0(R"( local t = {} t['a'] = 2 )"), @@ -3792,17 +4387,21 @@ NEWTABLE R0 0 0 LOADN R1 2 SETTABLEKS R1 R0 K0 ['a'] RETURN R0 0 -)"); +)" + ); } TEST_CASE("Coverage") { // basic statement coverage - CHECK_EQ("\n" + compileFunction0Coverage(R"( + CHECK_EQ( + "\n" + compileFunction0Coverage( + R"( print(1) print(2) )", - 1), + 1 + ), R"( 2: COVERAGE 2: GETIMPORT R0 1 [print] @@ -3813,17 +4412,21 @@ print(2) 3: LOADN R1 2 3: CALL R0 1 0 4: RETURN R0 0 -)"); +)" + ); // branching - CHECK_EQ("\n" + compileFunction0Coverage(R"( + CHECK_EQ( + "\n" + compileFunction0Coverage( + R"( if x then print(1) else print(2) end )", - 1), + 1 + ), R"( 2: COVERAGE 2: GETIMPORT R0 1 [x] @@ -3832,17 +4435,20 @@ end 3: GETIMPORT R0 3 [print] 3: LOADN R1 1 3: CALL R0 1 0 -7: RETURN R0 0 +3: RETURN R0 0 5: L0: COVERAGE 5: GETIMPORT R0 3 [print] 5: LOADN R1 2 5: CALL R0 1 0 7: RETURN R0 0 -)"); +)" + ); // branching with comments // note that commented lines don't have COVERAGE insns! - CHECK_EQ("\n" + compileFunction0Coverage(R"( + CHECK_EQ( + "\n" + compileFunction0Coverage( + R"( if x then -- first print(1) @@ -3851,7 +4457,8 @@ else print(2) end )", - 1), + 1 + ), R"( 2: COVERAGE 2: GETIMPORT R0 1 [x] @@ -3860,17 +4467,20 @@ end 4: GETIMPORT R0 3 [print] 4: LOADN R1 1 4: CALL R0 1 0 -9: RETURN R0 0 +4: RETURN R0 0 7: L0: COVERAGE 7: GETIMPORT R0 3 [print] 7: LOADN R1 2 7: CALL R0 1 0 9: RETURN R0 0 -)"); +)" + ); // expression coverage for table literals // note: duplicate COVERAGE instructions are there since we don't deduplicate expr/stat - CHECK_EQ("\n" + compileFunction0Coverage(R"( + CHECK_EQ( + "\n" + compileFunction0Coverage( + R"( local c = ... local t = { a = 1, @@ -3878,7 +4488,8 @@ local t = { c = c } )", - 2), + 2 + ), R"( 2: COVERAGE 2: COVERAGE @@ -3897,52 +4508,68 @@ local t = { 6: COVERAGE 6: SETTABLEKS R0 R1 K2 ['c'] 8: RETURN R0 0 -)"); +)" + ); } TEST_CASE("ConstantClosure") { // closures without upvalues are created when bytecode is loaded - CHECK_EQ("\n" + compileFunction(R"( + CHECK_EQ( + "\n" + compileFunction( + R"( return function() end )", - 1), + 1 + ), R"( DUPCLOSURE R0 K0 [] RETURN R0 1 -)"); +)" + ); // they can access globals just fine - CHECK_EQ("\n" + compileFunction(R"( + CHECK_EQ( + "\n" + compileFunction( + R"( return function() print("hi") end )", - 1), + 1 + ), R"( DUPCLOSURE R0 K0 [] RETURN R0 1 -)"); +)" + ); // if they need upvalues, we can't create them before running the code (but see SharedClosure test) - CHECK_EQ("\n" + compileFunction(R"( + CHECK_EQ( + "\n" + compileFunction( + R"( function test() local print = print return function() print("hi") end end )", - 1), + 1 + ), R"( GETIMPORT R0 1 [print] NEWCLOSURE R1 P0 CAPTURE VAL R0 RETURN R1 1 -)"); +)" + ); // if they don't need upvalues but we sense that environment may be modified, we disable this to avoid fenv-related identity confusion - CHECK_EQ("\n" + compileFunction(R"( + CHECK_EQ( + "\n" + compileFunction( + R"( setfenv(1, {}) return function() print("hi") end )", - 1), + 1 + ), R"( GETIMPORT R0 1 [setfenv] LOADN R1 1 @@ -3950,39 +4577,50 @@ NEWTABLE R2 0 0 CALL R0 2 0 NEWCLOSURE R0 P0 RETURN R0 1 -)"); +)" + ); // note that fenv analysis isn't flow-sensitive right now, which is sort of a feature - CHECK_EQ("\n" + compileFunction(R"( + CHECK_EQ( + "\n" + compileFunction( + R"( if false then setfenv(1, {}) end return function() print("hi") end )", - 1), + 1 + ), R"( NEWCLOSURE R0 P0 RETURN R0 1 -)"); +)" + ); } TEST_CASE("SharedClosure") { // closures can be shared even if functions refer to upvalues, as long as upvalues are top-level - CHECK_EQ("\n" + compileFunction(R"( + CHECK_EQ( + "\n" + compileFunction( + R"( local val = ... local function foo() return function() return val end end )", - 1), + 1 + ), R"( DUPCLOSURE R0 K0 [] CAPTURE UPVAL U0 RETURN R0 1 -)"); +)" + ); // ... as long as the values aren't mutated. - CHECK_EQ("\n" + compileFunction(R"( + CHECK_EQ( + "\n" + compileFunction( + R"( local val = ... local function foo() @@ -3991,28 +4629,36 @@ end val = 5 )", - 1), + 1 + ), R"( NEWCLOSURE R0 P0 CAPTURE UPVAL U0 RETURN R0 1 -)"); +)" + ); // making the upvalue non-toplevel disables the optimization since it's likely that it will change - CHECK_EQ("\n" + compileFunction(R"( + CHECK_EQ( + "\n" + compileFunction( + R"( local function foo(val) return function() return val end end )", - 1), + 1 + ), R"( NEWCLOSURE R1 P0 CAPTURE VAL R0 RETURN R1 1 -)"); +)" + ); // the upvalue analysis is transitive through local functions, which allows for code reuse to not defeat the optimization - CHECK_EQ("\n" + compileFunction(R"( + CHECK_EQ( + "\n" + compileFunction( + R"( local val = ... local function foo() @@ -4023,17 +4669,21 @@ local function foo() return function() return bar() end end )", - 2), + 2 + ), R"( DUPCLOSURE R0 K0 ['bar'] CAPTURE UPVAL U0 DUPCLOSURE R1 K1 [] CAPTURE VAL R0 RETURN R1 1 -)"); +)" + ); // as such, if the upvalue that we reach transitively isn't top-level we fall back to newclosure - CHECK_EQ("\n" + compileFunction(R"( + CHECK_EQ( + "\n" + compileFunction( + R"( local function foo(val) local function bar() return val @@ -4042,14 +4692,16 @@ local function foo(val) return function() return bar() end end )", - 2), + 2 + ), R"( NEWCLOSURE R1 P0 CAPTURE VAL R0 NEWCLOSURE R2 P1 CAPTURE VAL R1 RETURN R2 1 -)"); +)" + ); // we also allow recursive function captures to share the object, even when it's not top-level CHECK_EQ("\n" + compileFunction("function test() local function foo() return foo() end end", 1), R"( @@ -4060,22 +4712,28 @@ RETURN R0 0 // multi-level recursive capture where function isn't top-level fails however. // note: this should probably be optimized to DUPCLOSURE but doing that requires a different upval tracking flow in the compiler - CHECK_EQ("\n" + compileFunction(R"( + CHECK_EQ( + "\n" + compileFunction( + R"( local function foo() local function bar() return function() return bar() end end end )", - 1), + 1 + ), R"( NEWCLOSURE R0 P0 CAPTURE UPVAL U0 RETURN R0 1 -)"); +)" + ); // top level upvalues inside loops should not be shared -- note that the bytecode below only uses NEWCLOSURE - CHECK_EQ("\n" + compileFunction(R"( + CHECK_EQ( + "\n" + compileFunction( + R"( for i=1,10 do print(function() return i end) end @@ -4089,7 +4747,8 @@ for i=1,10 do print(function() return j end) end )", - 3), + 3 + ), R"( LOADN R2 1 LOADN R0 10 @@ -4119,7 +4778,8 @@ CAPTURE VAL R2 CALL R3 1 0 FORNLOOP R0 L4 L5: RETURN R0 0 -)"); +)" + ); } TEST_CASE("MutableGlobals") @@ -4165,7 +4825,7 @@ RETURN R0 0 bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Code); Luau::CompileOptions options; const char* mutableGlobals[] = {"Game", "Workspace", "game", "plugin", "script", "shared", "workspace", NULL}; - options.mutableGlobals = &mutableGlobals[0]; + options.mutableGlobals = mutableGlobals; Luau::compileOrThrow(bcb, source, options); CHECK_EQ("\n" + bcb.dumpFunction(0), R"( @@ -4240,37 +4900,101 @@ L0: RETURN R0 -1 )"); } -TEST_CASE("TypeAssertion") +TEST_CASE("VectorFastCall3") { - // validate that type assertions work with the compiler and that the code inside type assertion isn't evaluated - CHECK_EQ("\n" + compileFunction0(R"( -print(foo() :: typeof(error("compile time"))) -)"), - R"( -GETIMPORT R0 1 [print] -GETIMPORT R1 3 [foo] -CALL R1 0 1 -CALL R0 1 0 -RETURN R0 0 -)"); + const char* source = R"( +local a, b, c = ... +return Vector3.new(a, b, c) +)"; - // note that above, foo() is treated as single-arg function; removing type assertion changes the bytecode - CHECK_EQ("\n" + compileFunction0(R"( -print(foo()) -)"), - R"( -GETIMPORT R0 1 [print] -GETIMPORT R1 3 [foo] -CALL R1 0 -1 -CALL R0 -1 0 + Luau::BytecodeBuilder bcb; + bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Code); + Luau::CompileOptions options; + options.vectorLib = "Vector3"; + options.vectorCtor = "new"; + Luau::compileOrThrow(bcb, source, options); + + CHECK_EQ("\n" + bcb.dumpFunction(0), R"( +GETVARARGS R0 3 +FASTCALL3 54 R0 R1 R2 L0 +MOVE R4 R0 +MOVE R5 R1 +MOVE R6 R2 +GETIMPORT R3 2 [Vector3.new] +CALL R3 3 -1 +L0: RETURN R3 -1 +)"); +} + +TEST_CASE("VectorLiterals") +{ + CHECK_EQ("\n" + compileFunction("return Vector3.new(1, 2, 3)", 0, 2, /*enableVectors*/ true), R"( +LOADK R0 K0 [1, 2, 3] +RETURN R0 1 +)"); + + CHECK_EQ("\n" + compileFunction("print(Vector3.new(1, 2, 3))", 0, 2, /*enableVectors*/ true), R"( +GETIMPORT R0 1 [print] +LOADK R1 K2 [1, 2, 3] +CALL R0 1 0 +RETURN R0 0 +)"); + + CHECK_EQ("\n" + compileFunction("print(Vector3.new(1, 2, 3, 4))", 0, 2, /*enableVectors*/ true), R"( +GETIMPORT R0 1 [print] +LOADK R1 K2 [1, 2, 3, 4] +CALL R0 1 0 RETURN R0 0 )"); + + CHECK_EQ("\n" + compileFunction("return Vector3.new(0, 0, 0), Vector3.new(-0, 0, 0)", 0, 2, /*enableVectors*/ true), R"( +LOADK R0 K0 [0, 0, 0] +LOADK R1 K1 [-0, 0, 0] +RETURN R0 2 +)"); + + CHECK_EQ("\n" + compileFunction("return type(Vector3.new(0, 0, 0))", 0, 2, /*enableVectors*/ true), R"( +LOADK R0 K0 ['vector'] +RETURN R0 1 +)"); +} + +TEST_CASE("TypeAssertion") +{ + // validate that type assertions work with the compiler and that the code inside type assertion isn't evaluated + CHECK_EQ( + "\n" + compileFunction0(R"( +print(foo() :: typeof(error("compile time"))) +)"), + R"( +GETIMPORT R0 1 [print] +GETIMPORT R1 3 [foo] +CALL R1 0 1 +CALL R0 1 0 +RETURN R0 0 +)" + ); + + // note that above, foo() is treated as single-arg function; removing type assertion changes the bytecode + CHECK_EQ( + "\n" + compileFunction0(R"( +print(foo()) +)"), + R"( +GETIMPORT R0 1 [print] +GETIMPORT R1 3 [foo] +CALL R1 0 -1 +CALL R0 -1 0 +RETURN R0 0 +)" + ); } TEST_CASE("Arithmetics") { // basic arithmetics codegen with non-constants - CHECK_EQ("\n" + compileFunction0(R"( + CHECK_EQ( + "\n" + compileFunction0(R"( local a, b = ... return a + b, a - b, a / b, a * b, a % b, a ^ b )"), @@ -4283,11 +5007,13 @@ MUL R5 R0 R1 MOD R6 R0 R1 POW R7 R0 R1 RETURN R2 6 -)"); +)" + ); // basic arithmetics codegen with constants on the right side // note that we don't simplify these expressions as we don't know the type of a - CHECK_EQ("\n" + compileFunction0(R"( + CHECK_EQ( + "\n" + compileFunction0(R"( local a = ... return a + 1, a - 1, a / 1, a * 1, a % 1, a ^ 1 )"), @@ -4300,20 +5026,25 @@ MULK R4 R0 K0 [1] MODK R5 R0 K0 [1] POWK R6 R0 K0 [1] RETURN R1 6 -)"); +)" + ); } TEST_CASE("LoopUnrollBasic") { // forward loops - CHECK_EQ("\n" + compileFunction(R"( + CHECK_EQ( + "\n" + compileFunction( + R"( local t = {} for i=1,2 do t[i] = i end return t )", - 0, 2), + 0, + 2 + ), R"( NEWTABLE R0 0 2 LOADN R1 1 @@ -4321,17 +5052,22 @@ SETTABLEN R1 R0 1 LOADN R1 2 SETTABLEN R1 R0 2 RETURN R0 1 -)"); +)" + ); // backward loops - CHECK_EQ("\n" + compileFunction(R"( + CHECK_EQ( + "\n" + compileFunction( + R"( local t = {} for i=2,1,-1 do t[i] = i end return t )", - 0, 2), + 0, + 2 + ), R"( NEWTABLE R0 0 0 LOADN R1 2 @@ -4339,17 +5075,22 @@ SETTABLEN R1 R0 2 LOADN R1 1 SETTABLEN R1 R0 1 RETURN R0 1 -)"); +)" + ); // loops with step that doesn't divide to-from - CHECK_EQ("\n" + compileFunction(R"( + CHECK_EQ( + "\n" + compileFunction( + R"( local t = {} for i=1,4,2 do t[i] = i end return t )", - 0, 2), + 0, + 2 + ), R"( NEWTABLE R0 0 0 LOADN R1 1 @@ -4357,23 +5098,31 @@ SETTABLEN R1 R0 1 LOADN R1 3 SETTABLEN R1 R0 3 RETURN R0 1 -)"); +)" + ); // empty loops - CHECK_EQ("\n" + compileFunction(R"( + CHECK_EQ( + "\n" + compileFunction( + R"( for i=2,1 do end )", - 0, 2), + 0, + 2 + ), R"( RETURN R0 0 -)"); +)" + ); } TEST_CASE("LoopUnrollNested") { // we can unroll nested loops just fine - CHECK_EQ("\n" + compileFunction(R"( + CHECK_EQ( + "\n" + compileFunction( + R"( local t = {} for i=0,1 do for j=0,1 do @@ -4381,7 +5130,9 @@ for i=0,1 do end end )", - 0, 2), + 0, + 2 + ), R"( NEWTABLE R0 0 0 LOADN R1 0 @@ -4393,10 +5144,13 @@ SETTABLEN R1 R0 3 LOADN R1 0 SETTABLEN R1 R0 4 RETURN R0 0 -)"); +)" + ); // if the inner loop is too expensive, we won't unroll the outer loop though, but we'll still unroll the inner loop! - CHECK_EQ("\n" + compileFunction(R"( + CHECK_EQ( + "\n" + compileFunction( + R"( local t = {} for i=0,3 do for j=0,3 do @@ -4404,7 +5158,9 @@ for i=0,3 do end end )", - 0, 2), + 0, + 2 + ), R"( NEWTABLE R0 0 0 LOADN R3 0 @@ -4429,10 +5185,13 @@ LOADN R5 0 SETTABLE R5 R0 R4 FORNLOOP R1 L0 L1: RETURN R0 0 -)"); +)" + ); // note, we sometimes can even unroll a loop with varying internal iterations - CHECK_EQ("\n" + compileFunction(R"( + CHECK_EQ( + "\n" + compileFunction( + R"( local t = {} for i=0,1 do for j=0,i do @@ -4440,7 +5199,9 @@ for i=0,1 do end end )", - 0, 2), + 0, + 2 + ), R"( NEWTABLE R0 0 0 LOADN R1 0 @@ -4450,17 +5211,22 @@ SETTABLEN R1 R0 3 LOADN R1 0 SETTABLEN R1 R0 4 RETURN R0 0 -)"); +)" + ); } TEST_CASE("LoopUnrollUnsupported") { // can't unroll loops with non-constant bounds - CHECK_EQ("\n" + compileFunction(R"( + CHECK_EQ( + "\n" + compileFunction( + R"( for i=x,y,z do end )", - 0, 2), + 0, + 2 + ), R"( GETIMPORT R2 1 [x] GETIMPORT R0 3 [y] @@ -4468,14 +5234,19 @@ GETIMPORT R1 5 [z] FORNPREP R0 L1 L0: FORNLOOP R0 L0 L1: RETURN R0 0 -)"); +)" + ); // can't unroll loops with bounds where we can't compute trip count - CHECK_EQ("\n" + compileFunction(R"( + CHECK_EQ( + "\n" + compileFunction( + R"( for i=1,1,0 do end )", - 0, 2), + 0, + 2 + ), R"( LOADN R2 1 LOADN R0 1 @@ -4483,14 +5254,19 @@ LOADN R1 0 FORNPREP R0 L1 L0: FORNLOOP R0 L0 L1: RETURN R0 0 -)"); +)" + ); // can't unroll loops with bounds that might be imprecise (non-integer) - CHECK_EQ("\n" + compileFunction(R"( + CHECK_EQ( + "\n" + compileFunction( + R"( for i=1,2,0.1 do end )", - 0, 2), + 0, + 2 + ), R"( LOADN R2 1 LOADN R0 2 @@ -4498,14 +5274,19 @@ LOADK R1 K0 [0.10000000000000001] FORNPREP R0 L1 L0: FORNLOOP R0 L0 L1: RETURN R0 0 -)"); +)" + ); // can't unroll loops if the bounds are too large, as it might overflow trip count math - CHECK_EQ("\n" + compileFunction(R"( + CHECK_EQ( + "\n" + compileFunction( + R"( for i=4294967295,4294967296 do end )", - 0, 2), + 0, + 2 + ), R"( LOADK R2 K0 [4294967295] LOADK R0 K1 [4294967296] @@ -4513,25 +5294,30 @@ LOADN R1 1 FORNPREP R0 L1 L0: FORNLOOP R0 L0 L1: RETURN R0 0 -)"); +)" + ); } TEST_CASE("LoopUnrollControlFlow") { ScopedFastInt sfis[] = { - {"LuauCompileLoopUnrollThreshold", 50}, - {"LuauCompileLoopUnrollThresholdMaxBoost", 300}, + {FInt::LuauCompileLoopUnrollThreshold, 50}, + {FInt::LuauCompileLoopUnrollThresholdMaxBoost, 300}, }; // break jumps to the end - CHECK_EQ("\n" + compileFunction(R"( + CHECK_EQ( + "\n" + compileFunction( + R"( for i=1,3 do if math.random() < 0.5 then break end end )", - 0, 2), + 0, + 2 + ), R"( GETIMPORT R0 2 [math.random] CALL R0 0 1 @@ -4546,10 +5332,13 @@ CALL R0 0 1 LOADK R1 K3 [0.5] JUMPIFLT R0 R1 L0 L0: RETURN R0 0 -)"); +)" + ); // continue jumps to the next iteration - CHECK_EQ("\n" + compileFunction(R"( + CHECK_EQ( + "\n" + compileFunction( + R"( for i=1,3 do if math.random() < 0.5 then continue @@ -4557,7 +5346,9 @@ for i=1,3 do print(i) end )", - 0, 2), + 0, + 2 + ), R"( GETIMPORT R0 2 [math.random] CALL R0 0 1 @@ -4581,10 +5372,13 @@ GETIMPORT R0 5 [print] LOADN R1 3 CALL R0 1 0 L2: RETURN R0 0 -)"); +)" + ); // continue needs to properly close upvalues - CHECK_EQ("\n" + compileFunction(R"( + CHECK_EQ( + "\n" + compileFunction( + R"( for i=1,1 do local j = global(i) print(function() return j end) @@ -4594,7 +5388,9 @@ for i=1,1 do j += 1 end )", - 1, 2), + 1, + 2 + ), R"( GETIMPORT R0 1 [global] LOADN R1 1 @@ -4612,10 +5408,13 @@ RETURN R0 0 L0: ADDK R0 R0 K8 [1] CLOSEUPVALS R0 RETURN R0 0 -)"); +)" + ); // this weird contraption just disappears - CHECK_EQ("\n" + compileFunction(R"( + CHECK_EQ( + "\n" + compileFunction( + R"( for i=1,1 do for j=1,1 do if i == 1 then @@ -4626,22 +5425,29 @@ for i=1,1 do end end )", - 0, 2), + 0, + 2 + ), R"( RETURN R0 0 RETURN R0 0 -)"); +)" + ); } TEST_CASE("LoopUnrollNestedClosure") { // if the body has functions that refer to loop variables, we unroll the loop and use MOVE+CAPTURE for upvalues - CHECK_EQ("\n" + compileFunction(R"( + CHECK_EQ( + "\n" + compileFunction( + R"( for i=1,2 do local x = function() return i end end )", - 1, 2), + 1, + 2 + ), R"( LOADN R1 1 NEWCLOSURE R0 P0 @@ -4650,27 +5456,30 @@ LOADN R1 2 NEWCLOSURE R0 P0 CAPTURE VAL R1 RETURN R0 0 -)"); +)" + ); } TEST_CASE("LoopUnrollCost") { - ScopedFastFlag sff("LuauCompileBuiltinArity", true); - ScopedFastInt sfis[] = { - {"LuauCompileLoopUnrollThreshold", 25}, - {"LuauCompileLoopUnrollThresholdMaxBoost", 300}, + {FInt::LuauCompileLoopUnrollThreshold, 25}, + {FInt::LuauCompileLoopUnrollThresholdMaxBoost, 300}, }; // loops with short body - CHECK_EQ("\n" + compileFunction(R"( + CHECK_EQ( + "\n" + compileFunction( + R"( local t = {} for i=1,10 do t[i] = i end return t )", - 0, 2), + 0, + 2 + ), R"( NEWTABLE R0 0 10 LOADN R1 1 @@ -4694,17 +5503,22 @@ SETTABLEN R1 R0 9 LOADN R1 10 SETTABLEN R1 R0 10 RETURN R0 1 -)"); +)" + ); // loops with body that's too long - CHECK_EQ("\n" + compileFunction(R"( + CHECK_EQ( + "\n" + compileFunction( + R"( local t = {} for i=1,100 do t[i] = i end return t )", - 0, 2), + 0, + 2 + ), R"( NEWTABLE R0 0 0 LOADN R3 1 @@ -4714,17 +5528,22 @@ FORNPREP R1 L1 L0: SETTABLE R3 R0 R3 FORNLOOP R1 L0 L1: RETURN R0 1 -)"); +)" + ); // loops with body that's long but has a high boost factor due to constant folding - CHECK_EQ("\n" + compileFunction(R"( + CHECK_EQ( + "\n" + compileFunction( + R"( local t = {} for i=1,25 do t[i] = i * i * i end return t )", - 0, 2), + 0, + 2 + ), R"( NEWTABLE R0 0 0 LOADN R1 1 @@ -4778,17 +5597,22 @@ SETTABLEN R1 R0 24 LOADN R1 15625 SETTABLEN R1 R0 25 RETURN R0 1 -)"); +)" + ); // loops with body that's long and doesn't have a high boost factor - CHECK_EQ("\n" + compileFunction(R"( + CHECK_EQ( + "\n" + compileFunction( + R"( local t = {} for i=1,10 do t[i] = math.abs(math.sin(i)) end return t )", - 0, 2), + 0, + 2 + ), R"( NEWTABLE R0 0 10 LOADN R3 1 @@ -4805,19 +5629,24 @@ CALL R4 1 1 L2: SETTABLE R4 R0 R3 FORNLOOP R1 L0 L3: RETURN R0 1 -)"); +)" + ); } TEST_CASE("LoopUnrollMutable") { // can't unroll loops that mutate iteration variable - CHECK_EQ("\n" + compileFunction(R"( + CHECK_EQ( + "\n" + compileFunction( + R"( for i=1,3 do i = 3 print(i) -- should print 3 three times in a row end )", - 0, 2), + 0, + 2 + ), R"( LOADN R2 1 LOADN R0 3 @@ -4830,25 +5659,30 @@ MOVE R5 R3 CALL R4 1 0 FORNLOOP R0 L0 L1: RETURN R0 0 -)"); +)" + ); } TEST_CASE("LoopUnrollCostBuiltins") { ScopedFastInt sfis[] = { - {"LuauCompileLoopUnrollThreshold", 25}, - {"LuauCompileLoopUnrollThresholdMaxBoost", 300}, + {FInt::LuauCompileLoopUnrollThreshold, 25}, + {FInt::LuauCompileLoopUnrollThresholdMaxBoost, 300}, }; // this loop uses builtins and is close to the cost budget so it's important that we model builtins as cheaper than regular calls - CHECK_EQ("\n" + compileFunction(R"( + CHECK_EQ( + "\n" + compileFunction( + R"( function cipher(block, nonce) for i = 0,3 do block[i + 1] = bit32.band(bit32.rshift(nonce, i * 8), 0xff) end end )", - 0, 2), + 0, + 2 + ), R"( FASTCALL2K 39 R1 K0 L0 [0] MOVE R4 R1 @@ -4891,10 +5725,13 @@ GETIMPORT R2 6 [bit32.band] CALL R2 2 1 L7: SETTABLEN R2 R0 4 RETURN R0 0 -)"); +)" + ); // note that if we break compiler's ability to reason about bit32 builtin the loop is no longer unrolled as it's too expensive - CHECK_EQ("\n" + compileFunction(R"( + CHECK_EQ( + "\n" + compileFunction( + R"( bit32 = {} function cipher(block, nonce) @@ -4903,7 +5740,9 @@ function cipher(block, nonce) end end )", - 0, 2), + 0, + 2 + ), R"( LOADN R4 0 LOADN R2 3 @@ -4922,17 +5761,22 @@ CALL R6 2 1 SETTABLE R6 R0 R5 FORNLOOP R2 L0 L1: RETURN R0 0 -)"); +)" + ); // additionally, if we pass too many constants the builtin stops being cheap because of argument setup - CHECK_EQ("\n" + compileFunction(R"( + CHECK_EQ( + "\n" + compileFunction( + R"( function cipher(block, nonce) for i = 0,3 do block[i + 1] = bit32.band(bit32.rshift(nonce, i * 8), 0xff, 0xff, 0xff, 0xff, 0xff) end end )", - 0, 2), + 0, + 2 + ), R"( LOADN R4 0 LOADN R2 3 @@ -4955,13 +5799,16 @@ CALL R6 6 1 L2: SETTABLE R6 R0 R5 FORNLOOP R2 L0 L3: RETURN R0 0 -)"); +)" + ); } TEST_CASE("InlineBasic") { // inline function that returns a constant - CHECK_EQ("\n" + compileFunction(R"( + CHECK_EQ( + "\n" + compileFunction( + R"( local function foo() return 42 end @@ -4969,15 +5816,20 @@ end local x = foo() return x )", - 1, 2), + 1, + 2 + ), R"( DUPCLOSURE R0 K0 ['foo'] LOADN R1 42 RETURN R1 1 -)"); +)" + ); // inline function that returns the argument - CHECK_EQ("\n" + compileFunction(R"( + CHECK_EQ( + "\n" + compileFunction( + R"( local function foo(a) return a end @@ -4985,15 +5837,20 @@ end local x = foo(42) return x )", - 1, 2), + 1, + 2 + ), R"( DUPCLOSURE R0 K0 ['foo'] LOADN R1 42 RETURN R1 1 -)"); +)" + ); // inline function that returns one of the two arguments - CHECK_EQ("\n" + compileFunction(R"( + CHECK_EQ( + "\n" + compileFunction( + R"( local function foo(a, b, c) if a then return b @@ -5005,7 +5862,9 @@ end local x = foo(true, math.random(), 5) return x )", - 1, 2), + 1, + 2 + ), R"( DUPCLOSURE R0 K0 ['foo'] GETIMPORT R2 3 [math.random] @@ -5013,10 +5872,13 @@ CALL R2 0 1 MOVE R1 R2 RETURN R1 1 RETURN R1 1 -)"); +)" + ); // inline function that returns one of the two arguments - CHECK_EQ("\n" + compileFunction(R"( + CHECK_EQ( + "\n" + compileFunction( + R"( local function foo(a, b, c) if a then return b @@ -5028,7 +5890,9 @@ end local x = foo(true, 5, math.random()) return x )", - 1, 2), + 1, + 2 + ), R"( DUPCLOSURE R0 K0 ['foo'] GETIMPORT R2 3 [math.random] @@ -5036,13 +5900,16 @@ CALL R2 0 1 LOADN R1 5 RETURN R1 1 RETURN R1 1 -)"); +)" + ); } -TEST_CASE("InlineBasicProhibited") +TEST_CASE("InlineProhibited") { // we can't inline variadic functions - CHECK_EQ("\n" + compileFunction(R"( + CHECK_EQ( + "\n" + compileFunction( + R"( local function foo(...) return 42 end @@ -5050,16 +5917,21 @@ end local x = foo() return x )", - 1, 2), + 1, + 2 + ), R"( DUPCLOSURE R0 K0 ['foo'] MOVE R1 R0 CALL R1 0 1 RETURN R1 1 -)"); +)" + ); // we can't inline any functions in modules with getfenv/setfenv - CHECK_EQ("\n" + compileFunction(R"( + CHECK_EQ( + "\n" + compileFunction( + R"( local function foo() return 42 end @@ -5068,7 +5940,9 @@ local x = foo() getfenv() return x )", - 1, 2), + 1, + 2 + ), R"( DUPCLOSURE R0 K0 ['foo'] MOVE R1 R0 @@ -5076,13 +5950,86 @@ CALL R1 0 1 GETIMPORT R2 2 [getfenv] CALL R2 0 0 RETURN R1 1 -)"); +)" + ); +} + +TEST_CASE("InlineProhibitedRecursion") +{ + // we can't inline recursive invocations of functions in the functions + // this is actually profitable in certain cases, but it complicates the compiler as it means a local has multiple registers/values + + // in this example, inlining is blocked because we're compiling fact() and we don't yet have the cost model / profitability data for fact() + CHECK_EQ( + "\n" + compileFunction( + R"( +local function fact(n) + return if n <= 1 then 1 else fact(n-1)*n +end + +return fact +)", + 0, + 2 + ), + R"( +LOADN R2 1 +JUMPIFNOTLE R0 R2 L0 +LOADN R1 1 +RETURN R1 1 +L0: GETUPVAL R2 0 +SUBK R3 R0 K0 [1] +CALL R2 1 1 +MUL R1 R2 R0 +RETURN R1 1 +)" + ); + + // in this example, inlining of fact() succeeds, but the nested call to fact() fails since fact is already on the inline stack + CHECK_EQ( + "\n" + compileFunction( + R"( +local function fact(n) + return if n <= 1 then 1 else fact(n-1)*n +end + +local function factsafe(n) + assert(n >= 1) + return fact(n) +end + +return factsafe +)", + 1, + 2 + ), + R"( +LOADN R3 1 +JUMPIFLE R3 R0 L0 +LOADB R2 0 +1 +L0: LOADB R2 1 +L1: FASTCALL1 1 R2 L2 +GETIMPORT R1 1 [assert] +CALL R1 1 0 +L2: LOADN R2 1 +JUMPIFNOTLE R0 R2 L3 +LOADN R1 1 +RETURN R1 1 +L3: GETUPVAL R2 0 +SUBK R3 R0 K2 [1] +CALL R2 1 1 +MUL R1 R2 R0 +RETURN R1 1 +)" + ); } TEST_CASE("InlineNestedLoops") { // functions with basic loops get inlined - CHECK_EQ("\n" + compileFunction(R"( + CHECK_EQ( + "\n" + compileFunction( + R"( local function foo(t) for i=1,3 do t[i] = i @@ -5093,7 +6040,9 @@ end local x = foo({}) return x )", - 1, 2), + 1, + 2 + ), R"( DUPCLOSURE R0 K0 ['foo'] NEWTABLE R2 0 0 @@ -5105,10 +6054,13 @@ LOADN R3 3 SETTABLEN R3 R2 3 MOVE R1 R2 RETURN R1 1 -)"); +)" + ); // we can even unroll the loops based on inline argument - CHECK_EQ("\n" + compileFunction(R"( + CHECK_EQ( + "\n" + compileFunction( + R"( local function foo(t, n) for i=1, n do t[i] = i @@ -5119,7 +6071,9 @@ end local x = foo({}, 3) return x )", - 1, 2), + 1, + 2 + ), R"( DUPCLOSURE R0 K0 ['foo'] NEWTABLE R2 0 0 @@ -5131,13 +6085,16 @@ LOADN R3 3 SETTABLEN R3 R2 3 MOVE R1 R2 RETURN R1 1 -)"); +)" + ); } TEST_CASE("InlineNestedClosures") { // we can inline functions that contain/return functions - CHECK_EQ("\n" + compileFunction(R"( + CHECK_EQ( + "\n" + compileFunction( + R"( local function foo(x) return function(y) return x + y end end @@ -5145,7 +6102,9 @@ end local x = foo(1)(2) return x )", - 2, 2), + 2, + 2 + ), R"( DUPCLOSURE R0 K0 ['foo'] LOADN R2 1 @@ -5154,13 +6113,16 @@ CAPTURE VAL R2 LOADN R2 2 CALL R1 1 1 RETURN R1 1 -)"); +)" + ); } TEST_CASE("InlineMutate") { // if the argument is mutated, it gets a register even if the value is constant - CHECK_EQ("\n" + compileFunction(R"( + CHECK_EQ( + "\n" + compileFunction( + R"( local function foo(a) a = a or 5 return a @@ -5169,17 +6131,22 @@ end local x = foo(42) return x )", - 1, 2), + 1, + 2 + ), R"( DUPCLOSURE R0 K0 ['foo'] LOADN R2 42 ORK R2 R2 K1 [5] MOVE R1 R2 RETURN R1 1 -)"); +)" + ); // if the argument is a local, it can be used directly - CHECK_EQ("\n" + compileFunction(R"( + CHECK_EQ( + "\n" + compileFunction( + R"( local function foo(a) return a end @@ -5188,16 +6155,21 @@ local x = ... local y = foo(x) return y )", - 1, 2), + 1, + 2 + ), R"( DUPCLOSURE R0 K0 ['foo'] GETVARARGS R1 1 MOVE R2 R1 RETURN R2 1 -)"); +)" + ); // ... but if it's mutated, we move it in case it is mutated through a capture during the inlined function - CHECK_EQ("\n" + compileFunction(R"( + CHECK_EQ( + "\n" + compileFunction( + R"( local function foo(a) return a end @@ -5207,7 +6179,9 @@ x = nil local y = foo(x) return y )", - 1, 2), + 1, + 2 + ), R"( DUPCLOSURE R0 K0 ['foo'] GETVARARGS R1 1 @@ -5215,10 +6189,13 @@ LOADNIL R1 MOVE R3 R1 MOVE R2 R3 RETURN R2 1 -)"); +)" + ); // we also don't inline functions if they have been assigned to - CHECK_EQ("\n" + compileFunction(R"( + CHECK_EQ( + "\n" + compileFunction( + R"( local function foo(a) return a end @@ -5228,20 +6205,25 @@ foo = foo local x = foo(42) return x )", - 1, 2), + 1, + 2 + ), R"( DUPCLOSURE R0 K0 ['foo'] MOVE R1 R0 LOADN R2 42 CALL R1 1 1 RETURN R1 1 -)"); +)" + ); } TEST_CASE("InlineUpval") { // if the argument is an upvalue, we naturally need to copy it to a local - CHECK_EQ("\n" + compileFunction(R"( + CHECK_EQ( + "\n" + compileFunction( + R"( local function foo(a) return a end @@ -5253,15 +6235,20 @@ function bar() return x end )", - 1, 2), + 1, + 2 + ), R"( GETUPVAL R1 0 MOVE R0 R1 RETURN R0 1 -)"); +)" + ); // if the function uses an upvalue it's more complicated, because the lexical upvalue may become a local - CHECK_EQ("\n" + compileFunction(R"( + CHECK_EQ( + "\n" + compileFunction( + R"( local b = ... local function foo(a) @@ -5271,7 +6258,9 @@ end local x = foo(42) return x )", - 1, 2), + 1, + 2 + ), R"( GETVARARGS R0 1 DUPCLOSURE R1 K0 ['foo'] @@ -5279,10 +6268,13 @@ CAPTURE VAL R0 LOADN R3 42 ADD R2 R3 R0 RETURN R2 1 -)"); +)" + ); // sometimes the lexical upvalue is deep enough that it's still an upvalue though - CHECK_EQ("\n" + compileFunction(R"( + CHECK_EQ( + "\n" + compileFunction( + R"( local b = ... function bar() @@ -5294,7 +6286,9 @@ function bar() return x end )", - 1, 2), + 1, + 2 + ), R"( DUPCLOSURE R0 K0 ['foo'] CAPTURE UPVAL U0 @@ -5302,13 +6296,16 @@ LOADN R2 42 GETUPVAL R3 0 ADD R1 R2 R3 RETURN R1 1 -)"); +)" + ); } TEST_CASE("InlineCapture") { // if the argument is captured by a nested closure, normally we can rely on capture by value - CHECK_EQ("\n" + compileFunction(R"( + CHECK_EQ( + "\n" + compileFunction( + R"( local function foo(a) return function() return a end end @@ -5317,17 +6314,22 @@ local x = ... local y = foo(x) return y )", - 2, 2), + 2, + 2 + ), R"( DUPCLOSURE R0 K0 ['foo'] GETVARARGS R1 1 NEWCLOSURE R2 P1 CAPTURE VAL R1 RETURN R2 1 -)"); +)" + ); // if the argument is a constant, we move it to a register so that capture by value can happen - CHECK_EQ("\n" + compileFunction(R"( + CHECK_EQ( + "\n" + compileFunction( + R"( local function foo(a) return function() return a end end @@ -5335,17 +6337,22 @@ end local y = foo(42) return y )", - 2, 2), + 2, + 2 + ), R"( DUPCLOSURE R0 K0 ['foo'] LOADN R2 42 NEWCLOSURE R1 P1 CAPTURE VAL R2 RETURN R1 1 -)"); +)" + ); // if the argument is an externally mutated variable, we copy it to an argument and capture it by value - CHECK_EQ("\n" + compileFunction(R"( + CHECK_EQ( + "\n" + compileFunction( + R"( local function foo(a) return function() return a end end @@ -5354,7 +6361,9 @@ local x x = 42 local y = foo(x) return y )", - 2, 2), + 2, + 2 + ), R"( DUPCLOSURE R0 K0 ['foo'] LOADNIL R1 @@ -5363,10 +6372,13 @@ MOVE R3 R1 NEWCLOSURE R2 P1 CAPTURE VAL R3 RETURN R2 1 -)"); +)" + ); // finally, if the argument is mutated internally, we must capture it by reference and close the upvalue - CHECK_EQ("\n" + compileFunction(R"( + CHECK_EQ( + "\n" + compileFunction( + R"( local function foo(a) a = a or 42 return function() return a end @@ -5375,7 +6387,9 @@ end local y = foo() return y )", - 2, 2), + 2, + 2 + ), R"( DUPCLOSURE R0 K0 ['foo'] LOADNIL R2 @@ -5384,10 +6398,13 @@ NEWCLOSURE R1 P1 CAPTURE REF R2 CLOSEUPVALS R2 RETURN R1 1 -)"); +)" + ); // note that capture might need to be performed during the fallthrough block - CHECK_EQ("\n" + compileFunction(R"( + CHECK_EQ( + "\n" + compileFunction( + R"( local function foo(a) a = a or 42 print(function() return a end) @@ -5397,7 +6414,9 @@ local x = ... local y = foo(x) return y )", - 2, 2), + 2, + 2 + ), R"( DUPCLOSURE R0 K0 ['foo'] GETVARARGS R1 1 @@ -5410,12 +6429,15 @@ CALL R4 1 0 LOADNIL R2 CLOSEUPVALS R3 RETURN R2 1 -)"); +)" + ); // note that mutation and capture might be inside internal control flow // TODO: this has an oddly redundant CLOSEUPVALS after JUMP; it's not due to inlining, and is an artifact of how StatBlock/StatReturn interact // fixing this would reduce the number of redundant CLOSEUPVALS a bit but it only affects bytecode size as these instructions aren't executed - CHECK_EQ("\n" + compileFunction(R"( + CHECK_EQ( + "\n" + compileFunction( + R"( local function foo(a) if not a then local b b = 42 @@ -5427,7 +6449,9 @@ local x = ... local y = foo(x) return y, x )", - 2, 2), + 2, + 2 + ), R"( DUPCLOSURE R0 K0 ['foo'] GETVARARGS R1 1 @@ -5443,29 +6467,37 @@ L0: LOADNIL R2 L1: MOVE R3 R2 MOVE R4 R1 RETURN R3 2 -)"); +)" + ); } TEST_CASE("InlineFallthrough") { // if the function doesn't return, we still fill the results with nil - CHECK_EQ("\n" + compileFunction(R"( + CHECK_EQ( + "\n" + compileFunction( + R"( local function foo() end local a, b = foo() return a, b )", - 1, 2), + 1, + 2 + ), R"( DUPCLOSURE R0 K0 ['foo'] LOADNIL R1 LOADNIL R2 RETURN R1 2 -)"); +)" + ); // this happens even if the function returns conditionally - CHECK_EQ("\n" + compileFunction(R"( + CHECK_EQ( + "\n" + compileFunction( + R"( local function foo(a) if a then return 42 end end @@ -5473,29 +6505,37 @@ end local a, b = foo(false) return a, b )", - 1, 2), + 1, + 2 + ), R"( DUPCLOSURE R0 K0 ['foo'] LOADNIL R1 LOADNIL R2 RETURN R1 2 -)"); +)" + ); // note though that we can't inline a function like this in multret context // this is because we don't have a SETTOP instruction - CHECK_EQ("\n" + compileFunction(R"( + CHECK_EQ( + "\n" + compileFunction( + R"( local function foo() end return foo() )", - 1, 2), + 1, + 2 + ), R"( DUPCLOSURE R0 K0 ['foo'] MOVE R1 R0 CALL R1 0 -1 RETURN R1 -1 -)"); +)" + ); } TEST_CASE("InlineArgMismatch") @@ -5503,7 +6543,9 @@ TEST_CASE("InlineArgMismatch") // when inlining a function, we must respect all the usual rules // caller might not have enough arguments - CHECK_EQ("\n" + compileFunction(R"( + CHECK_EQ( + "\n" + compileFunction( + R"( local function foo(a) return a end @@ -5511,15 +6553,20 @@ end local x = foo() return x )", - 1, 2), + 1, + 2 + ), R"( DUPCLOSURE R0 K0 ['foo'] LOADNIL R1 RETURN R1 1 -)"); +)" + ); // caller might be using multret for arguments - CHECK_EQ("\n" + compileFunction(R"( + CHECK_EQ( + "\n" + compileFunction( + R"( local function foo(a, b) return a + b end @@ -5527,7 +6574,9 @@ end local x = foo(math.modf(1.5)) return x )", - 1, 2), + 1, + 2 + ), R"( DUPCLOSURE R0 K0 ['foo'] LOADK R3 K1 [1.5] @@ -5536,10 +6585,13 @@ GETIMPORT R2 4 [math.modf] CALL R2 1 2 L0: ADD R1 R2 R3 RETURN R1 1 -)"); +)" + ); // caller might be using varargs for arguments - CHECK_EQ("\n" + compileFunction(R"( + CHECK_EQ( + "\n" + compileFunction( + R"( local function foo(a, b) return a + b end @@ -5547,16 +6599,21 @@ end local x = foo(...) return x )", - 1, 2), + 1, + 2 + ), R"( DUPCLOSURE R0 K0 ['foo'] GETVARARGS R2 2 ADD R1 R2 R3 RETURN R1 1 -)"); +)" + ); // caller might have too many arguments, but we still need to compute them for side effects - CHECK_EQ("\n" + compileFunction(R"( + CHECK_EQ( + "\n" + compileFunction( + R"( local function foo(a) return a end @@ -5564,17 +6621,22 @@ end local x = foo(42, print()) return x )", - 1, 2), + 1, + 2 + ), R"( DUPCLOSURE R0 K0 ['foo'] GETIMPORT R2 2 [print] CALL R2 0 1 LOADN R1 42 RETURN R1 1 -)"); +)" + ); // caller might not have enough arguments, and the arg might be mutated so it needs a register - CHECK_EQ("\n" + compileFunction(R"( + CHECK_EQ( + "\n" + compileFunction( + R"( local function foo(a) a = 42 return a @@ -5583,21 +6645,26 @@ end local x = foo() return x )", - 1, 2), + 1, + 2 + ), R"( DUPCLOSURE R0 K0 ['foo'] LOADNIL R2 LOADN R2 42 MOVE R1 R2 RETURN R1 1 -)"); +)" + ); } TEST_CASE("InlineMultiple") { // we call this with a different set of variable/constant args - CHECK_EQ("\n" + compileFunction(R"( -local function foo(a, b) + CHECK_EQ( + "\n" + compileFunction( + R"( +local function foo(a, b) return a + b end @@ -5608,7 +6675,9 @@ local c = foo(1, 2) local d = foo(x, y) return a, b, c, d )", - 1, 2), + 1, + 2 + ), R"( DUPCLOSURE R0 K0 ['foo'] GETVARARGS R1 2 @@ -5618,13 +6687,16 @@ ADD R4 R5 R1 LOADN R5 3 ADD R6 R1 R2 RETURN R3 4 -)"); +)" + ); } TEST_CASE("InlineChain") { // inline a chain of functions - CHECK_EQ("\n" + compileFunction(R"( + CHECK_EQ( + "\n" + compileFunction( + R"( local function foo(a, b) return a + b end @@ -5639,7 +6711,9 @@ end return (baz()) )", - 3, 2), + 3, + 2 + ), R"( DUPCLOSURE R0 K0 ['foo'] DUPCLOSURE R1 K1 ['bar'] @@ -5648,51 +6722,64 @@ LOADN R4 43 LOADN R5 41 MUL R3 R4 R5 RETURN R3 1 -)"); +)" + ); } TEST_CASE("InlineThresholds") { ScopedFastInt sfis[] = { - {"LuauCompileInlineThreshold", 25}, - {"LuauCompileInlineThresholdMaxBoost", 300}, - {"LuauCompileInlineDepth", 2}, + {FInt::LuauCompileInlineThreshold, 25}, + {FInt::LuauCompileInlineThresholdMaxBoost, 300}, + {FInt::LuauCompileInlineDepth, 2}, }; // this function has enormous register pressure (50 regs) so we choose not to inline it - CHECK_EQ("\n" + compileFunction(R"( + CHECK_EQ( + "\n" + compileFunction( + R"( local function foo() return {{{{{{{{{{{{{{{{{{{{{{{{{{{{{{{{{{{{{{{{{{{{{{{{{{}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}} end return (foo()) )", - 1, 2), + 1, + 2 + ), R"( DUPCLOSURE R0 K0 ['foo'] MOVE R1 R0 CALL R1 0 1 RETURN R1 1 -)"); +)" + ); // this function has less register pressure but a large cost - CHECK_EQ("\n" + compileFunction(R"( + CHECK_EQ( + "\n" + compileFunction( + R"( local function foo() return {},{},{},{},{} end return (foo()) )", - 1, 2), + 1, + 2 + ), R"( DUPCLOSURE R0 K0 ['foo'] MOVE R1 R0 CALL R1 0 1 RETURN R1 1 -)"); +)" + ); // this chain of function is of length 3 but our limit in this test is 2, so we call foo twice - CHECK_EQ("\n" + compileFunction(R"( + CHECK_EQ( + "\n" + compileFunction( + R"( local function foo(a, b) return a + b end @@ -5707,7 +6794,9 @@ end return (baz()) )", - 3, 2), + 3, + 2 + ), R"( DUPCLOSURE R0 K0 ['foo'] DUPCLOSURE R1 K1 ['bar'] @@ -5722,18 +6811,23 @@ LOADN R7 -1 CALL R5 2 1 MUL R3 R4 R5 RETURN R3 1 -)"); +)" + ); } TEST_CASE("InlineIIFE") { // IIFE with arguments - CHECK_EQ("\n" + compileFunction(R"( + CHECK_EQ( + "\n" + compileFunction( + R"( function choose(a, b, c) return ((function(a, b, c) if a then return b else return c end end)(a, b, c)) end )", - 1, 2), + 1, + 2 + ), R"( JUMPIFNOT R0 L0 MOVE R3 R1 @@ -5741,15 +6835,20 @@ RETURN R3 1 L0: MOVE R3 R2 RETURN R3 1 RETURN R3 1 -)"); +)" + ); // IIFE with upvalues - CHECK_EQ("\n" + compileFunction(R"( + CHECK_EQ( + "\n" + compileFunction( + R"( function choose(a, b, c) return ((function() if a then return b else return c end end)()) end )", - 1, 2), + 1, + 2 + ), R"( JUMPIFNOT R0 L0 MOVE R3 R1 @@ -5757,35 +6856,120 @@ RETURN R3 1 L0: MOVE R3 R2 RETURN R3 1 RETURN R3 1 -)"); +)" + ); } TEST_CASE("InlineRecurseArguments") { - // we can't inline a function if it's used to compute its own arguments - CHECK_EQ("\n" + compileFunction(R"( + // the example looks silly but we preserve it verbatim as it was found by fuzzer for a previous version of the compiler + CHECK_EQ( + "\n" + compileFunction( + R"( local function foo(a, b) end foo(foo(foo,foo(foo,foo))[foo]) )", - 1, 2), + 1, + 2 + ), R"( DUPCLOSURE R0 K0 ['foo'] -MOVE R2 R0 -MOVE R3 R0 -MOVE R4 R0 -MOVE R5 R0 -MOVE R6 R0 -CALL R4 2 -1 -CALL R2 -1 1 +LOADNIL R3 +LOADNIL R2 GETTABLE R1 R2 R0 RETURN R0 0 -)"); +)" + ); + + // verify that invocations of the inlined function in any position for computing the arguments to itself compile + CHECK_EQ( + "\n" + compileFunction( + R"( +local function foo(a, b) + return a + b +end + +local x, y, z = ... + +return foo(foo(x, y), foo(z, 1)) +)", + 1, + 2 + ), + R"( +DUPCLOSURE R0 K0 ['foo'] +GETVARARGS R1 3 +ADD R5 R1 R2 +ADDK R6 R3 K1 [1] +ADD R4 R5 R6 +RETURN R4 1 +)" + ); + + // verify that invocations of the inlined function in any position for computing the arguments to itself compile, including constants and locals + // note that foo(k1, k2) doesn't get constant folded, so there's still actual math emitted for some of the calls below + CHECK_EQ( + "\n" + compileFunction( + R"( +local function foo(a, b) + return a + b +end + +local x, y, z = ... + +return + foo(foo(1, 2), 3), + foo(1, foo(2, 3)), + foo(x, foo(2, 3)), + foo(x, foo(y, 3)), + foo(x, foo(y, z)), + foo(x+0, foo(y, z)), + foo(x+0, foo(y+0, z)), + foo(x+0, foo(y, z+0)), + foo(1, foo(x, y)) +)", + 1, + 2 + ), + R"( +DUPCLOSURE R0 K0 ['foo'] +GETVARARGS R1 3 +LOADN R5 3 +ADDK R4 R5 K1 [3] +LOADN R6 5 +LOADN R7 1 +ADD R5 R7 R6 +LOADN R7 5 +ADD R6 R1 R7 +ADDK R8 R2 K1 [3] +ADD R7 R1 R8 +ADD R9 R2 R3 +ADD R8 R1 R9 +ADDK R10 R1 K2 [0] +ADD R11 R2 R3 +ADD R9 R10 R11 +ADDK R11 R1 K2 [0] +ADDK R13 R2 K2 [0] +ADD R12 R13 R3 +ADD R10 R11 R12 +ADDK R12 R1 K2 [0] +ADDK R14 R3 K2 [0] +ADD R13 R2 R14 +ADD R11 R12 R13 +ADD R13 R1 R2 +LOADN R14 1 +ADD R12 R14 R13 +RETURN R4 9 +)" + ); } TEST_CASE("InlineFastCallK") { - CHECK_EQ("\n" + compileFunction(R"( + CHECK_EQ( + "\n" + compileFunction( + R"( local function set(l0) rawset({}, l0) end @@ -5793,7 +6977,9 @@ end set(false) set({}) )", - 1, 2), + 1, + 2 + ), R"( DUPCLOSURE R0 K0 ['set'] NEWTABLE R2 0 0 @@ -5808,12 +6994,15 @@ MOVE R4 R1 GETIMPORT R2 3 [rawset] CALL R2 2 0 L1: RETURN R0 0 -)"); +)" + ); } TEST_CASE("InlineExprIndexK") { - CHECK_EQ("\n" + compileFunction(R"( + CHECK_EQ( + "\n" + compileFunction( + R"( local _ = function(l0) local _ = nil while _(_)[_] do @@ -5834,7 +7023,9 @@ end end end )", - 1, 2), + 1, + 2 + ), R"( DUPCLOSURE R0 K0 [] L0: LOADNIL R4 @@ -5869,13 +7060,16 @@ RETURN R2 1 LOADB R2 1 RETURN R2 1 L3: RETURN R0 0 -)"); +)" + ); } TEST_CASE("InlineHiddenMutation") { // when the argument is assigned inside the function, we can't reuse the local - CHECK_EQ("\n" + compileFunction(R"( + CHECK_EQ( + "\n" + compileFunction( + R"( local function foo(a) a = 42 return a @@ -5885,7 +7079,9 @@ local x = ... local y = foo(x :: number) return y )", - 1, 2), + 1, + 2 + ), R"( DUPCLOSURE R0 K0 ['foo'] GETVARARGS R1 1 @@ -5893,10 +7089,13 @@ MOVE R3 R1 LOADN R3 42 MOVE R2 R3 RETURN R2 1 -)"); +)" + ); // and neither can we do that when it's assigned outside the function - CHECK_EQ("\n" + compileFunction(R"( + CHECK_EQ( + "\n" + compileFunction( + R"( local function foo(a) mutator() return a @@ -5908,7 +7107,9 @@ mutator = function() x = 42 end local y = foo(x :: number) return y )", - 2, 2), + 2, + 2 + ), R"( DUPCLOSURE R0 K0 ['foo'] GETVARARGS R1 1 @@ -5921,47 +7122,58 @@ CALL R4 0 0 MOVE R2 R3 CLOSEUPVALS R1 RETURN R2 1 -)"); +)" + ); } TEST_CASE("InlineMultret") { - ScopedFastFlag sff("LuauCompileBuiltinArity", true); - // inlining a function in multret context is prohibited since we can't adjust L->top outside of CALL/GETVARARGS - CHECK_EQ("\n" + compileFunction(R"( + CHECK_EQ( + "\n" + compileFunction( + R"( local function foo(a) return a() end return foo(42) )", - 1, 2), + 1, + 2 + ), R"( DUPCLOSURE R0 K0 ['foo'] MOVE R1 R0 LOADN R2 42 CALL R1 1 -1 RETURN R1 -1 -)"); +)" + ); // however, if we can deduce statically that a function always returns a single value, the inlining will work - CHECK_EQ("\n" + compileFunction(R"( + CHECK_EQ( + "\n" + compileFunction( + R"( local function foo(a) return a end return foo(42) )", - 1, 2), + 1, + 2 + ), R"( DUPCLOSURE R0 K0 ['foo'] LOADN R1 42 RETURN R1 1 -)"); +)" + ); // this analysis will also propagate through other functions - CHECK_EQ("\n" + compileFunction(R"( + CHECK_EQ( + "\n" + compileFunction( + R"( local function foo(a) return a end @@ -5972,23 +7184,30 @@ end return bar(42) )", - 2, 2), + 2, + 2 + ), R"( DUPCLOSURE R0 K0 ['foo'] DUPCLOSURE R1 K1 ['bar'] LOADN R2 42 RETURN R2 1 -)"); +)" + ); // we currently don't do this analysis fully for recursive functions since they can't be inlined anyway - CHECK_EQ("\n" + compileFunction(R"( + CHECK_EQ( + "\n" + compileFunction( + R"( local function foo(a) return foo(a) end return foo(42) )", - 1, 2), + 1, + 2 + ), R"( DUPCLOSURE R0 K0 ['foo'] CAPTURE VAL R0 @@ -5996,48 +7215,59 @@ MOVE R1 R0 LOADN R2 42 CALL R1 1 -1 RETURN R1 -1 -)"); +)" + ); // we do this for builtins though as we assume getfenv is not used or is not changing arity - CHECK_EQ("\n" + compileFunction(R"( + CHECK_EQ( + "\n" + compileFunction( + R"( local function foo(a) return math.abs(a) end return foo(42) )", - 1, 2), + 1, + 2 + ), R"( DUPCLOSURE R0 K0 ['foo'] LOADN R1 42 RETURN R1 1 -)"); +)" + ); } TEST_CASE("ReturnConsecutive") { // we can return a single local directly - CHECK_EQ("\n" + compileFunction0(R"( + CHECK_EQ( + "\n" + compileFunction0(R"( local x = ... return x )"), R"( GETVARARGS R0 1 RETURN R0 1 -)"); +)" + ); // or multiple, when they are allocated in consecutive registers - CHECK_EQ("\n" + compileFunction0(R"( + CHECK_EQ( + "\n" + compileFunction0(R"( local x, y = ... return x, y )"), R"( GETVARARGS R0 2 RETURN R0 2 -)"); +)" + ); // but not if it's an expression - CHECK_EQ("\n" + compileFunction0(R"( + CHECK_EQ( + "\n" + compileFunction0(R"( local x, y = ... return x, y + 1 )"), @@ -6046,10 +7276,12 @@ GETVARARGS R0 2 MOVE R2 R0 ADDK R3 R1 K0 [1] RETURN R2 2 -)"); +)" + ); // or a local with wrong register number - CHECK_EQ("\n" + compileFunction0(R"( + CHECK_EQ( + "\n" + compileFunction0(R"( local x, y = ... return y, x )"), @@ -6058,48 +7290,60 @@ GETVARARGS R0 2 MOVE R2 R1 MOVE R3 R0 RETURN R2 2 -)"); +)" + ); // also double check the optimization doesn't trip on no-argument return (these are rare) - CHECK_EQ("\n" + compileFunction0(R"( + CHECK_EQ( + "\n" + compileFunction0(R"( return )"), R"( RETURN R0 0 -)"); +)" + ); // this optimization also works in presence of group / type casts - CHECK_EQ("\n" + compileFunction0(R"( + CHECK_EQ( + "\n" + compileFunction0(R"( local x, y = ... return (x), y :: number )"), R"( GETVARARGS R0 2 RETURN R0 2 -)"); +)" + ); } TEST_CASE("OptimizationLevel") { // at optimization level 1, no inlining is performed - CHECK_EQ("\n" + compileFunction(R"( + CHECK_EQ( + "\n" + compileFunction( + R"( local function foo(a) return a end return foo(42) )", - 1, 1), + 1, + 1 + ), R"( DUPCLOSURE R0 K0 ['foo'] MOVE R1 R0 LOADN R2 42 CALL R1 1 -1 RETURN R1 -1 -)"); +)" + ); // you can override the level from 1 to 2 to force it - CHECK_EQ("\n" + compileFunction(R"( + CHECK_EQ( + "\n" + compileFunction( + R"( --!optimize 2 local function foo(a) return a @@ -6107,30 +7351,40 @@ end return foo(42) )", - 1, 1), + 1, + 1 + ), R"( DUPCLOSURE R0 K0 ['foo'] LOADN R1 42 RETURN R1 1 -)"); +)" + ); // you can also override it externally - CHECK_EQ("\n" + compileFunction(R"( + CHECK_EQ( + "\n" + compileFunction( + R"( local function foo(a) return a end return foo(42) )", - 1, 2), + 1, + 2 + ), R"( DUPCLOSURE R0 K0 ['foo'] LOADN R1 42 RETURN R1 1 -)"); +)" + ); // ... after which you can downgrade it back via hot comment - CHECK_EQ("\n" + compileFunction(R"( + CHECK_EQ( + "\n" + compileFunction( + R"( --!optimize 1 local function foo(a) return a @@ -6138,19 +7392,24 @@ end return foo(42) )", - 1, 2), + 1, + 2 + ), R"( DUPCLOSURE R0 K0 ['foo'] MOVE R1 R0 LOADN R2 42 CALL R1 1 -1 RETURN R1 -1 -)"); +)" + ); } TEST_CASE("BuiltinFolding") { - CHECK_EQ("\n" + compileFunction(R"( + CHECK_EQ( + "\n" + compileFunction( + R"( return math.abs(-42), math.acos(1), @@ -6168,7 +7427,7 @@ return math.log10(100), math.log(1), math.log(4, 2), - math.log(27, 3), + math.log(64, 4), math.max(1, 2, 3), math.min(1, 2, 3), math.pow(3, 3), @@ -6205,7 +7464,9 @@ return typeof(nil), (type("fin")) )", - 0, 2), + 0, + 2 + ), R"( LOADN R0 42 LOADN R1 0 @@ -6260,14 +7521,15 @@ LOADN R49 2 LOADK R50 K3 ['nil'] LOADK R51 K4 ['string'] RETURN R0 52 -)"); +)" + ); } TEST_CASE("BuiltinFoldingProhibited") { - ScopedFastFlag sff("LuauCompileBuiltinArity", true); - - CHECK_EQ("\n" + compileFunction(R"( + CHECK_EQ( + "\n" + compileFunction( + R"( return math.abs(), math.max(1, true), @@ -6280,7 +7542,9 @@ return bit32.btest(1, true), math.min(1, true) )", - 0, 2), + 0, + 2 + ), R"( FASTCALL 2 L0 GETIMPORT R0 2 [math.abs] @@ -6332,55 +7596,18 @@ LOADK R11 K3 [true] GETIMPORT R9 26 [math.min] CALL R9 2 1 L9: RETURN R0 10 -)"); +)" + ); } TEST_CASE("BuiltinFoldingProhibitedCoverage") { const char* builtins[] = { - "math.abs", - "math.acos", - "math.asin", - "math.atan2", - "math.atan", - "math.ceil", - "math.cosh", - "math.cos", - "math.deg", - "math.exp", - "math.floor", - "math.fmod", - "math.ldexp", - "math.log10", - "math.log", - "math.max", - "math.min", - "math.pow", - "math.rad", - "math.sinh", - "math.sin", - "math.sqrt", - "math.tanh", - "math.tan", - "bit32.arshift", - "bit32.band", - "bit32.bnot", - "bit32.bor", - "bit32.bxor", - "bit32.btest", - "bit32.extract", - "bit32.lrotate", - "bit32.lshift", - "bit32.replace", - "bit32.rrotate", - "bit32.rshift", - "type", - "string.byte", - "string.len", - "typeof", - "math.clamp", - "math.sign", - "math.round", + "math.abs", "math.acos", "math.asin", "math.atan2", "math.atan", "math.ceil", "math.cosh", "math.cos", "math.deg", + "math.exp", "math.floor", "math.fmod", "math.ldexp", "math.log10", "math.log", "math.max", "math.min", "math.pow", + "math.rad", "math.sinh", "math.sin", "math.sqrt", "math.tanh", "math.tan", "bit32.arshift", "bit32.band", "bit32.bnot", + "bit32.bor", "bit32.bxor", "bit32.btest", "bit32.extract", "bit32.lrotate", "bit32.lshift", "bit32.replace", "bit32.rrotate", "bit32.rshift", + "type", "string.byte", "string.len", "typeof", "math.clamp", "math.sign", "math.round", }; for (const char* func : builtins) @@ -6397,7 +7624,9 @@ TEST_CASE("BuiltinFoldingProhibitedCoverage") TEST_CASE("BuiltinFoldingMultret") { - CHECK_EQ("\n" + compileFunction(R"( + CHECK_EQ( + "\n" + compileFunction( + R"( local NoLanes: Lanes = --[[ ]] 0b0000000000000000000000000000000 local OffscreenLane: Lane = --[[ ]] 0b1000000000000000000000000000000 @@ -6412,7 +7641,9 @@ local function getLanesToRetrySynchronouslyOnError(root: FiberRoot): Lanes return NoLanes end )", - 0, 2), + 0, + 2 + ), R"( GETTABLEKS R2 R0 K0 ['pendingLanes'] FASTCALL2K 29 R2 K1 L0 [3221225471] @@ -6431,23 +7662,30 @@ LOADK R2 K6 [1073741824] RETURN R2 1 L3: LOADN R2 0 RETURN R2 1 -)"); +)" + ); // Note: similarly, here we should have folded the return value but haven't because it's the last call in the sequence - CHECK_EQ("\n" + compileFunction(R"( + CHECK_EQ( + "\n" + compileFunction( + R"( return math.abs(-42) )", - 0, 2), + 0, + 2 + ), R"( LOADN R0 42 RETURN R0 1 -)"); +)" + ); } TEST_CASE("LocalReassign") { // locals can be re-assigned and the register gets reused - CHECK_EQ("\n" + compileFunction0(R"( + CHECK_EQ( + "\n" + compileFunction0(R"( local function test(a, b) local c = a return c + b @@ -6456,10 +7694,12 @@ end R"( ADD R2 R0 R1 RETURN R2 1 -)"); +)" + ); // this works if the expression is using type casts or grouping - CHECK_EQ("\n" + compileFunction0(R"( + CHECK_EQ( + "\n" + compileFunction0(R"( local function test(a, b) local c = (a :: number) return c + b @@ -6468,10 +7708,12 @@ end R"( ADD R2 R0 R1 RETURN R2 1 -)"); +)" + ); // the optimization requires that neither local is mutated - CHECK_EQ("\n" + compileFunction0(R"( + CHECK_EQ( + "\n" + compileFunction0(R"( local function test(a, b) local c = a c += 0 @@ -6487,10 +7729,12 @@ MOVE R3 R1 ADDK R1 R1 K0 [0] ADD R4 R2 R3 RETURN R4 1 -)"); +)" + ); // sanity check for two values - CHECK_EQ("\n" + compileFunction0(R"( + CHECK_EQ( + "\n" + compileFunction0(R"( local function test(a, b) local c = a local d = b @@ -6500,10 +7744,12 @@ end R"( ADD R2 R0 R1 RETURN R2 1 -)"); +)" + ); // note: we currently only support this for single assignments - CHECK_EQ("\n" + compileFunction0(R"( + CHECK_EQ( + "\n" + compileFunction0(R"( local function test(a, b) local c, d = a, b return c + d @@ -6514,29 +7760,35 @@ MOVE R2 R0 MOVE R3 R1 ADD R4 R2 R3 RETURN R4 1 -)"); +)" + ); // of course, captures capture the original register as well (by value since it's immutable) - CHECK_EQ("\n" + compileFunction(R"( + CHECK_EQ( + "\n" + compileFunction( + R"( local function test(a, b) local c = a local d = b return function() return c + d end end )", - 1), + 1 + ), R"( NEWCLOSURE R2 P0 CAPTURE VAL R0 CAPTURE VAL R1 RETURN R2 1 -)"); +)" + ); } TEST_CASE("MultipleAssignments") { // order of assignments is left to right - CHECK_EQ("\n" + compileFunction0(R"( + CHECK_EQ( + "\n" + compileFunction0(R"( local a, b a, b = f(1), f(2) )"), @@ -6552,10 +7804,12 @@ LOADN R3 2 CALL R2 1 1 MOVE R1 R2 RETURN R0 0 -)"); +)" + ); // this includes table assignments - CHECK_EQ("\n" + compileFunction0(R"( + CHECK_EQ( + "\n" + compileFunction0(R"( local t t[1], t[2] = 3, 4 )"), @@ -6567,10 +7821,12 @@ LOADN R3 4 SETTABLEN R2 R0 1 SETTABLEN R3 R1 2 RETURN R0 0 -)"); +)" + ); // semantically, we evaluate the right hand side first; this allows us to e.g swap elements in a table easily - CHECK_EQ("\n" + compileFunction0(R"( + CHECK_EQ( + "\n" + compileFunction0(R"( local t = ... t[1], t[2] = t[2], t[1] )"), @@ -6581,14 +7837,16 @@ GETTABLEN R2 R0 1 SETTABLEN R1 R0 1 SETTABLEN R2 R0 2 RETURN R0 0 -)"); +)" + ); // however, we need to optimize local assignments; to do this well, we need to handle assignment conflicts // let's first go through a few cases where there are no conflicts: // when multiple assignments have no conflicts (all local vars are read after being assigned), codegen is the same as a series of single // assignments - CHECK_EQ("\n" + compileFunction0(R"( + CHECK_EQ( + "\n" + compileFunction0(R"( local xm1, x, xp1, xi = ... xm1,x,xp1,xi = x,xp1,xp1+1,xi-1 @@ -6600,10 +7858,12 @@ MOVE R1 R2 ADDK R2 R2 K0 [1] SUBK R3 R3 K0 [1] RETURN R0 0 -)"); +)" + ); // similar example to above from a more complex case - CHECK_EQ("\n" + compileFunction0(R"( + CHECK_EQ( + "\n" + compileFunction0(R"( local a, b, c, d, e, f, g, h, t1, t2 = ... h, g, f, e, d, c, b, a = g, f, e, d + t1, c, b, a, t1 + t2 @@ -6619,11 +7879,13 @@ MOVE R2 R1 MOVE R1 R0 ADD R0 R8 R9 RETURN R0 0 -)"); +)" + ); // when locals have a conflict, we assign temporaries instead of locals, and at the end copy the values back // the basic example of this is a swap/rotate - CHECK_EQ("\n" + compileFunction0(R"( + CHECK_EQ( + "\n" + compileFunction0(R"( local a, b = ... a, b = b, a )"), @@ -6633,9 +7895,11 @@ MOVE R2 R1 MOVE R1 R0 MOVE R0 R2 RETURN R0 0 -)"); +)" + ); - CHECK_EQ("\n" + compileFunction0(R"( + CHECK_EQ( + "\n" + compileFunction0(R"( local a, b, c = ... a, b, c = c, a, b )"), @@ -6647,9 +7911,11 @@ MOVE R2 R1 MOVE R0 R3 MOVE R1 R4 RETURN R0 0 -)"); +)" + ); - CHECK_EQ("\n" + compileFunction0(R"( + CHECK_EQ( + "\n" + compileFunction0(R"( local a, b, c = ... a, b, c = b, c, a )"), @@ -6660,10 +7926,12 @@ MOVE R1 R2 MOVE R2 R0 MOVE R0 R3 RETURN R0 0 -)"); +)" + ); // multiple assignments with multcall handling - foo() evalutes to temporary registers and they are copied out to target - CHECK_EQ("\n" + compileFunction0(R"( + CHECK_EQ( + "\n" + compileFunction0(R"( local a, b, c, d = ... a, b, c, d = 1, foo() )"), @@ -6676,10 +7944,12 @@ MOVE R1 R4 MOVE R2 R5 MOVE R3 R6 RETURN R0 0 -)"); +)" + ); // note that during this we still need to handle local reassignment, eg when table assignments are performed - CHECK_EQ("\n" + compileFunction0(R"( + CHECK_EQ( + "\n" + compileFunction0(R"( local a, b, c, d = ... a, b[a], c[d], d = 1, foo() )"), @@ -6693,11 +7963,13 @@ SETTABLE R7 R2 R3 MOVE R0 R4 MOVE R3 R8 RETURN R0 0 -)"); +)" + ); // multiple assignments with multcall handling - foo evaluates to a single argument so all remaining locals are assigned to nil // note that here we don't assign the locals directly, as this case is very rare so we use the similar code path as above - CHECK_EQ("\n" + compileFunction0(R"( + CHECK_EQ( + "\n" + compileFunction0(R"( local a, b, c, d = ... a, b, c, d = 1, foo )"), @@ -6711,10 +7983,12 @@ MOVE R1 R4 MOVE R2 R5 MOVE R3 R6 RETURN R0 0 -)"); +)" + ); // note that we also try to use locals as a source of assignment directly when assigning fields; this works using old local value when possible - CHECK_EQ("\n" + compileFunction0(R"( + CHECK_EQ( + "\n" + compileFunction0(R"( local a, b = ... a[1], a[2] = b, b + 1 )"), @@ -6724,10 +7998,12 @@ ADDK R2 R1 K0 [1] SETTABLEN R1 R0 1 SETTABLEN R2 R0 2 RETURN R0 0 -)"); +)" + ); // ... of course if the local is reassigned, we defer the assignment until later - CHECK_EQ("\n" + compileFunction0(R"( + CHECK_EQ( + "\n" + compileFunction0(R"( local a, b = ... b, a[1] = 42, b )"), @@ -6737,10 +8013,12 @@ LOADN R2 42 SETTABLEN R1 R0 1 MOVE R1 R2 RETURN R0 0 -)"); +)" + ); // when there are more expressions when values, we evalute them for side effects, but they also participate in conflict handling - CHECK_EQ("\n" + compileFunction0(R"( + CHECK_EQ( + "\n" + compileFunction0(R"( local a, b = ... a, b = 1, 2, a + b )"), @@ -6752,10 +8030,12 @@ ADD R4 R0 R1 MOVE R0 R2 MOVE R1 R3 RETURN R0 0 -)"); +)" + ); // because we perform assignments to complex l-values after assignments to locals, we make sure register conflicts are tracked accordingly - CHECK_EQ("\n" + compileFunction0(R"( + CHECK_EQ( + "\n" + compileFunction0(R"( local a, b = ... a[1], b = b, b + 1 )"), @@ -6765,14 +8045,16 @@ ADDK R2 R1 K0 [1] SETTABLEN R1 R0 1 MOVE R1 R2 RETURN R0 0 -)"); +)" + ); } TEST_CASE("BuiltinExtractK") { // below, K0 refers to a packed f+w constant for bit32.extractk builtin // K1 and K2 refer to 1 and 3 and are only used during fallback path - CHECK_EQ("\n" + compileFunction0(R"( + CHECK_EQ( + "\n" + compileFunction0(R"( local v = ... return bit32.extract(v, 1, 3) @@ -6786,7 +8068,8 @@ LOADK R4 K2 [3] GETIMPORT R1 5 [bit32.extract] CALL R1 3 -1 L0: RETURN R1 -1 -)"); +)" + ); } TEST_CASE("SkipSelfAssignment") @@ -6816,10 +8099,9 @@ RETURN R0 0 TEST_CASE("ElideJumpAfterIf") { - ScopedFastFlag sff("LuauCompileTerminateBC", true); - // break refers to outer loop => we can elide unconditional branches - CHECK_EQ("\n" + compileFunction0(R"( + CHECK_EQ( + "\n" + compileFunction0(R"( local foo, bar = ... repeat if foo then break @@ -6839,10 +8121,12 @@ CALL R2 1 0 JUMPIFEQ R0 R1 L2 JUMPBACK L0 L2: RETURN R0 0 -)"); +)" + ); // break refers to inner loop => branches remain - CHECK_EQ("\n" + compileFunction0(R"( + CHECK_EQ( + "\n" + compileFunction0(R"( local foo, bar = ... repeat if foo then while true do break end @@ -6866,18 +8150,21 @@ CALL R2 1 0 JUMPIFEQ R0 R1 L3 JUMPBACK L0 L3: RETURN R0 0 -)"); +)" + ); } TEST_CASE("BuiltinArity") { - ScopedFastFlag sff("LuauCompileBuiltinArity", true); - // by default we can't assume that we know parameter/result count for builtins as they can be overridden at runtime - CHECK_EQ("\n" + compileFunction(R"( + CHECK_EQ( + "\n" + compileFunction( + R"( return math.abs(unknown()) )", - 0, 1), + 0, + 1 + ), R"( GETIMPORT R1 1 [unknown] CALL R1 0 -1 @@ -6885,14 +8172,19 @@ FASTCALL 2 L0 GETIMPORT R0 4 [math.abs] CALL R0 -1 -1 L0: RETURN R0 -1 -)"); +)" + ); // however, when using optimization level 2, we assume compile time knowledge about builtin behavior even if we can't deoptimize that with fenv // in the test case below, this allows us to synthesize a more efficient FASTCALL1 (and use a fixed-return call to unknown) - CHECK_EQ("\n" + compileFunction(R"( + CHECK_EQ( + "\n" + compileFunction( + R"( return math.abs(unknown()) )", - 0, 2), + 0, + 2 + ), R"( GETIMPORT R1 1 [unknown] CALL R1 0 1 @@ -6900,13 +8192,18 @@ FASTCALL1 2 R1 L0 GETIMPORT R0 4 [math.abs] CALL R0 1 1 L0: RETURN R0 1 -)"); +)" + ); // some builtins are variadic, and as such they can't use fixed-length fastcall variants - CHECK_EQ("\n" + compileFunction(R"( + CHECK_EQ( + "\n" + compileFunction( + R"( return math.max(0, unknown()) )", - 0, 2), + 0, + 2 + ), R"( LOADN R1 0 GETIMPORT R2 1 [unknown] @@ -6915,14 +8212,19 @@ FASTCALL 18 L0 GETIMPORT R0 4 [math.max] CALL R0 -1 1 L0: RETURN R0 1 -)"); +)" + ); // some builtins are not variadic but don't have a fixed number of arguments; we currently don't optimize this although we might start to in the // future - CHECK_EQ("\n" + compileFunction(R"( + CHECK_EQ( + "\n" + compileFunction( + R"( return bit32.extract(0, 1, unknown()) )", - 0, 2), + 0, + 2 + ), R"( LOADN R1 0 LOADN R2 1 @@ -6932,17 +8234,42 @@ FASTCALL 34 L0 GETIMPORT R0 4 [bit32.extract] CALL R0 -1 1 L0: RETURN R0 1 -)"); +)" + ); + + // some builtins are not variadic and have a fixed number of arguments but are not none-safe, meaning that we can't replace calls that may + // return none with calls that will return nil + CHECK_EQ( + "\n" + compileFunction( + R"( +return type(unknown()) +)", + 0, + 2 + ), + R"( +GETIMPORT R1 1 [unknown] +CALL R1 0 -1 +FASTCALL 40 L0 +GETIMPORT R0 3 [type] +CALL R0 -1 1 +L0: RETURN R0 1 +)" + ); // importantly, this optimization also helps us get around the multret inlining restriction for builtin wrappers - CHECK_EQ("\n" + compileFunction(R"( + CHECK_EQ( + "\n" + compileFunction( + R"( local function new() return setmetatable({}, MT) end return new() )", - 1, 2), + 1, + 2 + ), R"( DUPCLOSURE R0 K0 ['new'] NEWTABLE R2 0 0 @@ -6951,16 +8278,21 @@ FASTCALL2 61 R2 R3 L0 GETIMPORT R1 4 [setmetatable] CALL R1 2 1 L0: RETURN R1 1 -)"); +)" + ); // note that the results of this optimization are benign in fixed-arg contexts which dampens the effect of fenv substitutions on correctness in // practice - CHECK_EQ("\n" + compileFunction(R"( + CHECK_EQ( + "\n" + compileFunction( + R"( local x = ... local y, z = type(x) return type(y, z) )", - 0, 2), + 0, + 2 + ), R"( GETVARARGS R0 1 FASTCALL1 40 R0 L0 @@ -6973,7 +8305,552 @@ MOVE R5 R2 GETIMPORT R3 1 [type] CALL R3 2 1 L1: RETURN R3 1 +)" + ); +} + +TEST_CASE("EncodedTypeTable") +{ + CHECK_EQ( + "\n" + compileTypeTable(R"( +function myfunc(test: string, num: number) + print(test) +end + +function myfunc2(test: number?) +end + +function myfunc3(test: string, n: number) +end + +function myfunc4(test: string | number, n: number) +end + +-- Promoted to function(any, any) since general unions are not supported. +-- Functions with all `any` parameters will have omitted type info. +function myfunc5(test: string | number, n: number | boolean) +end + +function myfunc6(test: (number) -> string) +end + +myfunc('test') +)"), + R"( +0: function(string, number) +1: function(number?) +2: function(string, number) +3: function(any, number) +5: function(function) +)" + ); + + CHECK_EQ( + "\n" + compileTypeTable(R"( +local Str = { + a = 1 +} + +-- Implicit `self` parameter is automatically assumed to be table type. +function Str:test(n: number) + print(self.a, n) +end + +Str:test(234) +)"), + R"( +0: function(table, number) +)" + ); +} + +TEST_CASE("HostTypesAreUserdata") +{ + CHECK_EQ( + "\n" + compileTypeTable(R"( +function myfunc(test: string, num: number) + print(test) +end + +function myfunc2(test: Instance, num: number) +end + +type Foo = string + +function myfunc3(test: string, n: Foo) +end + +function myfunc4(test: Bar, n: Part) +end +)"), + R"( +0: function(string, number) +1: function(userdata, number) +2: function(string, string) +3: function(any, userdata) +)" + ); +} + +TEST_CASE("HostTypesVector") +{ + CHECK_EQ( + "\n" + compileTypeTable(R"( +function myfunc(test: Instance, pos: Vector3) +end + +function myfunc2(test: Instance, pos: Vector3) +end + +do + type Vector3 = number + + function myfunc3(test: Instance, pos: Vector3) + end +end +)"), + R"( +0: function(userdata, vector) +1: function(userdata, any) +2: function(userdata, number) +)" + ); +} + +TEST_CASE("TypeAliasScoping") +{ + CHECK_EQ( + "\n" + compileTypeTable(R"( +do + type Part = number +end + +function myfunc1(test: Part, num: number) +end + +do + type Part = number + + function myfunc2(test: Part, num: number) + end +end + +repeat + type Part = number +until (function(test: Part, num: number) end)() + +function myfunc4(test: Instance, num: number) +end + +type Instance = string +)"), + R"( +0: function(userdata, number) +1: function(number, number) +2: function(number, number) +3: function(string, number) +)" + ); +} + +TEST_CASE("TypeAliasResolve") +{ + CHECK_EQ( + "\n" + compileTypeTable(R"( +type Foo1 = number +type Foo2 = { number } +type Foo3 = Part +type Foo4 = Foo1 -- we do not resolve aliases within aliases +type Foo5 = X + +function myfunc(f1: Foo1, f2: Foo2, f3: Foo3, f4: Foo4, f5: Foo5) +end + +function myfuncerr(f1: Foo1, f2: Foo5) +end + +)"), + R"( +0: function(number, table, userdata, any, any) +1: function(number, any) +)" + ); +} + +TEST_CASE("TypeUnionIntersection") +{ + CHECK_EQ( + "\n" + compileTypeTable(R"( +function myfunc(test: string | nil, foo: nil) +end + +function myfunc2(test: string & nil, foo: nil) +end + +function myfunc3(test: string | number, foo: nil) +end + +function myfunc4(test: string & number, foo: nil) +end +)"), + R"( +0: function(string?, nil) +1: function(any, nil) +2: function(any, nil) +3: function(any, nil) +)" + ); +} + +TEST_CASE("BuiltinFoldMathK") +{ + // we can fold math.pi at optimization level 2 + CHECK_EQ( + "\n" + compileFunction( + R"( +function test() + return math.pi * 2 +end +)", + 0, + 2 + ), + R"( +LOADK R0 K0 [6.2831853071795862] +RETURN R0 1 +)" + ); + + // we don't do this at optimization level 1 because it may interfere with environment substitution + CHECK_EQ( + "\n" + compileFunction( + R"( +function test() + return math.pi * 2 +end +)", + 0, + 1 + ), + R"( +GETIMPORT R1 3 [math.pi] +MULK R0 R1 K0 [2] +RETURN R0 1 +)" + ); + + // we also don't do it if math global is assigned to + CHECK_EQ( + "\n" + compileFunction( + R"( +function test() + return math.pi * 2 +end + +math = { pi = 4 } +)", + 0, + 2 + ), + R"( +GETGLOBAL R2 K1 ['math'] +GETTABLEKS R1 R2 K2 ['pi'] +MULK R0 R1 K0 [2] +RETURN R0 1 +)" + ); +} + +TEST_CASE("NoBuiltinFoldFenv") +{ + // builtin folding is disabled when getfenv/setfenv is used in the module + CHECK_EQ( + "\n" + compileFunction( + R"( +getfenv() + +function test() + return math.pi, math.sin(0) +end +)", + 0, + 2 + ), + R"( +GETIMPORT R0 2 [math.pi] +LOADN R2 0 +FASTCALL1 24 R2 L0 +GETIMPORT R1 4 [math.sin] +CALL R1 1 1 +L0: RETURN R0 2 +)" + ); +} + +TEST_CASE("IfThenElseAndOr") +{ + // if v then v else k can be optimized to ORK + CHECK_EQ( + "\n" + compileFunction0(R"( +local x = ... +return if x then x else 0 +)"), + R"( +GETVARARGS R0 1 +ORK R1 R0 K0 [0] +RETURN R1 1 +)" + ); + + // if v then v else l can be optimized to OR + CHECK_EQ( + "\n" + compileFunction0(R"( +local x, y = ... +return if x then x else y +)"), + R"( +GETVARARGS R0 2 +OR R2 R0 R1 +RETURN R2 1 +)" + ); + + // this also works in presence of type casts + CHECK_EQ( + "\n" + compileFunction0(R"( +local x, y = ... +return if x then x :: number else 0 +)"), + R"( +GETVARARGS R0 2 +ORK R2 R0 K0 [0] +RETURN R2 1 +)" + ); + + // if v then k else v can be optimized to ANDK + CHECK_EQ( + "\n" + compileFunction0(R"( +local x = ... +return if x then 0 else x +)"), + R"( +GETVARARGS R0 1 +ANDK R1 R0 K0 [0] +RETURN R1 1 +)" + ); + + // if v then l else v can be optimized to AND + CHECK_EQ( + "\n" + compileFunction0(R"( +local x, y = ... +return if x then y else x +)"), + R"( +GETVARARGS R0 2 +AND R2 R0 R1 +RETURN R2 1 +)" + ); + + // this also works in presence of type casts + CHECK_EQ( + "\n" + compileFunction0(R"( +local x, y = ... +return if x then y else x :: number +)"), + R"( +GETVARARGS R0 2 +AND R2 R0 R1 +RETURN R2 1 +)" + ); + + // all of the above work when the target is a temporary register, which is safe because the value is only mutated once + CHECK_EQ( + "\n" + compileFunction0(R"( +local x, y = ... +x = if x then x else y +x = if x then y else x +)"), + R"( +GETVARARGS R0 2 +OR R0 R0 R1 +AND R0 R0 R1 +RETURN R0 0 +)" + ); + + // note that we can't do this transformation if the expression has possible side effects + CHECK_EQ( + "\n" + compileFunction0(R"( +local x = ... +return if x.data then x.data else 0 +)"), + R"( +GETVARARGS R0 1 +GETTABLEKS R2 R0 K0 ['data'] +JUMPIFNOT R2 L0 +GETTABLEKS R1 R0 K0 ['data'] +RETURN R1 1 +L0: LOADN R1 0 +RETURN R1 1 +)" + ); +} + +TEST_CASE("SideEffects") +{ + // we do not evaluate expressions in some cases when we know they can't carry side effects + CHECK_EQ( + "\n" + compileFunction0(R"( +local x = 5, print +local y = 5, 42 +local z = 5, table.find -- considered side effecting because of metamethods +)"), + R"( +LOADN R0 5 +LOADN R1 5 +LOADN R2 5 +GETIMPORT R3 2 [table.find] +RETURN R0 0 +)" + ); + + // this also applies to returns in cases where a function gets inlined + CHECK_EQ( + "\n" + compileFunction( + R"( +local function test1() + return 42 +end + +local function test2() + return print +end + +local function test3() + return function() print(test3) end +end + +local function test4() + return table.find -- considered side effecting because of metamethods +end + +test1() +test2() +test3() +test4() +)", + 5, + 2 + ), + R"( +DUPCLOSURE R0 K0 ['test1'] +DUPCLOSURE R1 K1 ['test2'] +DUPCLOSURE R2 K2 ['test3'] +CAPTURE VAL R2 +DUPCLOSURE R3 K3 ['test4'] +GETIMPORT R4 6 [table.find] +RETURN R0 0 +)" + ); +} + +TEST_CASE("IfElimination") +{ + // if the left hand side of a condition is constant, it constant folds and we don't emit the branch + CHECK_EQ("\n" + compileFunction0("local a = false if a and b then b() end"), R"( +RETURN R0 0 +)"); + + CHECK_EQ("\n" + compileFunction0("local a = true if a or b then b() end"), R"( +GETIMPORT R0 1 [b] +CALL R0 0 0 +RETURN R0 0 +)"); + + // of course this keeps the other branch if present + CHECK_EQ("\n" + compileFunction0("local a = false if a and b then b() else return 42 end"), R"( +LOADN R0 42 +RETURN R0 1 +RETURN R0 0 +)"); + + CHECK_EQ("\n" + compileFunction0("local a = true if a or b then b() else return 42 end"), R"( +GETIMPORT R0 1 [b] +CALL R0 0 0 +RETURN R0 0 +)"); + + // if the right hand side is constant, the condition doesn't constant fold but we still could eliminate one of the branches for 'a and K' + CHECK_EQ("\n" + compileFunction0("local a = false if b and a then return 1 end"), R"( +RETURN R0 0 +)"); + + CHECK_EQ("\n" + compileFunction0("local a = false if b and a then return 1 else return 2 end"), R"( +LOADN R0 2 +RETURN R0 1 )"); + + // of course if the right hand side of 'and' is 'true', we still need to actually evaluate the left hand side + CHECK_EQ("\n" + compileFunction0("local a = true if b and a then return 1 end"), R"( +GETIMPORT R0 1 [b] +JUMPIFNOT R0 L0 +LOADN R0 1 +RETURN R0 1 +L0: RETURN R0 0 +)"); + + CHECK_EQ("\n" + compileFunction0("local a = true if b and a then return 1 else return 2 end"), R"( +GETIMPORT R0 1 [b] +JUMPIFNOT R0 L0 +LOADN R0 1 +RETURN R0 1 +L0: LOADN R0 2 +RETURN R0 1 +)"); + + // also even if we eliminate the branch, we still need to compute side effects + CHECK_EQ("\n" + compileFunction0("local a = false if b.test and a then return 1 end"), R"( +GETIMPORT R0 2 [b.test] +RETURN R0 0 +)"); + + CHECK_EQ("\n" + compileFunction0("local a = false if b.test and a then return 1 else return 2 end"), R"( +GETIMPORT R0 2 [b.test] +LOADN R0 2 +RETURN R0 1 +)"); +} + +TEST_CASE("ArithRevK") +{ + // - and / have special optimized form for reverse constants; in the future, + and * will likely get compiled to ADDK/MULK + // other operators are not important enough to optimize reverse constant forms for + CHECK_EQ( + "\n" + compileFunction0(R"( +local x: number = unknown +return 2 + x, 2 - x, 2 * x, 2 / x, 2 % x, 2 // x, 2 ^ x +)"), + R"( +GETIMPORT R0 1 [unknown] +LOADN R2 2 +ADD R1 R2 R0 +SUBRK R2 K2 [2] R0 +LOADN R4 2 +MUL R3 R4 R0 +DIVRK R4 K2 [2] R0 +LOADN R6 2 +MOD R5 R6 R0 +LOADN R7 2 +IDIV R6 R7 R0 +LOADN R8 2 +POW R7 R8 R0 +RETURN R1 7 +)" + ); } TEST_SUITE_END(); diff --git a/tests/Config.test.cpp b/tests/Config.test.cpp index e6a72672a..70d6d6d7d 100644 --- a/tests/Config.test.cpp +++ b/tests/Config.test.cpp @@ -25,12 +25,14 @@ TEST_CASE("language_mode") TEST_CASE("disable_a_lint_rule") { Config config; - auto err = parseConfig(R"( + auto err = parseConfig( + R"( {"lint": { "UnknownGlobal": false, }} )", - config); + config + ); REQUIRE(!err); CHECK(!config.enabledLint.isEnabled(LintWarning::Code_UnknownGlobal)); @@ -40,12 +42,14 @@ TEST_CASE("disable_a_lint_rule") TEST_CASE("report_a_syntax_error") { Config config; - auto err = parseConfig(R"( + auto err = parseConfig( + R"( {"lint": { "UnknownGlobal": "oops" }} )", - config); + config + ); REQUIRE(err); CHECK_EQ("In key UnknownGlobal: Bad setting 'oops'. Valid options are true and false", *err); @@ -79,7 +83,8 @@ TEST_CASE("lint_warnings_are_ordered") TEST_CASE("comments") { Config config; - auto err = parseConfig(R"( + auto err = parseConfig( + R"( { "lint": { "*": false, @@ -92,7 +97,8 @@ TEST_CASE("comments") } } )", - config); + config + ); REQUIRE(!err); CHECK(!config.enabledLint.isEnabled(LintWarning::Code_LocalShadow)); @@ -105,13 +111,15 @@ TEST_CASE("issue_severity") CHECK(!config.lintErrors); CHECK(config.typeErrors); - auto err = parseConfig(R"( + auto err = parseConfig( + R"( { "lintErrors": true, "typeErrors": false, } )", - config); + config + ); REQUIRE(!err); CHECK(config.lintErrors); @@ -121,12 +129,14 @@ TEST_CASE("issue_severity") TEST_CASE("extra_globals") { Config config; - auto err = parseConfig(R"( + auto err = parseConfig( + R"( { "globals": ["it", "__DEV__"], } )", - config); + config + ); REQUIRE(!err); REQUIRE(config.globals.size() == 2); @@ -137,14 +147,17 @@ TEST_CASE("extra_globals") TEST_CASE("lint_rules_compat") { Config config; - auto err = parseConfig(R"( + auto err = parseConfig( + R"( {"lint": { "SameLineStatement": "enabled", "FunctionUnused": "disabled", "ImportUnused": "fatal", }} )", - config, true); + config, + true + ); REQUIRE(!err); CHECK(config.enabledLint.isEnabled(LintWarning::Code_SameLineStatement)); diff --git a/tests/Conformance.test.cpp b/tests/Conformance.test.cpp index 077310ac8..0a88444a8 100644 --- a/tests/Conformance.test.cpp +++ b/tests/Conformance.test.cpp @@ -2,18 +2,24 @@ #include "lua.h" #include "lualib.h" #include "luacode.h" +#include "luacodegen.h" #include "Luau/BuiltinDefinitions.h" +#include "Luau/DenseHash.h" #include "Luau/ModuleResolver.h" #include "Luau/TypeInfer.h" -#include "Luau/StringUtils.h" #include "Luau/BytecodeBuilder.h" +#include "Luau/Frontend.h" +#include "Luau/Compiler.h" #include "Luau/CodeGen.h" +#include "Luau/BytecodeSummary.h" #include "doctest.h" #include "ScopedFlags.h" +#include "ConformanceIrHooks.h" #include +#include #include #include @@ -21,20 +27,42 @@ extern bool verbose; extern bool codegen; extern int optimizationLevel; +// internal functions, declared in lgc.h - not exposed via lua.h +void luaC_fullgc(lua_State* L); +void luaC_validate(lua_State* L); + +LUAU_FASTFLAG(LuauMathMap) +LUAU_FASTFLAG(DebugLuauAbortingChecks) +LUAU_FASTINT(CodegenHeuristicsInstructionLimit) +LUAU_FASTFLAG(LuauNativeAttribute) +LUAU_DYNAMIC_FASTFLAG(LuauStackLimit) + static lua_CompileOptions defaultOptions() { lua_CompileOptions copts = {}; copts.optimizationLevel = optimizationLevel; copts.debugLevel = 1; + copts.typeInfoLevel = 1; + + copts.vectorCtor = "vector"; + copts.vectorType = "vector"; return copts; } +static Luau::CodeGen::CompilationOptions defaultCodegenOptions() +{ + Luau::CodeGen::CompilationOptions opts = {}; + opts.flags = Luau::CodeGen::CodeGen_ColdFunctions; + return opts; +} + static int lua_collectgarbage(lua_State* L) { static const char* const opts[] = {"stop", "restart", "collect", "count", "isrunning", "step", "setgoal", "setstepmul", "setstepsize", nullptr}; static const int optsnum[] = { - LUA_GCSTOP, LUA_GCRESTART, LUA_GCCOLLECT, LUA_GCCOUNT, LUA_GCISRUNNING, LUA_GCSTEP, LUA_GCSETGOAL, LUA_GCSETSTEPMUL, LUA_GCSETSTEPSIZE}; + LUA_GCSTOP, LUA_GCRESTART, LUA_GCCOLLECT, LUA_GCCOUNT, LUA_GCISRUNNING, LUA_GCSTEP, LUA_GCSETGOAL, LUA_GCSETSTEPMUL, LUA_GCSETSTEPSIZE + }; int o = luaL_checkoption(L, 1, "collect", opts); int ex = luaL_optinteger(L, 2, 0); @@ -100,6 +128,20 @@ static int lua_vector_dot(lua_State* L) return 1; } +static int lua_vector_cross(lua_State* L) +{ + const float* a = luaL_checkvector(L, 1); + const float* b = luaL_checkvector(L, 2); + +#if LUA_VECTOR_SIZE == 4 + lua_pushvector(L, a[1] * b[2] - a[2] * b[1], a[2] * b[0] - a[0] * b[2], a[0] * b[1] - a[1] * b[0], 0.0f); +#else + lua_pushvector(L, a[1] * b[2] - a[2] * b[1], a[2] * b[0] - a[0] * b[2], a[0] * b[1] - a[1] * b[0]); +#endif + + return 1; +} + static int lua_vector_index(lua_State* L) { const float* v = luaL_checkvector(L, 1); @@ -107,7 +149,25 @@ static int lua_vector_index(lua_State* L) if (strcmp(name, "Magnitude") == 0) { +#if LUA_VECTOR_SIZE == 4 + lua_pushnumber(L, sqrtf(v[0] * v[0] + v[1] * v[1] + v[2] * v[2] + v[3] * v[3])); +#else lua_pushnumber(L, sqrtf(v[0] * v[0] + v[1] * v[1] + v[2] * v[2])); +#endif + return 1; + } + + if (strcmp(name, "Unit") == 0) + { +#if LUA_VECTOR_SIZE == 4 + float invSqrt = 1.0f / sqrtf(v[0] * v[0] + v[1] * v[1] + v[2] * v[2] + v[3] * v[3]); + + lua_pushvector(L, v[0] * invSqrt, v[1] * invSqrt, v[2] * invSqrt, v[3] * invSqrt); +#else + float invSqrt = 1.0f / sqrtf(v[0] * v[0] + v[1] * v[1] + v[2] * v[2]); + + lua_pushvector(L, v[0] * invSqrt, v[1] * invSqrt, v[2] * invSqrt); +#endif return 1; } @@ -126,6 +186,9 @@ static int lua_vector_namecall(lua_State* L) { if (strcmp(str, "Dot") == 0) return lua_vector_dot(L); + + if (strcmp(str, "Cross") == 0) + return lua_vector_cross(L); } luaL_error(L, "%s is not a valid method of vector", luaL_checkstring(L, 1)); @@ -138,15 +201,29 @@ int lua_silence(lua_State* L) using StateRef = std::unique_ptr; -static StateRef runConformance(const char* name, void (*setup)(lua_State* L) = nullptr, void (*yield)(lua_State* L) = nullptr, - lua_State* initialLuaState = nullptr, lua_CompileOptions* options = nullptr, bool skipCodegen = false) -{ +static StateRef runConformance( + const char* name, + void (*setup)(lua_State* L) = nullptr, + void (*yield)(lua_State* L) = nullptr, + lua_State* initialLuaState = nullptr, + lua_CompileOptions* options = nullptr, + bool skipCodegen = false, + Luau::CodeGen::CompilationOptions* codegenOptions = nullptr +) +{ +#ifdef LUAU_CONFORMANCE_SOURCE_DIR + std::string path = LUAU_CONFORMANCE_SOURCE_DIR; + path += "/"; + path += name; +#else std::string path = __FILE__; path.erase(path.find_last_of("\\/")); path += "/conformance/"; path += name; +#endif std::fstream stream(path, std::ios::in | std::ios::binary); + INFO(path); REQUIRE(stream); std::string source(std::istreambuf_iterator(stream), {}); @@ -158,8 +235,8 @@ static StateRef runConformance(const char* name, void (*setup)(lua_State* L) = n StateRef globalState(initialLuaState, lua_close); lua_State* L = globalState.get(); - if (codegen && !skipCodegen && Luau::CodeGen::isSupported()) - Luau::CodeGen::create(L); + if (codegen && !skipCodegen && luau_codegen_supported()) + luau_codegen_create(L); luaL_openlibs(L); @@ -212,8 +289,12 @@ static StateRef runConformance(const char* name, void (*setup)(lua_State* L) = n int result = luau_load(L, chunkname.c_str(), bytecode, bytecodeSize, 0); free(bytecode); - if (result == 0 && codegen && !skipCodegen && Luau::CodeGen::isSupported()) - Luau::CodeGen::compile(L, -1); + if (result == 0 && codegen && !skipCodegen && luau_codegen_supported()) + { + Luau::CodeGen::CompilationOptions nativeOpts = codegenOptions ? *codegenOptions : defaultCodegenOptions(); + + Luau::CodeGen::compile(L, -1, nativeOpts); + } int status = (result == 0) ? lua_resume(L, nullptr, 0) : LUA_ERRSYNTAX; @@ -223,7 +304,6 @@ static StateRef runConformance(const char* name, void (*setup)(lua_State* L) = n status = lua_resume(L, nullptr, 0); } - extern void luaC_validate(lua_State * L); // internal function, declared in lgc.h - not exposed via lua.h luaC_validate(L); if (status == 0) @@ -243,8 +323,320 @@ static StateRef runConformance(const char* name, void (*setup)(lua_State* L) = n return globalState; } +static void* limitedRealloc(void* ud, void* ptr, size_t osize, size_t nsize) +{ + if (nsize == 0) + { + free(ptr); + return nullptr; + } + else if (nsize > 8 * 1024 * 1024) + { + // For testing purposes return null for large allocations so we can generate errors related to memory allocation failures + return nullptr; + } + else + { + return realloc(ptr, nsize); + } +} + +void setupVectorHelpers(lua_State* L) +{ + lua_pushcfunction(L, lua_vector, "vector"); + lua_setglobal(L, "vector"); + +#if LUA_VECTOR_SIZE == 4 + lua_pushvector(L, 0.0f, 0.0f, 0.0f, 0.0f); +#else + lua_pushvector(L, 0.0f, 0.0f, 0.0f); +#endif + luaL_newmetatable(L, "vector"); + + lua_pushstring(L, "__index"); + lua_pushcfunction(L, lua_vector_index, nullptr); + lua_settable(L, -3); + + lua_pushstring(L, "__namecall"); + lua_pushcfunction(L, lua_vector_namecall, nullptr); + lua_settable(L, -3); + + lua_setreadonly(L, -1, true); + lua_setmetatable(L, -2); + lua_pop(L, 1); +} + +Vec2* lua_vec2_push(lua_State* L) +{ + Vec2* data = (Vec2*)lua_newuserdatatagged(L, sizeof(Vec2), kTagVec2); + + lua_getuserdatametatable(L, kTagVec2); + lua_setmetatable(L, -2); + + return data; +} + +Vec2* lua_vec2_get(lua_State* L, int idx) +{ + Vec2* a = (Vec2*)lua_touserdatatagged(L, idx, kTagVec2); + + if (a) + return a; + + luaL_typeerror(L, idx, "vec2"); +} + +static int lua_vec2(lua_State* L) +{ + double x = luaL_checknumber(L, 1); + double y = luaL_checknumber(L, 2); + + Vec2* data = lua_vec2_push(L); + + data->x = float(x); + data->y = float(y); + + return 1; +} + +static int lua_vec2_dot(lua_State* L) +{ + Vec2* a = lua_vec2_get(L, 1); + Vec2* b = lua_vec2_get(L, 2); + + lua_pushnumber(L, a->x * b->x + a->y * b->y); + return 1; +} + +static int lua_vec2_min(lua_State* L) +{ + Vec2* a = lua_vec2_get(L, 1); + Vec2* b = lua_vec2_get(L, 2); + + Vec2* data = lua_vec2_push(L); + + data->x = a->x < b->x ? a->x : b->x; + data->y = a->y < b->y ? a->y : b->y; + + return 1; +} + +static int lua_vec2_index(lua_State* L) +{ + Vec2* v = lua_vec2_get(L, 1); + const char* name = luaL_checkstring(L, 2); + + if (strcmp(name, "X") == 0) + { + lua_pushnumber(L, v->x); + return 1; + } + + if (strcmp(name, "Y") == 0) + { + lua_pushnumber(L, v->y); + return 1; + } + + if (strcmp(name, "Magnitude") == 0) + { + lua_pushnumber(L, sqrtf(v->x * v->x + v->y * v->y)); + return 1; + } + + if (strcmp(name, "Unit") == 0) + { + float invSqrt = 1.0f / sqrtf(v->x * v->x + v->y * v->y); + + Vec2* data = lua_vec2_push(L); + + data->x = v->x * invSqrt; + data->y = v->y * invSqrt; + return 1; + } + + luaL_error(L, "%s is not a valid member of vector", name); +} + +static int lua_vec2_namecall(lua_State* L) +{ + if (const char* str = lua_namecallatom(L, nullptr)) + { + if (strcmp(str, "Dot") == 0) + return lua_vec2_dot(L); + + if (strcmp(str, "Min") == 0) + return lua_vec2_min(L); + } + + luaL_error(L, "%s is not a valid method of vector", luaL_checkstring(L, 1)); +} + +void setupUserdataHelpers(lua_State* L) +{ + // create metatable with all the metamethods + luaL_newmetatable(L, "vec2"); + luaL_getmetatable(L, "vec2"); + lua_pushvalue(L, -1); + lua_setuserdatametatable(L, kTagVec2, -1); + + lua_pushcfunction(L, lua_vec2_index, nullptr); + lua_setfield(L, -2, "__index"); + + lua_pushcfunction(L, lua_vec2_namecall, nullptr); + lua_setfield(L, -2, "__namecall"); + + lua_pushcclosurek( + L, + [](lua_State* L) + { + Vec2* a = lua_vec2_get(L, 1); + Vec2* b = lua_vec2_get(L, 2); + Vec2* data = lua_vec2_push(L); + + data->x = a->x + b->x; + data->y = a->y + b->y; + + return 1; + }, + nullptr, + 0, + nullptr + ); + lua_setfield(L, -2, "__add"); + + lua_pushcclosurek( + L, + [](lua_State* L) + { + Vec2* a = lua_vec2_get(L, 1); + Vec2* b = lua_vec2_get(L, 2); + Vec2* data = lua_vec2_push(L); + + data->x = a->x - b->x; + data->y = a->y - b->y; + + return 1; + }, + nullptr, + 0, + nullptr + ); + lua_setfield(L, -2, "__sub"); + + lua_pushcclosurek( + L, + [](lua_State* L) + { + Vec2* a = lua_vec2_get(L, 1); + Vec2* b = lua_vec2_get(L, 2); + Vec2* data = lua_vec2_push(L); + + data->x = a->x * b->x; + data->y = a->y * b->y; + + return 1; + }, + nullptr, + 0, + nullptr + ); + lua_setfield(L, -2, "__mul"); + + lua_pushcclosurek( + L, + [](lua_State* L) + { + Vec2* a = lua_vec2_get(L, 1); + Vec2* b = lua_vec2_get(L, 2); + Vec2* data = lua_vec2_push(L); + + data->x = a->x / b->x; + data->y = a->y / b->y; + + return 1; + }, + nullptr, + 0, + nullptr + ); + lua_setfield(L, -2, "__div"); + + lua_pushcclosurek( + L, + [](lua_State* L) + { + Vec2* a = lua_vec2_get(L, 1); + Vec2* data = lua_vec2_push(L); + + data->x = -a->x; + data->y = -a->y; + + return 1; + }, + nullptr, + 0, + nullptr + ); + lua_setfield(L, -2, "__unm"); + + lua_setreadonly(L, -1, true); + + // ctor + lua_pushcfunction(L, lua_vec2, "vec2"); + lua_setglobal(L, "vec2"); + + lua_pop(L, 1); +} + +static void setupNativeHelpers(lua_State* L) +{ + lua_pushcclosurek( + L, + [](lua_State* L) -> int + { + extern int luaG_isnative(lua_State * L, int level); + + lua_pushboolean(L, luaG_isnative(L, 1)); + return 1; + }, + "is_native", + 0, + nullptr + ); + lua_setglobal(L, "is_native"); +} + +static std::vector analyzeFile(const char* source, const unsigned nestingLimit) +{ + Luau::BytecodeBuilder bcb; + + Luau::CompileOptions options; + options.optimizationLevel = optimizationLevel; + options.debugLevel = 1; + options.typeInfoLevel = 1; + + compileOrThrow(bcb, source, options); + + const std::string& bytecode = bcb.getBytecode(); + + std::unique_ptr globalState(luaL_newstate(), lua_close); + lua_State* L = globalState.get(); + + int result = luau_load(L, "source", bytecode.data(), bytecode.size(), 0); + REQUIRE(result == 0); + + return Luau::CodeGen::summarizeBytecode(L, -1, nestingLimit); +} + TEST_SUITE_BEGIN("Conformance"); +TEST_CASE("CodegenSupported") +{ + if (codegen && !luau_codegen_supported()) + MESSAGE("Native code generation is not supported by the current configuration and will be disabled"); +} + TEST_CASE("Assert") { runConformance("assert.lua"); @@ -255,24 +647,46 @@ TEST_CASE("Basic") runConformance("basic.lua"); } +TEST_CASE("Buffers") +{ + runConformance("buffers.lua"); +} + TEST_CASE("Math") { + ScopedFastFlag LuauMathMap{FFlag::LuauMathMap, true}; + runConformance("math.lua"); } TEST_CASE("Tables") { - runConformance("tables.lua", [](lua_State* L) { - lua_pushcfunction( - L, - [](lua_State* L) { - unsigned v = luaL_checkunsigned(L, 1); - lua_pushlightuserdata(L, reinterpret_cast(uintptr_t(v))); - return 1; - }, - "makelud"); - lua_setglobal(L, "makelud"); - }); + runConformance( + "tables.lua", + [](lua_State* L) + { + lua_pushcfunction( + L, + [](lua_State* L) + { + if (lua_type(L, 1) == LUA_TNUMBER) + { + unsigned v = luaL_checkunsigned(L, 1); + lua_pushlightuserdata(L, reinterpret_cast(uintptr_t(v))); + } + else + { + const void* p = lua_topointer(L, 1); + LUAU_ASSERT(p); // we expect the test call to only pass GC values here + lua_pushlightuserdata(L, const_cast(p)); + } + return 1; + }, + "makelud" + ); + lua_setglobal(L, "makelud"); + } + ); } TEST_CASE("PatternMatch") @@ -342,6 +756,8 @@ TEST_CASE("Closure") TEST_CASE("Calls") { + ScopedFastFlag LuauStackLimit{DFFlag::LuauStackLimit, true}; + runConformance("calls.lua"); } @@ -381,21 +797,31 @@ static int cxxthrow(lua_State* L) TEST_CASE("PCall") { - runConformance("pcall.lua", [](lua_State* L) { - lua_pushcfunction(L, cxxthrow, "cxxthrow"); - lua_setglobal(L, "cxxthrow"); - - lua_pushcfunction( - L, - [](lua_State* L) -> int { - lua_State* co = lua_tothread(L, 1); - lua_xmove(L, co, 1); - lua_resumeerror(co, L); - return 0; - }, - "resumeerror"); - lua_setglobal(L, "resumeerror"); - }); + ScopedFastFlag LuauStackLimit{DFFlag::LuauStackLimit, true}; + + runConformance( + "pcall.lua", + [](lua_State* L) + { + lua_pushcfunction(L, cxxthrow, "cxxthrow"); + lua_setglobal(L, "cxxthrow"); + + lua_pushcfunction( + L, + [](lua_State* L) -> int + { + lua_State* co = lua_tothread(L, 1); + lua_xmove(L, co, 1); + lua_resumeerror(co, L); + return 0; + }, + "resumeerror" + ); + lua_setglobal(L, "resumeerror"); + }, + nullptr, + lua_newstate(limitedRealloc, nullptr) + ); } TEST_CASE("Pack") @@ -406,34 +832,56 @@ TEST_CASE("Pack") TEST_CASE("Vector") { lua_CompileOptions copts = defaultOptions(); - copts.vectorCtor = "vector"; - - runConformance( - "vector.lua", - [](lua_State* L) { - lua_pushcfunction(L, lua_vector, "vector"); - lua_setglobal(L, "vector"); + Luau::CodeGen::CompilationOptions nativeOpts = defaultCodegenOptions(); -#if LUA_VECTOR_SIZE == 4 - lua_pushvector(L, 0.0f, 0.0f, 0.0f, 0.0f); -#else - lua_pushvector(L, 0.0f, 0.0f, 0.0f); -#endif - luaL_newmetatable(L, "vector"); - - lua_pushstring(L, "__index"); - lua_pushcfunction(L, lua_vector_index, nullptr); - lua_settable(L, -3); + SUBCASE("NoIrHooks") + { + SUBCASE("O0") + { + copts.optimizationLevel = 0; + } + SUBCASE("O1") + { + copts.optimizationLevel = 1; + } + SUBCASE("O2") + { + copts.optimizationLevel = 2; + } + } + SUBCASE("IrHooks") + { + nativeOpts.hooks.vectorAccessBytecodeType = vectorAccessBytecodeType; + nativeOpts.hooks.vectorNamecallBytecodeType = vectorNamecallBytecodeType; + nativeOpts.hooks.vectorAccess = vectorAccess; + nativeOpts.hooks.vectorNamecall = vectorNamecall; - lua_pushstring(L, "__namecall"); - lua_pushcfunction(L, lua_vector_namecall, nullptr); - lua_settable(L, -3); + SUBCASE("O0") + { + copts.optimizationLevel = 0; + } + SUBCASE("O1") + { + copts.optimizationLevel = 1; + } + SUBCASE("O2") + { + copts.optimizationLevel = 2; + } + } - lua_setreadonly(L, -1, true); - lua_setmetatable(L, -2); - lua_pop(L, 1); + runConformance( + "vector.lua", + [](lua_State* L) + { + setupVectorHelpers(L); }, - nullptr, nullptr, &copts); + nullptr, + nullptr, + &copts, + false, + &nativeOpts + ); } static void populateRTTI(lua_State* L, Luau::TypeId type) @@ -462,6 +910,10 @@ static void populateRTTI(lua_State* L, Luau::TypeId type) lua_pushstring(L, "thread"); break; + case Luau::PrimitiveType::Buffer: + lua_pushstring(L, "buffer"); + break; + default: LUAU_ASSERT(!"Unknown primitive type"); } @@ -472,7 +924,7 @@ static void populateRTTI(lua_State* L, Luau::TypeId type) for (const auto& [name, prop] : t->props) { - populateRTTI(L, prop.type); + populateRTTI(L, prop.type()); lua_setfield(L, -2, name.c_str()); } } @@ -499,25 +951,28 @@ static void populateRTTI(lua_State* L, Luau::TypeId type) TEST_CASE("Types") { - runConformance("types.lua", [](lua_State* L) { - Luau::NullModuleResolver moduleResolver; - Luau::InternalErrorReporter iceHandler; - Luau::BuiltinTypes builtinTypes; - Luau::TypeChecker env(&moduleResolver, Luau::NotNull{&builtinTypes}, &iceHandler); + runConformance( + "types.lua", + [](lua_State* L) + { + Luau::NullModuleResolver moduleResolver; + Luau::NullFileResolver fileResolver; + Luau::NullConfigResolver configResolver; + Luau::Frontend frontend{&fileResolver, &configResolver}; + Luau::registerBuiltinGlobals(frontend, frontend.globals); + Luau::freeze(frontend.globals.globalTypes); - Luau::registerBuiltinGlobals(env); - Luau::freeze(env.globalTypes); + lua_newtable(L); - lua_newtable(L); + for (const auto& [name, binding] : frontend.globals.globalScope->bindings) + { + populateRTTI(L, binding.typeId); + lua_setfield(L, -2, toString(name).c_str()); + } - for (const auto& [name, binding] : env.globalScope->bindings) - { - populateRTTI(L, binding.typeId); - lua_setfield(L, -2, toString(name).c_str()); + lua_setglobal(L, "RTTI"); } - - lua_setglobal(L, "RTTI"); - }); + ); } TEST_CASE("DateTime") @@ -555,18 +1010,21 @@ TEST_CASE("Debugger") runConformance( "debugger.lua", - [](lua_State* L) { + [](lua_State* L) + { lua_Callbacks* cb = lua_callbacks(L); lua_singlestep(L, singlestep); // this will only be called in single-step mode - cb->debugstep = [](lua_State* L, lua_Debug* ar) { + cb->debugstep = [](lua_State* L, lua_Debug* ar) + { stephits++; }; // for breakpoints to work we should make sure debugbreak is installed - cb->debugbreak = [](lua_State* L, lua_Debug* ar) { + cb->debugbreak = [](lua_State* L, lua_Debug* ar) + { breakhits++; // make sure we can trace the stack for every breakpoint we hit @@ -580,7 +1038,8 @@ TEST_CASE("Debugger") }; // for resuming off a breakpoint inside a coroutine we need to resume the interrupted coroutine - cb->debuginterrupt = [](lua_State* L, lua_Debug* ar) { + cb->debuginterrupt = [](lua_State* L, lua_Debug* ar) + { CHECK(interruptedthread == nullptr); CHECK(ar->userdata); // userdata contains the interrupted thread @@ -590,7 +1049,8 @@ TEST_CASE("Debugger") // add breakpoint() function lua_pushcclosurek( L, - [](lua_State* L) -> int { + [](lua_State* L) -> int + { int line = luaL_checkinteger(L, 1); bool enabled = luaL_optboolean(L, 2, true); @@ -600,10 +1060,14 @@ TEST_CASE("Debugger") lua_breakpoint(L, -1, line, enabled); return 0; }, - "breakpoint", 0, nullptr); + "breakpoint", + 0, + nullptr + ); lua_setglobal(L, "breakpoint"); }, - [](lua_State* L) { + [](lua_State* L) + { CHECK(breakhits % 2 == 1); lua_checkstack(L, LUA_MINSTACK); @@ -674,6 +1138,26 @@ TEST_CASE("Debugger") CHECK(lua_tointeger(L, -1) == 9); lua_pop(L, 1); } + else if (breakhits == 13) + { + // validate assignment via lua_getlocal + const char* l = lua_getlocal(L, 0, 1); + REQUIRE(l); + CHECK(strcmp(l, "a") == 0); + CHECK(lua_isnil(L, -1)); + lua_pop(L, 1); + } + else if (breakhits == 15) + { + // test lua_getlocal + const char* x = lua_getlocal(L, 2, 1); + REQUIRE(x); + CHECK(strcmp(x, "x") == 0); + lua_pop(L, 1); + + const char* a1 = lua_getlocal(L, 2, 2); + REQUIRE(!a1); + } if (interruptedthread) { @@ -681,9 +1165,12 @@ TEST_CASE("Debugger") interruptedthread = nullptr; } }, - nullptr, &copts, /* skipCodegen */ true); // Native code doesn't support debugging yet + nullptr, + &copts, + /* skipCodegen */ true + ); // Native code doesn't support debugging yet - CHECK(breakhits == 12); // 2 hits per breakpoint + CHECK(breakhits == 16); // 2 hits per breakpoint if (singlestep) CHECK(stephits > 100); // note; this will depend on number of instructions which can vary, so we just make sure the callback gets hit often @@ -697,8 +1184,10 @@ TEST_CASE("NDebugGetUpValue") copts.optimizationLevel = 0; runConformance( - "ndebug_upvalues.lua", nullptr, - [](lua_State* L) { + "ndebug_upvalues.lua", + nullptr, + [](lua_State* L) + { lua_checkstack(L, LUA_MINSTACK); // push the second frame's closure to the stack @@ -713,7 +1202,10 @@ TEST_CASE("NDebugGetUpValue") CHECK(lua_tointeger(L, -1) == 5); lua_pop(L, 2); }, - nullptr, &copts, /* skipCodegen */ false); + nullptr, + &copts, + /* skipCodegen */ false + ); } TEST_CASE("SameHash") @@ -744,12 +1236,22 @@ TEST_CASE("Reference") lua_State* L = globalState.get(); // note, we push two userdata objects but only pin one of them (the first one) - lua_newuserdatadtor(L, 0, [](void*) { - dtorhits++; - }); - lua_newuserdatadtor(L, 0, [](void*) { - dtorhits++; - }); + lua_newuserdatadtor( + L, + 0, + [](void*) + { + dtorhits++; + } + ); + lua_newuserdatadtor( + L, + 0, + [](void*) + { + dtorhits++; + } + ); lua_gc(L, LUA_GCCOLLECT, 0); CHECK(dtorhits == 0); @@ -780,19 +1282,32 @@ TEST_CASE("NewUserdataOverflow") lua_pushcfunction( L, - [](lua_State* L1) { + [](lua_State* L1) + { // The following userdata request might cause an overflow. lua_newuserdatadtor(L1, SIZE_MAX, [](void* d) {}); // The overflow might segfault in the following call. lua_getmetatable(L1, -1); return 0; }, - nullptr); + nullptr + ); CHECK(lua_pcall(L, 0, 0, 0) == LUA_ERRRUN); CHECK(strcmp(lua_tostring(L, -1), "memory allocation error: block too big") == 0); } +TEST_CASE("SandboxWithoutLibs") +{ + StateRef globalState(luaL_newstate(), lua_close); + lua_State* L = globalState.get(); + + luaopen_base(L); // Load only base library + luaL_sandbox(L); + + CHECK(lua_getreadonly(L, LUA_GLOBALSINDEX)); +} + TEST_CASE("ApiTables") { StateRef globalState(luaL_newstate(), lua_close); @@ -881,7 +1396,7 @@ TEST_CASE("ApiIter") TEST_CASE("ApiCalls") { - StateRef globalState = runConformance("apicalls.lua"); + StateRef globalState = runConformance("apicalls.lua", nullptr, nullptr, lua_newstate(limitedRealloc, nullptr)); lua_State* L = globalState.get(); // lua_call @@ -980,6 +1495,53 @@ TEST_CASE("ApiCalls") CHECK(lua_tonumber(L, -1) == 4); lua_pop(L, 1); } + + // lua_pcall on OOM + { + lua_getfield(L, LUA_GLOBALSINDEX, "largealloc"); + int res = lua_pcall(L, 0, 0, 0); + CHECK(res == LUA_ERRMEM); + } + + // lua_pcall on OOM with an error handler + { + lua_getfield(L, LUA_GLOBALSINDEX, "oops"); + lua_getfield(L, LUA_GLOBALSINDEX, "largealloc"); + int res = lua_pcall(L, 0, 1, -2); + CHECK(res == LUA_ERRMEM); + CHECK((lua_isstring(L, -1) && strcmp(lua_tostring(L, -1), "oops") == 0)); + lua_pop(L, 1); + } + + // lua_pcall on OOM with an error handler that errors + { + lua_getfield(L, LUA_GLOBALSINDEX, "error"); + lua_getfield(L, LUA_GLOBALSINDEX, "largealloc"); + int res = lua_pcall(L, 0, 1, -2); + CHECK(res == LUA_ERRERR); + CHECK((lua_isstring(L, -1) && strcmp(lua_tostring(L, -1), "error in error handling") == 0)); + lua_pop(L, 1); + } + + // lua_pcall on OOM with an error handler that OOMs + { + lua_getfield(L, LUA_GLOBALSINDEX, "largealloc"); + lua_getfield(L, LUA_GLOBALSINDEX, "largealloc"); + int res = lua_pcall(L, 0, 1, -2); + CHECK(res == LUA_ERRMEM); + CHECK((lua_isstring(L, -1) && strcmp(lua_tostring(L, -1), "not enough memory") == 0)); + lua_pop(L, 1); + } + + // lua_pcall on error with an error handler that OOMs + { + lua_getfield(L, LUA_GLOBALSINDEX, "largealloc"); + lua_getfield(L, LUA_GLOBALSINDEX, "error"); + int res = lua_pcall(L, 0, 1, -2); + CHECK(res == LUA_ERRERR); + CHECK((lua_isstring(L, -1) && strcmp(lua_tostring(L, -1), "error in error handling") == 0)); + lua_pop(L, 1); + } } TEST_CASE("ApiAtoms") @@ -987,7 +1549,8 @@ TEST_CASE("ApiAtoms") StateRef globalState(luaL_newstate(), lua_close); lua_State* L = globalState.get(); - lua_callbacks(L)->useratom = [](const char* s, size_t l) -> int16_t { + lua_callbacks(L)->useratom = [](const char* s, size_t l) -> int16_t + { if (strcmp(s, "string") == 0) return 0; if (strcmp(s, "important") == 0) @@ -1025,6 +1588,91 @@ static bool endsWith(const std::string& str, const std::string& suffix) return suffix == std::string_view(str.c_str() + str.length() - suffix.length(), suffix.length()); } +TEST_CASE("ApiType") +{ + StateRef globalState(luaL_newstate(), lua_close); + lua_State* L = globalState.get(); + + lua_pushnumber(L, 2); + CHECK(strcmp(luaL_typename(L, -1), "number") == 0); + CHECK(strcmp(luaL_typename(L, 1), "number") == 0); + CHECK(lua_type(L, -1) == LUA_TNUMBER); + CHECK(lua_type(L, 1) == LUA_TNUMBER); + + CHECK(strcmp(luaL_typename(L, 2), "no value") == 0); + CHECK(lua_type(L, 2) == LUA_TNONE); + CHECK(strcmp(lua_typename(L, lua_type(L, 2)), "no value") == 0); + + lua_newuserdata(L, 0); + CHECK(strcmp(luaL_typename(L, -1), "userdata") == 0); + CHECK(lua_type(L, -1) == LUA_TUSERDATA); + + lua_newtable(L); + lua_pushstring(L, "hello"); + lua_setfield(L, -2, "__type"); + lua_setmetatable(L, -2); + + CHECK(strcmp(luaL_typename(L, -1), "hello") == 0); + CHECK(lua_type(L, -1) == LUA_TUSERDATA); +} + +TEST_CASE("ApiBuffer") +{ + StateRef globalState(luaL_newstate(), lua_close); + lua_State* L = globalState.get(); + + lua_newbuffer(L, 1000); + + REQUIRE(lua_type(L, -1) == LUA_TBUFFER); + + CHECK(lua_isbuffer(L, -1)); + CHECK(lua_objlen(L, -1) == 1000); + + CHECK(strcmp(lua_typename(L, LUA_TBUFFER), "buffer") == 0); + + CHECK(strcmp(luaL_typename(L, -1), "buffer") == 0); + + void* p1 = lua_tobuffer(L, -1, nullptr); + + size_t len = 0; + void* p2 = lua_tobuffer(L, -1, &len); + CHECK(len == 1000); + CHECK(p1 == p2); + + void* p3 = luaL_checkbuffer(L, -1, nullptr); + CHECK(p1 == p3); + + len = 0; + void* p4 = luaL_checkbuffer(L, -1, &len); + CHECK(len == 1000); + CHECK(p1 == p4); + + memset(p1, 0xab, 1000); + + CHECK(lua_topointer(L, -1) != nullptr); + + lua_newbuffer(L, 0); + + lua_pushvalue(L, -2); + + CHECK(lua_equal(L, -3, -1)); + CHECK(!lua_equal(L, -2, -1)); + + lua_pop(L, 1); +} + +TEST_CASE("AllocApi") +{ + int ud = 0; + StateRef globalState(lua_newstate(limitedRealloc, &ud), lua_close); + lua_State* L = globalState.get(); + + void* udCheck = nullptr; + bool allocfIsSet = lua_getallocf(L, &udCheck) == limitedRealloc; + CHECK(allocfIsSet); + CHECK(udCheck == &ud); +} + #if !LUA_USE_LONGJMP TEST_CASE("ExceptionObject") { @@ -1034,7 +1682,8 @@ TEST_CASE("ExceptionObject") std::string description; }; - auto captureException = [](lua_State* L, const char* functionToRun) { + auto captureException = [](lua_State* L, const char* functionToRun) + { try { lua_State* threadState = lua_newthread(L); @@ -1050,26 +1699,7 @@ TEST_CASE("ExceptionObject") return ExceptionResult{false, ""}; }; - auto reallocFunc = [](void* /*ud*/, void* ptr, size_t /*osize*/, size_t nsize) -> void* { - if (nsize == 0) - { - free(ptr); - return nullptr; - } - else if (nsize > 512 * 1024) - { - // For testing purposes return null for large allocations - // so we can generate exceptions related to memory allocation - // failures. - return nullptr; - } - else - { - return realloc(ptr, nsize); - } - }; - - StateRef globalState = runConformance("exceptions.lua", nullptr, nullptr, lua_newstate(reallocFunc, nullptr)); + StateRef globalState = runConformance("exceptions.lua", nullptr, nullptr, lua_newstate(limitedRealloc, nullptr)); lua_State* L = globalState.get(); { @@ -1111,15 +1741,70 @@ TEST_CASE("IfElseExpression") runConformance("ifelseexpr.lua"); } +// Optionally returns debug info for the first Luau stack frame that is encountered on the callstack. +static std::optional getFirstLuauFrameDebugInfo(lua_State* L) +{ + static std::string_view kLua = "Lua"; + lua_Debug ar; + for (int i = 0; lua_getinfo(L, i, "sl", &ar); i++) + { + if (kLua == ar.what) + return ar; + } + return std::nullopt; +} + TEST_CASE("TagMethodError") { - runConformance("tmerror.lua", [](lua_State* L) { - auto* cb = lua_callbacks(L); + static std::vector expectedHits; + + // Loop over two modes: + // when doLuaBreak is false the test only verifies that callbacks occur on the expected lines in the Luau source + // when doLuaBreak is true the test additionally calls lua_break to ensure breaking the debugger doesn't cause the VM to crash + for (bool doLuaBreak : {false, true}) + { + expectedHits = {22, 32}; + + static int index; + static bool luaBreak; + index = 0; + luaBreak = doLuaBreak; + + // 'yieldCallback' doesn't do anything, but providing the callback to runConformance + // ensures that the call to lua_break doesn't cause an error to be generated because + // runConformance doesn't expect the VM to be in the state LUA_BREAK. + auto yieldCallback = [](lua_State* L) {}; + + runConformance( + "tmerror.lua", + [](lua_State* L) + { + auto* cb = lua_callbacks(L); + + cb->debugprotectederror = [](lua_State* L) + { + std::optional ar = getFirstLuauFrameDebugInfo(L); + + CHECK(lua_isyieldable(L)); + REQUIRE(ar.has_value()); + REQUIRE(index < int(std::size(expectedHits))); + CHECK(ar->currentline == expectedHits[index++]); + + if (luaBreak) + { + // Cause luau execution to break when 'error' is called via 'pcall' + // This call to lua_break is a regression test for an issue where debugprotectederror + // was called on a thread that couldn't be yielded even though lua_isyieldable was true. + lua_break(L); + } + }; + }, + yieldCallback + ); - cb->debugprotectederror = [](lua_State* L) { - CHECK(lua_isyieldable(L)); - }; - }); + // Make sure the number of break points hit was the expected number + CHECK(index == std::size(expectedHits)); + } } TEST_CASE("Coverage") @@ -1130,43 +1815,55 @@ TEST_CASE("Coverage") runConformance( "coverage.lua", - [](lua_State* L) { + [](lua_State* L) + { lua_pushcfunction( L, - [](lua_State* L) -> int { + [](lua_State* L) -> int + { luaL_argexpected(L, lua_isLfunction(L, 1), 1, "function"); lua_newtable(L); - lua_getcoverage(L, 1, L, [](void* context, const char* function, int linedefined, int depth, const int* hits, size_t size) { - lua_State* L = static_cast(context); + lua_getcoverage( + L, + 1, + L, + [](void* context, const char* function, int linedefined, int depth, const int* hits, size_t size) + { + lua_State* L = static_cast(context); - lua_newtable(L); + lua_newtable(L); - lua_pushstring(L, function); - lua_setfield(L, -2, "name"); + lua_pushstring(L, function); + lua_setfield(L, -2, "name"); - lua_pushinteger(L, linedefined); - lua_setfield(L, -2, "linedefined"); + lua_pushinteger(L, linedefined); + lua_setfield(L, -2, "linedefined"); - lua_pushinteger(L, depth); - lua_setfield(L, -2, "depth"); + lua_pushinteger(L, depth); + lua_setfield(L, -2, "depth"); - for (size_t i = 0; i < size; ++i) - if (hits[i] != -1) - { - lua_pushinteger(L, hits[i]); - lua_rawseti(L, -2, int(i)); - } + for (size_t i = 0; i < size; ++i) + if (hits[i] != -1) + { + lua_pushinteger(L, hits[i]); + lua_rawseti(L, -2, int(i)); + } - lua_rawseti(L, -2, lua_objlen(L, -2) + 1); - }); + lua_rawseti(L, -2, lua_objlen(L, -2) + 1); + } + ); return 1; }, - "getcoverage"); + "getcoverage" + ); lua_setglobal(L, "getcoverage"); }, - nullptr, nullptr, &copts); + nullptr, + nullptr, + &copts + ); } TEST_CASE("StringConversion") @@ -1177,7 +1874,13 @@ TEST_CASE("StringConversion") TEST_CASE("GCDump") { // internal function, declared in lgc.h - not exposed via lua.h - extern void luaC_dump(lua_State * L, void* file, const char* (*categoryName)(lua_State * L, uint8_t memcat)); + extern void luaC_dump(lua_State * L, void* file, const char* (*categoryName)(lua_State* L, uint8_t memcat)); + extern void luaC_enumheap( + lua_State * L, + void* context, + void (*node)(void* context, void* ptr, uint8_t tt, uint8_t memcat, size_t size, const char* name), + void (*edge)(void* context, void* from, void* to, const char* name) + ); StateRef globalState(luaL_newstate(), lua_close); lua_State* L = globalState.get(); @@ -1187,6 +1890,9 @@ TEST_CASE("GCDump") lua_pushstring(L, "value"); lua_setfield(L, -2, "key"); + lua_pushstring(L, "u42"); + lua_setfield(L, -2, "__type"); + lua_pushinteger(L, 42); lua_rawseti(L, -2, 1000); @@ -1203,6 +1909,8 @@ TEST_CASE("GCDump") lua_pushinteger(L, 1); lua_pushcclosure(L, lua_silence, "test", 1); + lua_newbuffer(L, 100); + lua_State* CL = lua_newthread(L); lua_pushstring(CL, "local x x = {} local function f() x[1] = math.abs(42) end function foo() coroutine.yield() end foo() return f"); @@ -1221,6 +1929,49 @@ TEST_CASE("GCDump") luaC_dump(L, f, nullptr); fclose(f); + + struct Node + { + void* ptr; + uint8_t tag; + uint8_t memcat; + size_t size; + std::string name; + }; + + struct EnumContext + { + EnumContext() + : nodes{nullptr} + , edges{nullptr} + { + } + + Luau::DenseHashMap nodes; + Luau::DenseHashMap edges; + } ctx; + + luaC_enumheap( + L, + &ctx, + [](void* ctx, void* gco, uint8_t tt, uint8_t memcat, size_t size, const char* name) + { + EnumContext& context = *(EnumContext*)ctx; + + if (tt == LUA_TUSERDATA) + CHECK(strcmp(name, "u42") == 0); + + context.nodes[gco] = {gco, tt, memcat, size, name ? name : ""}; + }, + [](void* ctx, void* s, void* t, const char*) + { + EnumContext& context = *(EnumContext*)ctx; + context.edges[s] = t; + } + ); + + CHECK(!ctx.nodes.empty()); + CHECK(!ctx.edges.empty()); } TEST_CASE("Interrupt") @@ -1228,66 +1979,123 @@ TEST_CASE("Interrupt") lua_CompileOptions copts = defaultOptions(); copts.optimizationLevel = 1; // disable loop unrolling to get fixed expected hit results - static const int expectedhits[] = { - 2, - 9, - 5, - 5, - 5, - 5, - 5, - 5, - 5, - 5, - 5, - 5, - 5, - 6, - 18, - 13, - 13, - 13, - 13, - 16, - 20, - }; static int index; - index = 0; + StateRef globalState = runConformance("interrupt.lua", nullptr, nullptr, nullptr, &copts); - runConformance( - "interrupt.lua", - [](lua_State* L) { - auto* cb = lua_callbacks(L); + lua_State* L = globalState.get(); - // note: for simplicity here we setup the interrupt callback once - // however, this carries a noticeable performance cost. in a real application, - // it's advised to set interrupt callback on a timer from a different thread, - // and set it back to nullptr once the interrupt triggered. - cb->interrupt = [](lua_State* L, int gc) { - if (gc >= 0) - return; + // note: for simplicity here we setup the interrupt callback when the test starts + // however, this carries a noticeable performance cost. in a real application, + // it's advised to set interrupt callback on a timer from a different thread, + // and set it back to nullptr once the interrupt triggered. - CHECK(index < int(std::size(expectedhits))); + // define the interrupt to check the expected hits + static const int expectedhits[] = {11, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 8, 20, 15, 15, 15, 15, 18, 25, 23, 26}; - lua_Debug ar = {}; - lua_getinfo(L, 0, "l", &ar); + lua_callbacks(L)->interrupt = [](lua_State* L, int gc) + { + if (gc >= 0) + return; - CHECK(ar.currentline == expectedhits[index]); + CHECK(index < int(std::size(expectedhits))); - index++; + lua_Debug ar = {}; + lua_getinfo(L, 0, "l", &ar); - // check that we can yield inside an interrupt - if (index == 5) - lua_yield(L, 0); - }; - }, - [](lua_State* L) { - CHECK(index == 5); // a single yield point - }, - nullptr, &copts); + CHECK(ar.currentline == expectedhits[index]); + + index++; + + // check that we can yield inside an interrupt + if (index == 4) + lua_yield(L, 0); + }; + + { + lua_State* T = lua_newthread(L); + + lua_getglobal(T, "test"); + + index = 0; + int status = lua_resume(T, nullptr, 0); + CHECK(status == LUA_YIELD); + CHECK(index == 4); + + status = lua_resume(T, nullptr, 0); + CHECK(status == LUA_OK); + CHECK(index == int(std::size(expectedhits))); + + lua_pop(L, 1); + } + + // redefine the interrupt to break after 10 iterations of a loop that would otherwise be infinite + // the test exposes a few global functions that we will call; the interrupt will force a yield + lua_callbacks(L)->interrupt = [](lua_State* L, int gc) + { + if (gc >= 0) + return; + + CHECK(index < 11); + if (++index == 11) + lua_yield(L, 0); + }; + + for (int test = 1; test <= 10; ++test) + { + lua_State* T = lua_newthread(L); + + std::string name = "infloop" + std::to_string(test); + lua_getglobal(T, name.c_str()); + + index = 0; + int status = lua_resume(T, nullptr, 0); + CHECK(status == LUA_YIELD); + CHECK(index == 11); + + // abandon the thread + lua_pop(L, 1); + } + + lua_callbacks(L)->interrupt = [](lua_State* L, int gc) + { + if (gc >= 0) + return; + + index++; + + if (index == 1'000) + { + index = 0; + luaL_error(L, "timeout"); + } + }; + + for (int test = 1; test <= 5; ++test) + { + lua_State* T = lua_newthread(L); + + std::string name = "strhang" + std::to_string(test); + lua_getglobal(T, name.c_str()); + + index = 0; + int status = lua_resume(T, nullptr, 0); + CHECK(status == LUA_ERRRUN); + + lua_pop(L, 1); + } + + { + lua_State* T = lua_newthread(L); + + lua_getglobal(T, "strhangpcall"); + + index = 0; + int status = lua_resume(T, nullptr, 0); + CHECK(status == LUA_OK); - CHECK(index == int(std::size(expectedhits))); + lua_pop(L, 1); + } } TEST_CASE("UserdataApi") @@ -1300,9 +2108,15 @@ TEST_CASE("UserdataApi") lua_State* L = globalState.get(); // setup dtor for tag 42 (created later) - lua_setuserdatadtor(L, 42, [](lua_State* l, void* data) { + auto dtor = [](lua_State* l, void* data) + { dtorhits += *(int*)data; - }); + }; + bool dtorIsNull = lua_getuserdatadtor(L, 42) == nullptr; + CHECK(dtorIsNull); + lua_setuserdatadtor(L, 42, dtor); + bool dtorIsSet = lua_getuserdatadtor(L, 42) == dtor; + CHECK(dtorIsSet); // light user data int lud; @@ -1333,13 +2147,23 @@ TEST_CASE("UserdataApi") lua_setuserdatatag(L, -1, 42); // user data with inline dtor - void* ud3 = lua_newuserdatadtor(L, 4, [](void* data) { - dtorhits += *(int*)data; - }); + void* ud3 = lua_newuserdatadtor( + L, + 4, + [](void* data) + { + dtorhits += *(int*)data; + } + ); - void* ud4 = lua_newuserdatadtor(L, 1, [](void* data) { - dtorhits += *(char*)data; - }); + void* ud4 = lua_newuserdatadtor( + L, + 1, + [](void* data) + { + dtorhits += *(char*)data; + } + ); *(int*)ud3 = 43; *(char*)ud4 = 3; @@ -1349,28 +2173,121 @@ TEST_CASE("UserdataApi") luaL_newmetatable(L, "udata2"); void* ud5 = lua_newuserdata(L, 0); - lua_getfield(L, LUA_REGISTRYINDEX, "udata1"); + luaL_getmetatable(L, "udata1"); lua_setmetatable(L, -2); void* ud6 = lua_newuserdata(L, 0); - lua_getfield(L, LUA_REGISTRYINDEX, "udata2"); + luaL_getmetatable(L, "udata2"); lua_setmetatable(L, -2); CHECK(luaL_checkudata(L, -2, "udata1") == ud5); CHECK(luaL_checkudata(L, -1, "udata2") == ud6); + // tagged user data with fast metatable access + luaL_newmetatable(L, "udata3"); + luaL_getmetatable(L, "udata3"); + lua_setuserdatametatable(L, 50, -1); + + luaL_newmetatable(L, "udata4"); + luaL_getmetatable(L, "udata4"); + lua_setuserdatametatable(L, 51, -1); + + void* ud7 = lua_newuserdatatagged(L, 16, 50); + lua_getuserdatametatable(L, 50); + lua_setmetatable(L, -2); + + void* ud8 = lua_newuserdatatagged(L, 16, 51); + lua_getuserdatametatable(L, 51); + lua_setmetatable(L, -2); + + CHECK(luaL_checkudata(L, -2, "udata3") == ud7); + CHECK(luaL_checkudata(L, -1, "udata4") == ud8); + globalState.reset(); CHECK(dtorhits == 42); } +TEST_CASE("LightuserdataApi") +{ + StateRef globalState(luaL_newstate(), lua_close); + lua_State* L = globalState.get(); + + void* value = (void*)0x12345678; + + lua_pushlightuserdatatagged(L, value, 1); + CHECK(lua_lightuserdatatag(L, -1) == 1); + CHECK(lua_tolightuserdatatagged(L, -1, 0) == nullptr); + CHECK(lua_tolightuserdatatagged(L, -1, 1) == value); + + lua_setlightuserdataname(L, 1, "id"); + CHECK(!lua_getlightuserdataname(L, 0)); + CHECK(strcmp(lua_getlightuserdataname(L, 1), "id") == 0); + CHECK(strcmp(luaL_typename(L, -1), "id") == 0); + lua_pop(L, 1); + + lua_pushlightuserdatatagged(L, value, 0); + lua_pushlightuserdatatagged(L, value, 1); + CHECK(lua_rawequal(L, -1, -2) == 0); + lua_pop(L, 2); + + // Check lightuserdata table key uniqueness + lua_newtable(L); + + lua_pushlightuserdatatagged(L, value, 2); + lua_pushinteger(L, 20); + lua_settable(L, -3); + lua_pushlightuserdatatagged(L, value, 3); + lua_pushinteger(L, 30); + lua_settable(L, -3); + + lua_pushlightuserdatatagged(L, value, 2); + lua_gettable(L, -2); + lua_pushinteger(L, 20); + CHECK(lua_rawequal(L, -1, -2) == 1); + lua_pop(L, 2); + + lua_pushlightuserdatatagged(L, value, 3); + lua_gettable(L, -2); + lua_pushinteger(L, 30); + CHECK(lua_rawequal(L, -1, -2) == 1); + lua_pop(L, 2); + + lua_pop(L, 1); + + // Still possible to rename the global lightuserdata name using a metatable + lua_pushlightuserdata(L, value); + CHECK(strcmp(luaL_typename(L, -1), "userdata") == 0); + + lua_createtable(L, 0, 1); + lua_pushstring(L, "luserdata"); + lua_setfield(L, -2, "__type"); + lua_setmetatable(L, -2); + + CHECK(strcmp(luaL_typename(L, -1), "luserdata") == 0); + lua_pop(L, 1); + + globalState.reset(); +} + +TEST_CASE("DebugApi") +{ + StateRef globalState(luaL_newstate(), lua_close); + lua_State* L = globalState.get(); + + lua_pushnumber(L, 10); + + lua_Debug ar; + CHECK(lua_getinfo(L, -1, "f", &ar) == 0); // number is not a function + CHECK(lua_getinfo(L, -10, "f", &ar) == 0); // not on stack +} + TEST_CASE("Iter") { runConformance("iter.lua"); } const int kInt64Tag = 1; -static int gInt64MT = -1; static int64_t getInt64(lua_State* L, int idx) { @@ -1387,7 +2304,7 @@ static void pushInt64(lua_State* L, int64_t value) { void* p = lua_newuserdatatagged(L, sizeof(int64_t), kInt64Tag); - lua_getref(L, gInt64MT); + luaL_getmetatable(L, "int64"); lua_setmetatable(L, -2); *static_cast(p) = value; @@ -1395,179 +2312,225 @@ static void pushInt64(lua_State* L, int64_t value) TEST_CASE("Userdata") { - runConformance("userdata.lua", [](lua_State* L) { - // create metatable with all the metamethods - lua_newtable(L); - gInt64MT = lua_ref(L, -1); - - // __index - lua_pushcfunction( - L, - [](lua_State* L) { - void* p = lua_touserdatatagged(L, 1, kInt64Tag); - if (!p) - luaL_typeerror(L, 1, "int64"); - - const char* name = luaL_checkstring(L, 2); + runConformance( + "userdata.lua", + [](lua_State* L) + { + // create metatable with all the metamethods + luaL_newmetatable(L, "int64"); - if (strcmp(name, "value") == 0) + // __index + lua_pushcfunction( + L, + [](lua_State* L) { - lua_pushnumber(L, double(*static_cast(p))); - return 1; - } + void* p = lua_touserdatatagged(L, 1, kInt64Tag); + if (!p) + luaL_typeerror(L, 1, "int64"); - luaL_error(L, "unknown field %s", name); - }, - nullptr); - lua_setfield(L, -2, "__index"); + const char* name = luaL_checkstring(L, 2); - // __newindex - lua_pushcfunction( - L, - [](lua_State* L) { - void* p = lua_touserdatatagged(L, 1, kInt64Tag); - if (!p) - luaL_typeerror(L, 1, "int64"); + if (strcmp(name, "value") == 0) + { + lua_pushnumber(L, double(*static_cast(p))); + return 1; + } - const char* name = luaL_checkstring(L, 2); + luaL_error(L, "unknown field %s", name); + }, + nullptr + ); + lua_setfield(L, -2, "__index"); - if (strcmp(name, "value") == 0) + // __newindex + lua_pushcfunction( + L, + [](lua_State* L) { - double value = luaL_checknumber(L, 3); - *static_cast(p) = int64_t(value); - return 0; - } + void* p = lua_touserdatatagged(L, 1, kInt64Tag); + if (!p) + luaL_typeerror(L, 1, "int64"); - luaL_error(L, "unknown field %s", name); - }, - nullptr); - lua_setfield(L, -2, "__newindex"); - - // __eq - lua_pushcfunction( - L, - [](lua_State* L) { - lua_pushboolean(L, getInt64(L, 1) == getInt64(L, 2)); - return 1; - }, - nullptr); - lua_setfield(L, -2, "__eq"); - - // __lt - lua_pushcfunction( - L, - [](lua_State* L) { - lua_pushboolean(L, getInt64(L, 1) < getInt64(L, 2)); - return 1; - }, - nullptr); - lua_setfield(L, -2, "__lt"); - - // __le - lua_pushcfunction( - L, - [](lua_State* L) { - lua_pushboolean(L, getInt64(L, 1) <= getInt64(L, 2)); - return 1; - }, - nullptr); - lua_setfield(L, -2, "__le"); - - // __add - lua_pushcfunction( - L, - [](lua_State* L) { - pushInt64(L, getInt64(L, 1) + getInt64(L, 2)); - return 1; - }, - nullptr); - lua_setfield(L, -2, "__add"); - - // __sub - lua_pushcfunction( - L, - [](lua_State* L) { - pushInt64(L, getInt64(L, 1) - getInt64(L, 2)); - return 1; - }, - nullptr); - lua_setfield(L, -2, "__sub"); - - // __mul - lua_pushcfunction( - L, - [](lua_State* L) { - pushInt64(L, getInt64(L, 1) * getInt64(L, 2)); - return 1; - }, - nullptr); - lua_setfield(L, -2, "__mul"); - - // __div - lua_pushcfunction( - L, - [](lua_State* L) { - // ideally we'd guard against 0 but it's a test so eh - pushInt64(L, getInt64(L, 1) / getInt64(L, 2)); - return 1; - }, - nullptr); - lua_setfield(L, -2, "__div"); - - // __mod - lua_pushcfunction( - L, - [](lua_State* L) { - // ideally we'd guard against 0 and INT64_MIN but it's a test so eh - pushInt64(L, getInt64(L, 1) % getInt64(L, 2)); - return 1; - }, - nullptr); - lua_setfield(L, -2, "__mod"); - - // __pow - lua_pushcfunction( - L, - [](lua_State* L) { - pushInt64(L, int64_t(pow(double(getInt64(L, 1)), double(getInt64(L, 2))))); - return 1; - }, - nullptr); - lua_setfield(L, -2, "__pow"); - - // __unm - lua_pushcfunction( - L, - [](lua_State* L) { - pushInt64(L, -getInt64(L, 1)); - return 1; - }, - nullptr); - lua_setfield(L, -2, "__unm"); - - // __tostring - lua_pushcfunction( - L, - [](lua_State* L) { - int64_t value = getInt64(L, 1); - std::string str = std::to_string(value); - lua_pushlstring(L, str.c_str(), str.length()); - return 1; - }, - nullptr); - lua_setfield(L, -2, "__tostring"); - - // ctor - lua_pushcfunction( - L, - [](lua_State* L) { - double v = luaL_checknumber(L, 1); - pushInt64(L, int64_t(v)); - return 1; - }, - "int64"); - lua_setglobal(L, "int64"); - }); + const char* name = luaL_checkstring(L, 2); + + if (strcmp(name, "value") == 0) + { + double value = luaL_checknumber(L, 3); + *static_cast(p) = int64_t(value); + return 0; + } + + luaL_error(L, "unknown field %s", name); + }, + nullptr + ); + lua_setfield(L, -2, "__newindex"); + + // __eq + lua_pushcfunction( + L, + [](lua_State* L) + { + lua_pushboolean(L, getInt64(L, 1) == getInt64(L, 2)); + return 1; + }, + nullptr + ); + lua_setfield(L, -2, "__eq"); + + // __lt + lua_pushcfunction( + L, + [](lua_State* L) + { + lua_pushboolean(L, getInt64(L, 1) < getInt64(L, 2)); + return 1; + }, + nullptr + ); + lua_setfield(L, -2, "__lt"); + + // __le + lua_pushcfunction( + L, + [](lua_State* L) + { + lua_pushboolean(L, getInt64(L, 1) <= getInt64(L, 2)); + return 1; + }, + nullptr + ); + lua_setfield(L, -2, "__le"); + + // __add + lua_pushcfunction( + L, + [](lua_State* L) + { + pushInt64(L, getInt64(L, 1) + getInt64(L, 2)); + return 1; + }, + nullptr + ); + lua_setfield(L, -2, "__add"); + + // __sub + lua_pushcfunction( + L, + [](lua_State* L) + { + pushInt64(L, getInt64(L, 1) - getInt64(L, 2)); + return 1; + }, + nullptr + ); + lua_setfield(L, -2, "__sub"); + + // __mul + lua_pushcfunction( + L, + [](lua_State* L) + { + pushInt64(L, getInt64(L, 1) * getInt64(L, 2)); + return 1; + }, + nullptr + ); + lua_setfield(L, -2, "__mul"); + + // __div + lua_pushcfunction( + L, + [](lua_State* L) + { + // ideally we'd guard against 0 but it's a test so eh + pushInt64(L, getInt64(L, 1) / getInt64(L, 2)); + return 1; + }, + nullptr + ); + lua_setfield(L, -2, "__div"); + + // __idiv + lua_pushcfunction( + L, + [](lua_State* L) + { + // for testing we use different semantics here compared to __div: __idiv rounds to negative inf, __div truncates (rounds to zero) + // additionally, division loses precision here outside of 2^53 range + // we do not necessarily recommend this behavior in production code! + pushInt64(L, int64_t(floor(double(getInt64(L, 1)) / double(getInt64(L, 2))))); + return 1; + }, + nullptr + ); + lua_setfield(L, -2, "__idiv"); + + // __mod + lua_pushcfunction( + L, + [](lua_State* L) + { + // ideally we'd guard against 0 and INT64_MIN but it's a test so eh + pushInt64(L, getInt64(L, 1) % getInt64(L, 2)); + return 1; + }, + nullptr + ); + lua_setfield(L, -2, "__mod"); + + // __pow + lua_pushcfunction( + L, + [](lua_State* L) + { + pushInt64(L, int64_t(pow(double(getInt64(L, 1)), double(getInt64(L, 2))))); + return 1; + }, + nullptr + ); + lua_setfield(L, -2, "__pow"); + + // __unm + lua_pushcfunction( + L, + [](lua_State* L) + { + pushInt64(L, -getInt64(L, 1)); + return 1; + }, + nullptr + ); + lua_setfield(L, -2, "__unm"); + + // __tostring + lua_pushcfunction( + L, + [](lua_State* L) + { + int64_t value = getInt64(L, 1); + std::string str = std::to_string(value); + lua_pushlstring(L, str.c_str(), str.length()); + return 1; + }, + nullptr + ); + lua_setfield(L, -2, "__tostring"); + + // ctor + lua_pushcfunction( + L, + [](lua_State* L) + { + double v = luaL_checknumber(L, 1); + pushInt64(L, int64_t(v)); + return 1; + }, + "int64" + ); + lua_setglobal(L, "int64"); + } + ); } TEST_CASE("SafeEnv") @@ -1575,7 +2538,139 @@ TEST_CASE("SafeEnv") runConformance("safeenv.lua"); } -TEST_CASE("HugeFunction") +TEST_CASE("Native") +{ + // This tests requires code to run natively, otherwise all 'is_native' checks will fail + if (!codegen || !luau_codegen_supported()) + return; + + SUBCASE("Checked") + { + FFlag::DebugLuauAbortingChecks.value = true; + } + + SUBCASE("Regular") + { + FFlag::DebugLuauAbortingChecks.value = false; + } + + runConformance( + "native.lua", + [](lua_State* L) + { + setupNativeHelpers(L); + } + ); +} + +TEST_CASE("NativeTypeAnnotations") +{ + // This tests requires code to run natively, otherwise all 'is_native' checks will fail + if (!codegen || !luau_codegen_supported()) + return; + + runConformance( + "native_types.lua", + [](lua_State* L) + { + setupNativeHelpers(L); + setupVectorHelpers(L); + } + ); +} + +TEST_CASE("NativeUserdata") +{ + lua_CompileOptions copts = defaultOptions(); + Luau::CodeGen::CompilationOptions nativeOpts = defaultCodegenOptions(); + + static const char* kUserdataCompileTypes[] = {"vec2", "color", "mat3", nullptr}; + copts.userdataTypes = kUserdataCompileTypes; + + SUBCASE("NoIrHooks") + { + SUBCASE("O0") + { + copts.optimizationLevel = 0; + } + SUBCASE("O1") + { + copts.optimizationLevel = 1; + } + SUBCASE("O2") + { + copts.optimizationLevel = 2; + } + } + SUBCASE("IrHooks") + { + nativeOpts.hooks.vectorAccessBytecodeType = vectorAccessBytecodeType; + nativeOpts.hooks.vectorNamecallBytecodeType = vectorNamecallBytecodeType; + nativeOpts.hooks.vectorAccess = vectorAccess; + nativeOpts.hooks.vectorNamecall = vectorNamecall; + + nativeOpts.hooks.userdataAccessBytecodeType = userdataAccessBytecodeType; + nativeOpts.hooks.userdataMetamethodBytecodeType = userdataMetamethodBytecodeType; + nativeOpts.hooks.userdataNamecallBytecodeType = userdataNamecallBytecodeType; + nativeOpts.hooks.userdataAccess = userdataAccess; + nativeOpts.hooks.userdataMetamethod = userdataMetamethod; + nativeOpts.hooks.userdataNamecall = userdataNamecall; + + nativeOpts.userdataTypes = kUserdataRunTypes; + + SUBCASE("O0") + { + copts.optimizationLevel = 0; + } + SUBCASE("O1") + { + copts.optimizationLevel = 1; + } + SUBCASE("O2") + { + copts.optimizationLevel = 2; + } + } + + runConformance( + "native_userdata.lua", + [](lua_State* L) + { + Luau::CodeGen::setUserdataRemapper( + L, + kUserdataRunTypes, + [](void* context, const char* str, size_t len) -> uint8_t + { + const char** types = (const char**)context; + + uint8_t index = 0; + + std::string_view sv{str, len}; + + for (; *types; ++types) + { + if (sv == *types) + return index; + + index++; + } + + return 0xff; + } + ); + + setupVectorHelpers(L); + setupUserdataHelpers(L); + }, + nullptr, + nullptr, + &copts, + false, + &nativeOpts + ); +} + +[[nodiscard]] static std::string makeHugeFunctionSource() { std::string source; @@ -1596,11 +2691,18 @@ TEST_CASE("HugeFunction") // use failed fast-calls with imports and constants to exercise all of the more complex fallback sequences source += "return bit32.lshift('84', -1)"; + return source; +} + +TEST_CASE("HugeFunction") +{ + std::string source = makeHugeFunctionSource(); + StateRef globalState(luaL_newstate(), lua_close); lua_State* L = globalState.get(); - if (codegen && Luau::CodeGen::isSupported()) - Luau::CodeGen::create(L); + if (codegen && luau_codegen_supported()) + luau_codegen_create(L); luaL_openlibs(L); luaL_sandbox(L); @@ -1613,8 +2715,11 @@ TEST_CASE("HugeFunction") REQUIRE(result == 0); - if (codegen && Luau::CodeGen::isSupported()) - Luau::CodeGen::compile(L, -1); + if (codegen && luau_codegen_supported()) + { + Luau::CodeGen::CompilationOptions nativeOptions{Luau::CodeGen::CodeGen_ColdFunctions}; + Luau::CodeGen::compile(L, -1, nativeOptions); + } int status = lua_resume(L, nullptr, 0); REQUIRE(status == 0); @@ -1622,4 +2727,243 @@ TEST_CASE("HugeFunction") CHECK(lua_tonumber(L, -1) == 42); } +TEST_CASE("HugeFunctionLoadFailure") +{ + // This test case verifies that if an out-of-memory error occurs inside of + // luau_load, we are not left with any GC objects in inconsistent states + // that would cause issues during garbage collection. + // + // We create a script with a huge function in it, then pass this to + // luau_load. This should require two "large" allocations: One for the + // code array and one for the constants array (k). We run this test twice + // and fail each of these two allocations. + std::string source = makeHugeFunctionSource(); + + static const size_t expectedTotalLargeAllocations = 2; + + static size_t largeAllocationToFail = 0; + static size_t largeAllocationCount = 0; + + const auto testAllocate = [](void* ud, void* ptr, size_t osize, size_t nsize) -> void* + { + if (nsize == 0) + { + free(ptr); + return nullptr; + } + else if (nsize > 32768) + { + if (largeAllocationCount == largeAllocationToFail) + return nullptr; + + ++largeAllocationCount; + return realloc(ptr, nsize); + } + else + { + return realloc(ptr, nsize); + } + }; + + size_t bytecodeSize = 0; + char* const bytecode = luau_compile(source.data(), source.size(), nullptr, &bytecodeSize); + + for (largeAllocationToFail = 0; largeAllocationToFail != expectedTotalLargeAllocations; ++largeAllocationToFail) + { + largeAllocationCount = 0; + + StateRef globalState(lua_newstate(testAllocate, nullptr), lua_close); + lua_State* L = globalState.get(); + + luaL_openlibs(L); + luaL_sandbox(L); + luaL_sandboxthread(L); + + try + { + luau_load(L, "=HugeFunction", bytecode, bytecodeSize, 0); + REQUIRE(false); // The luau_load should fail with an exception + } + catch (const std::exception& ex) + { + REQUIRE(strcmp(ex.what(), "lua_exception: not enough memory") == 0); + } + + luaC_fullgc(L); + } + + free(bytecode); + + REQUIRE_EQ(largeAllocationToFail, expectedTotalLargeAllocations); +} + +TEST_CASE("IrInstructionLimit") +{ + if (!codegen || !luau_codegen_supported()) + return; + + ScopedFastInt codegenHeuristicsInstructionLimit{FInt::CodegenHeuristicsInstructionLimit, 50'000}; + + std::string source; + + // Generate a hundred fat functions + for (int fn = 0; fn < 100; fn++) + { + source += "local function fn" + std::to_string(fn) + "(...)\n"; + source += "if ... then\n"; + source += "local p1, p2 = ...\n"; + source += "local _ = {\n"; + + for (int i = 0; i < 100; ++i) + { + source += "p1*0." + std::to_string(i) + ","; + source += "p2+0." + std::to_string(i) + ","; + } + + source += "}\n"; + source += "return _\n"; + source += "end\n"; + source += "end\n"; + } + + StateRef globalState(luaL_newstate(), lua_close); + lua_State* L = globalState.get(); + + luau_codegen_create(L); + + luaL_openlibs(L); + luaL_sandbox(L); + luaL_sandboxthread(L); + + size_t bytecodeSize = 0; + char* bytecode = luau_compile(source.data(), source.size(), nullptr, &bytecodeSize); + int result = luau_load(L, "=HugeFunction", bytecode, bytecodeSize, 0); + free(bytecode); + + REQUIRE(result == 0); + + Luau::CodeGen::CompilationOptions nativeOptions{Luau::CodeGen::CodeGen_ColdFunctions}; + Luau::CodeGen::CompilationStats nativeStats = {}; + Luau::CodeGen::CompilationResult nativeResult = Luau::CodeGen::compile(L, -1, nativeOptions, &nativeStats); + + // Limit is not hit immediately, so with some functions compiled it should be a success + CHECK(nativeResult.result == Luau::CodeGen::CodeGenCompilationResult::Success); + + // But it has some failed functions + CHECK(nativeResult.hasErrors()); + REQUIRE(!nativeResult.protoFailures.empty()); + + CHECK(nativeResult.protoFailures.front().result == Luau::CodeGen::CodeGenCompilationResult::CodeGenOverflowInstructionLimit); + CHECK(nativeResult.protoFailures.front().line != -1); + CHECK(nativeResult.protoFailures.front().debugname != ""); + + // We should be able to compile at least one of our functions + CHECK(nativeStats.functionsCompiled > 0); + + // But because of the limit, not all of them (101 because there's an extra global function) + CHECK(nativeStats.functionsCompiled < 101); +} + +TEST_CASE("BytecodeDistributionPerFunctionTest") +{ + const char* source = R"( +local function first(n, p) + local t = {} + for i=1,p do t[i] = i*10 end + + local function inner(_,n) + if n > 0 then + n = n-1 + return n, unpack(t) + end + end + return inner, nil, n +end + +local function second(x) + return x[1] +end +)"; + + std::vector summaries(analyzeFile(source, 0)); + + CHECK_EQ(summaries[0].getName(), "inner"); + CHECK_EQ(summaries[0].getLine(), 6); + CHECK_EQ(summaries[0].getCounts(0), std::vector({0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0})); + + CHECK_EQ(summaries[1].getName(), "first"); + CHECK_EQ(summaries[1].getLine(), 2); + CHECK_EQ(summaries[1].getCounts(0), std::vector({0, 0, 1, 0, 2, 0, 3, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, + 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0})); + + + CHECK_EQ(summaries[2].getName(), "second"); + CHECK_EQ(summaries[2].getLine(), 15); + CHECK_EQ(summaries[2].getCounts(0), std::vector({0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0})); + + CHECK_EQ(summaries[3].getName(), ""); + CHECK_EQ(summaries[3].getLine(), 1); + CHECK_EQ(summaries[3].getCounts(0), std::vector({0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0})); +} + +TEST_CASE("NativeAttribute") +{ + if (!codegen || !luau_codegen_supported()) + return; + + ScopedFastFlag sffs[] = {{FFlag::LuauNativeAttribute, true}}; + + std::string source = R"R( + @native + local function sum(x, y) + local function sumHelper(z) + return (x+y+z) + end + return sumHelper + end + + local function sub(x, y) + @native + local function subHelper(z) + return (x+y-z) + end + return subHelper + end)R"; + + StateRef globalState(luaL_newstate(), lua_close); + lua_State* L = globalState.get(); + + luau_codegen_create(L); + + luaL_openlibs(L); + luaL_sandbox(L); + luaL_sandboxthread(L); + + size_t bytecodeSize = 0; + char* bytecode = luau_compile(source.data(), source.size(), nullptr, &bytecodeSize); + int result = luau_load(L, "=Code", bytecode, bytecodeSize, 0); + free(bytecode); + + REQUIRE(result == 0); + + Luau::CodeGen::CompilationOptions nativeOptions{Luau::CodeGen::CodeGen_ColdFunctions}; + Luau::CodeGen::CompilationStats nativeStats = {}; + Luau::CodeGen::CompilationResult nativeResult = Luau::CodeGen::compile(L, -1, nativeOptions, &nativeStats); + + CHECK(nativeResult.result == Luau::CodeGen::CodeGenCompilationResult::Success); + + CHECK(!nativeResult.hasErrors()); + REQUIRE(nativeResult.protoFailures.empty()); + + // We should be able to compile at least one of our functions + CHECK_EQ(nativeStats.functionsCompiled, 2); +} + TEST_SUITE_END(); diff --git a/tests/ConformanceIrHooks.h b/tests/ConformanceIrHooks.h new file mode 100644 index 000000000..07c721baf --- /dev/null +++ b/tests/ConformanceIrHooks.h @@ -0,0 +1,574 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/IrBuilder.h" + +static const char* kUserdataRunTypes[] = {"extra", "color", "vec2", "mat3", nullptr}; + +constexpr uint8_t kUserdataExtra = 0; +constexpr uint8_t kUserdataColor = 1; +constexpr uint8_t kUserdataVec2 = 2; +constexpr uint8_t kUserdataMat3 = 3; + +// Userdata tags can be different from userdata bytecode type indices +constexpr uint8_t kTagVec2 = 12; + +struct Vec2 +{ + float x; + float y; +}; + +inline bool compareMemberName(const char* member, size_t memberLength, const char* str) +{ + return memberLength == strlen(str) && strcmp(member, str) == 0; +} + +inline uint8_t typeToUserdataIndex(uint8_t type) +{ + // Underflow will push the type into a value that is not comparable to any kUserdata* constants + return type - LBC_TYPE_TAGGED_USERDATA_BASE; +} + +inline uint8_t userdataIndexToType(uint8_t userdataIndex) +{ + return LBC_TYPE_TAGGED_USERDATA_BASE + userdataIndex; +} + +inline uint8_t vectorAccessBytecodeType(const char* member, size_t memberLength) +{ + using namespace Luau::CodeGen; + + if (compareMemberName(member, memberLength, "Magnitude")) + return LBC_TYPE_NUMBER; + + if (compareMemberName(member, memberLength, "Unit")) + return LBC_TYPE_VECTOR; + + return LBC_TYPE_ANY; +} + +inline bool vectorAccess(Luau::CodeGen::IrBuilder& build, const char* member, size_t memberLength, int resultReg, int sourceReg, int pcpos) +{ + using namespace Luau::CodeGen; + + if (compareMemberName(member, memberLength, "Magnitude")) + { + IrOp x = build.inst(IrCmd::LOAD_FLOAT, build.vmReg(sourceReg), build.constInt(0)); + IrOp y = build.inst(IrCmd::LOAD_FLOAT, build.vmReg(sourceReg), build.constInt(4)); + IrOp z = build.inst(IrCmd::LOAD_FLOAT, build.vmReg(sourceReg), build.constInt(8)); + + IrOp x2 = build.inst(IrCmd::MUL_NUM, x, x); + IrOp y2 = build.inst(IrCmd::MUL_NUM, y, y); + IrOp z2 = build.inst(IrCmd::MUL_NUM, z, z); + + IrOp sum = build.inst(IrCmd::ADD_NUM, build.inst(IrCmd::ADD_NUM, x2, y2), z2); + + IrOp mag = build.inst(IrCmd::SQRT_NUM, sum); + + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(resultReg), mag); + build.inst(IrCmd::STORE_TAG, build.vmReg(resultReg), build.constTag(LUA_TNUMBER)); + + return true; + } + + if (compareMemberName(member, memberLength, "Unit")) + { + IrOp x = build.inst(IrCmd::LOAD_FLOAT, build.vmReg(sourceReg), build.constInt(0)); + IrOp y = build.inst(IrCmd::LOAD_FLOAT, build.vmReg(sourceReg), build.constInt(4)); + IrOp z = build.inst(IrCmd::LOAD_FLOAT, build.vmReg(sourceReg), build.constInt(8)); + + IrOp x2 = build.inst(IrCmd::MUL_NUM, x, x); + IrOp y2 = build.inst(IrCmd::MUL_NUM, y, y); + IrOp z2 = build.inst(IrCmd::MUL_NUM, z, z); + + IrOp sum = build.inst(IrCmd::ADD_NUM, build.inst(IrCmd::ADD_NUM, x2, y2), z2); + + IrOp mag = build.inst(IrCmd::SQRT_NUM, sum); + IrOp inv = build.inst(IrCmd::DIV_NUM, build.constDouble(1.0), mag); + + IrOp xr = build.inst(IrCmd::MUL_NUM, x, inv); + IrOp yr = build.inst(IrCmd::MUL_NUM, y, inv); + IrOp zr = build.inst(IrCmd::MUL_NUM, z, inv); + + build.inst(IrCmd::STORE_VECTOR, build.vmReg(resultReg), xr, yr, zr); + build.inst(IrCmd::STORE_TAG, build.vmReg(resultReg), build.constTag(LUA_TVECTOR)); + + return true; + } + + return false; +} + +inline uint8_t vectorNamecallBytecodeType(const char* member, size_t memberLength) +{ + if (compareMemberName(member, memberLength, "Dot")) + return LBC_TYPE_NUMBER; + + if (compareMemberName(member, memberLength, "Cross")) + return LBC_TYPE_VECTOR; + + return LBC_TYPE_ANY; +} + +inline bool vectorNamecall( + Luau::CodeGen::IrBuilder& build, + const char* member, + size_t memberLength, + int argResReg, + int sourceReg, + int params, + int results, + int pcpos +) +{ + using namespace Luau::CodeGen; + + if (compareMemberName(member, memberLength, "Dot") && params == 2 && results <= 1) + { + build.loadAndCheckTag(build.vmReg(argResReg + 2), LUA_TVECTOR, build.vmExit(pcpos)); + + IrOp x1 = build.inst(IrCmd::LOAD_FLOAT, build.vmReg(sourceReg), build.constInt(0)); + IrOp x2 = build.inst(IrCmd::LOAD_FLOAT, build.vmReg(argResReg + 2), build.constInt(0)); + IrOp xx = build.inst(IrCmd::MUL_NUM, x1, x2); + + IrOp y1 = build.inst(IrCmd::LOAD_FLOAT, build.vmReg(sourceReg), build.constInt(4)); + IrOp y2 = build.inst(IrCmd::LOAD_FLOAT, build.vmReg(argResReg + 2), build.constInt(4)); + IrOp yy = build.inst(IrCmd::MUL_NUM, y1, y2); + + IrOp z1 = build.inst(IrCmd::LOAD_FLOAT, build.vmReg(sourceReg), build.constInt(8)); + IrOp z2 = build.inst(IrCmd::LOAD_FLOAT, build.vmReg(argResReg + 2), build.constInt(8)); + IrOp zz = build.inst(IrCmd::MUL_NUM, z1, z2); + + IrOp sum = build.inst(IrCmd::ADD_NUM, build.inst(IrCmd::ADD_NUM, xx, yy), zz); + + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(argResReg), sum); + build.inst(IrCmd::STORE_TAG, build.vmReg(argResReg), build.constTag(LUA_TNUMBER)); + + // If the function is called in multi-return context, stack has to be adjusted + if (results == LUA_MULTRET) + build.inst(IrCmd::ADJUST_STACK_TO_REG, build.vmReg(argResReg), build.constInt(1)); + + return true; + } + + if (compareMemberName(member, memberLength, "Cross") && params == 2 && results <= 1) + { + build.loadAndCheckTag(build.vmReg(argResReg + 2), LUA_TVECTOR, build.vmExit(pcpos)); + + IrOp x1 = build.inst(IrCmd::LOAD_FLOAT, build.vmReg(sourceReg), build.constInt(0)); + IrOp x2 = build.inst(IrCmd::LOAD_FLOAT, build.vmReg(argResReg + 2), build.constInt(0)); + + IrOp y1 = build.inst(IrCmd::LOAD_FLOAT, build.vmReg(sourceReg), build.constInt(4)); + IrOp y2 = build.inst(IrCmd::LOAD_FLOAT, build.vmReg(argResReg + 2), build.constInt(4)); + + IrOp z1 = build.inst(IrCmd::LOAD_FLOAT, build.vmReg(sourceReg), build.constInt(8)); + IrOp z2 = build.inst(IrCmd::LOAD_FLOAT, build.vmReg(argResReg + 2), build.constInt(8)); + + IrOp y1z2 = build.inst(IrCmd::MUL_NUM, y1, z2); + IrOp z1y2 = build.inst(IrCmd::MUL_NUM, z1, y2); + IrOp xr = build.inst(IrCmd::SUB_NUM, y1z2, z1y2); + + IrOp z1x2 = build.inst(IrCmd::MUL_NUM, z1, x2); + IrOp x1z2 = build.inst(IrCmd::MUL_NUM, x1, z2); + IrOp yr = build.inst(IrCmd::SUB_NUM, z1x2, x1z2); + + IrOp x1y2 = build.inst(IrCmd::MUL_NUM, x1, y2); + IrOp y1x2 = build.inst(IrCmd::MUL_NUM, y1, x2); + IrOp zr = build.inst(IrCmd::SUB_NUM, x1y2, y1x2); + + build.inst(IrCmd::STORE_VECTOR, build.vmReg(argResReg), xr, yr, zr); + build.inst(IrCmd::STORE_TAG, build.vmReg(argResReg), build.constTag(LUA_TVECTOR)); + + // If the function is called in multi-return context, stack has to be adjusted + if (results == LUA_MULTRET) + build.inst(IrCmd::ADJUST_STACK_TO_REG, build.vmReg(argResReg), build.constInt(1)); + + return true; + } + + return false; +} + +inline uint8_t userdataAccessBytecodeType(uint8_t type, const char* member, size_t memberLength) +{ + switch (typeToUserdataIndex(type)) + { + case kUserdataColor: + if (compareMemberName(member, memberLength, "R")) + return LBC_TYPE_NUMBER; + + if (compareMemberName(member, memberLength, "G")) + return LBC_TYPE_NUMBER; + + if (compareMemberName(member, memberLength, "B")) + return LBC_TYPE_NUMBER; + break; + case kUserdataVec2: + if (compareMemberName(member, memberLength, "X")) + return LBC_TYPE_NUMBER; + + if (compareMemberName(member, memberLength, "Y")) + return LBC_TYPE_NUMBER; + + if (compareMemberName(member, memberLength, "Magnitude")) + return LBC_TYPE_NUMBER; + + if (compareMemberName(member, memberLength, "Unit")) + return userdataIndexToType(kUserdataVec2); + break; + case kUserdataMat3: + if (compareMemberName(member, memberLength, "Row1")) + return LBC_TYPE_VECTOR; + + if (compareMemberName(member, memberLength, "Row2")) + return LBC_TYPE_VECTOR; + + if (compareMemberName(member, memberLength, "Row3")) + return LBC_TYPE_VECTOR; + break; + } + + return LBC_TYPE_ANY; +} + +inline bool userdataAccess( + Luau::CodeGen::IrBuilder& build, + uint8_t type, + const char* member, + size_t memberLength, + int resultReg, + int sourceReg, + int pcpos +) +{ + using namespace Luau::CodeGen; + + switch (typeToUserdataIndex(type)) + { + case kUserdataColor: + break; + case kUserdataVec2: + if (compareMemberName(member, memberLength, "X")) + { + IrOp udata = build.inst(IrCmd::LOAD_POINTER, build.vmReg(sourceReg)); + build.inst(IrCmd::CHECK_USERDATA_TAG, udata, build.constInt(kTagVec2), build.vmExit(pcpos)); + + IrOp value = build.inst(IrCmd::BUFFER_READF32, udata, build.constInt(offsetof(Vec2, x)), build.constTag(LUA_TUSERDATA)); + + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(resultReg), value); + build.inst(IrCmd::STORE_TAG, build.vmReg(resultReg), build.constTag(LUA_TNUMBER)); + return true; + } + + if (compareMemberName(member, memberLength, "Y")) + { + IrOp udata = build.inst(IrCmd::LOAD_POINTER, build.vmReg(sourceReg)); + build.inst(IrCmd::CHECK_USERDATA_TAG, udata, build.constInt(kTagVec2), build.vmExit(pcpos)); + + IrOp value = build.inst(IrCmd::BUFFER_READF32, udata, build.constInt(offsetof(Vec2, y)), build.constTag(LUA_TUSERDATA)); + + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(resultReg), value); + build.inst(IrCmd::STORE_TAG, build.vmReg(resultReg), build.constTag(LUA_TNUMBER)); + return true; + } + + if (compareMemberName(member, memberLength, "Magnitude")) + { + IrOp udata = build.inst(IrCmd::LOAD_POINTER, build.vmReg(sourceReg)); + build.inst(IrCmd::CHECK_USERDATA_TAG, udata, build.constInt(kTagVec2), build.vmExit(pcpos)); + + IrOp x = build.inst(IrCmd::BUFFER_READF32, udata, build.constInt(offsetof(Vec2, x)), build.constTag(LUA_TUSERDATA)); + IrOp y = build.inst(IrCmd::BUFFER_READF32, udata, build.constInt(offsetof(Vec2, y)), build.constTag(LUA_TUSERDATA)); + + IrOp x2 = build.inst(IrCmd::MUL_NUM, x, x); + IrOp y2 = build.inst(IrCmd::MUL_NUM, y, y); + + IrOp sum = build.inst(IrCmd::ADD_NUM, x2, y2); + + IrOp mag = build.inst(IrCmd::SQRT_NUM, sum); + + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(resultReg), mag); + build.inst(IrCmd::STORE_TAG, build.vmReg(resultReg), build.constTag(LUA_TNUMBER)); + return true; + } + + if (compareMemberName(member, memberLength, "Unit")) + { + IrOp udata = build.inst(IrCmd::LOAD_POINTER, build.vmReg(sourceReg)); + build.inst(IrCmd::CHECK_USERDATA_TAG, udata, build.constInt(kTagVec2), build.vmExit(pcpos)); + + IrOp x = build.inst(IrCmd::BUFFER_READF32, udata, build.constInt(offsetof(Vec2, x)), build.constTag(LUA_TUSERDATA)); + IrOp y = build.inst(IrCmd::BUFFER_READF32, udata, build.constInt(offsetof(Vec2, y)), build.constTag(LUA_TUSERDATA)); + + IrOp x2 = build.inst(IrCmd::MUL_NUM, x, x); + IrOp y2 = build.inst(IrCmd::MUL_NUM, y, y); + + IrOp sum = build.inst(IrCmd::ADD_NUM, x2, y2); + + IrOp mag = build.inst(IrCmd::SQRT_NUM, sum); + IrOp inv = build.inst(IrCmd::DIV_NUM, build.constDouble(1.0), mag); + + IrOp xr = build.inst(IrCmd::MUL_NUM, x, inv); + IrOp yr = build.inst(IrCmd::MUL_NUM, y, inv); + + build.inst(IrCmd::CHECK_GC); + IrOp udatar = build.inst(IrCmd::NEW_USERDATA, build.constInt(sizeof(Vec2)), build.constInt(kTagVec2)); + + build.inst(IrCmd::BUFFER_WRITEF32, udatar, build.constInt(offsetof(Vec2, x)), xr, build.constTag(LUA_TUSERDATA)); + build.inst(IrCmd::BUFFER_WRITEF32, udatar, build.constInt(offsetof(Vec2, y)), yr, build.constTag(LUA_TUSERDATA)); + + build.inst(IrCmd::STORE_POINTER, build.vmReg(resultReg), udatar); + build.inst(IrCmd::STORE_TAG, build.vmReg(resultReg), build.constTag(LUA_TUSERDATA)); + return true; + } + break; + case kUserdataMat3: + break; + } + + return false; +} + +inline uint8_t userdataMetamethodBytecodeType(uint8_t lhsTy, uint8_t rhsTy, Luau::CodeGen::HostMetamethod method) +{ + switch (method) + { + case Luau::CodeGen::HostMetamethod::Add: + case Luau::CodeGen::HostMetamethod::Sub: + case Luau::CodeGen::HostMetamethod::Mul: + case Luau::CodeGen::HostMetamethod::Div: + if (typeToUserdataIndex(lhsTy) == kUserdataVec2 || typeToUserdataIndex(rhsTy) == kUserdataVec2) + return userdataIndexToType(kUserdataVec2); + break; + case Luau::CodeGen::HostMetamethod::Minus: + if (typeToUserdataIndex(lhsTy) == kUserdataVec2) + return userdataIndexToType(kUserdataVec2); + break; + default: + break; + } + + return LBC_TYPE_ANY; +} + +inline bool userdataMetamethod( + Luau::CodeGen::IrBuilder& build, + uint8_t lhsTy, + uint8_t rhsTy, + int resultReg, + Luau::CodeGen::IrOp lhs, + Luau::CodeGen::IrOp rhs, + Luau::CodeGen::HostMetamethod method, + int pcpos +) +{ + using namespace Luau::CodeGen; + + switch (method) + { + case Luau::CodeGen::HostMetamethod::Add: + if (typeToUserdataIndex(lhsTy) == kUserdataVec2 && typeToUserdataIndex(rhsTy) == kUserdataVec2) + { + build.loadAndCheckTag(lhs, LUA_TUSERDATA, build.vmExit(pcpos)); + build.loadAndCheckTag(rhs, LUA_TUSERDATA, build.vmExit(pcpos)); + + IrOp udata1 = build.inst(IrCmd::LOAD_POINTER, lhs); + build.inst(IrCmd::CHECK_USERDATA_TAG, udata1, build.constInt(kTagVec2), build.vmExit(pcpos)); + + IrOp udata2 = build.inst(IrCmd::LOAD_POINTER, rhs); + build.inst(IrCmd::CHECK_USERDATA_TAG, udata2, build.constInt(kTagVec2), build.vmExit(pcpos)); + + IrOp x1 = build.inst(IrCmd::BUFFER_READF32, udata1, build.constInt(offsetof(Vec2, x)), build.constTag(LUA_TUSERDATA)); + IrOp x2 = build.inst(IrCmd::BUFFER_READF32, udata2, build.constInt(offsetof(Vec2, x)), build.constTag(LUA_TUSERDATA)); + IrOp mx = build.inst(IrCmd::ADD_NUM, x1, x2); + + IrOp y1 = build.inst(IrCmd::BUFFER_READF32, udata1, build.constInt(offsetof(Vec2, y)), build.constTag(LUA_TUSERDATA)); + IrOp y2 = build.inst(IrCmd::BUFFER_READF32, udata2, build.constInt(offsetof(Vec2, y)), build.constTag(LUA_TUSERDATA)); + IrOp my = build.inst(IrCmd::ADD_NUM, y1, y2); + + build.inst(IrCmd::CHECK_GC); + IrOp udatar = build.inst(IrCmd::NEW_USERDATA, build.constInt(sizeof(Vec2)), build.constInt(kTagVec2)); + + build.inst(IrCmd::BUFFER_WRITEF32, udatar, build.constInt(offsetof(Vec2, x)), mx, build.constTag(LUA_TUSERDATA)); + build.inst(IrCmd::BUFFER_WRITEF32, udatar, build.constInt(offsetof(Vec2, y)), my, build.constTag(LUA_TUSERDATA)); + + build.inst(IrCmd::STORE_POINTER, build.vmReg(resultReg), udatar); + build.inst(IrCmd::STORE_TAG, build.vmReg(resultReg), build.constTag(LUA_TUSERDATA)); + + return true; + } + break; + case Luau::CodeGen::HostMetamethod::Mul: + if (typeToUserdataIndex(lhsTy) == kUserdataVec2 && typeToUserdataIndex(rhsTy) == kUserdataVec2) + { + build.loadAndCheckTag(lhs, LUA_TUSERDATA, build.vmExit(pcpos)); + build.loadAndCheckTag(rhs, LUA_TUSERDATA, build.vmExit(pcpos)); + + IrOp udata1 = build.inst(IrCmd::LOAD_POINTER, lhs); + build.inst(IrCmd::CHECK_USERDATA_TAG, udata1, build.constInt(kTagVec2), build.vmExit(pcpos)); + + IrOp udata2 = build.inst(IrCmd::LOAD_POINTER, rhs); + build.inst(IrCmd::CHECK_USERDATA_TAG, udata2, build.constInt(kTagVec2), build.vmExit(pcpos)); + + IrOp x1 = build.inst(IrCmd::BUFFER_READF32, udata1, build.constInt(offsetof(Vec2, x)), build.constTag(LUA_TUSERDATA)); + IrOp x2 = build.inst(IrCmd::BUFFER_READF32, udata2, build.constInt(offsetof(Vec2, x)), build.constTag(LUA_TUSERDATA)); + IrOp mx = build.inst(IrCmd::MUL_NUM, x1, x2); + + IrOp y1 = build.inst(IrCmd::BUFFER_READF32, udata1, build.constInt(offsetof(Vec2, y)), build.constTag(LUA_TUSERDATA)); + IrOp y2 = build.inst(IrCmd::BUFFER_READF32, udata2, build.constInt(offsetof(Vec2, y)), build.constTag(LUA_TUSERDATA)); + IrOp my = build.inst(IrCmd::MUL_NUM, y1, y2); + + build.inst(IrCmd::CHECK_GC); + IrOp udatar = build.inst(IrCmd::NEW_USERDATA, build.constInt(sizeof(Vec2)), build.constInt(kTagVec2)); + + build.inst(IrCmd::BUFFER_WRITEF32, udatar, build.constInt(offsetof(Vec2, x)), mx, build.constTag(LUA_TUSERDATA)); + build.inst(IrCmd::BUFFER_WRITEF32, udatar, build.constInt(offsetof(Vec2, y)), my, build.constTag(LUA_TUSERDATA)); + + build.inst(IrCmd::STORE_POINTER, build.vmReg(resultReg), udatar); + build.inst(IrCmd::STORE_TAG, build.vmReg(resultReg), build.constTag(LUA_TUSERDATA)); + + return true; + } + break; + case Luau::CodeGen::HostMetamethod::Minus: + if (typeToUserdataIndex(lhsTy) == kUserdataVec2) + { + build.loadAndCheckTag(lhs, LUA_TUSERDATA, build.vmExit(pcpos)); + + IrOp udata1 = build.inst(IrCmd::LOAD_POINTER, lhs); + build.inst(IrCmd::CHECK_USERDATA_TAG, udata1, build.constInt(kTagVec2), build.vmExit(pcpos)); + + IrOp x = build.inst(IrCmd::BUFFER_READF32, udata1, build.constInt(offsetof(Vec2, x)), build.constTag(LUA_TUSERDATA)); + IrOp y = build.inst(IrCmd::BUFFER_READF32, udata1, build.constInt(offsetof(Vec2, y)), build.constTag(LUA_TUSERDATA)); + IrOp mx = build.inst(IrCmd::UNM_NUM, x); + IrOp my = build.inst(IrCmd::UNM_NUM, y); + + build.inst(IrCmd::CHECK_GC); + IrOp udatar = build.inst(IrCmd::NEW_USERDATA, build.constInt(sizeof(Vec2)), build.constInt(kTagVec2)); + + build.inst(IrCmd::BUFFER_WRITEF32, udatar, build.constInt(offsetof(Vec2, x)), mx, build.constTag(LUA_TUSERDATA)); + build.inst(IrCmd::BUFFER_WRITEF32, udatar, build.constInt(offsetof(Vec2, y)), my, build.constTag(LUA_TUSERDATA)); + + build.inst(IrCmd::STORE_POINTER, build.vmReg(resultReg), udatar); + build.inst(IrCmd::STORE_TAG, build.vmReg(resultReg), build.constTag(LUA_TUSERDATA)); + + return true; + } + break; + default: + break; + } + + return false; +} + +inline uint8_t userdataNamecallBytecodeType(uint8_t type, const char* member, size_t memberLength) +{ + switch (typeToUserdataIndex(type)) + { + case kUserdataColor: + break; + case kUserdataVec2: + if (compareMemberName(member, memberLength, "Dot")) + return LBC_TYPE_NUMBER; + + if (compareMemberName(member, memberLength, "Min")) + return userdataIndexToType(kUserdataVec2); + break; + case kUserdataMat3: + break; + } + + return LBC_TYPE_ANY; +} + +inline bool userdataNamecall( + Luau::CodeGen::IrBuilder& build, + uint8_t type, + const char* member, + size_t memberLength, + int argResReg, + int sourceReg, + int params, + int results, + int pcpos +) +{ + using namespace Luau::CodeGen; + + switch (typeToUserdataIndex(type)) + { + case kUserdataColor: + break; + case kUserdataVec2: + if (compareMemberName(member, memberLength, "Dot")) + { + IrOp udata1 = build.inst(IrCmd::LOAD_POINTER, build.vmReg(sourceReg)); + build.inst(IrCmd::CHECK_USERDATA_TAG, udata1, build.constInt(kTagVec2), build.vmExit(pcpos)); + + build.loadAndCheckTag(build.vmReg(argResReg + 2), LUA_TUSERDATA, build.vmExit(pcpos)); + + IrOp udata2 = build.inst(IrCmd::LOAD_POINTER, build.vmReg(argResReg + 2)); + build.inst(IrCmd::CHECK_USERDATA_TAG, udata2, build.constInt(kTagVec2), build.vmExit(pcpos)); + + IrOp x1 = build.inst(IrCmd::BUFFER_READF32, udata1, build.constInt(offsetof(Vec2, x)), build.constTag(LUA_TUSERDATA)); + IrOp x2 = build.inst(IrCmd::BUFFER_READF32, udata2, build.constInt(offsetof(Vec2, x)), build.constTag(LUA_TUSERDATA)); + IrOp xx = build.inst(IrCmd::MUL_NUM, x1, x2); + + IrOp y1 = build.inst(IrCmd::BUFFER_READF32, udata1, build.constInt(offsetof(Vec2, y)), build.constTag(LUA_TUSERDATA)); + IrOp y2 = build.inst(IrCmd::BUFFER_READF32, udata2, build.constInt(offsetof(Vec2, y)), build.constTag(LUA_TUSERDATA)); + IrOp yy = build.inst(IrCmd::MUL_NUM, y1, y2); + + IrOp sum = build.inst(IrCmd::ADD_NUM, xx, yy); + + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(argResReg), sum); + build.inst(IrCmd::STORE_TAG, build.vmReg(argResReg), build.constTag(LUA_TNUMBER)); + + // If the function is called in multi-return context, stack has to be adjusted + if (results == LUA_MULTRET) + build.inst(IrCmd::ADJUST_STACK_TO_REG, build.vmReg(argResReg), build.constInt(1)); + + return true; + } + + if (compareMemberName(member, memberLength, "Min")) + { + IrOp udata1 = build.inst(IrCmd::LOAD_POINTER, build.vmReg(sourceReg)); + build.inst(IrCmd::CHECK_USERDATA_TAG, udata1, build.constInt(kTagVec2), build.vmExit(pcpos)); + + build.loadAndCheckTag(build.vmReg(argResReg + 2), LUA_TUSERDATA, build.vmExit(pcpos)); + + IrOp udata2 = build.inst(IrCmd::LOAD_POINTER, build.vmReg(argResReg + 2)); + build.inst(IrCmd::CHECK_USERDATA_TAG, udata2, build.constInt(kTagVec2), build.vmExit(pcpos)); + + IrOp x1 = build.inst(IrCmd::BUFFER_READF32, udata1, build.constInt(offsetof(Vec2, x)), build.constTag(LUA_TUSERDATA)); + IrOp x2 = build.inst(IrCmd::BUFFER_READF32, udata2, build.constInt(offsetof(Vec2, x)), build.constTag(LUA_TUSERDATA)); + IrOp mx = build.inst(IrCmd::MIN_NUM, x1, x2); + + IrOp y1 = build.inst(IrCmd::BUFFER_READF32, udata1, build.constInt(offsetof(Vec2, y)), build.constTag(LUA_TUSERDATA)); + IrOp y2 = build.inst(IrCmd::BUFFER_READF32, udata2, build.constInt(offsetof(Vec2, y)), build.constTag(LUA_TUSERDATA)); + IrOp my = build.inst(IrCmd::MIN_NUM, y1, y2); + + build.inst(IrCmd::CHECK_GC); + IrOp udatar = build.inst(IrCmd::NEW_USERDATA, build.constInt(sizeof(Vec2)), build.constInt(kTagVec2)); + + build.inst(IrCmd::BUFFER_WRITEF32, udatar, build.constInt(offsetof(Vec2, x)), mx, build.constTag(LUA_TUSERDATA)); + build.inst(IrCmd::BUFFER_WRITEF32, udatar, build.constInt(offsetof(Vec2, y)), my, build.constTag(LUA_TUSERDATA)); + + build.inst(IrCmd::STORE_POINTER, build.vmReg(argResReg), udatar); + build.inst(IrCmd::STORE_TAG, build.vmReg(argResReg), build.constTag(LUA_TUSERDATA)); + + // If the function is called in multi-return context, stack has to be adjusted + if (results == LUA_MULTRET) + build.inst(IrCmd::ADJUST_STACK_TO_REG, build.vmReg(argResReg), build.constInt(1)); + + return true; + } + break; + case kUserdataMat3: + break; + } + + return false; +} diff --git a/tests/ConstraintGeneratorFixture.cpp b/tests/ConstraintGeneratorFixture.cpp new file mode 100644 index 000000000..1b84d4c90 --- /dev/null +++ b/tests/ConstraintGeneratorFixture.cpp @@ -0,0 +1,52 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "ConstraintGeneratorFixture.h" +#include "ScopedFlags.h" + +LUAU_FASTFLAG(LuauSolverV2); + +namespace Luau +{ + +ConstraintGeneratorFixture::ConstraintGeneratorFixture() + : Fixture() + , mainModule(new Module) + , forceTheFlag{FFlag::LuauSolverV2, true} +{ + mainModule->name = "MainModule"; + mainModule->humanReadableName = "MainModule"; + + BlockedTypePack::nextIndex = 0; +} + +void ConstraintGeneratorFixture::generateConstraints(const std::string& code) +{ + AstStatBlock* root = parse(code); + dfg = std::make_unique(DataFlowGraphBuilder::build(root, NotNull{&ice})); + cg = std::make_unique( + mainModule, + NotNull{&normalizer}, + NotNull{&typeFunctionRuntime}, + NotNull(&moduleResolver), + builtinTypes, + NotNull(&ice), + frontend.globals.globalScope, + /*prepareModuleScope*/ nullptr, + &logger, + NotNull{dfg.get()}, + std::vector() + ); + cg->visitModuleRoot(root); + rootScope = cg->rootScope; + constraints = Luau::borrowConstraints(cg->constraints); +} + +void ConstraintGeneratorFixture::solve(const std::string& code) +{ + generateConstraints(code); + ConstraintSolver cs{ + NotNull{&normalizer}, NotNull{&typeFunctionRuntime}, NotNull{rootScope}, constraints, "MainModule", NotNull(&moduleResolver), {}, &logger, NotNull{dfg.get()}, {} + }; + cs.run(); +} + +} // namespace Luau diff --git a/tests/ConstraintGraphBuilderFixture.h b/tests/ConstraintGeneratorFixture.h similarity index 74% rename from tests/ConstraintGraphBuilderFixture.h rename to tests/ConstraintGeneratorFixture.h index 5e7fedab5..782747c70 100644 --- a/tests/ConstraintGraphBuilderFixture.h +++ b/tests/ConstraintGeneratorFixture.h @@ -1,7 +1,7 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #pragma once -#include "Luau/ConstraintGraphBuilder.h" +#include "Luau/ConstraintGenerator.h" #include "Luau/ConstraintSolver.h" #include "Luau/DcrLogger.h" #include "Luau/TypeArena.h" @@ -13,23 +13,25 @@ namespace Luau { -struct ConstraintGraphBuilderFixture : Fixture +struct ConstraintGeneratorFixture : Fixture { TypeArena arena; ModulePtr mainModule; DcrLogger logger; UnifierSharedState sharedState{&ice}; Normalizer normalizer{&arena, builtinTypes, NotNull{&sharedState}}; + TypeCheckLimits limits; + TypeFunctionRuntime typeFunctionRuntime{NotNull{&ice}, NotNull{&limits}}; std::unique_ptr dfg; - std::unique_ptr cgb; + std::unique_ptr cg; Scope* rootScope = nullptr; std::vector> constraints; ScopedFastFlag forceTheFlag; - ConstraintGraphBuilderFixture(); + ConstraintGeneratorFixture(); void generateConstraints(const std::string& code); void solve(const std::string& code); diff --git a/tests/ConstraintGraphBuilderFixture.cpp b/tests/ConstraintGraphBuilderFixture.cpp deleted file mode 100644 index a9a43f0b6..000000000 --- a/tests/ConstraintGraphBuilderFixture.cpp +++ /dev/null @@ -1,39 +0,0 @@ -// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details -#include "ConstraintGraphBuilderFixture.h" - -#include "Luau/TypeReduction.h" - -namespace Luau -{ - -ConstraintGraphBuilderFixture::ConstraintGraphBuilderFixture() - : Fixture() - , mainModule(new Module) - , forceTheFlag{"DebugLuauDeferredConstraintResolution", true} -{ - mainModule->reduction = std::make_unique(NotNull{&mainModule->internalTypes}, builtinTypes, NotNull{&ice}); - - BlockedType::nextIndex = 0; - BlockedTypePack::nextIndex = 0; -} - -void ConstraintGraphBuilderFixture::generateConstraints(const std::string& code) -{ - AstStatBlock* root = parse(code); - dfg = std::make_unique(DataFlowGraphBuilder::build(root, NotNull{&ice})); - cgb = std::make_unique("MainModule", mainModule, &arena, NotNull(&moduleResolver), builtinTypes, NotNull(&ice), - frontend.getGlobalScope(), &logger, NotNull{dfg.get()}); - cgb->visit(root); - rootScope = cgb->rootScope; - constraints = Luau::borrowConstraints(cgb->constraints); -} - -void ConstraintGraphBuilderFixture::solve(const std::string& code) -{ - generateConstraints(code); - ConstraintSolver cs{NotNull{&normalizer}, NotNull{rootScope}, constraints, "MainModule", NotNull{mainModule->reduction.get()}, - NotNull(&moduleResolver), {}, &logger}; - cs.run(); -} - -} // namespace Luau diff --git a/tests/ConstraintSolver.test.cpp b/tests/ConstraintSolver.test.cpp index eaa0d41a2..b83fb3456 100644 --- a/tests/ConstraintSolver.test.cpp +++ b/tests/ConstraintSolver.test.cpp @@ -1,9 +1,11 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details -#include "ConstraintGraphBuilderFixture.h" +#include "ConstraintGeneratorFixture.h" #include "Fixture.h" #include "doctest.h" +LUAU_FASTFLAG(LuauSolverV2); + using namespace Luau; static TypeId requireBinding(Scope* scope, const char* name) @@ -15,7 +17,7 @@ static TypeId requireBinding(Scope* scope, const char* name) TEST_SUITE_BEGIN("ConstraintSolver"); -TEST_CASE_FIXTURE(ConstraintGraphBuilderFixture, "hello") +TEST_CASE_FIXTURE(ConstraintGeneratorFixture, "constraint_basics") { solve(R"( local a = 55 @@ -27,7 +29,7 @@ TEST_CASE_FIXTURE(ConstraintGraphBuilderFixture, "hello") CHECK("number" == toString(bType)); } -TEST_CASE_FIXTURE(ConstraintGraphBuilderFixture, "generic_function") +TEST_CASE_FIXTURE(ConstraintGeneratorFixture, "generic_function") { solve(R"( local function id(a) @@ -40,7 +42,7 @@ TEST_CASE_FIXTURE(ConstraintGraphBuilderFixture, "generic_function") CHECK("(a) -> a" == toString(idType)); } -TEST_CASE_FIXTURE(ConstraintGraphBuilderFixture, "proper_let_generalization") +TEST_CASE_FIXTURE(ConstraintGeneratorFixture, "proper_let_generalization") { solve(R"( local function a(c) @@ -56,8 +58,7 @@ TEST_CASE_FIXTURE(ConstraintGraphBuilderFixture, "proper_let_generalization") TypeId idType = requireBinding(rootScope, "b"); - ToStringOptions opts; - CHECK("(a) -> number" == toString(idType, opts)); + CHECK("(unknown) -> number" == toString(idType)); } TEST_SUITE_END(); diff --git a/tests/CostModel.test.cpp b/tests/CostModel.test.cpp index 018fa87cf..29fffb4f1 100644 --- a/tests/CostModel.test.cpp +++ b/tests/CostModel.test.cpp @@ -31,7 +31,7 @@ static uint64_t modelFunction(const char* source) AstStatFunction* func = result.root->body.data[0]->as(); REQUIRE(func); - return Luau::Compile::modelCost(func->func->body, func->func->args.data, func->func->args.size, {nullptr}); + return Luau::Compile::modelCost(func->func->body, func->func->args.data, func->func->args.size, DenseHashMap{nullptr}); } TEST_CASE("Expression") @@ -156,8 +156,8 @@ end const bool args1[] = {false}; const bool args2[] = {true}; - CHECK_EQ(82, Luau::Compile::computeCost(model, args1, 1)); - CHECK_EQ(79, Luau::Compile::computeCost(model, args2, 1)); + CHECK_EQ(76, Luau::Compile::computeCost(model, args1, 1)); + CHECK_EQ(73, Luau::Compile::computeCost(model, args2, 1)); } TEST_CASE("Conditional") @@ -240,4 +240,23 @@ end CHECK_EQ(3, Luau::Compile::computeCost(model, args2, 1)); } +TEST_CASE("MultipleAssignments") +{ + uint64_t model = modelFunction(R"( +function test(a) + local x = 0 + x = a + x = a + 1 + x, x, x = a + x = a, a, a +end +)"); + + const bool args1[] = {false}; + const bool args2[] = {true}; + + CHECK_EQ(8, Luau::Compile::computeCost(model, args1, 1)); + CHECK_EQ(7, Luau::Compile::computeCost(model, args2, 1)); +} + TEST_SUITE_END(); diff --git a/tests/DataFlowGraph.test.cpp b/tests/DataFlowGraph.test.cpp index bd5fe5628..4ea656eef 100644 --- a/tests/DataFlowGraph.test.cpp +++ b/tests/DataFlowGraph.test.cpp @@ -1,5 +1,6 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/DataFlowGraph.h" +#include "Fixture.h" #include "Luau/Error.h" #include "Luau/Parser.h" @@ -10,10 +11,12 @@ using namespace Luau; +LUAU_FASTFLAG(LuauSolverV2); + struct DataFlowGraphFixture { // Only needed to fix the operator== reflexivity of an empty Symbol. - ScopedFastFlag dcr{"DebugLuauDeferredConstraintResolution", true}; + ScopedFastFlag dcr{FFlag::LuauSolverV2, true}; InternalErrorReporter handle; @@ -33,19 +36,11 @@ struct DataFlowGraphFixture } template - NullableBreadcrumbId getBreadcrumb(const std::vector& nths = {nth(N)}) + DefId getDef(const std::vector& nths = {nth(N)}) { T* node = query(module, nths); REQUIRE(node); - return graph->getBreadcrumb(node); - } - - template - BreadcrumbId requireBreadcrumb(const std::vector& nths = {nth(N)}) - { - auto bc = getBreadcrumb(nths); - REQUIRE(bc); - return NotNull{bc}; + return graph->getDef(node); } }; @@ -58,7 +53,7 @@ TEST_CASE_FIXTURE(DataFlowGraphFixture, "define_locals_in_local_stat") local y = x )"); - REQUIRE(getBreadcrumb()); + (void)getDef(); } TEST_CASE_FIXTURE(DataFlowGraphFixture, "define_parameters_in_functions") @@ -69,7 +64,7 @@ TEST_CASE_FIXTURE(DataFlowGraphFixture, "define_parameters_in_functions") end )"); - REQUIRE(getBreadcrumb()); + (void)getDef(); } TEST_CASE_FIXTURE(DataFlowGraphFixture, "find_aliases") @@ -80,8 +75,8 @@ TEST_CASE_FIXTURE(DataFlowGraphFixture, "find_aliases") local z = y )"); - BreadcrumbId x = requireBreadcrumb(); - BreadcrumbId y = requireBreadcrumb(); + DefId x = getDef(); + DefId y = getDef(); REQUIRE(x != y); } @@ -95,9 +90,618 @@ TEST_CASE_FIXTURE(DataFlowGraphFixture, "independent_locals") local b = y )"); - BreadcrumbId x = requireBreadcrumb(); - BreadcrumbId y = requireBreadcrumb(); + DefId x = getDef(); + DefId y = getDef(); REQUIRE(x != y); } +TEST_CASE_FIXTURE(DataFlowGraphFixture, "phi") +{ + dfg(R"( + local x + + if a then + x = true + end + + local y = x + )"); + + DefId y = getDef(); + + const Phi* phi = get(y); + CHECK(phi); +} + +TEST_CASE_FIXTURE(DataFlowGraphFixture, "mutate_local_not_owned_by_while") +{ + dfg(R"( + local x + + while cond() do + x = true + end + + local y = x + )"); + + DefId x0 = graph->getDef(query(module)->vars.data[0]); + DefId x1 = getDef(); // x = true + DefId x2 = getDef(); // local y = x + + CHECK(x0 == x1); + CHECK(x1 == x2); +} + +TEST_CASE_FIXTURE(DataFlowGraphFixture, "mutate_local_owned_by_while") +{ + dfg(R"( + while cond() do + local x + x = true + x = 5 + end + )"); + + DefId x0 = graph->getDef(query(module)->vars.data[0]); + DefId x1 = getDef(); // x = true + DefId x2 = getDef(); // x = 5 + + CHECK(x0 != x1); + CHECK(x1 != x2); +} + +TEST_CASE_FIXTURE(DataFlowGraphFixture, "mutate_local_not_owned_by_repeat") +{ + dfg(R"( + local x + + repeat + x = true + until cond() + + local y = x + )"); + + DefId x0 = graph->getDef(query(module)->vars.data[0]); + DefId x1 = getDef(); // x = true + DefId x2 = getDef(); // local y = x + + CHECK(x0 == x1); + CHECK(x1 == x2); +} + +TEST_CASE_FIXTURE(DataFlowGraphFixture, "mutate_local_owned_by_repeat") +{ + dfg(R"( + repeat + local x + x = true + x = 5 + until cond() + )"); + + DefId x0 = graph->getDef(query(module)->vars.data[0]); + DefId x1 = getDef(); // x = true + DefId x2 = getDef(); // x = 5 + + CHECK(x0 != x1); + CHECK(x1 != x2); +} + +TEST_CASE_FIXTURE(DataFlowGraphFixture, "mutate_local_not_owned_by_for") +{ + dfg(R"( + local x + + for i = 0, 5 do + x = true + end + + local y = x + )"); + + DefId x0 = graph->getDef(query(module)->vars.data[0]); + DefId x1 = getDef(); // x = true + DefId x2 = getDef(); // local y = x + + CHECK(x0 == x1); + CHECK(x1 == x2); +} + +TEST_CASE_FIXTURE(DataFlowGraphFixture, "mutate_local_owned_by_for") +{ + dfg(R"( + for i = 0, 5 do + local x + x = true + x = 5 + end + )"); + + DefId x0 = graph->getDef(query(module)->vars.data[0]); + DefId x1 = getDef(); // x = true + DefId x2 = getDef(); // x = 5 + + CHECK(x0 != x1); + CHECK(x1 != x2); +} + +TEST_CASE_FIXTURE(DataFlowGraphFixture, "mutate_local_not_owned_by_for_in") +{ + dfg(R"( + local x + + for i, v in t do + x = true + end + + local y = x + )"); + + DefId x0 = graph->getDef(query(module)->vars.data[0]); + DefId x1 = getDef(); // x = true + DefId x2 = getDef(); // local y = x + + CHECK(x0 == x1); + CHECK(x1 == x2); +} + +TEST_CASE_FIXTURE(DataFlowGraphFixture, "mutate_local_owned_by_for_in") +{ + dfg(R"( + for i, v in t do + local x + x = true + x = 5 + end + )"); + + DefId x0 = graph->getDef(query(module)->vars.data[0]); + DefId x1 = getDef(); // x = true + DefId x2 = getDef(); // x = 5 + + CHECK(x0 != x1); + CHECK(x1 != x2); +} + +TEST_CASE_FIXTURE(DataFlowGraphFixture, "mutate_preexisting_property_not_owned_by_while") +{ + dfg(R"( + local t = {} + t.x = 5 + + while cond() do + t.x = true + end + + local y = t.x + )"); + + DefId x1 = getDef(); // t.x = 5 + DefId x2 = getDef(); // t.x = true + DefId x3 = getDef(); // local y = t.x + + CHECK(x1 == x2); + CHECK(x2 == x3); +} + +TEST_CASE_FIXTURE(DataFlowGraphFixture, "mutate_non_preexisting_property_not_owned_by_while") +{ + dfg(R"( + local t = {} + + while cond() do + t.x = true + end + + local y = t.x + )"); + + DefId x1 = getDef(); // t.x = true + DefId x2 = getDef(); // local y = t.x + + CHECK(x1 == x2); +} + +TEST_CASE_FIXTURE(DataFlowGraphFixture, "mutate_property_of_table_owned_by_while") +{ + dfg(R"( + while cond() do + local t = {} + t.x = true + t.x = 5 + end + )"); + + DefId x1 = getDef(); // t.x = true + DefId x2 = getDef(); // t.x = 5 + + CHECK(x1 != x2); +} + +TEST_CASE_FIXTURE(DataFlowGraphFixture, "property_lookup_on_a_phi_node") +{ + dfg(R"( + local t = {} + t.x = 5 + + if cond() then + t.x = 7 + end + + print(t.x) + )"); + + DefId x1 = getDef(); // t.x = 5 + DefId x2 = getDef(); // t.x = 7 + DefId x3 = getDef(); // print(t.x) + + CHECK(x1 != x2); + CHECK(x2 != x3); + + const Phi* phi = get(x3); + REQUIRE(phi); + REQUIRE(phi->operands.size() == 2); + CHECK(phi->operands.at(0) == x1); + CHECK(phi->operands.at(1) == x2); +} + +TEST_CASE_FIXTURE(DataFlowGraphFixture, "property_lookup_on_a_phi_node_2") +{ + dfg(R"( + local t = {} + + if cond() then + t.x = 5 + else + t.x = 7 + end + + print(t.x) + )"); + + DefId x1 = getDef(); // t.x = 5 + DefId x2 = getDef(); // t.x = 7 + DefId x3 = getDef(); // print(t.x) + + CHECK(x1 != x2); + CHECK(x2 != x3); + + const Phi* phi = get(x3); + REQUIRE(phi); + REQUIRE(phi->operands.size() == 2); + CHECK(phi->operands.at(0) == x2); + CHECK(phi->operands.at(1) == x1); +} + +TEST_CASE_FIXTURE(DataFlowGraphFixture, "property_lookup_on_a_phi_node_3") +{ + dfg(R"( + local t = {} + t.x = 3 + + if cond() then + t.x = 5 + t.y = 7 + else + t.z = 42 + end + + print(t.x) + print(t.y) + print(t.z) + )"); + + DefId x1 = getDef(); // t.x = 3 + DefId x2 = getDef(); // t.x = 5 + + DefId y1 = getDef(); // t.y = 7 + + DefId z1 = getDef(); // t.z = 42 + + DefId x3 = getDef(); // print(t.x) + DefId y2 = getDef(); // print(t.y) + DefId z2 = getDef(); // print(t.z) + + CHECK(x1 != x2); + CHECK(x2 != x3); + CHECK(y1 == y2); + CHECK(z1 == z2); + + const Phi* phi = get(x3); + REQUIRE(phi); + REQUIRE(phi->operands.size() == 2); + CHECK(phi->operands.at(0) == x1); + CHECK(phi->operands.at(1) == x2); +} + +TEST_CASE_FIXTURE(DataFlowGraphFixture, "function_captures_are_phi_nodes_of_all_versions") +{ + dfg(R"( + local x = 5 + + function f() + print(x) + x = nil + end + + f() + x = "five" + )"); + + DefId x1 = graph->getDef(query(module)->vars.data[0]); + DefId x2 = getDef(); // print(x) + DefId x3 = getDef(); // x = nil + DefId x4 = getDef(); // x = "five" + + CHECK(x1 != x2); + CHECK(x2 != x3); + CHECK(x3 != x4); + + const Phi* phi = get(x2); + REQUIRE(phi); + REQUIRE(phi->operands.size() == 3); + CHECK(phi->operands.at(0) == x1); + CHECK(phi->operands.at(1) == x3); + CHECK(phi->operands.at(2) == x4); +} + +TEST_CASE_FIXTURE(DataFlowGraphFixture, "function_captures_are_phi_nodes_of_all_versions_properties") +{ + dfg(R"( + local t = {} + t.x = 5 + + function f() + print(t.x) + t.x = nil + end + + f() + t.x = "five" + )"); + + DefId x1 = getDef(); // t.x = 5 + DefId x2 = getDef(); // print(t.x) + DefId x3 = getDef(); // t.x = nil + DefId x4 = getDef(); // t.x = "five" + + CHECK(x1 != x2); + CHECK(x2 != x3); + CHECK(x3 != x4); + + // When a local is referenced within a function, it is not pointer identical. + // Instead, it's a phi node of all possible versions, including just one version. + DefId t1 = graph->getDef(query(module)->vars.data[0]); + DefId t2 = getDef(); // print(t.x) + + const Phi* phi = get(t2); + REQUIRE(phi); + REQUIRE(phi->operands.size() == 1); + CHECK(phi->operands.at(0) == t1); +} + +TEST_CASE_FIXTURE(DataFlowGraphFixture, "local_f_which_is_prototyped_enclosed_by_function") +{ + dfg(R"( + local f + function f() + if cond() then + f() + end + end + )"); + + DefId f1 = graph->getDef(query(module)->vars.data[0]); + DefId f2 = getDef(); // function f() + DefId f3 = getDef(); // f() + + CHECK(f1 != f2); + CHECK(f2 != f3); + + const Phi* phi = get(f3); + REQUIRE(phi); + REQUIRE(phi->operands.size() == 1); + CHECK(phi->operands.at(0) == f2); +} + +TEST_CASE_FIXTURE(DataFlowGraphFixture, "local_f_which_is_prototyped_enclosed_by_function_has_some_prior_versions") +{ + dfg(R"( + local f + f = 5 + function f() + if cond() then + f() + end + end + )"); + + DefId f1 = graph->getDef(query(module)->vars.data[0]); + DefId f2 = getDef(); // f = 5 + DefId f3 = getDef(); // function f() + DefId f4 = getDef(); // f() + + CHECK(f1 != f2); + CHECK(f2 != f3); + CHECK(f3 != f4); + + const Phi* phi = get(f4); + REQUIRE(phi); + REQUIRE(phi->operands.size() == 1); + CHECK(phi->operands.at(0) == f3); +} + +TEST_CASE_FIXTURE(DataFlowGraphFixture, "local_f_which_is_prototyped_enclosed_by_function_has_some_future_versions") +{ + dfg(R"( + local f + function f() + if cond() then + f() + end + end + f = 5 + )"); + + DefId f1 = graph->getDef(query(module)->vars.data[0]); + DefId f2 = getDef(); // function f() + DefId f3 = getDef(); // f() + DefId f4 = getDef(); // f = 5 + + CHECK(f1 != f2); + CHECK(f2 != f3); + CHECK(f3 != f4); + + const Phi* phi = get(f3); + REQUIRE(phi); + REQUIRE(phi->operands.size() == 2); + CHECK(phi->operands.at(0) == f2); + CHECK(phi->operands.at(1) == f4); +} + +TEST_CASE_FIXTURE(DataFlowGraphFixture, "phi_node_if_case_binding") +{ + dfg(R"( +local x = nil +if true then + if true then + x = 5 + end + print(x) +else + print(x) +end +)"); + DefId x1 = graph->getDef(query(module)->vars.data[0]); + DefId x2 = getDef(); // x = 5 + DefId x3 = getDef(); // print(x) + + const Phi* phi = get(x3); + REQUIRE(phi); + CHECK(phi->operands.at(0) == x2); + CHECK(phi->operands.at(1) == x1); +} + +TEST_CASE_FIXTURE(DataFlowGraphFixture, "phi_node_if_case_table_prop") +{ + dfg(R"( +local t = {} +t.x = true +if true then + if true then + t.x = 5 + end + print(t.x) +else + print(t.x) +end +)"); + + DefId x1 = getDef(); // t.x = true + DefId x2 = getDef(); // t.x = 5 + + DefId x3 = getDef(); // print(t.x) + const Phi* phi = get(x3); + REQUIRE(phi); + CHECK(phi->operands.size() == 2); + CHECK(phi->operands.at(0) == x1); + CHECK(phi->operands.at(1) == x2); +} + +TEST_CASE_FIXTURE(DataFlowGraphFixture, "phi_node_if_case_table_prop_literal") +{ + dfg(R"( +local t = { x = true } +if true then + t.x = 5 +end +print(t.x) + +)"); + + DefId x1 = getDef(); // {x = true <- } + DefId x2 = getDef(); // t.x = 5 + DefId x3 = getDef(); // print(t.x) + const Phi* phi = get(x3); + REQUIRE(phi); + CHECK(phi->operands.size() == 2); + CHECK(phi->operands.at(0) == x1); + CHECK(phi->operands.at(1) == x2); +} + +TEST_CASE_FIXTURE(DataFlowGraphFixture, "insert_trivial_phi_nodes_inside_of_phi_nodes") +{ + dfg(R"( + local t = {} + + local function f(k: string) + if t[k] ~= nil then + return + end + + t[k] = 5 + end + )"); + + DefId t1 = graph->getDef(query(module)->vars.data[0]); // local t = {} + DefId t2 = getDef(); // t[k] ~= nil + DefId t3 = getDef(); // t[k] = 5 + + CHECK(t1 != t2); + CHECK(t2 == t3); + + const Phi* t2phi = get(t2); + REQUIRE(t2phi); + CHECK(t2phi->operands.size() == 1); + CHECK(t2phi->operands.at(0) == t1); +} + +TEST_CASE_FIXTURE(DataFlowGraphFixture, "dfg_function_definition_in_a_do_block") +{ + dfg(R"( + local f + do + function f() + end + end + f() + )"); + + DefId x1 = graph->getDef(query(module)->vars.data[0]); + DefId x2 = getDef(); // x = 5 + DefId x3 = getDef(); // print(x) + + CHECK(x1 != x2); + CHECK(x1 != x3); + CHECK(x2 == x3); +} + +TEST_CASE_FIXTURE(DataFlowGraphFixture, "dfg_captured_local_is_assigned_a_function") +{ + dfg(R"( + local f + + local function g() + f() + end + + function f() + end + )"); + + DefId f1 = graph->getDef(query(module)->vars.data[0]); + DefId f2 = getDef(); + DefId f3 = getDef(); + + CHECK(f1 != f2); + CHECK(f2 != f3); + + const Phi* f2phi = get(f2); + REQUIRE(f2phi); + CHECK(f2phi->operands.size() == 1); + CHECK(f2phi->operands.at(0) == f3); +} + TEST_SUITE_END(); diff --git a/tests/DiffAsserts.cpp b/tests/DiffAsserts.cpp new file mode 100644 index 000000000..f343367f9 --- /dev/null +++ b/tests/DiffAsserts.cpp @@ -0,0 +1,31 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "DiffAsserts.h" + +#include + +namespace Luau +{ + + +std::string toString(const DifferResult& result) +{ + if (result.diffError) + return result.diffError->toString(); + else + return ""; +} + +template<> +std::string diff(TypeId l, TypeId r) +{ + return toString(diff(l, r)); +} + +template<> +std::string diff(const Type& l, const Type& r) +{ + return toString(diff(&l, &r)); +} + +} // namespace Luau diff --git a/tests/DiffAsserts.h b/tests/DiffAsserts.h new file mode 100644 index 000000000..b80ea3121 --- /dev/null +++ b/tests/DiffAsserts.h @@ -0,0 +1,46 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/Differ.h" +#include "Luau/TypeFwd.h" + +#include "doctest.h" + +#include +#include + +namespace Luau +{ + +std::string toString(const DifferResult& result); + +template +std::string diff(L, R) +{ + return ""; +} + +template<> +std::string diff(TypeId l, TypeId r); + +template<> +std::string diff(const Type& l, const Type& r); + +} // namespace Luau + +// Note: the do-while blocks in the macros below is to scope the INFO block to +// only that assertion. + +#define CHECK_EQ_DIFF(l, r) \ + do \ + { \ + INFO("Left and right values were not equal: ", diff(l, r)); \ + CHECK_EQ(l, r); \ + } while (false); + +#define REQUIRE_EQ_DIFF(l, r) \ + do \ + { \ + INFO("Left and right values were not equal: ", diff(l, r)); \ + REQUIRE_EQ(l, r); \ + } while (false); diff --git a/tests/Differ.test.cpp b/tests/Differ.test.cpp new file mode 100644 index 000000000..a2b2280bd --- /dev/null +++ b/tests/Differ.test.cpp @@ -0,0 +1,1763 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/Differ.h" +#include "Luau/Common.h" +#include "Luau/Error.h" +#include "Luau/Frontend.h" + +#include "Fixture.h" +#include "ClassFixture.h" + +#include "Luau/Symbol.h" +#include "Luau/Type.h" +#include "ScopedFlags.h" +#include "doctest.h" +#include + +using namespace Luau; + +LUAU_FASTFLAG(LuauSolverV2) + +TEST_SUITE_BEGIN("Differ"); + +TEST_CASE_FIXTURE(DifferFixture, "equal_numbers") +{ + CheckResult result = check(R"( + local foo = 5 + local almostFoo = 78 + )"); + LUAU_REQUIRE_NO_ERRORS(result); + + compareTypesEq("foo", "almostFoo"); +} + +TEST_CASE_FIXTURE(DifferFixture, "equal_strings") +{ + CheckResult result = check(R"( + local foo = "hello" + local almostFoo = "world" + )"); + LUAU_REQUIRE_NO_ERRORS(result); + + compareTypesEq("foo", "almostFoo"); +} + +TEST_CASE_FIXTURE(DifferFixture, "equal_tables") +{ + CheckResult result = check(R"( + local foo = { x = 1, y = "where" } + local almostFoo = { x = 5, y = "when" } + )"); + LUAU_REQUIRE_NO_ERRORS(result); + + compareTypesEq("foo", "almostFoo"); +} + +TEST_CASE_FIXTURE(DifferFixture, "a_table_missing_property") +{ + CheckResult result = check(R"( + local foo = { x = 1, y = 2 } + local almostFoo = { x = 1, z = 3 } + )"); + LUAU_REQUIRE_NO_ERRORS(result); + + compareTypesNe( + "foo", + "almostFoo", + "DiffError: these two types are not equal because the left type at foo.y has type number, while the right type at almostFoo is missing " + "the property y" + ); +} + +TEST_CASE_FIXTURE(DifferFixture, "left_table_missing_property") +{ + CheckResult result = check(R"( + local foo = { x = 1 } + local almostFoo = { x = 1, z = 3 } + )"); + LUAU_REQUIRE_NO_ERRORS(result); + + compareTypesNe( + "foo", + "almostFoo", + "DiffError: these two types are not equal because the left type at foo is missing the property z, while the right type at almostFoo.z " + "has type number" + ); +} + +TEST_CASE_FIXTURE(DifferFixture, "a_table_wrong_type") +{ + CheckResult result = check(R"( + local foo = { x = 1, y = 2 } + local almostFoo = { x = 1, y = "two" } + )"); + LUAU_REQUIRE_NO_ERRORS(result); + + compareTypesNe( + "foo", + "almostFoo", + "DiffError: these two types are not equal because the left type at foo.y has type number, while the right type at almostFoo.y has type " + "string" + ); +} + +TEST_CASE_FIXTURE(DifferFixture, "a_table_wrong_type") +{ + CheckResult result = check(R"( + local foo: string + local almostFoo: number + )"); + LUAU_REQUIRE_NO_ERRORS(result); + + compareTypesNe( + "foo", + "almostFoo", + "DiffError: these two types are not equal because the left type at has type string, while the right type at " + " has type number" + ); +} + +TEST_CASE_FIXTURE(DifferFixture, "a_nested_table_wrong_type") +{ + CheckResult result = check(R"( + local foo = { x = 1, inner = { table = { has = { wrong = { value = 5 } } } } } + local almostFoo = { x = 1, inner = { table = { has = { wrong = { value = "five" } } } } } + )"); + LUAU_REQUIRE_NO_ERRORS(result); + + compareTypesNe( + "foo", + "almostFoo", + "DiffError: these two types are not equal because the left type at foo.inner.table.has.wrong.value has type number, while the right " + "type at almostFoo.inner.table.has.wrong.value has type string" + ); +} + +TEST_CASE_FIXTURE(DifferFixture, "a_nested_table_wrong_match") +{ + CheckResult result = check(R"( + local foo = { x = 1, inner = { table = { has = { wrong = { variant = { because = { it = { goes = { on = "five" } } } } } } } } } + local almostFoo = { x = 1, inner = { table = { has = { wrong = { variant = "five" } } } } } + )"); + LUAU_REQUIRE_NO_ERRORS(result); + + compareTypesNe( + "foo", + "almostFoo", + "DiffError: these two types are not equal because the left type at foo.inner.table.has.wrong.variant has type { because: { it: { goes: " + "{ on: string } } } }, while the right type at almostFoo.inner.table.has.wrong.variant has type string" + ); +} + +TEST_CASE_FIXTURE(DifferFixture, "left_cyclic_table_right_table_missing_property") +{ + CheckResult result = check(R"( + local function id(x: a): a + return x + end + + -- Remove name from cyclic table + local foo = id({}) + foo.foo = foo + local almostFoo = { x = 2 } + )"); + LUAU_REQUIRE_NO_ERRORS(result); + + compareTypesNe( + "foo", + "almostFoo", + R"(DiffError: these two types are not equal because the left type at .foo has type t1 where t1 = { foo: t1 }, while the right type at almostFoo is missing the property foo)" + ); +} + +TEST_CASE_FIXTURE(DifferFixture, "left_cyclic_table_right_table_property_wrong") +{ + CheckResult result = check(R"( + local function id(x: a): a + return x + end + + -- Remove name from cyclic table + local foo = id({}) + foo.foo = foo + local almostFoo = { foo = 2 } + )"); + LUAU_REQUIRE_NO_ERRORS(result); + + compareTypesNe( + "foo", + "almostFoo", + R"(DiffError: these two types are not equal because the left type at .foo has type t1 where t1 = { foo: t1 }, while the right type at almostFoo.foo has type number)" + ); +} + +TEST_CASE_FIXTURE(DifferFixture, "right_cyclic_table_left_table_missing_property") +{ + CheckResult result = check(R"( + local function id(x: a): a + return x + end + + -- Remove name from cyclic table + local foo = id({}) + foo.foo = foo + local almostFoo = { x = 2 } + )"); + LUAU_REQUIRE_NO_ERRORS(result); + + compareTypesNe( + "almostFoo", + "foo", + R"(DiffError: these two types are not equal because the left type at almostFoo.x has type number, while the right type at is missing the property x)" + ); +} + +TEST_CASE_FIXTURE(DifferFixture, "right_cyclic_table_left_table_property_wrong") +{ + CheckResult result = check(R"( + local function id(x: a): a + return x + end + + -- Remove name from cyclic table + local foo = id({}) + foo.foo = foo + local almostFoo = { foo = 2 } + )"); + LUAU_REQUIRE_NO_ERRORS(result); + + compareTypesNe( + "almostFoo", + "foo", + R"(DiffError: these two types are not equal because the left type at almostFoo.foo has type number, while the right type at .foo has type t1 where t1 = { foo: t1 })" + ); +} + +TEST_CASE_FIXTURE(DifferFixture, "equal_table_two_cyclic_tables_are_not_different") +{ + ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + + CheckResult result = check(R"( + local function id(x: a): a + return x + end + + -- Remove name from cyclic table + local foo = id({}) + foo.foo = foo + local almostFoo = id({}) + almostFoo.foo = almostFoo + )"); + LUAU_REQUIRE_NO_ERRORS(result); + + compareTypesEq("foo", "almostFoo"); +} + +TEST_CASE_FIXTURE(DifferFixture, "equal_table_two_shifted_circles_are_not_different") +{ + CheckResult result = check(R"( + local function id(x: a): a + return x + end + + -- Remove name from cyclic table + local foo = id({}) + foo.foo = id({}) + foo.foo.foo = id({}) + foo.foo.foo.foo = id({}) + foo.foo.foo.foo.foo = foo + + local builder = id({}) + builder.foo = id({}) + builder.foo.foo = id({}) + builder.foo.foo.foo = id({}) + builder.foo.foo.foo.foo = builder + -- Shift + local almostFoo = builder.foo.foo + )"); + LUAU_REQUIRE_NO_ERRORS(result); + + compareTypesEq("foo", "almostFoo"); +} + +TEST_CASE_FIXTURE(DifferFixture, "table_left_circle_right_measuring_tape") +{ + // Left is a circle, right is a measuring tape + CheckResult result = check(R"( + local function id(x: a): a + return x + end + + -- Remove name from cyclic table + local foo = id({}) + foo.foo = id({}) + foo.foo.foo = id({}) + foo.foo.foo.foo = id({}) + foo.foo.foo.bar = id({}) -- anchor to pin shape + foo.foo.foo.foo.foo = foo + local almostFoo = id({}) + almostFoo.foo = id({}) + almostFoo.foo.foo = id({}) + almostFoo.foo.foo.foo = id({}) + almostFoo.foo.foo.bar = id({}) -- anchor to pin shape + almostFoo.foo.foo.foo.foo = almostFoo.foo + )"); + LUAU_REQUIRE_NO_ERRORS(result); + + compareTypesNe( + "foo", + "almostFoo", + R"(DiffError: these two types are not equal because the left type at .foo.foo.foo.foo.foo is missing the property bar, while the right type at .foo.foo.foo.foo.foo.bar has type { })" + ); +} + +TEST_CASE_FIXTURE(DifferFixture, "equal_table_measuring_tapes") +{ + CheckResult result = check(R"( + local function id(x: a): a + return x + end + + -- Remove name from cyclic table + local foo = id({}) + foo.foo = id({}) + foo.foo.foo = id({}) + foo.foo.foo.foo = id({}) + foo.foo.foo.foo.foo = foo.foo + local almostFoo = id({}) + almostFoo.foo = id({}) + almostFoo.foo.foo = id({}) + almostFoo.foo.foo.foo = id({}) + almostFoo.foo.foo.foo.foo = almostFoo.foo + )"); + LUAU_REQUIRE_NO_ERRORS(result); + + compareTypesEq("foo", "almostFoo"); +} + +TEST_CASE_FIXTURE(DifferFixture, "equal_table_A_B_C") +{ + CheckResult result = check(R"( + local function id(x: a): a + return x + end + + -- Remove name from cyclic table + local foo = id({}) + foo.foo = id({}) + foo.foo.foo = id({}) + foo.foo.foo.foo = id({}) + foo.foo.foo.foo.foo = foo.foo + local almostFoo = id({}) + almostFoo.foo = id({}) + almostFoo.foo.foo = id({}) + almostFoo.foo.foo.foo = id({}) + almostFoo.foo.foo.foo.foo = almostFoo.foo + )"); + LUAU_REQUIRE_NO_ERRORS(result); + + compareTypesEq("foo", "almostFoo"); +} + +TEST_CASE_FIXTURE(DifferFixture, "equal_table_kind_A") +{ + CheckResult result = check(R"( + -- Remove name from cyclic table + local function id(x: a): a + return x + end + + local foo = id({}) + foo.left = id({}) + foo.right = id({}) + foo.left.left = id({}) + foo.left.right = id({}) + foo.right.left = id({}) + foo.right.right = id({}) + foo.right.left.left = id({}) + foo.right.left.right = id({}) + + foo.right.left.left.child = foo.right + + local almostFoo = id({}) + almostFoo.left = id({}) + almostFoo.right = id({}) + almostFoo.left.left = id({}) + almostFoo.left.right = id({}) + almostFoo.right.left = id({}) + almostFoo.right.right = id({}) + almostFoo.right.left.left = id({}) + almostFoo.right.left.right = id({}) + + almostFoo.right.left.left.child = almostFoo.right + + -- Bindings for requireType + local fooLeft = foo.left + local fooRight = foo.left.right + local fooLeftLeft = foo.left.left + local fooLeftRight = foo.left.right + local fooRightLeft = foo.right.left + local fooRightRight = foo.right.right + local fooRightLeftLeft = foo.right.left.left + local fooRightLeftRight = foo.right.left.right + + local almostFooLeft = almostFoo.left + local almostFooRight = almostFoo.left.right + local almostFooLeftLeft = almostFoo.left.left + local almostFooLeftRight = almostFoo.left.right + local almostFooRightLeft = almostFoo.right.left + local almostFooRightRight = almostFoo.right.right + local almostFooRightLeftLeft = almostFoo.right.left.left + local almostFooRightLeftRight = almostFoo.right.left.right + )"); + LUAU_REQUIRE_NO_ERRORS(result); + + compareTypesEq("foo", "almostFoo"); +} + +TEST_CASE_FIXTURE(DifferFixture, "equal_table_kind_B") +{ + CheckResult result = check(R"( + -- Remove name from cyclic table + local function id(x: a): a + return x + end + + local foo = id({}) + foo.left = id({}) + foo.right = id({}) + foo.left.left = id({}) + foo.left.right = id({}) + foo.right.left = id({}) + foo.right.right = id({}) + foo.right.left.left = id({}) + foo.right.left.right = id({}) + + foo.right.left.left.child = foo.left + + local almostFoo = id({}) + almostFoo.left = id({}) + almostFoo.right = id({}) + almostFoo.left.left = id({}) + almostFoo.left.right = id({}) + almostFoo.right.left = id({}) + almostFoo.right.right = id({}) + almostFoo.right.left.left = id({}) + almostFoo.right.left.right = id({}) + + almostFoo.right.left.left.child = almostFoo.left + )"); + LUAU_REQUIRE_NO_ERRORS(result); + + compareTypesEq("foo", "almostFoo"); +} + +TEST_CASE_FIXTURE(DifferFixture, "equal_table_kind_C") +{ + CheckResult result = check(R"( + -- Remove name from cyclic table + local function id(x: a): a + return x + end + + local foo = id({}) + foo.left = id({}) + foo.right = id({}) + foo.left.left = id({}) + foo.left.right = id({}) + foo.right.left = id({}) + foo.right.right = id({}) + foo.right.left.left = id({}) + foo.right.left.right = id({}) + + foo.right.left.left.child = foo + + local almostFoo = id({}) + almostFoo.left = id({}) + almostFoo.right = id({}) + almostFoo.left.left = id({}) + almostFoo.left.right = id({}) + almostFoo.right.left = id({}) + almostFoo.right.right = id({}) + almostFoo.right.left.left = id({}) + almostFoo.right.left.right = id({}) + + almostFoo.right.left.left.child = almostFoo + )"); + LUAU_REQUIRE_NO_ERRORS(result); + + compareTypesEq("foo", "almostFoo"); +} + +TEST_CASE_FIXTURE(DifferFixture, "equal_table_kind_D") +{ + CheckResult result = check(R"( + -- Remove name from cyclic table + local function id(x: a): a + return x + end + + local foo = id({}) + foo.left = id({}) + foo.right = id({}) + foo.left.left = id({}) + foo.left.right = id({}) + foo.right.left = id({}) + foo.right.right = id({}) + foo.right.left.left = id({}) + foo.right.left.right = id({}) + + foo.right.left.left.child = foo.right.left.left + + local almostFoo = id({}) + almostFoo.left = id({}) + almostFoo.right = id({}) + almostFoo.left.left = id({}) + almostFoo.left.right = id({}) + almostFoo.right.left = id({}) + almostFoo.right.right = id({}) + almostFoo.right.left.left = id({}) + almostFoo.right.left.right = id({}) + + almostFoo.right.left.left.child = almostFoo.right.left.left + )"); + LUAU_REQUIRE_NO_ERRORS(result); + + compareTypesEq("foo", "almostFoo"); +} + +TEST_CASE_FIXTURE(DifferFixture, "equal_table_cyclic_diamonds_unraveled") +{ + CheckResult result = check(R"( + -- Remove name from cyclic table + local function id(x: a): a + return x + end + + -- Pattern 1 + local foo = id({}) + foo.child = id({}) + foo.child.left = id({}) + foo.child.right = id({}) + + foo.child.left.child = foo + foo.child.right.child = foo + + -- Pattern 2 + local almostFoo = id({}) + almostFoo.child = id({}) + almostFoo.child.left = id({}) + almostFoo.child.right = id({}) + + almostFoo.child.left.child = id({}) -- Use a new table + almostFoo.child.right.child = almostFoo.child.left.child -- Refer to the same new table + + almostFoo.child.left.child.child = id({}) + almostFoo.child.left.child.child.left = id({}) + almostFoo.child.left.child.child.right = id({}) + + almostFoo.child.left.child.child.left.child = almostFoo.child.left.child + almostFoo.child.left.child.child.right.child = almostFoo.child.left.child + + -- Pattern 3 + local anotherFoo = id({}) + anotherFoo.child = id({}) + anotherFoo.child.left = id({}) + anotherFoo.child.right = id({}) + + anotherFoo.child.left.child = id({}) -- Use a new table + anotherFoo.child.right.child = id({}) -- Use another new table + + anotherFoo.child.left.child.child = id({}) + anotherFoo.child.left.child.child.left = id({}) + anotherFoo.child.left.child.child.right = id({}) + anotherFoo.child.right.child.child = id({}) + anotherFoo.child.right.child.child.left = id({}) + anotherFoo.child.right.child.child.right = id({}) + + anotherFoo.child.left.child.child.left.child = anotherFoo.child.left.child + anotherFoo.child.left.child.child.right.child = anotherFoo.child.left.child + anotherFoo.child.right.child.child.left.child = anotherFoo.child.right.child + anotherFoo.child.right.child.child.right.child = anotherFoo.child.right.child + + -- Pattern 4 + local cleverFoo = id({}) + cleverFoo.child = id({}) + cleverFoo.child.left = id({}) + cleverFoo.child.right = id({}) + + cleverFoo.child.left.child = id({}) -- Use a new table + cleverFoo.child.right.child = id({}) -- Use another new table + + cleverFoo.child.left.child.child = id({}) + cleverFoo.child.left.child.child.left = id({}) + cleverFoo.child.left.child.child.right = id({}) + cleverFoo.child.right.child.child = id({}) + cleverFoo.child.right.child.child.left = id({}) + cleverFoo.child.right.child.child.right = id({}) + -- Same as pattern 3, but swapped here + cleverFoo.child.left.child.child.left.child = cleverFoo.child.right.child -- Swap + cleverFoo.child.left.child.child.right.child = cleverFoo.child.right.child + cleverFoo.child.right.child.child.left.child = cleverFoo.child.left.child + cleverFoo.child.right.child.child.right.child = cleverFoo.child.left.child + + -- Pattern 5 + local cheekyFoo = id({}) + cheekyFoo.child = id({}) + cheekyFoo.child.left = id({}) + cheekyFoo.child.right = id({}) + + cheekyFoo.child.left.child = foo -- Use existing pattern + cheekyFoo.child.right.child = foo -- Use existing pattern + )"); + LUAU_REQUIRE_NO_ERRORS(result); + + std::vector symbols{"foo", "almostFoo", "anotherFoo", "cleverFoo", "cheekyFoo"}; + + for (auto left : symbols) + { + for (auto right : symbols) + { + compareTypesEq(left, right); + } + } +} + +TEST_CASE_FIXTURE(DifferFixture, "equal_function_cyclic") +{ + // Old solver does not correctly infer function typepacks + ScopedFastFlag sff{FFlag::LuauSolverV2, true}; + + CheckResult result = check(R"( + function foo() + return foo + end + function almostFoo() + function bar() + return bar + end + return bar + end + )"); + LUAU_REQUIRE_NO_ERRORS(result); + + compareTypesEq("foo", "almostFoo"); +} + +TEST_CASE_FIXTURE(DifferFixture, "equal_function_table_cyclic") +{ + // Old solver does not correctly infer function typepacks + ScopedFastFlag sff{FFlag::LuauSolverV2, true}; + + CheckResult result = check(R"( + function foo() + return { + bar = foo + } + end + function almostFoo() + function bar() + return { + bar = bar + } + end + return { + bar = bar + } + end + )"); + LUAU_REQUIRE_NO_ERRORS(result); + + compareTypesEq("foo", "almostFoo"); +} + +TEST_CASE_FIXTURE(DifferFixture, "function_table_self_referential_cyclic") +{ + // Old solver does not correctly infer function typepacks + // ScopedFastFlag sff{FFlag::LuauSolverV2, true}; + + CheckResult result = check(R"( + function foo() + return { + bar = foo + } + end + function almostFoo() + function bar() + return bar + end + return { + bar = bar + } + end + )"); + LUAU_REQUIRE_NO_ERRORS(result); + + if (FFlag::LuauSolverV2) + compareTypesNe( + "foo", + "almostFoo", + R"(DiffError: these two types are not equal because the left type at .Ret[1].bar.Ret[1] has type t1 where t1 = { bar: () -> t1 }, while the right type at .Ret[1].bar.Ret[1] has type t1 where t1 = () -> t1)" + ); + else + compareTypesNe( + "foo", + "almostFoo", + R"(DiffError: these two types are not equal because the left type at .Ret[1].bar.Ret[1] has type t1 where t1 = {| bar: () -> t1 |}, while the right type at .Ret[1].bar.Ret[1] has type t1 where t1 = () -> t1)" + ); +} + +TEST_CASE_FIXTURE(DifferFixture, "equal_union_cyclic") +{ + TypeArena arena; + TypeId number = arena.addType(PrimitiveType{PrimitiveType::Number}); + TypeId string = arena.addType(PrimitiveType{PrimitiveType::String}); + + TypeId foo = arena.addType(UnionType{std::vector{number, string}}); + UnionType* unionFoo = getMutable(foo); + unionFoo->options.push_back(foo); + + TypeId almostFoo = arena.addType(UnionType{std::vector{number, string}}); + UnionType* unionAlmostFoo = getMutable(almostFoo); + unionAlmostFoo->options.push_back(almostFoo); + + compareEq(foo, almostFoo); +} + +TEST_CASE_FIXTURE(DifferFixture, "equal_intersection_cyclic") +{ + // Old solver does not correctly refine test types + ScopedFastFlag sff{FFlag::LuauSolverV2, true}; + + CheckResult result = check(R"( + function foo1(x: number) + return x + end + function foo2(x: string) + return 0 + end + function bar1(x: number) + return x + end + function bar2(x: string) + return 0 + end + )"); + LUAU_REQUIRE_NO_ERRORS(result); + TypeId foo1 = requireType("foo1"); + TypeId foo2 = requireType("foo2"); + TypeId bar1 = requireType("bar1"); + TypeId bar2 = requireType("bar2"); + + TypeArena arena; + + TypeId foo = arena.addType(IntersectionType{std::vector{foo1, foo2}}); + IntersectionType* intersectionFoo = getMutable(foo); + intersectionFoo->parts.push_back(foo); + + TypeId almostFoo = arena.addType(IntersectionType{std::vector{bar1, bar2}}); + IntersectionType* intersectionAlmostFoo = getMutable(almostFoo); + intersectionAlmostFoo->parts.push_back(almostFoo); + + compareEq(foo, almostFoo); +} + +TEST_CASE_FIXTURE(DifferFixture, "singleton") +{ + CheckResult result = check(R"( + local foo: "hello" = "hello" + local almostFoo: true = true + )"); + LUAU_REQUIRE_NO_ERRORS(result); + + compareTypesNe( + "foo", + "almostFoo", + R"(DiffError: these two types are not equal because the left type at has type "hello", while the right type at has type true)" + ); +} + +TEST_CASE_FIXTURE(DifferFixture, "equal_singleton") +{ + CheckResult result = check(R"( + local foo: "hello" = "hello" + local almostFoo: "hello" + )"); + LUAU_REQUIRE_NO_ERRORS(result); + + compareTypesEq("foo", "almostFoo"); +} + +TEST_CASE_FIXTURE(DifferFixture, "singleton_string") +{ + CheckResult result = check(R"( + local foo: "hello" = "hello" + local almostFoo: "world" = "world" + )"); + LUAU_REQUIRE_NO_ERRORS(result); + + compareTypesNe( + "foo", + "almostFoo", + R"(DiffError: these two types are not equal because the left type at has type "hello", while the right type at has type "world")" + ); +} + +TEST_CASE_FIXTURE(DifferFixtureWithBuiltins, "negation") +{ + if (!FFlag::LuauSolverV2) + return; + + CheckResult result = check(R"( + local bar: { x: { y: unknown }} + local almostBar: { x: { y: unknown }} + + local foo + local almostFoo + + if typeof(bar.x.y) ~= "string" then + foo = bar + end + + if typeof(almostBar.x.y) ~= "number" then + almostFoo = almostBar + end + )"); + LUAU_REQUIRE_NO_ERRORS(result); + + compareTypesNe( + "foo", + "almostFoo", + R"(DiffError: these two types are not equal because the left type at is a union containing type { x: { y: ~string } }, while the right type at is a union missing type { x: { y: ~string } })" + ); + + // TODO: a more desirable expected error here is as below, but `Differ` requires improvements to + // dealing with unions to get something like this (recognizing that the union is identical + // except in one component where they differ). + // + // compareTypesNe("foo", "almostFoo", + // R"(DiffError: these two types are not equal because the left type at .x.y.Negation has type string, while the right type + // at .x.y.Negation has type number)"); +} + +TEST_CASE_FIXTURE(DifferFixture, "union_missing_right") +{ + CheckResult result = check(R"( + local foo: string | number + local almostFoo: boolean | string + )"); + LUAU_REQUIRE_NO_ERRORS(result); + + compareTypesNe( + "foo", + "almostFoo", + R"(DiffError: these two types are not equal because the left type at is a union containing type number, while the right type at is a union missing type number)" + ); +} + +TEST_CASE_FIXTURE(DifferFixture, "union_missing_left") +{ + CheckResult result = check(R"( + local foo: string | number + local almostFoo: boolean | string | number + )"); + LUAU_REQUIRE_NO_ERRORS(result); + + compareTypesNe( + "foo", + "almostFoo", + R"(DiffError: these two types are not equal because the left type at is a union missing type boolean, while the right type at is a union containing type boolean)" + ); +} + +TEST_CASE_FIXTURE(DifferFixture, "union_missing") +{ + // TODO: this test case produces an error message that is not the most UX-friendly + + CheckResult result = check(R"( + local foo: { bar: number, pan: string } | { baz: boolean, rot: "singleton" } + local almostFoo: { bar: number, pan: string } | { baz: string, rot: "singleton" } + )"); + LUAU_REQUIRE_NO_ERRORS(result); + + if (FFlag::LuauSolverV2) + compareTypesNe( + "foo", + "almostFoo", + R"(DiffError: these two types are not equal because the left type at is a union containing type { baz: boolean, rot: "singleton" }, while the right type at is a union missing type { baz: boolean, rot: "singleton" })" + ); + else + compareTypesNe( + "foo", + "almostFoo", + R"(DiffError: these two types are not equal because the left type at is a union containing type {| baz: boolean, rot: "singleton" |}, while the right type at is a union missing type {| baz: boolean, rot: "singleton" |})" + ); +} + +TEST_CASE_FIXTURE(DifferFixture, "intersection_missing_right") +{ + CheckResult result = check(R"( + local foo: (number) -> () & (string) -> () + local almostFoo: (string) -> () & (boolean) -> () + )"); + LUAU_REQUIRE_NO_ERRORS(result); + + compareTypesNe( + "foo", + "almostFoo", + R"(DiffError: these two types are not equal because the left type at is an intersection containing type (number) -> (), while the right type at is an intersection missing type (number) -> ())" + ); +} + +TEST_CASE_FIXTURE(DifferFixture, "intersection_missing_left") +{ + CheckResult result = check(R"( + local foo: (number) -> () & (string) -> () + local almostFoo: (string) -> () & (boolean) -> () & (number) -> () + )"); + LUAU_REQUIRE_NO_ERRORS(result); + + compareTypesNe( + "foo", + "almostFoo", + R"(DiffError: these two types are not equal because the left type at is an intersection missing type (boolean) -> (), while the right type at is an intersection containing type (boolean) -> ())" + ); +} + +TEST_CASE_FIXTURE(DifferFixture, "intersection_tables_missing_right") +{ + CheckResult result = check(R"( + local foo: { x: number } & { y: string } + local almostFoo: { y: string } & { z: boolean } + )"); + LUAU_REQUIRE_NO_ERRORS(result); + + if (FFlag::LuauSolverV2) + compareTypesNe( + "foo", + "almostFoo", + R"(DiffError: these two types are not equal because the left type at is an intersection containing type { x: number }, while the right type at is an intersection missing type { x: number })" + ); + else + compareTypesNe( + "foo", + "almostFoo", + R"(DiffError: these two types are not equal because the left type at is an intersection containing type {| x: number |}, while the right type at is an intersection missing type {| x: number |})" + ); +} + +TEST_CASE_FIXTURE(DifferFixture, "intersection_tables_missing_left") +{ + CheckResult result = check(R"( + local foo: { x: number } & { y: string } + local almostFoo: { y: string } & { z: boolean } & { x: number } + )"); + LUAU_REQUIRE_NO_ERRORS(result); + + if (FFlag::LuauSolverV2) + compareTypesNe( + "foo", + "almostFoo", + R"(DiffError: these two types are not equal because the left type at is an intersection missing type { z: boolean }, while the right type at is an intersection containing type { z: boolean })" + ); + else + compareTypesNe( + "foo", + "almostFoo", + R"(DiffError: these two types are not equal because the left type at is an intersection missing type {| z: boolean |}, while the right type at is an intersection containing type {| z: boolean |})" + ); +} + +TEST_CASE_FIXTURE(DifferFixture, "equal_function") +{ + // Old solver does not correctly infer function typepacks + ScopedFastFlag sff{FFlag::LuauSolverV2, true}; + + CheckResult result = check(R"( + function foo(x: number) + return x + end + function almostFoo(y: number) + return y + 10 + end + )"); + LUAU_REQUIRE_NO_ERRORS(result); + + compareTypesEq("foo", "almostFoo"); +} + +TEST_CASE_FIXTURE(DifferFixture, "equal_function_inferred_ret_length") +{ + // Old solver does not correctly infer function typepacks + ScopedFastFlag sff{FFlag::LuauSolverV2, true}; + + CheckResult result = check(R"( + function bar(x: number, y: string) + return x, y + end + function almostBar(a: number, b: string) + return a, b + end + function foo(x: number, y: string, z: boolean) + return z, bar(x, y) + end + function almostFoo(a: number, b: string, c: boolean) + return c, almostBar(a, b) + end + )"); + LUAU_REQUIRE_NO_ERRORS(result); + + compareTypesEq("foo", "almostFoo"); +} + +TEST_CASE_FIXTURE(DifferFixture, "equal_function_inferred_ret_length_2") +{ + // Old solver does not correctly infer function typepacks + ScopedFastFlag sff{FFlag::LuauSolverV2, true}; + + CheckResult result = check(R"( + function bar(x: number, y: string) + return x, y + end + function foo(x: number, y: string, z: boolean) + return bar(x, y), z + end + function almostFoo(a: number, b: string, c: boolean) + return a, c + end + )"); + LUAU_REQUIRE_NO_ERRORS(result); + + compareTypesEq("foo", "almostFoo"); +} + +TEST_CASE_FIXTURE(DifferFixture, "function_arg_normal") +{ + // Old solver does not correctly infer function typepacks + ScopedFastFlag sff{FFlag::LuauSolverV2, true}; + + CheckResult result = check(R"( + function foo(x: number, y: number, z: number) + return x * y * z + end + function almostFoo(a: number, b: number, msg: string) + return a + end + )"); + LUAU_REQUIRE_NO_ERRORS(result); + + compareTypesNe( + "foo", + "almostFoo", + R"(DiffError: these two types are not equal because the left type at .Arg[3] has type number, while the right type at .Arg[3] has type string)" + ); +} + +TEST_CASE_FIXTURE(DifferFixture, "function_arg_normal_2") +{ + // Old solver does not correctly infer function typepacks + ScopedFastFlag sff{FFlag::LuauSolverV2, true}; + + CheckResult result = check(R"( + function foo(x: number, y: number, z: string) + return x * y + end + function almostFoo(a: number, y: string, msg: string) + return a + end + )"); + LUAU_REQUIRE_NO_ERRORS(result); + + compareTypesNe( + "foo", + "almostFoo", + R"(DiffError: these two types are not equal because the left type at .Arg[2] has type number, while the right type at .Arg[2] has type string)" + ); +} + +TEST_CASE_FIXTURE(DifferFixture, "function_ret_normal") +{ + // Old solver does not correctly infer function typepacks + ScopedFastFlag sff{FFlag::LuauSolverV2, true}; + + CheckResult result = check(R"( + function foo(x: number, y: number, z: string) + return x + end + function almostFoo(a: number, b: number, msg: string) + return msg + end + )"); + LUAU_REQUIRE_NO_ERRORS(result); + + compareTypesNe( + "foo", + "almostFoo", + R"(DiffError: these two types are not equal because the left type at .Ret[1] has type number, while the right type at .Ret[1] has type string)" + ); +} + +TEST_CASE_FIXTURE(DifferFixture, "function_arg_length") +{ + // Old solver does not correctly infer function typepacks + ScopedFastFlag sff{FFlag::LuauSolverV2, true}; + + CheckResult result = check(R"( + function foo(x: number, y: number) + return x + end + function almostFoo(x: number, y: number, c: number) + return x + end + )"); + LUAU_REQUIRE_NO_ERRORS(result); + + compareTypesNe( + "foo", + "almostFoo", + R"(DiffError: these two types are not equal because the left type at takes 2 or more arguments, while the right type at takes 3 or more arguments)" + ); +} + +TEST_CASE_FIXTURE(DifferFixture, "function_arg_length_2") +{ + // Old solver does not correctly infer function typepacks + ScopedFastFlag sff{FFlag::LuauSolverV2, true}; + + CheckResult result = check(R"( + function foo(x: number, y: string, z: number) + return z + end + function almostFoo(x: number, y: string) + return x + end + )"); + LUAU_REQUIRE_NO_ERRORS(result); + + compareTypesNe( + "foo", + "almostFoo", + R"(DiffError: these two types are not equal because the left type at takes 3 or more arguments, while the right type at takes 2 or more arguments)" + ); +} + +TEST_CASE_FIXTURE(DifferFixture, "function_arg_length_none") +{ + // Old solver does not correctly infer function typepacks + ScopedFastFlag sff{FFlag::LuauSolverV2, true}; + + CheckResult result = check(R"( + function foo() + return 5 + end + function almostFoo(x: number, y: string) + return x + end + )"); + LUAU_REQUIRE_NO_ERRORS(result); + + compareTypesNe( + "foo", + "almostFoo", + R"(DiffError: these two types are not equal because the left type at takes 0 or more arguments, while the right type at takes 2 or more arguments)" + ); +} + +TEST_CASE_FIXTURE(DifferFixture, "function_arg_length_none_2") +{ + // Old solver does not correctly infer function typepacks + ScopedFastFlag sff{FFlag::LuauSolverV2, true}; + + CheckResult result = check(R"( + function foo(x: number) + return x + end + function almostFoo() + return 5 + end + )"); + LUAU_REQUIRE_NO_ERRORS(result); + + compareTypesNe( + "foo", + "almostFoo", + R"(DiffError: these two types are not equal because the left type at takes 1 or more arguments, while the right type at takes 0 or more arguments)" + ); +} + +TEST_CASE_FIXTURE(DifferFixture, "function_ret_length") +{ + // Old solver does not correctly infer function typepacks + ScopedFastFlag sff{FFlag::LuauSolverV2, true}; + + CheckResult result = check(R"( + function foo(x: number, y: number) + return x + end + function almostFoo(x: number, y: number) + return x, y + end + )"); + LUAU_REQUIRE_NO_ERRORS(result); + + compareTypesNe( + "foo", + "almostFoo", + R"(DiffError: these two types are not equal because the left type at returns 1 values, while the right type at returns 2 values)" + ); +} + +TEST_CASE_FIXTURE(DifferFixture, "function_ret_length_2") +{ + // Old solver does not correctly infer function typepacks + ScopedFastFlag sff{FFlag::LuauSolverV2, true}; + + CheckResult result = check(R"( + function foo(x: number, y: string, z: number) + return y, x, z + end + function almostFoo(x: number, y: string, z: number) + return y, x + end + )"); + LUAU_REQUIRE_NO_ERRORS(result); + + compareTypesNe( + "foo", + "almostFoo", + R"(DiffError: these two types are not equal because the left type at returns 3 values, while the right type at returns 2 values)" + ); +} + +TEST_CASE_FIXTURE(DifferFixture, "function_ret_length_none") +{ + // Old solver does not correctly infer function typepacks + ScopedFastFlag sff{FFlag::LuauSolverV2, true}; + + CheckResult result = check(R"( + function foo(x: number, y: string) + return + end + function almostFoo(x: number, y: string) + return x + end + )"); + LUAU_REQUIRE_NO_ERRORS(result); + + compareTypesNe( + "foo", + "almostFoo", + R"(DiffError: these two types are not equal because the left type at returns 0 values, while the right type at returns 1 values)" + ); +} + +TEST_CASE_FIXTURE(DifferFixture, "function_ret_length_none_2") +{ + // Old solver does not correctly infer function typepacks + ScopedFastFlag sff{FFlag::LuauSolverV2, true}; + + CheckResult result = check(R"( + function foo() + return 5 + end + function almostFoo() + return + end + )"); + LUAU_REQUIRE_NO_ERRORS(result); + + compareTypesNe( + "foo", + "almostFoo", + R"(DiffError: these two types are not equal because the left type at returns 1 values, while the right type at returns 0 values)" + ); +} + +TEST_CASE_FIXTURE(DifferFixture, "function_variadic_arg_normal") +{ + // Old solver does not correctly infer function typepacks + ScopedFastFlag sff{FFlag::LuauSolverV2, true}; + + CheckResult result = check(R"( + function foo(x: number, y: string, ...: number) + return x, y + end + function almostFoo(a: number, b: string, ...: string) + return a, b + end + )"); + LUAU_REQUIRE_NO_ERRORS(result); + + compareTypesNe( + "foo", + "almostFoo", + R"(DiffError: these two types are not equal because the left type at .Arg[Variadic] has type number, while the right type at .Arg[Variadic] has type string)" + ); +} + +TEST_CASE_FIXTURE(DifferFixture, "function_variadic_arg_missing") +{ + // Old solver does not correctly infer function typepacks + ScopedFastFlag sff{FFlag::LuauSolverV2, true}; + + CheckResult result = check(R"( + function foo(x: number, y: string, ...: number) + return x, y + end + function almostFoo(a: number, b: string) + return a, b + end + )"); + LUAU_REQUIRE_NO_ERRORS(result); + + compareTypesNe( + "foo", + "almostFoo", + R"(DiffError: these two types are not equal because the left type at .Arg[Variadic] has type number, while the right type at .Arg[Variadic] has type any)" + ); +} + +TEST_CASE_FIXTURE(DifferFixture, "function_variadic_arg_missing_2") +{ + // Old solver does not correctly infer function typepacks + ScopedFastFlag sff{FFlag::LuauSolverV2, true}; + + CheckResult result = check(R"( + function foo(x: number, y: string) + return x, y + end + function almostFoo(a: number, b: string, ...: string) + return a, b + end + )"); + LUAU_REQUIRE_NO_ERRORS(result); + + compareTypesNe( + "foo", + "almostFoo", + R"(DiffError: these two types are not equal because the left type at .Arg[Variadic] has type any, while the right type at .Arg[Variadic] has type string)" + ); +} + +TEST_CASE_FIXTURE(DifferFixture, "function_variadic_oversaturation") +{ + // Old solver does not correctly infer function typepacks + ScopedFastFlag sff{FFlag::LuauSolverV2, true}; + + CheckResult result = check(R"( + -- allowed to be oversaturated + function foo(x: number, y: string) + return x, y + end + -- must not be oversaturated + local almostFoo: (number, string) -> (number, string) = foo + )"); + LUAU_REQUIRE_NO_ERRORS(result); + + compareTypesNe( + "foo", + "almostFoo", + R"(DiffError: these two types are not equal because the left type at takes 2 or more arguments, while the right type at takes 2 arguments)" + ); +} + +TEST_CASE_FIXTURE(DifferFixture, "function_variadic_oversaturation_2") +{ + // Old solver does not correctly infer function typepacks + ScopedFastFlag sff{FFlag::LuauSolverV2, true}; + + CheckResult result = check(R"( + -- must not be oversaturated + local foo: (number, string) -> (number, string) + -- allowed to be oversaturated + function almostFoo(x: number, y: string) + return x, y + end + )"); + LUAU_REQUIRE_NO_ERRORS(result); + + compareTypesNe( + "foo", + "almostFoo", + R"(DiffError: these two types are not equal because the left type at takes 2 arguments, while the right type at takes 2 or more arguments)" + ); +} + +TEST_CASE_FIXTURE(DifferFixture, "generic") +{ + // Old solver does not correctly infer function typepacks + ScopedFastFlag sff{FFlag::LuauSolverV2, true}; + + CheckResult result = check(R"( + function foo(x, y) + return x, y + end + function almostFoo(x, y) + return y, x + end + )"); + LUAU_REQUIRE_NO_ERRORS(result); + + compareTypesNe( + "foo", + "almostFoo", + R"(DiffError: these two types are not equal because the left generic at .Ret[1] cannot be the same type parameter as the right generic at .Ret[1])" + ); +} + +TEST_CASE_FIXTURE(DifferFixture, "generic_one_vs_two") +{ + // Old solver does not correctly infer function typepacks + ScopedFastFlag sff{FFlag::LuauSolverV2, true}; + + CheckResult result = check(R"( + function foo(x: X, y: X) + return + end + function almostFoo(x: T, y: U) + return + end + )"); + LUAU_REQUIRE_NO_ERRORS(result); + + compareTypesNe( + "foo", + "almostFoo", + R"(DiffError: these two types are not equal because the left generic at .Arg[2] cannot be the same type parameter as the right generic at .Arg[2])" + ); +} + +TEST_CASE_FIXTURE(DifferFixture, "generic_three_or_three") +{ + // Old solver does not correctly infer function typepacks + ScopedFastFlag sff{FFlag::LuauSolverV2, true}; + + CheckResult result = check(R"( + function foo(x: X, y: X, z: Y) + return + end + function almostFoo(x: T, y: U, z: U) + return + end + )"); + LUAU_REQUIRE_NO_ERRORS(result); + + compareTypesNe( + "foo", + "almostFoo", + R"(DiffError: these two types are not equal because the left generic at .Arg[2] cannot be the same type parameter as the right generic at .Arg[2])" + ); +} + +TEST_CASE_FIXTURE(DifferFixtureWithBuiltins, "equal_metatable") +{ + CheckResult result = check(R"( + local metaFoo = { + metaBar = 5 + } + local metaAlmostFoo = { + metaBar = 1 + } + local foo = { + bar = 3 + } + setmetatable(foo, metaFoo) + local almostFoo = { + bar = 4 + } + setmetatable(almostFoo, metaAlmostFoo) + )"); + LUAU_REQUIRE_NO_ERRORS(result); + + compareTypesEq("foo", "almostFoo"); +} + +TEST_CASE_FIXTURE(DifferFixtureWithBuiltins, "metatable_normal") +{ + ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + + CheckResult result = check(R"( + local metaFoo = { + metaBar = 5 + } + local metaAlmostFoo = { + metaBar = 1 + } + local foo = { + bar = 3 + } + setmetatable(foo, metaFoo) + local almostFoo = { + bar = "hello" + } + setmetatable(almostFoo, metaAlmostFoo) + )"); + LUAU_REQUIRE_NO_ERRORS(result); + + compareTypesNe( + "foo", + "almostFoo", + R"(DiffError: these two types are not equal because the left type at .bar has type number, while the right type at .bar has type string)" + ); +} + +TEST_CASE_FIXTURE(DifferFixtureWithBuiltins, "metatable_metanormal") +{ + CheckResult result = check(R"( + local metaFoo = { + metaBar = "world" + } + local metaAlmostFoo = { + metaBar = 1 + } + local foo = { + bar = "amazing" + } + setmetatable(foo, metaFoo) + local almostFoo = { + bar = "hello" + } + setmetatable(almostFoo, metaAlmostFoo) + )"); + LUAU_REQUIRE_NO_ERRORS(result); + + compareTypesNe( + "foo", + "almostFoo", + R"(DiffError: these two types are not equal because the left type at .__metatable.metaBar has type string, while the right type at .__metatable.metaBar has type number)" + ); +} + +TEST_CASE_FIXTURE(DifferFixtureWithBuiltins, "metatable_metamissing_left") +{ + CheckResult result = check(R"( + local metaFoo = { + metaBar = "world" + } + local metaAlmostFoo = { + metaBar = 1, + thisIsOnlyInRight = 2, + } + local foo = { + bar = "amazing" + } + setmetatable(foo, metaFoo) + local almostFoo = { + bar = "hello" + } + setmetatable(almostFoo, metaAlmostFoo) + )"); + LUAU_REQUIRE_NO_ERRORS(result); + + compareTypesNe( + "foo", + "almostFoo", + R"(DiffError: these two types are not equal because the left type at .__metatable is missing the property thisIsOnlyInRight, while the right type at .__metatable.thisIsOnlyInRight has type number)" + ); +} + +TEST_CASE_FIXTURE(DifferFixtureWithBuiltins, "metatable_metamissing_right") +{ + CheckResult result = check(R"( + local metaFoo = { + metaBar = "world", + thisIsOnlyInLeft = 2, + } + local metaAlmostFoo = { + metaBar = 1, + } + local foo = { + bar = "amazing" + } + setmetatable(foo, metaFoo) + local almostFoo = { + bar = "hello" + } + setmetatable(almostFoo, metaAlmostFoo) + )"); + LUAU_REQUIRE_NO_ERRORS(result); + + compareTypesNe( + "foo", + "almostFoo", + R"(DiffError: these two types are not equal because the left type at .__metatable.thisIsOnlyInLeft has type number, while the right type at .__metatable is missing the property thisIsOnlyInLeft)" + ); +} + +TEST_CASE_FIXTURE(DifferFixtureGeneric, "equal_class") +{ + CheckResult result = check(R"( + local foo = BaseClass + local almostFoo = BaseClass + )"); + LUAU_REQUIRE_NO_ERRORS(result); + + compareTypesEq("foo", "almostFoo"); +} + +TEST_CASE_FIXTURE(DifferFixtureGeneric, "class_normal") +{ + CheckResult result = check(R"( + local foo = BaseClass + local almostFoo = ChildClass + )"); + LUAU_REQUIRE_NO_ERRORS(result); + + compareTypesNe( + "foo", + "almostFoo", + R"(DiffError: these two types are not equal because the left type at has type BaseClass, while the right type at has type ChildClass)" + ); +} + +TEST_CASE_FIXTURE(DifferFixture, "equal_generictp") +{ + CheckResult result = check(R"( + local foo: () -> T... + local almostFoo: () -> U... + )"); + LUAU_REQUIRE_NO_ERRORS(result); + + compareTypesEq("foo", "almostFoo"); +} + +TEST_CASE_FIXTURE(DifferFixture, "generictp_ne_fn") +{ + CheckResult result = check(R"( + local foo: (...T) -> U... + local almostFoo: (U...) -> U... + )"); + LUAU_REQUIRE_NO_ERRORS(result); + + compareTypesNe( + "foo", + "almostFoo", + R"(DiffError: these two types are not equal because the left type at has type (...T) -> (U...), while the right type at has type (U...) -> (U...))" + ); +} + +TEST_CASE_FIXTURE(DifferFixture, "generictp_normal") +{ + CheckResult result = check(R"( + -- trN should be X... -> Y... + -- s should be X -> Y... + -- x should be X + -- bij should be X... -> X... + + -- Intended signature: (X... -> Y..., Z -> X..., X... -> Y..., Z, Y... -> Y...) -> () + function foo(tr, s, tr2, x, bij) + bij(bij(tr(s(x)))) + bij(bij(tr2(s(x)))) + end + -- Intended signature: (X... -> X..., Z -> X..., X... -> Y..., Z, Y... -> Y...) -> () + function almostFoo(bij, s, tr, x, bij2) + bij(bij(s(x))) + bij2(bij2(tr(s(x)))) + end + )"); + LUAU_REQUIRE_NO_ERRORS(result); + + INFO(Luau::toString(requireType("foo"))); + INFO(Luau::toString(requireType("almostFoo"))); + + compareTypesNe( + "foo", + "almostFoo", + R"(DiffError: these two types are not equal because the left generic at .Arg[1].Ret[Variadic] cannot be the same type parameter as the right generic at .Arg[1].Ret[Variadic])" + ); +} + +TEST_CASE_FIXTURE(DifferFixture, "generictp_normal_2") +{ + CheckResult result = check(R"( + -- trN should be X... -> Y... + -- s should be X -> Y... + -- x should be X + -- bij should be X... -> X... + + function foo(s, tr, tr2, x, bij) + bij(bij(tr(s(x)))) + bij(bij(tr2(s(x)))) + end + function almostFoo(s, bij, tr, x, bij2) + bij2(bij2(bij(bij(tr(s(x)))))) + end + )"); + LUAU_REQUIRE_NO_ERRORS(result); + + INFO(Luau::toString(requireType("foo"))); + INFO(Luau::toString(requireType("almostFoo"))); + + compareTypesNe( + "foo", + "almostFoo", + R"(DiffError: these two types are not equal because the left generic at .Arg[2].Arg[Variadic] cannot be the same type parameter as the right generic at .Arg[2].Arg[Variadic])" + ); +} + +TEST_CASE_FIXTURE(DifferFixture, "equal_generictp_cyclic") +{ + CheckResult result = check(R"( + function foo(f, g, s, x) + f(f(g(g(s(x))))) + return foo + end + function almostFoo(f, g, s, x) + g(g(f(f(s(x))))) + return almostFoo + end + )"); + LUAU_REQUIRE_NO_ERRORS(result); + + INFO(Luau::toString(requireType("foo"))); + INFO(Luau::toString(requireType("almostFoo"))); + + compareTypesEq("foo", "almostFoo"); +} + +TEST_CASE_FIXTURE(DifferFixture, "symbol_forward") +{ + CheckResult result = check(R"( + local foo = 5 + local almostFoo = "five" + )"); + LUAU_REQUIRE_NO_ERRORS(result); + + INFO(Luau::toString(requireType("foo"))); + INFO(Luau::toString(requireType("almostFoo"))); + + compareTypesNe( + "foo", + "almostFoo", + R"(DiffError: these two types are not equal because the left type at foo has type number, while the right type at almostFoo has type string)", + true + ); +} + +TEST_CASE_FIXTURE(DifferFixture, "newlines") +{ + CheckResult result = check(R"( + local foo = 5 + local almostFoo = "five" + )"); + LUAU_REQUIRE_NO_ERRORS(result); + + INFO(Luau::toString(requireType("foo"))); + INFO(Luau::toString(requireType("almostFoo"))); + + compareTypesNe( + "foo", + "almostFoo", + R"(DiffError: these two types are not equal because the left type at + foo +has type + number, +while the right type at + almostFoo +has type + string)", + true, + true + ); +} + +TEST_SUITE_END(); diff --git a/tests/EqSat.language.test.cpp b/tests/EqSat.language.test.cpp new file mode 100644 index 000000000..282d4ad2a --- /dev/null +++ b/tests/EqSat.language.test.cpp @@ -0,0 +1,144 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include + +#include "Luau/Id.h" +#include "Luau/Language.h" + +#include +#include + +LUAU_EQSAT_ATOM(I32, int); +LUAU_EQSAT_ATOM(Bool, bool); +LUAU_EQSAT_ATOM(Str, std::string); + +LUAU_EQSAT_FIELD(Left); +LUAU_EQSAT_FIELD(Right); +LUAU_EQSAT_NODE_FIELDS(Add, Left, Right); + +using namespace Luau; + +using Value = EqSat::Language; + +TEST_SUITE_BEGIN("EqSatLanguage"); + +TEST_CASE("atom_equality") +{ + CHECK(I32{0} == I32{0}); + CHECK(I32{0} != I32{1}); +} + +TEST_CASE("node_equality") +{ + CHECK(Add{EqSat::Id{0}, EqSat::Id{0}} == Add{EqSat::Id{0}, EqSat::Id{0}}); + CHECK(Add{EqSat::Id{1}, EqSat::Id{0}} != Add{EqSat::Id{0}, EqSat::Id{0}}); +} + +TEST_CASE("language_get") +{ + Value v{I32{5}}; + + auto i = v.get(); + REQUIRE(i); + CHECK(i->value()); + + CHECK(!v.get()); +} + +TEST_CASE("language_copy_ctor") +{ + Value v1{I32{5}}; + Value v2 = v1; + + auto i1 = v1.get(); + auto i2 = v2.get(); + REQUIRE(i1); + REQUIRE(i2); + CHECK(i1->value() == i2->value()); +} + +TEST_CASE("language_move_ctor") +{ + Value v1{Str{"hello"}}; + { + auto s1 = v1.get(); + REQUIRE(s1); + CHECK(s1->value() == "hello"); + } + + Value v2 = std::move(v1); + + auto s1 = v1.get(); + REQUIRE(s1); + CHECK(s1->value() == ""); // this also tests the dtor. + + auto s2 = v2.get(); + REQUIRE(s2); + CHECK(s2->value() == "hello"); +} + +TEST_CASE("language_equality") +{ + Value v1{I32{0}}; + Value v2{I32{0}}; + Value v3{I32{1}}; + Value v4{Bool{true}}; + Value v5{Add{EqSat::Id{0}, EqSat::Id{1}}}; + + CHECK(v1 == v2); + CHECK(v2 != v3); + CHECK(v3 != v4); + CHECK(v4 != v5); +} + +TEST_CASE("language_is_mappable") +{ + std::unordered_map map; + + Value v1{I32{5}}; + Value v2{I32{5}}; + Value v3{Bool{true}}; + Value v4{Add{EqSat::Id{0}, EqSat::Id{1}}}; + + map[v1] = 1; + map[v2] = 2; + map[v3] = 42; + map[v4] = 37; + + CHECK(map[v1] == 2); + CHECK(map[v2] == 2); + CHECK(map[v3] == 42); + CHECK(map[v4] == 37); +} + +TEST_CASE("node_field") +{ + EqSat::Id left{0}; + EqSat::Id right{1}; + + Add add{left, right}; + + EqSat::Id left2 = add.field(); + EqSat::Id right2 = add.field(); + + CHECK(left == left2); + CHECK(left != right2); + CHECK(right == right2); + CHECK(right != left2); +} + +TEST_CASE("language_operands") +{ + Value v1{I32{0}}; + CHECK(v1.operands().empty()); + + Value v2{Add{EqSat::Id{0}, EqSat::Id{1}}}; + const Add* add = v2.get(); + REQUIRE(add); + + EqSat::Slice actual = v2.operands(); + CHECK(actual.size() == 2); + CHECK(actual[0] == add->field()); + CHECK(actual[1] == add->field()); +} + +TEST_SUITE_END(); diff --git a/tests/EqSat.propositional.test.cpp b/tests/EqSat.propositional.test.cpp new file mode 100644 index 000000000..679477ced --- /dev/null +++ b/tests/EqSat.propositional.test.cpp @@ -0,0 +1,198 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include + +#include "Luau/EGraph.h" +#include "Luau/Id.h" +#include "Luau/Language.h" + +#include +#include + +LUAU_EQSAT_ATOM(Var, std::string); +LUAU_EQSAT_ATOM(Bool, bool); +LUAU_EQSAT_NODE_ARRAY(Not, 1); +LUAU_EQSAT_NODE_ARRAY(And, 2); +LUAU_EQSAT_NODE_ARRAY(Or, 2); +LUAU_EQSAT_NODE_ARRAY(Implies, 2); + +using namespace Luau; + +using PropositionalLogic = EqSat::Language; + +using EGraph = EqSat::EGraph; + +struct ConstantFold +{ + using Data = std::optional; + + Data make(const EGraph& egraph, const Var& var) const + { + return std::nullopt; + } + + Data make(const EGraph& egraph, const Bool& b) const + { + return b.value(); + } + + Data make(const EGraph& egraph, const Not& n) const + { + Data data = egraph[n[0]].data; + if (data) + return !*data; + + return std::nullopt; + } + + Data make(const EGraph& egraph, const And& a) const + { + Data l = egraph[a[0]].data; + Data r = egraph[a[1]].data; + if (l && r) + return *l && *r; + + return std::nullopt; + } + + Data make(const EGraph& egraph, const Or& o) const + { + Data l = egraph[o[0]].data; + Data r = egraph[o[1]].data; + if (l && r) + return *l || *r; + + return std::nullopt; + } + + Data make(const EGraph& egraph, const Implies& i) const + { + Data antecedent = egraph[i[0]].data; + Data consequent = egraph[i[1]].data; + if (antecedent && consequent) + return !*antecedent || *consequent; + + return std::nullopt; + } + + void join(Data& a, const Data& b) const + { + if (!a && b) + a = b; + } +}; + +TEST_SUITE_BEGIN("EqSatPropositionalLogic"); + +TEST_CASE("egraph_hashconsing") +{ + EGraph egraph; + + EqSat::Id id1 = egraph.add(Bool{true}); + EqSat::Id id2 = egraph.add(Bool{true}); + EqSat::Id id3 = egraph.add(Bool{false}); + + CHECK(id1 == id2); + CHECK(id2 != id3); +} + +TEST_CASE("egraph_data") +{ + EGraph egraph; + + EqSat::Id id1 = egraph.add(Bool{true}); + EqSat::Id id2 = egraph.add(Bool{false}); + + CHECK(egraph[id1].data == true); + CHECK(egraph[id2].data == false); +} + +TEST_CASE("egraph_merge") +{ + EGraph egraph; + + EqSat::Id id1 = egraph.add(Var{"a"}); + EqSat::Id id2 = egraph.add(Bool{true}); + egraph.merge(id1, id2); + + CHECK(egraph[id1].data == true); + CHECK(egraph[id2].data == true); +} + +TEST_CASE("const_fold_true_and_true") +{ + EGraph egraph; + + EqSat::Id id1 = egraph.add(Bool{true}); + EqSat::Id id2 = egraph.add(Bool{true}); + EqSat::Id id3 = egraph.add(And{id1, id2}); + + CHECK(egraph[id3].data == true); +} + +TEST_CASE("const_fold_true_and_false") +{ + EGraph egraph; + + EqSat::Id id1 = egraph.add(Bool{true}); + EqSat::Id id2 = egraph.add(Bool{false}); + EqSat::Id id3 = egraph.add(And{id1, id2}); + + CHECK(egraph[id3].data == false); +} + +TEST_CASE("const_fold_false_and_false") +{ + EGraph egraph; + + EqSat::Id id1 = egraph.add(Bool{false}); + EqSat::Id id2 = egraph.add(Bool{false}); + EqSat::Id id3 = egraph.add(And{id1, id2}); + + CHECK(egraph[id3].data == false); +} + +TEST_CASE("implications") +{ + EGraph egraph; + + EqSat::Id t = egraph.add(Bool{true}); + EqSat::Id f = egraph.add(Bool{false}); + + EqSat::Id a = egraph.add(Implies{t, t}); // true + EqSat::Id b = egraph.add(Implies{t, f}); // false + EqSat::Id c = egraph.add(Implies{f, t}); // true + EqSat::Id d = egraph.add(Implies{f, f}); // true + + CHECK(egraph[a].data == true); + CHECK(egraph[b].data == false); + CHECK(egraph[c].data == true); + CHECK(egraph[d].data == true); +} + +TEST_CASE("merge_x_and_y") +{ + EGraph egraph; + + EqSat::Id x = egraph.add(Var{"x"}); + EqSat::Id y = egraph.add(Var{"y"}); + + EqSat::Id a = egraph.add(Var{"a"}); + EqSat::Id ax = egraph.add(And{a, x}); + EqSat::Id ay = egraph.add(And{a, y}); + + egraph.merge(x, y); // [x y] [ax] [ay] [a] + CHECK_EQ(egraph.size(), 4); + CHECK_EQ(egraph.find(x), egraph.find(y)); + CHECK_NE(egraph.find(ax), egraph.find(ay)); + CHECK_NE(egraph.find(a), egraph.find(x)); + CHECK_NE(egraph.find(a), egraph.find(y)); + + egraph.rebuild(); // [x y] [ax ay] [a] + CHECK_EQ(egraph.size(), 3); + CHECK_EQ(egraph.find(x), egraph.find(y)); + CHECK_EQ(egraph.find(ax), egraph.find(ay)); + CHECK_NE(egraph.find(a), egraph.find(x)); + CHECK_NE(egraph.find(a), egraph.find(y)); +} + +TEST_SUITE_END(); diff --git a/tests/EqSat.slice.test.cpp b/tests/EqSat.slice.test.cpp new file mode 100644 index 000000000..26ca3bfd3 --- /dev/null +++ b/tests/EqSat.slice.test.cpp @@ -0,0 +1,58 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include + +#include "Luau/Slice.h" + +#include + +using namespace Luau; + +TEST_SUITE_BEGIN("EqSatSlice"); + +TEST_CASE("slice_is_a_view_over_array") +{ + std::array a{1, 2, 3, 4, 5, 6, 7, 8}; + + EqSat::Slice slice{a}; + + CHECK(slice.data() == a.data()); + CHECK(slice.size() == a.size()); + + for (size_t i = 0; i < a.size(); ++i) + { + CHECK(slice[i] == a[i]); + CHECK(&slice[i] == &a[i]); + } +} + +TEST_CASE("slice_is_a_view_over_vector") +{ + std::vector vector{1, 2, 3, 4, 5, 6, 7, 8}; + + EqSat::Slice slice{vector.data(), vector.size()}; + + CHECK(slice.data() == vector.data()); + CHECK(slice.size() == vector.size()); + + for (size_t i = 0; i < vector.size(); ++i) + { + CHECK(slice[i] == vector[i]); + CHECK(&slice[i] == &vector[i]); + } +} + +TEST_CASE("mutate_via_slice") +{ + std::array a{1, 2}; + CHECK(a[0] == 1); + CHECK(a[1] == 2); + + EqSat::Slice slice{a}; + slice[0] = 42; + slice[1] = 37; + + CHECK(a[0] == 42); + CHECK(a[1] == 37); +} + +TEST_SUITE_END(); diff --git a/tests/Error.test.cpp b/tests/Error.test.cpp index 5ba5c112e..15c317fca 100644 --- a/tests/Error.test.cpp +++ b/tests/Error.test.cpp @@ -1,10 +1,13 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/Error.h" +#include "Fixture.h" #include "doctest.h" using namespace Luau; +LUAU_FASTFLAG(LuauSolverV2) + TEST_SUITE_BEGIN("ErrorTests"); TEST_CASE("TypeError_code_should_return_nonzero_code") @@ -13,4 +16,69 @@ TEST_CASE("TypeError_code_should_return_nonzero_code") CHECK_GE(e.code(), 1000); } +TEST_CASE_FIXTURE(BuiltinsFixture, "metatable_names_show_instead_of_tables") +{ + frontend.options.retainFullTypeGraphs = false; + + CheckResult result = check(R"( +--!strict +local Account = {} +Account.__index = Account +function Account.deposit(self: Account, x: number) + self.balance += x +end +type Account = typeof(setmetatable({} :: { balance: number }, Account)) +local x: Account = 5 +)"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CHECK_EQ("Type 'number' could not be converted into 'Account'", toString(result.errors[0])); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "binary_op_type_function_errors") +{ + frontend.options.retainFullTypeGraphs = false; + + CheckResult result = check(R"( + --!strict + local x = 1 + "foo" + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + if (FFlag::LuauSolverV2) + CHECK_EQ( + "Operator '+' could not be applied to operands of types number and string; there is no corresponding overload for __add", + toString(result.errors[0]) + ); + else + CHECK_EQ("Type 'string' could not be converted into 'number'", toString(result.errors[0])); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "unary_op_type_function_errors") +{ + frontend.options.retainFullTypeGraphs = false; + + CheckResult result = check(R"( + --!strict + local x = -"foo" + )"); + + + if (FFlag::LuauSolverV2) + { + LUAU_REQUIRE_ERROR_COUNT(2, result); + CHECK_EQ( + "Operator '-' could not be applied to operand of type string; there is no corresponding overload for __unm", toString(result.errors[0]) + ); + CHECK_EQ("Type 'string' could not be converted into 'number'", toString(result.errors[1])); + } + else + { + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ("Type 'string' could not be converted into 'number'", toString(result.errors[0])); + } +} + TEST_SUITE_END(); diff --git a/tests/Fixture.cpp b/tests/Fixture.cpp index cbceabbdc..bf254f80c 100644 --- a/tests/Fixture.cpp +++ b/tests/Fixture.cpp @@ -3,24 +3,30 @@ #include "Luau/AstQuery.h" #include "Luau/BuiltinDefinitions.h" +#include "Luau/Common.h" #include "Luau/Constraint.h" #include "Luau/ModuleResolver.h" #include "Luau/NotNull.h" #include "Luau/Parser.h" #include "Luau/Type.h" #include "Luau/TypeAttach.h" +#include "Luau/TypeInfer.h" #include "Luau/Transpiler.h" #include "doctest.h" #include +#include #include #include #include +#include static const char* mainModuleName = "MainModule"; -LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); +LUAU_FASTFLAG(LuauSolverV2); +LUAU_FASTFLAG(DebugLuauLogSolverToJsonFile) +LUAU_DYNAMIC_FASTINT(LuauTypeSolverRelease) extern std::optional randomSeed; // tests/main.cpp @@ -113,7 +119,17 @@ std::optional TestFileResolver::resolveModule(const ModuleInfo* cont std::string TestFileResolver::getHumanReadableModuleName(const ModuleName& name) const { - return name; + // We have a handful of tests that need to distinguish between a canonical + // ModuleName and the human-readable version so we apply a simple transform + // here: We replace all slashes with dots. + std::string result = name; + for (size_t i = 0; i < result.size(); ++i) + { + if (result[i] == '/') + result[i] = '.'; + } + + return result; } std::optional TestFileResolver::getEnvironmentForModule(const ModuleName& name) const @@ -134,23 +150,38 @@ const Config& TestConfigResolver::getConfig(const ModuleName& name) const return defaultConfig; } -Fixture::Fixture(bool freeze, bool prepareAutocomplete) - : sff_DebugLuauFreezeArena("DebugLuauFreezeArena", freeze) - , frontend(&fileResolver, &configResolver, - {/* retainFullTypeGraphs= */ true, /* forAutocomplete */ false, /* randomConstraintResolutionSeed */ randomSeed}) - , typeChecker(frontend.typeChecker) +Fixture::Fixture(bool prepareAutocomplete) + : frontend( + &fileResolver, + &configResolver, + {/* retainFullTypeGraphs= */ true, /* forAutocomplete */ false, /* runLintChecks */ false, /* randomConstraintResolutionSeed */ randomSeed} + ) , builtinTypes(frontend.builtinTypes) { configResolver.defaultConfig.mode = Mode::Strict; configResolver.defaultConfig.enabledLint.warningMask = ~0ull; configResolver.defaultConfig.parseOptions.captureComments = true; - registerBuiltinTypes(frontend); - - Luau::freeze(frontend.typeChecker.globalTypes); - Luau::freeze(frontend.typeCheckerForAutocomplete.globalTypes); + Luau::freeze(frontend.globals.globalTypes); + Luau::freeze(frontend.globalsForAutocomplete.globalTypes); Luau::setPrintLine([](auto s) {}); + + if (FFlag::DebugLuauLogSolverToJsonFile) + { + frontend.writeJsonLog = [&](const Luau::ModuleName& moduleName, std::string log) + { + std::string path = moduleName + ".log.json"; + size_t pos = moduleName.find_last_of('/'); + if (pos != std::string::npos) + path = moduleName.substr(pos + 1); + + std::ofstream os(path); + + os << log << std::endl; + MESSAGE("Wrote JSON log to ", path); + }; + } } Fixture::~Fixture() @@ -174,15 +205,34 @@ AstStatBlock* Fixture::parse(const std::string& source, const ParseOptions& pars // if AST is available, check how lint and typecheck handle error nodes if (result.root) { - frontend.lint(*sourceModule); - - if (FFlag::DebugLuauDeferredConstraintResolution) + if (FFlag::LuauSolverV2) { - Luau::check(*sourceModule, {}, frontend.builtinTypes, NotNull{&ice}, NotNull{&moduleResolver}, NotNull{&fileResolver}, - typeChecker.globalScope, frontend.options); + Mode mode = sourceModule->mode ? *sourceModule->mode : Mode::Strict; + ModulePtr module = Luau::check( + *sourceModule, + mode, + {}, + builtinTypes, + NotNull{&ice}, + NotNull{&moduleResolver}, + NotNull{&fileResolver}, + frontend.globals.globalScope, + /*prepareModuleScope*/ nullptr, + frontend.options, + {}, + false, + {} + ); + + Luau::lint(sourceModule->root, *sourceModule->names, frontend.globals.globalScope, module.get(), sourceModule->hotcomments, {}); } else - typeChecker.check(*sourceModule, sourceModule->mode.value_or(Luau::Mode::Nonstrict)); + { + TypeChecker typeChecker(frontend.globals.globalScope, &moduleResolver, builtinTypes, &frontend.iceHandler); + ModulePtr module = typeChecker.check(*sourceModule, sourceModule->mode.value_or(Luau::Mode::Nonstrict), std::nullopt); + + Luau::lint(sourceModule->root, *sourceModule->names, frontend.globals.globalScope, module.get(), sourceModule->hotcomments, {}); + } } throw ParseErrors(result.errors); @@ -191,7 +241,7 @@ AstStatBlock* Fixture::parse(const std::string& source, const ParseOptions& pars return result.root; } -CheckResult Fixture::check(Mode mode, std::string source) +CheckResult Fixture::check(Mode mode, const std::string& source) { ModuleName mm = fromString(mainModuleName); configResolver.defaultConfig.mode = mode; @@ -210,20 +260,23 @@ CheckResult Fixture::check(const std::string& source) LintResult Fixture::lint(const std::string& source, const std::optional& lintOptions) { - ParseOptions parseOptions; - parseOptions.captureComments = true; - configResolver.defaultConfig.mode = Mode::Nonstrict; - parse(source, parseOptions); + ModuleName mm = fromString(mainModuleName); + configResolver.defaultConfig.mode = Mode::Strict; + fileResolver.source[mm] = std::move(source); + frontend.markDirty(mm); - return frontend.lint(*sourceModule, lintOptions); + return lintModule(mm); } -LintResult Fixture::lintTyped(const std::string& source, const std::optional& lintOptions) +LintResult Fixture::lintModule(const ModuleName& moduleName, const std::optional& lintOptions) { - check(source); - ModuleName mm = fromString(mainModuleName); + FrontendOptions options = frontend.options; + options.runLintChecks = true; + options.enabledLintWarnings = lintOptions; - return frontend.lint(mm, lintOptions); + CheckResult result = frontend.check(moduleName, options); + + return result.lintResult; } ParseResult Fixture::parseEx(const std::string& source, const ParseOptions& options) @@ -319,7 +372,7 @@ std::optional Fixture::getType(const std::string& name) if (!module->hasModuleScope()) return std::nullopt; - if (FFlag::DebugLuauDeferredConstraintResolution) + if (FFlag::LuauSolverV2) return linearSearchForBinding(module->getModuleScope().get(), name.c_str()); else return lookupName(module->getModuleScope(), name); @@ -405,7 +458,18 @@ TypeId Fixture::requireTypeAlias(const std::string& name) { std::optional ty = lookupType(name); REQUIRE(ty); - return *ty; + return follow(*ty); +} + +TypeId Fixture::requireExportedType(const ModuleName& moduleName, const std::string& name) +{ + ModulePtr module = frontend.moduleResolver.getModule(moduleName); + REQUIRE(module); + + auto it = module->exportedTypeBindings.find(name); + REQUIRE(it != module->exportedTypeBindings.end()); + + return it->second.type; } std::string Fixture::decorateWithTypes(const std::string& code) @@ -447,9 +511,9 @@ void Fixture::dumpErrors(std::ostream& os, const std::vector& errors) void Fixture::registerTestTypes() { - addGlobalBinding(frontend, "game", typeChecker.anyType, "@luau"); - addGlobalBinding(frontend, "workspace", typeChecker.anyType, "@luau"); - addGlobalBinding(frontend, "script", typeChecker.anyType, "@luau"); + addGlobalBinding(frontend.globals, "game", builtinTypes->anyType, "@luau"); + addGlobalBinding(frontend.globals, "workspace", builtinTypes->anyType, "@luau"); + addGlobalBinding(frontend.globals, "script", builtinTypes->anyType, "@luau"); } void Fixture::dumpErrors(const CheckResult& cr) @@ -499,9 +563,10 @@ void Fixture::validateErrors(const std::vector& errors) LoadDefinitionFileResult Fixture::loadDefinition(const std::string& source) { - unfreeze(typeChecker.globalTypes); - LoadDefinitionFileResult result = frontend.loadDefinitionFile(source, "@test"); - freeze(typeChecker.globalTypes); + unfreeze(frontend.globals.globalTypes); + LoadDefinitionFileResult result = + frontend.loadDefinitionFile(frontend.globals, frontend.globals.globalScope, source, "@test", /* captureComments */ false); + freeze(frontend.globals.globalTypes); if (result.module) dumpErrors(result.module); @@ -509,19 +574,85 @@ LoadDefinitionFileResult Fixture::loadDefinition(const std::string& source) return result; } -BuiltinsFixture::BuiltinsFixture(bool freeze, bool prepareAutocomplete) - : Fixture(freeze, prepareAutocomplete) +BuiltinsFixture::BuiltinsFixture(bool prepareAutocomplete) + : Fixture(prepareAutocomplete) { - Luau::unfreeze(frontend.typeChecker.globalTypes); - Luau::unfreeze(frontend.typeCheckerForAutocomplete.globalTypes); + Luau::unfreeze(frontend.globals.globalTypes); + Luau::unfreeze(frontend.globalsForAutocomplete.globalTypes); - registerBuiltinGlobals(frontend); + registerBuiltinGlobals(frontend, frontend.globals); if (prepareAutocomplete) - registerBuiltinGlobals(frontend.typeCheckerForAutocomplete); + registerBuiltinGlobals(frontend, frontend.globalsForAutocomplete, /*typeCheckForAutocomplete*/ true); registerTestTypes(); - Luau::freeze(frontend.typeChecker.globalTypes); - Luau::freeze(frontend.typeCheckerForAutocomplete.globalTypes); + Luau::freeze(frontend.globals.globalTypes); + Luau::freeze(frontend.globalsForAutocomplete.globalTypes); +} + +static std::vector parsePathExpr(const AstExpr& pathExpr) +{ + const AstExprIndexName* indexName = pathExpr.as(); + if (!indexName) + return {}; + + std::vector segments{indexName->index.value}; + + while (true) + { + if (AstExprIndexName* in = indexName->expr->as()) + { + segments.push_back(in->index.value); + indexName = in; + continue; + } + else if (AstExprGlobal* indexNameAsGlobal = indexName->expr->as()) + { + segments.push_back(indexNameAsGlobal->name.value); + break; + } + else if (AstExprLocal* indexNameAsLocal = indexName->expr->as()) + { + segments.push_back(indexNameAsLocal->local->name.value); + break; + } + else + return {}; + } + + std::reverse(segments.begin(), segments.end()); + return segments; +} + +std::optional pathExprToModuleName(const ModuleName& currentModuleName, const std::vector& segments) +{ + if (segments.empty()) + return std::nullopt; + + std::vector result; + + auto it = segments.begin(); + + if (*it == "script" && !currentModuleName.empty()) + { + result = split(currentModuleName, '/'); + ++it; + } + + for (; it != segments.end(); ++it) + { + if (result.size() > 1 && *it == "Parent") + result.pop_back(); + else + result.push_back(*it); + } + + return join(result, "/"); +} + +std::optional pathExprToModuleName(const ModuleName& currentModuleName, const AstExpr& pathExpr) +{ + std::vector segments = parsePathExpr(pathExpr); + return pathExprToModuleName(currentModuleName, segments); } ModuleName fromString(std::string_view name) @@ -581,47 +712,59 @@ std::optional linearSearchForBinding(Scope* scope, const char* name) void registerHiddenTypes(Frontend* frontend) { - TypeId t = frontend->globalTypes.addType(GenericType{"T"}); + GlobalTypes& globals = frontend->globals; + + unfreeze(globals.globalTypes); + + TypeId t = globals.globalTypes.addType(GenericType{"T"}); GenericTypeDefinition genericT{t}; - ScopePtr globalScope = frontend->getGlobalScope(); - globalScope->exportedTypeBindings["Not"] = TypeFun{{genericT}, frontend->globalTypes.addType(NegationType{t})}; + TypeId u = globals.globalTypes.addType(GenericType{"U"}); + GenericTypeDefinition genericU{u}; + + ScopePtr globalScope = globals.globalScope; + globalScope->exportedTypeBindings["Not"] = TypeFun{{genericT}, globals.globalTypes.addType(NegationType{t})}; + globalScope->exportedTypeBindings["Mt"] = TypeFun{{genericT, genericU}, globals.globalTypes.addType(MetatableType{t, u})}; globalScope->exportedTypeBindings["fun"] = TypeFun{{}, frontend->builtinTypes->functionType}; globalScope->exportedTypeBindings["cls"] = TypeFun{{}, frontend->builtinTypes->classType}; globalScope->exportedTypeBindings["err"] = TypeFun{{}, frontend->builtinTypes->errorType}; globalScope->exportedTypeBindings["tbl"] = TypeFun{{}, frontend->builtinTypes->tableType}; + + freeze(globals.globalTypes); } void createSomeClasses(Frontend* frontend) { - TypeArena& arena = frontend->globalTypes; + GlobalTypes& globals = frontend->globals; + + TypeArena& arena = globals.globalTypes; unfreeze(arena); - ScopePtr moduleScope = frontend->getGlobalScope(); + ScopePtr moduleScope = globals.globalScope; - TypeId parentType = arena.addType(ClassType{"Parent", {}, frontend->builtinTypes->classType, std::nullopt, {}, nullptr, "Test"}); + TypeId parentType = arena.addType(ClassType{"Parent", {}, frontend->builtinTypes->classType, std::nullopt, {}, nullptr, "Test", {}}); ClassType* parentClass = getMutable(parentType); parentClass->props["method"] = {makeFunction(arena, parentType, {}, {})}; parentClass->props["virtual_method"] = {makeFunction(arena, parentType, {}, {})}; - addGlobalBinding(*frontend, "Parent", {parentType}); + addGlobalBinding(globals, "Parent", {parentType}); moduleScope->exportedTypeBindings["Parent"] = TypeFun{{}, parentType}; - TypeId childType = arena.addType(ClassType{"Child", {}, parentType, std::nullopt, {}, nullptr, "Test"}); + TypeId childType = arena.addType(ClassType{"Child", {}, parentType, std::nullopt, {}, nullptr, "Test", {}}); - addGlobalBinding(*frontend, "Child", {childType}); + addGlobalBinding(globals, "Child", {childType}); moduleScope->exportedTypeBindings["Child"] = TypeFun{{}, childType}; - TypeId anotherChildType = arena.addType(ClassType{"AnotherChild", {}, parentType, std::nullopt, {}, nullptr, "Test"}); + TypeId anotherChildType = arena.addType(ClassType{"AnotherChild", {}, parentType, std::nullopt, {}, nullptr, "Test", {}}); - addGlobalBinding(*frontend, "AnotherChild", {anotherChildType}); + addGlobalBinding(globals, "AnotherChild", {anotherChildType}); moduleScope->exportedTypeBindings["AnotherChild"] = TypeFun{{}, anotherChildType}; - TypeId unrelatedType = arena.addType(ClassType{"Unrelated", {}, frontend->builtinTypes->classType, std::nullopt, {}, nullptr, "Test"}); + TypeId unrelatedType = arena.addType(ClassType{"Unrelated", {}, frontend->builtinTypes->classType, std::nullopt, {}, nullptr, "Test", {}}); - addGlobalBinding(*frontend, "Unrelated", {unrelatedType}); + addGlobalBinding(globals, "Unrelated", {unrelatedType}); moduleScope->exportedTypeBindings["Unrelated"] = TypeFun{{}, unrelatedType}; for (const auto& [name, ty] : moduleScope->exportedTypeBindings) diff --git a/tests/Fixture.h b/tests/Fixture.h index a81a5e783..4f2050c11 100644 --- a/tests/Fixture.h +++ b/tests/Fixture.h @@ -2,6 +2,8 @@ #pragma once #include "Luau/Config.h" +#include "Luau/Differ.h" +#include "Luau/Error.h" #include "Luau/FileResolver.h" #include "Luau/Frontend.h" #include "Luau/IostreamHelpers.h" @@ -11,13 +13,19 @@ #include "Luau/Scope.h" #include "Luau/ToString.h" #include "Luau/Type.h" +#include "Luau/TypeFunction.h" #include "IostreamOptional.h" #include "ScopedFlags.h" +#include "doctest.h" #include +#include #include #include +#include + +LUAU_FASTFLAG(DebugLuauFreezeArena) namespace Luau { @@ -57,16 +65,16 @@ struct TestConfigResolver : ConfigResolver struct Fixture { - explicit Fixture(bool freeze = true, bool prepareAutocomplete = false); + explicit Fixture(bool prepareAutocomplete = false); ~Fixture(); // Throws Luau::ParseErrors if the parse fails. AstStatBlock* parse(const std::string& source, const ParseOptions& parseOptions = {}); - CheckResult check(Mode mode, std::string source); + CheckResult check(Mode mode, const std::string& source); CheckResult check(const std::string& source); LintResult lint(const std::string& source, const std::optional& lintOptions = {}); - LintResult lintTyped(const std::string& source, const std::optional& lintOptions = {}); + LintResult lintModule(const ModuleName& moduleName, const std::optional& lintOptions = {}); /// Parse with all language extensions enabled ParseResult parseEx(const std::string& source, const ParseOptions& parseOptions = {}); @@ -92,8 +100,16 @@ struct Fixture std::optional lookupType(const std::string& name); std::optional lookupImportedType(const std::string& moduleAlias, const std::string& name); TypeId requireTypeAlias(const std::string& name); + TypeId requireExportedType(const ModuleName& moduleName, const std::string& name); + + // While most flags can be flipped inside the unit test, some code changes affect the state that is part of Fixture initialization + // Most often those are changes related to builtin type definitions. + // In that case, flag can be forced to 'true' using the example below: + // ScopedFastFlag sff_LuauExampleFlagDefinition{FFlag::LuauExampleFlagDefinition, true}; - ScopedFastFlag sff_DebugLuauFreezeArena; + // Arena freezing marks the `TypeArena`'s underlying memory as read-only, raising an access violation whenever you mutate it. + // This is useful for tracking down violations of Luau's memory model. + ScopedFastFlag sff_DebugLuauFreezeArena{FFlag::DebugLuauFreezeArena, true}; TestFileResolver fileResolver; TestConfigResolver configResolver; @@ -101,7 +117,6 @@ struct Fixture std::unique_ptr sourceModule; Frontend frontend; InternalErrorReporter ice; - TypeChecker& typeChecker; NotNull builtinTypes; std::string decorateWithTypes(const std::string& code); @@ -123,9 +138,12 @@ struct Fixture struct BuiltinsFixture : Fixture { - BuiltinsFixture(bool freeze = true, bool prepareAutocomplete = false); + BuiltinsFixture(bool prepareAutocomplete = false); }; +std::optional pathExprToModuleName(const ModuleName& currentModuleName, const std::vector& segments); +std::optional pathExprToModuleName(const ModuleName& currentModuleName, const AstExpr& pathExpr); + ModuleName fromString(std::string_view name); template @@ -154,6 +172,84 @@ std::optional linearSearchForBinding(Scope* scope, const char* name); void registerHiddenTypes(Frontend* frontend); void createSomeClasses(Frontend* frontend); +template +struct DifferFixtureGeneric : BaseFixture +{ + std::string normalizeWhitespace(std::string msg) + { + std::string normalizedMsg = ""; + bool wasWhitespace = true; + for (char c : msg) + { + bool isWhitespace = c == ' ' || c == '\n'; + if (wasWhitespace && isWhitespace) + continue; + normalizedMsg += isWhitespace ? ' ' : c; + wasWhitespace = isWhitespace; + } + if (wasWhitespace) + normalizedMsg.pop_back(); + return normalizedMsg; + } + + void compareNe(TypeId left, TypeId right, const std::string& expectedMessage, bool multiLine) + { + compareNe(left, std::nullopt, right, std::nullopt, expectedMessage, multiLine); + } + + void compareNe( + TypeId left, + std::optional symbolLeft, + TypeId right, + std::optional symbolRight, + const std::string& expectedMessage, + bool multiLine + ) + { + DifferResult diffRes = diffWithSymbols(left, right, symbolLeft, symbolRight); + REQUIRE_MESSAGE(diffRes.diffError.has_value(), "Differ did not report type error, even though types are unequal"); + std::string diffMessage = diffRes.diffError->toString(multiLine); + CHECK_EQ(expectedMessage, diffMessage); + } + + void compareTypesNe( + const std::string& leftSymbol, + const std::string& rightSymbol, + const std::string& expectedMessage, + bool forwardSymbol = false, + bool multiLine = false + ) + { + if (forwardSymbol) + { + compareNe( + BaseFixture::requireType(leftSymbol), leftSymbol, BaseFixture::requireType(rightSymbol), rightSymbol, expectedMessage, multiLine + ); + } + else + { + compareNe( + BaseFixture::requireType(leftSymbol), std::nullopt, BaseFixture::requireType(rightSymbol), std::nullopt, expectedMessage, multiLine + ); + } + } + + void compareEq(TypeId left, TypeId right) + { + DifferResult diffRes = diff(left, right); + CHECK(!diffRes.diffError); + if (diffRes.diffError) + INFO(diffRes.diffError->toString()); + } + + void compareTypesEq(const std::string& leftSymbol, const std::string& rightSymbol) + { + compareEq(BaseFixture::requireType(leftSymbol), BaseFixture::requireType(rightSymbol)); + } +}; +using DifferFixture = DifferFixtureGeneric; +using DifferFixtureWithBuiltins = DifferFixtureGeneric; + } // namespace Luau #define LUAU_REQUIRE_ERRORS(result) \ @@ -173,3 +269,21 @@ void createSomeClasses(Frontend* frontend); } while (false) #define LUAU_REQUIRE_NO_ERRORS(result) LUAU_REQUIRE_ERROR_COUNT(0, result) + +#define LUAU_CHECK_ERRORS(result) \ + do \ + { \ + auto&& r = (result); \ + validateErrors(r.errors); \ + CHECK(!r.errors.empty()); \ + } while (false) + +#define LUAU_CHECK_ERROR_COUNT(count, result) \ + do \ + { \ + auto&& r = (result); \ + validateErrors(r.errors); \ + CHECK_MESSAGE(count == r.errors.size(), getErrors(r)); \ + } while (false) + +#define LUAU_CHECK_NO_ERRORS(result) LUAU_CHECK_ERROR_COUNT(0, result) diff --git a/tests/FragmentAutocomplete.test.cpp b/tests/FragmentAutocomplete.test.cpp new file mode 100644 index 000000000..de2e98322 --- /dev/null +++ b/tests/FragmentAutocomplete.test.cpp @@ -0,0 +1,334 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Luau/FragmentAutocomplete.h" +#include "Fixture.h" +#include "Luau/Ast.h" +#include "Luau/AstQuery.h" +#include "Luau/Common.h" +#include "Luau/Frontend.h" + + +using namespace Luau; + +LUAU_FASTFLAG(LuauAllowFragmentParsing); +LUAU_FASTFLAG(LuauStoreDFGOnModule2); + +struct FragmentAutocompleteFixture : Fixture +{ + ScopedFastFlag sffs[3] = {{FFlag::LuauAllowFragmentParsing, true}, {FFlag::LuauSolverV2, true}, {FFlag::LuauStoreDFGOnModule2, true}}; + + FragmentAutocompleteAncestryResult runAutocompleteVisitor(const std::string& source, const Position& cursorPos) + { + ParseResult p = tryParse(source); // We don't care about parsing incomplete asts + REQUIRE(p.root); + return findAncestryForFragmentParse(p.root, cursorPos); + } + + CheckResult checkBase(const std::string& document) + { + ScopedFastFlag sff{FFlag::LuauSolverV2, true}; + FrontendOptions opts; + opts.retainFullTypeGraphs = true; + return this->frontend.check("MainModule", opts); + } + + FragmentParseResult parseFragment(const std::string& document, const Position& cursorPos) + { + SourceModule* srcModule = this->getMainSourceModule(); + std::string_view srcString = document; + return Luau::parseFragment(*srcModule, srcString, cursorPos); + } + + FragmentTypeCheckResult checkFragment(const std::string& document, const Position& cursorPos) + { + FrontendOptions options; + options.retainFullTypeGraphs = true; + // Don't strictly need this in the new solver + options.forAutocomplete = true; + options.runLintChecks = false; + return Luau::typecheckFragment(frontend, "MainModule", cursorPos, options, document); + } +}; + +TEST_SUITE_BEGIN("FragmentAutocompleteTraversalTests"); + +TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "just_two_locals") +{ + auto result = runAutocompleteVisitor( + R"( +local x = 4 +local y = 5 +)", + {2, 11} + ); + + CHECK_EQ(3, result.ancestry.size()); + CHECK_EQ(1, result.localStack.size()); + CHECK_EQ(result.localMap.size(), result.localStack.size()); + REQUIRE(result.nearestStatement); + + AstStatLocal* local = result.nearestStatement->as(); + REQUIRE(local); + CHECK(1 == local->vars.size); + CHECK_EQ("y", std::string(local->vars.data[0]->name.value)); +} + +TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "cursor_within_scope_tracks_locals_from_previous_scope") +{ + auto result = runAutocompleteVisitor( + R"( +local x = 4 +local y = 5 +if x == 4 then + local e = y +end +)", + {4, 15} + ); + + CHECK_EQ(5, result.ancestry.size()); + CHECK_EQ(2, result.localStack.size()); + CHECK_EQ(result.localMap.size(), result.localStack.size()); + REQUIRE(result.nearestStatement); + CHECK_EQ("y", std::string(result.localStack.back()->name.value)); + + AstStatLocal* local = result.nearestStatement->as(); + REQUIRE(local); + CHECK(1 == local->vars.size); + CHECK_EQ("e", std::string(local->vars.data[0]->name.value)); +} + +TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "cursor_that_comes_later_shouldnt_capture_locals_in_unavailable_scope") +{ + auto result = runAutocompleteVisitor( + R"( +local x = 4 +local y = 5 +if x == 4 then + local e = y +end +local z = x + x +if y == 5 then + local q = x + y + z +end +)", + {8, 23} + ); + + CHECK_EQ(6, result.ancestry.size()); + CHECK_EQ(3, result.localStack.size()); + CHECK_EQ(result.localMap.size(), result.localStack.size()); + REQUIRE(result.nearestStatement); + CHECK_EQ("z", std::string(result.localStack.back()->name.value)); + + AstStatLocal* local = result.nearestStatement->as(); + REQUIRE(local); + CHECK(1 == local->vars.size); + CHECK_EQ("q", std::string(local->vars.data[0]->name.value)); +} + +TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "nearest_enclosing_statement_can_be_non_local") +{ + auto result = runAutocompleteVisitor( + R"( +local x = 4 +local y = 5 +if x == 4 then +)", + {3, 4} + ); + + CHECK_EQ(4, result.ancestry.size()); + CHECK_EQ(2, result.localStack.size()); + CHECK_EQ(result.localMap.size(), result.localStack.size()); + REQUIRE(result.nearestStatement); + CHECK_EQ("y", std::string(result.localStack.back()->name.value)); + + AstStatIf* ifS = result.nearestStatement->as(); + CHECK(ifS != nullptr); +} + +TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "local_funcs_show_up_in_local_stack") +{ + auto result = runAutocompleteVisitor( + R"( +local function foo() return 4 end +local x = foo() +local function bar() return x + foo() end +)", + {3, 32} + ); + + CHECK_EQ(8, result.ancestry.size()); + CHECK_EQ(2, result.localStack.size()); + CHECK_EQ(result.localMap.size(), result.localStack.size()); + CHECK_EQ("x", std::string(result.localStack.back()->name.value)); + auto returnSt = result.nearestStatement->as(); + CHECK(returnSt != nullptr); +} + +TEST_SUITE_END(); + + +TEST_SUITE_BEGIN("FragmentAutocompleteParserTests"); + +TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "statement_in_empty_fragment_is_non_null") +{ + auto res = check(R"( + +)"); + + LUAU_REQUIRE_NO_ERRORS(res); + + auto fragment = parseFragment( + R"( + +)", + Position(1, 0) + ); + CHECK_EQ("\n", fragment.fragmentToParse); + CHECK_EQ(2, fragment.ancestry.size()); + REQUIRE(fragment.root); + CHECK_EQ(0, fragment.root->body.size); + auto statBody = fragment.root->as(); + CHECK(statBody != nullptr); +} + +TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "can_parse_complete_fragments") +{ + auto res = check( + R"( +local x = 4 +local y = 5 +)" + ); + + LUAU_REQUIRE_NO_ERRORS(res); + + auto fragment = parseFragment( + R"( +local x = 4 +local y = 5 +local z = x + y +)", + Position{3, 15} + ); + + CHECK_EQ("\nlocal z = x + y", fragment.fragmentToParse); + CHECK_EQ(5, fragment.ancestry.size()); + REQUIRE(fragment.root); + CHECK_EQ(1, fragment.root->body.size); + auto stat = fragment.root->body.data[0]->as(); + REQUIRE(stat); + CHECK_EQ(1, stat->vars.size); + CHECK_EQ(1, stat->values.size); + CHECK_EQ("z", std::string(stat->vars.data[0]->name.value)); + + auto bin = stat->values.data[0]->as(); + REQUIRE(bin); + CHECK_EQ(AstExprBinary::Op::Add, bin->op); + + auto lhs = bin->left->as(); + auto rhs = bin->right->as(); + REQUIRE(lhs); + REQUIRE(rhs); + CHECK_EQ("x", std::string(lhs->local->name.value)); + CHECK_EQ("y", std::string(rhs->local->name.value)); +} + +TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "can_parse_fragments_in_line") +{ + auto res = check( + R"( +local x = 4 +local y = 5 +)" + ); + + LUAU_REQUIRE_NO_ERRORS(res); + + auto fragment = parseFragment( + R"( +local x = 4 +local z = x + y +local y = 5 +)", + Position{2, 15} + ); + + CHECK_EQ("local z = x + y", fragment.fragmentToParse); + CHECK_EQ(5, fragment.ancestry.size()); + REQUIRE(fragment.root); + CHECK_EQ(1, fragment.root->body.size); + auto stat = fragment.root->body.data[0]->as(); + REQUIRE(stat); + CHECK_EQ(1, stat->vars.size); + CHECK_EQ(1, stat->values.size); + CHECK_EQ("z", std::string(stat->vars.data[0]->name.value)); + + auto bin = stat->values.data[0]->as(); + REQUIRE(bin); + CHECK_EQ(AstExprBinary::Op::Add, bin->op); + + auto lhs = bin->left->as(); + auto rhs = bin->right->as(); + REQUIRE(lhs); + REQUIRE(rhs); + CHECK_EQ("x", std::string(lhs->local->name.value)); + CHECK_EQ("y", std::string(rhs->name.value)); +} + +TEST_SUITE_END(); + +TEST_SUITE_BEGIN("FragmentAutocompleteTypeCheckerTests"); + +TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "can_typecheck_simple_fragment") +{ + auto res = check( + R"( +local x = 4 +local y = 5 +)" + ); + + LUAU_REQUIRE_NO_ERRORS(res); + + auto fragment = checkFragment( + R"( +local x = 4 +local y = 5 +local z = x + y +)", + Position{3, 15} + ); + + auto opt = linearSearchForBinding(fragment.freshScope, "z"); + REQUIRE(opt); + CHECK_EQ("number", toString(*opt)); +} + +TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "can_typecheck_fragment_inserted_inline") +{ + auto res = check( + R"( +local x = 4 +local y = 5 +)" + ); + + LUAU_REQUIRE_NO_ERRORS(res); + auto fragment = checkFragment( + R"( +local x = 4 +local z = x +local y = 5 +)", + Position{2, 11} + ); + + auto correct = linearSearchForBinding(fragment.freshScope, "z"); + REQUIRE(correct); + CHECK_EQ("number", toString(*correct)); +} + +TEST_SUITE_END(); diff --git a/tests/Frontend.test.cpp b/tests/Frontend.test.cpp index 1d31b2813..88f91708b 100644 --- a/tests/Frontend.test.cpp +++ b/tests/Frontend.test.cpp @@ -12,6 +12,10 @@ using namespace Luau; +LUAU_FASTFLAG(LuauSolverV2) +LUAU_FASTFLAG(DebugLuauFreezeArena); +LUAU_FASTFLAG(DebugLuauMagicTypes); + namespace { @@ -81,8 +85,8 @@ struct FrontendFixture : BuiltinsFixture { FrontendFixture() { - addGlobalBinding(frontend, "game", frontend.typeChecker.anyType, "@test"); - addGlobalBinding(frontend, "script", frontend.typeChecker.anyType, "@test"); + addGlobalBinding(frontend.globals, "game", builtinTypes->anyType, "@test"); + addGlobalBinding(frontend.globals, "script", builtinTypes->anyType, "@test"); } }; @@ -152,7 +156,7 @@ TEST_CASE_FIXTURE(FrontendFixture, "automatically_check_dependent_scripts") frontend.check("game/Gui/Modules/B"); - ModulePtr bModule = frontend.moduleResolver.modules["game/Gui/Modules/B"]; + ModulePtr bModule = frontend.moduleResolver.getModule("game/Gui/Modules/B"); REQUIRE(bModule != nullptr); CHECK(bModule->errors.empty()); Luau::dumpErrors(bModule); @@ -160,7 +164,10 @@ TEST_CASE_FIXTURE(FrontendFixture, "automatically_check_dependent_scripts") auto bExports = first(bModule->returnType); REQUIRE(!!bExports); - CHECK_EQ("{| b_value: number |}", toString(*bExports)); + if (FFlag::LuauSolverV2) + CHECK_EQ("{ b_value: number }", toString(*bExports)); + else + CHECK_EQ("{| b_value: number |}", toString(*bExports)); } TEST_CASE_FIXTURE(FrontendFixture, "automatically_check_cyclically_dependent_scripts") @@ -240,13 +247,13 @@ TEST_CASE_FIXTURE(FrontendFixture, "nocheck_modules_are_typed") CheckResult result = frontend.check("game/Gui/Modules/C"); LUAU_REQUIRE_NO_ERRORS(result); - ModulePtr aModule = frontend.moduleResolver.modules["game/Gui/Modules/A"]; + ModulePtr aModule = frontend.moduleResolver.getModule("game/Gui/Modules/A"); REQUIRE(bool(aModule)); std::optional aExports = first(aModule->returnType); REQUIRE(bool(aExports)); - ModulePtr bModule = frontend.moduleResolver.modules["game/Gui/Modules/B"]; + ModulePtr bModule = frontend.moduleResolver.getModule("game/Gui/Modules/B"); REQUIRE(bool(bModule)); std::optional bExports = first(bModule->returnType); @@ -288,6 +295,7 @@ TEST_CASE_FIXTURE(FrontendFixture, "nocheck_cycle_used_by_checked") return {hello = A.hello} )"; fileResolver.source["game/Gui/Modules/C"] = R"( + --!strict local Modules = game:GetService('Gui').Modules local A = require(Modules.A) local B = require(Modules.B) @@ -297,12 +305,16 @@ TEST_CASE_FIXTURE(FrontendFixture, "nocheck_cycle_used_by_checked") CheckResult result = frontend.check("game/Gui/Modules/C"); LUAU_REQUIRE_NO_ERRORS(result); - ModulePtr cModule = frontend.moduleResolver.modules["game/Gui/Modules/C"]; + ModulePtr cModule = frontend.moduleResolver.getModule("game/Gui/Modules/C"); REQUIRE(bool(cModule)); std::optional cExports = first(cModule->returnType); REQUIRE(bool(cExports)); - CHECK_EQ("{| a: any, b: any |}", toString(*cExports)); + + if (FFlag::LuauSolverV2) + CHECK_EQ("{ a: { hello: any }, b: { hello: any } }", toString(*cExports)); + else + CHECK_EQ("{| a: any, b: any |}", toString(*cExports)); } TEST_CASE_FIXTURE(FrontendFixture, "cycle_detection_disabled_in_nocheck") @@ -444,6 +456,63 @@ TEST_CASE_FIXTURE(FrontendFixture, "cycle_incremental_type_surface_longer") CHECK_EQ(toString(tyB), "any"); } +TEST_CASE_FIXTURE(FrontendFixture, "cycle_incremental_type_surface_exports") +{ + fileResolver.source["game/A"] = R"( +local b = require(game.B) +export type atype = { x: b.btype } +return {mod_a = 1} + )"; + + fileResolver.source["game/B"] = R"( +export type btype = { x: number } + +local function bf() + local a = require(game.A) + local bfl : a.atype = nil + return {bfl.x} +end +return {mod_b = 2} + )"; + + ToStringOptions opts; + opts.exhaustive = true; + + CheckResult resultA = frontend.check("game/A"); + LUAU_REQUIRE_ERRORS(resultA); + + CheckResult resultB = frontend.check("game/B"); + LUAU_REQUIRE_ERRORS(resultB); + + TypeId tyB = requireExportedType("game/B", "btype"); + if (FFlag::LuauSolverV2) + CHECK_EQ(toString(tyB, opts), "{ x: number }"); + else + CHECK_EQ(toString(tyB, opts), "{| x: number |}"); + + TypeId tyA = requireExportedType("game/A", "atype"); + if (FFlag::LuauSolverV2) + CHECK_EQ(toString(tyA, opts), "{ x: any }"); + else + CHECK_EQ(toString(tyA, opts), "{| x: any |}"); + + frontend.markDirty("game/B"); + resultB = frontend.check("game/B"); + LUAU_REQUIRE_ERRORS(resultB); + + tyB = requireExportedType("game/B", "btype"); + if (FFlag::LuauSolverV2) + CHECK_EQ(toString(tyB, opts), "{ x: number }"); + else + CHECK_EQ(toString(tyB, opts), "{| x: number |}"); + + tyA = requireExportedType("game/A", "atype"); + if (FFlag::LuauSolverV2) + CHECK_EQ(toString(tyA, opts), "{ x: any }"); + else + CHECK_EQ(toString(tyA, opts), "{| x: any |}"); +} + TEST_CASE_FIXTURE(FrontendFixture, "dont_reparse_clean_file_when_linting") { fileResolver.source["Modules/A"] = R"( @@ -456,16 +525,16 @@ TEST_CASE_FIXTURE(FrontendFixture, "dont_reparse_clean_file_when_linting") end )"; - frontend.check("Modules/A"); + configResolver.defaultConfig.enabledLint.enableWarning(LintWarning::Code_ForRange); + + lintModule("Modules/A"); fileResolver.source["Modules/A"] = R"( -- We have fixed the lint error, but we did not tell the Frontend that the file is changed! - -- Therefore, we expect Frontend to reuse the parse tree. + -- Therefore, we expect Frontend to reuse the results from previous lint. )"; - configResolver.defaultConfig.enabledLint.enableWarning(LintWarning::Code_ForRange); - - LintResult lintResult = frontend.lint("Modules/A"); + LintResult lintResult = lintModule("Modules/A"); CHECK_EQ(1, lintResult.warnings.size()); } @@ -486,7 +555,7 @@ TEST_CASE_FIXTURE(FrontendFixture, "dont_recheck_script_that_hasnt_been_marked_d frontend.check("game/Gui/Modules/B"); - ModulePtr bModule = frontend.moduleResolver.modules["game/Gui/Modules/B"]; + ModulePtr bModule = frontend.moduleResolver.getModule("game/Gui/Modules/B"); CHECK(bModule->errors.empty()); Luau::dumpErrors(bModule); } @@ -507,14 +576,17 @@ TEST_CASE_FIXTURE(FrontendFixture, "recheck_if_dependent_script_is_dirty") frontend.check("game/Gui/Modules/B"); - ModulePtr bModule = frontend.moduleResolver.modules["game/Gui/Modules/B"]; + ModulePtr bModule = frontend.moduleResolver.getModule("game/Gui/Modules/B"); CHECK(bModule->errors.empty()); Luau::dumpErrors(bModule); auto bExports = first(bModule->returnType); REQUIRE(!!bExports); - CHECK_EQ("{| b_value: string |}", toString(*bExports)); + if (FFlag::LuauSolverV2) + CHECK_EQ("{ b_value: string }", toString(*bExports)); + else + CHECK_EQ("{| b_value: string |}", toString(*bExports)); } TEST_CASE_FIXTURE(FrontendFixture, "mark_non_immediate_reverse_deps_as_dirty") @@ -642,9 +714,14 @@ TEST_CASE_FIXTURE(FrontendFixture, "report_syntax_error_in_required_file") CHECK_EQ("Modules/A", result.errors[0].moduleName); - bool b = std::any_of(begin(result.errors), end(result.errors), [](auto&& e) -> bool { - return get(e); - }); + bool b = std::any_of( + begin(result.errors), + end(result.errors), + [](auto&& e) -> bool + { + return get(e); + } + ); if (!b) { CHECK_MESSAGE(false, "Expected a syntax error!"); @@ -737,8 +814,10 @@ TEST_CASE_FIXTURE(FrontendFixture, "accumulate_cached_errors_in_consistent_order TEST_CASE_FIXTURE(FrontendFixture, "test_pruneParentSegments") { - CHECK_EQ(std::optional{"Modules/Enum/ButtonState"}, - pathExprToModuleName("", {"Modules", "LuaApp", "DeprecatedDarkTheme", "Parent", "Parent", "Enum", "ButtonState"})); + CHECK_EQ( + std::optional{"Modules/Enum/ButtonState"}, + pathExprToModuleName("", {"Modules", "LuaApp", "DeprecatedDarkTheme", "Parent", "Parent", "Enum", "ButtonState"}) + ); CHECK_EQ(std::optional{"workspace/Foo/Bar/Baz"}, pathExprToModuleName("workspace/Foo/Quux", {"script", "Parent", "Bar", "Baz"})); CHECK_EQ(std::nullopt, pathExprToModuleName("", {})); CHECK_EQ(std::optional{"script"}, pathExprToModuleName("", {"script"})); @@ -760,25 +839,49 @@ TEST_CASE_FIXTURE(FrontendFixture, "test_lint_uses_correct_config") configResolver.configFiles["Module/A"].enabledLint.enableWarning(LintWarning::Code_ForRange); - auto result = frontend.lint("Module/A"); + auto result = lintModule("Module/A"); CHECK_EQ(1, result.warnings.size()); configResolver.configFiles["Module/A"].enabledLint.disableWarning(LintWarning::Code_ForRange); + frontend.markDirty("Module/A"); - auto result2 = frontend.lint("Module/A"); + auto result2 = lintModule("Module/A"); CHECK_EQ(0, result2.warnings.size()); LintOptions overrideOptions; overrideOptions.enableWarning(LintWarning::Code_ForRange); - auto result3 = frontend.lint("Module/A", overrideOptions); + frontend.markDirty("Module/A"); + + auto result3 = lintModule("Module/A", overrideOptions); CHECK_EQ(1, result3.warnings.size()); overrideOptions.disableWarning(LintWarning::Code_ForRange); - auto result4 = frontend.lint("Module/A", overrideOptions); + frontend.markDirty("Module/A"); + + auto result4 = lintModule("Module/A", overrideOptions); CHECK_EQ(0, result4.warnings.size()); } +TEST_CASE_FIXTURE(FrontendFixture, "lint_results_are_only_for_checked_module") +{ + fileResolver.source["Module/A"] = R"( +local _ = 0b10000000000000000000000000000000000000000000000000000000000000000 + )"; + + fileResolver.source["Module/B"] = R"( +require(script.Parent.A) +local _ = 0x10000000000000000 + )"; + + LintResult lintResult = lintModule("Module/B"); + CHECK_EQ(1, lintResult.warnings.size()); + + // Check cached result + lintResult = lintModule("Module/B"); + CHECK_EQ(1, lintResult.warnings.size()); +} + TEST_CASE_FIXTURE(FrontendFixture, "discard_type_graphs") { Frontend fe{&fileResolver, &configResolver, {false}}; @@ -814,12 +917,26 @@ TEST_CASE_FIXTURE(FrontendFixture, "it_should_be_safe_to_stringify_errors_when_f // When this test fails, it is because the TypeIds needed by the error have been deallocated. // It is thus basically impossible to predict what will happen when this assert is evaluated. // It could segfault, or you could see weird type names like the empty string or - REQUIRE_EQ( - "Table type 'a' not compatible with type '{| Count: number |}' because the former is missing field 'Count'", toString(result.errors[0])); + if (FFlag::LuauSolverV2) + REQUIRE_EQ( + R"(Type + '{ count: string }' +could not be converted into + '{ Count: number }')", + toString(result.errors[0]) + ); + else + REQUIRE_EQ( + "Table type 'a' not compatible with type '{| Count: number |}' because the former is missing field 'Count'", toString(result.errors[0]) + ); } TEST_CASE_FIXTURE(FrontendFixture, "trace_requires_in_nonstrict_mode") { + // The new non-strict mode is not currently expected to signal any errors here. + if (FFlag::LuauSolverV2) + return; + fileResolver.source["Module/A"] = R"( --!nonstrict local module = {} @@ -852,12 +969,17 @@ TEST_CASE_FIXTURE(FrontendFixture, "environments") { ScopePtr testScope = frontend.addEnvironment("test"); - unfreeze(typeChecker.globalTypes); - loadDefinitionFile(typeChecker, testScope, R"( + unfreeze(frontend.globals.globalTypes); + frontend.loadDefinitionFile( + frontend.globals, + testScope, + R"( export type Foo = number | string )", - "@test"); - freeze(typeChecker.globalTypes); + "@test", + /* captureComments */ false + ); + freeze(frontend.globals.globalTypes); fileResolver.source["A"] = R"( --!nonstrict @@ -869,13 +991,25 @@ TEST_CASE_FIXTURE(FrontendFixture, "environments") local foo: Foo = 1 )"; + fileResolver.source["C"] = R"( + --!strict + local foo: Foo = 1 + )"; + fileResolver.environments["A"] = "test"; CheckResult resultA = frontend.check("A"); LUAU_REQUIRE_NO_ERRORS(resultA); CheckResult resultB = frontend.check("B"); - LUAU_REQUIRE_ERROR_COUNT(1, resultB); + // In the new non-strict mode, we do not currently support error reporting for unknown symbols in type positions. + if (FFlag::LuauSolverV2) + LUAU_REQUIRE_NO_ERRORS(resultB); + else + LUAU_REQUIRE_ERROR_COUNT(1, resultB); + + CheckResult resultC = frontend.check("C"); + LUAU_REQUIRE_ERROR_COUNT(1, resultC); } TEST_CASE_FIXTURE(FrontendFixture, "ast_node_at_position") @@ -976,6 +1110,10 @@ TEST_CASE_FIXTURE(FrontendFixture, "typecheck_twice_for_ast_types") TEST_CASE_FIXTURE(FrontendFixture, "imported_table_modification_2") { + // This test describes non-strict mode behavior that is just not currently present in the new non-strict mode. + if (FFlag::LuauSolverV2) + return; + frontend.options.retainFullTypeGraphs = false; fileResolver.source["Module/A"] = R"( @@ -1015,7 +1153,7 @@ a:b() -- this should error, since A doesn't define a:b() TEST_CASE("no_use_after_free_with_type_fun_instantiation") { // This flag forces this test to crash if there's a UAF in this code. - ScopedFastFlag sff_DebugLuauFreezeArena("DebugLuauFreezeArena", true); + ScopedFastFlag sff_DebugLuauFreezeArena(FFlag::DebugLuauFreezeArena, true); FrontendFixture fix; @@ -1105,4 +1243,215 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "reexport_type_alias") LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(BuiltinsFixture, "module_scope_check") +{ + frontend.prepareModuleScope = [this](const ModuleName& name, const ScopePtr& scope, bool forAutocomplete) + { + scope->bindings[Luau::AstName{"x"}] = Luau::Binding{frontend.globals.builtinTypes->numberType}; + }; + + fileResolver.source["game/A"] = R"( + local a = x + )"; + + CheckResult result = frontend.check("game/A"); + LUAU_REQUIRE_NO_ERRORS(result); + + auto ty = requireType("game/A", "a"); + CHECK_EQ(toString(ty), "number"); +} + +TEST_CASE_FIXTURE(FrontendFixture, "parse_only") +{ + fileResolver.source["game/Gui/Modules/A"] = R"( + local a: number = 'oh no a type error' + return {a=a} + )"; + + fileResolver.source["game/Gui/Modules/B"] = R"( + local Modules = script.Parent + local A = require(Modules.A) + local b: number = 2 + )"; + + frontend.parse("game/Gui/Modules/B"); + + REQUIRE(frontend.sourceNodes.count("game/Gui/Modules/A")); + REQUIRE(frontend.sourceNodes.count("game/Gui/Modules/B")); + + auto node = frontend.sourceNodes["game/Gui/Modules/B"]; + CHECK(node->requireSet.contains("game/Gui/Modules/A")); + REQUIRE_EQ(node->requireLocations.size(), 1); + CHECK_EQ(node->requireLocations[0].second, Luau::Location(Position(2, 18), Position(2, 36))); + + // Early parse doesn't cause typechecking to be skipped + CheckResult result = frontend.check("game/Gui/Modules/B"); + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CHECK_EQ("game/Gui/Modules/A", result.errors[0].moduleName); + CHECK_EQ("Type 'string' could not be converted into 'number'", toString(result.errors[0])); +} + +TEST_CASE_FIXTURE(FrontendFixture, "markdirty_early_return") +{ + constexpr char moduleName[] = "game/Gui/Modules/A"; + fileResolver.source[moduleName] = R"( + return 1 + )"; + + { + std::vector markedDirty; + frontend.markDirty(moduleName, &markedDirty); + CHECK(markedDirty.empty()); + } + + frontend.parse(moduleName); + + { + std::vector markedDirty; + frontend.markDirty(moduleName, &markedDirty); + CHECK(!markedDirty.empty()); + } +} + +TEST_CASE_FIXTURE(FrontendFixture, "attribute_ices_to_the_correct_module") +{ + ScopedFastFlag sff{FFlag::DebugLuauMagicTypes, true}; + + fileResolver.source["game/one"] = R"( + require(game.two) + )"; + + fileResolver.source["game/two"] = R"( + local a: _luau_ice + )"; + + try + { + frontend.check("game/one"); + } + catch (InternalCompilerError& err) + { + CHECK("game/two" == err.moduleName); + return; + } + + FAIL("Expected an InternalCompilerError!"); +} + +TEST_CASE_FIXTURE(FrontendFixture, "checked_modules_have_the_correct_mode") +{ + fileResolver.source["game/A"] = R"( + --!nocheck + local a: number = "five" + )"; + + fileResolver.source["game/B"] = R"( + --!nonstrict + local a = math.abs("five") + )"; + + fileResolver.source["game/C"] = R"( + --!strict + local a = 10 + )"; + + frontend.check("game/A"); + frontend.check("game/B"); + frontend.check("game/C"); + + ModulePtr moduleA = frontend.moduleResolver.getModule("game/A"); + REQUIRE(moduleA); + CHECK(moduleA->mode == Mode::NoCheck); + + ModulePtr moduleB = frontend.moduleResolver.getModule("game/B"); + REQUIRE(moduleB); + CHECK(moduleB->mode == Mode::Nonstrict); + + ModulePtr moduleC = frontend.moduleResolver.getModule("game/C"); + REQUIRE(moduleC); + CHECK(moduleC->mode == Mode::Strict); +} + +TEST_CASE_FIXTURE(FrontendFixture, "separate_caches_for_autocomplete") +{ + ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + + fileResolver.source["game/A"] = R"( + --!nonstrict + local exports = {} + function exports.hello() end + return exports + )"; + + FrontendOptions opts; + opts.forAutocomplete = true; + + frontend.check("game/A", opts); + + CHECK(nullptr == frontend.moduleResolver.getModule("game/A")); + + ModulePtr acModule = frontend.moduleResolverForAutocomplete.getModule("game/A"); + REQUIRE(acModule != nullptr); + CHECK(acModule->mode == Mode::Strict); + + frontend.check("game/A"); + + ModulePtr module = frontend.moduleResolver.getModule("game/A"); + + REQUIRE(module != nullptr); + CHECK(module->mode == Mode::Nonstrict); +} + +TEST_CASE_FIXTURE(FrontendFixture, "no_separate_caches_with_the_new_solver") +{ + ScopedFastFlag sff{FFlag::LuauSolverV2, true}; + + fileResolver.source["game/A"] = R"( + --!nonstrict + local exports = {} + function exports.hello() end + return exports + )"; + + FrontendOptions opts; + opts.forAutocomplete = true; + + frontend.check("game/A", opts); + + CHECK(nullptr == frontend.moduleResolverForAutocomplete.getModule("game/A")); + + ModulePtr module = frontend.moduleResolver.getModule("game/A"); + + REQUIRE(module != nullptr); + CHECK(module->mode == Mode::Nonstrict); +} + +TEST_CASE_FIXTURE(Fixture, "exported_tables_have_position_metadata") +{ + CheckResult result = check(R"( + return { abc = 22 } + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + ModulePtr mm = getMainModule(); + + TypePackId retTp = mm->getModuleScope()->returnType; + auto retHead = flatten(retTp).first; + REQUIRE(1 == retHead.size()); + + const TableType* tt = get(retHead[0]); + REQUIRE(tt); + + CHECK("MainModule" == tt->definitionModuleName); + + CHECK(1 == tt->props.size()); + CHECK(tt->props.count("abc")); + + const Property& prop = tt->props.find("abc")->second; + + CHECK(Location{Position{1, 17}, Position{1, 20}} == prop.location); +} + TEST_SUITE_END(); diff --git a/tests/Generalization.test.cpp b/tests/Generalization.test.cpp new file mode 100644 index 000000000..1388b9005 --- /dev/null +++ b/tests/Generalization.test.cpp @@ -0,0 +1,253 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Luau/Generalization.h" +#include "Luau/Scope.h" +#include "Luau/ToString.h" +#include "Luau/Type.h" +#include "Luau/TypeArena.h" +#include "Luau/Error.h" + +#include "Fixture.h" +#include "ScopedFlags.h" + +#include "doctest.h" + +using namespace Luau; + +LUAU_FASTFLAG(LuauSolverV2) + +TEST_SUITE_BEGIN("Generalization"); + +struct GeneralizationFixture +{ + TypeArena arena; + BuiltinTypes builtinTypes; + ScopePtr globalScope = std::make_shared(builtinTypes.anyTypePack); + ScopePtr scope = std::make_shared(globalScope); + ToStringOptions opts; + + DenseHashSet generalizedTypes_{nullptr}; + NotNull> generalizedTypes{&generalizedTypes_}; + + ScopedFastFlag sff{FFlag::LuauSolverV2, true}; + + std::pair freshType() + { + FreeType ft{scope.get(), builtinTypes.neverType, builtinTypes.unknownType}; + + TypeId ty = arena.addType(ft); + FreeType* ftv = getMutable(ty); + REQUIRE(ftv != nullptr); + + return {ty, ftv}; + } + + std::string toString(TypeId ty) + { + return ::Luau::toString(ty, opts); + } + + std::string toString(TypePackId ty) + { + return ::Luau::toString(ty, opts); + } + + std::optional generalize(TypeId ty) + { + return ::Luau::generalize(NotNull{&arena}, NotNull{&builtinTypes}, NotNull{scope.get()}, generalizedTypes, ty); + } +}; + +TEST_CASE_FIXTURE(GeneralizationFixture, "generalize_a_type_that_is_bounded_by_another_generalizable_type") +{ + auto [t1, ft1] = freshType(); + auto [t2, ft2] = freshType(); + + // t2 <: t1 <: unknown + // unknown <: t2 <: t1 + + ft1->lowerBound = t2; + ft2->upperBound = t1; + ft2->lowerBound = builtinTypes.unknownType; + + auto t2generalized = generalize(t2); + REQUIRE(t2generalized); + + CHECK(follow(t1) == follow(t2)); + + auto t1generalized = generalize(t1); + REQUIRE(t1generalized); + + CHECK(builtinTypes.unknownType == follow(t1)); + CHECK(builtinTypes.unknownType == follow(t2)); +} + +// Same as generalize_a_type_that_is_bounded_by_another_generalizable_type +// except that we generalize the types in the opposite order +TEST_CASE_FIXTURE(GeneralizationFixture, "generalize_a_type_that_is_bounded_by_another_generalizable_type_in_reverse_order") +{ + auto [t1, ft1] = freshType(); + auto [t2, ft2] = freshType(); + + // t2 <: t1 <: unknown + // unknown <: t2 <: t1 + + ft1->lowerBound = t2; + ft2->upperBound = t1; + ft2->lowerBound = builtinTypes.unknownType; + + auto t1generalized = generalize(t1); + REQUIRE(t1generalized); + + CHECK(follow(t1) == follow(t2)); + + auto t2generalized = generalize(t2); + REQUIRE(t2generalized); + + CHECK(builtinTypes.unknownType == follow(t1)); + CHECK(builtinTypes.unknownType == follow(t2)); +} + +TEST_CASE_FIXTURE(GeneralizationFixture, "dont_traverse_into_class_types_when_generalizing") +{ + auto [propTy, _] = freshType(); + + TypeId cursedClass = arena.addType(ClassType{"Cursed", {{"oh_no", Property::readonly(propTy)}}, std::nullopt, std::nullopt, {}, {}, "", {}}); + + auto genClass = generalize(cursedClass); + REQUIRE(genClass); + + auto genPropTy = get(*genClass)->props.at("oh_no").readTy; + CHECK(is(*genPropTy)); +} + +TEST_CASE_FIXTURE(GeneralizationFixture, "cache_fully_generalized_types") +{ + CHECK(generalizedTypes->empty()); + + TypeId tinyTable = arena.addType( + TableType{TableType::Props{{"one", builtinTypes.numberType}, {"two", builtinTypes.stringType}}, std::nullopt, TypeLevel{}, TableState::Sealed} + ); + + generalize(tinyTable); + + CHECK(generalizedTypes->contains(tinyTable)); + CHECK(generalizedTypes->contains(builtinTypes.numberType)); + CHECK(generalizedTypes->contains(builtinTypes.stringType)); +} + +TEST_CASE_FIXTURE(GeneralizationFixture, "dont_cache_types_that_arent_done_yet") +{ + TypeId freeTy = arena.addType(FreeType{NotNull{globalScope.get()}, builtinTypes.neverType, builtinTypes.stringType}); + + TypeId fnTy = arena.addType(FunctionType{builtinTypes.emptyTypePack, arena.addTypePack(TypePack{{builtinTypes.numberType}})}); + + TypeId tableTy = arena.addType( + TableType{TableType::Props{{"one", builtinTypes.numberType}, {"two", freeTy}, {"three", fnTy}}, std::nullopt, TypeLevel{}, TableState::Sealed} + ); + + generalize(tableTy); + + CHECK(generalizedTypes->contains(fnTy)); + CHECK(generalizedTypes->contains(builtinTypes.numberType)); + CHECK(generalizedTypes->contains(builtinTypes.neverType)); + CHECK(generalizedTypes->contains(builtinTypes.stringType)); + CHECK(!generalizedTypes->contains(freeTy)); + CHECK(!generalizedTypes->contains(tableTy)); +} + +TEST_CASE_FIXTURE(GeneralizationFixture, "functions_containing_cyclic_tables_can_be_cached") +{ + TypeId selfTy = arena.addType(BlockedType{}); + + TypeId methodTy = arena.addType(FunctionType{ + arena.addTypePack({selfTy}), + arena.addTypePack({builtinTypes.numberType}), + }); + + asMutable(selfTy)->ty.emplace( + TableType::Props{{"count", builtinTypes.numberType}, {"method", methodTy}}, std::nullopt, TypeLevel{}, TableState::Sealed + ); + + generalize(methodTy); + + CHECK(generalizedTypes->contains(methodTy)); + CHECK(generalizedTypes->contains(selfTy)); + CHECK(generalizedTypes->contains(builtinTypes.numberType)); +} + +TEST_CASE_FIXTURE(GeneralizationFixture, "union_type_traversal_doesnt_crash") +{ + // t1 where t1 = ('h <: (t1 <: 'i)) | ('j <: (t1 <: 'i)) + TypeId i = arena.addType(FreeType{NotNull{globalScope.get()}}); + TypeId h = arena.addType(FreeType{NotNull{globalScope.get()}}); + TypeId j = arena.addType(FreeType{NotNull{globalScope.get()}}); + TypeId unionType = arena.addType(UnionType{{h, j}}); + getMutable(h)->upperBound = i; + getMutable(h)->lowerBound = builtinTypes.neverType; + getMutable(i)->upperBound = builtinTypes.unknownType; + getMutable(i)->lowerBound = unionType; + getMutable(j)->upperBound = i; + getMutable(j)->lowerBound = builtinTypes.neverType; + + generalize(unionType); +} + +TEST_CASE_FIXTURE(GeneralizationFixture, "intersection_type_traversal_doesnt_crash") +{ + // t1 where t1 = ('h <: (t1 <: 'i)) & ('j <: (t1 <: 'i)) + TypeId i = arena.addType(FreeType{NotNull{globalScope.get()}}); + TypeId h = arena.addType(FreeType{NotNull{globalScope.get()}}); + TypeId j = arena.addType(FreeType{NotNull{globalScope.get()}}); + TypeId intersectionType = arena.addType(IntersectionType{{h, j}}); + + getMutable(h)->upperBound = i; + getMutable(h)->lowerBound = builtinTypes.neverType; + getMutable(i)->upperBound = builtinTypes.unknownType; + getMutable(i)->lowerBound = intersectionType; + getMutable(j)->upperBound = i; + getMutable(j)->lowerBound = builtinTypes.neverType; + + generalize(intersectionType); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "generalization_traversal_should_re_traverse_unions_if_they_change_type") +{ + // This test case should just not assert + CheckResult result = check(R"( +function byId(p) + return p.id +end + +function foo() + + local productButtonPairs = {} + local func = byId + local dir = -1 + + local function updateSearch() + for product, button in pairs(productButtonPairs) do + button.LayoutOrder = func(product) * dir + end + end + + function(mode) + if mode == 'Name'then + else + if mode == 'New'then + func = function(p) + return p.id + end + elseif mode == 'Price'then + func = function(p) + return p.price + end + end + + end + end +end +)"); +} + +TEST_SUITE_END(); diff --git a/tests/InsertionOrderedMap.test.cpp b/tests/InsertionOrderedMap.test.cpp new file mode 100644 index 000000000..ca6f14994 --- /dev/null +++ b/tests/InsertionOrderedMap.test.cpp @@ -0,0 +1,140 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/InsertionOrderedMap.h" + +#include + +#include "doctest.h" + +using namespace Luau; + +struct MapFixture +{ + std::vector> ptrs; + + int* makePtr() + { + ptrs.push_back(std::make_unique(int{})); + return ptrs.back().get(); + } +}; + +TEST_SUITE_BEGIN("InsertionOrderedMap"); + +TEST_CASE_FIXTURE(MapFixture, "map_insertion") +{ + InsertionOrderedMap map; + + int* a = makePtr(); + int* b = makePtr(); + + map.insert(a, 1); + map.insert(b, 2); +} + +TEST_CASE_FIXTURE(MapFixture, "map_lookup") +{ + InsertionOrderedMap map; + + int* a = makePtr(); + map.insert(a, 1); + + int* r = map.get(a); + REQUIRE(r != nullptr); + CHECK(*r == 1); + + r = map.get(makePtr()); + CHECK(r == nullptr); +} + +TEST_CASE_FIXTURE(MapFixture, "insert_does_not_update") +{ + InsertionOrderedMap map; + + int* k = makePtr(); + map.insert(k, 1); + map.insert(k, 2); + + int* v = map.get(k); + REQUIRE(v != nullptr); + CHECK(*v == 1); +} + +TEST_CASE_FIXTURE(MapFixture, "insertion_order_is_iteration_order") +{ + // This one is a little hard to prove, in that if the ordering guarantees + // fail this test isn't guaranteed to fail, but it is strictly better than + // nothing. + + InsertionOrderedMap map; + int* a = makePtr(); + int* b = makePtr(); + int* c = makePtr(); + map.insert(a, 1); + map.insert(b, 1); + map.insert(c, 1); + + auto it = map.begin(); + REQUIRE(it != map.end()); + CHECK(it->first == a); + CHECK(it->second == 1); + + ++it; + REQUIRE(it != map.end()); + CHECK(it->first == b); + CHECK(it->second == 1); + + ++it; + REQUIRE(it != map.end()); + CHECK(it->first == c); + CHECK(it->second == 1); + + ++it; + CHECK(it == map.end()); +} + +TEST_CASE_FIXTURE(MapFixture, "destructuring_iterator_compiles") +{ + // This test's only purpose is to successfully compile. + InsertionOrderedMap map; + + for (auto [k, v] : map) + { + // Checks here solely to silence unused variable warnings. + CHECK(k); + CHECK(v > 0); + } +} + +TEST_CASE_FIXTURE(MapFixture, "map_erasure") +{ + InsertionOrderedMap map; + + int* a = makePtr(); + int* b = makePtr(); + + map.insert(a, 1); + map.insert(b, 2); + + map.erase(map.find(a)); + CHECK(map.size() == 1); + CHECK(!map.contains(a)); + CHECK(map.get(a) == nullptr); + + int* v = map.get(b); + REQUIRE(v); +} + +TEST_CASE_FIXTURE(MapFixture, "map_clear") +{ + InsertionOrderedMap map; + int* a = makePtr(); + + map.insert(a, 1); + + map.clear(); + CHECK(map.size() == 0); + CHECK(!map.contains(a)); + CHECK(map.get(a) == nullptr); +} + +TEST_SUITE_END(); diff --git a/tests/Instantiation2.test.cpp b/tests/Instantiation2.test.cpp new file mode 100644 index 000000000..fff98e601 --- /dev/null +++ b/tests/Instantiation2.test.cpp @@ -0,0 +1,51 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Luau/Instantiation2.h" + +#include "Fixture.h" +#include "ClassFixture.h" +#include "ScopedFlags.h" + +#include "doctest.h" + +using namespace Luau; + +TEST_SUITE_BEGIN("Instantiation2Test"); + +TEST_CASE_FIXTURE(Fixture, "weird_cyclic_instantiation") +{ + TypeArena arena; + Scope scope(builtinTypes->anyTypePack); + + TypeId genericT = arena.addType(GenericType{"T"}); + + TypeId idTy = arena.addType(FunctionType{ + /* generics */ {genericT}, + /* genericPacks */ {}, + /* argTypes */ arena.addTypePack({genericT}), + /* retTypes */ arena.addTypePack({genericT}) + }); + + DenseHashMap genericSubstitutions{nullptr}; + DenseHashMap genericPackSubstitutions{nullptr}; + + TypeId freeTy = arena.freshType(&scope); + FreeType* ft = getMutable(freeTy); + REQUIRE(ft); + ft->lowerBound = idTy; + ft->upperBound = builtinTypes->unknownType; + + genericSubstitutions[genericT] = freeTy; + + CHECK("(T) -> T" == toString(idTy)); + + std::optional res = instantiate2(&arena, std::move(genericSubstitutions), std::move(genericPackSubstitutions), idTy); + + // Substitutions should not mutate the original type! + CHECK("(T) -> T" == toString(idTy)); + + REQUIRE(res); + CHECK("(T) -> T" == toString(*res)); +} + +TEST_SUITE_END(); diff --git a/tests/IostreamOptional.h b/tests/IostreamOptional.h index e0756badd..51122f380 100644 --- a/tests/IostreamOptional.h +++ b/tests/IostreamOptional.h @@ -1,6 +1,7 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #pragma once +#include "Luau/DenseHash.h" #include #include @@ -21,4 +22,40 @@ auto operator<<(std::ostream& lhs, const std::optional& t) -> decltype(lhs << return lhs << "none"; } +template +auto operator<<(std::ostream& lhs, const std::vector& t) -> decltype(lhs << t[0]) +{ + lhs << "{ "; + bool first = true; + for (const T& element : t) + { + if (first) + first = false; + else + lhs << ", "; + + lhs << element; + } + + return lhs << " }"; +} + +template +auto operator<<(std::ostream& lhs, const Luau::DenseHashSet& set) -> decltype(lhs << *set.begin()) +{ + lhs << "{ "; + bool first = true; + for (const K& element : set) + { + if (first) + first = false; + else + lhs << ", "; + + lhs << element; + } + + return lhs << " }"; +} + } // namespace std diff --git a/tests/IrBuilder.test.cpp b/tests/IrBuilder.test.cpp index 0896517f9..d02fd9f17 100644 --- a/tests/IrBuilder.test.cpp +++ b/tests/IrBuilder.test.cpp @@ -4,17 +4,26 @@ #include "Luau/IrDump.h" #include "Luau/IrUtils.h" #include "Luau/OptimizeConstProp.h" +#include "Luau/OptimizeDeadStore.h" #include "Luau/OptimizeFinalX64.h" +#include "ScopedFlags.h" #include "doctest.h" #include +LUAU_FASTFLAG(DebugLuauAbortingChecks) + using namespace Luau::CodeGen; class IrBuilderFixture { public: + IrBuilderFixture() + : build(hooks) + { + } + void constantFold() { for (IrBlock& block : build.function.blocks) @@ -42,7 +51,7 @@ class IrBuilderFixture f(a); build.beginBlock(a); - build.inst(IrCmd::LOP_RETURN, build.constUint(1)); + build.inst(IrCmd::RETURN, build.constUint(1)); }; template @@ -56,36 +65,63 @@ class IrBuilderFixture f(a, b); build.beginBlock(a); - build.inst(IrCmd::LOP_RETURN, build.constUint(1)); + build.inst(IrCmd::RETURN, build.constUint(1)); build.beginBlock(b); - build.inst(IrCmd::LOP_RETURN, build.constUint(2)); + build.inst(IrCmd::RETURN, build.constUint(2)); }; - void checkEq(IrOp lhs, IrOp rhs) - { - CHECK_EQ(lhs.kind, rhs.kind); - LUAU_ASSERT(lhs.kind != IrOpKind::Constant && "can't compare constants, each ref is unique"); - CHECK_EQ(lhs.index, rhs.index); - } - void checkEq(IrOp instOp, const IrInst& inst) { const IrInst& target = build.function.instOp(instOp); CHECK(target.cmd == inst.cmd); - checkEq(target.a, inst.a); - checkEq(target.b, inst.b); - checkEq(target.c, inst.c); - checkEq(target.d, inst.d); - checkEq(target.e, inst.e); + CHECK(target.a == inst.a); + CHECK(target.b == inst.b); + CHECK(target.c == inst.c); + CHECK(target.d == inst.d); + CHECK(target.e == inst.e); + CHECK(target.f == inst.f); + } + + void defineCfgTree(const std::vector>& successorSets) + { + for (const std::vector& successorSet : successorSets) + { + build.beginBlock(build.block(IrBlockKind::Internal)); + + build.function.cfg.successorsOffsets.push_back(uint32_t(build.function.cfg.successors.size())); + build.function.cfg.successors.insert(build.function.cfg.successors.end(), successorSet.begin(), successorSet.end()); + } + + // Brute-force the predecessor list + for (int i = 0; i < int(build.function.blocks.size()); i++) + { + build.function.cfg.predecessorsOffsets.push_back(uint32_t(build.function.cfg.predecessors.size())); + + for (int k = 0; k < int(build.function.blocks.size()); k++) + { + for (uint32_t succIdx : successors(build.function.cfg, k)) + { + if (succIdx == uint32_t(i)) + build.function.cfg.predecessors.push_back(k); + } + } + } + + computeCfgImmediateDominators(build.function); + computeCfgDominanceTreeChildren(build.function); } + HostIrHooks hooks; IrBuilder build; // Luau.VM headers are not accessible static const int tnil = 0; static const int tboolean = 1; static const int tnumber = 3; + static const int tstring = 5; + static const int ttable = 6; + static const int tfunction = 7; }; TEST_SUITE_BEGIN("Optimization"); @@ -100,23 +136,23 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "FinalX64OptCheckTag") build.inst(IrCmd::CHECK_TAG, tag1, build.constTag(0), fallback); IrOp tag2 = build.inst(IrCmd::LOAD_TAG, build.vmConst(5)); build.inst(IrCmd::CHECK_TAG, tag2, build.constTag(0), fallback); - build.inst(IrCmd::LOP_RETURN, build.constUint(0)); + build.inst(IrCmd::RETURN, build.constUint(0)); build.beginBlock(fallback); - build.inst(IrCmd::LOP_RETURN, build.constUint(1)); + build.inst(IrCmd::RETURN, build.constUint(1)); updateUseCounts(build.function); optimizeMemoryOperandsX64(build.function); // Load from memory is 'inlined' into CHECK_TAG - CHECK("\n" + toString(build.function, /* includeDetails */ false) == R"( + CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"( bb_0: CHECK_TAG R2, tnil, bb_fallback_1 CHECK_TAG K5, tnil, bb_fallback_1 - LOP_RETURN 0u + RETURN 0u bb_fallback_1: - LOP_RETURN 1u + RETURN 1u )"); } @@ -129,17 +165,17 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "FinalX64OptBinaryArith") IrOp opA = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(1)); IrOp opB = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(2)); build.inst(IrCmd::ADD_NUM, opA, opB); - build.inst(IrCmd::LOP_RETURN, build.constUint(0)); + build.inst(IrCmd::RETURN, build.constUint(0)); updateUseCounts(build.function); optimizeMemoryOperandsX64(build.function); // Load from memory is 'inlined' into second argument - CHECK("\n" + toString(build.function, /* includeDetails */ false) == R"( + CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"( bb_0: %0 = LOAD_DOUBLE R1 %2 = ADD_NUM %0, R2 - LOP_RETURN 0u + RETURN 0u )"); } @@ -156,25 +192,25 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "FinalX64OptEqTag1") build.inst(IrCmd::JUMP_EQ_TAG, opA, opB, trueBlock, falseBlock); build.beginBlock(trueBlock); - build.inst(IrCmd::LOP_RETURN, build.constUint(0)); + build.inst(IrCmd::RETURN, build.constUint(0)); build.beginBlock(falseBlock); - build.inst(IrCmd::LOP_RETURN, build.constUint(0)); + build.inst(IrCmd::RETURN, build.constUint(0)); updateUseCounts(build.function); optimizeMemoryOperandsX64(build.function); // Load from memory is 'inlined' into first argument - CHECK("\n" + toString(build.function, /* includeDetails */ false) == R"( + CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"( bb_0: %1 = LOAD_TAG R2 JUMP_EQ_TAG R1, %1, bb_1, bb_2 bb_1: - LOP_RETURN 0u + RETURN 0u bb_2: - LOP_RETURN 0u + RETURN 0u )"); } @@ -192,27 +228,27 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "FinalX64OptEqTag2") build.inst(IrCmd::JUMP_EQ_TAG, opA, opB, trueBlock, falseBlock); build.beginBlock(trueBlock); - build.inst(IrCmd::LOP_RETURN, build.constUint(0)); + build.inst(IrCmd::RETURN, build.constUint(0)); build.beginBlock(falseBlock); - build.inst(IrCmd::LOP_RETURN, build.constUint(0)); + build.inst(IrCmd::RETURN, build.constUint(0)); updateUseCounts(build.function); optimizeMemoryOperandsX64(build.function); // Load from memory is 'inlined' into second argument is it can't be done for the first one // We also swap first and second argument to generate memory access on the LHS - CHECK("\n" + toString(build.function, /* includeDetails */ false) == R"( + CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"( bb_0: %0 = LOAD_TAG R1 STORE_TAG R6, %0 JUMP_EQ_TAG R2, %0, bb_1, bb_2 bb_1: - LOP_RETURN 0u + RETURN 0u bb_2: - LOP_RETURN 0u + RETURN 0u )"); } @@ -230,16 +266,16 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "FinalX64OptEqTag3") build.inst(IrCmd::JUMP_EQ_TAG, opA, build.constTag(0), trueBlock, falseBlock); build.beginBlock(trueBlock); - build.inst(IrCmd::LOP_RETURN, build.constUint(0)); + build.inst(IrCmd::RETURN, build.constUint(0)); build.beginBlock(falseBlock); - build.inst(IrCmd::LOP_RETURN, build.constUint(0)); + build.inst(IrCmd::RETURN, build.constUint(0)); updateUseCounts(build.function); optimizeMemoryOperandsX64(build.function); // Load from memory is 'inlined' into first argument - CHECK("\n" + toString(build.function, /* includeDetails */ false) == R"( + CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"( bb_0: %0 = LOAD_POINTER R1 %1 = GET_ARR_ADDR %0, 0i @@ -247,10 +283,10 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "FinalX64OptEqTag3") JUMP_EQ_TAG %2, tnil, bb_1, bb_2 bb_1: - LOP_RETURN 0u + RETURN 0u bb_2: - LOP_RETURN 0u + RETURN 0u )"); } @@ -267,25 +303,25 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "FinalX64OptJumpCmpNum") build.inst(IrCmd::JUMP_CMP_NUM, opA, opB, trueBlock, falseBlock); build.beginBlock(trueBlock); - build.inst(IrCmd::LOP_RETURN, build.constUint(0)); + build.inst(IrCmd::RETURN, build.constUint(0)); build.beginBlock(falseBlock); - build.inst(IrCmd::LOP_RETURN, build.constUint(0)); + build.inst(IrCmd::RETURN, build.constUint(0)); updateUseCounts(build.function); optimizeMemoryOperandsX64(build.function); // Load from memory is 'inlined' into first argument - CHECK("\n" + toString(build.function, /* includeDetails */ false) == R"( + CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"( bb_0: %1 = LOAD_DOUBLE R2 JUMP_CMP_NUM R1, %1, bb_1, bb_2 bb_1: - LOP_RETURN 0u + RETURN 0u bb_2: - LOP_RETURN 0u + RETURN 0u )"); } @@ -300,189 +336,456 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "Numeric") build.beginBlock(block); build.inst(IrCmd::STORE_INT, build.vmReg(0), build.inst(IrCmd::ADD_INT, build.constInt(10), build.constInt(20))); - build.inst(IrCmd::STORE_INT, build.vmReg(0), build.inst(IrCmd::ADD_INT, build.constInt(INT_MAX), build.constInt(1))); + build.inst(IrCmd::STORE_INT, build.vmReg(1), build.inst(IrCmd::ADD_INT, build.constInt(INT_MAX), build.constInt(1))); - build.inst(IrCmd::STORE_INT, build.vmReg(0), build.inst(IrCmd::SUB_INT, build.constInt(10), build.constInt(20))); - build.inst(IrCmd::STORE_INT, build.vmReg(0), build.inst(IrCmd::SUB_INT, build.constInt(INT_MIN), build.constInt(1))); + build.inst(IrCmd::STORE_INT, build.vmReg(2), build.inst(IrCmd::SUB_INT, build.constInt(10), build.constInt(20))); + build.inst(IrCmd::STORE_INT, build.vmReg(3), build.inst(IrCmd::SUB_INT, build.constInt(INT_MIN), build.constInt(1))); - build.inst(IrCmd::STORE_DOUBLE, build.vmReg(0), build.inst(IrCmd::ADD_NUM, build.constDouble(2), build.constDouble(5))); - build.inst(IrCmd::STORE_DOUBLE, build.vmReg(0), build.inst(IrCmd::SUB_NUM, build.constDouble(2), build.constDouble(5))); - build.inst(IrCmd::STORE_DOUBLE, build.vmReg(0), build.inst(IrCmd::MUL_NUM, build.constDouble(2), build.constDouble(5))); - build.inst(IrCmd::STORE_DOUBLE, build.vmReg(0), build.inst(IrCmd::DIV_NUM, build.constDouble(2), build.constDouble(5))); - build.inst(IrCmd::STORE_DOUBLE, build.vmReg(0), build.inst(IrCmd::MOD_NUM, build.constDouble(5), build.constDouble(2))); - build.inst(IrCmd::STORE_DOUBLE, build.vmReg(0), build.inst(IrCmd::POW_NUM, build.constDouble(5), build.constDouble(2))); - build.inst(IrCmd::STORE_DOUBLE, build.vmReg(0), build.inst(IrCmd::MIN_NUM, build.constDouble(5), build.constDouble(2))); - build.inst(IrCmd::STORE_DOUBLE, build.vmReg(0), build.inst(IrCmd::MAX_NUM, build.constDouble(5), build.constDouble(2))); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(4), build.inst(IrCmd::ADD_NUM, build.constDouble(2), build.constDouble(5))); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(5), build.inst(IrCmd::SUB_NUM, build.constDouble(2), build.constDouble(5))); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(6), build.inst(IrCmd::MUL_NUM, build.constDouble(2), build.constDouble(5))); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(7), build.inst(IrCmd::DIV_NUM, build.constDouble(2), build.constDouble(5))); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(8), build.inst(IrCmd::MOD_NUM, build.constDouble(5), build.constDouble(2))); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(10), build.inst(IrCmd::MIN_NUM, build.constDouble(5), build.constDouble(2))); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(11), build.inst(IrCmd::MAX_NUM, build.constDouble(5), build.constDouble(2))); - build.inst(IrCmd::STORE_DOUBLE, build.vmReg(0), build.inst(IrCmd::UNM_NUM, build.constDouble(5))); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(12), build.inst(IrCmd::UNM_NUM, build.constDouble(5))); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(13), build.inst(IrCmd::FLOOR_NUM, build.constDouble(2.5))); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(14), build.inst(IrCmd::CEIL_NUM, build.constDouble(2.5))); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(15), build.inst(IrCmd::ROUND_NUM, build.constDouble(2.5))); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(16), build.inst(IrCmd::SQRT_NUM, build.constDouble(16))); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(17), build.inst(IrCmd::ABS_NUM, build.constDouble(-4))); - build.inst(IrCmd::STORE_INT, build.vmReg(0), build.inst(IrCmd::NOT_ANY, build.constTag(tnil), build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(1)))); - build.inst(IrCmd::STORE_INT, build.vmReg(0), build.inst(IrCmd::NOT_ANY, build.constTag(tnumber), build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(1)))); - build.inst(IrCmd::STORE_INT, build.vmReg(0), build.inst(IrCmd::NOT_ANY, build.constTag(tboolean), build.constInt(0))); - build.inst(IrCmd::STORE_INT, build.vmReg(0), build.inst(IrCmd::NOT_ANY, build.constTag(tboolean), build.constInt(1))); + build.inst(IrCmd::STORE_INT, build.vmReg(18), build.inst(IrCmd::NOT_ANY, build.constTag(tnil), build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(1)))); + build.inst( + IrCmd::STORE_INT, build.vmReg(19), build.inst(IrCmd::NOT_ANY, build.constTag(tnumber), build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(1))) + ); + build.inst(IrCmd::STORE_INT, build.vmReg(20), build.inst(IrCmd::NOT_ANY, build.constTag(tboolean), build.constInt(0))); + build.inst(IrCmd::STORE_INT, build.vmReg(21), build.inst(IrCmd::NOT_ANY, build.constTag(tboolean), build.constInt(1))); - build.inst(IrCmd::STORE_DOUBLE, build.vmReg(0), build.inst(IrCmd::INT_TO_NUM, build.constInt(8))); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(22), build.inst(IrCmd::SIGN_NUM, build.constDouble(-4))); - build.inst(IrCmd::LOP_RETURN, build.constUint(0)); + build.inst(IrCmd::RETURN, build.constUint(0)); updateUseCounts(build.function); constantFold(); - CHECK("\n" + toString(build.function, /* includeDetails */ false) == R"( + CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"( bb_0: STORE_INT R0, 30i - STORE_INT R0, -2147483648i - STORE_INT R0, -10i - STORE_INT R0, 2147483647i - STORE_DOUBLE R0, 7 - STORE_DOUBLE R0, -3 - STORE_DOUBLE R0, 10 - STORE_DOUBLE R0, 0.40000000000000002 - STORE_DOUBLE R0, 1 - STORE_DOUBLE R0, 25 - STORE_DOUBLE R0, 2 - STORE_DOUBLE R0, 5 - STORE_DOUBLE R0, -5 - STORE_INT R0, 1i - STORE_INT R0, 0i - STORE_INT R0, 1i - STORE_INT R0, 0i + STORE_INT R1, -2147483648i + STORE_INT R2, -10i + STORE_INT R3, 2147483647i + STORE_DOUBLE R4, 7 + STORE_DOUBLE R5, -3 + STORE_DOUBLE R6, 10 + STORE_DOUBLE R7, 0.40000000000000002 + STORE_DOUBLE R8, 1 + STORE_DOUBLE R10, 2 + STORE_DOUBLE R11, 5 + STORE_DOUBLE R12, -5 + STORE_DOUBLE R13, 2 + STORE_DOUBLE R14, 3 + STORE_DOUBLE R15, 3 + STORE_DOUBLE R16, 4 + STORE_DOUBLE R17, 4 + STORE_INT R18, 1i + STORE_INT R19, 0i + STORE_INT R20, 1i + STORE_INT R21, 0i + STORE_DOUBLE R22, -1 + RETURN 0u + +)"); +} + +TEST_CASE_FIXTURE(IrBuilderFixture, "NumericConversions") +{ + IrOp block = build.block(IrBlockKind::Internal); + + build.beginBlock(block); + + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(0), build.inst(IrCmd::INT_TO_NUM, build.constInt(8))); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(1), build.inst(IrCmd::UINT_TO_NUM, build.constInt(0xdeee0000u))); + build.inst(IrCmd::STORE_INT, build.vmReg(2), build.inst(IrCmd::NUM_TO_INT, build.constDouble(200.0))); + build.inst(IrCmd::STORE_INT, build.vmReg(3), build.inst(IrCmd::NUM_TO_UINT, build.constDouble(3740139520.0))); + + build.inst(IrCmd::RETURN, build.constUint(0)); + + updateUseCounts(build.function); + constantFold(); + + CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"( +bb_0: STORE_DOUBLE R0, 8 - LOP_RETURN 0u + STORE_DOUBLE R1, 3740139520 + STORE_INT R2, 200i + STORE_INT R3, -554827776i + RETURN 0u + +)"); +} + +TEST_CASE_FIXTURE(IrBuilderFixture, "NumericConversionsBlocked") +{ + IrOp block = build.block(IrBlockKind::Internal); + + build.beginBlock(block); + + IrOp nan = build.inst(IrCmd::DIV_NUM, build.constDouble(0.0), build.constDouble(0.0)); + build.inst(IrCmd::STORE_INT, build.vmReg(0), build.inst(IrCmd::NUM_TO_INT, build.constDouble(1e20))); + build.inst(IrCmd::STORE_INT, build.vmReg(1), build.inst(IrCmd::NUM_TO_UINT, build.constDouble(-10))); + build.inst(IrCmd::STORE_INT, build.vmReg(2), build.inst(IrCmd::NUM_TO_INT, nan)); + build.inst(IrCmd::STORE_INT, build.vmReg(3), build.inst(IrCmd::NUM_TO_UINT, nan)); + + build.inst(IrCmd::RETURN, build.constUint(0)); + + updateUseCounts(build.function); + constantFold(); + + CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"( +bb_0: + %1 = NUM_TO_INT 1e+20 + STORE_INT R0, %1 + %3 = NUM_TO_UINT -10 + STORE_INT R1, %3 + %5 = NUM_TO_INT nan + STORE_INT R2, %5 + %7 = NUM_TO_UINT nan + STORE_INT R3, %7 + RETURN 0u + +)"); +} + +TEST_CASE_FIXTURE(IrBuilderFixture, "Bit32") +{ + IrOp block = build.block(IrBlockKind::Internal); + + build.beginBlock(block); + + IrOp unk = build.inst(IrCmd::LOAD_INT, build.vmReg(0)); + build.inst(IrCmd::STORE_INT, build.vmReg(0), build.inst(IrCmd::BITAND_UINT, build.constInt(0xfe), build.constInt(0xe))); + build.inst(IrCmd::STORE_INT, build.vmReg(1), build.inst(IrCmd::BITAND_UINT, unk, build.constInt(0))); + build.inst(IrCmd::STORE_INT, build.vmReg(2), build.inst(IrCmd::BITAND_UINT, build.constInt(0), unk)); + build.inst(IrCmd::STORE_INT, build.vmReg(3), build.inst(IrCmd::BITAND_UINT, unk, build.constInt(~0u))); + build.inst(IrCmd::STORE_INT, build.vmReg(4), build.inst(IrCmd::BITAND_UINT, build.constInt(~0u), unk)); + build.inst(IrCmd::STORE_INT, build.vmReg(5), build.inst(IrCmd::BITXOR_UINT, build.constInt(0xfe), build.constInt(0xe))); + build.inst(IrCmd::STORE_INT, build.vmReg(6), build.inst(IrCmd::BITXOR_UINT, unk, build.constInt(0))); + build.inst(IrCmd::STORE_INT, build.vmReg(7), build.inst(IrCmd::BITXOR_UINT, build.constInt(0), unk)); + build.inst(IrCmd::STORE_INT, build.vmReg(8), build.inst(IrCmd::BITXOR_UINT, unk, build.constInt(~0u))); + build.inst(IrCmd::STORE_INT, build.vmReg(9), build.inst(IrCmd::BITXOR_UINT, build.constInt(~0u), unk)); + build.inst(IrCmd::STORE_INT, build.vmReg(10), build.inst(IrCmd::BITOR_UINT, build.constInt(0xf0), build.constInt(0xe))); + build.inst(IrCmd::STORE_INT, build.vmReg(11), build.inst(IrCmd::BITOR_UINT, unk, build.constInt(0))); + build.inst(IrCmd::STORE_INT, build.vmReg(12), build.inst(IrCmd::BITOR_UINT, build.constInt(0), unk)); + build.inst(IrCmd::STORE_INT, build.vmReg(13), build.inst(IrCmd::BITOR_UINT, unk, build.constInt(~0u))); + build.inst(IrCmd::STORE_INT, build.vmReg(14), build.inst(IrCmd::BITOR_UINT, build.constInt(~0u), unk)); + build.inst(IrCmd::STORE_INT, build.vmReg(15), build.inst(IrCmd::BITNOT_UINT, build.constInt(0xe))); + build.inst(IrCmd::STORE_INT, build.vmReg(16), build.inst(IrCmd::BITLSHIFT_UINT, build.constInt(0xf0), build.constInt(4))); + build.inst(IrCmd::STORE_INT, build.vmReg(17), build.inst(IrCmd::BITLSHIFT_UINT, unk, build.constInt(0))); + build.inst(IrCmd::STORE_INT, build.vmReg(18), build.inst(IrCmd::BITRSHIFT_UINT, build.constInt(0xdeee0000u), build.constInt(8))); + build.inst(IrCmd::STORE_INT, build.vmReg(19), build.inst(IrCmd::BITRSHIFT_UINT, unk, build.constInt(0))); + build.inst(IrCmd::STORE_INT, build.vmReg(20), build.inst(IrCmd::BITARSHIFT_UINT, build.constInt(0xdeee0000u), build.constInt(8))); + build.inst(IrCmd::STORE_INT, build.vmReg(21), build.inst(IrCmd::BITARSHIFT_UINT, unk, build.constInt(0))); + build.inst(IrCmd::STORE_INT, build.vmReg(22), build.inst(IrCmd::BITLROTATE_UINT, build.constInt(0xdeee0000u), build.constInt(8))); + build.inst(IrCmd::STORE_INT, build.vmReg(23), build.inst(IrCmd::BITLROTATE_UINT, unk, build.constInt(0))); + build.inst(IrCmd::STORE_INT, build.vmReg(24), build.inst(IrCmd::BITRROTATE_UINT, build.constInt(0xdeee0000u), build.constInt(8))); + build.inst(IrCmd::STORE_INT, build.vmReg(25), build.inst(IrCmd::BITRROTATE_UINT, unk, build.constInt(0))); + build.inst(IrCmd::STORE_INT, build.vmReg(26), build.inst(IrCmd::BITCOUNTLZ_UINT, build.constInt(0xff00))); + build.inst(IrCmd::STORE_INT, build.vmReg(27), build.inst(IrCmd::BITCOUNTLZ_UINT, build.constInt(0))); + build.inst(IrCmd::STORE_INT, build.vmReg(28), build.inst(IrCmd::BITCOUNTRZ_UINT, build.constInt(0xff00))); + build.inst(IrCmd::STORE_INT, build.vmReg(29), build.inst(IrCmd::BITCOUNTRZ_UINT, build.constInt(0))); + + build.inst(IrCmd::RETURN, build.constUint(0)); + + updateUseCounts(build.function); + constantFold(); + + CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"( +bb_0: + %0 = LOAD_INT R0 + STORE_INT R0, 14i + STORE_INT R1, 0i + STORE_INT R2, 0i + STORE_INT R3, %0 + STORE_INT R4, %0 + STORE_INT R5, 240i + STORE_INT R6, %0 + STORE_INT R7, %0 + %17 = BITNOT_UINT %0 + STORE_INT R8, %17 + %19 = BITNOT_UINT %0 + STORE_INT R9, %19 + STORE_INT R10, 254i + STORE_INT R11, %0 + STORE_INT R12, %0 + STORE_INT R13, -1i + STORE_INT R14, -1i + STORE_INT R15, -15i + STORE_INT R16, 3840i + STORE_INT R17, %0 + STORE_INT R18, 14609920i + STORE_INT R19, %0 + STORE_INT R20, -2167296i + STORE_INT R21, %0 + STORE_INT R22, -301989666i + STORE_INT R23, %0 + STORE_INT R24, 14609920i + STORE_INT R25, %0 + STORE_INT R26, 16i + STORE_INT R27, 32i + STORE_INT R28, 8i + STORE_INT R29, 32i + RETURN 0u + +)"); +} + +TEST_CASE_FIXTURE(IrBuilderFixture, "Bit32RangeReduction") +{ + IrOp block = build.block(IrBlockKind::Internal); + + build.beginBlock(block); + + build.inst(IrCmd::STORE_INT, build.vmReg(10), build.inst(IrCmd::BITLSHIFT_UINT, build.constInt(0xf), build.constInt(-10))); + build.inst(IrCmd::STORE_INT, build.vmReg(10), build.inst(IrCmd::BITLSHIFT_UINT, build.constInt(0xf), build.constInt(140))); + build.inst(IrCmd::STORE_INT, build.vmReg(10), build.inst(IrCmd::BITRSHIFT_UINT, build.constInt(0xffffff), build.constInt(-10))); + build.inst(IrCmd::STORE_INT, build.vmReg(10), build.inst(IrCmd::BITRSHIFT_UINT, build.constInt(0xffffff), build.constInt(140))); + build.inst(IrCmd::STORE_INT, build.vmReg(10), build.inst(IrCmd::BITARSHIFT_UINT, build.constInt(0xffffff), build.constInt(-10))); + build.inst(IrCmd::STORE_INT, build.vmReg(10), build.inst(IrCmd::BITARSHIFT_UINT, build.constInt(0xffffff), build.constInt(140))); + + build.inst(IrCmd::RETURN, build.constUint(0)); + + updateUseCounts(build.function); + constantFold(); + + CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"( +bb_0: + STORE_INT R10, 62914560i + STORE_INT R10, 61440i + STORE_INT R10, 3i + STORE_INT R10, 4095i + STORE_INT R10, 3i + STORE_INT R10, 4095i + RETURN 0u + +)"); +} + +TEST_CASE_FIXTURE(IrBuilderFixture, "ReplacementPreservesUses") +{ + IrOp block = build.block(IrBlockKind::Internal); + + build.beginBlock(block); + + IrOp unk = build.inst(IrCmd::LOAD_INT, build.vmReg(0)); + build.inst(IrCmd::STORE_INT, build.vmReg(8), build.inst(IrCmd::BITXOR_UINT, unk, build.constInt(~0u))); + + build.inst(IrCmd::RETURN, build.constUint(0)); + + updateUseCounts(build.function); + constantFold(); + + CHECK("\n" + toString(build.function, IncludeUseInfo::Yes) == R"( +bb_0: ; useCount: 0 + %0 = LOAD_INT R0 ; useCount: 1, lastUse: %0 + %1 = BITNOT_UINT %0 ; useCount: 1, lastUse: %0 + STORE_INT R8, %1 ; %2 + RETURN 0u ; %3 + +)"); +} + +TEST_CASE_FIXTURE(IrBuilderFixture, "NumericNan") +{ + IrOp block = build.block(IrBlockKind::Internal); + + build.beginBlock(block); + + IrOp nan = build.inst(IrCmd::DIV_NUM, build.constDouble(0.0), build.constDouble(0.0)); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(0), build.inst(IrCmd::MIN_NUM, nan, build.constDouble(2))); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(0), build.inst(IrCmd::MIN_NUM, build.constDouble(1), nan)); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(0), build.inst(IrCmd::MAX_NUM, nan, build.constDouble(2))); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(0), build.inst(IrCmd::MAX_NUM, build.constDouble(1), nan)); + + build.inst(IrCmd::RETURN, build.constUint(0)); + + updateUseCounts(build.function); + constantFold(); + + CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"( +bb_0: + STORE_DOUBLE R0, 2 + STORE_DOUBLE R0, nan + STORE_DOUBLE R0, 2 + STORE_DOUBLE R0, nan + RETURN 0u )"); } TEST_CASE_FIXTURE(IrBuilderFixture, "ControlFlowEq") { - withTwoBlocks([this](IrOp a, IrOp b) { - build.inst(IrCmd::JUMP_EQ_TAG, build.constTag(tnil), build.constTag(tnil), a, b); - }); + withTwoBlocks( + [this](IrOp a, IrOp b) + { + build.inst(IrCmd::JUMP_EQ_TAG, build.constTag(tnil), build.constTag(tnil), a, b); + } + ); - withTwoBlocks([this](IrOp a, IrOp b) { - build.inst(IrCmd::JUMP_EQ_TAG, build.constTag(tnil), build.constTag(tnumber), a, b); - }); + withTwoBlocks( + [this](IrOp a, IrOp b) + { + build.inst(IrCmd::JUMP_EQ_TAG, build.constTag(tnil), build.constTag(tnumber), a, b); + } + ); - withTwoBlocks([this](IrOp a, IrOp b) { - build.inst(IrCmd::JUMP_EQ_INT, build.constInt(0), build.constInt(0), a, b); - }); + withTwoBlocks( + [this](IrOp a, IrOp b) + { + build.inst(IrCmd::JUMP_CMP_INT, build.constInt(0), build.constInt(0), build.cond(IrCondition::Equal), a, b); + } + ); - withTwoBlocks([this](IrOp a, IrOp b) { - build.inst(IrCmd::JUMP_EQ_INT, build.constInt(0), build.constInt(1), a, b); - }); + withTwoBlocks( + [this](IrOp a, IrOp b) + { + build.inst(IrCmd::JUMP_CMP_INT, build.constInt(0), build.constInt(1), build.cond(IrCondition::Equal), a, b); + } + ); updateUseCounts(build.function); constantFold(); - CHECK("\n" + toString(build.function, /* includeDetails */ false) == R"( + CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"( bb_0: JUMP bb_1 bb_1: - LOP_RETURN 1u + RETURN 1u bb_3: JUMP bb_5 bb_5: - LOP_RETURN 2u + RETURN 2u bb_6: JUMP bb_7 bb_7: - LOP_RETURN 1u + RETURN 1u bb_9: JUMP bb_11 bb_11: - LOP_RETURN 2u + RETURN 2u )"); } TEST_CASE_FIXTURE(IrBuilderFixture, "NumToIndex") { - withOneBlock([this](IrOp a) { - build.inst(IrCmd::STORE_INT, build.vmReg(0), build.inst(IrCmd::NUM_TO_INDEX, build.constDouble(4), a)); - build.inst(IrCmd::LOP_RETURN, build.constUint(0)); - }); + withOneBlock( + [this](IrOp a) + { + build.inst(IrCmd::STORE_INT, build.vmReg(0), build.inst(IrCmd::TRY_NUM_TO_INDEX, build.constDouble(4), a)); + build.inst(IrCmd::RETURN, build.constUint(0)); + } + ); - withOneBlock([this](IrOp a) { - build.inst(IrCmd::STORE_INT, build.vmReg(0), build.inst(IrCmd::NUM_TO_INDEX, build.constDouble(1.2), a)); - build.inst(IrCmd::LOP_RETURN, build.constUint(0)); - }); + withOneBlock( + [this](IrOp a) + { + build.inst(IrCmd::STORE_INT, build.vmReg(0), build.inst(IrCmd::TRY_NUM_TO_INDEX, build.constDouble(1.2), a)); + build.inst(IrCmd::RETURN, build.constUint(0)); + } + ); - withOneBlock([this](IrOp a) { - IrOp nan = build.inst(IrCmd::DIV_NUM, build.constDouble(0.0), build.constDouble(0.0)); - build.inst(IrCmd::STORE_INT, build.vmReg(0), build.inst(IrCmd::NUM_TO_INDEX, nan, a)); - build.inst(IrCmd::LOP_RETURN, build.constUint(0)); - }); + withOneBlock( + [this](IrOp a) + { + IrOp nan = build.inst(IrCmd::DIV_NUM, build.constDouble(0.0), build.constDouble(0.0)); + build.inst(IrCmd::STORE_INT, build.vmReg(0), build.inst(IrCmd::TRY_NUM_TO_INDEX, nan, a)); + build.inst(IrCmd::RETURN, build.constUint(0)); + } + ); updateUseCounts(build.function); constantFold(); - CHECK("\n" + toString(build.function, /* includeDetails */ false) == R"( + CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"( bb_0: STORE_INT R0, 4i - LOP_RETURN 0u + RETURN 0u bb_2: JUMP bb_3 bb_3: - LOP_RETURN 1u + RETURN 1u bb_4: JUMP bb_5 bb_5: - LOP_RETURN 1u + RETURN 1u )"); } TEST_CASE_FIXTURE(IrBuilderFixture, "Guards") { - withOneBlock([this](IrOp a) { - build.inst(IrCmd::CHECK_TAG, build.constTag(tnumber), build.constTag(tnumber), a); - build.inst(IrCmd::LOP_RETURN, build.constUint(0)); - }); + withOneBlock( + [this](IrOp a) + { + build.inst(IrCmd::CHECK_TAG, build.constTag(tnumber), build.constTag(tnumber), a); + build.inst(IrCmd::RETURN, build.constUint(0)); + } + ); - withOneBlock([this](IrOp a) { - build.inst(IrCmd::CHECK_TAG, build.constTag(tnil), build.constTag(tnumber), a); - build.inst(IrCmd::LOP_RETURN, build.constUint(0)); - }); + withOneBlock( + [this](IrOp a) + { + build.inst(IrCmd::CHECK_TAG, build.constTag(tnil), build.constTag(tnumber), a); + build.inst(IrCmd::RETURN, build.constUint(0)); + } + ); updateUseCounts(build.function); constantFold(); - CHECK("\n" + toString(build.function, /* includeDetails */ false) == R"( + CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"( bb_0: - LOP_RETURN 0u + RETURN 0u bb_2: JUMP bb_3 bb_3: - LOP_RETURN 1u + RETURN 1u )"); } TEST_CASE_FIXTURE(IrBuilderFixture, "ControlFlowCmpNum") { - auto compareFold = [this](IrOp lhs, IrOp rhs, IrCondition cond, bool result) { + auto compareFold = [this](IrOp lhs, IrOp rhs, IrCondition cond, bool result) + { IrOp instOp; IrInst instExpected; - withTwoBlocks([&](IrOp a, IrOp b) { - IrOp nan = build.inst(IrCmd::DIV_NUM, build.constDouble(0.0), build.constDouble(0.0)); - instOp = build.inst( - IrCmd::JUMP_CMP_NUM, lhs.kind == IrOpKind::None ? nan : lhs, rhs.kind == IrOpKind::None ? nan : rhs, build.cond(cond), a, b); - instExpected = IrInst{IrCmd::JUMP, result ? a : b}; - }); + withTwoBlocks( + [&](IrOp a, IrOp b) + { + IrOp nan = build.inst(IrCmd::DIV_NUM, build.constDouble(0.0), build.constDouble(0.0)); + instOp = build.inst( + IrCmd::JUMP_CMP_NUM, lhs.kind == IrOpKind::None ? nan : lhs, rhs.kind == IrOpKind::None ? nan : rhs, build.cond(cond), a, b + ); + instExpected = IrInst{IrCmd::JUMP, result ? a : b}; + } + ); updateUseCounts(build.function); constantFold(); @@ -574,12 +877,12 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "RememberTagsAndValues") build.inst(IrCmd::STORE_INT, build.vmReg(10), build.inst(IrCmd::LOAD_INT, build.vmReg(1))); build.inst(IrCmd::STORE_DOUBLE, build.vmReg(11), build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(2))); - build.inst(IrCmd::LOP_RETURN, build.constUint(0)); + build.inst(IrCmd::RETURN, build.constUint(0)); updateUseCounts(build.function); - constPropInBlockChains(build); + constPropInBlockChains(build, true); - CHECK("\n" + toString(build.function, /* includeDetails */ false) == R"( + CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"( bb_0: STORE_TAG R0, tnumber STORE_INT R1, 10i @@ -595,11 +898,9 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "RememberTagsAndValues") STORE_DOUBLE R2, %16 %18 = LOAD_TAG R0 STORE_TAG R9, %18 - %20 = LOAD_INT R1 - STORE_INT R10, %20 - %22 = LOAD_DOUBLE R2 - STORE_DOUBLE R11, %22 - LOP_RETURN 0u + STORE_INT R10, %14 + STORE_DOUBLE R11, %16 + RETURN 0u )"); } @@ -620,20 +921,19 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "PropagateThroughTvalue") build.inst(IrCmd::STORE_TAG, build.vmReg(3), build.inst(IrCmd::LOAD_TAG, build.vmReg(1))); build.inst(IrCmd::STORE_DOUBLE, build.vmReg(3), build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(1))); - build.inst(IrCmd::LOP_RETURN, build.constUint(0)); + build.inst(IrCmd::RETURN, build.constUint(0)); updateUseCounts(build.function); - constPropInBlockChains(build); + constPropInBlockChains(build, true); - CHECK("\n" + toString(build.function, /* includeDetails */ false) == R"( + CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"( bb_0: STORE_TAG R0, tnumber STORE_DOUBLE R0, 0.5 - %2 = LOAD_TVALUE R0 - STORE_TVALUE R1, %2 + STORE_SPLIT_TVALUE R1, tnumber, 0.5 STORE_TAG R3, tnumber STORE_DOUBLE R3, 0.5 - LOP_RETURN 0u + RETURN 0u )"); } @@ -647,18 +947,18 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "SkipCheckTag") build.inst(IrCmd::STORE_TAG, build.vmReg(0), build.constTag(tnumber)); build.inst(IrCmd::CHECK_TAG, build.inst(IrCmd::LOAD_TAG, build.vmReg(0)), build.constTag(tnumber), fallback); - build.inst(IrCmd::LOP_RETURN, build.constUint(0)); + build.inst(IrCmd::RETURN, build.constUint(0)); build.beginBlock(fallback); - build.inst(IrCmd::LOP_RETURN, build.constUint(1)); + build.inst(IrCmd::RETURN, build.constUint(1)); updateUseCounts(build.function); - constPropInBlockChains(build); + constPropInBlockChains(build, true); - CHECK("\n" + toString(build.function, /* includeDetails */ false) == R"( + CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"( bb_0: STORE_TAG R0, tnumber - LOP_RETURN 0u + RETURN 0u )"); } @@ -677,18 +977,18 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "SkipOncePerBlockChecks") build.inst(IrCmd::DO_LEN, build.vmReg(1), build.vmReg(2)); // Can make env unsafe build.inst(IrCmd::CHECK_SAFE_ENV); - build.inst(IrCmd::LOP_RETURN, build.constUint(0)); + build.inst(IrCmd::RETURN, build.constUint(0)); updateUseCounts(build.function); - constPropInBlockChains(build); + constPropInBlockChains(build, true); - CHECK("\n" + toString(build.function, /* includeDetails */ false) == R"( + CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"( bb_0: CHECK_SAFE_ENV CHECK_GC DO_LEN R1, R2 CHECK_SAFE_ENV - LOP_RETURN 0u + RETURN 0u )"); } @@ -713,15 +1013,15 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "RememberTableState") build.inst(IrCmd::CHECK_NO_METATABLE, table, fallback); build.inst(IrCmd::CHECK_READONLY, table, fallback); - build.inst(IrCmd::LOP_RETURN, build.constUint(0)); + build.inst(IrCmd::RETURN, build.constUint(0)); build.beginBlock(fallback); - build.inst(IrCmd::LOP_RETURN, build.constUint(1)); + build.inst(IrCmd::RETURN, build.constUint(1)); updateUseCounts(build.function); - constPropInBlockChains(build); + constPropInBlockChains(build, true); - CHECK("\n" + toString(build.function, /* includeDetails */ false) == R"( + CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"( bb_0: %0 = LOAD_POINTER R0 CHECK_NO_METATABLE %0, bb_fallback_1 @@ -729,10 +1029,56 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "RememberTableState") DO_LEN R1, R2 CHECK_NO_METATABLE %0, bb_fallback_1 CHECK_READONLY %0, bb_fallback_1 - LOP_RETURN 0u + RETURN 0u + +bb_fallback_1: + RETURN 1u + +)"); +} + +TEST_CASE_FIXTURE(IrBuilderFixture, "RememberNewTableState") +{ + IrOp block = build.block(IrBlockKind::Internal); + IrOp fallback = build.block(IrBlockKind::Fallback); + + build.beginBlock(block); + + IrOp newtable = build.inst(IrCmd::NEW_TABLE, build.constUint(16), build.constUint(32)); + build.inst(IrCmd::STORE_POINTER, build.vmReg(0), newtable); + + IrOp table = build.inst(IrCmd::LOAD_POINTER, build.vmReg(0)); + + build.inst(IrCmd::CHECK_NO_METATABLE, table, fallback); + build.inst(IrCmd::CHECK_READONLY, table, fallback); + build.inst(IrCmd::CHECK_ARRAY_SIZE, table, build.constInt(14), fallback); + + build.inst(IrCmd::SET_TABLE, build.vmReg(1), build.vmReg(0), build.constUint(13)); // Invalidate table knowledge + + build.inst(IrCmd::CHECK_NO_METATABLE, table, fallback); + build.inst(IrCmd::CHECK_READONLY, table, fallback); + build.inst(IrCmd::CHECK_ARRAY_SIZE, table, build.constInt(14), fallback); + + build.inst(IrCmd::RETURN, build.constUint(0)); + + build.beginBlock(fallback); + build.inst(IrCmd::RETURN, build.constUint(1)); + + updateUseCounts(build.function); + constPropInBlockChains(build, true); + + CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"( +bb_0: + %0 = NEW_TABLE 16u, 32u + STORE_POINTER R0, %0 + SET_TABLE R1, R0, 13u + CHECK_NO_METATABLE %0, bb_fallback_1 + CHECK_READONLY %0, bb_fallback_1 + CHECK_ARRAY_SIZE %0, 14i, bb_fallback_1 + RETURN 0u bb_fallback_1: - LOP_RETURN 1u + RETURN 1u )"); } @@ -745,18 +1091,18 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "SkipUselessBarriers") build.inst(IrCmd::STORE_TAG, build.vmReg(0), build.constTag(tnumber)); IrOp table = build.inst(IrCmd::LOAD_POINTER, build.vmReg(1)); - build.inst(IrCmd::BARRIER_TABLE_FORWARD, table, build.vmReg(0)); + build.inst(IrCmd::BARRIER_TABLE_FORWARD, table, build.vmReg(0), build.undef()); IrOp something = build.inst(IrCmd::LOAD_POINTER, build.vmReg(2)); - build.inst(IrCmd::BARRIER_OBJ, something, build.vmReg(0)); - build.inst(IrCmd::LOP_RETURN, build.constUint(0)); + build.inst(IrCmd::BARRIER_OBJ, something, build.vmReg(0), build.undef()); + build.inst(IrCmd::RETURN, build.constUint(0)); updateUseCounts(build.function); - constPropInBlockChains(build); + constPropInBlockChains(build, true); - CHECK("\n" + toString(build.function, /* includeDetails */ false) == R"( + CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"( bb_0: STORE_TAG R0, tnumber - LOP_RETURN 0u + RETURN 0u )"); } @@ -770,31 +1116,35 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "ConcatInvalidation") build.inst(IrCmd::STORE_TAG, build.vmReg(0), build.constTag(tnumber)); build.inst(IrCmd::STORE_INT, build.vmReg(1), build.constInt(10)); build.inst(IrCmd::STORE_DOUBLE, build.vmReg(2), build.constDouble(0.5)); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(3), build.constDouble(2.0)); - build.inst(IrCmd::CONCAT, build.vmReg(0), build.vmReg(3)); // Concat invalidates more than the target register + build.inst(IrCmd::CONCAT, build.vmReg(0), build.constUint(3)); - build.inst(IrCmd::STORE_TAG, build.vmReg(3), build.inst(IrCmd::LOAD_TAG, build.vmReg(0))); - build.inst(IrCmd::STORE_INT, build.vmReg(4), build.inst(IrCmd::LOAD_INT, build.vmReg(1))); - build.inst(IrCmd::STORE_DOUBLE, build.vmReg(5), build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(2))); + build.inst(IrCmd::STORE_TAG, build.vmReg(4), build.inst(IrCmd::LOAD_TAG, build.vmReg(0))); + build.inst(IrCmd::STORE_INT, build.vmReg(5), build.inst(IrCmd::LOAD_INT, build.vmReg(1))); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(6), build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(2))); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(7), build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(3))); - build.inst(IrCmd::LOP_RETURN, build.constUint(0)); + build.inst(IrCmd::RETURN, build.constUint(0)); updateUseCounts(build.function); - constPropInBlockChains(build); + constPropInBlockChains(build, true); - CHECK("\n" + toString(build.function, /* includeDetails */ false) == R"( + CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"( bb_0: STORE_TAG R0, tnumber STORE_INT R1, 10i STORE_DOUBLE R2, 0.5 - CONCAT R0, R3 - %4 = LOAD_TAG R0 - STORE_TAG R3, %4 - %6 = LOAD_INT R1 - STORE_INT R4, %6 - %8 = LOAD_DOUBLE R2 - STORE_DOUBLE R5, %8 - LOP_RETURN 0u + STORE_DOUBLE R3, 2 + CONCAT R0, 3u + %5 = LOAD_TAG R0 + STORE_TAG R4, %5 + %7 = LOAD_INT R1 + STORE_INT R5, %7 + %9 = LOAD_DOUBLE R2 + STORE_DOUBLE R6, %9 + STORE_DOUBLE R7, 2 + RETURN 0u )"); } @@ -813,61 +1163,44 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "BuiltinFastcallsMayInvalidateMemory") build.inst(IrCmd::CHECK_NO_METATABLE, table, fallback); build.inst(IrCmd::CHECK_READONLY, table, fallback); - build.inst(IrCmd::INVOKE_FASTCALL, build.constUint(LBF_SETMETATABLE), build.vmReg(1), build.vmReg(2), build.vmReg(3), build.constInt(3), - build.constInt(1)); + build.inst( + IrCmd::INVOKE_FASTCALL, + build.constUint(LBF_SETMETATABLE), + build.vmReg(1), + build.vmReg(2), + build.vmReg(3), + build.undef(), + build.constInt(3), + build.constInt(1) + ); build.inst(IrCmd::CHECK_NO_METATABLE, table, fallback); build.inst(IrCmd::CHECK_READONLY, table, fallback); build.inst(IrCmd::STORE_DOUBLE, build.vmReg(1), build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(0))); // At least R0 wasn't touched - build.inst(IrCmd::LOP_RETURN, build.constUint(0)); + build.inst(IrCmd::RETURN, build.constUint(0)); build.beginBlock(fallback); - build.inst(IrCmd::LOP_RETURN, build.constUint(1)); + build.inst(IrCmd::RETURN, build.constUint(1)); updateUseCounts(build.function); - constPropInBlockChains(build); + constPropInBlockChains(build, true); - CHECK("\n" + toString(build.function, /* includeDetails */ false) == R"( + CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"( bb_0: STORE_DOUBLE R0, 0.5 %1 = LOAD_POINTER R0 CHECK_NO_METATABLE %1, bb_fallback_1 CHECK_READONLY %1, bb_fallback_1 - %4 = INVOKE_FASTCALL 61u, R1, R2, R3, 3i, 1i + %4 = INVOKE_FASTCALL 61u, R1, R2, R3, undef, 3i, 1i CHECK_NO_METATABLE %1, bb_fallback_1 CHECK_READONLY %1, bb_fallback_1 STORE_DOUBLE R1, 0.5 - LOP_RETURN 0u + RETURN 0u bb_fallback_1: - LOP_RETURN 1u - -)"); -} - -TEST_CASE_FIXTURE(IrBuilderFixture, "RedundantStoreCheckConstantType") -{ - IrOp block = build.block(IrBlockKind::Internal); - - build.beginBlock(block); - - build.inst(IrCmd::STORE_INT, build.vmReg(0), build.constInt(10)); - build.inst(IrCmd::STORE_DOUBLE, build.vmReg(0), build.constDouble(0.5)); - build.inst(IrCmd::STORE_INT, build.vmReg(0), build.constInt(10)); - - build.inst(IrCmd::LOP_RETURN, build.constUint(0)); - - updateUseCounts(build.function); - constPropInBlockChains(build); - - CHECK("\n" + toString(build.function, /* includeDetails */ false) == R"( -bb_0: - STORE_INT R0, 10i - STORE_DOUBLE R0, 0.5 - STORE_INT R0, 10i - LOP_RETURN 0u + RETURN 1u )"); } @@ -882,17 +1215,17 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "RedundantStoreCheckConstantType") build.inst(IrCmd::STORE_DOUBLE, build.vmReg(0), build.constDouble(0.5)); build.inst(IrCmd::STORE_INT, build.vmReg(0), build.constInt(10)); - build.inst(IrCmd::LOP_RETURN, build.constUint(0)); + build.inst(IrCmd::RETURN, build.constUint(0)); updateUseCounts(build.function); - constPropInBlockChains(build); + constPropInBlockChains(build, true); - CHECK("\n" + toString(build.function, /* includeDetails */ false) == R"( + CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"( bb_0: STORE_INT R0, 10i STORE_DOUBLE R0, 0.5 STORE_INT R0, 10i - LOP_RETURN 0u + RETURN 0u )"); } @@ -909,22 +1242,22 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "TagCheckPropagation") build.inst(IrCmd::CHECK_TAG, unknown, build.constTag(tnumber), fallback); build.inst(IrCmd::CHECK_TAG, unknown, build.constTag(tnumber), fallback); - build.inst(IrCmd::LOP_RETURN, build.constUint(0)); + build.inst(IrCmd::RETURN, build.constUint(0)); build.beginBlock(fallback); - build.inst(IrCmd::LOP_RETURN, build.constUint(1)); + build.inst(IrCmd::RETURN, build.constUint(1)); updateUseCounts(build.function); - constPropInBlockChains(build); + constPropInBlockChains(build, true); - CHECK("\n" + toString(build.function, /* includeDetails */ false) == R"( + CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"( bb_0: %0 = LOAD_TAG R0 CHECK_TAG %0, tnumber, bb_fallback_1 - LOP_RETURN 0u + RETURN 0u bb_fallback_1: - LOP_RETURN 1u + RETURN 1u )"); } @@ -941,22 +1274,22 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "TagCheckPropagationConflicting") build.inst(IrCmd::CHECK_TAG, unknown, build.constTag(tnumber), fallback); build.inst(IrCmd::CHECK_TAG, unknown, build.constTag(tnil), fallback); - build.inst(IrCmd::LOP_RETURN, build.constUint(0)); + build.inst(IrCmd::RETURN, build.constUint(0)); build.beginBlock(fallback); - build.inst(IrCmd::LOP_RETURN, build.constUint(1)); + build.inst(IrCmd::RETURN, build.constUint(1)); updateUseCounts(build.function); - constPropInBlockChains(build); + constPropInBlockChains(build, true); - CHECK("\n" + toString(build.function, /* includeDetails */ false) == R"( + CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"( bb_0: %0 = LOAD_TAG R0 CHECK_TAG %0, tnumber, bb_fallback_1 JUMP bb_fallback_1 bb_fallback_1: - LOP_RETURN 1u + RETURN 1u )"); } @@ -974,28 +1307,28 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "TruthyTestRemoval") build.inst(IrCmd::JUMP_IF_TRUTHY, build.vmReg(1), trueBlock, falseBlock); build.beginBlock(trueBlock); - build.inst(IrCmd::LOP_RETURN, build.constUint(1)); + build.inst(IrCmd::RETURN, build.constUint(1)); build.beginBlock(falseBlock); - build.inst(IrCmd::LOP_RETURN, build.constUint(2)); + build.inst(IrCmd::RETURN, build.constUint(2)); build.beginBlock(fallback); - build.inst(IrCmd::LOP_RETURN, build.constUint(3)); + build.inst(IrCmd::RETURN, build.constUint(3)); updateUseCounts(build.function); - constPropInBlockChains(build); + constPropInBlockChains(build, true); - CHECK("\n" + toString(build.function, /* includeDetails */ false) == R"( + CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"( bb_0: %0 = LOAD_TAG R1 CHECK_TAG %0, tnumber, bb_fallback_3 JUMP bb_1 bb_1: - LOP_RETURN 1u + RETURN 1u bb_fallback_3: - LOP_RETURN 3u + RETURN 3u )"); } @@ -1013,28 +1346,28 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "FalsyTestRemoval") build.inst(IrCmd::JUMP_IF_FALSY, build.vmReg(1), trueBlock, falseBlock); build.beginBlock(trueBlock); - build.inst(IrCmd::LOP_RETURN, build.constUint(1)); + build.inst(IrCmd::RETURN, build.constUint(1)); build.beginBlock(falseBlock); - build.inst(IrCmd::LOP_RETURN, build.constUint(2)); + build.inst(IrCmd::RETURN, build.constUint(2)); build.beginBlock(fallback); - build.inst(IrCmd::LOP_RETURN, build.constUint(3)); + build.inst(IrCmd::RETURN, build.constUint(3)); updateUseCounts(build.function); - constPropInBlockChains(build); + constPropInBlockChains(build, true); - CHECK("\n" + toString(build.function, /* includeDetails */ false) == R"( + CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"( bb_0: %0 = LOAD_TAG R1 CHECK_TAG %0, tnumber, bb_fallback_3 JUMP bb_2 bb_2: - LOP_RETURN 2u + RETURN 2u bb_fallback_3: - LOP_RETURN 3u + RETURN 3u )"); } @@ -1051,22 +1384,22 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "TagEqRemoval") build.inst(IrCmd::JUMP_EQ_TAG, tag, build.constTag(tnumber), trueBlock, falseBlock); build.beginBlock(trueBlock); - build.inst(IrCmd::LOP_RETURN, build.constUint(1)); + build.inst(IrCmd::RETURN, build.constUint(1)); build.beginBlock(falseBlock); - build.inst(IrCmd::LOP_RETURN, build.constUint(2)); + build.inst(IrCmd::RETURN, build.constUint(2)); updateUseCounts(build.function); - constPropInBlockChains(build); + constPropInBlockChains(build, true); - CHECK("\n" + toString(build.function, /* includeDetails */ false) == R"( + CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"( bb_0: %0 = LOAD_TAG R1 CHECK_TAG %0, tboolean JUMP bb_2 bb_2: - LOP_RETURN 2u + RETURN 2u )"); } @@ -1078,26 +1411,26 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "IntEqRemoval") IrOp falseBlock = build.block(IrBlockKind::Internal); build.beginBlock(block); - IrOp value = build.inst(IrCmd::LOAD_INT, build.vmReg(1)); build.inst(IrCmd::STORE_INT, build.vmReg(1), build.constInt(5)); - build.inst(IrCmd::JUMP_EQ_INT, value, build.constInt(5), trueBlock, falseBlock); + IrOp value = build.inst(IrCmd::LOAD_INT, build.vmReg(1)); + build.inst(IrCmd::JUMP_CMP_INT, value, build.constInt(5), build.cond(IrCondition::Equal), trueBlock, falseBlock); build.beginBlock(trueBlock); - build.inst(IrCmd::LOP_RETURN, build.constUint(1)); + build.inst(IrCmd::RETURN, build.constUint(1)); build.beginBlock(falseBlock); - build.inst(IrCmd::LOP_RETURN, build.constUint(2)); + build.inst(IrCmd::RETURN, build.constUint(2)); updateUseCounts(build.function); - constPropInBlockChains(build); + constPropInBlockChains(build, true); - CHECK("\n" + toString(build.function, /* includeDetails */ false) == R"( + CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"( bb_0: STORE_INT R1, 5i JUMP bb_1 bb_1: - LOP_RETURN 1u + RETURN 1u )"); } @@ -1109,26 +1442,26 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "NumCmpRemoval") IrOp falseBlock = build.block(IrBlockKind::Internal); build.beginBlock(block); - IrOp value = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(1)); build.inst(IrCmd::STORE_DOUBLE, build.vmReg(1), build.constDouble(4.0)); + IrOp value = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(1)); build.inst(IrCmd::JUMP_CMP_NUM, value, build.constDouble(8.0), build.cond(IrCondition::Greater), trueBlock, falseBlock); build.beginBlock(trueBlock); - build.inst(IrCmd::LOP_RETURN, build.constUint(1)); + build.inst(IrCmd::RETURN, build.constUint(1)); build.beginBlock(falseBlock); - build.inst(IrCmd::LOP_RETURN, build.constUint(2)); + build.inst(IrCmd::RETURN, build.constUint(2)); updateUseCounts(build.function); - constPropInBlockChains(build); + constPropInBlockChains(build, true); - CHECK("\n" + toString(build.function, /* includeDetails */ false) == R"( + CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"( bb_0: STORE_DOUBLE R1, 4 JUMP bb_2 bb_2: - LOP_RETURN 2u + RETURN 2u )"); } @@ -1145,19 +1478,19 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "DataFlowsThroughDirectJumpToUniqueSuccessor build.beginBlock(block2); build.inst(IrCmd::STORE_TAG, build.vmReg(1), build.inst(IrCmd::LOAD_TAG, build.vmReg(0))); - build.inst(IrCmd::LOP_RETURN, build.constUint(1)); + build.inst(IrCmd::RETURN, build.constUint(1)); updateUseCounts(build.function); - constPropInBlockChains(build); + constPropInBlockChains(build, true); - CHECK("\n" + toString(build.function, /* includeDetails */ false) == R"( + CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"( bb_0: STORE_TAG R0, tnumber JUMP bb_1 bb_1: STORE_TAG R1, tnumber - LOP_RETURN 1u + RETURN 1u )"); } @@ -1175,15 +1508,15 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "DataDoesNotFlowThroughDirectJumpToNonUnique build.beginBlock(block2); build.inst(IrCmd::STORE_TAG, build.vmReg(1), build.inst(IrCmd::LOAD_TAG, build.vmReg(0))); - build.inst(IrCmd::LOP_RETURN, build.constUint(1)); + build.inst(IrCmd::RETURN, build.constUint(1)); build.beginBlock(block3); build.inst(IrCmd::JUMP, block2); updateUseCounts(build.function); - constPropInBlockChains(build); + constPropInBlockChains(build, true); - CHECK("\n" + toString(build.function, /* includeDetails */ false) == R"( + CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"( bb_0: STORE_TAG R0, tnumber JUMP bb_1 @@ -1191,7 +1524,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "DataDoesNotFlowThroughDirectJumpToNonUnique bb_1: %2 = LOAD_TAG R0 STORE_TAG R1, %2 - LOP_RETURN 1u + RETURN 1u bb_2: JUMP bb_1 @@ -1199,74 +1532,263 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "DataDoesNotFlowThroughDirectJumpToNonUnique )"); } -TEST_SUITE_END(); +TEST_CASE_FIXTURE(IrBuilderFixture, "EntryBlockUseRemoval") +{ + IrOp entry = build.block(IrBlockKind::Internal); + IrOp exit = build.block(IrBlockKind::Internal); + IrOp repeat = build.block(IrBlockKind::Internal); -TEST_SUITE_BEGIN("LinearExecutionFlowExtraction"); + build.beginBlock(entry); + build.inst(IrCmd::STORE_TAG, build.vmReg(0), build.constTag(tnumber)); + build.inst(IrCmd::JUMP_IF_TRUTHY, build.vmReg(0), exit, repeat); -TEST_CASE_FIXTURE(IrBuilderFixture, "SimplePathExtraction") -{ - IrOp block1 = build.block(IrBlockKind::Internal); - IrOp fallback1 = build.block(IrBlockKind::Fallback); - IrOp block2 = build.block(IrBlockKind::Internal); - IrOp fallback2 = build.block(IrBlockKind::Fallback); - IrOp block3 = build.block(IrBlockKind::Internal); - IrOp block4 = build.block(IrBlockKind::Internal); + build.beginBlock(exit); + build.inst(IrCmd::RETURN, build.vmReg(0), build.constInt(0)); - build.beginBlock(block1); + build.beginBlock(repeat); + build.inst(IrCmd::INTERRUPT, build.constUint(0)); + build.inst(IrCmd::JUMP, entry); - IrOp tag1 = build.inst(IrCmd::LOAD_TAG, build.vmReg(2)); - build.inst(IrCmd::CHECK_TAG, tag1, build.constTag(tnumber), fallback1); - build.inst(IrCmd::JUMP, block2); + updateUseCounts(build.function); + constPropInBlockChains(build, true); - build.beginBlock(fallback1); - build.inst(IrCmd::DO_LEN, build.vmReg(1), build.vmReg(2)); - build.inst(IrCmd::JUMP, block2); + CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"( +bb_0: + STORE_TAG R0, tnumber + JUMP bb_1 - build.beginBlock(block2); - IrOp tag2 = build.inst(IrCmd::LOAD_TAG, build.vmReg(2)); - build.inst(IrCmd::CHECK_TAG, tag2, build.constTag(tnumber), fallback2); - build.inst(IrCmd::JUMP, block3); +bb_1: + RETURN R0, 0i - build.beginBlock(fallback2); - build.inst(IrCmd::DO_LEN, build.vmReg(0), build.vmReg(2)); - build.inst(IrCmd::JUMP, block3); +)"); +} - build.beginBlock(block3); - build.inst(IrCmd::JUMP, block4); +TEST_CASE_FIXTURE(IrBuilderFixture, "RecursiveSccUseRemoval1") +{ + IrOp entry = build.block(IrBlockKind::Internal); + IrOp block = build.block(IrBlockKind::Internal); + IrOp exit = build.block(IrBlockKind::Internal); + IrOp repeat = build.block(IrBlockKind::Internal); - build.beginBlock(block4); - build.inst(IrCmd::LOP_RETURN, build.constUint(0), build.vmReg(0), build.constInt(0)); + build.beginBlock(entry); + build.inst(IrCmd::RETURN, build.vmReg(0), build.constInt(0)); + + build.beginBlock(block); + build.inst(IrCmd::STORE_TAG, build.vmReg(0), build.constTag(tnumber)); + build.inst(IrCmd::JUMP_IF_TRUTHY, build.vmReg(0), exit, repeat); + + build.beginBlock(exit); + build.inst(IrCmd::RETURN, build.vmReg(0), build.constInt(0)); + + build.beginBlock(repeat); + build.inst(IrCmd::INTERRUPT, build.constUint(0)); + build.inst(IrCmd::JUMP, block); updateUseCounts(build.function); - constPropInBlockChains(build); + constPropInBlockChains(build, true); - CHECK("\n" + toString(build.function, /* includeDetails */ false) == R"( + CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"( bb_0: - %0 = LOAD_TAG R2 - CHECK_TAG %0, tnumber, bb_fallback_1 - JUMP bb_linear_6 + RETURN R0, 0i -bb_fallback_1: - DO_LEN R1, R2 +bb_1: + STORE_TAG R0, tnumber JUMP bb_2 bb_2: - %5 = LOAD_TAG R2 - CHECK_TAG %5, tnumber, bb_fallback_3 - JUMP bb_4 + RETURN R0, 0i -bb_fallback_3: - DO_LEN R0, R2 - JUMP bb_4 +)"); +} -bb_4: +TEST_CASE_FIXTURE(IrBuilderFixture, "RecursiveSccUseRemoval2") +{ + IrOp entry = build.block(IrBlockKind::Internal); + IrOp exit1 = build.block(IrBlockKind::Internal); + IrOp block = build.block(IrBlockKind::Internal); + IrOp exit2 = build.block(IrBlockKind::Internal); + IrOp repeat = build.block(IrBlockKind::Internal); + + build.beginBlock(entry); + build.inst(IrCmd::JUMP_CMP_INT, build.constInt(0), build.constInt(1), build.cond(IrCondition::Equal), block, exit1); + + build.beginBlock(exit1); + build.inst(IrCmd::RETURN, build.vmReg(0), build.constInt(0)); + + build.beginBlock(block); + build.inst(IrCmd::STORE_TAG, build.vmReg(0), build.constTag(tnumber)); + build.inst(IrCmd::JUMP_IF_TRUTHY, build.vmReg(0), exit2, repeat); + + build.beginBlock(exit2); + build.inst(IrCmd::RETURN, build.vmReg(0), build.constInt(0)); + + build.beginBlock(repeat); + build.inst(IrCmd::INTERRUPT, build.constUint(0)); + build.inst(IrCmd::JUMP, block); + + updateUseCounts(build.function); + constPropInBlockChains(build, true); + + CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"( +bb_0: + JUMP bb_1 + +bb_1: + RETURN R0, 0i + +bb_2: + STORE_TAG R0, tnumber + JUMP bb_3 + +bb_3: + RETURN R0, 0i + +)"); +} + +TEST_CASE_FIXTURE(IrBuilderFixture, "IntNumIntPeepholes") +{ + IrOp block = build.block(IrBlockKind::Internal); + + build.beginBlock(block); + + IrOp i1 = build.inst(IrCmd::LOAD_INT, build.vmReg(0)); + IrOp u1 = build.inst(IrCmd::LOAD_INT, build.vmReg(1)); + IrOp ni1 = build.inst(IrCmd::INT_TO_NUM, i1); + IrOp nu1 = build.inst(IrCmd::UINT_TO_NUM, u1); + IrOp i2 = build.inst(IrCmd::NUM_TO_INT, ni1); + IrOp u2 = build.inst(IrCmd::NUM_TO_UINT, nu1); + build.inst(IrCmd::STORE_INT, build.vmReg(0), i2); + build.inst(IrCmd::STORE_INT, build.vmReg(1), u2); + build.inst(IrCmd::RETURN, build.constUint(2)); + + updateUseCounts(build.function); + constPropInBlockChains(build, true); + + CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"( +bb_0: + %0 = LOAD_INT R0 + %1 = LOAD_INT R1 + STORE_INT R0, %0 + STORE_INT R1, %1 + RETURN 2u + +)"); +} + +TEST_CASE_FIXTURE(IrBuilderFixture, "InvalidateReglinkVersion") +{ + IrOp block = build.block(IrBlockKind::Internal); + IrOp fallback = build.block(IrBlockKind::Fallback); + + build.beginBlock(block); + + build.inst(IrCmd::STORE_TAG, build.vmReg(2), build.constTag(tstring)); + IrOp tv2 = build.inst(IrCmd::LOAD_TVALUE, build.vmReg(2)); + build.inst(IrCmd::STORE_TVALUE, build.vmReg(1), tv2); + IrOp ft = build.inst(IrCmd::NEW_TABLE, build.constUint(0), build.constUint(0)); + build.inst(IrCmd::STORE_POINTER, build.vmReg(2), ft); + build.inst(IrCmd::STORE_TAG, build.vmReg(2), build.constTag(ttable)); + IrOp tv1 = build.inst(IrCmd::LOAD_TVALUE, build.vmReg(1)); + build.inst(IrCmd::STORE_TVALUE, build.vmReg(0), tv1); + IrOp tag = build.inst(IrCmd::LOAD_TAG, build.vmReg(0)); + build.inst(IrCmd::CHECK_TAG, tag, build.constTag(ttable), fallback); + build.inst(IrCmd::RETURN, build.constUint(0)); + + build.beginBlock(fallback); + build.inst(IrCmd::RETURN, build.constUint(1)); + + updateUseCounts(build.function); + constPropInBlockChains(build, true); + + CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"( +bb_0: + STORE_TAG R2, tstring + %1 = LOAD_TVALUE R2 + STORE_TVALUE R1, %1 + %3 = NEW_TABLE 0u, 0u + STORE_POINTER R2, %3 + STORE_TAG R2, ttable + STORE_TVALUE R0, %1 + %8 = LOAD_TAG R0 + CHECK_TAG %8, ttable, bb_fallback_1 + RETURN 0u + +bb_fallback_1: + RETURN 1u + +)"); +} + +TEST_SUITE_END(); + +TEST_SUITE_BEGIN("LinearExecutionFlowExtraction"); + +TEST_CASE_FIXTURE(IrBuilderFixture, "SimplePathExtraction") +{ + IrOp block1 = build.block(IrBlockKind::Internal); + IrOp fallback1 = build.block(IrBlockKind::Fallback); + IrOp block2 = build.block(IrBlockKind::Internal); + IrOp fallback2 = build.block(IrBlockKind::Fallback); + IrOp block3 = build.block(IrBlockKind::Internal); + IrOp block4 = build.block(IrBlockKind::Internal); + + build.beginBlock(block1); + + IrOp tag1 = build.inst(IrCmd::LOAD_TAG, build.vmReg(2)); + build.inst(IrCmd::CHECK_TAG, tag1, build.constTag(tnumber), fallback1); + build.inst(IrCmd::JUMP, block2); + + build.beginBlock(fallback1); + build.inst(IrCmd::DO_LEN, build.vmReg(1), build.vmReg(2)); + build.inst(IrCmd::JUMP, block2); + + build.beginBlock(block2); + IrOp tag2 = build.inst(IrCmd::LOAD_TAG, build.vmReg(2)); + build.inst(IrCmd::CHECK_TAG, tag2, build.constTag(tnumber), fallback2); + build.inst(IrCmd::JUMP, block3); + + build.beginBlock(fallback2); + build.inst(IrCmd::DO_LEN, build.vmReg(0), build.vmReg(2)); + build.inst(IrCmd::JUMP, block3); + + build.beginBlock(block3); + build.inst(IrCmd::JUMP, block4); + + build.beginBlock(block4); + build.inst(IrCmd::RETURN, build.vmReg(0), build.constInt(0)); + + updateUseCounts(build.function); + constPropInBlockChains(build, true); + createLinearBlocks(build, true); + + CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"( +bb_0: + %0 = LOAD_TAG R2 + CHECK_TAG %0, tnumber, bb_fallback_1 + JUMP bb_linear_6 + +bb_fallback_1: + DO_LEN R1, R2 + JUMP bb_2 + +bb_2: + %5 = LOAD_TAG R2 + CHECK_TAG %5, tnumber, bb_fallback_3 + JUMP bb_4 + +bb_fallback_3: + DO_LEN R0, R2 + JUMP bb_4 + +bb_4: JUMP bb_5 bb_5: - LOP_RETURN 0u, R0, 0i + RETURN R0, 0i bb_linear_6: - LOP_RETURN 0u, R0, 0i + RETURN R0, 0i )"); } @@ -1306,16 +1828,17 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "NoPathExtractionForBlocksWithLiveOutValues" build.beginBlock(block4a); build.inst(IrCmd::STORE_TAG, build.vmReg(0), tag3a); - build.inst(IrCmd::LOP_RETURN, build.constUint(0), build.vmReg(0), build.constInt(0)); + build.inst(IrCmd::RETURN, build.vmReg(0), build.constInt(0)); build.beginBlock(block4b); build.inst(IrCmd::STORE_TAG, build.vmReg(0), tag3a); - build.inst(IrCmd::LOP_RETURN, build.constUint(0), build.vmReg(0), build.constInt(0)); + build.inst(IrCmd::RETURN, build.vmReg(0), build.constInt(0)); updateUseCounts(build.function); - constPropInBlockChains(build); + constPropInBlockChains(build, true); + createLinearBlocks(build, true); - CHECK("\n" + toString(build.function, /* includeDetails */ false) == R"( + CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"( bb_0: %0 = LOAD_TAG R2 CHECK_TAG %0, tnumber, bb_fallback_1 @@ -1340,11 +1863,11 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "NoPathExtractionForBlocksWithLiveOutValues" bb_5: STORE_TAG R0, %10 - LOP_RETURN 0u, R0, 0i + RETURN R0, 0i bb_6: STORE_TAG R0, %10 - LOP_RETURN 0u, R0, 0i + RETURN R0, 0i )"); } @@ -1364,9 +1887,10 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "InfiniteLoopInPathAnalysis") build.inst(IrCmd::JUMP, block2); updateUseCounts(build.function); - constPropInBlockChains(build); + constPropInBlockChains(build, true); + createLinearBlocks(build, true); - CHECK("\n" + toString(build.function, /* includeDetails */ false) == R"( + CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"( bb_0: STORE_TAG R0, tnumber JUMP bb_1 @@ -1378,4 +1902,2552 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "InfiniteLoopInPathAnalysis") )"); } +TEST_CASE_FIXTURE(IrBuilderFixture, "PartialStoreInvalidation") +{ + IrOp block = build.block(IrBlockKind::Internal); + + build.beginBlock(block); + + build.inst(IrCmd::STORE_TVALUE, build.vmReg(1), build.inst(IrCmd::LOAD_TVALUE, build.vmReg(0))); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(0), build.constDouble(0.5)); + build.inst(IrCmd::STORE_TVALUE, build.vmReg(1), build.inst(IrCmd::LOAD_TVALUE, build.vmReg(0))); // Should be reloaded + build.inst(IrCmd::STORE_TAG, build.vmReg(0), build.constTag(tnumber)); + build.inst(IrCmd::STORE_TVALUE, build.vmReg(1), build.inst(IrCmd::LOAD_TVALUE, build.vmReg(0))); + + build.inst(IrCmd::RETURN, build.constUint(0)); + + updateUseCounts(build.function); + constPropInBlockChains(build, true); + + CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"( +bb_0: + %0 = LOAD_TVALUE R0 + STORE_TVALUE R1, %0 + STORE_DOUBLE R0, 0.5 + %3 = LOAD_TVALUE R0 + STORE_TVALUE R1, %3 + STORE_TAG R0, tnumber + STORE_SPLIT_TVALUE R1, tnumber, 0.5 + RETURN 0u + +)"); +} + +TEST_CASE_FIXTURE(IrBuilderFixture, "VaridicRegisterRangeInvalidation") +{ + IrOp block = build.block(IrBlockKind::Internal); + + build.beginBlock(block); + + build.inst(IrCmd::STORE_TAG, build.vmReg(2), build.constTag(tnumber)); + build.inst(IrCmd::FALLBACK_GETVARARGS, build.constUint(0), build.vmReg(1), build.constInt(-1)); + build.inst(IrCmd::STORE_TAG, build.vmReg(2), build.constTag(tnumber)); + build.inst(IrCmd::RETURN, build.constUint(0)); + + updateUseCounts(build.function); + constPropInBlockChains(build, true); + + CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"( +bb_0: + STORE_TAG R2, tnumber + FALLBACK_GETVARARGS 0u, R1, -1i + STORE_TAG R2, tnumber + RETURN 0u + +)"); +} + +TEST_CASE_FIXTURE(IrBuilderFixture, "LoadPropagatesOnlyRightType") +{ + IrOp block = build.block(IrBlockKind::Internal); + + build.beginBlock(block); + + build.inst(IrCmd::STORE_INT, build.vmReg(0), build.constInt(2)); + IrOp value1 = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(0)); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(1), value1); + IrOp value2 = build.inst(IrCmd::LOAD_INT, build.vmReg(1)); + build.inst(IrCmd::STORE_INT, build.vmReg(2), value2); + build.inst(IrCmd::RETURN, build.constUint(0)); + + updateUseCounts(build.function); + constPropInBlockChains(build, true); + + CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"( +bb_0: + STORE_INT R0, 2i + %1 = LOAD_DOUBLE R0 + STORE_DOUBLE R1, %1 + %3 = LOAD_INT R1 + STORE_INT R2, %3 + RETURN 0u + +)"); +} + +TEST_CASE_FIXTURE(IrBuilderFixture, "DuplicateHashSlotChecks") +{ + IrOp block = build.block(IrBlockKind::Internal); + IrOp fallback = build.block(IrBlockKind::Fallback); + + build.beginBlock(block); + + // This roughly corresponds to 'return t.a + t.a' + IrOp table1 = build.inst(IrCmd::LOAD_POINTER, build.vmReg(1)); + IrOp slot1 = build.inst(IrCmd::GET_SLOT_NODE_ADDR, table1, build.constUint(3), build.vmConst(1)); + build.inst(IrCmd::CHECK_SLOT_MATCH, slot1, build.vmConst(1), fallback); + IrOp value1 = build.inst(IrCmd::LOAD_TVALUE, slot1, build.constInt(0)); + build.inst(IrCmd::STORE_TVALUE, build.vmReg(3), value1); + + IrOp slot1b = build.inst(IrCmd::GET_SLOT_NODE_ADDR, table1, build.constUint(8), build.vmConst(1)); // This will be removed + build.inst(IrCmd::CHECK_SLOT_MATCH, slot1b, build.vmConst(1), fallback); // Key will be replaced with undef here + IrOp value1b = build.inst(IrCmd::LOAD_TVALUE, slot1b, build.constInt(0)); + build.inst(IrCmd::STORE_TVALUE, build.vmReg(4), value1b); + + IrOp a = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(3)); + IrOp b = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(4)); + IrOp sum = build.inst(IrCmd::ADD_NUM, a, b); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(2), sum); + + build.inst(IrCmd::RETURN, build.vmReg(2), build.constUint(1)); + + build.beginBlock(fallback); + build.inst(IrCmd::RETURN, build.vmReg(0), build.constUint(1)); + + updateUseCounts(build.function); + constPropInBlockChains(build, true); + + // In the future, we might even see duplicate identical TValue loads go away + // In the future, we might even see loads of different VM regs with the same value go away + CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"( +bb_0: + %0 = LOAD_POINTER R1 + %1 = GET_SLOT_NODE_ADDR %0, 3u, K1 + CHECK_SLOT_MATCH %1, K1, bb_fallback_1 + %3 = LOAD_TVALUE %1, 0i + STORE_TVALUE R3, %3 + CHECK_NODE_VALUE %1, bb_fallback_1 + %7 = LOAD_TVALUE %1, 0i + STORE_TVALUE R4, %7 + %9 = LOAD_DOUBLE R3 + %10 = LOAD_DOUBLE R4 + %11 = ADD_NUM %9, %10 + STORE_DOUBLE R2, %11 + RETURN R2, 1u + +bb_fallback_1: + RETURN R0, 1u + +)"); +} + +TEST_CASE_FIXTURE(IrBuilderFixture, "DuplicateHashSlotChecksAvoidNil") +{ + IrOp block = build.block(IrBlockKind::Internal); + IrOp fallback = build.block(IrBlockKind::Fallback); + + build.beginBlock(block); + + IrOp table1 = build.inst(IrCmd::LOAD_POINTER, build.vmReg(1)); + IrOp slot1 = build.inst(IrCmd::GET_SLOT_NODE_ADDR, table1, build.constUint(3), build.vmConst(1)); + build.inst(IrCmd::CHECK_SLOT_MATCH, slot1, build.vmConst(1), fallback); + IrOp value1 = build.inst(IrCmd::LOAD_TVALUE, slot1, build.constInt(0)); + build.inst(IrCmd::STORE_TVALUE, build.vmReg(3), value1); + + IrOp table2 = build.inst(IrCmd::LOAD_POINTER, build.vmReg(2)); + IrOp slot2 = build.inst(IrCmd::GET_SLOT_NODE_ADDR, table2, build.constUint(6), build.vmConst(1)); + build.inst(IrCmd::CHECK_SLOT_MATCH, slot2, build.vmConst(1), fallback); + build.inst(IrCmd::CHECK_READONLY, table2, fallback); + + build.inst(IrCmd::STORE_TAG, build.vmReg(4), build.constTag(tnil)); + IrOp valueNil = build.inst(IrCmd::LOAD_TVALUE, build.vmReg(4)); + build.inst(IrCmd::STORE_TVALUE, slot2, valueNil, build.constInt(0)); + + // In the future, we might get to track that value became 'nil' and that fallback will be taken + IrOp slot1b = build.inst(IrCmd::GET_SLOT_NODE_ADDR, table1, build.constUint(8), build.vmConst(1)); // This will be removed + build.inst(IrCmd::CHECK_SLOT_MATCH, slot1b, build.vmConst(1), fallback); // Key will be replaced with undef here + IrOp value1b = build.inst(IrCmd::LOAD_TVALUE, slot1b, build.constInt(0)); + build.inst(IrCmd::STORE_TVALUE, build.vmReg(3), value1b); + + IrOp slot2b = build.inst(IrCmd::GET_SLOT_NODE_ADDR, table2, build.constUint(11), build.vmConst(1)); // This will be removed + build.inst(IrCmd::CHECK_SLOT_MATCH, slot2b, build.vmConst(1), fallback); // Key will be replaced with undef here + build.inst(IrCmd::CHECK_READONLY, table2, fallback); + + build.inst(IrCmd::STORE_SPLIT_TVALUE, slot2b, build.constTag(tnumber), build.constDouble(1), build.constInt(0)); + + build.inst(IrCmd::RETURN, build.vmReg(3), build.constUint(2)); + + build.beginBlock(fallback); + build.inst(IrCmd::RETURN, build.vmReg(1), build.constUint(2)); + + updateUseCounts(build.function); + constPropInBlockChains(build, true); + + CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"( +bb_0: + %0 = LOAD_POINTER R1 + %1 = GET_SLOT_NODE_ADDR %0, 3u, K1 + CHECK_SLOT_MATCH %1, K1, bb_fallback_1 + %3 = LOAD_TVALUE %1, 0i + STORE_TVALUE R3, %3 + %5 = LOAD_POINTER R2 + %6 = GET_SLOT_NODE_ADDR %5, 6u, K1 + CHECK_SLOT_MATCH %6, K1, bb_fallback_1 + CHECK_READONLY %5, bb_fallback_1 + STORE_TAG R4, tnil + %10 = LOAD_TVALUE R4 + STORE_TVALUE %6, %10, 0i + CHECK_NODE_VALUE %1, bb_fallback_1 + %14 = LOAD_TVALUE %1, 0i + STORE_TVALUE R3, %14 + CHECK_NODE_VALUE %6, bb_fallback_1 + STORE_SPLIT_TVALUE %6, tnumber, 1, 0i + RETURN R3, 2u + +bb_fallback_1: + RETURN R1, 2u + +)"); +} + +TEST_CASE_FIXTURE(IrBuilderFixture, "DuplicateHashSlotChecksInvalidation") +{ + IrOp block = build.block(IrBlockKind::Internal); + IrOp fallback = build.block(IrBlockKind::Fallback); + + build.beginBlock(block); + + // This roughly corresponds to 'return t.a + t.a' with a stange GC assist in the middle + IrOp table1 = build.inst(IrCmd::LOAD_POINTER, build.vmReg(1)); + IrOp slot1 = build.inst(IrCmd::GET_SLOT_NODE_ADDR, table1, build.constUint(3), build.vmConst(1)); + build.inst(IrCmd::CHECK_SLOT_MATCH, slot1, build.vmConst(1), fallback); + IrOp value1 = build.inst(IrCmd::LOAD_TVALUE, slot1, build.constInt(0)); + build.inst(IrCmd::STORE_TVALUE, build.vmReg(3), value1); + + build.inst(IrCmd::CHECK_GC); + + IrOp slot1b = build.inst(IrCmd::GET_SLOT_NODE_ADDR, table1, build.constUint(8), build.vmConst(1)); + build.inst(IrCmd::CHECK_SLOT_MATCH, slot1b, build.vmConst(1), fallback); + IrOp value1b = build.inst(IrCmd::LOAD_TVALUE, slot1b, build.constInt(0)); + build.inst(IrCmd::STORE_TVALUE, build.vmReg(4), value1b); + + IrOp a = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(3)); + IrOp b = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(4)); + IrOp sum = build.inst(IrCmd::ADD_NUM, a, b); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(2), sum); + + build.inst(IrCmd::RETURN, build.vmReg(2), build.constUint(1)); + + build.beginBlock(fallback); + build.inst(IrCmd::RETURN, build.vmReg(0), build.constUint(1)); + + updateUseCounts(build.function); + constPropInBlockChains(build, true); + + // In the future, we might even see duplicate identical TValue loads go away + // In the future, we might even see loads of different VM regs with the same value go away + CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"( +bb_0: + %0 = LOAD_POINTER R1 + %1 = GET_SLOT_NODE_ADDR %0, 3u, K1 + CHECK_SLOT_MATCH %1, K1, bb_fallback_1 + %3 = LOAD_TVALUE %1, 0i + STORE_TVALUE R3, %3 + CHECK_GC + %6 = GET_SLOT_NODE_ADDR %0, 8u, K1 + CHECK_SLOT_MATCH %6, K1, bb_fallback_1 + %8 = LOAD_TVALUE %6, 0i + STORE_TVALUE R4, %8 + %10 = LOAD_DOUBLE R3 + %11 = LOAD_DOUBLE R4 + %12 = ADD_NUM %10, %11 + STORE_DOUBLE R2, %12 + RETURN R2, 1u + +bb_fallback_1: + RETURN R0, 1u + +)"); +} + +TEST_CASE_FIXTURE(IrBuilderFixture, "DuplicateArrayElemChecksSameIndex") +{ + IrOp block = build.block(IrBlockKind::Internal); + IrOp fallback = build.block(IrBlockKind::Fallback); + + build.beginBlock(block); + + // This roughly corresponds to 'return t[1] + t[1]' + IrOp table1 = build.inst(IrCmd::LOAD_POINTER, build.vmReg(1)); + build.inst(IrCmd::CHECK_ARRAY_SIZE, table1, build.constInt(0), fallback); + IrOp elem1 = build.inst(IrCmd::GET_ARR_ADDR, table1, build.constInt(0)); + IrOp value1 = build.inst(IrCmd::LOAD_TVALUE, elem1, build.constInt(0)); + build.inst(IrCmd::STORE_TVALUE, build.vmReg(3), value1); + + build.inst(IrCmd::CHECK_ARRAY_SIZE, table1, build.constInt(0), fallback); // This will be removed + IrOp elem2 = build.inst(IrCmd::GET_ARR_ADDR, table1, build.constInt(0)); // And this will be substituted + IrOp value1b = build.inst(IrCmd::LOAD_TVALUE, elem2, build.constInt(0)); + build.inst(IrCmd::STORE_TVALUE, build.vmReg(4), value1b); + + IrOp a = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(3)); + IrOp b = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(4)); + IrOp sum = build.inst(IrCmd::ADD_NUM, a, b); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(2), sum); + + build.inst(IrCmd::RETURN, build.vmReg(2), build.constUint(1)); + + build.beginBlock(fallback); + build.inst(IrCmd::RETURN, build.vmReg(0), build.constUint(1)); + + updateUseCounts(build.function); + constPropInBlockChains(build, true); + + // In the future, we might even see duplicate identical TValue loads go away + // In the future, we might even see loads of different VM regs with the same value go away + CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"( +bb_0: + %0 = LOAD_POINTER R1 + CHECK_ARRAY_SIZE %0, 0i, bb_fallback_1 + %2 = GET_ARR_ADDR %0, 0i + %3 = LOAD_TVALUE %2, 0i + STORE_TVALUE R3, %3 + %7 = LOAD_TVALUE %2, 0i + STORE_TVALUE R4, %7 + %9 = LOAD_DOUBLE R3 + %10 = LOAD_DOUBLE R4 + %11 = ADD_NUM %9, %10 + STORE_DOUBLE R2, %11 + RETURN R2, 1u + +bb_fallback_1: + RETURN R0, 1u + +)"); +} + +TEST_CASE_FIXTURE(IrBuilderFixture, "DuplicateArrayElemChecksSameValue") +{ + IrOp block = build.block(IrBlockKind::Internal); + IrOp fallback = build.block(IrBlockKind::Fallback); + + build.beginBlock(block); + + // This roughly corresponds to 'return t[i] + t[i]' + IrOp table1 = build.inst(IrCmd::LOAD_POINTER, build.vmReg(1)); + IrOp index = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(2)); + IrOp validIndex = build.inst(IrCmd::TRY_NUM_TO_INDEX, index, fallback); + IrOp validOffset = build.inst(IrCmd::SUB_INT, validIndex, build.constInt(1)); + build.inst(IrCmd::CHECK_ARRAY_SIZE, table1, validOffset, fallback); + IrOp elem1 = build.inst(IrCmd::GET_ARR_ADDR, table1, build.constInt(0)); + IrOp value1 = build.inst(IrCmd::LOAD_TVALUE, elem1, build.constInt(0)); + build.inst(IrCmd::STORE_TVALUE, build.vmReg(3), value1); + + IrOp validIndex2 = build.inst(IrCmd::TRY_NUM_TO_INDEX, index, fallback); + IrOp validOffset2 = build.inst(IrCmd::SUB_INT, validIndex2, build.constInt(1)); + build.inst(IrCmd::CHECK_ARRAY_SIZE, table1, validOffset2, fallback); // This will be removed + IrOp elem2 = build.inst(IrCmd::GET_ARR_ADDR, table1, build.constInt(0)); // And this will be substituted + IrOp value1b = build.inst(IrCmd::LOAD_TVALUE, elem2, build.constInt(0)); + build.inst(IrCmd::STORE_TVALUE, build.vmReg(4), value1b); + + IrOp a = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(3)); + IrOp b = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(4)); + IrOp sum = build.inst(IrCmd::ADD_NUM, a, b); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(2), sum); + + build.inst(IrCmd::RETURN, build.vmReg(2), build.constUint(1)); + + build.beginBlock(fallback); + build.inst(IrCmd::RETURN, build.vmReg(0), build.constUint(1)); + + updateUseCounts(build.function); + constPropInBlockChains(build, true); + + // In the future, we might even see duplicate identical TValue loads go away + // In the future, we might even see loads of different VM regs with the same value go away + CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"( +bb_0: + %0 = LOAD_POINTER R1 + %1 = LOAD_DOUBLE R2 + %2 = TRY_NUM_TO_INDEX %1, bb_fallback_1 + %3 = SUB_INT %2, 1i + CHECK_ARRAY_SIZE %0, %3, bb_fallback_1 + %5 = GET_ARR_ADDR %0, 0i + %6 = LOAD_TVALUE %5, 0i + STORE_TVALUE R3, %6 + %12 = LOAD_TVALUE %5, 0i + STORE_TVALUE R4, %12 + %14 = LOAD_DOUBLE R3 + %15 = LOAD_DOUBLE R4 + %16 = ADD_NUM %14, %15 + STORE_DOUBLE R2, %16 + RETURN R2, 1u + +bb_fallback_1: + RETURN R0, 1u + +)"); +} + +TEST_CASE_FIXTURE(IrBuilderFixture, "DuplicateArrayElemChecksLowerIndex") +{ + IrOp block = build.block(IrBlockKind::Internal); + IrOp fallback = build.block(IrBlockKind::Fallback); + + build.beginBlock(block); + + // This roughly corresponds to 'return t[2] + t[1]' + IrOp table1 = build.inst(IrCmd::LOAD_POINTER, build.vmReg(1)); + build.inst(IrCmd::CHECK_ARRAY_SIZE, table1, build.constInt(1), fallback); + IrOp elem1 = build.inst(IrCmd::GET_ARR_ADDR, table1, build.constInt(1)); + IrOp value1 = build.inst(IrCmd::LOAD_TVALUE, elem1, build.constInt(0)); + build.inst(IrCmd::STORE_TVALUE, build.vmReg(3), value1); + + build.inst(IrCmd::CHECK_ARRAY_SIZE, table1, build.constInt(0), fallback); // This will be removed + IrOp elem2 = build.inst(IrCmd::GET_ARR_ADDR, table1, build.constInt(0)); + IrOp value1b = build.inst(IrCmd::LOAD_TVALUE, elem2, build.constInt(0)); + build.inst(IrCmd::STORE_TVALUE, build.vmReg(4), value1b); + + IrOp a = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(3)); + IrOp b = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(4)); + IrOp sum = build.inst(IrCmd::ADD_NUM, a, b); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(2), sum); + + build.inst(IrCmd::RETURN, build.vmReg(2), build.constUint(1)); + + build.beginBlock(fallback); + build.inst(IrCmd::RETURN, build.vmReg(0), build.constUint(1)); + + updateUseCounts(build.function); + constPropInBlockChains(build, true); + + CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"( +bb_0: + %0 = LOAD_POINTER R1 + CHECK_ARRAY_SIZE %0, 1i, bb_fallback_1 + %2 = GET_ARR_ADDR %0, 1i + %3 = LOAD_TVALUE %2, 0i + STORE_TVALUE R3, %3 + %6 = GET_ARR_ADDR %0, 0i + %7 = LOAD_TVALUE %6, 0i + STORE_TVALUE R4, %7 + %9 = LOAD_DOUBLE R3 + %10 = LOAD_DOUBLE R4 + %11 = ADD_NUM %9, %10 + STORE_DOUBLE R2, %11 + RETURN R2, 1u + +bb_fallback_1: + RETURN R0, 1u + +)"); +} + +TEST_CASE_FIXTURE(IrBuilderFixture, "DuplicateArrayElemChecksInvalidations") +{ + IrOp block = build.block(IrBlockKind::Internal); + IrOp fallback = build.block(IrBlockKind::Fallback); + + build.beginBlock(block); + + // This roughly corresponds to 'return t[1] + t[1]' with a strange table.insert in the middle + IrOp table1 = build.inst(IrCmd::LOAD_POINTER, build.vmReg(1)); + build.inst(IrCmd::CHECK_ARRAY_SIZE, table1, build.constInt(0), fallback); + IrOp elem1 = build.inst(IrCmd::GET_ARR_ADDR, table1, build.constInt(0)); + IrOp value1 = build.inst(IrCmd::LOAD_TVALUE, elem1, build.constInt(0)); + build.inst(IrCmd::STORE_TVALUE, build.vmReg(3), value1); + + build.inst(IrCmd::TABLE_SETNUM, table1, build.constInt(2)); + + build.inst(IrCmd::CHECK_ARRAY_SIZE, table1, build.constInt(0), fallback); // This will be removed + IrOp elem2 = build.inst(IrCmd::GET_ARR_ADDR, table1, build.constInt(0)); // And this will be substituted + IrOp value1b = build.inst(IrCmd::LOAD_TVALUE, elem2, build.constInt(0)); + build.inst(IrCmd::STORE_TVALUE, build.vmReg(4), value1b); + + IrOp a = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(3)); + IrOp b = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(4)); + IrOp sum = build.inst(IrCmd::ADD_NUM, a, b); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(2), sum); + + build.inst(IrCmd::RETURN, build.vmReg(2), build.constUint(1)); + + build.beginBlock(fallback); + build.inst(IrCmd::RETURN, build.vmReg(0), build.constUint(1)); + + updateUseCounts(build.function); + constPropInBlockChains(build, true); + + CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"( +bb_0: + %0 = LOAD_POINTER R1 + CHECK_ARRAY_SIZE %0, 0i, bb_fallback_1 + %2 = GET_ARR_ADDR %0, 0i + %3 = LOAD_TVALUE %2, 0i + STORE_TVALUE R3, %3 + %5 = TABLE_SETNUM %0, 2i + CHECK_ARRAY_SIZE %0, 0i, bb_fallback_1 + %7 = GET_ARR_ADDR %0, 0i + %8 = LOAD_TVALUE %7, 0i + STORE_TVALUE R4, %8 + %10 = LOAD_DOUBLE R3 + %11 = LOAD_DOUBLE R4 + %12 = ADD_NUM %10, %11 + STORE_DOUBLE R2, %12 + RETURN R2, 1u + +bb_fallback_1: + RETURN R0, 1u + +)"); +} + +TEST_CASE_FIXTURE(IrBuilderFixture, "ArrayElemChecksNegativeIndex") +{ + IrOp block = build.block(IrBlockKind::Internal); + IrOp fallback = build.block(IrBlockKind::Fallback); + + build.beginBlock(block); + + // This roughly corresponds to 'return t[1] + t[0]' + IrOp table1 = build.inst(IrCmd::LOAD_POINTER, build.vmReg(1)); + build.inst(IrCmd::CHECK_ARRAY_SIZE, table1, build.constInt(0), fallback); + IrOp elem1 = build.inst(IrCmd::GET_ARR_ADDR, table1, build.constInt(0)); + IrOp value1 = build.inst(IrCmd::LOAD_TVALUE, elem1, build.constInt(0)); + build.inst(IrCmd::STORE_TVALUE, build.vmReg(3), value1); + + build.inst(IrCmd::CHECK_ARRAY_SIZE, table1, build.constInt(-1), fallback); // This will jump directly to fallback + IrOp elem2 = build.inst(IrCmd::GET_ARR_ADDR, table1, build.constInt(-1)); + IrOp value1b = build.inst(IrCmd::LOAD_TVALUE, elem2, build.constInt(0)); + build.inst(IrCmd::STORE_TVALUE, build.vmReg(4), value1b); + + IrOp a = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(3)); + IrOp b = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(4)); + IrOp sum = build.inst(IrCmd::ADD_NUM, a, b); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(2), sum); + + build.inst(IrCmd::RETURN, build.vmReg(2), build.constUint(1)); + + build.beginBlock(fallback); + build.inst(IrCmd::RETURN, build.vmReg(0), build.constUint(1)); + + updateUseCounts(build.function); + constPropInBlockChains(build, true); + + CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"( +bb_0: + %0 = LOAD_POINTER R1 + CHECK_ARRAY_SIZE %0, 0i, bb_fallback_1 + %2 = GET_ARR_ADDR %0, 0i + %3 = LOAD_TVALUE %2, 0i + STORE_TVALUE R3, %3 + JUMP bb_fallback_1 + +bb_fallback_1: + RETURN R0, 1u + +)"); +} + +TEST_CASE_FIXTURE(IrBuilderFixture, "DuplicateBufferLengthChecks") +{ + IrOp block = build.block(IrBlockKind::Internal); + IrOp fallback = build.block(IrBlockKind::Fallback); + + build.beginBlock(block); + + IrOp sourceBuf = build.inst(IrCmd::LOAD_TVALUE, build.vmReg(0)); + + build.inst(IrCmd::STORE_TVALUE, build.vmReg(2), sourceBuf); + IrOp buffer1 = build.inst(IrCmd::LOAD_POINTER, build.vmReg(2)); + build.inst(IrCmd::CHECK_BUFFER_LEN, buffer1, build.constInt(12), build.constInt(4), fallback); + build.inst(IrCmd::BUFFER_WRITEI32, buffer1, build.constInt(12), build.constInt(32)); + + // Now with lower index, should be removed + build.inst(IrCmd::STORE_TVALUE, build.vmReg(2), sourceBuf); + IrOp buffer2 = build.inst(IrCmd::LOAD_POINTER, build.vmReg(2)); + build.inst(IrCmd::CHECK_BUFFER_LEN, buffer2, build.constInt(8), build.constInt(4), fallback); + build.inst(IrCmd::BUFFER_WRITEI32, buffer2, build.constInt(8), build.constInt(30)); + + // Now with higher index, should raise the initial check bound + build.inst(IrCmd::STORE_TVALUE, build.vmReg(2), sourceBuf); + IrOp buffer3 = build.inst(IrCmd::LOAD_POINTER, build.vmReg(2)); + build.inst(IrCmd::CHECK_BUFFER_LEN, buffer3, build.constInt(16), build.constInt(4), fallback); + build.inst(IrCmd::BUFFER_WRITEI32, buffer3, build.constInt(16), build.constInt(60)); + + // Now with different access size, should not reuse previous checks (can be improved in the future) + build.inst(IrCmd::CHECK_BUFFER_LEN, buffer3, build.constInt(16), build.constInt(2), fallback); + build.inst(IrCmd::BUFFER_WRITEI16, buffer3, build.constInt(16), build.constInt(55)); + + // Now with same, but unknown index value + IrOp index = build.inst(IrCmd::LOAD_INT, build.vmReg(1)); + build.inst(IrCmd::CHECK_BUFFER_LEN, buffer3, index, build.constInt(2), fallback); + build.inst(IrCmd::BUFFER_WRITEI16, buffer3, index, build.constInt(1)); + build.inst(IrCmd::CHECK_BUFFER_LEN, buffer3, index, build.constInt(2), fallback); + build.inst(IrCmd::BUFFER_WRITEI16, buffer3, index, build.constInt(2)); + + build.inst(IrCmd::RETURN, build.vmReg(1), build.constUint(1)); + + build.beginBlock(fallback); + build.inst(IrCmd::RETURN, build.vmReg(0), build.constUint(1)); + + updateUseCounts(build.function); + constPropInBlockChains(build, true); + + CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"( +bb_0: + %0 = LOAD_TVALUE R0 + STORE_TVALUE R2, %0 + %2 = LOAD_POINTER R2 + CHECK_BUFFER_LEN %2, 16i, 4i, bb_fallback_1 + BUFFER_WRITEI32 %2, 12i, 32i + BUFFER_WRITEI32 %2, 8i, 30i + BUFFER_WRITEI32 %2, 16i, 60i + CHECK_BUFFER_LEN %2, 16i, 2i, bb_fallback_1 + BUFFER_WRITEI16 %2, 16i, 55i + %15 = LOAD_INT R1 + CHECK_BUFFER_LEN %2, %15, 2i, bb_fallback_1 + BUFFER_WRITEI16 %2, %15, 1i + BUFFER_WRITEI16 %2, %15, 2i + RETURN R1, 1u + +bb_fallback_1: + RETURN R0, 1u + +)"); +} + +TEST_CASE_FIXTURE(IrBuilderFixture, "BufferLenghtChecksNegativeIndex") +{ + IrOp block = build.block(IrBlockKind::Internal); + IrOp fallback = build.block(IrBlockKind::Fallback); + + build.beginBlock(block); + + IrOp sourceBuf = build.inst(IrCmd::LOAD_TVALUE, build.vmReg(0)); + + build.inst(IrCmd::STORE_TVALUE, build.vmReg(2), sourceBuf); + IrOp buffer1 = build.inst(IrCmd::LOAD_POINTER, build.vmReg(2)); + build.inst(IrCmd::CHECK_BUFFER_LEN, buffer1, build.constInt(-4), build.constInt(4), fallback); + build.inst(IrCmd::BUFFER_WRITEI32, buffer1, build.constInt(-4), build.constInt(32)); + build.inst(IrCmd::RETURN, build.vmReg(1), build.constUint(1)); + + build.beginBlock(fallback); + build.inst(IrCmd::RETURN, build.vmReg(0), build.constUint(1)); + + updateUseCounts(build.function); + constPropInBlockChains(build, true); + + CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"( +bb_0: + %0 = LOAD_TVALUE R0 + STORE_TVALUE R2, %0 + JUMP bb_fallback_1 + +bb_fallback_1: + RETURN R0, 1u + +)"); +} + +TEST_CASE_FIXTURE(IrBuilderFixture, "TagVectorSkipErrorFix") +{ + IrOp block = build.block(IrBlockKind::Internal); + + build.beginBlock(block); + + IrOp a = build.inst(IrCmd::LOAD_TVALUE, build.vmReg(0)); + IrOp b = build.inst(IrCmd::LOAD_TVALUE, build.vmReg(1)); + + IrOp mul = build.inst(IrCmd::TAG_VECTOR, build.inst(IrCmd::MUL_VEC, a, b)); + + IrOp t1 = build.inst(IrCmd::TAG_VECTOR, build.inst(IrCmd::ADD_VEC, mul, mul)); + IrOp t2 = build.inst(IrCmd::TAG_VECTOR, build.inst(IrCmd::SUB_VEC, mul, mul)); + + IrOp t3 = build.inst(IrCmd::TAG_VECTOR, build.inst(IrCmd::DIV_VEC, t1, build.inst(IrCmd::UNM_VEC, t2))); + + build.inst(IrCmd::STORE_TVALUE, build.vmReg(0), t3); + build.inst(IrCmd::RETURN, build.vmReg(0), build.constUint(1)); + + updateUseCounts(build.function); + constPropInBlockChains(build, true); + + CHECK("\n" + toString(build.function, IncludeUseInfo::Yes) == R"( +bb_0: ; useCount: 0 + %0 = LOAD_TVALUE R0 ; useCount: 1, lastUse: %0 + %1 = LOAD_TVALUE R1 ; useCount: 1, lastUse: %0 + %2 = MUL_VEC %0, %1 ; useCount: 4, lastUse: %0 + %4 = ADD_VEC %2, %2 ; useCount: 1, lastUse: %0 + %6 = SUB_VEC %2, %2 ; useCount: 1, lastUse: %0 + %8 = UNM_VEC %6 ; useCount: 1, lastUse: %0 + %9 = DIV_VEC %4, %8 ; useCount: 1, lastUse: %0 + %10 = TAG_VECTOR %9 ; useCount: 1, lastUse: %0 + STORE_TVALUE R0, %10 ; %11 + RETURN R0, 1u ; %12 + +)"); +} + +TEST_CASE_FIXTURE(IrBuilderFixture, "ForgprepInvalidation") +{ + IrOp block = build.block(IrBlockKind::Internal); + IrOp followup = build.block(IrBlockKind::Internal); + + build.beginBlock(block); + + IrOp tbl = build.inst(IrCmd::LOAD_POINTER, build.vmReg(0)); + build.inst(IrCmd::CHECK_READONLY, tbl, build.vmExit(1)); + + build.inst(IrCmd::FALLBACK_FORGPREP, build.constUint(2), build.vmReg(1), followup); + + build.beginBlock(followup); + build.inst(IrCmd::CHECK_READONLY, tbl, build.vmExit(2)); + + build.inst(IrCmd::RETURN, build.vmReg(1), build.constInt(3)); + + updateUseCounts(build.function); + computeCfgInfo(build.function); + constPropInBlockChains(build, true); + + CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"( +bb_0: +; successors: bb_1 +; in regs: R0, R1, R2, R3 +; out regs: R1, R2, R3 + %0 = LOAD_POINTER R0 + CHECK_READONLY %0, exit(1) + FALLBACK_FORGPREP 2u, R1, bb_1 + +bb_1: +; predecessors: bb_0 +; in regs: R1, R2, R3 + CHECK_READONLY %0, exit(2) + RETURN R1, 3i + +)"); +} + +TEST_CASE_FIXTURE(IrBuilderFixture, "FastCallEffects1") +{ + IrOp entry = build.block(IrBlockKind::Internal); + + build.beginBlock(entry); + build.inst(IrCmd::FASTCALL, build.constUint(LBF_MATH_FREXP), build.vmReg(1), build.vmReg(2), build.constInt(2)); + build.inst(IrCmd::CHECK_TAG, build.inst(IrCmd::LOAD_TAG, build.vmReg(1)), build.constTag(tnumber), build.vmExit(1)); + build.inst(IrCmd::CHECK_TAG, build.inst(IrCmd::LOAD_TAG, build.vmReg(2)), build.constTag(tnumber), build.vmExit(1)); + build.inst(IrCmd::RETURN, build.vmReg(1), build.constInt(2)); + + updateUseCounts(build.function); + computeCfgInfo(build.function); + constPropInBlockChains(build, true); + + CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"( +bb_0: +; in regs: R2 + FASTCALL 14u, R1, R2, 2i + RETURN R1, 2i + +)"); +} + +TEST_CASE_FIXTURE(IrBuilderFixture, "FastCallEffects2") +{ + IrOp entry = build.block(IrBlockKind::Internal); + + build.beginBlock(entry); + build.inst(IrCmd::FASTCALL, build.constUint(LBF_MATH_MODF), build.vmReg(1), build.vmReg(2), build.constInt(1)); + build.inst(IrCmd::CHECK_TAG, build.inst(IrCmd::LOAD_TAG, build.vmReg(1)), build.constTag(tnumber), build.vmExit(1)); + build.inst(IrCmd::CHECK_TAG, build.inst(IrCmd::LOAD_TAG, build.vmReg(2)), build.constTag(tnumber), build.vmExit(1)); + build.inst(IrCmd::RETURN, build.vmReg(1), build.constInt(2)); + + updateUseCounts(build.function); + computeCfgInfo(build.function); + constPropInBlockChains(build, true); + + CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"( +bb_0: +; in regs: R2 + FASTCALL 20u, R1, R2, 1i + %3 = LOAD_TAG R2 + CHECK_TAG %3, tnumber, exit(1) + RETURN R1, 2i + +)"); +} + +TEST_CASE_FIXTURE(IrBuilderFixture, "InferNumberTagFromLimitedContext") +{ + IrOp entry = build.block(IrBlockKind::Internal); + + build.beginBlock(entry); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(0), build.constDouble(2.0)); + build.inst(IrCmd::CHECK_TAG, build.inst(IrCmd::LOAD_TAG, build.vmReg(0)), build.constTag(ttable), build.vmExit(1)); + build.inst(IrCmd::STORE_TVALUE, build.vmReg(1), build.inst(IrCmd::LOAD_TVALUE, build.vmReg(0))); + build.inst(IrCmd::RETURN, build.vmReg(1), build.constInt(1)); + + updateUseCounts(build.function); + computeCfgInfo(build.function); + constPropInBlockChains(build, true); + + CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"( +bb_0: + STORE_DOUBLE R0, 2 + JUMP exit(1) + +)"); +} + +TEST_CASE_FIXTURE(IrBuilderFixture, "DoNotProduceInvalidSplitStore1") +{ + IrOp entry = build.block(IrBlockKind::Internal); + + build.beginBlock(entry); + build.inst(IrCmd::STORE_INT, build.vmReg(0), build.constInt(1)); + build.inst(IrCmd::CHECK_TAG, build.inst(IrCmd::LOAD_TAG, build.vmReg(0)), build.constTag(ttable), build.vmExit(1)); + build.inst(IrCmd::STORE_TVALUE, build.vmReg(1), build.inst(IrCmd::LOAD_TVALUE, build.vmReg(0))); + build.inst(IrCmd::RETURN, build.vmReg(1), build.constInt(1)); + + updateUseCounts(build.function); + computeCfgInfo(build.function); + constPropInBlockChains(build, true); + + CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"( +bb_0: + STORE_INT R0, 1i + %1 = LOAD_TAG R0 + CHECK_TAG %1, ttable, exit(1) + %3 = LOAD_TVALUE R0 + STORE_TVALUE R1, %3 + RETURN R1, 1i + +)"); +} + +TEST_CASE_FIXTURE(IrBuilderFixture, "DoNotProduceInvalidSplitStore2") +{ + IrOp entry = build.block(IrBlockKind::Internal); + + build.beginBlock(entry); + build.inst(IrCmd::STORE_INT, build.vmReg(0), build.constInt(1)); + build.inst(IrCmd::CHECK_TAG, build.inst(IrCmd::LOAD_TAG, build.vmReg(0)), build.constTag(tnumber), build.vmExit(1)); + build.inst(IrCmd::STORE_TVALUE, build.vmReg(1), build.inst(IrCmd::LOAD_TVALUE, build.vmReg(0))); + build.inst(IrCmd::RETURN, build.vmReg(1), build.constInt(1)); + + updateUseCounts(build.function); + computeCfgInfo(build.function); + constPropInBlockChains(build, true); + + CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"( +bb_0: + STORE_INT R0, 1i + %1 = LOAD_TAG R0 + CHECK_TAG %1, tnumber, exit(1) + %3 = LOAD_TVALUE R0 + STORE_TVALUE R1, %3 + RETURN R1, 1i + +)"); +} + +TEST_SUITE_END(); + +TEST_SUITE_BEGIN("Analysis"); + +TEST_CASE_FIXTURE(IrBuilderFixture, "SimpleDiamond") +{ + IrOp entry = build.block(IrBlockKind::Internal); + IrOp a = build.block(IrBlockKind::Internal); + IrOp b = build.block(IrBlockKind::Internal); + IrOp exit = build.block(IrBlockKind::Internal); + + build.beginBlock(entry); + build.inst(IrCmd::JUMP_EQ_TAG, build.inst(IrCmd::LOAD_TAG, build.vmReg(0)), build.constTag(tnumber), a, b); + + build.beginBlock(a); + build.inst(IrCmd::STORE_TVALUE, build.vmReg(2), build.inst(IrCmd::LOAD_TVALUE, build.vmReg(1))); + build.inst(IrCmd::JUMP, exit); + + build.beginBlock(b); + build.inst(IrCmd::STORE_TVALUE, build.vmReg(3), build.inst(IrCmd::LOAD_TVALUE, build.vmReg(1))); + build.inst(IrCmd::JUMP, exit); + + build.beginBlock(exit); + build.inst(IrCmd::RETURN, build.vmReg(2), build.constInt(2)); + + updateUseCounts(build.function); + computeCfgInfo(build.function); + + CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"( +bb_0: +; successors: bb_1, bb_2 +; in regs: R0, R1, R2, R3 +; out regs: R1, R2, R3 + %0 = LOAD_TAG R0 + JUMP_EQ_TAG %0, tnumber, bb_1, bb_2 + +bb_1: +; predecessors: bb_0 +; successors: bb_3 +; in regs: R1, R3 +; out regs: R2, R3 + %2 = LOAD_TVALUE R1 + STORE_TVALUE R2, %2 + JUMP bb_3 + +bb_2: +; predecessors: bb_0 +; successors: bb_3 +; in regs: R1, R2 +; out regs: R2, R3 + %5 = LOAD_TVALUE R1 + STORE_TVALUE R3, %5 + JUMP bb_3 + +bb_3: +; predecessors: bb_1, bb_2 +; in regs: R2, R3 + RETURN R2, 2i + +)"); +} + +TEST_CASE_FIXTURE(IrBuilderFixture, "ImplicitFixedRegistersInVarargCall") +{ + IrOp entry = build.block(IrBlockKind::Internal); + IrOp exit = build.block(IrBlockKind::Internal); + + build.beginBlock(entry); + build.inst(IrCmd::FALLBACK_GETVARARGS, build.constUint(0), build.vmReg(3), build.constInt(-1)); + build.inst(IrCmd::CALL, build.vmReg(0), build.constInt(-1), build.constInt(5)); + build.inst(IrCmd::JUMP, exit); + + build.beginBlock(exit); + build.inst(IrCmd::RETURN, build.vmReg(0), build.constInt(5)); + + updateUseCounts(build.function); + computeCfgInfo(build.function); + + CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"( +bb_0: +; successors: bb_1 +; in regs: R0, R1, R2 +; out regs: R0, R1, R2, R3, R4 + FALLBACK_GETVARARGS 0u, R3, -1i + CALL R0, -1i, 5i + JUMP bb_1 + +bb_1: +; predecessors: bb_0 +; in regs: R0, R1, R2, R3, R4 + RETURN R0, 5i + +)"); +} + +TEST_CASE_FIXTURE(IrBuilderFixture, "ExplicitUseOfRegisterInVarargSequence") +{ + IrOp entry = build.block(IrBlockKind::Internal); + IrOp exit = build.block(IrBlockKind::Internal); + + build.beginBlock(entry); + build.inst(IrCmd::FALLBACK_GETVARARGS, build.constUint(0), build.vmReg(1), build.constInt(-1)); + IrOp results = build.inst( + IrCmd::INVOKE_FASTCALL, + build.constUint(0), + build.vmReg(0), + build.vmReg(1), + build.vmReg(2), + build.undef(), + build.constInt(-1), + build.constInt(-1) + ); + build.inst(IrCmd::ADJUST_STACK_TO_REG, build.vmReg(0), results); + build.inst(IrCmd::JUMP, exit); + + build.beginBlock(exit); + build.inst(IrCmd::RETURN, build.vmReg(0), build.constInt(-1)); + + updateUseCounts(build.function); + computeCfgInfo(build.function); + + CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"( +bb_0: +; successors: bb_1 +; out regs: R0... + FALLBACK_GETVARARGS 0u, R1, -1i + %1 = INVOKE_FASTCALL 0u, R0, R1, R2, undef, -1i, -1i + ADJUST_STACK_TO_REG R0, %1 + JUMP bb_1 + +bb_1: +; predecessors: bb_0 +; in regs: R0... + RETURN R0, -1i + +)"); +} + +TEST_CASE_FIXTURE(IrBuilderFixture, "VariadicSequenceRestart") +{ + IrOp entry = build.block(IrBlockKind::Internal); + IrOp exit = build.block(IrBlockKind::Internal); + + build.beginBlock(entry); + build.inst(IrCmd::CALL, build.vmReg(1), build.constInt(0), build.constInt(-1)); + build.inst(IrCmd::CALL, build.vmReg(0), build.constInt(-1), build.constInt(-1)); + build.inst(IrCmd::JUMP, exit); + + build.beginBlock(exit); + build.inst(IrCmd::RETURN, build.vmReg(0), build.constInt(-1)); + + updateUseCounts(build.function); + computeCfgInfo(build.function); + + CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"( +bb_0: +; successors: bb_1 +; in regs: R0, R1 +; out regs: R0... + CALL R1, 0i, -1i + CALL R0, -1i, -1i + JUMP bb_1 + +bb_1: +; predecessors: bb_0 +; in regs: R0... + RETURN R0, -1i + +)"); +} + +TEST_CASE_FIXTURE(IrBuilderFixture, "FallbackDoesNotFlowUp") +{ + IrOp entry = build.block(IrBlockKind::Internal); + IrOp fallback = build.block(IrBlockKind::Fallback); + IrOp exit = build.block(IrBlockKind::Internal); + + build.beginBlock(entry); + build.inst(IrCmd::FALLBACK_GETVARARGS, build.constUint(0), build.vmReg(1), build.constInt(-1)); + build.inst(IrCmd::CHECK_TAG, build.inst(IrCmd::LOAD_TAG, build.vmReg(0)), build.constTag(tnumber), fallback); + build.inst(IrCmd::CALL, build.vmReg(0), build.constInt(-1), build.constInt(-1)); + build.inst(IrCmd::JUMP, exit); + + build.beginBlock(fallback); + build.inst(IrCmd::CALL, build.vmReg(0), build.constInt(-1), build.constInt(-1)); + build.inst(IrCmd::JUMP, exit); + + build.beginBlock(exit); + build.inst(IrCmd::RETURN, build.vmReg(0), build.constInt(-1)); + + updateUseCounts(build.function); + computeCfgInfo(build.function); + + CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"( +bb_0: +; successors: bb_fallback_1, bb_2 +; in regs: R0 +; out regs: R0... + FALLBACK_GETVARARGS 0u, R1, -1i + %1 = LOAD_TAG R0 + CHECK_TAG %1, tnumber, bb_fallback_1 + CALL R0, -1i, -1i + JUMP bb_2 + +bb_fallback_1: +; predecessors: bb_0 +; successors: bb_2 +; in regs: R0, R1... +; out regs: R0... + CALL R0, -1i, -1i + JUMP bb_2 + +bb_2: +; predecessors: bb_0, bb_fallback_1 +; in regs: R0... + RETURN R0, -1i + +)"); +} + +TEST_CASE_FIXTURE(IrBuilderFixture, "VariadicSequencePeeling") +{ + IrOp entry = build.block(IrBlockKind::Internal); + IrOp a = build.block(IrBlockKind::Internal); + IrOp b = build.block(IrBlockKind::Internal); + IrOp exit = build.block(IrBlockKind::Internal); + + build.beginBlock(entry); + build.inst(IrCmd::FALLBACK_GETVARARGS, build.constUint(0), build.vmReg(3), build.constInt(-1)); + build.inst(IrCmd::JUMP_EQ_TAG, build.inst(IrCmd::LOAD_TAG, build.vmReg(0)), build.constTag(tnumber), a, b); + + build.beginBlock(a); + build.inst(IrCmd::STORE_TVALUE, build.vmReg(2), build.inst(IrCmd::LOAD_TVALUE, build.vmReg(0))); + build.inst(IrCmd::JUMP, exit); + + build.beginBlock(b); + build.inst(IrCmd::STORE_TVALUE, build.vmReg(2), build.inst(IrCmd::LOAD_TVALUE, build.vmReg(1))); + build.inst(IrCmd::JUMP, exit); + + build.beginBlock(exit); + build.inst(IrCmd::RETURN, build.vmReg(2), build.constInt(-1)); + + updateUseCounts(build.function); + computeCfgInfo(build.function); + + CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"( +bb_0: +; successors: bb_1, bb_2 +; in regs: R0, R1 +; out regs: R0, R1, R3... + FALLBACK_GETVARARGS 0u, R3, -1i + %1 = LOAD_TAG R0 + JUMP_EQ_TAG %1, tnumber, bb_1, bb_2 + +bb_1: +; predecessors: bb_0 +; successors: bb_3 +; in regs: R0, R3... +; out regs: R2... + %3 = LOAD_TVALUE R0 + STORE_TVALUE R2, %3 + JUMP bb_3 + +bb_2: +; predecessors: bb_0 +; successors: bb_3 +; in regs: R1, R3... +; out regs: R2... + %6 = LOAD_TVALUE R1 + STORE_TVALUE R2, %6 + JUMP bb_3 + +bb_3: +; predecessors: bb_1, bb_2 +; in regs: R2... + RETURN R2, -1i + +)"); +} + +TEST_CASE_FIXTURE(IrBuilderFixture, "BuiltinVariadicStart") +{ + IrOp entry = build.block(IrBlockKind::Internal); + IrOp exit = build.block(IrBlockKind::Internal); + + build.beginBlock(entry); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(1), build.constDouble(1.0)); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(2), build.constDouble(2.0)); + build.inst(IrCmd::ADJUST_STACK_TO_REG, build.vmReg(2), build.constInt(1)); + build.inst(IrCmd::CALL, build.vmReg(1), build.constInt(-1), build.constInt(1)); + build.inst(IrCmd::JUMP, exit); + + build.beginBlock(exit); + build.inst(IrCmd::RETURN, build.vmReg(0), build.constInt(2)); + + updateUseCounts(build.function); + computeCfgInfo(build.function); + + CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"( +bb_0: +; successors: bb_1 +; in regs: R0 +; out regs: R0, R1 + STORE_DOUBLE R1, 1 + STORE_DOUBLE R2, 2 + ADJUST_STACK_TO_REG R2, 1i + CALL R1, -1i, 1i + JUMP bb_1 + +bb_1: +; predecessors: bb_0 +; in regs: R0, R1 + RETURN R0, 2i + +)"); +} + +TEST_CASE_FIXTURE(IrBuilderFixture, "ForgprepImplicitUse") +{ + IrOp entry = build.block(IrBlockKind::Internal); + IrOp direct = build.block(IrBlockKind::Internal); + IrOp fallback = build.block(IrBlockKind::Internal); + IrOp exit = build.block(IrBlockKind::Internal); + + build.beginBlock(entry); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(1), build.constDouble(1.0)); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(2), build.constDouble(10.0)); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(3), build.constDouble(1.0)); + IrOp tag = build.inst(IrCmd::LOAD_TAG, build.vmReg(0)); + build.inst(IrCmd::JUMP_EQ_TAG, tag, build.constTag(tnumber), direct, fallback); + + build.beginBlock(direct); + build.inst(IrCmd::RETURN, build.vmReg(0), build.constInt(1)); + + build.beginBlock(fallback); + build.inst(IrCmd::FALLBACK_FORGPREP, build.constUint(0), build.vmReg(1), exit); + + build.beginBlock(exit); + build.inst(IrCmd::RETURN, build.vmReg(1), build.constInt(3)); + + updateUseCounts(build.function); + computeCfgInfo(build.function); + + CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"( +bb_0: +; successors: bb_1, bb_2 +; in regs: R0 +; out regs: R0, R1, R2, R3 + STORE_DOUBLE R1, 1 + STORE_DOUBLE R2, 10 + STORE_DOUBLE R3, 1 + %3 = LOAD_TAG R0 + JUMP_EQ_TAG %3, tnumber, bb_1, bb_2 + +bb_1: +; predecessors: bb_0 +; in regs: R0 + RETURN R0, 1i + +bb_2: +; predecessors: bb_0 +; successors: bb_3 +; in regs: R1, R2, R3 +; out regs: R1, R2, R3 + FALLBACK_FORGPREP 0u, R1, bb_3 + +bb_3: +; predecessors: bb_2 +; in regs: R1, R2, R3 + RETURN R1, 3i + +)"); +} + +TEST_CASE_FIXTURE(IrBuilderFixture, "SetTable") +{ + IrOp entry = build.block(IrBlockKind::Internal); + + build.beginBlock(entry); + build.inst(IrCmd::SET_TABLE, build.vmReg(0), build.vmReg(1), build.constUint(1)); + build.inst(IrCmd::RETURN, build.vmReg(0), build.constInt(1)); + + updateUseCounts(build.function); + computeCfgInfo(build.function); + + CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"( +bb_0: +; in regs: R0, R1 + SET_TABLE R0, R1, 1u + RETURN R0, 1i + +)"); +} + +// 'A Simple, Fast Dominance Algorithm' [Keith D. Cooper, et al]. Figure 2. +TEST_CASE_FIXTURE(IrBuilderFixture, "DominanceVerification1") +{ + defineCfgTree({{1, 2}, {3}, {4}, {4}, {3}}); + + CHECK(build.function.cfg.idoms == std::vector{{~0u, 0, 0, 0, 0}}); +} + +// 'A Linear Time Algorithm for Placing Phi-Nodes' [Vugranam C.Sreedhar]. Figure 1. +TEST_CASE_FIXTURE(IrBuilderFixture, "DominanceVerification2") +{ + defineCfgTree({{1, 16}, {2, 3, 4}, {4, 7}, {9}, {5}, {6}, {2, 8}, {8}, {7, 15}, {10, 11}, {12}, {12}, {13}, {3, 14, 15}, {12}, {16}, {}}); + + CHECK(build.function.cfg.idoms == std::vector{~0u, 0, 1, 1, 1, 4, 5, 1, 1, 3, 9, 9, 9, 12, 13, 1, 0}); +} + +// 'A Linear Time Algorithm for Placing Phi-Nodes' [Vugranam C.Sreedhar]. Figure 4. +TEST_CASE_FIXTURE(IrBuilderFixture, "DominanceVerification3") +{ + defineCfgTree({{1, 2}, {3}, {3, 4}, {5}, {5, 6}, {7}, {7}, {}}); + + CHECK(build.function.cfg.idoms == std::vector{~0u, 0, 0, 0, 2, 0, 4, 0}); +} + +// 'Static Single Assignment Book' Figure 4.1 +TEST_CASE_FIXTURE(IrBuilderFixture, "DominanceVerification4") +{ + defineCfgTree({{1}, {2, 10}, {3, 7}, {4}, {5}, {4, 6}, {1}, {8}, {5, 9}, {7}, {}}); + + IdfContext ctx; + + computeIteratedDominanceFrontierForDefs(ctx, build.function, {0, 2, 3, 6}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}); + CHECK(ctx.idf == std::vector{1, 4, 5}); +} + +// 'Static Single Assignment Book' Figure 4.5 +TEST_CASE_FIXTURE(IrBuilderFixture, "DominanceVerification4") +{ + defineCfgTree({{1}, {2}, {3, 7}, {4, 5}, {6}, {6}, {8}, {8}, {9}, {10, 11}, {11}, {9, 12}, {2}}); + + IdfContext ctx; + + computeIteratedDominanceFrontierForDefs(ctx, build.function, {4, 5, 7, 12}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + CHECK(ctx.idf == std::vector{2, 6, 8}); + + // Pruned form, when variable is only live-in in limited set of blocks + computeIteratedDominanceFrontierForDefs(ctx, build.function, {4, 5, 7, 12}, {6, 8, 9}); + CHECK(ctx.idf == std::vector{6, 8}); +} + +TEST_SUITE_END(); + +TEST_SUITE_BEGIN("ValueNumbering"); + +TEST_CASE_FIXTURE(IrBuilderFixture, "RemoveDuplicateCalculation") +{ + IrOp entry = build.block(IrBlockKind::Internal); + + build.beginBlock(entry); + IrOp op1 = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(0)); + IrOp op2 = build.inst(IrCmd::UNM_NUM, op1); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(1), op2); + IrOp op3 = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(0)); // Load propagation is tested here + IrOp op4 = build.inst(IrCmd::UNM_NUM, op3); // And allows value numbering to trigger here + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(2), op4); + build.inst(IrCmd::RETURN, build.vmReg(1), build.constInt(2)); + + updateUseCounts(build.function); + constPropInBlockChains(build, true); + + CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"( +bb_0: + %0 = LOAD_DOUBLE R0 + %1 = UNM_NUM %0 + STORE_DOUBLE R1, %1 + STORE_DOUBLE R2, %1 + RETURN R1, 2i + +)"); +} + +TEST_CASE_FIXTURE(IrBuilderFixture, "LateTableStateLink") +{ + IrOp block = build.block(IrBlockKind::Internal); + IrOp fallback = build.block(IrBlockKind::Fallback); + + build.beginBlock(block); + + IrOp tmp = build.inst(IrCmd::DUP_TABLE, build.vmReg(0)); + build.inst(IrCmd::STORE_POINTER, build.vmReg(0), tmp); // Late tmp -> R0 link is tested here + IrOp table = build.inst(IrCmd::LOAD_POINTER, build.vmReg(0)); // Store to load propagation test + + build.inst(IrCmd::CHECK_NO_METATABLE, table, fallback); + build.inst(IrCmd::CHECK_READONLY, table, fallback); + + build.inst(IrCmd::CHECK_NO_METATABLE, table, fallback); + build.inst(IrCmd::CHECK_READONLY, table, fallback); + + build.inst(IrCmd::RETURN, build.constUint(0)); + + build.beginBlock(fallback); + build.inst(IrCmd::RETURN, build.constUint(1)); + + updateUseCounts(build.function); + constPropInBlockChains(build, true); + + CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"( +bb_0: + %0 = DUP_TABLE R0 + STORE_POINTER R0, %0 + CHECK_NO_METATABLE %0, bb_fallback_1 + CHECK_READONLY %0, bb_fallback_1 + RETURN 0u + +bb_fallback_1: + RETURN 1u + +)"); +} + +TEST_CASE_FIXTURE(IrBuilderFixture, "RegisterVersioning") +{ + IrOp entry = build.block(IrBlockKind::Internal); + + build.beginBlock(entry); + IrOp op1 = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(0)); + IrOp op2 = build.inst(IrCmd::UNM_NUM, op1); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(0), op2); + build.inst(IrCmd::STORE_TAG, build.vmReg(0), build.constTag(tnumber)); // Doesn't prevent previous store propagation + IrOp op3 = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(0)); // No longer 'op1' + IrOp op4 = build.inst(IrCmd::UNM_NUM, op3); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(1), op4); + build.inst(IrCmd::RETURN, build.vmReg(0), build.constInt(2)); + + updateUseCounts(build.function); + constPropInBlockChains(build, true); + + CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"( +bb_0: + %0 = LOAD_DOUBLE R0 + %1 = UNM_NUM %0 + STORE_DOUBLE R0, %1 + STORE_TAG R0, tnumber + %5 = UNM_NUM %1 + STORE_DOUBLE R1, %5 + RETURN R0, 2i + +)"); +} + +// This can be relaxed in the future when SETLIST becomes aware of register allocator +TEST_CASE_FIXTURE(IrBuilderFixture, "SetListIsABlocker") +{ + IrOp entry = build.block(IrBlockKind::Internal); + + build.beginBlock(entry); + IrOp op1 = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(0)); + build.inst(IrCmd::SETLIST); + IrOp op2 = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(0)); + IrOp sum = build.inst(IrCmd::ADD_NUM, op1, op2); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(0), sum); + build.inst(IrCmd::RETURN, build.vmReg(0), build.constInt(1)); + + updateUseCounts(build.function); + constPropInBlockChains(build, true); + + CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"( +bb_0: + %0 = LOAD_DOUBLE R0 + SETLIST + %2 = LOAD_DOUBLE R0 + %3 = ADD_NUM %0, %2 + STORE_DOUBLE R0, %3 + RETURN R0, 1i + +)"); +} + +// Luau call will reuse the same stack and spills will be lost +// However, in the future we might propagate values that can be rematerialized +TEST_CASE_FIXTURE(IrBuilderFixture, "CallIsABlocker") +{ + IrOp entry = build.block(IrBlockKind::Internal); + + build.beginBlock(entry); + IrOp op1 = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(0)); + build.inst(IrCmd::CALL, build.vmReg(1), build.constInt(1), build.vmReg(2), build.constInt(1)); + IrOp op2 = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(0)); + IrOp sum = build.inst(IrCmd::ADD_NUM, op1, op2); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(1), sum); + build.inst(IrCmd::RETURN, build.vmReg(1), build.constInt(2)); + + updateUseCounts(build.function); + constPropInBlockChains(build, true); + + CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"( +bb_0: + %0 = LOAD_DOUBLE R0 + CALL R1, 1i, R2, 1i + %2 = LOAD_DOUBLE R0 + %3 = ADD_NUM %0, %2 + STORE_DOUBLE R1, %3 + RETURN R1, 2i + +)"); +} + +// While constant propagation correctly versions captured registers, IrValueLocationTracking doesn't (yet) +TEST_CASE_FIXTURE(IrBuilderFixture, "NoPropagationOfCapturedRegs") +{ + IrOp entry = build.block(IrBlockKind::Internal); + + build.beginBlock(entry); + build.inst(IrCmd::CAPTURE, build.vmReg(0), build.constUint(1)); + IrOp op1 = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(0)); + IrOp op2 = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(0)); + IrOp sum = build.inst(IrCmd::ADD_NUM, op1, op2); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(1), sum); + build.inst(IrCmd::RETURN, build.vmReg(1), build.constInt(1)); + + updateUseCounts(build.function); + computeCfgInfo(build.function); + constPropInBlockChains(build, true); + + CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"( +; captured regs: R0 + +bb_0: +; in regs: R0 + CAPTURE R0, 1u + %1 = LOAD_DOUBLE R0 + %2 = LOAD_DOUBLE R0 + %3 = ADD_NUM %1, %2 + STORE_DOUBLE R1, %3 + RETURN R1, 1i + +)"); +} + +TEST_CASE_FIXTURE(IrBuilderFixture, "NoDeadLoadReuse") +{ + IrOp entry = build.block(IrBlockKind::Internal); + + build.beginBlock(entry); + IrOp op1 = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(0)); + IrOp op1i = build.inst(IrCmd::NUM_TO_INT, op1); + IrOp res = build.inst(IrCmd::BITAND_UINT, op1i, build.constInt(0)); + IrOp resd = build.inst(IrCmd::INT_TO_NUM, res); + IrOp op2 = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(0)); + IrOp sum = build.inst(IrCmd::ADD_NUM, resd, op2); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(1), sum); + build.inst(IrCmd::RETURN, build.vmReg(1), build.constInt(1)); + + updateUseCounts(build.function); + constPropInBlockChains(build, true); + + CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"( +bb_0: + %4 = LOAD_DOUBLE R0 + %5 = ADD_NUM 0, %4 + STORE_DOUBLE R1, %5 + RETURN R1, 1i + +)"); +} + +TEST_CASE_FIXTURE(IrBuilderFixture, "NoDeadValueReuse") +{ + IrOp entry = build.block(IrBlockKind::Internal); + + build.beginBlock(entry); + IrOp op1 = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(0)); + IrOp op1i = build.inst(IrCmd::NUM_TO_INT, op1); + IrOp res = build.inst(IrCmd::BITAND_UINT, op1i, build.constInt(0)); + IrOp op2i = build.inst(IrCmd::NUM_TO_INT, op1); + IrOp sum = build.inst(IrCmd::ADD_INT, res, op2i); + IrOp resd = build.inst(IrCmd::INT_TO_NUM, sum); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(1), resd); + build.inst(IrCmd::RETURN, build.vmReg(1), build.constInt(1)); + + updateUseCounts(build.function); + constPropInBlockChains(build, true); + + CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"( +bb_0: + %0 = LOAD_DOUBLE R0 + %3 = NUM_TO_INT %0 + %4 = ADD_INT 0i, %3 + %5 = INT_TO_NUM %4 + STORE_DOUBLE R1, %5 + RETURN R1, 1i + +)"); +} + +TEST_CASE_FIXTURE(IrBuilderFixture, "TValueLoadToSplitStore") +{ + IrOp entry = build.block(IrBlockKind::Internal); + IrOp fallback = build.block(IrBlockKind::Fallback); + + build.beginBlock(entry); + IrOp op1 = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(0)); + IrOp op1v2 = build.inst(IrCmd::ADD_NUM, op1, build.constDouble(4.0)); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(1), op1v2); + build.inst(IrCmd::STORE_TAG, build.vmReg(1), build.constTag(tnumber)); + + // Check that this TValue store will be replaced by a split store + IrOp tv = build.inst(IrCmd::LOAD_TVALUE, build.vmReg(1)); + build.inst(IrCmd::STORE_TVALUE, build.vmReg(2), tv); + + // Check that tag and value can be extracted from R2 now (removing the fallback) + IrOp tag2 = build.inst(IrCmd::LOAD_TAG, build.vmReg(2)); + build.inst(IrCmd::CHECK_TAG, tag2, build.constTag(tnumber), fallback); + IrOp op2 = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(2)); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(3), op2); + build.inst(IrCmd::STORE_TAG, build.vmReg(3), build.constTag(tnumber)); + + build.inst(IrCmd::RETURN, build.vmReg(1), build.constInt(1)); + + build.beginBlock(fallback); + build.inst(IrCmd::RETURN, build.vmReg(2), build.constInt(1)); + + updateUseCounts(build.function); + constPropInBlockChains(build, true); + + CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"( +bb_0: + %0 = LOAD_DOUBLE R0 + %1 = ADD_NUM %0, 4 + STORE_DOUBLE R1, %1 + STORE_TAG R1, tnumber + STORE_SPLIT_TVALUE R2, tnumber, %1 + STORE_DOUBLE R3, %1 + STORE_TAG R3, tnumber + RETURN R1, 1i + +)"); +} + +TEST_CASE_FIXTURE(IrBuilderFixture, "TagStoreUpdatesValueVersion") +{ + IrOp entry = build.block(IrBlockKind::Internal); + + build.beginBlock(entry); + + IrOp op1 = build.inst(IrCmd::LOAD_POINTER, build.vmReg(0)); + build.inst(IrCmd::STORE_POINTER, build.vmReg(1), op1); + build.inst(IrCmd::STORE_TAG, build.vmReg(1), build.constTag(tstring)); + + IrOp str = build.inst(IrCmd::LOAD_POINTER, build.vmReg(1)); + build.inst(IrCmd::STORE_POINTER, build.vmReg(2), str); + build.inst(IrCmd::STORE_TAG, build.vmReg(2), build.constTag(tstring)); + + build.inst(IrCmd::RETURN, build.vmReg(1), build.constInt(1)); + + updateUseCounts(build.function); + constPropInBlockChains(build, true); + + CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"( +bb_0: + %0 = LOAD_POINTER R0 + STORE_POINTER R1, %0 + STORE_TAG R1, tstring + STORE_POINTER R2, %0 + STORE_TAG R2, tstring + RETURN R1, 1i + +)"); +} + +TEST_CASE_FIXTURE(IrBuilderFixture, "TagStoreUpdatesSetUpval") +{ + IrOp entry = build.block(IrBlockKind::Internal); + + build.beginBlock(entry); + + build.inst(IrCmd::STORE_TAG, build.vmReg(0), build.constTag(tnumber)); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(0), build.constDouble(0.5)); + + build.inst(IrCmd::SET_UPVALUE, build.vmUpvalue(0), build.vmReg(0), build.undef()); + + build.inst(IrCmd::RETURN, build.vmReg(0), build.constInt(0)); + + updateUseCounts(build.function); + constPropInBlockChains(build, true); + + CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"( +bb_0: + STORE_TAG R0, tnumber + STORE_DOUBLE R0, 0.5 + SET_UPVALUE U0, R0, tnumber + RETURN R0, 0i + +)"); +} + +TEST_CASE_FIXTURE(IrBuilderFixture, "TagSelfEqualityCheckRemoval") +{ + IrOp entry = build.block(IrBlockKind::Internal); + IrOp trueBlock = build.block(IrBlockKind::Internal); + IrOp falseBlock = build.block(IrBlockKind::Internal); + + build.beginBlock(entry); + + IrOp tag1 = build.inst(IrCmd::LOAD_TAG, build.vmReg(0)); + IrOp tag2 = build.inst(IrCmd::LOAD_TAG, build.vmReg(0)); + build.inst(IrCmd::JUMP_EQ_TAG, tag1, tag2, trueBlock, falseBlock); + + build.beginBlock(trueBlock); + build.inst(IrCmd::RETURN, build.constUint(1)); + + build.beginBlock(falseBlock); + build.inst(IrCmd::RETURN, build.constUint(2)); + + updateUseCounts(build.function); + constPropInBlockChains(build, true); + + CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"( +bb_0: + JUMP bb_1 + +bb_1: + RETURN 1u + +)"); +} + +TEST_CASE_FIXTURE(IrBuilderFixture, "TaggedValuePropagationIntoTvalueChecksRegisterVersion") +{ + IrOp entry = build.block(IrBlockKind::Internal); + + build.beginBlock(entry); + IrOp a1 = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(0)); + IrOp b1 = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(1)); + IrOp sum1 = build.inst(IrCmd::ADD_NUM, a1, b1); + + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(7), sum1); + build.inst(IrCmd::STORE_TAG, build.vmReg(7), build.constTag(tnumber)); + + IrOp a2 = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(2)); + IrOp b2 = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(3)); + IrOp sum2 = build.inst(IrCmd::ADD_NUM, a2, b2); + + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(8), sum2); + build.inst(IrCmd::STORE_TAG, build.vmReg(8), build.constTag(tnumber)); + + IrOp old7 = build.inst(IrCmd::LOAD_TVALUE, build.vmReg(7), build.constInt(0), build.constTag(tnumber)); + IrOp old8 = build.inst(IrCmd::LOAD_TVALUE, build.vmReg(8), build.constInt(0), build.constTag(tnumber)); + + build.inst(IrCmd::STORE_TVALUE, build.vmReg(8), old7); // Invalidate R8 + build.inst(IrCmd::STORE_TVALUE, build.vmReg(9), old8); // Old R8 cannot be substituted as it was invalidated + + build.inst(IrCmd::RETURN, build.vmReg(8), build.constInt(2)); + + updateUseCounts(build.function); + computeCfgInfo(build.function); + constPropInBlockChains(build, true); + + CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"( +bb_0: +; in regs: R0, R1, R2, R3 + %0 = LOAD_DOUBLE R0 + %1 = LOAD_DOUBLE R1 + %2 = ADD_NUM %0, %1 + STORE_DOUBLE R7, %2 + STORE_TAG R7, tnumber + %5 = LOAD_DOUBLE R2 + %6 = LOAD_DOUBLE R3 + %7 = ADD_NUM %5, %6 + STORE_DOUBLE R8, %7 + STORE_TAG R8, tnumber + %11 = LOAD_TVALUE R8, 0i, tnumber + STORE_SPLIT_TVALUE R8, tnumber, %2 + STORE_TVALUE R9, %11 + RETURN R8, 2i + +)"); +} + +TEST_SUITE_END(); + +TEST_SUITE_BEGIN("DeadStoreRemoval"); + +TEST_CASE_FIXTURE(IrBuilderFixture, "SimpleDoubleStore") +{ + IrOp entry = build.block(IrBlockKind::Internal); + + build.beginBlock(entry); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(1), build.constDouble(1.0)); + build.inst(IrCmd::STORE_TAG, build.vmReg(1), build.constTag(tnumber)); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(1), build.constDouble(2.0)); // Should remove previous store + + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(2), build.constDouble(1.0)); + build.inst(IrCmd::STORE_TAG, build.vmReg(2), build.constTag(tnumber)); + build.inst(IrCmd::STORE_INT, build.vmReg(2), build.constInt(4)); + build.inst(IrCmd::STORE_TAG, build.vmReg(2), build.constTag(tboolean)); // Should remove previous store of different type + + build.inst(IrCmd::STORE_TAG, build.vmReg(3), build.constTag(tnil)); + build.inst(IrCmd::STORE_TAG, build.vmReg(3), build.constTag(tnumber)); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(3), build.constDouble(4.0)); + + build.inst(IrCmd::STORE_TAG, build.vmReg(4), build.constTag(tnil)); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(4), build.constDouble(1.0)); + build.inst(IrCmd::STORE_SPLIT_TVALUE, build.vmReg(4), build.constTag(tnumber), build.constDouble(2.0)); // Should remove two previous stores + + IrOp someTv = build.inst(IrCmd::LOAD_TVALUE, build.vmReg(0)); + build.inst(IrCmd::STORE_TAG, build.vmReg(5), build.constTag(tnil)); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(5), build.constDouble(1.0)); + build.inst(IrCmd::STORE_TVALUE, build.vmReg(5), someTv); // Should remove two previous stores + + build.inst(IrCmd::RETURN, build.vmReg(1), build.constInt(5)); + + updateUseCounts(build.function); + computeCfgInfo(build.function); + constPropInBlockChains(build, true); + markDeadStoresInBlockChains(build); + + CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"( +bb_0: +; in regs: R0 + STORE_SPLIT_TVALUE R1, tnumber, 2 + STORE_SPLIT_TVALUE R2, tboolean, 4i + STORE_TAG R3, tnumber + STORE_DOUBLE R3, 4 + STORE_SPLIT_TVALUE R4, tnumber, 2 + %13 = LOAD_TVALUE R0 + STORE_TVALUE R5, %13 + RETURN R1, 5i + +)"); +} + +TEST_CASE_FIXTURE(IrBuilderFixture, "UnusedAtReturn") +{ + IrOp entry = build.block(IrBlockKind::Internal); + + build.beginBlock(entry); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(1), build.constDouble(1.0)); + build.inst(IrCmd::STORE_TAG, build.vmReg(1), build.constTag(tnumber)); + build.inst(IrCmd::STORE_INT, build.vmReg(2), build.constInt(4)); + build.inst(IrCmd::STORE_TAG, build.vmReg(2), build.constTag(tboolean)); + build.inst(IrCmd::STORE_SPLIT_TVALUE, build.vmReg(4), build.constTag(tnumber), build.constDouble(2.0)); + + IrOp someTv = build.inst(IrCmd::LOAD_TVALUE, build.vmReg(0)); + build.inst(IrCmd::STORE_TVALUE, build.vmReg(5), someTv); + + build.inst(IrCmd::STORE_TAG, build.vmReg(6), build.constTag(tnil)); + + build.inst(IrCmd::RETURN, build.vmReg(0), build.constInt(1)); + + updateUseCounts(build.function); + computeCfgInfo(build.function); + constPropInBlockChains(build, true); + markDeadStoresInBlockChains(build); + + CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"( +bb_0: +; in regs: R0 + RETURN R0, 1i + +)"); +} + +TEST_CASE_FIXTURE(IrBuilderFixture, "UnusedAtReturnPartial") +{ + IrOp entry = build.block(IrBlockKind::Internal); + + build.beginBlock(entry); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(1), build.constDouble(1.0)); + build.inst(IrCmd::STORE_INT, build.vmReg(2), build.constInt(4)); + build.inst(IrCmd::STORE_TAG, build.vmReg(3), build.constTag(tnumber)); + build.inst(IrCmd::RETURN, build.vmReg(0), build.constInt(1)); + + updateUseCounts(build.function); + computeCfgInfo(build.function); + constPropInBlockChains(build, true); + markDeadStoresInBlockChains(build); + + // Partial stores cannot be removed, even if unused + // Existance of an unpaired partial store means that the other valid part is a block live in (even if not present is this test) + CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"( +bb_0: +; in regs: R0 + STORE_DOUBLE R1, 1 + STORE_INT R2, 4i + STORE_TAG R3, tnumber + RETURN R0, 1i + +)"); +} + +TEST_CASE_FIXTURE(IrBuilderFixture, "HiddenPointerUse1") +{ + IrOp entry = build.block(IrBlockKind::Internal); + + build.beginBlock(entry); + IrOp somePtr = build.inst(IrCmd::LOAD_POINTER, build.vmReg(0)); + build.inst(IrCmd::STORE_POINTER, build.vmReg(1), somePtr); + build.inst(IrCmd::STORE_TAG, build.vmReg(1), build.constTag(ttable)); + build.inst(IrCmd::CALL, build.vmReg(2), build.constInt(0), build.constInt(1)); + build.inst(IrCmd::RETURN, build.vmReg(2), build.constInt(1)); + + updateUseCounts(build.function); + computeCfgInfo(build.function); + constPropInBlockChains(build, true); + markDeadStoresInBlockChains(build); + + CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"( +bb_0: +; in regs: R0, R2 + %0 = LOAD_POINTER R0 + STORE_POINTER R1, %0 + STORE_TAG R1, ttable + CALL R2, 0i, 1i + RETURN R2, 1i + +)"); +} + +TEST_CASE_FIXTURE(IrBuilderFixture, "HiddenPointerUse2") +{ + IrOp entry = build.block(IrBlockKind::Internal); + + build.beginBlock(entry); + IrOp somePtrA = build.inst(IrCmd::LOAD_POINTER, build.vmReg(0)); + build.inst(IrCmd::STORE_POINTER, build.vmReg(1), somePtrA); + build.inst(IrCmd::STORE_TAG, build.vmReg(1), build.constTag(ttable)); + build.inst(IrCmd::CALL, build.vmReg(2), build.constInt(0), build.constInt(1)); + IrOp somePtrB = build.inst(IrCmd::LOAD_POINTER, build.vmReg(2)); + build.inst(IrCmd::STORE_POINTER, build.vmReg(1), somePtrB); + build.inst(IrCmd::STORE_TAG, build.vmReg(1), build.constTag(ttable)); + build.inst(IrCmd::RETURN, build.vmReg(2), build.constInt(1)); + + updateUseCounts(build.function); + computeCfgInfo(build.function); + constPropInBlockChains(build, true); + markDeadStoresInBlockChains(build); + + // Stores to pointers can be safely removed at 'return' point, but have to preserved for any GC assist trigger (such as a call) + CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"( +bb_0: +; in regs: R0, R2 + %0 = LOAD_POINTER R0 + STORE_POINTER R1, %0 + STORE_TAG R1, ttable + CALL R2, 0i, 1i + RETURN R2, 1i + +)"); +} + +TEST_CASE_FIXTURE(IrBuilderFixture, "HiddenPointerUse3") +{ + IrOp entry = build.block(IrBlockKind::Internal); + + build.beginBlock(entry); + IrOp somePtrA = build.inst(IrCmd::LOAD_POINTER, build.vmReg(0)); + build.inst(IrCmd::STORE_POINTER, build.vmReg(1), somePtrA); + build.inst(IrCmd::STORE_TAG, build.vmReg(1), build.constTag(ttable)); + IrOp someTv = build.inst(IrCmd::LOAD_TVALUE, build.vmReg(2)); + build.inst(IrCmd::STORE_TVALUE, build.vmReg(1), someTv); + build.inst(IrCmd::RETURN, build.vmReg(1), build.constInt(1)); + + updateUseCounts(build.function); + computeCfgInfo(build.function); + constPropInBlockChains(build, true); + markDeadStoresInBlockChains(build); + + // Stores to pointers can be safely removed if there are no potential implicit uses by any GC assists + CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"( +bb_0: +; in regs: R0, R2 + %3 = LOAD_TVALUE R2 + STORE_TVALUE R1, %3 + RETURN R1, 1i + +)"); +} + +TEST_CASE_FIXTURE(IrBuilderFixture, "HiddenPointerUse4") +{ + IrOp entry = build.block(IrBlockKind::Internal); + + build.beginBlock(entry); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(0), build.constDouble(1.0)); + build.inst(IrCmd::STORE_TAG, build.vmReg(0), build.constTag(tnumber)); + build.inst(IrCmd::CHECK_GC); + build.inst(IrCmd::STORE_TAG, build.vmReg(0), build.constTag(tnil)); + build.inst(IrCmd::RETURN, build.vmReg(0), build.constInt(1)); + + updateUseCounts(build.function); + computeCfgInfo(build.function); + constPropInBlockChains(build, true); + markDeadStoresInBlockChains(build); + + // It is important for tag overwrite to TNIL to kill not only the previous tag store, but the value as well + // This is important in a following scenario: + // - R0 might have been a GCO on entry to bb_0 + // - R0 is overwritten by a number + // - Stack is visited by GC assist + // - R0 is overwritten by nil + // If only number tag write would have been killed, there will be a GCO tag with a double value on stack + CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"( +bb_0: + CHECK_GC + STORE_TAG R0, tnil + RETURN R0, 1i + +)"); +} + +TEST_CASE_FIXTURE(IrBuilderFixture, "PartialVsFullStoresWithRecombination") +{ + IrOp entry = build.block(IrBlockKind::Internal); + + build.beginBlock(entry); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(1), build.constDouble(1.0)); + build.inst(IrCmd::STORE_TAG, build.vmReg(1), build.constTag(tnumber)); + build.inst(IrCmd::STORE_TVALUE, build.vmReg(0), build.inst(IrCmd::LOAD_TVALUE, build.vmReg(1))); + build.inst(IrCmd::RETURN, build.vmReg(0), build.constInt(1)); + + updateUseCounts(build.function); + computeCfgInfo(build.function); + constPropInBlockChains(build, true); + markDeadStoresInBlockChains(build); + + CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"( +bb_0: + STORE_SPLIT_TVALUE R0, tnumber, 1 + RETURN R0, 1i + +)"); +} + +TEST_CASE_FIXTURE(IrBuilderFixture, "IgnoreFastcallAdjustment") +{ + IrOp entry = build.block(IrBlockKind::Internal); + + build.beginBlock(entry); + build.inst(IrCmd::STORE_TAG, build.vmReg(1), build.constTag(tnumber)); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(1), build.constDouble(-1.0)); + build.inst(IrCmd::ADJUST_STACK_TO_REG, build.vmReg(1), build.constInt(1)); + build.inst(IrCmd::STORE_TAG, build.vmReg(1), build.constTag(tnumber)); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(1), build.constDouble(1.0)); + build.inst(IrCmd::RETURN, build.vmReg(1), build.constInt(1)); + + updateUseCounts(build.function); + computeCfgInfo(build.function); + constPropInBlockChains(build, true); + markDeadStoresInBlockChains(build); + + CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"( +bb_0: + ADJUST_STACK_TO_REG R1, 1i + STORE_SPLIT_TVALUE R1, tnumber, 1 + RETURN R1, 1i + +)"); +} + +TEST_CASE_FIXTURE(IrBuilderFixture, "JumpImplicitLiveOut") +{ + IrOp entry = build.block(IrBlockKind::Internal); + IrOp next = build.block(IrBlockKind::Internal); + + build.beginBlock(entry); + build.inst(IrCmd::STORE_TAG, build.vmReg(1), build.constTag(tnumber)); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(1), build.constDouble(1.0)); + build.inst(IrCmd::JUMP, next); + + build.beginBlock(next); + build.inst(IrCmd::STORE_TAG, build.vmReg(1), build.constTag(tnumber)); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(1), build.constDouble(1.0)); + build.inst(IrCmd::RETURN, build.vmReg(1), build.constInt(1)); + + updateUseCounts(build.function); + computeCfgInfo(build.function); + constPropInBlockChains(build, true); + markDeadStoresInBlockChains(build); + + // Even though bb_0 doesn't have R1 as a live out, chain optimization used the knowledge of those writes happening to optimize duplicate stores + CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"( +bb_0: +; successors: bb_1 + STORE_TAG R1, tnumber + STORE_DOUBLE R1, 1 + JUMP bb_1 + +bb_1: +; predecessors: bb_0 + RETURN R1, 1i + +)"); +} + +TEST_CASE_FIXTURE(IrBuilderFixture, "KeepCapturedRegisterStores") +{ + IrOp entry = build.block(IrBlockKind::Internal); + + build.beginBlock(entry); + build.inst(IrCmd::CAPTURE, build.vmReg(1), build.constUint(1)); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(1), build.constDouble(1.0)); + build.inst(IrCmd::STORE_TAG, build.vmReg(1), build.constTag(tnumber)); + build.inst(IrCmd::DO_ARITH, build.vmReg(0), build.vmReg(2), build.vmReg(3), build.constInt(0)); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(1), build.constDouble(-1.0)); + build.inst(IrCmd::STORE_TAG, build.vmReg(1), build.constTag(tnumber)); + build.inst(IrCmd::DO_ARITH, build.vmReg(1), build.vmReg(4), build.vmReg(5), build.constInt(0)); + build.inst(IrCmd::RETURN, build.vmReg(0), build.constInt(2)); + + updateUseCounts(build.function); + computeCfgInfo(build.function); + constPropInBlockChains(build, true); + markDeadStoresInBlockChains(build); + + // Captured registers may be modified from called user functions (plain or hidden in metamethods) + CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"( +; captured regs: R1 + +bb_0: +; in regs: R1, R2, R3, R4, R5 + CAPTURE R1, 1u + STORE_DOUBLE R1, 1 + STORE_TAG R1, tnumber + DO_ARITH R0, R2, R3, 0i + STORE_DOUBLE R1, -1 + STORE_TAG R1, tnumber + DO_ARITH R1, R4, R5, 0i + RETURN R0, 2i + +)"); +} + +TEST_CASE_FIXTURE(IrBuilderFixture, "StoreCannotBeReplacedWithCheck") +{ + ScopedFastFlag debugLuauAbortingChecks{FFlag::DebugLuauAbortingChecks, true}; + + IrOp block = build.block(IrBlockKind::Internal); + IrOp fallback = build.block(IrBlockKind::Fallback); + IrOp last = build.block(IrBlockKind::Internal); + + build.beginBlock(block); + + IrOp ptr = build.inst(IrCmd::LOAD_POINTER, build.vmReg(1)); + + build.inst(IrCmd::STORE_POINTER, build.vmReg(2), ptr); + build.inst(IrCmd::STORE_TAG, build.vmReg(2), build.constTag(ttable)); + + build.inst(IrCmd::CHECK_READONLY, ptr, fallback); + + build.inst(IrCmd::STORE_POINTER, build.vmReg(2), build.inst(IrCmd::LOAD_POINTER, build.vmReg(0))); + build.inst(IrCmd::STORE_TAG, build.vmReg(2), build.constTag(ttable)); + + build.inst(IrCmd::STORE_TAG, build.vmReg(2), build.constTag(tnil)); + + build.inst(IrCmd::JUMP, last); + + build.beginBlock(fallback); + IrOp fallbackPtr = build.inst(IrCmd::LOAD_POINTER, build.vmReg(1)); + build.inst(IrCmd::STORE_POINTER, build.vmReg(2), fallbackPtr); + build.inst(IrCmd::STORE_TAG, build.vmReg(2), build.constTag(ttable)); + build.inst(IrCmd::CHECK_GC); + build.inst(IrCmd::JUMP, last); + + build.beginBlock(last); + build.inst(IrCmd::RETURN, build.vmReg(0), build.constInt(3)); + + updateUseCounts(build.function); + computeCfgInfo(build.function); + constPropInBlockChains(build, true); + markDeadStoresInBlockChains(build); + + CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"( +bb_0: +; successors: bb_fallback_1, bb_2 +; in regs: R0, R1 +; out regs: R0, R1, R2 + %0 = LOAD_POINTER R1 + CHECK_READONLY %0, bb_fallback_1 + STORE_TAG R2, tnil + JUMP bb_2 + +bb_fallback_1: +; predecessors: bb_0 +; successors: bb_2 +; in regs: R0, R1 +; out regs: R0, R1, R2 + %9 = LOAD_POINTER R1 + STORE_POINTER R2, %9 + STORE_TAG R2, ttable + CHECK_GC + JUMP bb_2 + +bb_2: +; predecessors: bb_0, bb_fallback_1 +; in regs: R0, R1, R2 + RETURN R0, 3i + +)"); +} + +TEST_CASE_FIXTURE(IrBuilderFixture, "FullStoreHasToBeObservableFromFallbacks") +{ + IrOp entry = build.block(IrBlockKind::Internal); + IrOp fallback = build.block(IrBlockKind::Fallback); + IrOp last = build.block(IrBlockKind::Internal); + + build.beginBlock(entry); + build.inst(IrCmd::STORE_POINTER, build.vmReg(1), build.inst(IrCmd::NEW_TABLE, build.constUint(16), build.constUint(32))); + build.inst(IrCmd::STORE_TAG, build.vmReg(1), build.constTag(ttable)); + build.inst(IrCmd::CHECK_SAFE_ENV, fallback); + build.inst(IrCmd::STORE_POINTER, build.vmReg(1), build.inst(IrCmd::NEW_TABLE, build.constUint(16), build.constUint(32))); + build.inst(IrCmd::STORE_TAG, build.vmReg(1), build.constTag(ttable)); + build.inst(IrCmd::JUMP, last); + + build.beginBlock(fallback); + build.inst(IrCmd::CHECK_GC); + build.inst(IrCmd::STORE_SPLIT_TVALUE, build.vmReg(1), build.constTag(tnumber), build.constDouble(1.0)); + build.inst(IrCmd::JUMP, last); + + build.beginBlock(last); + build.inst(IrCmd::RETURN, build.vmReg(0), build.constInt(2)); + + updateUseCounts(build.function); + computeCfgInfo(build.function); + constPropInBlockChains(build, true); + markDeadStoresInBlockChains(build); + + // Even though R1 is not live in of the fallback, stack state cannot be left in a partial store state + // Either tag+pointer store should both remain before the guard, or they both have to be made after + CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"( +bb_0: +; successors: bb_fallback_1, bb_2 +; in regs: R0 +; out regs: R0, R1 + CHECK_SAFE_ENV bb_fallback_1 + %4 = NEW_TABLE 16u, 32u + STORE_SPLIT_TVALUE R1, ttable, %4 + JUMP bb_2 + +bb_fallback_1: +; predecessors: bb_0 +; successors: bb_2 +; in regs: R0 +; out regs: R0, R1 + CHECK_GC + STORE_SPLIT_TVALUE R1, tnumber, 1 + JUMP bb_2 + +bb_2: +; predecessors: bb_0, bb_fallback_1 +; in regs: R0, R1 + RETURN R0, 2i + +)"); +} + +TEST_CASE_FIXTURE(IrBuilderFixture, "FullStoreHasToBeObservableFromFallbacks2") +{ + IrOp entry = build.block(IrBlockKind::Internal); + IrOp fallback = build.block(IrBlockKind::Fallback); + IrOp last = build.block(IrBlockKind::Internal); + + build.beginBlock(entry); + build.inst(IrCmd::STORE_TAG, build.vmReg(1), build.constTag(tnumber)); // Tag store unpaired to a visible value store + build.inst(IrCmd::CHECK_SAFE_ENV, fallback); + build.inst(IrCmd::STORE_TVALUE, build.vmReg(1), build.inst(IrCmd::LOAD_TVALUE, build.vmReg(2))); + build.inst(IrCmd::JUMP, last); + + build.beginBlock(fallback); + build.inst(IrCmd::CHECK_GC); + build.inst(IrCmd::STORE_SPLIT_TVALUE, build.vmReg(1), build.constTag(tnumber), build.constDouble(1.0)); + build.inst(IrCmd::JUMP, last); + + build.beginBlock(last); + build.inst(IrCmd::RETURN, build.vmReg(0), build.constInt(2)); + + updateUseCounts(build.function); + computeCfgInfo(build.function); + constPropInBlockChains(build, true); + markDeadStoresInBlockChains(build); + + // If table tag store at the start is removed, GC assists in the fallback can observe value with a wrong tag + CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"( +bb_0: +; successors: bb_fallback_1, bb_2 +; in regs: R0, R2 +; out regs: R0, R1 + STORE_TAG R1, tnumber + CHECK_SAFE_ENV bb_fallback_1 + %2 = LOAD_TVALUE R2 + STORE_TVALUE R1, %2 + JUMP bb_2 + +bb_fallback_1: +; predecessors: bb_0 +; successors: bb_2 +; in regs: R0 +; out regs: R0, R1 + CHECK_GC + STORE_SPLIT_TVALUE R1, tnumber, 1 + JUMP bb_2 + +bb_2: +; predecessors: bb_0, bb_fallback_1 +; in regs: R0, R1 + RETURN R0, 2i + +)"); +} + +TEST_CASE_FIXTURE(IrBuilderFixture, "FullStoreHasToBeObservableFromFallbacks3") +{ + IrOp entry = build.block(IrBlockKind::Internal); + IrOp fallback = build.block(IrBlockKind::Fallback); + IrOp last = build.block(IrBlockKind::Internal); + + build.beginBlock(entry); + build.inst(IrCmd::CHECK_TAG, build.inst(IrCmd::LOAD_TAG, build.vmReg(1)), build.constTag(tfunction), fallback); + build.inst(IrCmd::STORE_POINTER, build.vmReg(1), build.inst(IrCmd::LOAD_POINTER, build.vmConst(10))); + build.inst(IrCmd::CHECK_SAFE_ENV, fallback); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(1), build.constDouble(1)); + build.inst(IrCmd::STORE_TAG, build.vmReg(1), build.constTag(tnumber)); + build.inst(IrCmd::JUMP, last); + + build.beginBlock(fallback); + build.inst(IrCmd::CHECK_GC); + build.inst(IrCmd::STORE_SPLIT_TVALUE, build.vmReg(1), build.constTag(tnumber), build.constDouble(1.0)); + build.inst(IrCmd::JUMP, last); + + build.beginBlock(last); + build.inst(IrCmd::RETURN, build.vmReg(0), build.constInt(2)); + + updateUseCounts(build.function); + computeCfgInfo(build.function); + markDeadStoresInBlockChains(build); + + // Tag check establishes that at that point, the tag of the value IS a function (as an exit here has to be with well-formed stack) + // Later additional function pointer store can be removed, even if it observable from the GC in the fallback + CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"( +bb_0: +; successors: bb_fallback_1, bb_fallback_1, bb_2 +; in regs: R0, R1 +; out regs: R0, R1 + %0 = LOAD_TAG R1 + CHECK_TAG %0, tfunction, bb_fallback_1 + CHECK_SAFE_ENV bb_fallback_1 + STORE_DOUBLE R1, 1 + STORE_TAG R1, tnumber + JUMP bb_2 + +bb_fallback_1: +; predecessors: bb_0, bb_0 +; successors: bb_2 +; in regs: R0 +; out regs: R0, R1 + CHECK_GC + STORE_SPLIT_TVALUE R1, tnumber, 1 + JUMP bb_2 + +bb_2: +; predecessors: bb_0, bb_fallback_1 +; in regs: R0, R1 + RETURN R0, 2i + +)"); +} + +TEST_CASE_FIXTURE(IrBuilderFixture, "SafePartialValueStoresWithPreservedTag") +{ + IrOp entry = build.block(IrBlockKind::Internal); + IrOp fallback = build.block(IrBlockKind::Fallback); + IrOp last = build.block(IrBlockKind::Internal); + + build.beginBlock(entry); + build.inst(IrCmd::STORE_SPLIT_TVALUE, build.vmReg(1), build.constTag(tnumber), build.constDouble(1)); + build.inst(IrCmd::CHECK_SAFE_ENV, fallback); // While R1 has to be observed in full by the fallback + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(1), build.constDouble(2)); // This partial store is safe to remove because number tag is established + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(1), build.constDouble(3)); // And so is this + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(1), build.constDouble(4)); + build.inst(IrCmd::JUMP, last); + + build.beginBlock(fallback); + build.inst(IrCmd::CHECK_GC); + build.inst(IrCmd::JUMP, last); + + build.beginBlock(last); + build.inst(IrCmd::RETURN, build.vmReg(0), build.constInt(2)); + + updateUseCounts(build.function); + computeCfgInfo(build.function); + constPropInBlockChains(build, true); + markDeadStoresInBlockChains(build); + + // If table tag store at the start is removed, GC assists in the fallback can observe value with a wrong tag + CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"( +bb_0: +; successors: bb_fallback_1, bb_2 +; in regs: R0 +; out regs: R0, R1 + STORE_SPLIT_TVALUE R1, tnumber, 1 + CHECK_SAFE_ENV bb_fallback_1 + STORE_DOUBLE R1, 4 + JUMP bb_2 + +bb_fallback_1: +; predecessors: bb_0 +; successors: bb_2 +; in regs: R0, R1 +; out regs: R0, R1 + CHECK_GC + JUMP bb_2 + +bb_2: +; predecessors: bb_0, bb_fallback_1 +; in regs: R0, R1 + RETURN R0, 2i + +)"); +} + +TEST_CASE_FIXTURE(IrBuilderFixture, "SafePartialValueStoresWithPreservedTag2") +{ + IrOp entry = build.block(IrBlockKind::Internal); + IrOp fallback = build.block(IrBlockKind::Fallback); + IrOp last = build.block(IrBlockKind::Internal); + + build.beginBlock(entry); + build.inst(IrCmd::STORE_SPLIT_TVALUE, build.vmReg(1), build.constTag(tnumber), build.constDouble(1)); + build.inst(IrCmd::CHECK_SAFE_ENV, fallback); // While R1 has to be observed in full by the fallback + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(1), build.constDouble(2)); // This partial store is safe to remove because tag is established + build.inst(IrCmd::STORE_SPLIT_TVALUE, build.vmReg(1), build.constTag(tnumber), build.constDouble(4)); + build.inst(IrCmd::JUMP, last); + + build.beginBlock(fallback); + build.inst(IrCmd::CHECK_GC); + build.inst(IrCmd::JUMP, last); + + build.beginBlock(last); + build.inst(IrCmd::RETURN, build.vmReg(0), build.constInt(2)); + + updateUseCounts(build.function); + computeCfgInfo(build.function); + constPropInBlockChains(build, true); + markDeadStoresInBlockChains(build); + + // If table tag store at the start is removed, GC assists in the fallback can observe value with a wrong tag + CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"( +bb_0: +; successors: bb_fallback_1, bb_2 +; in regs: R0 +; out regs: R0, R1 + STORE_SPLIT_TVALUE R1, tnumber, 1 + CHECK_SAFE_ENV bb_fallback_1 + STORE_SPLIT_TVALUE R1, tnumber, 4 + JUMP bb_2 + +bb_fallback_1: +; predecessors: bb_0 +; successors: bb_2 +; in regs: R0, R1 +; out regs: R0, R1 + CHECK_GC + JUMP bb_2 + +bb_2: +; predecessors: bb_0, bb_fallback_1 +; in regs: R0, R1 + RETURN R0, 2i + +)"); +} + +TEST_CASE_FIXTURE(IrBuilderFixture, "DoNotReturnWithPartialStores") +{ + IrOp entry = build.block(IrBlockKind::Internal); + IrOp success = build.block(IrBlockKind::Internal); + IrOp fail = build.block(IrBlockKind::Internal); + IrOp exit = build.block(IrBlockKind::Internal); + + build.beginBlock(entry); + build.inst(IrCmd::STORE_POINTER, build.vmReg(1), build.inst(IrCmd::NEW_TABLE, build.constUint(0), build.constUint(0))); + build.inst(IrCmd::STORE_TAG, build.vmReg(1), build.constTag(ttable)); + IrOp toUint = build.inst(IrCmd::NUM_TO_UINT, build.constDouble(-1)); + IrOp bitAnd = build.inst(IrCmd::BITAND_UINT, toUint, build.constInt(4)); + build.inst(IrCmd::JUMP_CMP_INT, bitAnd, build.constInt(0), build.cond(IrCondition::Equal), success, fail); + + build.beginBlock(success); + build.inst(IrCmd::STORE_INT, build.vmReg(1), build.constInt(0)); + build.inst(IrCmd::JUMP, exit); + + build.beginBlock(fail); + build.inst(IrCmd::STORE_INT, build.vmReg(1), build.constInt(1)); + build.inst(IrCmd::JUMP, exit); + + build.beginBlock(exit); + build.inst(IrCmd::STORE_TAG, build.vmReg(1), build.constTag(tboolean)); + build.inst(IrCmd::RETURN, build.vmReg(0), build.constInt(1)); + + updateUseCounts(build.function); + computeCfgInfo(build.function); + constPropInBlockChains(build, true); + markDeadStoresInBlockChains(build); + + // Even though R1 is not live out at return, we stored table tag followed by an integer value + // Boolean tag store has to remain, even if unused, because all stack slots are visible to GC + CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"( +bb_0: +; successors: bb_1, bb_2 +; in regs: R0 +; out regs: R0 + %0 = NEW_TABLE 0u, 0u + STORE_POINTER R1, %0 + STORE_TAG R1, ttable + %3 = NUM_TO_UINT -1 + %4 = BITAND_UINT %3, 4i + JUMP_CMP_INT %4, 0i, eq, bb_1, bb_2 + +bb_1: +; predecessors: bb_0 +; successors: bb_3 +; in regs: R0 +; out regs: R0 + STORE_INT R1, 0i + JUMP bb_3 + +bb_2: +; predecessors: bb_0 +; successors: bb_3 +; in regs: R0 +; out regs: R0 + STORE_INT R1, 1i + JUMP bb_3 + +bb_3: +; predecessors: bb_1, bb_2 +; in regs: R0 + STORE_TAG R1, tboolean + RETURN R0, 1i + +)"); +} + +TEST_CASE_FIXTURE(IrBuilderFixture, "PartialOverFullValue") +{ + IrOp entry = build.block(IrBlockKind::Internal); + + build.beginBlock(entry); + build.inst(IrCmd::STORE_SPLIT_TVALUE, build.vmReg(0), build.constTag(tnumber), build.constDouble(1.0)); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(0), build.constDouble(2.0)); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(0), build.constDouble(4.0)); + build.inst( + IrCmd::STORE_SPLIT_TVALUE, build.vmReg(0), build.constTag(ttable), build.inst(IrCmd::NEW_TABLE, build.constUint(16), build.constUint(32)) + ); + build.inst(IrCmd::STORE_POINTER, build.vmReg(0), build.inst(IrCmd::NEW_TABLE, build.constUint(8), build.constUint(16))); + build.inst(IrCmd::STORE_POINTER, build.vmReg(0), build.inst(IrCmd::NEW_TABLE, build.constUint(4), build.constUint(8))); + build.inst(IrCmd::STORE_SPLIT_TVALUE, build.vmReg(0), build.constTag(tnumber), build.constDouble(1.0)); + build.inst(IrCmd::STORE_TAG, build.vmReg(0), build.constTag(tstring)); + IrOp newtable = build.inst(IrCmd::NEW_TABLE, build.constUint(16), build.constUint(32)); + build.inst(IrCmd::STORE_TAG, build.vmReg(0), build.constTag(ttable)); + build.inst(IrCmd::STORE_POINTER, build.vmReg(0), newtable); + build.inst(IrCmd::RETURN, build.vmReg(0), build.constInt(1)); + + updateUseCounts(build.function); + computeCfgInfo(build.function); + markDeadStoresInBlockChains(build); + + CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"( +bb_0: + %11 = NEW_TABLE 16u, 32u + STORE_SPLIT_TVALUE R0, ttable, %11 + RETURN R0, 1i + +)"); +} + +TEST_SUITE_END(); + +TEST_SUITE_BEGIN("Dump"); + +TEST_CASE_FIXTURE(IrBuilderFixture, "ToDot") +{ + IrOp entry = build.block(IrBlockKind::Internal); + IrOp a = build.block(IrBlockKind::Internal); + IrOp b = build.block(IrBlockKind::Internal); + IrOp exit = build.block(IrBlockKind::Internal); + + build.beginBlock(entry); + build.inst(IrCmd::JUMP_EQ_TAG, build.inst(IrCmd::LOAD_TAG, build.vmReg(0)), build.constTag(tnumber), a, b); + + build.beginBlock(a); + build.inst(IrCmd::STORE_TVALUE, build.vmReg(2), build.inst(IrCmd::LOAD_TVALUE, build.vmReg(1))); + build.inst(IrCmd::JUMP, exit); + + build.beginBlock(b); + build.inst(IrCmd::STORE_TVALUE, build.vmReg(3), build.inst(IrCmd::LOAD_TVALUE, build.vmReg(1))); + build.inst(IrCmd::JUMP, exit); + + build.beginBlock(exit); + build.inst(IrCmd::RETURN, build.vmReg(2), build.constInt(2)); + + updateUseCounts(build.function); + computeCfgInfo(build.function); + + // note: we don't validate the output of these to avoid test churn when formatting changes; we run these to make sure they don't assert/crash + toDot(build.function, /* includeInst= */ true); + toDotCfg(build.function); + toDotDjGraph(build.function); +} + TEST_SUITE_END(); diff --git a/tests/IrCallWrapperX64.test.cpp b/tests/IrCallWrapperX64.test.cpp new file mode 100644 index 000000000..8336a634b --- /dev/null +++ b/tests/IrCallWrapperX64.test.cpp @@ -0,0 +1,562 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/IrCallWrapperX64.h" +#include "Luau/IrRegAllocX64.h" + +#include "doctest.h" + +using namespace Luau::CodeGen; +using namespace Luau::CodeGen::X64; + +class IrCallWrapperX64Fixture +{ +public: + IrCallWrapperX64Fixture(ABIX64 abi = ABIX64::Windows) + : build(/* logText */ true, abi) + , regs(build, function, nullptr) + , callWrap(regs, build, ~0u) + { + } + + void checkMatch(std::string expected) + { + regs.assertAllFree(); + + build.finalize(); + + CHECK("\n" + build.text == expected); + } + + AssemblyBuilderX64 build; + IrFunction function; + IrRegAllocX64 regs; + IrCallWrapperX64 callWrap; + + // Tests rely on these to force interference between registers + static constexpr RegisterX64 rArg1 = rcx; + static constexpr RegisterX64 rArg1d = ecx; + static constexpr RegisterX64 rArg2 = rdx; + static constexpr RegisterX64 rArg2d = edx; + static constexpr RegisterX64 rArg3 = r8; + static constexpr RegisterX64 rArg3d = r8d; + static constexpr RegisterX64 rArg4 = r9; + static constexpr RegisterX64 rArg4d = r9d; +}; + +class IrCallWrapperX64FixtureSystemV : public IrCallWrapperX64Fixture +{ +public: + IrCallWrapperX64FixtureSystemV() + : IrCallWrapperX64Fixture(ABIX64::SystemV) + { + } +}; + +TEST_SUITE_BEGIN("IrCallWrapperX64"); + +TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "SimpleRegs") +{ + ScopedRegX64 tmp1{regs, regs.takeReg(rax, kInvalidInstIdx)}; + ScopedRegX64 tmp2{regs, regs.takeReg(rArg2, kInvalidInstIdx)}; + callWrap.addArgument(SizeX64::qword, tmp1); + callWrap.addArgument(SizeX64::qword, tmp2); // Already in its place + callWrap.call(qword[r12]); + + checkMatch(R"( + mov rcx,rax + call qword ptr [r12] +)"); +} + +TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "TrickyUse1") +{ + ScopedRegX64 tmp1{regs, regs.takeReg(rArg1, kInvalidInstIdx)}; + callWrap.addArgument(SizeX64::qword, tmp1.reg); // Already in its place + callWrap.addArgument(SizeX64::qword, tmp1.release()); + callWrap.call(qword[r12]); + + checkMatch(R"( + mov rdx,rcx + call qword ptr [r12] +)"); +} + +TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "TrickyUse2") +{ + ScopedRegX64 tmp1{regs, regs.takeReg(rArg1, kInvalidInstIdx)}; + callWrap.addArgument(SizeX64::qword, qword[tmp1.reg]); + callWrap.addArgument(SizeX64::qword, tmp1.release()); + callWrap.call(qword[r12]); + + checkMatch(R"( + mov rdx,rcx + mov rcx,qword ptr [rcx] + call qword ptr [r12] +)"); +} + +TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "SimpleMemImm") +{ + ScopedRegX64 tmp1{regs, regs.takeReg(rax, kInvalidInstIdx)}; + ScopedRegX64 tmp2{regs, regs.takeReg(rsi, kInvalidInstIdx)}; + callWrap.addArgument(SizeX64::dword, 32); + callWrap.addArgument(SizeX64::dword, -1); + callWrap.addArgument(SizeX64::qword, qword[r14 + 32]); + callWrap.addArgument(SizeX64::qword, qword[tmp1.release() + tmp2.release()]); + callWrap.call(qword[r12]); + + checkMatch(R"( + mov r8,qword ptr [r14+020h] + mov r9,qword ptr [rax+rsi] + mov ecx,20h + mov edx,FFFFFFFFh + call qword ptr [r12] +)"); +} + +TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "SimpleStackArgs") +{ + ScopedRegX64 tmp{regs, regs.takeReg(rax, kInvalidInstIdx)}; + callWrap.addArgument(SizeX64::qword, tmp); + callWrap.addArgument(SizeX64::qword, qword[r14 + 16]); + callWrap.addArgument(SizeX64::qword, qword[r14 + 32]); + callWrap.addArgument(SizeX64::qword, qword[r14 + 48]); + callWrap.addArgument(SizeX64::dword, 1); + callWrap.addArgument(SizeX64::qword, qword[r13]); + callWrap.call(qword[r12]); + + checkMatch(R"( + mov rdx,qword ptr [r13] + mov qword ptr [rsp+028h],rdx + mov rcx,rax + mov rdx,qword ptr [r14+010h] + mov r8,qword ptr [r14+020h] + mov r9,qword ptr [r14+030h] + mov dword ptr [rsp+020h],1 + call qword ptr [r12] +)"); +} + +TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "FixedRegisters") +{ + callWrap.addArgument(SizeX64::dword, 1); + callWrap.addArgument(SizeX64::qword, 2); + callWrap.addArgument(SizeX64::qword, 3); + callWrap.addArgument(SizeX64::qword, 4); + callWrap.addArgument(SizeX64::qword, r14); + callWrap.call(qword[r12]); + + checkMatch(R"( + mov qword ptr [rsp+020h],r14 + mov ecx,1 + mov rdx,2 + mov r8,3 + mov r9,4 + call qword ptr [r12] +)"); +} + +TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "EasyInterference") +{ + ScopedRegX64 tmp1{regs, regs.takeReg(rdi, kInvalidInstIdx)}; + ScopedRegX64 tmp2{regs, regs.takeReg(rsi, kInvalidInstIdx)}; + ScopedRegX64 tmp3{regs, regs.takeReg(rArg2, kInvalidInstIdx)}; + ScopedRegX64 tmp4{regs, regs.takeReg(rArg1, kInvalidInstIdx)}; + callWrap.addArgument(SizeX64::qword, tmp1); + callWrap.addArgument(SizeX64::qword, tmp2); + callWrap.addArgument(SizeX64::qword, tmp3); + callWrap.addArgument(SizeX64::qword, tmp4); + callWrap.call(qword[r12]); + + checkMatch(R"( + mov r8,rdx + mov rdx,rsi + mov r9,rcx + mov rcx,rdi + call qword ptr [r12] +)"); +} + +TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "FakeInterference") +{ + ScopedRegX64 tmp1{regs, regs.takeReg(rArg1, kInvalidInstIdx)}; + ScopedRegX64 tmp2{regs, regs.takeReg(rArg2, kInvalidInstIdx)}; + callWrap.addArgument(SizeX64::qword, qword[tmp1.release() + 8]); + callWrap.addArgument(SizeX64::qword, qword[tmp2.release() + 8]); + callWrap.call(qword[r12]); + + checkMatch(R"( + mov rcx,qword ptr [rcx+8] + mov rdx,qword ptr [rdx+8] + call qword ptr [r12] +)"); +} + +TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "HardInterferenceInt") +{ + ScopedRegX64 tmp1{regs, regs.takeReg(rArg4, kInvalidInstIdx)}; + ScopedRegX64 tmp2{regs, regs.takeReg(rArg3, kInvalidInstIdx)}; + ScopedRegX64 tmp3{regs, regs.takeReg(rArg2, kInvalidInstIdx)}; + ScopedRegX64 tmp4{regs, regs.takeReg(rArg1, kInvalidInstIdx)}; + callWrap.addArgument(SizeX64::qword, tmp1); + callWrap.addArgument(SizeX64::qword, tmp2); + callWrap.addArgument(SizeX64::qword, tmp3); + callWrap.addArgument(SizeX64::qword, tmp4); + callWrap.call(qword[r12]); + + checkMatch(R"( + mov rax,r9 + mov r9,rcx + mov rcx,rax + mov rax,r8 + mov r8,rdx + mov rdx,rax + call qword ptr [r12] +)"); +} + +TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "HardInterferenceInt2") +{ + ScopedRegX64 tmp1{regs, regs.takeReg(rArg4d, kInvalidInstIdx)}; + ScopedRegX64 tmp2{regs, regs.takeReg(rArg3d, kInvalidInstIdx)}; + ScopedRegX64 tmp3{regs, regs.takeReg(rArg2d, kInvalidInstIdx)}; + ScopedRegX64 tmp4{regs, regs.takeReg(rArg1d, kInvalidInstIdx)}; + callWrap.addArgument(SizeX64::dword, tmp1); + callWrap.addArgument(SizeX64::dword, tmp2); + callWrap.addArgument(SizeX64::dword, tmp3); + callWrap.addArgument(SizeX64::dword, tmp4); + callWrap.call(qword[r12]); + + checkMatch(R"( + mov eax,r9d + mov r9d,ecx + mov ecx,eax + mov eax,r8d + mov r8d,edx + mov edx,eax + call qword ptr [r12] +)"); +} + +TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "HardInterferenceFp") +{ + ScopedRegX64 tmp1{regs, regs.takeReg(xmm1, kInvalidInstIdx)}; + ScopedRegX64 tmp2{regs, regs.takeReg(xmm0, kInvalidInstIdx)}; + callWrap.addArgument(SizeX64::xmmword, tmp1); + callWrap.addArgument(SizeX64::xmmword, tmp2); + callWrap.call(qword[r12]); + + checkMatch(R"( + vmovsd xmm2,xmm1,xmm1 + vmovsd xmm1,xmm0,xmm0 + vmovsd xmm0,xmm2,xmm2 + call qword ptr [r12] +)"); +} + +TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "HardInterferenceBoth") +{ + ScopedRegX64 int1{regs, regs.takeReg(rArg2, kInvalidInstIdx)}; + ScopedRegX64 int2{regs, regs.takeReg(rArg1, kInvalidInstIdx)}; + ScopedRegX64 fp1{regs, regs.takeReg(xmm3, kInvalidInstIdx)}; + ScopedRegX64 fp2{regs, regs.takeReg(xmm2, kInvalidInstIdx)}; + callWrap.addArgument(SizeX64::qword, int1); + callWrap.addArgument(SizeX64::qword, int2); + callWrap.addArgument(SizeX64::xmmword, fp1); + callWrap.addArgument(SizeX64::xmmword, fp2); + callWrap.call(qword[r12]); + + checkMatch(R"( + mov rax,rdx + mov rdx,rcx + mov rcx,rax + vmovsd xmm0,xmm3,xmm3 + vmovsd xmm3,xmm2,xmm2 + vmovsd xmm2,xmm0,xmm0 + call qword ptr [r12] +)"); +} + +TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "FakeMultiuseInterferenceMem") +{ + ScopedRegX64 tmp1{regs, regs.takeReg(rArg1, kInvalidInstIdx)}; + ScopedRegX64 tmp2{regs, regs.takeReg(rArg2, kInvalidInstIdx)}; + callWrap.addArgument(SizeX64::qword, qword[tmp1.reg + tmp2.reg + 8]); + callWrap.addArgument(SizeX64::qword, qword[tmp2.reg + 16]); + tmp1.release(); + tmp2.release(); + callWrap.call(qword[r12]); + + checkMatch(R"( + mov rcx,qword ptr [rcx+rdx+8] + mov rdx,qword ptr [rdx+010h] + call qword ptr [r12] +)"); +} + +TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "HardMultiuseInterferenceMem1") +{ + ScopedRegX64 tmp1{regs, regs.takeReg(rArg1, kInvalidInstIdx)}; + ScopedRegX64 tmp2{regs, regs.takeReg(rArg2, kInvalidInstIdx)}; + callWrap.addArgument(SizeX64::qword, qword[tmp1.reg + tmp2.reg + 8]); + callWrap.addArgument(SizeX64::qword, qword[tmp1.reg + 16]); + tmp1.release(); + tmp2.release(); + callWrap.call(qword[r12]); + + checkMatch(R"( + mov rax,rcx + mov rcx,qword ptr [rax+rdx+8] + mov rdx,qword ptr [rax+010h] + call qword ptr [r12] +)"); +} + +TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "HardMultiuseInterferenceMem2") +{ + ScopedRegX64 tmp1{regs, regs.takeReg(rArg1, kInvalidInstIdx)}; + ScopedRegX64 tmp2{regs, regs.takeReg(rArg2, kInvalidInstIdx)}; + callWrap.addArgument(SizeX64::qword, qword[tmp1.reg + tmp2.reg + 8]); + callWrap.addArgument(SizeX64::qword, qword[tmp1.reg + tmp2.reg + 16]); + tmp1.release(); + tmp2.release(); + callWrap.call(qword[r12]); + + checkMatch(R"( + mov rax,rcx + mov rcx,qword ptr [rax+rdx+8] + mov rdx,qword ptr [rax+rdx+010h] + call qword ptr [r12] +)"); +} + +TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "HardMultiuseInterferenceMem3") +{ + ScopedRegX64 tmp1{regs, regs.takeReg(rArg3, kInvalidInstIdx)}; + ScopedRegX64 tmp2{regs, regs.takeReg(rArg2, kInvalidInstIdx)}; + ScopedRegX64 tmp3{regs, regs.takeReg(rArg1, kInvalidInstIdx)}; + callWrap.addArgument(SizeX64::qword, qword[tmp1.reg + tmp2.reg + 8]); + callWrap.addArgument(SizeX64::qword, qword[tmp2.reg + tmp3.reg + 16]); + callWrap.addArgument(SizeX64::qword, qword[tmp3.reg + tmp1.reg + 16]); + tmp1.release(); + tmp2.release(); + tmp3.release(); + callWrap.call(qword[r12]); + + checkMatch(R"( + mov rax,r8 + mov r8,qword ptr [rcx+rax+010h] + mov rbx,rdx + mov rdx,qword ptr [rbx+rcx+010h] + mov rcx,qword ptr [rax+rbx+8] + call qword ptr [r12] +)"); +} + +TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "InterferenceWithCallArg1") +{ + ScopedRegX64 tmp1{regs, regs.takeReg(rArg1, kInvalidInstIdx)}; + callWrap.addArgument(SizeX64::qword, qword[tmp1.reg + 8]); + callWrap.call(qword[tmp1.release() + 16]); + + checkMatch(R"( + mov rax,rcx + mov rcx,qword ptr [rax+8] + call qword ptr [rax+010h] +)"); +} + +TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "InterferenceWithCallArg2") +{ + ScopedRegX64 tmp1{regs, regs.takeReg(rArg1, kInvalidInstIdx)}; + ScopedRegX64 tmp2{regs, regs.takeReg(rArg2, kInvalidInstIdx)}; + callWrap.addArgument(SizeX64::qword, tmp2); + callWrap.call(qword[tmp1.release() + 16]); + + checkMatch(R"( + mov rax,rcx + mov rcx,rdx + call qword ptr [rax+010h] +)"); +} + +TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "InterferenceWithCallArg3") +{ + ScopedRegX64 tmp1{regs, regs.takeReg(rArg1, kInvalidInstIdx)}; + callWrap.addArgument(SizeX64::qword, tmp1.reg); + callWrap.call(qword[tmp1.release() + 16]); + + checkMatch(R"( + call qword ptr [rcx+010h] +)"); +} + +TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "WithLastIrInstUse1") +{ + IrInst irInst1; + IrOp irOp1 = {IrOpKind::Inst, 0}; + irInst1.regX64 = regs.takeReg(xmm0, irOp1.index); + irInst1.lastUse = 1; + function.instructions.push_back(irInst1); + callWrap.instIdx = irInst1.lastUse; + + callWrap.addArgument(SizeX64::xmmword, irInst1.regX64, irOp1); // Already in its place + callWrap.addArgument(SizeX64::xmmword, qword[r12 + 8]); + callWrap.call(qword[r12]); + + checkMatch(R"( + vmovsd xmm1,qword ptr [r12+8] + call qword ptr [r12] +)"); +} + +TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "WithLastIrInstUse2") +{ + IrInst irInst1; + IrOp irOp1 = {IrOpKind::Inst, 0}; + irInst1.regX64 = regs.takeReg(xmm0, irOp1.index); + irInst1.lastUse = 1; + function.instructions.push_back(irInst1); + callWrap.instIdx = irInst1.lastUse; + + callWrap.addArgument(SizeX64::xmmword, qword[r12 + 8]); + callWrap.addArgument(SizeX64::xmmword, irInst1.regX64, irOp1); + callWrap.call(qword[r12]); + + checkMatch(R"( + vmovsd xmm1,xmm0,xmm0 + vmovsd xmm0,qword ptr [r12+8] + call qword ptr [r12] +)"); +} + +TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "WithLastIrInstUse3") +{ + IrInst irInst1; + IrOp irOp1 = {IrOpKind::Inst, 0}; + irInst1.regX64 = regs.takeReg(xmm0, irOp1.index); + irInst1.lastUse = 1; + function.instructions.push_back(irInst1); + callWrap.instIdx = irInst1.lastUse; + + callWrap.addArgument(SizeX64::xmmword, irInst1.regX64, irOp1); + callWrap.addArgument(SizeX64::xmmword, irInst1.regX64, irOp1); + callWrap.call(qword[r12]); + + checkMatch(R"( + vmovsd xmm1,xmm0,xmm0 + call qword ptr [r12] +)"); +} + +TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "WithLastIrInstUse4") +{ + IrInst irInst1; + IrOp irOp1 = {IrOpKind::Inst, 0}; + irInst1.regX64 = regs.takeReg(rax, irOp1.index); + irInst1.lastUse = 1; + function.instructions.push_back(irInst1); + callWrap.instIdx = irInst1.lastUse; + + ScopedRegX64 tmp{regs, regs.takeReg(rdx, kInvalidInstIdx)}; + callWrap.addArgument(SizeX64::qword, r15); + callWrap.addArgument(SizeX64::qword, irInst1.regX64, irOp1); + callWrap.addArgument(SizeX64::qword, tmp); + callWrap.call(qword[r12]); + + checkMatch(R"( + mov rcx,r15 + mov r8,rdx + mov rdx,rax + call qword ptr [r12] +)"); +} + +TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "ExtraCoverage") +{ + ScopedRegX64 tmp1{regs, regs.takeReg(rArg1, kInvalidInstIdx)}; + ScopedRegX64 tmp2{regs, regs.takeReg(rArg2, kInvalidInstIdx)}; + callWrap.addArgument(SizeX64::qword, addr[r12 + 8]); + callWrap.addArgument(SizeX64::qword, addr[r12 + 16]); + callWrap.addArgument(SizeX64::xmmword, xmmword[r13]); + callWrap.call(qword[tmp1.release() + tmp2.release()]); + + checkMatch(R"( + vmovups xmm2,xmmword ptr [r13] + mov rax,rcx + lea rcx,[r12+8] + mov rbx,rdx + lea rdx,[r12+010h] + call qword ptr [rax+rbx] +)"); +} + +TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "AddressInStackArguments") +{ + callWrap.addArgument(SizeX64::dword, 1); + callWrap.addArgument(SizeX64::dword, 2); + callWrap.addArgument(SizeX64::dword, 3); + callWrap.addArgument(SizeX64::dword, 4); + callWrap.addArgument(SizeX64::qword, addr[r12 + 16]); + callWrap.call(qword[r14]); + + checkMatch(R"( + lea rax,[r12+010h] + mov qword ptr [rsp+020h],rax + mov ecx,1 + mov edx,2 + mov r8d,3 + mov r9d,4 + call qword ptr [r14] +)"); +} + +TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "ImmediateConflictWithFunction") +{ + ScopedRegX64 tmp1{regs, regs.takeReg(rArg1, kInvalidInstIdx)}; + ScopedRegX64 tmp2{regs, regs.takeReg(rArg2, kInvalidInstIdx)}; + + callWrap.addArgument(SizeX64::dword, 1); + callWrap.addArgument(SizeX64::dword, 2); + callWrap.call(qword[tmp1.release() + tmp2.release()]); + + checkMatch(R"( + mov rax,rcx + mov ecx,1 + mov rbx,rdx + mov edx,2 + call qword ptr [rax+rbx] +)"); +} + +TEST_CASE_FIXTURE(IrCallWrapperX64FixtureSystemV, "SuggestedConflictWithReserved") +{ + ScopedRegX64 tmp{regs, regs.takeReg(r9, kInvalidInstIdx)}; + + IrCallWrapperX64 callWrap(regs, build); + callWrap.addArgument(SizeX64::qword, r12); + callWrap.addArgument(SizeX64::qword, r13); + callWrap.addArgument(SizeX64::qword, r14); + callWrap.addArgument(SizeX64::dword, 2); + callWrap.addArgument(SizeX64::qword, 1); + + RegisterX64 reg = callWrap.suggestNextArgumentRegister(SizeX64::dword); + build.mov(reg, 10); + callWrap.addArgument(SizeX64::dword, reg); + + callWrap.call(tmp.release()); + + checkMatch(R"( + mov eax,Ah + mov rdi,r12 + mov rsi,r13 + mov rdx,r14 + mov rcx,r9 + mov r9d,eax + mov rax,rcx + mov ecx,2 + mov r8,1 + call rax +)"); +} + +TEST_SUITE_END(); diff --git a/tests/IrLowering.test.cpp b/tests/IrLowering.test.cpp new file mode 100644 index 000000000..396678468 --- /dev/null +++ b/tests/IrLowering.test.cpp @@ -0,0 +1,1997 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "lua.h" +#include "lualib.h" + +#include "Luau/BytecodeBuilder.h" +#include "Luau/CodeGen.h" +#include "Luau/Compiler.h" +#include "Luau/Parser.h" +#include "Luau/IrBuilder.h" + +#include "doctest.h" +#include "ScopedFlags.h" +#include "ConformanceIrHooks.h" + +#include +#include + +static std::string getCodegenAssembly(const char* source, bool includeIrTypes = false, int debugLevel = 1) +{ + Luau::CodeGen::AssemblyOptions options; + + options.compilationOptions.hooks.vectorAccessBytecodeType = vectorAccessBytecodeType; + options.compilationOptions.hooks.vectorNamecallBytecodeType = vectorNamecallBytecodeType; + options.compilationOptions.hooks.vectorAccess = vectorAccess; + options.compilationOptions.hooks.vectorNamecall = vectorNamecall; + + options.compilationOptions.hooks.userdataAccessBytecodeType = userdataAccessBytecodeType; + options.compilationOptions.hooks.userdataMetamethodBytecodeType = userdataMetamethodBytecodeType; + options.compilationOptions.hooks.userdataNamecallBytecodeType = userdataNamecallBytecodeType; + options.compilationOptions.hooks.userdataAccess = userdataAccess; + options.compilationOptions.hooks.userdataMetamethod = userdataMetamethod; + options.compilationOptions.hooks.userdataNamecall = userdataNamecall; + + // For IR, we don't care about assembly, but we want a stable target + options.target = Luau::CodeGen::AssemblyOptions::Target::X64_SystemV; + + options.outputBinary = false; + options.includeAssembly = false; + options.includeIr = true; + options.includeOutlinedCode = false; + options.includeIrTypes = includeIrTypes; + + options.includeIrPrefix = Luau::CodeGen::IncludeIrPrefix::No; + options.includeUseInfo = Luau::CodeGen::IncludeUseInfo::No; + options.includeCfgInfo = Luau::CodeGen::IncludeCfgInfo::No; + options.includeRegFlowInfo = Luau::CodeGen::IncludeRegFlowInfo::No; + + Luau::Allocator allocator; + Luau::AstNameTable names(allocator); + Luau::ParseResult result = Luau::Parser::parse(source, strlen(source), names, allocator); + + if (!result.errors.empty()) + throw Luau::ParseErrors(result.errors); + + Luau::CompileOptions copts = {}; + + copts.optimizationLevel = 2; + copts.debugLevel = debugLevel; + copts.typeInfoLevel = 1; + copts.vectorCtor = "vector"; + copts.vectorType = "vector"; + + static const char* kUserdataCompileTypes[] = {"vec2", "color", "mat3", nullptr}; + copts.userdataTypes = kUserdataCompileTypes; + + Luau::BytecodeBuilder bcb; + Luau::compileOrThrow(bcb, result, names, copts); + + std::string bytecode = bcb.getBytecode(); + std::unique_ptr globalState(luaL_newstate(), lua_close); + lua_State* L = globalState.get(); + + // Runtime mapping is specifically created to NOT match the compilation mapping + options.compilationOptions.userdataTypes = kUserdataRunTypes; + + if (Luau::CodeGen::isSupported()) + { + // Type remapper requires the codegen runtime + Luau::CodeGen::create(L); + + Luau::CodeGen::setUserdataRemapper( + L, + kUserdataRunTypes, + [](void* context, const char* str, size_t len) -> uint8_t + { + const char** types = (const char**)context; + + uint8_t index = 0; + + std::string_view sv{str, len}; + + for (; *types; ++types) + { + if (sv == *types) + return index; + + index++; + } + + return 0xff; + } + ); + } + + if (luau_load(L, "name", bytecode.data(), bytecode.size(), 0) == 0) + return Luau::CodeGen::getAssembly(L, -1, options, nullptr); + + FAIL("Failed to load bytecode"); + return ""; +} + +static std::string getCodegenHeader(const char* source) +{ + std::string assembly = getCodegenAssembly(source, /* includeIrTypes */ true, /* debugLevel */ 2); + + auto bytecodeStart = assembly.find("bb_bytecode_0:"); + + if (bytecodeStart == std::string::npos) + bytecodeStart = assembly.find("bb_0:"); + + REQUIRE(bytecodeStart != std::string::npos); + + return assembly.substr(0, bytecodeStart); +} + +TEST_SUITE_BEGIN("IrLowering"); + +TEST_CASE("VectorReciprocal") +{ + CHECK_EQ( + "\n" + getCodegenAssembly(R"( +local function vecrcp(a: vector) + return 1 / a +end +)"), + R"( +; function vecrcp($arg0) line 2 +bb_0: + CHECK_TAG R0, tvector, exit(entry) + JUMP bb_2 +bb_2: + JUMP bb_bytecode_1 +bb_bytecode_1: + %6 = NUM_TO_VEC 1 + %7 = LOAD_TVALUE R0 + %8 = DIV_VEC %6, %7 + %9 = TAG_VECTOR %8 + STORE_TVALUE R1, %9 + INTERRUPT 1u + RETURN R1, 1i +)" + ); +} + +TEST_CASE("VectorComponentRead") +{ + CHECK_EQ( + "\n" + getCodegenAssembly(R"( +local function compsum(a: vector) + return a.X + a.Y + a.Z +end +)"), + R"( +; function compsum($arg0) line 2 +bb_0: + CHECK_TAG R0, tvector, exit(entry) + JUMP bb_2 +bb_2: + JUMP bb_bytecode_1 +bb_bytecode_1: + %6 = LOAD_FLOAT R0, 0i + %11 = LOAD_FLOAT R0, 4i + %20 = ADD_NUM %6, %11 + %25 = LOAD_FLOAT R0, 8i + %34 = ADD_NUM %20, %25 + STORE_DOUBLE R1, %34 + STORE_TAG R1, tnumber + INTERRUPT 8u + RETURN R1, 1i +)" + ); +} + +TEST_CASE("VectorAdd") +{ + CHECK_EQ( + "\n" + getCodegenAssembly(R"( +local function vec3add(a: vector, b: vector) + return a + b +end +)"), + R"( +; function vec3add($arg0, $arg1) line 2 +bb_0: + CHECK_TAG R0, tvector, exit(entry) + CHECK_TAG R1, tvector, exit(entry) + JUMP bb_2 +bb_2: + JUMP bb_bytecode_1 +bb_bytecode_1: + %10 = LOAD_TVALUE R0 + %11 = LOAD_TVALUE R1 + %12 = ADD_VEC %10, %11 + %13 = TAG_VECTOR %12 + STORE_TVALUE R2, %13 + INTERRUPT 1u + RETURN R2, 1i +)" + ); +} + +TEST_CASE("VectorMinus") +{ + CHECK_EQ( + "\n" + getCodegenAssembly(R"( +local function vec3minus(a: vector) + return -a +end +)"), + R"( +; function vec3minus($arg0) line 2 +bb_0: + CHECK_TAG R0, tvector, exit(entry) + JUMP bb_2 +bb_2: + JUMP bb_bytecode_1 +bb_bytecode_1: + %6 = LOAD_TVALUE R0 + %7 = UNM_VEC %6 + %8 = TAG_VECTOR %7 + STORE_TVALUE R1, %8 + INTERRUPT 1u + RETURN R1, 1i +)" + ); +} + +TEST_CASE("VectorSubMulDiv") +{ + CHECK_EQ( + "\n" + getCodegenAssembly(R"( +local function vec3combo(a: vector, b: vector, c: vector, d: vector) + return a * b - c / d +end +)"), + R"( +; function vec3combo($arg0, $arg1, $arg2, $arg3) line 2 +bb_0: + CHECK_TAG R0, tvector, exit(entry) + CHECK_TAG R1, tvector, exit(entry) + CHECK_TAG R2, tvector, exit(entry) + CHECK_TAG R3, tvector, exit(entry) + JUMP bb_2 +bb_2: + JUMP bb_bytecode_1 +bb_bytecode_1: + %14 = LOAD_TVALUE R0 + %15 = LOAD_TVALUE R1 + %16 = MUL_VEC %14, %15 + %23 = LOAD_TVALUE R2 + %24 = LOAD_TVALUE R3 + %25 = DIV_VEC %23, %24 + %34 = SUB_VEC %16, %25 + %35 = TAG_VECTOR %34 + STORE_TVALUE R4, %35 + INTERRUPT 3u + RETURN R4, 1i +)" + ); +} + +TEST_CASE("VectorSubMulDiv2") +{ + CHECK_EQ( + "\n" + getCodegenAssembly(R"( +local function vec3combo(a: vector) + local tmp = a * a + return (tmp - tmp) / (tmp + tmp) +end +)"), + R"( +; function vec3combo($arg0) line 2 +bb_0: + CHECK_TAG R0, tvector, exit(entry) + JUMP bb_2 +bb_2: + JUMP bb_bytecode_1 +bb_bytecode_1: + %8 = LOAD_TVALUE R0 + %10 = MUL_VEC %8, %8 + %19 = SUB_VEC %10, %10 + %28 = ADD_VEC %10, %10 + %37 = DIV_VEC %19, %28 + %38 = TAG_VECTOR %37 + STORE_TVALUE R2, %38 + INTERRUPT 4u + RETURN R2, 1i +)" + ); +} + +TEST_CASE("VectorMulDivMixed") +{ + CHECK_EQ( + "\n" + getCodegenAssembly(R"( +local function vec3combo(a: vector, b: vector, c: vector, d: vector) + return a * 2 + b / 4 + 0.5 * c + 40 / d +end +)"), + R"( +; function vec3combo($arg0, $arg1, $arg2, $arg3) line 2 +bb_0: + CHECK_TAG R0, tvector, exit(entry) + CHECK_TAG R1, tvector, exit(entry) + CHECK_TAG R2, tvector, exit(entry) + CHECK_TAG R3, tvector, exit(entry) + JUMP bb_2 +bb_2: + JUMP bb_bytecode_1 +bb_bytecode_1: + %12 = LOAD_TVALUE R0 + %13 = NUM_TO_VEC 2 + %14 = MUL_VEC %12, %13 + %19 = LOAD_TVALUE R1 + %20 = NUM_TO_VEC 4 + %21 = DIV_VEC %19, %20 + %30 = ADD_VEC %14, %21 + %40 = NUM_TO_VEC 0.5 + %41 = LOAD_TVALUE R2 + %42 = MUL_VEC %40, %41 + %51 = ADD_VEC %30, %42 + %56 = NUM_TO_VEC 40 + %57 = LOAD_TVALUE R3 + %58 = DIV_VEC %56, %57 + %67 = ADD_VEC %51, %58 + %68 = TAG_VECTOR %67 + STORE_TVALUE R4, %68 + INTERRUPT 8u + RETURN R4, 1i +)" + ); +} + +TEST_CASE("ExtraMathMemoryOperands") +{ + CHECK_EQ( + "\n" + getCodegenAssembly(R"( +local function foo(a: number, b: number, c: number, d: number, e: number) + return math.floor(a) + math.ceil(b) + math.round(c) + math.sqrt(d) + math.abs(e) +end +)"), + R"( +; function foo($arg0, $arg1, $arg2, $arg3, $arg4) line 2 +bb_0: + CHECK_TAG R0, tnumber, exit(entry) + CHECK_TAG R1, tnumber, exit(entry) + CHECK_TAG R2, tnumber, exit(entry) + CHECK_TAG R3, tnumber, exit(entry) + CHECK_TAG R4, tnumber, exit(entry) + JUMP bb_2 +bb_2: + JUMP bb_bytecode_1 +bb_bytecode_1: + CHECK_SAFE_ENV exit(1) + %16 = FLOOR_NUM R0 + %23 = CEIL_NUM R1 + %32 = ADD_NUM %16, %23 + %39 = ROUND_NUM R2 + %48 = ADD_NUM %32, %39 + %55 = SQRT_NUM R3 + %64 = ADD_NUM %48, %55 + %71 = ABS_NUM R4 + %80 = ADD_NUM %64, %71 + STORE_DOUBLE R5, %80 + STORE_TAG R5, tnumber + INTERRUPT 29u + RETURN R5, 1i +)" + ); +} + +TEST_CASE("DseInitialStackState") +{ + CHECK_EQ( + "\n" + getCodegenAssembly(R"( +local function foo() + while {} do + local _ = not _,{} + _ = nil + end +end +)"), + R"( +; function foo() line 2 +bb_bytecode_0: + SET_SAVEDPC 1u + %1 = NEW_TABLE 0u, 0u + STORE_POINTER R0, %1 + STORE_TAG R0, ttable + CHECK_GC + JUMP bb_2 +bb_2: + CHECK_SAFE_ENV exit(3) + JUMP_EQ_TAG K1, tnil, bb_fallback_4, bb_3 +bb_3: + %9 = LOAD_TVALUE K1 + STORE_TVALUE R1, %9 + JUMP bb_5 +bb_5: + SET_SAVEDPC 7u + %21 = NEW_TABLE 0u, 0u + STORE_POINTER R1, %21 + STORE_TAG R1, ttable + CHECK_GC + STORE_TAG R0, tnil + INTERRUPT 9u + JUMP bb_bytecode_0 +)" + ); +} + +TEST_CASE("DseInitialStackState2") +{ + CHECK_EQ( + "\n" + getCodegenAssembly(R"( +local function foo(a) + math.frexp(a) + return a +end +)"), + R"( +; function foo($arg0) line 2 +bb_bytecode_0: + CHECK_SAFE_ENV exit(1) + CHECK_TAG R0, tnumber, exit(1) + FASTCALL 14u, R1, R0, 2i + INTERRUPT 5u + RETURN R0, 1i +)" + ); +} + +TEST_CASE("VectorConstantTag") +{ + CHECK_EQ( + "\n" + getCodegenAssembly(R"( +local function vecrcp(a: vector) + return vector(1, 2, 3) + a +end +)"), + R"( +; function vecrcp($arg0) line 2 +bb_0: + CHECK_TAG R0, tvector, exit(entry) + JUMP bb_2 +bb_2: + JUMP bb_bytecode_1 +bb_bytecode_1: + %4 = LOAD_TVALUE K0, 0i, tvector + %11 = LOAD_TVALUE R0 + %12 = ADD_VEC %4, %11 + %13 = TAG_VECTOR %12 + STORE_TVALUE R1, %13 + INTERRUPT 2u + RETURN R1, 1i +)" + ); +} + +TEST_CASE("VectorNamecall") +{ + CHECK_EQ( + "\n" + getCodegenAssembly(R"( +local function abs(a: vector) + return a:Abs() +end +)"), + R"( +; function abs($arg0) line 2 +bb_0: + CHECK_TAG R0, tvector, exit(entry) + JUMP bb_2 +bb_2: + JUMP bb_bytecode_1 +bb_bytecode_1: + FALLBACK_NAMECALL 0u, R1, R0, K0 + INTERRUPT 2u + SET_SAVEDPC 3u + CALL R1, 1i, -1i + INTERRUPT 3u + RETURN R1, -1i +)" + ); +} + +TEST_CASE("VectorRandomProp") +{ + CHECK_EQ( + "\n" + getCodegenAssembly(R"( +local function foo(a: vector) + return a.XX + a.YY + a.ZZ +end +)"), + R"( +; function foo($arg0) line 2 +bb_0: + CHECK_TAG R0, tvector, exit(entry) + JUMP bb_2 +bb_2: + JUMP bb_bytecode_1 +bb_bytecode_1: + FALLBACK_GETTABLEKS 0u, R3, R0, K0 + FALLBACK_GETTABLEKS 2u, R4, R0, K1 + CHECK_TAG R3, tnumber, bb_fallback_3 + CHECK_TAG R4, tnumber, bb_fallback_3 + %14 = LOAD_DOUBLE R3 + %16 = ADD_NUM %14, R4 + STORE_DOUBLE R2, %16 + STORE_TAG R2, tnumber + JUMP bb_4 +bb_4: + CHECK_TAG R0, tvector, exit(5) + FALLBACK_GETTABLEKS 5u, R3, R0, K2 + CHECK_TAG R2, tnumber, bb_fallback_5 + CHECK_TAG R3, tnumber, bb_fallback_5 + %30 = LOAD_DOUBLE R2 + %32 = ADD_NUM %30, R3 + STORE_DOUBLE R1, %32 + STORE_TAG R1, tnumber + JUMP bb_6 +bb_6: + INTERRUPT 8u + RETURN R1, 1i +)" + ); +} + +TEST_CASE("VectorCustomAccess") +{ + CHECK_EQ( + "\n" + getCodegenAssembly(R"( +local function vec3magn(a: vector) + return a.Magnitude * 2 +end +)"), + R"( +; function vec3magn($arg0) line 2 +bb_0: + CHECK_TAG R0, tvector, exit(entry) + JUMP bb_2 +bb_2: + JUMP bb_bytecode_1 +bb_bytecode_1: + %6 = LOAD_FLOAT R0, 0i + %7 = LOAD_FLOAT R0, 4i + %8 = LOAD_FLOAT R0, 8i + %9 = MUL_NUM %6, %6 + %10 = MUL_NUM %7, %7 + %11 = MUL_NUM %8, %8 + %12 = ADD_NUM %9, %10 + %13 = ADD_NUM %12, %11 + %14 = SQRT_NUM %13 + %20 = MUL_NUM %14, 2 + STORE_DOUBLE R1, %20 + STORE_TAG R1, tnumber + INTERRUPT 3u + RETURN R1, 1i +)" + ); +} + +TEST_CASE("VectorCustomNamecall") +{ + CHECK_EQ( + "\n" + getCodegenAssembly(R"( +local function vec3dot(a: vector, b: vector) + return (a:Dot(b)) +end +)"), + R"( +; function vec3dot($arg0, $arg1) line 2 +bb_0: + CHECK_TAG R0, tvector, exit(entry) + CHECK_TAG R1, tvector, exit(entry) + JUMP bb_2 +bb_2: + JUMP bb_bytecode_1 +bb_bytecode_1: + %6 = LOAD_TVALUE R1 + STORE_TVALUE R4, %6 + %12 = LOAD_FLOAT R0, 0i + %13 = LOAD_FLOAT R4, 0i + %14 = MUL_NUM %12, %13 + %15 = LOAD_FLOAT R0, 4i + %16 = LOAD_FLOAT R4, 4i + %17 = MUL_NUM %15, %16 + %18 = LOAD_FLOAT R0, 8i + %19 = LOAD_FLOAT R4, 8i + %20 = MUL_NUM %18, %19 + %21 = ADD_NUM %14, %17 + %22 = ADD_NUM %21, %20 + STORE_DOUBLE R2, %22 + STORE_TAG R2, tnumber + INTERRUPT 4u + RETURN R2, 1i +)" + ); +} + +TEST_CASE("VectorCustomAccessChain") +{ + CHECK_EQ( + "\n" + getCodegenAssembly(R"( +local function foo(a: vector, b: vector) + return a.Unit * b.Magnitude +end +)"), + R"( +; function foo($arg0, $arg1) line 2 +bb_0: + CHECK_TAG R0, tvector, exit(entry) + CHECK_TAG R1, tvector, exit(entry) + JUMP bb_2 +bb_2: + JUMP bb_bytecode_1 +bb_bytecode_1: + %8 = LOAD_FLOAT R0, 0i + %9 = LOAD_FLOAT R0, 4i + %10 = LOAD_FLOAT R0, 8i + %11 = MUL_NUM %8, %8 + %12 = MUL_NUM %9, %9 + %13 = MUL_NUM %10, %10 + %14 = ADD_NUM %11, %12 + %15 = ADD_NUM %14, %13 + %16 = SQRT_NUM %15 + %17 = DIV_NUM 1, %16 + %18 = MUL_NUM %8, %17 + %19 = MUL_NUM %9, %17 + %20 = MUL_NUM %10, %17 + STORE_VECTOR R3, %18, %19, %20 + STORE_TAG R3, tvector + %25 = LOAD_FLOAT R1, 0i + %26 = LOAD_FLOAT R1, 4i + %27 = LOAD_FLOAT R1, 8i + %28 = MUL_NUM %25, %25 + %29 = MUL_NUM %26, %26 + %30 = MUL_NUM %27, %27 + %31 = ADD_NUM %28, %29 + %32 = ADD_NUM %31, %30 + %33 = SQRT_NUM %32 + %40 = LOAD_TVALUE R3 + %42 = NUM_TO_VEC %33 + %43 = MUL_VEC %40, %42 + %44 = TAG_VECTOR %43 + STORE_TVALUE R2, %44 + INTERRUPT 5u + RETURN R2, 1i +)" + ); +} + +TEST_CASE("VectorCustomNamecallChain") +{ + CHECK_EQ( + "\n" + getCodegenAssembly(R"( +local function foo(n: vector, b: vector, t: vector) + return n:Cross(t):Dot(b) + 1 +end +)"), + R"( +; function foo($arg0, $arg1, $arg2) line 2 +bb_0: + CHECK_TAG R0, tvector, exit(entry) + CHECK_TAG R1, tvector, exit(entry) + CHECK_TAG R2, tvector, exit(entry) + JUMP bb_2 +bb_2: + JUMP bb_bytecode_1 +bb_bytecode_1: + %8 = LOAD_TVALUE R2 + STORE_TVALUE R6, %8 + %14 = LOAD_FLOAT R0, 0i + %15 = LOAD_FLOAT R6, 0i + %16 = LOAD_FLOAT R0, 4i + %17 = LOAD_FLOAT R6, 4i + %18 = LOAD_FLOAT R0, 8i + %19 = LOAD_FLOAT R6, 8i + %20 = MUL_NUM %16, %19 + %21 = MUL_NUM %18, %17 + %22 = SUB_NUM %20, %21 + %23 = MUL_NUM %18, %15 + %24 = MUL_NUM %14, %19 + %25 = SUB_NUM %23, %24 + %26 = MUL_NUM %14, %17 + %27 = MUL_NUM %16, %15 + %28 = SUB_NUM %26, %27 + STORE_VECTOR R4, %22, %25, %28 + STORE_TAG R4, tvector + %31 = LOAD_TVALUE R1 + STORE_TVALUE R6, %31 + %37 = LOAD_FLOAT R4, 0i + %38 = LOAD_FLOAT R6, 0i + %39 = MUL_NUM %37, %38 + %40 = LOAD_FLOAT R4, 4i + %41 = LOAD_FLOAT R6, 4i + %42 = MUL_NUM %40, %41 + %43 = LOAD_FLOAT R4, 8i + %44 = LOAD_FLOAT R6, 8i + %45 = MUL_NUM %43, %44 + %46 = ADD_NUM %39, %42 + %47 = ADD_NUM %46, %45 + %53 = ADD_NUM %47, 1 + STORE_DOUBLE R3, %53 + STORE_TAG R3, tnumber + INTERRUPT 9u + RETURN R3, 1i +)" + ); +} + +TEST_CASE("VectorCustomNamecallChain2") +{ + CHECK_EQ( + "\n" + getCodegenAssembly(R"( +type Vertex = {n: vector, b: vector} + +local function foo(v: Vertex, t: vector) + return v.n:Cross(t):Dot(v.b) + 1 +end +)"), + R"( +; function foo($arg0, $arg1) line 4 +bb_0: + CHECK_TAG R0, ttable, exit(entry) + CHECK_TAG R1, tvector, exit(entry) + JUMP bb_2 +bb_2: + JUMP bb_bytecode_1 +bb_bytecode_1: + %8 = LOAD_POINTER R0 + %9 = GET_SLOT_NODE_ADDR %8, 0u, K1 + CHECK_SLOT_MATCH %9, K1, bb_fallback_3 + %11 = LOAD_TVALUE %9, 0i + STORE_TVALUE R3, %11 + JUMP bb_4 +bb_4: + %16 = LOAD_TVALUE R1 + STORE_TVALUE R5, %16 + CHECK_TAG R3, tvector, exit(3) + CHECK_TAG R5, tvector, exit(3) + %22 = LOAD_FLOAT R3, 0i + %23 = LOAD_FLOAT R5, 0i + %24 = LOAD_FLOAT R3, 4i + %25 = LOAD_FLOAT R5, 4i + %26 = LOAD_FLOAT R3, 8i + %27 = LOAD_FLOAT R5, 8i + %28 = MUL_NUM %24, %27 + %29 = MUL_NUM %26, %25 + %30 = SUB_NUM %28, %29 + %31 = MUL_NUM %26, %23 + %32 = MUL_NUM %22, %27 + %33 = SUB_NUM %31, %32 + %34 = MUL_NUM %22, %25 + %35 = MUL_NUM %24, %23 + %36 = SUB_NUM %34, %35 + STORE_VECTOR R3, %30, %33, %36 + CHECK_TAG R0, ttable, exit(6) + %41 = LOAD_POINTER R0 + %42 = GET_SLOT_NODE_ADDR %41, 6u, K3 + CHECK_SLOT_MATCH %42, K3, bb_fallback_5 + %44 = LOAD_TVALUE %42, 0i + STORE_TVALUE R5, %44 + JUMP bb_6 +bb_6: + CHECK_TAG R3, tvector, exit(8) + CHECK_TAG R5, tvector, exit(8) + %53 = LOAD_FLOAT R3, 0i + %54 = LOAD_FLOAT R5, 0i + %55 = MUL_NUM %53, %54 + %56 = LOAD_FLOAT R3, 4i + %57 = LOAD_FLOAT R5, 4i + %58 = MUL_NUM %56, %57 + %59 = LOAD_FLOAT R3, 8i + %60 = LOAD_FLOAT R5, 8i + %61 = MUL_NUM %59, %60 + %62 = ADD_NUM %55, %58 + %63 = ADD_NUM %62, %61 + %69 = ADD_NUM %63, 1 + STORE_DOUBLE R2, %69 + STORE_TAG R2, tnumber + INTERRUPT 12u + RETURN R2, 1i +)" + ); +} + +TEST_CASE("UserDataGetIndex") +{ + CHECK_EQ( + "\n" + getCodegenAssembly(R"( +local function getxy(a: Point) + return a.x + a.y +end +)"), + R"( +; function getxy($arg0) line 2 +bb_0: + CHECK_TAG R0, tuserdata, exit(entry) + JUMP bb_2 +bb_2: + JUMP bb_bytecode_1 +bb_bytecode_1: + FALLBACK_GETTABLEKS 0u, R2, R0, K0 + FALLBACK_GETTABLEKS 2u, R3, R0, K1 + CHECK_TAG R2, tnumber, bb_fallback_3 + CHECK_TAG R3, tnumber, bb_fallback_3 + %14 = LOAD_DOUBLE R2 + %16 = ADD_NUM %14, R3 + STORE_DOUBLE R1, %16 + STORE_TAG R1, tnumber + JUMP bb_4 +bb_4: + INTERRUPT 5u + RETURN R1, 1i +)" + ); +} + +TEST_CASE("UserDataSetIndex") +{ + CHECK_EQ( + "\n" + getCodegenAssembly(R"( +local function setxy(a: Point) + a.x = 3 + a.y = 4 +end +)"), + R"( +; function setxy($arg0) line 2 +bb_0: + CHECK_TAG R0, tuserdata, exit(entry) + JUMP bb_2 +bb_2: + JUMP bb_bytecode_1 +bb_bytecode_1: + STORE_DOUBLE R1, 3 + STORE_TAG R1, tnumber + FALLBACK_SETTABLEKS 1u, R1, R0, K0 + STORE_DOUBLE R1, 4 + FALLBACK_SETTABLEKS 4u, R1, R0, K1 + INTERRUPT 6u + RETURN R0, 0i +)" + ); +} + +TEST_CASE("UserDataNamecall") +{ + CHECK_EQ( + "\n" + getCodegenAssembly(R"( +local function getxy(a: Point) + return a:GetX() + a:GetY() +end +)"), + R"( +; function getxy($arg0) line 2 +bb_0: + CHECK_TAG R0, tuserdata, exit(entry) + JUMP bb_2 +bb_2: + JUMP bb_bytecode_1 +bb_bytecode_1: + FALLBACK_NAMECALL 0u, R2, R0, K0 + INTERRUPT 2u + SET_SAVEDPC 3u + CALL R2, 1i, 1i + FALLBACK_NAMECALL 3u, R3, R0, K1 + INTERRUPT 5u + SET_SAVEDPC 6u + CALL R3, 1i, 1i + CHECK_TAG R2, tnumber, bb_fallback_3 + CHECK_TAG R3, tnumber, bb_fallback_3 + %20 = LOAD_DOUBLE R2 + %22 = ADD_NUM %20, R3 + STORE_DOUBLE R1, %22 + STORE_TAG R1, tnumber + JUMP bb_4 +bb_4: + INTERRUPT 7u + RETURN R1, 1i +)" + ); +} + +TEST_CASE("ExplicitUpvalueAndLocalTypes") +{ + CHECK_EQ( + "\n" + getCodegenAssembly( + R"( +local y: vector = ... + +local function getsum(t) + local x: vector = t + return x.X + x.Y + y.X + y.Y +end +)", + /* includeIrTypes */ true + ), + R"( +; function getsum($arg0) line 4 +; U0: vector +; R0: vector from 0 to 14 +bb_bytecode_0: + CHECK_TAG R0, tvector, exit(0) + %2 = LOAD_FLOAT R0, 0i + STORE_DOUBLE R4, %2 + STORE_TAG R4, tnumber + %7 = LOAD_FLOAT R0, 4i + %16 = ADD_NUM %2, %7 + STORE_DOUBLE R3, %16 + STORE_TAG R3, tnumber + GET_UPVALUE R5, U0 + CHECK_TAG R5, tvector, exit(6) + %22 = LOAD_FLOAT R5, 0i + %31 = ADD_NUM %16, %22 + STORE_DOUBLE R2, %31 + STORE_TAG R2, tnumber + GET_UPVALUE R4, U0 + CHECK_TAG R4, tvector, exit(10) + %37 = LOAD_FLOAT R4, 4i + %46 = ADD_NUM %31, %37 + STORE_DOUBLE R1, %46 + STORE_TAG R1, tnumber + INTERRUPT 13u + RETURN R1, 1i +)" + ); +} + +#if LUA_VECTOR_SIZE == 3 +TEST_CASE("FastcallTypeInferThroughLocal") +{ + CHECK_EQ( + "\n" + getCodegenAssembly( + R"( +local function getsum(x, c) + local v = vector(x, 2, 3) + if c then + return v.X + v.Y + else + return v.Z + end +end +)", + /* includeIrTypes */ true + ), + R"( +; function getsum($arg0, $arg1) line 2 +; R2: vector from 0 to 18 +bb_bytecode_0: + STORE_DOUBLE R4, 2 + STORE_TAG R4, tnumber + STORE_DOUBLE R5, 3 + STORE_TAG R5, tnumber + CHECK_SAFE_ENV exit(4) + CHECK_TAG R0, tnumber, exit(4) + %11 = LOAD_DOUBLE R0 + STORE_VECTOR R2, %11, 2, 3 + STORE_TAG R2, tvector + JUMP_IF_FALSY R1, bb_bytecode_1, bb_3 +bb_3: + CHECK_TAG R2, tvector, exit(9) + %19 = LOAD_FLOAT R2, 0i + %24 = LOAD_FLOAT R2, 4i + %33 = ADD_NUM %19, %24 + STORE_DOUBLE R3, %33 + STORE_TAG R3, tnumber + INTERRUPT 14u + RETURN R3, 1i +bb_bytecode_1: + CHECK_TAG R2, tvector, exit(15) + %40 = LOAD_FLOAT R2, 8i + STORE_DOUBLE R3, %40 + STORE_TAG R3, tnumber + INTERRUPT 17u + RETURN R3, 1i +)" + ); +} + +TEST_CASE("FastcallTypeInferThroughUpvalue") +{ + CHECK_EQ( + "\n" + getCodegenAssembly( + R"( +local v = ... + +local function getsum(x, c) + v = vector(x, 2, 3) + if c then + return v.X + v.Y + else + return v.Z + end +end +)", + /* includeIrTypes */ true + ), + R"( +; function getsum($arg0, $arg1) line 4 +; U0: vector +bb_bytecode_0: + STORE_DOUBLE R4, 2 + STORE_TAG R4, tnumber + STORE_DOUBLE R5, 3 + STORE_TAG R5, tnumber + CHECK_SAFE_ENV exit(4) + CHECK_TAG R0, tnumber, exit(4) + %11 = LOAD_DOUBLE R0 + STORE_VECTOR R2, %11, 2, 3 + STORE_TAG R2, tvector + SET_UPVALUE U0, R2, tvector + JUMP_IF_FALSY R1, bb_bytecode_1, bb_3 +bb_3: + GET_UPVALUE R4, U0 + CHECK_TAG R4, tvector, exit(11) + %21 = LOAD_FLOAT R4, 0i + STORE_DOUBLE R3, %21 + STORE_TAG R3, tnumber + GET_UPVALUE R5, U0 + CHECK_TAG R5, tvector, exit(14) + %27 = LOAD_FLOAT R5, 4i + %36 = ADD_NUM %21, %27 + STORE_DOUBLE R2, %36 + STORE_TAG R2, tnumber + INTERRUPT 17u + RETURN R2, 1i +bb_bytecode_1: + GET_UPVALUE R3, U0 + CHECK_TAG R3, tvector, exit(19) + %44 = LOAD_FLOAT R3, 8i + STORE_DOUBLE R2, %44 + STORE_TAG R2, tnumber + INTERRUPT 21u + RETURN R2, 1i +)" + ); +} +#endif + +TEST_CASE("LoadAndMoveTypePropagation") +{ + CHECK_EQ( + "\n" + getCodegenAssembly( + R"( +local function getsum(n) + local seqsum = 0 + for i = 1,n do + if i < 10 then + seqsum += i + else + seqsum *= i + end + end + + return seqsum +end +)", + /* includeIrTypes */ true + ), + R"( +; function getsum($arg0) line 2 +; R1: number from 0 to 13 +; R4: number from 1 to 11 +bb_bytecode_0: + STORE_DOUBLE R1, 0 + STORE_TAG R1, tnumber + STORE_DOUBLE R4, 1 + STORE_TAG R4, tnumber + %4 = LOAD_TVALUE R0 + STORE_TVALUE R2, %4 + STORE_DOUBLE R3, 1 + STORE_TAG R3, tnumber + CHECK_TAG R2, tnumber, exit(4) + %12 = LOAD_DOUBLE R2 + JUMP_CMP_NUM 1, %12, not_le, bb_bytecode_4, bb_bytecode_1 +bb_bytecode_1: + INTERRUPT 5u + STORE_DOUBLE R5, 10 + STORE_TAG R5, tnumber + CHECK_TAG R4, tnumber, bb_fallback_6 + JUMP_CMP_NUM R4, 10, not_lt, bb_bytecode_2, bb_5 +bb_5: + CHECK_TAG R1, tnumber, exit(8) + CHECK_TAG R4, tnumber, exit(8) + %32 = LOAD_DOUBLE R1 + %34 = ADD_NUM %32, R4 + STORE_DOUBLE R1, %34 + JUMP bb_bytecode_3 +bb_bytecode_2: + CHECK_TAG R1, tnumber, exit(10) + CHECK_TAG R4, tnumber, exit(10) + %41 = LOAD_DOUBLE R1 + %43 = MUL_NUM %41, R4 + STORE_DOUBLE R1, %43 + JUMP bb_bytecode_3 +bb_bytecode_3: + %46 = LOAD_DOUBLE R2 + %47 = LOAD_DOUBLE R4 + %48 = ADD_NUM %47, 1 + STORE_DOUBLE R4, %48 + JUMP_CMP_NUM %48, %46, le, bb_bytecode_1, bb_bytecode_4 +bb_bytecode_4: + INTERRUPT 12u + RETURN R1, 1i +)" + ); +} + +#if LUA_VECTOR_SIZE == 3 +TEST_CASE("ArgumentTypeRefinement") +{ + CHECK_EQ( + "\n" + getCodegenAssembly( + R"( +local function getsum(x, y) + x = vector(1, y, 3) + return x.Y + x.Z +end +)", + /* includeIrTypes */ true + ), + R"( +; function getsum($arg0, $arg1) line 2 +; R0: vector [argument] +bb_bytecode_0: + STORE_DOUBLE R3, 1 + STORE_TAG R3, tnumber + STORE_DOUBLE R5, 3 + STORE_TAG R5, tnumber + CHECK_SAFE_ENV exit(4) + CHECK_TAG R1, tnumber, exit(4) + %12 = LOAD_DOUBLE R1 + STORE_VECTOR R2, 1, %12, 3 + STORE_TAG R2, tvector + %16 = LOAD_TVALUE R2 + STORE_TVALUE R0, %16 + %20 = LOAD_FLOAT R0, 4i + %25 = LOAD_FLOAT R0, 8i + %34 = ADD_NUM %20, %25 + STORE_DOUBLE R2, %34 + STORE_TAG R2, tnumber + INTERRUPT 14u + RETURN R2, 1i +)" + ); +} +#endif + +TEST_CASE("InlineFunctionType") +{ + CHECK_EQ( + "\n" + getCodegenAssembly( + R"( +local function inl(v: vector, s: number) + return v.Y * s +end + +local function getsum(x) + return inl(x, 2) + inl(x, 5) +end +)", + /* includeIrTypes */ true + ), + R"( +; function inl($arg0, $arg1) line 2 +; R0: vector [argument] +; R1: number [argument] +bb_0: + CHECK_TAG R0, tvector, exit(entry) + CHECK_TAG R1, tnumber, exit(entry) + JUMP bb_2 +bb_2: + JUMP bb_bytecode_1 +bb_bytecode_1: + %8 = LOAD_FLOAT R0, 4i + %17 = MUL_NUM %8, R1 + STORE_DOUBLE R2, %17 + STORE_TAG R2, tnumber + INTERRUPT 3u + RETURN R2, 1i +; function getsum($arg0) line 6 +; R0: vector from 0 to 3 +; R0: vector from 3 to 6 +bb_bytecode_0: + CHECK_TAG R0, tvector, exit(0) + %2 = LOAD_FLOAT R0, 4i + %8 = MUL_NUM %2, 2 + %13 = LOAD_FLOAT R0, 4i + %19 = MUL_NUM %13, 5 + %28 = ADD_NUM %8, %19 + STORE_DOUBLE R1, %28 + STORE_TAG R1, tnumber + INTERRUPT 7u + RETURN R1, 1i +)" + ); +} + +TEST_CASE("ResolveTablePathTypes") +{ + CHECK_EQ( + "\n" + getCodegenAssembly( + R"( +type Vertex = {pos: vector, normal: vector} + +local function foo(arr: {Vertex}, i) + local v = arr[i] + + return v.pos.Y +end +)", + /* includeIrTypes */ true, + /* debugLevel */ 2 + ), + R"( +; function foo(arr, i) line 4 +; R0: table [argument 'arr'] +; R2: table from 0 to 6 [local 'v'] +; R4: vector from 3 to 5 +bb_0: + CHECK_TAG R0, ttable, exit(entry) + JUMP bb_2 +bb_2: + JUMP bb_bytecode_1 +bb_bytecode_1: + CHECK_TAG R1, tnumber, bb_fallback_3 + %8 = LOAD_POINTER R0 + %9 = LOAD_DOUBLE R1 + %10 = TRY_NUM_TO_INDEX %9, bb_fallback_3 + %11 = SUB_INT %10, 1i + CHECK_ARRAY_SIZE %8, %11, bb_fallback_3 + CHECK_NO_METATABLE %8, bb_fallback_3 + %14 = GET_ARR_ADDR %8, %11 + %15 = LOAD_TVALUE %14 + STORE_TVALUE R2, %15 + JUMP bb_4 +bb_4: + CHECK_TAG R2, ttable, exit(1) + %23 = LOAD_POINTER R2 + %24 = GET_SLOT_NODE_ADDR %23, 1u, K0 + CHECK_SLOT_MATCH %24, K0, bb_fallback_5 + %26 = LOAD_TVALUE %24, 0i + STORE_TVALUE R4, %26 + JUMP bb_6 +bb_6: + CHECK_TAG R4, tvector, exit(3) + %33 = LOAD_FLOAT R4, 4i + STORE_DOUBLE R3, %33 + STORE_TAG R3, tnumber + INTERRUPT 5u + RETURN R3, 1i +)" + ); +} + +TEST_CASE("ResolvableSimpleMath") +{ + CHECK_EQ( + "\n" + getCodegenHeader(R"( +type Vertex = { p: vector, uv: vector, n: vector, t: vector, b: vector, h: number } +local mesh: { vertices: {Vertex}, indices: {number} } = ... + +local function compute() + for i = 1,#mesh.indices,3 do + local a = mesh.vertices[mesh.indices[i]] + local b = mesh.vertices[mesh.indices[i + 1]] + local c = mesh.vertices[mesh.indices[i + 2]] + + local vba = b.p - a.p + local vca = c.p - a.p + + local uvba = b.uv - a.uv + local uvca = c.uv - a.uv + + local r = 1.0 / (uvba.X * uvca.Y - uvca.X * uvba.Y); + + local sdir = (uvca.Y * vba - uvba.Y * vca) * r + + a.t += sdir + end +end +)"), + R"( +; function compute() line 5 +; U0: table ['mesh'] +; R2: number from 0 to 78 [local 'i'] +; R3: table from 7 to 78 [local 'a'] +; R4: table from 15 to 78 [local 'b'] +; R5: table from 24 to 78 [local 'c'] +; R6: vector from 33 to 78 [local 'vba'] +; R7: vector from 37 to 38 +; R7: vector from 38 to 78 [local 'vca'] +; R8: vector from 37 to 38 +; R8: vector from 42 to 43 +; R8: vector from 43 to 78 [local 'uvba'] +; R9: vector from 42 to 43 +; R9: vector from 47 to 48 +; R9: vector from 48 to 78 [local 'uvca'] +; R10: vector from 47 to 48 +; R10: vector from 52 to 53 +; R10: number from 53 to 78 [local 'r'] +; R11: vector from 52 to 53 +; R11: vector from 65 to 78 [local 'sdir'] +; R12: vector from 72 to 73 +; R12: vector from 75 to 76 +; R13: vector from 71 to 72 +; R14: vector from 71 to 72 +)" + ); +} + +TEST_CASE("ResolveVectorNamecalls") +{ + CHECK_EQ( + "\n" + getCodegenAssembly( + R"( +type Vertex = {pos: vector, normal: vector} + +local function foo(arr: {Vertex}, i) + return arr[i].normal:Dot(vector(0.707, 0, 0.707)) +end +)", + /* includeIrTypes */ true + ), + R"( +; function foo($arg0, $arg1) line 4 +; R0: table [argument] +; R2: vector from 4 to 6 +bb_0: + CHECK_TAG R0, ttable, exit(entry) + JUMP bb_2 +bb_2: + JUMP bb_bytecode_1 +bb_bytecode_1: + CHECK_TAG R1, tnumber, bb_fallback_3 + %8 = LOAD_POINTER R0 + %9 = LOAD_DOUBLE R1 + %10 = TRY_NUM_TO_INDEX %9, bb_fallback_3 + %11 = SUB_INT %10, 1i + CHECK_ARRAY_SIZE %8, %11, bb_fallback_3 + CHECK_NO_METATABLE %8, bb_fallback_3 + %14 = GET_ARR_ADDR %8, %11 + %15 = LOAD_TVALUE %14 + STORE_TVALUE R3, %15 + JUMP bb_4 +bb_4: + CHECK_TAG R3, ttable, bb_fallback_5 + %23 = LOAD_POINTER R3 + %24 = GET_SLOT_NODE_ADDR %23, 1u, K0 + CHECK_SLOT_MATCH %24, K0, bb_fallback_5 + %26 = LOAD_TVALUE %24, 0i + STORE_TVALUE R2, %26 + JUMP bb_6 +bb_6: + %31 = LOAD_TVALUE K1, 0i, tvector + STORE_TVALUE R4, %31 + CHECK_TAG R2, tvector, exit(4) + %37 = LOAD_FLOAT R2, 0i + %38 = LOAD_FLOAT R4, 0i + %39 = MUL_NUM %37, %38 + %40 = LOAD_FLOAT R2, 4i + %41 = LOAD_FLOAT R4, 4i + %42 = MUL_NUM %40, %41 + %43 = LOAD_FLOAT R2, 8i + %44 = LOAD_FLOAT R4, 8i + %45 = MUL_NUM %43, %44 + %46 = ADD_NUM %39, %42 + %47 = ADD_NUM %46, %45 + STORE_DOUBLE R2, %47 + STORE_TAG R2, tnumber + ADJUST_STACK_TO_REG R2, 1i + INTERRUPT 7u + RETURN R2, -1i +)" + ); +} + +TEST_CASE("ImmediateTypeAnnotationHelp") +{ + CHECK_EQ( + "\n" + getCodegenAssembly( + R"( +local function foo(arr, i) + return (arr[i] :: vector) / 5 +end +)", + /* includeIrTypes */ true + ), + R"( +; function foo($arg0, $arg1) line 2 +; R3: vector from 1 to 2 +bb_bytecode_0: + CHECK_TAG R0, ttable, bb_fallback_1 + CHECK_TAG R1, tnumber, bb_fallback_1 + %4 = LOAD_POINTER R0 + %5 = LOAD_DOUBLE R1 + %6 = TRY_NUM_TO_INDEX %5, bb_fallback_1 + %7 = SUB_INT %6, 1i + CHECK_ARRAY_SIZE %4, %7, bb_fallback_1 + CHECK_NO_METATABLE %4, bb_fallback_1 + %10 = GET_ARR_ADDR %4, %7 + %11 = LOAD_TVALUE %10 + STORE_TVALUE R3, %11 + JUMP bb_2 +bb_2: + CHECK_TAG R3, tvector, exit(1) + %19 = LOAD_TVALUE R3 + %20 = NUM_TO_VEC 5 + %21 = DIV_VEC %19, %20 + %22 = TAG_VECTOR %21 + STORE_TVALUE R2, %22 + INTERRUPT 2u + RETURN R2, 1i +)" + ); +} + +#if LUA_VECTOR_SIZE == 3 +TEST_CASE("UnaryTypeResolve") +{ + CHECK_EQ( + "\n" + getCodegenHeader(R"( +local function foo(a, b: vector, c) + local d = not a + local e = -b + local f = #c + return (if d then e else vector(f, 2, 3)).X +end +)"), + R"( +; function foo(a, b, c) line 2 +; R1: vector [argument 'b'] +; R3: boolean from 0 to 17 [local 'd'] +; R4: vector from 1 to 17 [local 'e'] +; R5: number from 2 to 17 [local 'f'] +; R7: vector from 14 to 16 +)" + ); +} +#endif + +TEST_CASE("ForInManualAnnotation") +{ + CHECK_EQ( + "\n" + getCodegenAssembly( + R"( +type Vertex = {pos: vector, normal: vector} + +local function foo(a: {Vertex}) + local sum = 0 + for k, v: Vertex in ipairs(a) do + sum += v.pos.X + end + return sum +end +)", + /* includeIrTypes */ true, + /* debugLevel */ 2 + ), + R"( +; function foo(a) line 4 +; R0: table [argument 'a'] +; R1: number from 0 to 14 [local 'sum'] +; R5: number from 5 to 11 [local 'k'] +; R6: table from 5 to 11 [local 'v'] +; R8: vector from 8 to 10 +bb_0: + CHECK_TAG R0, ttable, exit(entry) + JUMP bb_4 +bb_4: + JUMP bb_bytecode_1 +bb_bytecode_1: + STORE_DOUBLE R1, 0 + STORE_TAG R1, tnumber + CHECK_SAFE_ENV exit(1) + JUMP_EQ_TAG K1, tnil, bb_fallback_6, bb_5 +bb_5: + %9 = LOAD_TVALUE K1 + STORE_TVALUE R2, %9 + JUMP bb_7 +bb_7: + %15 = LOAD_TVALUE R0 + STORE_TVALUE R3, %15 + INTERRUPT 4u + SET_SAVEDPC 5u + CALL R2, 1i, 3i + CHECK_SAFE_ENV exit(5) + CHECK_TAG R3, ttable, bb_fallback_8 + CHECK_TAG R4, tnumber, bb_fallback_8 + JUMP_CMP_NUM R4, 0, not_eq, bb_fallback_8, bb_9 +bb_9: + STORE_TAG R2, tnil + STORE_POINTER R4, 0i + STORE_EXTRA R4, 128i + STORE_TAG R4, tlightuserdata + JUMP bb_bytecode_3 +bb_bytecode_2: + CHECK_TAG R6, ttable, exit(6) + %35 = LOAD_POINTER R6 + %36 = GET_SLOT_NODE_ADDR %35, 6u, K2 + CHECK_SLOT_MATCH %36, K2, bb_fallback_10 + %38 = LOAD_TVALUE %36, 0i + STORE_TVALUE R8, %38 + JUMP bb_11 +bb_11: + CHECK_TAG R8, tvector, exit(8) + %45 = LOAD_FLOAT R8, 0i + STORE_DOUBLE R7, %45 + STORE_TAG R7, tnumber + CHECK_TAG R1, tnumber, exit(10) + %52 = LOAD_DOUBLE R1 + %54 = ADD_NUM %52, %45 + STORE_DOUBLE R1, %54 + JUMP bb_bytecode_3 +bb_bytecode_3: + INTERRUPT 11u + CHECK_TAG R2, tnil, bb_fallback_13 + %60 = LOAD_POINTER R3 + %61 = LOAD_INT R4 + %62 = GET_ARR_ADDR %60, %61 + CHECK_ARRAY_SIZE %60, %61, bb_12 + %64 = LOAD_TAG %62 + JUMP_EQ_TAG %64, tnil, bb_12, bb_14 +bb_14: + %66 = ADD_INT %61, 1i + STORE_INT R4, %66 + %68 = INT_TO_NUM %66 + STORE_DOUBLE R5, %68 + STORE_TAG R5, tnumber + %71 = LOAD_TVALUE %62 + STORE_TVALUE R6, %71 + JUMP bb_bytecode_2 +bb_12: + INTERRUPT 13u + RETURN R1, 1i +)" + ); +} + +TEST_CASE("ForInAutoAnnotationIpairs") +{ + CHECK_EQ( + "\n" + getCodegenHeader(R"( +type Vertex = {pos: vector, normal: vector} + +local function foo(a: {Vertex}) + local sum = 0 + for k, v in ipairs(a) do + local n = v.pos.X + sum += n + end + return sum +end +)"), + R"( +; function foo(a) line 4 +; R0: table [argument 'a'] +; R1: number from 0 to 14 [local 'sum'] +; R5: number from 5 to 11 [local 'k'] +; R6: table from 5 to 11 [local 'v'] +; R7: number from 6 to 11 [local 'n'] +; R8: vector from 8 to 10 +)" + ); +} + +TEST_CASE("ForInAutoAnnotationPairs") +{ + CHECK_EQ( + "\n" + getCodegenHeader(R"( +type Vertex = {pos: vector, normal: vector} + +local function foo(a: {[string]: Vertex}) + local sum = 0 + for k, v in pairs(a) do + local n = v.pos.X + sum += n + end + return sum +end +)"), + R"( +; function foo(a) line 4 +; R0: table [argument 'a'] +; R1: number from 0 to 14 [local 'sum'] +; R5: string from 5 to 11 [local 'k'] +; R6: table from 5 to 11 [local 'v'] +; R7: number from 6 to 11 [local 'n'] +; R8: vector from 8 to 10 +)" + ); +} + +TEST_CASE("ForInAutoAnnotationGeneric") +{ + CHECK_EQ( + "\n" + getCodegenHeader(R"( +type Vertex = {pos: vector, normal: vector} + +local function foo(a: {Vertex}) + local sum = 0 + for k, v in a do + local n = v.pos.X + sum += n + end + return sum +end +)"), + R"( +; function foo(a) line 4 +; R0: table [argument 'a'] +; R1: number from 0 to 13 [local 'sum'] +; R5: number from 4 to 10 [local 'k'] +; R6: table from 4 to 10 [local 'v'] +; R7: number from 5 to 10 [local 'n'] +; R8: vector from 7 to 9 +)" + ); +} + +TEST_CASE("CustomUserdataTypes") +{ + // This test requires runtime component to be present + if (!Luau::CodeGen::isSupported()) + return; + + CHECK_EQ( + "\n" + getCodegenHeader(R"( +local function foo(v: vec2, x: mat3) + return v.X * x +end +)"), + R"( +; function foo(v, x) line 2 +; R0: vec2 [argument 'v'] +; R1: mat3 [argument 'x'] +)" + ); +} + +TEST_CASE("CustomUserdataPropertyAccess") +{ + // This test requires runtime component to be present + if (!Luau::CodeGen::isSupported()) + return; + + CHECK_EQ( + "\n" + getCodegenAssembly( + R"( +local function foo(v: vec2) + return v.X + v.Y +end +)", + /* includeIrTypes */ true + ), + R"( +; function foo($arg0) line 2 +; R0: vec2 [argument] +bb_0: + CHECK_TAG R0, tuserdata, exit(entry) + JUMP bb_2 +bb_2: + JUMP bb_bytecode_1 +bb_bytecode_1: + %6 = LOAD_POINTER R0 + CHECK_USERDATA_TAG %6, 12i, exit(0) + %8 = BUFFER_READF32 %6, 0i, tuserdata + %15 = BUFFER_READF32 %6, 4i, tuserdata + %24 = ADD_NUM %8, %15 + STORE_DOUBLE R1, %24 + STORE_TAG R1, tnumber + INTERRUPT 5u + RETURN R1, 1i +)" + ); +} + +TEST_CASE("CustomUserdataPropertyAccess2") +{ + // This test requires runtime component to be present + if (!Luau::CodeGen::isSupported()) + return; + + CHECK_EQ( + "\n" + getCodegenAssembly( + R"( +local function foo(a: mat3) + return a.Row1 * a.Row2 +end +)", + /* includeIrTypes */ true + ), + R"( +; function foo($arg0) line 2 +; R0: mat3 [argument] +bb_0: + CHECK_TAG R0, tuserdata, exit(entry) + JUMP bb_2 +bb_2: + JUMP bb_bytecode_1 +bb_bytecode_1: + FALLBACK_GETTABLEKS 0u, R2, R0, K0 + FALLBACK_GETTABLEKS 2u, R3, R0, K1 + CHECK_TAG R2, tvector, exit(4) + CHECK_TAG R3, tvector, exit(4) + %14 = LOAD_TVALUE R2 + %15 = LOAD_TVALUE R3 + %16 = MUL_VEC %14, %15 + %17 = TAG_VECTOR %16 + STORE_TVALUE R1, %17 + INTERRUPT 5u + RETURN R1, 1i +)" + ); +} + +TEST_CASE("CustomUserdataNamecall1") +{ + // This test requires runtime component to be present + if (!Luau::CodeGen::isSupported()) + return; + + CHECK_EQ( + "\n" + getCodegenAssembly( + R"( +local function foo(a: vec2, b: vec2) + return a:Dot(b) +end +)", + /* includeIrTypes */ true + ), + R"( +; function foo($arg0, $arg1) line 2 +; R0: vec2 [argument] +; R1: vec2 [argument] +bb_0: + CHECK_TAG R0, tuserdata, exit(entry) + CHECK_TAG R1, tuserdata, exit(entry) + JUMP bb_2 +bb_2: + JUMP bb_bytecode_1 +bb_bytecode_1: + %6 = LOAD_TVALUE R1 + STORE_TVALUE R4, %6 + %10 = LOAD_POINTER R0 + CHECK_USERDATA_TAG %10, 12i, exit(1) + %14 = LOAD_POINTER R4 + CHECK_USERDATA_TAG %14, 12i, exit(1) + %16 = BUFFER_READF32 %10, 0i, tuserdata + %17 = BUFFER_READF32 %14, 0i, tuserdata + %18 = MUL_NUM %16, %17 + %19 = BUFFER_READF32 %10, 4i, tuserdata + %20 = BUFFER_READF32 %14, 4i, tuserdata + %21 = MUL_NUM %19, %20 + %22 = ADD_NUM %18, %21 + STORE_DOUBLE R2, %22 + STORE_TAG R2, tnumber + ADJUST_STACK_TO_REG R2, 1i + INTERRUPT 4u + RETURN R2, -1i +)" + ); +} + +TEST_CASE("CustomUserdataNamecall2") +{ + // This test requires runtime component to be present + if (!Luau::CodeGen::isSupported()) + return; + + CHECK_EQ( + "\n" + getCodegenAssembly( + R"( +local function foo(a: vec2, b: vec2) + return a:Min(b) +end +)", + /* includeIrTypes */ true + ), + R"( +; function foo($arg0, $arg1) line 2 +; R0: vec2 [argument] +; R1: vec2 [argument] +bb_0: + CHECK_TAG R0, tuserdata, exit(entry) + CHECK_TAG R1, tuserdata, exit(entry) + JUMP bb_2 +bb_2: + JUMP bb_bytecode_1 +bb_bytecode_1: + %6 = LOAD_TVALUE R1 + STORE_TVALUE R4, %6 + %10 = LOAD_POINTER R0 + CHECK_USERDATA_TAG %10, 12i, exit(1) + %14 = LOAD_POINTER R4 + CHECK_USERDATA_TAG %14, 12i, exit(1) + %16 = BUFFER_READF32 %10, 0i, tuserdata + %17 = BUFFER_READF32 %14, 0i, tuserdata + %18 = MIN_NUM %16, %17 + %19 = BUFFER_READF32 %10, 4i, tuserdata + %20 = BUFFER_READF32 %14, 4i, tuserdata + %21 = MIN_NUM %19, %20 + CHECK_GC + %23 = NEW_USERDATA 8i, 12i + BUFFER_WRITEF32 %23, 0i, %18, tuserdata + BUFFER_WRITEF32 %23, 4i, %21, tuserdata + STORE_POINTER R2, %23 + STORE_TAG R2, tuserdata + ADJUST_STACK_TO_REG R2, 1i + INTERRUPT 4u + RETURN R2, -1i +)" + ); +} + +TEST_CASE("CustomUserdataMetamethodDirectFlow") +{ + // This test requires runtime component to be present + if (!Luau::CodeGen::isSupported()) + return; + + CHECK_EQ( + "\n" + getCodegenAssembly( + R"( +local function foo(a: mat3, b: mat3) + return a * b +end +)", + /* includeIrTypes */ true + ), + R"( +; function foo($arg0, $arg1) line 2 +; R0: mat3 [argument] +; R1: mat3 [argument] +bb_0: + CHECK_TAG R0, tuserdata, exit(entry) + CHECK_TAG R1, tuserdata, exit(entry) + JUMP bb_2 +bb_2: + JUMP bb_bytecode_1 +bb_bytecode_1: + SET_SAVEDPC 1u + DO_ARITH R2, R0, R1, 10i + INTERRUPT 1u + RETURN R2, 1i +)" + ); +} + +TEST_CASE("CustomUserdataMetamethodDirectFlow2") +{ + // This test requires runtime component to be present + if (!Luau::CodeGen::isSupported()) + return; + + CHECK_EQ( + "\n" + getCodegenAssembly( + R"( +local function foo(a: mat3) + return -a +end +)", + /* includeIrTypes */ true + ), + R"( +; function foo($arg0) line 2 +; R0: mat3 [argument] +bb_0: + CHECK_TAG R0, tuserdata, exit(entry) + JUMP bb_2 +bb_2: + JUMP bb_bytecode_1 +bb_bytecode_1: + SET_SAVEDPC 1u + DO_ARITH R1, R0, R0, 15i + INTERRUPT 1u + RETURN R1, 1i +)" + ); +} + +TEST_CASE("CustomUserdataMetamethodDirectFlow3") +{ + // This test requires runtime component to be present + if (!Luau::CodeGen::isSupported()) + return; + + CHECK_EQ( + "\n" + getCodegenAssembly( + R"( +local function foo(a: sequence) + return #a +end +)", + /* includeIrTypes */ true + ), + R"( +; function foo($arg0) line 2 +; R0: userdata [argument] +bb_0: + CHECK_TAG R0, tuserdata, exit(entry) + JUMP bb_2 +bb_2: + JUMP bb_bytecode_1 +bb_bytecode_1: + SET_SAVEDPC 1u + DO_LEN R1, R0 + INTERRUPT 1u + RETURN R1, 1i +)" + ); +} + +TEST_CASE("CustomUserdataMetamethod") +{ + // This test requires runtime component to be present + if (!Luau::CodeGen::isSupported()) + return; + + CHECK_EQ( + "\n" + getCodegenAssembly( + R"( +local function foo(a: vec2, b: vec2, c: vec2) + return -c + a * b +end +)", + /* includeIrTypes */ true + ), + R"( +; function foo($arg0, $arg1, $arg2) line 2 +; R0: vec2 [argument] +; R1: vec2 [argument] +; R2: vec2 [argument] +bb_0: + CHECK_TAG R0, tuserdata, exit(entry) + CHECK_TAG R1, tuserdata, exit(entry) + CHECK_TAG R2, tuserdata, exit(entry) + JUMP bb_2 +bb_2: + JUMP bb_bytecode_1 +bb_bytecode_1: + %10 = LOAD_POINTER R2 + CHECK_USERDATA_TAG %10, 12i, exit(0) + %12 = BUFFER_READF32 %10, 0i, tuserdata + %13 = BUFFER_READF32 %10, 4i, tuserdata + %14 = UNM_NUM %12 + %15 = UNM_NUM %13 + CHECK_GC + %17 = NEW_USERDATA 8i, 12i + BUFFER_WRITEF32 %17, 0i, %14, tuserdata + BUFFER_WRITEF32 %17, 4i, %15, tuserdata + STORE_POINTER R4, %17 + STORE_TAG R4, tuserdata + %26 = LOAD_POINTER R0 + CHECK_USERDATA_TAG %26, 12i, exit(1) + %28 = LOAD_POINTER R1 + CHECK_USERDATA_TAG %28, 12i, exit(1) + %30 = BUFFER_READF32 %26, 0i, tuserdata + %31 = BUFFER_READF32 %28, 0i, tuserdata + %32 = MUL_NUM %30, %31 + %33 = BUFFER_READF32 %26, 4i, tuserdata + %34 = BUFFER_READF32 %28, 4i, tuserdata + %35 = MUL_NUM %33, %34 + %37 = NEW_USERDATA 8i, 12i + BUFFER_WRITEF32 %37, 0i, %32, tuserdata + BUFFER_WRITEF32 %37, 4i, %35, tuserdata + STORE_POINTER R5, %37 + STORE_TAG R5, tuserdata + %50 = BUFFER_READF32 %17, 0i, tuserdata + %51 = BUFFER_READF32 %37, 0i, tuserdata + %52 = ADD_NUM %50, %51 + %53 = BUFFER_READF32 %17, 4i, tuserdata + %54 = BUFFER_READF32 %37, 4i, tuserdata + %55 = ADD_NUM %53, %54 + %57 = NEW_USERDATA 8i, 12i + BUFFER_WRITEF32 %57, 0i, %52, tuserdata + BUFFER_WRITEF32 %57, 4i, %55, tuserdata + STORE_POINTER R3, %57 + STORE_TAG R3, tuserdata + INTERRUPT 3u + RETURN R3, 1i +)" + ); +} + +TEST_SUITE_END(); diff --git a/tests/IrRegAllocX64.test.cpp b/tests/IrRegAllocX64.test.cpp new file mode 100644 index 000000000..b4b63f4bf --- /dev/null +++ b/tests/IrRegAllocX64.test.cpp @@ -0,0 +1,58 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/IrRegAllocX64.h" + +#include "doctest.h" + +using namespace Luau::CodeGen; +using namespace Luau::CodeGen::X64; + +class IrRegAllocX64Fixture +{ +public: + IrRegAllocX64Fixture() + : build(/* logText */ true, ABIX64::Windows) + , regs(build, function, nullptr) + { + } + + void checkMatch(std::string expected) + { + build.finalize(); + + CHECK("\n" + build.text == expected); + } + + AssemblyBuilderX64 build; + IrFunction function; + IrRegAllocX64 regs; +}; + +TEST_SUITE_BEGIN("IrRegAllocX64"); + +TEST_CASE_FIXTURE(IrRegAllocX64Fixture, "RelocateFix") +{ + IrInst irInst0{IrCmd::LOAD_DOUBLE}; + irInst0.lastUse = 2; + function.instructions.push_back(irInst0); + + IrInst irInst1{IrCmd::LOAD_DOUBLE}; + irInst1.lastUse = 2; + function.instructions.push_back(irInst1); + + function.instructions[0].regX64 = regs.takeReg(rax, 0); + regs.preserve(function.instructions[0]); + + function.instructions[1].regX64 = regs.takeReg(rax, 1); + regs.restore(function.instructions[0], true); + + LUAU_ASSERT(function.instructions[0].regX64 == rax); + LUAU_ASSERT(function.instructions[1].spilled); + + checkMatch(R"( + vmovsd qword ptr [rsp+048h],rax + vmovsd qword ptr [rsp+050h],rax + vmovsd rax,qword ptr [rsp+048h] +)"); +} + +TEST_SUITE_END(); diff --git a/tests/LValue.test.cpp b/tests/LValue.test.cpp index c71d97d16..931c3d59a 100644 --- a/tests/LValue.test.cpp +++ b/tests/LValue.test.cpp @@ -10,23 +10,28 @@ using namespace Luau; static void merge(TypeArena& arena, RefinementMap& l, const RefinementMap& r) { - Luau::merge(l, r, [&arena](TypeId a, TypeId b) -> TypeId { - // TODO: normalize here also. - std::unordered_set s; - - if (auto utv = get(follow(a))) - s.insert(begin(utv), end(utv)); - else - s.insert(a); - - if (auto utv = get(follow(b))) - s.insert(begin(utv), end(utv)); - else - s.insert(b); - - std::vector options(s.begin(), s.end()); - return options.size() == 1 ? options[0] : arena.addType(UnionType{std::move(options)}); - }); + Luau::merge( + l, + r, + [&arena](TypeId a, TypeId b) -> TypeId + { + // TODO: normalize here also. + std::unordered_set s; + + if (auto utv = get(follow(a))) + s.insert(begin(utv), end(utv)); + else + s.insert(a); + + if (auto utv = get(follow(b))) + s.insert(begin(utv), end(utv)); + else + s.insert(b); + + std::vector options(s.begin(), s.end()); + return options.size() == 1 ? options[0] : arena.addType(UnionType{std::move(options)}); + } + ); } static LValue mkSymbol(const std::string& s) diff --git a/tests/Lexer.test.cpp b/tests/Lexer.test.cpp index 7fcc1e542..e0716e4c5 100644 --- a/tests/Lexer.test.cpp +++ b/tests/Lexer.test.cpp @@ -157,8 +157,6 @@ TEST_CASE("string_interpolation_basic") TEST_CASE("string_interpolation_full") { - ScopedFastFlag sff("LuauFixInterpStringMid", true); - const std::string testInput = R"(`foo {"bar"} {"baz"} end`)"; Luau::Allocator alloc; AstNameTable table(alloc); @@ -194,13 +192,13 @@ TEST_CASE("string_interpolation_double_brace") auto brokenInterpBegin = lexer.next(); CHECK_EQ(brokenInterpBegin.type, Lexeme::BrokenInterpDoubleBrace); - CHECK_EQ(std::string(brokenInterpBegin.data, brokenInterpBegin.length), std::string("foo")); + CHECK_EQ(std::string(brokenInterpBegin.data, brokenInterpBegin.getLength()), std::string("foo")); CHECK_EQ(lexer.next().type, Lexeme::Name); auto interpEnd = lexer.next(); CHECK_EQ(interpEnd.type, Lexeme::InterpStringEnd); - CHECK_EQ(std::string(interpEnd.data, interpEnd.length), std::string("}bar")); + CHECK_EQ(std::string(interpEnd.data, interpEnd.getLength()), std::string("}bar")); } TEST_CASE("string_interpolation_double_but_unmatched_brace") diff --git a/tests/Linter.test.cpp b/tests/Linter.test.cpp index c716982ee..8647777aa 100644 --- a/tests/Linter.test.cpp +++ b/tests/Linter.test.cpp @@ -7,6 +7,10 @@ #include "doctest.h" +LUAU_FASTFLAG(LuauSolverV2); +LUAU_FASTFLAG(LuauNativeAttribute); +LUAU_FASTFLAG(LintRedundantNativeAttribute); + using namespace Luau; TEST_SUITE_BEGIN("Linter"); @@ -18,7 +22,18 @@ function fib(n) return n < 2 and 1 or fib(n-1) + fib(n-2) end -return math.max(fib(5), 1) +)"); + + REQUIRE(0 == result.warnings.size()); +} + +TEST_CASE_FIXTURE(Fixture, "type_function_fully_reduces") +{ + LintResult result = lint(R"( +function fib(n) + return n < 2 or fib(n-2) +end + )"); REQUIRE(0 == result.warnings.size()); @@ -35,7 +50,7 @@ TEST_CASE_FIXTURE(Fixture, "UnknownGlobal") TEST_CASE_FIXTURE(Fixture, "DeprecatedGlobal") { // Normally this would be defined externally, so hack it in for testing - addGlobalBinding(frontend, "Wait", Binding{typeChecker.anyType, {}, true, "wait", "@test/global/Wait"}); + addGlobalBinding(frontend.globals, "Wait", Binding{builtinTypes->anyType, {}, true, "wait", "@test/global/Wait"}); LintResult result = lint("Wait(5)"); @@ -47,7 +62,7 @@ TEST_CASE_FIXTURE(Fixture, "DeprecatedGlobalNoReplacement") { // Normally this would be defined externally, so hack it in for testing const char* deprecationReplacementString = ""; - addGlobalBinding(frontend, "Version", Binding{typeChecker.anyType, {}, true, deprecationReplacementString}); + addGlobalBinding(frontend.globals, "Version", Binding{builtinTypes->anyType, {}, true, deprecationReplacementString}); LintResult result = lint("Version()"); @@ -298,8 +313,9 @@ fnB() -- prints "false", "nil" )"); REQUIRE(1 == result.warnings.size()); - CHECK_EQ(result.warnings[0].text, - "Global 'moreInternalLogic' is only used in the enclosing function defined at line 2; consider changing it to local"); + CHECK_EQ( + result.warnings[0].text, "Global 'moreInternalLogic' is only used in the enclosing function defined at line 2; consider changing it to local" + ); } TEST_CASE_FIXTURE(Fixture, "LocalShadowLocal") @@ -373,7 +389,7 @@ return bar() TEST_CASE_FIXTURE(Fixture, "ImportUnused") { // Normally this would be defined externally, so hack it in for testing - addGlobalBinding(frontend, "game", typeChecker.anyType, "@test"); + addGlobalBinding(frontend.globals, "game", builtinTypes->anyType, "@test"); LintResult result = lint(R"( local Roact = require(game.Packages.Roact) @@ -604,16 +620,16 @@ return foo1 TEST_CASE_FIXTURE(Fixture, "UnknownType") { - unfreeze(typeChecker.globalTypes); + unfreeze(frontend.globals.globalTypes); TableType::Props instanceProps{ - {"ClassName", {typeChecker.anyType}}, + {"ClassName", {builtinTypes->anyType}}, }; - TableType instanceTable{instanceProps, std::nullopt, typeChecker.globalScope->level, Luau::TableState::Sealed}; - TypeId instanceType = typeChecker.globalTypes.addType(instanceTable); + TableType instanceTable{instanceProps, std::nullopt, frontend.globals.globalScope->level, Luau::TableState::Sealed}; + TypeId instanceType = frontend.globals.globalTypes.addType(instanceTable); TypeFun instanceTypeFun{{}, instanceType}; - typeChecker.globalScope->exportedTypeBindings["Part"] = instanceTypeFun; + frontend.globals.globalScope->exportedTypeBindings["Part"] = instanceTypeFun; LintResult result = lint(R"( local game = ... @@ -699,8 +715,10 @@ end CHECK_EQ(result.warnings[0].location.begin.line, 1); CHECK_EQ(result.warnings[0].text, "For loop starts at 0, but arrays start at 1"); CHECK_EQ(result.warnings[1].location.begin.line, 7); - CHECK_EQ(result.warnings[1].text, - "For loop should iterate backwards; did you forget to specify -1 as step? Also consider changing 0 to 1 since arrays start at 1"); + CHECK_EQ( + result.warnings[1].text, + "For loop should iterate backwards; did you forget to specify -1 as step? Also consider changing 0 to 1 since arrays start at 1" + ); } TEST_CASE_FIXTURE(Fixture, "UnbalancedAssignment") @@ -733,6 +751,7 @@ end TEST_CASE_FIXTURE(Fixture, "ImplicitReturn") { LintResult result = lint(R"( +--!nonstrict function f1(a) if not a then return 5 @@ -789,20 +808,27 @@ return f1,f2,f3,f4,f5,f6,f7 )"); REQUIRE(3 == result.warnings.size()); - CHECK_EQ(result.warnings[0].location.begin.line, 4); - CHECK_EQ(result.warnings[0].text, - "Function 'f1' can implicitly return no values even though there's an explicit return at line 4; add explicit return to silence"); - CHECK_EQ(result.warnings[1].location.begin.line, 28); - CHECK_EQ(result.warnings[1].text, - "Function 'f4' can implicitly return no values even though there's an explicit return at line 25; add explicit return to silence"); - CHECK_EQ(result.warnings[2].location.begin.line, 44); - CHECK_EQ(result.warnings[2].text, - "Function can implicitly return no values even though there's an explicit return at line 44; add explicit return to silence"); + CHECK_EQ(result.warnings[0].location.begin.line, 5); + CHECK_EQ( + result.warnings[0].text, + "Function 'f1' can implicitly return no values even though there's an explicit return at line 5; add explicit return to silence" + ); + CHECK_EQ(result.warnings[1].location.begin.line, 29); + CHECK_EQ( + result.warnings[1].text, + "Function 'f4' can implicitly return no values even though there's an explicit return at line 26; add explicit return to silence" + ); + CHECK_EQ(result.warnings[2].location.begin.line, 45); + CHECK_EQ( + result.warnings[2].text, + "Function can implicitly return no values even though there's an explicit return at line 45; add explicit return to silence" + ); } TEST_CASE_FIXTURE(Fixture, "ImplicitReturnInfiniteLoop") { LintResult result = lint(R"( +--!nonstrict function f1(a) while true do if math.random() > 0.5 then @@ -845,12 +871,16 @@ return f1,f2,f3,f4 )"); REQUIRE(2 == result.warnings.size()); - CHECK_EQ(result.warnings[0].location.begin.line, 25); - CHECK_EQ(result.warnings[0].text, - "Function 'f3' can implicitly return no values even though there's an explicit return at line 21; add explicit return to silence"); - CHECK_EQ(result.warnings[1].location.begin.line, 36); - CHECK_EQ(result.warnings[1].text, - "Function 'f4' can implicitly return no values even though there's an explicit return at line 32; add explicit return to silence"); + CHECK_EQ(result.warnings[0].location.begin.line, 26); + CHECK_EQ( + result.warnings[0].text, + "Function 'f3' can implicitly return no values even though there's an explicit return at line 22; add explicit return to silence" + ); + CHECK_EQ(result.warnings[1].location.begin.line, 37); + CHECK_EQ( + result.warnings[1].text, + "Function 'f4' can implicitly return no values even though there's an explicit return at line 33; add explicit return to silence" + ); } TEST_CASE_FIXTURE(Fixture, "TypeAnnotationsShouldNotProduceWarnings") @@ -1164,7 +1194,7 @@ os.date("!*t") TEST_CASE_FIXTURE(Fixture, "FormatStringTyped") { - LintResult result = lintTyped(R"~( + LintResult result = lint(R"~( local s: string, nons = ... string.match(s, "[]") @@ -1233,6 +1263,30 @@ _ = { CHECK_EQ(result.warnings[5].text, "Table index 1 is a duplicate; previously defined at line 36"); } +TEST_CASE_FIXTURE(Fixture, "read_write_table_props") +{ + ScopedFastFlag sff{FFlag::LuauSolverV2, true}; + + LintResult result = lint(R"(-- line 1 + type A = {x: number} + type B = {read x: number, write x: number} + type C = {x: number, read x: number} -- line 4 + type D = {x: number, write x: number} + type E = {read x: number, x: boolean} + type F = {read x: number, read x: number} + type G = {write x: number, x: boolean} + type H = {write x: number, write x: boolean} + )"); + + REQUIRE(6 == result.warnings.size()); + CHECK(result.warnings[0].text == "Table type field 'x' is already read-write; previously defined at line 4"); + CHECK(result.warnings[1].text == "Table type field 'x' is already read-write; previously defined at line 5"); + CHECK(result.warnings[2].text == "Table type field 'x' already has a read type defined at line 6"); + CHECK(result.warnings[3].text == "Table type field 'x' is a duplicate; previously defined at line 7"); + CHECK(result.warnings[4].text == "Table type field 'x' already has a write type defined at line 8"); + CHECK(result.warnings[5].text == "Table type field 'x' is a duplicate; previously defined at line 9"); +} + TEST_CASE_FIXTURE(Fixture, "ImportOnlyUsedInTypeAnnotation") { LintResult result = lint(R"( @@ -1270,12 +1324,17 @@ TEST_CASE_FIXTURE(Fixture, "no_spurious_warning_after_a_function_type_alias") TEST_CASE_FIXTURE(Fixture, "use_all_parent_scopes_for_globals") { ScopePtr testScope = frontend.addEnvironment("Test"); - unfreeze(typeChecker.globalTypes); - loadDefinitionFile(frontend.typeChecker, testScope, R"( + unfreeze(frontend.globals.globalTypes); + frontend.loadDefinitionFile( + frontend.globals, + testScope, + R"( declare Foo: number )", - "@test"); - freeze(typeChecker.globalTypes); + "@test", + /* captureComments */ false + ); + freeze(frontend.globals.globalTypes); fileResolver.environments["A"] = "Test"; @@ -1285,7 +1344,7 @@ TEST_CASE_FIXTURE(Fixture, "use_all_parent_scopes_for_globals") local _bar: typeof(os.clock) = os.clock )"; - LintResult result = frontend.lint("A"); + LintResult result = lintModule("A"); REQUIRE(0 == result.warnings.size()); } @@ -1354,7 +1413,8 @@ TEST_CASE_FIXTURE(Fixture, "DuplicateLocalFunction") options.enableWarning(LintWarning::Code_DuplicateFunction); options.enableWarning(LintWarning::Code_LocalShadow); - LintResult result = lint(R"( + LintResult result = lint( + R"( local function x() end print(x) @@ -1363,7 +1423,8 @@ TEST_CASE_FIXTURE(Fixture, "DuplicateLocalFunction") return x )", - options); + options + ); REQUIRE_EQ(1, result.warnings.size()); @@ -1442,35 +1503,34 @@ TEST_CASE_FIXTURE(Fixture, "LintHygieneUAF") TEST_CASE_FIXTURE(BuiltinsFixture, "DeprecatedApiTyped") { - ScopedFastFlag sff("LuauImproveDeprecatedApiLint", true); - - unfreeze(typeChecker.globalTypes); - TypeId instanceType = typeChecker.globalTypes.addType(ClassType{"Instance", {}, std::nullopt, std::nullopt, {}, {}, "Test"}); + unfreeze(frontend.globals.globalTypes); + TypeId instanceType = frontend.globals.globalTypes.addType(ClassType{"Instance", {}, std::nullopt, std::nullopt, {}, {}, "Test", {}}); persist(instanceType); - typeChecker.globalScope->exportedTypeBindings["Instance"] = TypeFun{{}, instanceType}; + frontend.globals.globalScope->exportedTypeBindings["Instance"] = TypeFun{{}, instanceType}; getMutable(instanceType)->props = { - {"Name", {typeChecker.stringType}}, - {"DataCost", {typeChecker.numberType, /* deprecated= */ true}}, - {"Wait", {typeChecker.anyType, /* deprecated= */ true}}, + {"Name", {builtinTypes->stringType}}, + {"DataCost", {builtinTypes->numberType, /* deprecated= */ true}}, + {"Wait", {builtinTypes->anyType, /* deprecated= */ true}}, }; - TypeId colorType = typeChecker.globalTypes.addType(TableType{{}, std::nullopt, typeChecker.globalScope->level, Luau::TableState::Sealed}); + TypeId colorType = + frontend.globals.globalTypes.addType(TableType{{}, std::nullopt, frontend.globals.globalScope->level, Luau::TableState::Sealed}); - getMutable(colorType)->props = {{"toHSV", {typeChecker.anyType, /* deprecated= */ true, "Color3:ToHSV"}}}; + getMutable(colorType)->props = {{"toHSV", {builtinTypes->anyType, /* deprecated= */ true, "Color3:ToHSV"}}}; - addGlobalBinding(frontend, "Color3", Binding{colorType, {}}); + addGlobalBinding(frontend.globals, "Color3", Binding{colorType, {}}); - if (TableType* ttv = getMutable(getGlobalBinding(typeChecker, "table"))) + if (TableType* ttv = getMutable(getGlobalBinding(frontend.globals, "table"))) { ttv->props["foreach"].deprecated = true; ttv->props["getn"].deprecated = true; ttv->props["getn"].deprecatedSuggestion = "#"; } - freeze(typeChecker.globalTypes); + freeze(frontend.globals.globalTypes); - LintResult result = lintTyped(R"( + LintResult result = lint(R"( return function (i: Instance) i:Wait(1.0) print(i.Name) @@ -1493,9 +1553,7 @@ end TEST_CASE_FIXTURE(BuiltinsFixture, "DeprecatedApiUntyped") { - ScopedFastFlag sff("LuauImproveDeprecatedApiLint", true); - - if (TableType* ttv = getMutable(getGlobalBinding(typeChecker, "table"))) + if (TableType* ttv = getMutable(getGlobalBinding(frontend.globals, "table"))) { ttv->props["foreach"].deprecated = true; ttv->props["getn"].deprecated = true; @@ -1503,6 +1561,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "DeprecatedApiUntyped") } LintResult result = lint(R"( +-- TODO return function () print(table.getn({})) table.foreach({}, function() end) @@ -1515,9 +1574,36 @@ end CHECK_EQ(result.warnings[1].text, "Member 'table.foreach' is deprecated"); } +TEST_CASE_FIXTURE(BuiltinsFixture, "DeprecatedApiFenv") +{ + LintResult result = lint(R"( +local f, g, h = ... + +getfenv(1) +getfenv(f :: () -> ()) +getfenv(g :: number) +getfenv(h :: any) + +setfenv(1, {}) +setfenv(f :: () -> (), {}) +setfenv(g :: number, {}) +setfenv(h :: any, {}) +)"); + + REQUIRE(4 == result.warnings.size()); + CHECK_EQ(result.warnings[0].text, "Function 'getfenv' is deprecated; consider using 'debug.info' instead"); + CHECK_EQ(result.warnings[0].location.begin.line + 1, 4); + CHECK_EQ(result.warnings[1].text, "Function 'getfenv' is deprecated; consider using 'debug.info' instead"); + CHECK_EQ(result.warnings[1].location.begin.line + 1, 6); + CHECK_EQ(result.warnings[2].text, "Function 'setfenv' is deprecated"); + CHECK_EQ(result.warnings[2].location.begin.line + 1, 9); + CHECK_EQ(result.warnings[3].text, "Function 'setfenv' is deprecated"); + CHECK_EQ(result.warnings[3].location.begin.line + 1, 11); +} + TEST_CASE_FIXTURE(BuiltinsFixture, "TableOperations") { - LintResult result = lintTyped(R"( + LintResult result = lint(R"( local t = {} local tt = {} @@ -1543,21 +1629,89 @@ table.create(42, {} :: {}) )"); REQUIRE(10 == result.warnings.size()); - CHECK_EQ(result.warnings[0].text, "table.insert will insert the value before the last element, which is likely a bug; consider removing the " - "second argument or wrap it in parentheses to silence"); + CHECK_EQ( + result.warnings[0].text, + "table.insert will insert the value before the last element, which is likely a bug; consider removing the " + "second argument or wrap it in parentheses to silence" + ); CHECK_EQ(result.warnings[1].text, "table.insert will append the value to the table; consider removing the second argument for efficiency"); CHECK_EQ(result.warnings[2].text, "table.insert uses index 0 but arrays are 1-based; did you mean 1 instead?"); CHECK_EQ(result.warnings[3].text, "table.remove uses index 0 but arrays are 1-based; did you mean 1 instead?"); - CHECK_EQ(result.warnings[4].text, "table.remove will remove the value before the last element, which is likely a bug; consider removing the " - "second argument or wrap it in parentheses to silence"); - CHECK_EQ(result.warnings[5].text, - "table.insert may change behavior if the call returns more than one result; consider adding parentheses around second argument"); + CHECK_EQ( + result.warnings[4].text, + "table.remove will remove the value before the last element, which is likely a bug; consider removing the " + "second argument or wrap it in parentheses to silence" + ); + CHECK_EQ( + result.warnings[5].text, + "table.insert may change behavior if the call returns more than one result; consider adding parentheses around second argument" + ); CHECK_EQ(result.warnings[6].text, "table.move uses index 0 but arrays are 1-based; did you mean 1 instead?"); CHECK_EQ(result.warnings[7].text, "table.move uses index 0 but arrays are 1-based; did you mean 1 instead?"); CHECK_EQ( - result.warnings[8].text, "table.create with a table literal will reuse the same object for all elements; consider using a for loop instead"); + result.warnings[8].text, "table.create with a table literal will reuse the same object for all elements; consider using a for loop instead" + ); CHECK_EQ( - result.warnings[9].text, "table.create with a table literal will reuse the same object for all elements; consider using a for loop instead"); + result.warnings[9].text, "table.create with a table literal will reuse the same object for all elements; consider using a for loop instead" + ); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "TableOperationsIndexer") +{ + // CLI-116824 Linter incorrectly issues false positive when taking the length of a unannotated string function argument + if (FFlag::LuauSolverV2) + return; + + LintResult result = lint(R"( +local t1 = {} -- ok: empty +local t2 = {1, 2} -- ok: array +local t3 = { a = 1, b = 2 } -- not ok: dictionary +local t4: {[number]: number} = {} -- ok: array +local t5: {[string]: number} = {} -- not ok: dictionary +local t6: typeof(setmetatable({1, 2}, {})) = {} -- ok: table with metatable +local t7: string = "hello" -- ok: string +local t8: {number} | {n: number} = {} -- ok: union + +-- not ok +print(#t3) +print(#t5) +ipairs(t5) + +-- disabled +-- ipairs(t3) adds indexer to t3, silencing error on #t3 + +-- ok +print(#t1) +print(#t2) +print(#t4) +print(#t6) +print(#t7) +print(#t8) + +ipairs(t1) +ipairs(t2) +ipairs(t4) +ipairs(t6) +ipairs(t7) +ipairs(t8) + +-- ok, subtle: text is a string here implicitly, but the type annotation isn't available +-- type checker assigns a type of generic table with the 'sub' member; we don't emit warnings on generic tables +-- to avoid generating a false positive here +function _impliedstring(element, text) + for i = 1, #text do + element:sendText(text:sub(i, i)) + end +end +)"); + + REQUIRE(3 == result.warnings.size()); + CHECK_EQ(result.warnings[0].location.begin.line + 1, 12); + CHECK_EQ(result.warnings[0].text, "Using '#' on a table without an array part is likely a bug"); + CHECK_EQ(result.warnings[1].location.begin.line + 1, 13); + CHECK_EQ(result.warnings[1].text, "Using '#' on a table with string keys is likely a bug"); + CHECK_EQ(result.warnings[2].location.begin.line + 1, 14); + CHECK_EQ(result.warnings[2].text, "Using 'ipairs' on a table with string keys is likely a bug"); } TEST_CASE_FIXTURE(Fixture, "DuplicateConditions") @@ -1605,8 +1759,8 @@ TEST_CASE_FIXTURE(Fixture, "DuplicateConditionsExpr") LintResult result = lint(R"( local correct, opaque = ... -if correct({a = 1, b = 2 * (-2), c = opaque.path['with']("calls")}) then -elseif correct({a = 1, b = 2 * (-2), c = opaque.path['with']("calls")}) then +if correct({a = 1, b = 2 * (-2), c = opaque.path['with']("calls", `string {opaque}`)}) then +elseif correct({a = 1, b = 2 * (-2), c = opaque.path['with']("calls", `string {opaque}`)}) then elseif correct({a = 1, b = 2 * (-2), c = opaque.path['with']("calls", false)}) then end )"); @@ -1650,10 +1804,16 @@ _ = (math.random() < 0.5 and false) or 42 -- currently ignored )"); REQUIRE(2 == result.warnings.size()); - CHECK_EQ(result.warnings[0].text, "The and-or expression always evaluates to the second alternative because the first alternative is false; " - "consider using if-then-else expression instead"); - CHECK_EQ(result.warnings[1].text, "The and-or expression always evaluates to the second alternative because the first alternative is nil; " - "consider using if-then-else expression instead"); + CHECK_EQ( + result.warnings[0].text, + "The and-or expression always evaluates to the second alternative because the first alternative is false; " + "consider using if-then-else expression instead" + ); + CHECK_EQ( + result.warnings[1].text, + "The and-or expression always evaluates to the second alternative because the first alternative is nil; " + "consider using if-then-else expression instead" + ); } TEST_CASE_FIXTURE(Fixture, "WrongComment") @@ -1667,17 +1827,19 @@ TEST_CASE_FIXTURE(Fixture, "WrongComment") --!nolint UnknownGlobal --! no more lint --!strict here +--!native on do end --!nolint )"); - REQUIRE(6 == result.warnings.size()); + REQUIRE(7 == result.warnings.size()); CHECK_EQ(result.warnings[0].text, "Unknown comment directive 'struct'; did you mean 'strict'?"); CHECK_EQ(result.warnings[1].text, "Unknown comment directive 'nolintGlobal'"); CHECK_EQ(result.warnings[2].text, "nolint directive refers to unknown lint rule 'Global'"); CHECK_EQ(result.warnings[3].text, "nolint directive refers to unknown lint rule 'KnownGlobal'; did you mean 'UnknownGlobal'?"); CHECK_EQ(result.warnings[4].text, "Comment directive with the type checking mode has extra symbols at the end of the line"); - CHECK_EQ(result.warnings[5].text, "Comment directive is ignored because it is placed after the first non-comment token"); + CHECK_EQ(result.warnings[5].text, "native directive has extra symbols at the end of the line"); + CHECK_EQ(result.warnings[6].text, "Comment directive is ignored because it is placed after the first non-comment token"); } TEST_CASE_FIXTURE(Fixture, "WrongCommentMuteSelf") @@ -1740,8 +1902,67 @@ local _ = 0x10000000000000000 )"); REQUIRE(2 == result.warnings.size()); - CHECK_EQ(result.warnings[0].text, "Binary number literal exceeded available precision and has been truncated to 2^64"); - CHECK_EQ(result.warnings[1].text, "Hexadecimal number literal exceeded available precision and has been truncated to 2^64"); + CHECK_EQ(result.warnings[0].text, "Binary number literal exceeded available precision and was truncated to 2^64"); + CHECK_EQ(result.warnings[1].text, "Hexadecimal number literal exceeded available precision and was truncated to 2^64"); +} + +TEST_CASE_FIXTURE(Fixture, "IntegerParsingDecimalImprecise") +{ + LintResult result = lint(R"( +local _ = 10000000000000000000000000000000000000000000000000000000000000000 +local _ = 10000000000000001 +local _ = -10000000000000001 + +-- 10^16 = 2^16 * 5^16, 5^16 only requires 38 bits +local _ = 10000000000000000 +local _ = -10000000000000000 + +-- smallest possible number that is parsed imprecisely +local _ = 9007199254740993 +local _ = -9007199254740993 + +-- note that numbers before and after parse precisely (number after is even => 1 more mantissa bit) +local _ = 9007199254740992 +local _ = 9007199254740994 + +-- large powers of two should work as well (this is 2^63) +local _ = -9223372036854775808 +)"); + + REQUIRE(5 == result.warnings.size()); + CHECK_EQ(result.warnings[0].text, "Number literal exceeded available precision and was truncated to closest representable number"); + CHECK_EQ(result.warnings[0].location.begin.line, 1); + CHECK_EQ(result.warnings[1].text, "Number literal exceeded available precision and was truncated to closest representable number"); + CHECK_EQ(result.warnings[1].location.begin.line, 2); + CHECK_EQ(result.warnings[2].text, "Number literal exceeded available precision and was truncated to closest representable number"); + CHECK_EQ(result.warnings[2].location.begin.line, 3); + CHECK_EQ(result.warnings[3].text, "Number literal exceeded available precision and was truncated to closest representable number"); + CHECK_EQ(result.warnings[3].location.begin.line, 10); + CHECK_EQ(result.warnings[4].text, "Number literal exceeded available precision and was truncated to closest representable number"); + CHECK_EQ(result.warnings[4].location.begin.line, 11); +} + +TEST_CASE_FIXTURE(Fixture, "IntegerParsingHexImprecise") +{ + LintResult result = lint(R"( +local _ = 0x1234567812345678 + +-- smallest possible number that is parsed imprecisely +local _ = 0x20000000000001 + +-- note that numbers before and after parse precisely (number after is even => 1 more mantissa bit) +local _ = 0x20000000000000 +local _ = 0x20000000000002 + +-- large powers of two should work as well (this is 2^63) +local _ = 0x80000000000000 +)"); + + REQUIRE(2 == result.warnings.size()); + CHECK_EQ(result.warnings[0].text, "Number literal exceeded available precision and was truncated to closest representable number"); + CHECK_EQ(result.warnings[0].location.begin.line, 1); + CHECK_EQ(result.warnings[1].text, "Number literal exceeded available precision and was truncated to closest representable number"); + CHECK_EQ(result.warnings[1].location.begin.line, 4); } TEST_CASE_FIXTURE(Fixture, "ComparisonPrecedence") @@ -1776,4 +1997,32 @@ local _ = a <= (b == 0) CHECK_EQ(result.warnings[4].text, "X <= Y <= Z is equivalent to (X <= Y) <= Z; did you mean X <= Y and Y <= Z?"); } +TEST_CASE_FIXTURE(Fixture, "RedundantNativeAttribute") +{ + ScopedFastFlag sff[] = {{FFlag::LuauNativeAttribute, true}, {FFlag::LintRedundantNativeAttribute, true}}; + + LintResult result = lint(R"( +--!native + +@native +local function f(a) + @native + local function g(b) + return (a + b) + end + return g +end + +f(3)(4) +)"); + + REQUIRE(2 == result.warnings.size()); + + CHECK_EQ(result.warnings[0].text, "native attribute on a function is redundant in a native module; consider removing it"); + CHECK_EQ(result.warnings[0].location, Location(Position(3, 0), Position(3, 7))); + + CHECK_EQ(result.warnings[1].text, "native attribute on a function is redundant in a native module; consider removing it"); + CHECK_EQ(result.warnings[1].location, Location(Position(5, 4), Position(5, 11))); +} + TEST_SUITE_END(); diff --git a/tests/Module.test.cpp b/tests/Module.test.cpp index 2c45cc385..4519ba823 100644 --- a/tests/Module.test.cpp +++ b/tests/Module.test.cpp @@ -1,16 +1,19 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/Clone.h" +#include "Luau/Common.h" #include "Luau/Module.h" -#include "Luau/Scope.h" -#include "Luau/RecursionCounter.h" +#include "Luau/Parser.h" #include "Fixture.h" +#include "ScopedFlags.h" #include "doctest.h" using namespace Luau; -LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); +LUAU_FASTFLAG(LuauSolverV2); +LUAU_FASTFLAG(DebugLuauFreezeArena); +LUAU_FASTINT(LuauTypeCloneIterationLimit); TEST_SUITE_BEGIN("ModuleTests"); @@ -42,25 +45,57 @@ TEST_CASE_FIXTURE(Fixture, "is_within_comment") CHECK(!isWithinComment(*sm, Position{7, 11})); } +TEST_CASE_FIXTURE(Fixture, "is_within_comment_parse_result") +{ + std::string src = R"( + --!strict + local foo = {} + function foo:bar() end + + --[[ + foo: + ]] foo:bar() + + --[[]]--[[]] -- Two distinct comments that have zero characters of space between them. + )"; + + Luau::Allocator alloc; + Luau::AstNameTable names{alloc}; + Luau::ParseOptions parseOptions; + parseOptions.captureComments = true; + Luau::ParseResult parseResult = Luau::Parser::parse(src.data(), src.size(), names, alloc, parseOptions); + + CHECK_EQ(5, parseResult.commentLocations.size()); + + CHECK(isWithinComment(parseResult, Position{1, 15})); + CHECK(isWithinComment(parseResult, Position{6, 16})); + CHECK(isWithinComment(parseResult, Position{9, 13})); + CHECK(isWithinComment(parseResult, Position{9, 14})); + + CHECK(!isWithinComment(parseResult, Position{2, 15})); + CHECK(!isWithinComment(parseResult, Position{7, 10})); + CHECK(!isWithinComment(parseResult, Position{7, 11})); +} + TEST_CASE_FIXTURE(Fixture, "dont_clone_persistent_primitive") { TypeArena dest; - CloneState cloneState; + CloneState cloneState{builtinTypes}; // numberType is persistent. We leave it as-is. - TypeId newNumber = clone(typeChecker.numberType, dest, cloneState); - CHECK_EQ(newNumber, typeChecker.numberType); + TypeId newNumber = clone(builtinTypes->numberType, dest, cloneState); + CHECK_EQ(newNumber, builtinTypes->numberType); } TEST_CASE_FIXTURE(Fixture, "deepClone_non_persistent_primitive") { TypeArena dest; - CloneState cloneState; + CloneState cloneState{builtinTypes}; // Create a new number type that isn't persistent - unfreeze(typeChecker.globalTypes); - TypeId oldNumber = typeChecker.globalTypes.addType(PrimitiveType{PrimitiveType::Number}); - freeze(typeChecker.globalTypes); + unfreeze(frontend.globals.globalTypes); + TypeId oldNumber = frontend.globals.globalTypes.addType(PrimitiveType{PrimitiveType::Number}); + freeze(frontend.globals.globalTypes); TypeId newNumber = clone(oldNumber, dest, cloneState); CHECK_NE(newNumber, oldNumber); @@ -76,7 +111,7 @@ TEST_CASE_FIXTURE(Fixture, "deepClone_cyclic_table") // not, but it's tangental to the core purpose of this test. ScopedFastFlag sff[] = { - {"DebugLuauDeferredConstraintResolution", false}, + {FFlag::LuauSolverV2, false}, }; CheckResult result = check(R"( @@ -96,7 +131,7 @@ TEST_CASE_FIXTURE(Fixture, "deepClone_cyclic_table") TypeId ty = requireType("Cyclic"); TypeArena dest; - CloneState cloneState; + CloneState cloneState{builtinTypes}; TypeId cloneTy = clone(ty, dest, cloneState); TableType* ttv = getMutable(cloneTy); @@ -104,7 +139,7 @@ TEST_CASE_FIXTURE(Fixture, "deepClone_cyclic_table") CHECK_EQ(std::optional{"Cyclic"}, ttv->syntheticName); - TypeId methodType = ttv->props["get"].type; + TypeId methodType = ttv->props["get"].type(); REQUIRE(methodType != nullptr); const FunctionType* ftv = get(methodType); @@ -128,16 +163,16 @@ TEST_CASE_FIXTURE(Fixture, "deepClone_cyclic_table_2") TypeId methodTy = src.addType(FunctionType{src.addTypePack({}), src.addTypePack({tableTy})}); - tt->props["get"].type = methodTy; + tt->props["get"].setType(methodTy); TypeArena dest; - CloneState cloneState; + CloneState cloneState{builtinTypes}; TypeId cloneTy = clone(tableTy, dest, cloneState); TableType* ctt = getMutable(cloneTy); REQUIRE(ctt); - TypeId clonedMethodType = ctt->props["get"].type; + TypeId clonedMethodType = ctt->props["get"].type(); REQUIRE(clonedMethodType); const FunctionType* cmf = get(clonedMethodType); @@ -166,24 +201,21 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "builtin_types_point_into_globalTypes_arena") TableType* exportsTable = getMutable(*exports); REQUIRE(exportsTable != nullptr); - TypeId signType = exportsTable->props["sign"].type; + TypeId signType = exportsTable->props["sign"].type(); REQUIRE(signType != nullptr); CHECK(!isInArena(signType, module->interfaceTypes)); - if (FFlag::DebugLuauDeferredConstraintResolution) - CHECK(isInArena(signType, frontend.globalTypes)); - else - CHECK(isInArena(signType, typeChecker.globalTypes)); + CHECK(isInArena(signType, frontend.globals.globalTypes)); } TEST_CASE_FIXTURE(Fixture, "deepClone_union") { TypeArena dest; - CloneState cloneState; + CloneState cloneState{builtinTypes}; - unfreeze(typeChecker.globalTypes); - TypeId oldUnion = typeChecker.globalTypes.addType(UnionType{{typeChecker.numberType, typeChecker.stringType}}); - freeze(typeChecker.globalTypes); + unfreeze(frontend.globals.globalTypes); + TypeId oldUnion = frontend.globals.globalTypes.addType(UnionType{{builtinTypes->numberType, builtinTypes->stringType}}); + freeze(frontend.globals.globalTypes); TypeId newUnion = clone(oldUnion, dest, cloneState); CHECK_NE(newUnion, oldUnion); @@ -194,11 +226,11 @@ TEST_CASE_FIXTURE(Fixture, "deepClone_union") TEST_CASE_FIXTURE(Fixture, "deepClone_intersection") { TypeArena dest; - CloneState cloneState; + CloneState cloneState{builtinTypes}; - unfreeze(typeChecker.globalTypes); - TypeId oldIntersection = typeChecker.globalTypes.addType(IntersectionType{{typeChecker.numberType, typeChecker.stringType}}); - freeze(typeChecker.globalTypes); + unfreeze(frontend.globals.globalTypes); + TypeId oldIntersection = frontend.globals.globalTypes.addType(IntersectionType{{builtinTypes->numberType, builtinTypes->stringType}}); + freeze(frontend.globals.globalTypes); TypeId newIntersection = clone(oldIntersection, dest, cloneState); CHECK_NE(newIntersection, oldIntersection); @@ -208,20 +240,34 @@ TEST_CASE_FIXTURE(Fixture, "deepClone_intersection") TEST_CASE_FIXTURE(Fixture, "clone_class") { - Type exampleMetaClass{ClassType{"ExampleClassMeta", + Type exampleMetaClass{ClassType{ + "ExampleClassMeta", { - {"__add", {typeChecker.anyType}}, + {"__add", {builtinTypes->anyType}}, }, - std::nullopt, std::nullopt, {}, {}, "Test"}}; - Type exampleClass{ClassType{"ExampleClass", + std::nullopt, + std::nullopt, + {}, + {}, + "Test", + {} + }}; + Type exampleClass{ClassType{ + "ExampleClass", { - {"PropOne", {typeChecker.numberType}}, - {"PropTwo", {typeChecker.stringType}}, + {"PropOne", {builtinTypes->numberType}}, + {"PropTwo", {builtinTypes->stringType}}, }, - std::nullopt, &exampleMetaClass, {}, {}, "Test"}}; + std::nullopt, + &exampleMetaClass, + {}, + {}, + "Test", + {} + }}; TypeArena dest; - CloneState cloneState; + CloneState cloneState{builtinTypes}; TypeId cloned = clone(&exampleClass, dest, cloneState); const ClassType* ctv = get(cloned); @@ -237,16 +283,19 @@ TEST_CASE_FIXTURE(Fixture, "clone_class") TEST_CASE_FIXTURE(Fixture, "clone_free_types") { - Type freeTy(FreeType{TypeLevel{}}); + ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + + TypeArena arena; + TypeId freeTy = freshType(NotNull{&arena}, builtinTypes, nullptr); TypePackVar freeTp(FreeTypePack{TypeLevel{}}); TypeArena dest; - CloneState cloneState; + CloneState cloneState{builtinTypes}; - TypeId clonedTy = clone(&freeTy, dest, cloneState); + TypeId clonedTy = clone(freeTy, dest, cloneState); CHECK(get(clonedTy)); - cloneState = {}; + cloneState = {builtinTypes}; TypePackId clonedTp = clone(&freeTp, dest, cloneState); CHECK(get(clonedTp)); } @@ -258,7 +307,7 @@ TEST_CASE_FIXTURE(Fixture, "clone_free_tables") ttv->state = TableState::Free; TypeArena dest; - CloneState cloneState; + CloneState cloneState{builtinTypes}; TypeId cloned = clone(&tableTy, dest, cloneState); const TableType* clonedTtv = get(cloned); @@ -267,6 +316,9 @@ TEST_CASE_FIXTURE(Fixture, "clone_free_tables") TEST_CASE_FIXTURE(BuiltinsFixture, "clone_self_property") { + // CLI-117082 ModuleTests.clone_self_property we don't infer self correctly, instead replacing it with unknown. + if (FFlag::LuauSolverV2) + return; fileResolver.source["Module/A"] = R"( --!nonstrict local a = {} @@ -292,32 +344,59 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "clone_self_property") CHECK_EQ("This function must be called with self. Did you mean to use a colon instead of a dot?", toString(result.errors[0])); } -TEST_CASE_FIXTURE(Fixture, "clone_recursion_limit") +TEST_CASE_FIXTURE(Fixture, "clone_iteration_limit") { -#if defined(_DEBUG) || defined(_NOOPT) - int limit = 250; -#else - int limit = 400; -#endif - ScopedFastInt luauTypeCloneRecursionLimit{"LuauTypeCloneRecursionLimit", limit}; + ScopedFastInt sfi{FInt::LuauTypeCloneIterationLimit, 2000}; TypeArena src; TypeId table = src.addType(TableType{}); TypeId nested = table; - for (int i = 0; i < limit + 100; i++) + int nesting = 2500; + for (int i = 0; i < nesting; i++) { TableType* ttv = getMutable(nested); - - ttv->props["a"].type = src.addType(TableType{}); - nested = ttv->props["a"].type; + ttv->props["a"].setType(src.addType(TableType{})); + nested = ttv->props["a"].type(); } TypeArena dest; - CloneState cloneState; + CloneState cloneState{builtinTypes}; + + TypeId ty = clone(table, dest, cloneState); + CHECK(get(ty)); + + // Cloning it again is an important test. + TypeId ty2 = clone(table, dest, cloneState); + CHECK(get(ty2)); +} + +// Unions should never be cyclic, but we should clone them correctly even if +// they are. +TEST_CASE_FIXTURE(Fixture, "clone_cyclic_union") +{ + TypeArena src; + + TypeId u = src.addType(UnionType{{builtinTypes->numberType, builtinTypes->stringType}}); + UnionType* uu = getMutable(u); + REQUIRE(uu); + + uu->options.push_back(u); + + TypeArena dest; + CloneState cloneState{builtinTypes}; + + TypeId cloned = clone(u, dest, cloneState); + REQUIRE(cloned); + + const UnionType* clonedUnion = get(cloned); + REQUIRE(clonedUnion); + REQUIRE(3 == clonedUnion->options.size()); - CHECK_THROWS_AS(clone(table, dest, cloneState), RecursionLimitException); + CHECK(builtinTypes->numberType == clonedUnion->options[0]); + CHECK(builtinTypes->stringType == clonedUnion->options[1]); + CHECK(cloned == clonedUnion->options[2]); } TEST_CASE_FIXTURE(Fixture, "any_persistance_does_not_leak") @@ -335,27 +414,24 @@ type B = A auto mod = frontend.moduleResolver.getModule("Module/A"); auto it = mod->exportedTypeBindings.find("A"); REQUIRE(it != mod->exportedTypeBindings.end()); - CHECK(toString(it->second.type) == "any"); + + if (FFlag::LuauSolverV2) + CHECK(toString(it->second.type) == "any"); + else + CHECK(toString(it->second.type) == "*error-type*"); } TEST_CASE_FIXTURE(BuiltinsFixture, "do_not_clone_reexports") { - ScopedFastFlag flags[] = { - {"LuauClonePublicInterfaceLess", true}, - {"LuauSubstitutionReentrant", true}, - {"LuauClassTypeVarsInSubstitution", true}, - {"LuauSubstitutionFixMissingFields", true}, - }; - fileResolver.source["Module/A"] = R"( -export type A = {p : number} -return {} + export type A = {p : number} + return {} )"; fileResolver.source["Module/B"] = R"( -local a = require(script.Parent.A) -export type B = {q : a.A} -return {} + local a = require(script.Parent.A) + export type B = {q : a.A} + return {} )"; CheckResult result = frontend.check("Module/B"); @@ -373,43 +449,97 @@ return {} TypeId typeB = modBiter->second.type; TableType* tableB = getMutable(typeB); REQUIRE(tableB); - CHECK(typeA == tableB->props["q"].type); + CHECK(typeA == tableB->props["q"].type()); } TEST_CASE_FIXTURE(BuiltinsFixture, "do_not_clone_types_of_reexported_values") { - ScopedFastFlag flags[] = { - {"LuauClonePublicInterfaceLess", true}, - {"LuauSubstitutionReentrant", true}, - {"LuauClassTypeVarsInSubstitution", true}, - {"LuauSubstitutionFixMissingFields", true}, - }; - fileResolver.source["Module/A"] = R"( -local exports = {a={p=5}} -return exports + local exports = {a={p=5}} + return exports )"; fileResolver.source["Module/B"] = R"( -local a = require(script.Parent.A) -local exports = {b=a.a} -return exports + local a = require(script.Parent.A) + local exports = {b=a.a} + return exports )"; CheckResult result = frontend.check("Module/B"); LUAU_REQUIRE_NO_ERRORS(result); ModulePtr modA = frontend.moduleResolver.getModule("Module/A"); - ModulePtr modB = frontend.moduleResolver.getModule("Module/B"); REQUIRE(modA); + ModulePtr modB = frontend.moduleResolver.getModule("Module/B"); REQUIRE(modB); + std::optional typeA = first(modA->returnType); - std::optional typeB = first(modB->returnType); REQUIRE(typeA); + std::optional typeB = first(modB->returnType); REQUIRE(typeB); + TableType* tableA = getMutable(*typeA); + REQUIRE_MESSAGE(tableA, "Expected a table, but got " << toString(*typeA)); TableType* tableB = getMutable(*typeB); - CHECK(tableA->props["a"].type == tableB->props["b"].type); + REQUIRE_MESSAGE(tableB, "Expected a table, but got " << toString(*typeB)); + + CHECK(tableA->props["a"].type() == tableB->props["b"].type()); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "clone_table_bound_to_table_bound_to_table") +{ + TypeArena arena; + + TypeId a = arena.addType(TableType{TableState::Free, TypeLevel{}}); + getMutable(a)->name = "a"; + + TypeId b = arena.addType(TableType{TableState::Free, TypeLevel{}}); + getMutable(b)->name = "b"; + + TypeId c = arena.addType(TableType{TableState::Free, TypeLevel{}}); + getMutable(c)->name = "c"; + + getMutable(a)->boundTo = b; + getMutable(b)->boundTo = c; + + TypeArena dest; + CloneState state{builtinTypes}; + TypeId res = clone(a, dest, state); + + REQUIRE(dest.types.size() == 1); + + auto tableA = get(res); + REQUIRE_MESSAGE(tableA, "Expected table, got " << res); + REQUIRE(tableA->name == "c"); + REQUIRE(!tableA->boundTo); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "clone_a_bound_type_to_a_persistent_type") +{ + TypeArena arena; + + TypeId boundTo = arena.addType(BoundType{builtinTypes->numberType}); + REQUIRE(builtinTypes->numberType->persistent); + + TypeArena dest; + CloneState state{builtinTypes}; + TypeId res = clone(boundTo, dest, state); + + REQUIRE(res == follow(boundTo)); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "clone_a_bound_typepack_to_a_persistent_typepack") +{ + TypeArena arena; + + TypePackId boundTo = arena.addTypePack(BoundTypePack{builtinTypes->neverTypePack}); + REQUIRE(builtinTypes->neverTypePack->persistent); + + TypeArena dest; + CloneState state{builtinTypes}; + TypePackId res = clone(boundTo, dest, state); + + REQUIRE(res == follow(boundTo)); } TEST_SUITE_END(); diff --git a/tests/NonStrictTypeChecker.test.cpp b/tests/NonStrictTypeChecker.test.cpp new file mode 100644 index 000000000..ffb440492 --- /dev/null +++ b/tests/NonStrictTypeChecker.test.cpp @@ -0,0 +1,579 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/NonStrictTypeChecker.h" + +#include "Fixture.h" + +#include "Luau/Ast.h" +#include "Luau/Common.h" +#include "Luau/IostreamHelpers.h" +#include "Luau/ModuleResolver.h" +#include "Luau/VisitType.h" + +#include "ScopedFlags.h" +#include "doctest.h" +#include + +using namespace Luau; + +#define NONSTRICT_REQUIRE_ERR_AT_POS(pos, result, idx) \ + do \ + { \ + auto pos_ = (pos); \ + bool foundErr = false; \ + int index = 0; \ + for (const auto& err : result.errors) \ + { \ + if (err.location.begin == pos_) \ + { \ + foundErr = true; \ + break; \ + } \ + index++; \ + } \ + REQUIRE_MESSAGE(foundErr, "Expected error at " << pos_); \ + idx = index; \ + } while (false) + +#define NONSTRICT_REQUIRE_CHECKED_ERR(pos, name, result) \ + do \ + { \ + int errIndex; \ + NONSTRICT_REQUIRE_ERR_AT_POS(pos, result, errIndex); \ + auto err = get(result.errors[errIndex]); \ + REQUIRE(err != nullptr); \ + CHECK_EQ(err->checkedFunctionName, name); \ + } while (false) + +#define NONSTRICT_REQUIRE_FUNC_DEFINITION_ERR(pos, argname, result) \ + do \ + { \ + int errIndex; \ + NONSTRICT_REQUIRE_ERR_AT_POS(pos, result, errIndex); \ + auto err = get(result.errors[errIndex]); \ + REQUIRE(err != nullptr); \ + CHECK_EQ(err->argument, argname); \ + } while (false) + + +struct NonStrictTypeCheckerFixture : Fixture +{ + + NonStrictTypeCheckerFixture() + { + registerHiddenTypes(&frontend); + registerTestTypes(); + } + + CheckResult checkNonStrict(const std::string& code) + { + ScopedFastFlag flags[] = { + {FFlag::LuauSolverV2, true}, + }; + LoadDefinitionFileResult res = loadDefinition(definitions); + LUAU_ASSERT(res.success); + return check(Mode::Nonstrict, code); + } + + CheckResult checkNonStrictModule(const std::string& moduleName) + { + ScopedFastFlag flags[] = { + {FFlag::LuauSolverV2, true}, + }; + LoadDefinitionFileResult res = loadDefinition(definitions); + LUAU_ASSERT(res.success); + return frontend.check(moduleName); + } + + std::string definitions = R"BUILTIN_SRC( +@checked declare function abs(n: number): number +@checked declare function lower(s: string): string +declare function cond() : boolean +@checked declare function contrived(n : Not) : number + +-- interesting types of things that we would like to mark as checked +@checked declare function onlyNums(...: number) : number +@checked declare function mixedArgs(x: string, ...: number) : number +@checked declare function optionalArg(x: string?) : number +declare foo: { + bar: @checked (number) -> number, +} + +@checked declare function optionalArgsAtTheEnd1(x: string, y: number?, z: number?) : number +@checked declare function optionalArgsAtTheEnd2(x: string, y: number?, z: string) : number + +type DateTypeArg = { + year: number, + month: number, + day: number, + hour: number?, + min: number?, + sec: number?, + isdst: boolean?, +} + +declare os : { + time: @checked (time: DateTypeArg?) -> number +} + +@checked declare function require(target : any) : any +)BUILTIN_SRC"; +}; + +TEST_SUITE_BEGIN("NonStrictTypeCheckerTest"); + +TEST_CASE_FIXTURE(NonStrictTypeCheckerFixture, "interesting_checked_functions") +{ + CheckResult result = checkNonStrict(R"( +onlyNums(1,1,1) +onlyNums(1, "a") + +mixedArgs("a", 1, 2) +mixedArgs(1, 1, 1) +mixedArgs("a", true) + +optionalArg(nil) +optionalArg("a") +optionalArg(3) +)"); + + LUAU_REQUIRE_ERROR_COUNT(4, result); + NONSTRICT_REQUIRE_CHECKED_ERR(Position(2, 12), "onlyNums", result); // onlyNums(1, "a") + + NONSTRICT_REQUIRE_CHECKED_ERR(Position(5, 10), "mixedArgs", result); // mixedArgs(1, 1, 1) + NONSTRICT_REQUIRE_CHECKED_ERR(Position(6, 15), "mixedArgs", result); // mixedArgs("a", true) + + NONSTRICT_REQUIRE_CHECKED_ERR(Position(10, 12), "optionalArg", result); // optionalArg(3) +} + + +TEST_CASE_FIXTURE(NonStrictTypeCheckerFixture, "simple_negation_caching_example") +{ + CheckResult result = checkNonStrict(R"( +local x = 3 +abs(x) +abs(x) +)"); + + LUAU_REQUIRE_NO_ERRORS(result); + + result = checkNonStrict(R"( +local x = 3 +contrived(x) +contrived(x) + )"); + + LUAU_REQUIRE_ERROR_COUNT(2, result); + NONSTRICT_REQUIRE_CHECKED_ERR(Position(2, 10), "contrived", result); + NONSTRICT_REQUIRE_CHECKED_ERR(Position(3, 10), "contrived", result); +} + +TEST_CASE_FIXTURE(NonStrictTypeCheckerFixture, "simple_non_strict_failure") +{ + CheckResult result = checkNonStrict(R"BUILTIN_SRC( +abs("hi") +)BUILTIN_SRC"); + LUAU_REQUIRE_ERROR_COUNT(1, result); + NONSTRICT_REQUIRE_CHECKED_ERR(Position(1, 4), "abs", result); +} + +TEST_CASE_FIXTURE(NonStrictTypeCheckerFixture, "nested_function_calls_constant") +{ + CheckResult result = checkNonStrict(R"( +local x +abs(lower(x)) +)"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + NONSTRICT_REQUIRE_CHECKED_ERR(Position(2, 4), "abs", result); +} + + +TEST_CASE_FIXTURE(NonStrictTypeCheckerFixture, "if_then_else_warns_with_never_local") +{ + CheckResult result = checkNonStrict(R"( +local x : never +if cond() then + abs(x) +else + lower(x) +end +)"); + + LUAU_REQUIRE_ERROR_COUNT(2, result); + NONSTRICT_REQUIRE_CHECKED_ERR(Position(3, 8), "abs", result); + NONSTRICT_REQUIRE_CHECKED_ERR(Position(5, 10), "lower", result); +} + +TEST_CASE_FIXTURE(NonStrictTypeCheckerFixture, "if_then_else_warns_nil_branches") +{ + auto result = checkNonStrict(R"( +local x +if cond() then + abs(x) +else + lower(x) +end +)"); + + LUAU_REQUIRE_ERROR_COUNT(2, result); + NONSTRICT_REQUIRE_CHECKED_ERR(Position(3, 8), "abs", result); + NONSTRICT_REQUIRE_CHECKED_ERR(Position(5, 10), "lower", result); +} + +TEST_CASE_FIXTURE(NonStrictTypeCheckerFixture, "if_then_else_doesnt_warn_else_branch") +{ + auto result = checkNonStrict(R"( +local x : string = "hi" +if cond() then + abs(x) +else + lower(x) +end +)"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + NONSTRICT_REQUIRE_CHECKED_ERR(Position(3, 8), "abs", result); +} + +TEST_CASE_FIXTURE(NonStrictTypeCheckerFixture, "if_then_no_else") +{ + CheckResult result = checkNonStrict(R"( +local x : string +if cond() then + abs(x) +end +)"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(NonStrictTypeCheckerFixture, "if_then_no_else_err_in_cond") +{ + CheckResult result = checkNonStrict(R"( +local x : string = "" +if abs(x) then + lower(x) +end +)"); + LUAU_REQUIRE_ERROR_COUNT(1, result); + NONSTRICT_REQUIRE_CHECKED_ERR(Position(2, 7), "abs", result); +} + +TEST_CASE_FIXTURE(NonStrictTypeCheckerFixture, "if_then_else_expr_should_warn") +{ + CheckResult result = checkNonStrict(R"( +local x : never +local y = if cond() then abs(x) else lower(x) +)"); + LUAU_REQUIRE_ERROR_COUNT(2, result); + NONSTRICT_REQUIRE_CHECKED_ERR(Position(2, 29), "abs", result); + NONSTRICT_REQUIRE_CHECKED_ERR(Position(2, 43), "lower", result); +} + +TEST_CASE_FIXTURE(NonStrictTypeCheckerFixture, "if_then_else_expr_doesnt_warn_else_branch") +{ + CheckResult result = checkNonStrict(R"( +local x : string = "hi" +local y = if cond() then abs(x) else lower(x) +)"); + LUAU_REQUIRE_ERROR_COUNT(1, result); + NONSTRICT_REQUIRE_CHECKED_ERR(Position(2, 29), "abs", result); +} + +TEST_CASE_FIXTURE(NonStrictTypeCheckerFixture, "sequencing_if_checked_call") +{ + CheckResult result = checkNonStrict(R"( +local x +if cond() then + x = 5 +else + x = nil +end +lower(x) +)"); + LUAU_REQUIRE_ERROR_COUNT(1, result); + NONSTRICT_REQUIRE_CHECKED_ERR(Position(7, 6), "lower", result); +} + +TEST_CASE_FIXTURE(NonStrictTypeCheckerFixture, "function_def_unrelated_checked_calls") +{ + CheckResult result = checkNonStrict(R"( +function h(x, y) + abs(x) + lower(y) +end +)"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(NonStrictTypeCheckerFixture, "function_def_basic_no_errors") +{ + CheckResult result = checkNonStrict(R"( +function f(x) + abs(x) +end +)"); + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(NonStrictTypeCheckerFixture, "function_def_basic_errors") +{ + CheckResult result = checkNonStrict(R"( +function f(x : string) + abs(x) +end +)"); + LUAU_REQUIRE_ERROR_COUNT(1, result); + NONSTRICT_REQUIRE_CHECKED_ERR(Position(2, 8), "abs", result); +} + + +TEST_CASE_FIXTURE(NonStrictTypeCheckerFixture, "function_def_failure") +{ + CheckResult result = checkNonStrict(R"( +function f(x) + abs(lower(x)) +end +)"); + LUAU_REQUIRE_ERROR_COUNT(1, result); + NONSTRICT_REQUIRE_CHECKED_ERR(Position(2, 8), "abs", result); +} + +TEST_CASE_FIXTURE(NonStrictTypeCheckerFixture, "function_def_sequencing_errors") +{ + CheckResult result = checkNonStrict(R"( +function f(x) + abs(x) + lower(x) +end +)"); + LUAU_REQUIRE_ERROR_COUNT(3, result); + NONSTRICT_REQUIRE_CHECKED_ERR(Position(2, 8), "abs", result); + NONSTRICT_REQUIRE_CHECKED_ERR(Position(3, 10), "lower", result); + NONSTRICT_REQUIRE_FUNC_DEFINITION_ERR(Position(1, 11), "x", result); +} + +TEST_CASE_FIXTURE(NonStrictTypeCheckerFixture, "local_fn_produces_error") +{ + CheckResult result = checkNonStrict(R"( +local x = 5 +local function y() lower(x) end +)"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + NONSTRICT_REQUIRE_CHECKED_ERR(Position(2, 25), "lower", result); +} + +TEST_CASE_FIXTURE(NonStrictTypeCheckerFixture, "fn_expr_produces_error") +{ + CheckResult result = checkNonStrict(R"( +local x = 5 +local y = function() lower(x) end +)"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + NONSTRICT_REQUIRE_CHECKED_ERR(Position(2, 27), "lower", result); +} + + +TEST_CASE_FIXTURE(NonStrictTypeCheckerFixture, "function_def_if_warns_never") +{ + CheckResult result = checkNonStrict(R"( +function f(x) + if cond() then + abs(x) + else + lower(x) + end +end +)"); + LUAU_REQUIRE_ERROR_COUNT(2, result); + NONSTRICT_REQUIRE_CHECKED_ERR(Position(3, 12), "abs", result); + NONSTRICT_REQUIRE_CHECKED_ERR(Position(5, 14), "lower", result); +} + +TEST_CASE_FIXTURE(NonStrictTypeCheckerFixture, "function_def_if_no_else") +{ + CheckResult result = checkNonStrict(R"( +function f(x) + if cond() then + abs(x) + end +end +)"); + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(NonStrictTypeCheckerFixture, "function_def_if_assignment_errors") +{ + CheckResult result = checkNonStrict(R"( +function f(x) + if cond() then + x = 5 + else + x = nil + end + lower(x) +end +)"); + LUAU_REQUIRE_ERROR_COUNT(1, result); + NONSTRICT_REQUIRE_CHECKED_ERR(Position(7, 10), "lower", result); +} + +TEST_CASE_FIXTURE(NonStrictTypeCheckerFixture, "function_def_if_assignment_no_errors") +{ + CheckResult result = checkNonStrict(R"( +function f(x : string | number) + if cond() then + x = 5 + else + x = "hi" + end + abs(x) +end +)"); + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(NonStrictTypeCheckerFixture, "local_only_one_warning") +{ + CheckResult result = checkNonStrict(R"( +local x = 5 +lower(x) +)"); + LUAU_REQUIRE_ERROR_COUNT(1, result); + NONSTRICT_REQUIRE_CHECKED_ERR(Position(2, 6), "lower", result); +} + +TEST_CASE_FIXTURE(NonStrictTypeCheckerFixture, "phi_node_assignment") +{ + CheckResult result = checkNonStrict(R"( +local x = "a" -- x1 +if cond() then + x = 3 -- x2 +end +lower(x) -- phi {x1, x2} +)"); + + LUAU_REQUIRE_NO_ERRORS(result); +} // + +TEST_CASE_FIXTURE(NonStrictTypeCheckerFixture, "phi_node_assignment_err") +{ + CheckResult result = checkNonStrict(R"( +local x = nil +if cond() then + if cond() then + x = 5 + end + abs(x) +else + lower(x) +end +)"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + NONSTRICT_REQUIRE_CHECKED_ERR(Position(8, 10), "lower", result); +} + +TEST_CASE_FIXTURE(NonStrictTypeCheckerFixture, "tblprop_is_checked") +{ + CheckResult result = checkNonStrict(R"( +foo.bar("hi") +)"); + LUAU_REQUIRE_ERROR_COUNT(1, result); + NONSTRICT_REQUIRE_CHECKED_ERR(Position(1, 8), "foo.bar", result); +} + +TEST_CASE_FIXTURE(NonStrictTypeCheckerFixture, "incorrect_arg_count") +{ + CheckResult result = checkNonStrict(R"( +foo.bar(1,2,3) +abs(3, "hi"); +)"); + LUAU_REQUIRE_ERROR_COUNT(2, result); + auto r1 = get(result.errors[0]); + auto r2 = get(result.errors[1]); + LUAU_ASSERT(r1); + LUAU_ASSERT(r2); + CHECK_EQ("abs", r1->functionName); + CHECK_EQ("foo.bar", r2->functionName); +} + +TEST_CASE_FIXTURE(NonStrictTypeCheckerFixture, "optionals_in_checked_function_can_be_omitted") +{ + CheckResult result = checkNonStrict(R"( +optionalArgsAtTheEnd1("a") +optionalArgsAtTheEnd1("a", 3) +optionalArgsAtTheEnd1("a", nil, 3) +)"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(NonStrictTypeCheckerFixture, "optionals_in_checked_function_in_middle_cannot_be_omitted") +{ + CheckResult result = checkNonStrict(R"( +optionalArgsAtTheEnd2("a", "a") -- error +optionalArgsAtTheEnd2("a", nil, "b") +optionalArgsAtTheEnd2("a", 3, "b") +optionalArgsAtTheEnd2("a", "b", "c") -- error +)"); + LUAU_REQUIRE_ERROR_COUNT(3, result); + NONSTRICT_REQUIRE_CHECKED_ERR(Position(1, 27), "optionalArgsAtTheEnd2", result); + NONSTRICT_REQUIRE_CHECKED_ERR(Position(4, 27), "optionalArgsAtTheEnd2", result); + auto r1 = get(result.errors[2]); + LUAU_ASSERT(r1); + CHECK_EQ(3, r1->expected); + CHECK_EQ(2, r1->actual); +} + +TEST_CASE_FIXTURE(NonStrictTypeCheckerFixture, "non_testable_type_throws_ice") +{ + CHECK_THROWS_AS( + checkNonStrict(R"( +os.time({year = 0, month = 0, day = 0, min = 0, isdst = nil}) +)"), + Luau::InternalCompilerError + ); +} + +TEST_CASE_FIXTURE(NonStrictTypeCheckerFixture, "non_strict_shouldnt_warn_on_require_module") +{ + fileResolver.source["Modules/A"] = R"( +--!strict +type t = {x : number} +local e : t = {x = 3} +return e +)"; + fileResolver.sourceTypes["Modules/A"] = SourceCode::Module; + + fileResolver.source["Modules/B"] = R"( +--!nonstrict +local E = require(script.Parent.A) +)"; + + CheckResult result = checkNonStrictModule("Modules/B"); + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(NonStrictTypeCheckerFixture, "nonstrict_shouldnt_warn_on_valid_buffer_use") +{ + loadDefinition(R"( +declare buffer: { + create: @checked (size: number) -> buffer, + readi8: @checked (b: buffer, offset: number) -> number, + writef64: @checked (b: buffer, offset: number, value: number) -> (), +} +)"); + + CheckResult result = checkNonStrict(R"( +local b = buffer.create(100) +buffer.writef64(b, 0, 5) +buffer.readi8(b, 0) +)"); + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_SUITE_END(); diff --git a/tests/NonstrictMode.test.cpp b/tests/NonstrictMode.test.cpp index a84e26381..3acd39098 100644 --- a/tests/NonstrictMode.test.cpp +++ b/tests/NonstrictMode.test.cpp @@ -5,6 +5,7 @@ #include "Fixture.h" +#include "ScopedFlags.h" #include "doctest.h" #include @@ -15,6 +16,7 @@ TEST_SUITE_BEGIN("NonstrictModeTests"); TEST_CASE_FIXTURE(Fixture, "infer_nullary_function") { + ScopedFastFlag sff{FFlag::LuauSolverV2, false}; CheckResult result = check(R"( --!nonstrict function foo(x, y) end @@ -37,6 +39,7 @@ TEST_CASE_FIXTURE(Fixture, "infer_nullary_function") TEST_CASE_FIXTURE(Fixture, "infer_the_maximum_number_of_values_the_function_could_return") { + ScopedFastFlag sff{FFlag::LuauSolverV2, false}; CheckResult result = check(R"( --!nonstrict function getMinCardCountForWidth(width) @@ -64,7 +67,7 @@ TEST_CASE_FIXTURE(Fixture, "return_annotation_is_still_checked") LUAU_REQUIRE_ERROR_COUNT(1, result); - REQUIRE_NE(*typeChecker.anyType, *requireType("foo")); + REQUIRE_NE(*builtinTypes->anyType, *requireType("foo")); } #endif @@ -100,6 +103,7 @@ TEST_CASE_FIXTURE(Fixture, "inconsistent_return_types_are_ok") TEST_CASE_FIXTURE(Fixture, "locals_are_any_by_default") { + ScopedFastFlag sff{FFlag::LuauSolverV2, false}; CheckResult result = check(R"( --!nonstrict local m = 55 @@ -107,7 +111,7 @@ TEST_CASE_FIXTURE(Fixture, "locals_are_any_by_default") LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(*typeChecker.anyType, *requireType("m")); + CHECK_EQ(*builtinTypes->anyType, *requireType("m")); } TEST_CASE_FIXTURE(Fixture, "parameters_having_type_any_are_optional") @@ -126,6 +130,7 @@ TEST_CASE_FIXTURE(Fixture, "parameters_having_type_any_are_optional") TEST_CASE_FIXTURE(Fixture, "local_tables_are_not_any") { + ScopedFastFlag sff{FFlag::LuauSolverV2, false}; CheckResult result = check(R"( --!nonstrict local T = {} @@ -143,6 +148,7 @@ TEST_CASE_FIXTURE(Fixture, "local_tables_are_not_any") TEST_CASE_FIXTURE(Fixture, "offer_a_hint_if_you_use_a_dot_instead_of_a_colon") { + ScopedFastFlag sff{FFlag::LuauSolverV2, false}; CheckResult result = check(R"( --!nonstrict local T = {} @@ -157,6 +163,7 @@ TEST_CASE_FIXTURE(Fixture, "offer_a_hint_if_you_use_a_dot_instead_of_a_colon") TEST_CASE_FIXTURE(Fixture, "table_props_are_any") { + ScopedFastFlag sff{FFlag::LuauSolverV2, false}; CheckResult result = check(R"( --!nonstrict local T = {} @@ -170,14 +177,15 @@ TEST_CASE_FIXTURE(Fixture, "table_props_are_any") REQUIRE(ttv != nullptr); REQUIRE(ttv->props.count("foo")); - TypeId fooProp = ttv->props["foo"].type; + TypeId fooProp = ttv->props["foo"].type(); REQUIRE(fooProp != nullptr); - CHECK_EQ(*fooProp, *typeChecker.anyType); + CHECK_EQ(*fooProp, *builtinTypes->anyType); } TEST_CASE_FIXTURE(Fixture, "inline_table_props_are_also_any") { + ScopedFastFlag sff{FFlag::LuauSolverV2, false}; CheckResult result = check(R"( --!nonstrict local T = { @@ -192,9 +200,9 @@ TEST_CASE_FIXTURE(Fixture, "inline_table_props_are_also_any") TableType* ttv = getMutable(requireType("T")); REQUIRE_MESSAGE(ttv, "Should be a table: " << toString(requireType("T"))); - CHECK_EQ(*typeChecker.anyType, *ttv->props["one"].type); - CHECK_EQ(*typeChecker.anyType, *ttv->props["two"].type); - CHECK_MESSAGE(get(follow(ttv->props["three"].type)), "Should be a function: " << *ttv->props["three"].type); + CHECK_EQ(*builtinTypes->anyType, *ttv->props["one"].type()); + CHECK_EQ(*builtinTypes->anyType, *ttv->props["two"].type()); + CHECK_MESSAGE(get(follow(ttv->props["three"].type())), "Should be a function: " << *ttv->props["three"].type()); } TEST_CASE_FIXTURE(BuiltinsFixture, "for_in_iterator_variables_are_any") @@ -253,6 +261,7 @@ TEST_CASE_FIXTURE(Fixture, "delay_function_does_not_require_its_argument_to_retu TEST_CASE_FIXTURE(Fixture, "inconsistent_module_return_types_are_ok") { + ScopedFastFlag sff{FFlag::LuauSolverV2, false}; CheckResult result = check(R"( --!nonstrict diff --git a/tests/Normalize.test.cpp b/tests/Normalize.test.cpp index c45932c6f..24186c0a0 100644 --- a/tests/Normalize.test.cpp +++ b/tests/Normalize.test.cpp @@ -10,8 +10,9 @@ #include "Luau/Normalize.h" #include "Luau/BuiltinDefinitions.h" -LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) - +LUAU_FASTFLAG(LuauSolverV2) +LUAU_FASTINT(LuauTypeInferRecursionLimit) +LUAU_FASTFLAG(LuauNormalizationTracksCyclicPairsThroughInhabitance) using namespace Luau; namespace @@ -74,22 +75,6 @@ TEST_CASE_FIXTURE(IsSubtypeFixture, "functions") CHECK(isSubtype(a, d)); } -TEST_CASE_FIXTURE(IsSubtypeFixture, "functions_and_any") -{ - check(R"( - function a(n: number) return "string" end - function b(q: any) return 5 :: any end - )"); - - TypeId a = requireType("a"); - TypeId b = requireType("b"); - - // any makes things work even when it makes no sense. - - CHECK(isSubtype(b, a)); - CHECK(isSubtype(a, b)); -} - TEST_CASE_FIXTURE(IsSubtypeFixture, "variadic_functions_with_no_head") { check(R"( @@ -157,7 +142,10 @@ TEST_CASE_FIXTURE(IsSubtypeFixture, "table_with_union_prop") TypeId a = requireType("a"); TypeId b = requireType("b"); - CHECK(isSubtype(a, b)); + if (FFlag::LuauSolverV2) + CHECK(!isSubtype(a, b)); // table properties are invariant + else + CHECK(isSubtype(a, b)); CHECK(!isSubtype(b, a)); } @@ -171,8 +159,11 @@ TEST_CASE_FIXTURE(IsSubtypeFixture, "table_with_any_prop") TypeId a = requireType("a"); TypeId b = requireType("b"); - CHECK(isSubtype(a, b)); - CHECK(isSubtype(b, a)); + if (FFlag::LuauSolverV2) + CHECK(!isSubtype(a, b)); // table properties are invariant + else + CHECK(isSubtype(a, b)); + CHECK(!isSubtype(b, a)); } TEST_CASE_FIXTURE(IsSubtypeFixture, "intersection") @@ -228,8 +219,11 @@ TEST_CASE_FIXTURE(IsSubtypeFixture, "tables") TypeId c = requireType("c"); TypeId d = requireType("d"); - CHECK(isSubtype(a, b)); - CHECK(isSubtype(b, a)); + if (FFlag::LuauSolverV2) + CHECK(!isSubtype(a, b)); // table properties are invariant + else + CHECK(isSubtype(a, b)); + CHECK(!isSubtype(b, a)); CHECK(!isSubtype(c, a)); CHECK(!isSubtype(a, c)); @@ -237,7 +231,10 @@ TEST_CASE_FIXTURE(IsSubtypeFixture, "tables") CHECK(isSubtype(d, a)); CHECK(!isSubtype(a, d)); - CHECK(isSubtype(d, b)); + if (FFlag::LuauSolverV2) + CHECK(!isSubtype(d, b)); // table properties are invariant + else + CHECK(isSubtype(d, b)); CHECK(!isSubtype(b, d)); } @@ -325,9 +322,9 @@ TEST_CASE_FIXTURE(IsSubtypeFixture, "classes") check(""); // Ensure that we have a main Module. - TypeId p = typeChecker.globalScope->lookupType("Parent")->type; - TypeId c = typeChecker.globalScope->lookupType("Child")->type; - TypeId u = typeChecker.globalScope->lookupType("Unrelated")->type; + TypeId p = frontend.globals.globalScope->lookupType("Parent")->type; + TypeId c = frontend.globals.globalScope->lookupType("Child")->type; + TypeId u = frontend.globals.globalScope->lookupType("Unrelated")->type; CHECK(isSubtype(c, p)); CHECK(!isSubtype(p, c)); @@ -358,30 +355,106 @@ TEST_CASE_FIXTURE(IsSubtypeFixture, "metatable" * doctest::expected_failures{1}) } #endif +TEST_CASE_FIXTURE(IsSubtypeFixture, "any_is_unknown_union_error") +{ + check(R"( + local err = 5.nope.nope -- err is now an error type + local a : any + local b : (unknown | typeof(err)) + )"); + + TypeId a = requireType("a"); + TypeId b = requireType("b"); + + CHECK(isSubtype(a, b)); + CHECK(isSubtype(b, a)); + CHECK_EQ("*error-type*", toString(requireType("err"))); +} + +TEST_CASE_FIXTURE(IsSubtypeFixture, "any_intersect_T_is_T") +{ + check(R"( + local a : (any & string) + local b : string + local c : number + )"); + + TypeId a = requireType("a"); + TypeId b = requireType("b"); + TypeId c = requireType("c"); + + CHECK(isSubtype(a, b)); + CHECK(isSubtype(b, a)); + CHECK(!isSubtype(a, c)); + CHECK(!isSubtype(c, a)); +} + +TEST_CASE_FIXTURE(IsSubtypeFixture, "error_suppression") +{ + check(""); + + TypeId any = builtinTypes->anyType; + TypeId err = builtinTypes->errorType; + TypeId str = builtinTypes->stringType; + TypeId unk = builtinTypes->unknownType; + + CHECK(!isSubtype(any, err)); + CHECK(isSubtype(err, any)); + + CHECK(!isSubtype(any, str)); + CHECK(isSubtype(str, any)); + + // We have added this as an exception - the set of inhabitants of any is exactly the set of inhabitants of unknown (since error has no + // inhabitants). any = err | unknown, so under semantic subtyping, {} U unknown = unknown + if (FFlag::LuauSolverV2) + { + CHECK(isSubtype(any, unk)); + } + else + { + CHECK(!isSubtype(any, unk)); + } + + if (FFlag::LuauSolverV2) + { + CHECK(isSubtype(err, str)); + } + else + { + CHECK(!isSubtype(err, str)); + } + + CHECK(!isSubtype(str, err)); + + CHECK(!isSubtype(err, unk)); + CHECK(!isSubtype(unk, err)); + + CHECK(isSubtype(str, unk)); + CHECK(!isSubtype(unk, str)); +} + TEST_SUITE_END(); struct NormalizeFixture : Fixture { - ScopedFastFlag sff1{"LuauNegatedFunctionTypes", true}; - ScopedFastFlag sff2{"LuauNegatedClassTypes", true}; - TypeArena arena; InternalErrorReporter iceHandler; UnifierSharedState unifierState{&iceHandler}; Normalizer normalizer{&arena, builtinTypes, NotNull{&unifierState}}; + Scope globalScope{builtinTypes->anyTypePack}; NormalizeFixture() { registerHiddenTypes(&frontend); } - const NormalizedType* toNormalizedType(const std::string& annotation) + std::shared_ptr toNormalizedType(const std::string& annotation, int expectedErrors = 0) { normalizer.clearCaches(); CheckResult result = check("type _Res = " + annotation); - LUAU_REQUIRE_NO_ERRORS(result); + LUAU_REQUIRE_ERROR_COUNT(expectedErrors, result); - if (FFlag::DebugLuauDeferredConstraintResolution) + if (FFlag::LuauSolverV2) { SourceModule* sourceModule = getMainSourceModule(); REQUIRE(sourceModule); @@ -389,7 +462,7 @@ struct NormalizeFixture : Fixture REQUIRE(node); AstStatTypeAlias* alias = node->as(); REQUIRE(alias); - TypeId* originalTy = getMainModule()->astOriginalResolvedTypes.find(alias->type); + TypeId* originalTy = getMainModule()->astResolvedTypes.find(alias->type); REQUIRE(originalTy); return normalizer.normalize(*originalTy); } @@ -403,7 +476,7 @@ struct NormalizeFixture : Fixture TypeId normal(const std::string& annotation) { - const NormalizedType* norm = toNormalizedType(annotation); + std::shared_ptr norm = toNormalizedType(annotation); REQUIRE(norm); return normalizer.typeFromNormal(*norm); } @@ -411,6 +484,71 @@ struct NormalizeFixture : Fixture TEST_SUITE_BEGIN("Normalize"); +TEST_CASE_FIXTURE(NormalizeFixture, "string_intersection_is_commutative") +{ + auto c4 = toString(normal(R"( + string & (string & Not<"a"> & Not<"b">) +)")); + auto c4Reverse = toString(normal(R"( + (string & Not<"a"> & Not<"b">) & string +)")); + CHECK(c4 == c4Reverse); + CHECK_EQ("string & ~\"a\" & ~\"b\"", c4); + + auto c5 = toString(normal(R"( + (string & Not<"a"> & Not<"b">) & (string & Not<"b"> & Not<"c">) +)")); + auto c5Reverse = toString(normal(R"( + (string & Not<"b"> & Not<"c">) & (string & Not<"a"> & Not<"c">) +)")); + CHECK(c5 == c5Reverse); + CHECK_EQ("string & ~\"a\" & ~\"b\" & ~\"c\"", c5); + + auto c6 = toString(normal(R"( + ("a" | "b") & (string & Not<"b"> & Not<"c">) +)")); + auto c6Reverse = toString(normal(R"( + (string & Not<"b"> & Not<"c">) & ("a" | "b") +)")); + CHECK(c6 == c6Reverse); + CHECK_EQ("\"a\"", c6); + + auto c7 = toString(normal(R"( + string & ("b" | "c") +)")); + auto c7Reverse = toString(normal(R"( + ("b" | "c") & string +)")); + CHECK(c7 == c7Reverse); + CHECK_EQ("\"b\" | \"c\"", c7); + + auto c8 = toString(normal(R"( +(string & Not<"a"> & Not<"b">) & ("b" | "c") +)")); + auto c8Reverse = toString(normal(R"( + ("b" | "c") & (string & Not<"a"> & Not<"b">) +)")); + CHECK(c8 == c8Reverse); + CHECK_EQ("\"c\"", c8); + auto c9 = toString(normal(R"( + ("a" | "b") & ("b" | "c") + )")); + auto c9Reverse = toString(normal(R"( + ("b" | "c") & ("a" | "b") + )")); + CHECK(c9 == c9Reverse); + CHECK_EQ("\"b\"", c9); + + auto l = toString(normal(R"( + (string | number) & ("a" | true) + )")); + auto r = toString(normal(R"( + ("a" | true) & (string | number) + )")); + CHECK(l == r); + CHECK_EQ("\"a\"", l); +} + TEST_CASE_FIXTURE(NormalizeFixture, "negate_string") { CHECK("number" == toString(normal(R"( @@ -525,30 +663,32 @@ TEST_CASE_FIXTURE(NormalizeFixture, "union_function_and_top_function") TEST_CASE_FIXTURE(NormalizeFixture, "negated_function_is_anything_except_a_function") { - ScopedFastFlag sffs[] = { - {"LuauNegatedTableTypes", true}, - {"LuauNegatedClassTypes", true}, - }; - - CHECK("(boolean | class | number | string | table | thread)?" == toString(normal(R"( + CHECK("(boolean | buffer | class | number | string | table | thread)?" == toString(normal(R"( Not )"))); } TEST_CASE_FIXTURE(NormalizeFixture, "specific_functions_cannot_be_negated") { - CHECK(nullptr == toNormalizedType("Not<(boolean) -> boolean>")); + CHECK(nullptr == toNormalizedType("Not<(boolean) -> boolean>", FFlag::LuauSolverV2 ? 1 : 0)); } -TEST_CASE_FIXTURE(NormalizeFixture, "bare_negated_boolean") +TEST_CASE_FIXTURE(NormalizeFixture, "trivial_intersection_inhabited") { - ScopedFastFlag sffs[] = { - {"LuauNegatedTableTypes", true}, - {"LuauNegatedClassTypes", true}, - }; + // this test was used to fix a bug in normalization when working with intersections/unions of the same type. + + TypeId a = arena.addType(FunctionType{builtinTypes->emptyTypePack, builtinTypes->anyTypePack, std::nullopt, false}); + TypeId c = arena.addType(IntersectionType{{a, a}}); + + std::shared_ptr n = normalizer.normalize(c); + REQUIRE(n); + + CHECK(normalizer.isInhabited(n.get()) == NormalizationResult::True); +} - // TODO: We don't yet have a way to say number | string | thread | nil | Class | Table | Function - CHECK("(class | function | number | string | table | thread)?" == toString(normal(R"( +TEST_CASE_FIXTURE(NormalizeFixture, "bare_negated_boolean") +{ + CHECK("(buffer | class | function | number | string | table | thread)?" == toString(normal(R"( Not )"))); } @@ -569,6 +709,10 @@ TEST_CASE_FIXTURE(Fixture, "higher_order_function") TEST_CASE_FIXTURE(Fixture, "higher_order_function_with_annotation") { + // CLI-117088 - Inferring the type of a higher order function with an annotation sometimes doesn't fully constrain the type (there are free types + // left over). + if (FFlag::LuauSolverV2) + return; check(R"( function apply(f: (a) -> b, x) return f(x) @@ -616,8 +760,6 @@ export type t0 = (((any)&({_:l0.t0,n0:t0,_G:any,}))&({_:any,}))&(((any)&({_:l0.t TEST_CASE_FIXTURE(NormalizeFixture, "unions_of_classes") { - ScopedFastFlag sff{"LuauNegatedClassTypes", true}; - createSomeClasses(&frontend); CHECK("Parent | Unrelated" == toString(normal("Parent | Unrelated"))); CHECK("Parent" == toString(normal("Parent | Child"))); @@ -626,8 +768,6 @@ TEST_CASE_FIXTURE(NormalizeFixture, "unions_of_classes") TEST_CASE_FIXTURE(NormalizeFixture, "intersections_of_classes") { - ScopedFastFlag sff{"LuauNegatedClassTypes", true}; - createSomeClasses(&frontend); CHECK("Child" == toString(normal("Parent & Child"))); CHECK("never" == toString(normal("Child & Unrelated"))); @@ -635,61 +775,361 @@ TEST_CASE_FIXTURE(NormalizeFixture, "intersections_of_classes") TEST_CASE_FIXTURE(NormalizeFixture, "narrow_union_of_classes_with_intersection") { - ScopedFastFlag sff{"LuauNegatedClassTypes", true}; - createSomeClasses(&frontend); CHECK("Child" == toString(normal("(Child | Unrelated) & Child"))); } -TEST_CASE_FIXTURE(NormalizeFixture, "negations_of_classes") +TEST_CASE_FIXTURE(NormalizeFixture, "intersection_of_metatables_where_the_metatable_is_top_or_bottom") { - ScopedFastFlag sffs[] = { - {"LuauNegatedTableTypes", true}, - {"LuauNegatedClassTypes", true}, - }; + if (FFlag::LuauSolverV2) + CHECK("{ @metatable *error-type*, { } }" == toString(normal("Mt<{}, any> & Mt<{}, err>"))); + else + CHECK("{ @metatable *error-type*, {| |} }" == toString(normal("Mt<{}, any> & Mt<{}, err>"))); +} +TEST_CASE_FIXTURE(NormalizeFixture, "recurring_intersection") +{ + CheckResult result = check(R"( + type A = any? + type B = A & A + )"); + + std::optional t = lookupType("B"); + REQUIRE(t); + + std::shared_ptr nt = normalizer.normalize(*t); + REQUIRE(nt); + + CHECK("any" == toString(normalizer.typeFromNormal(*nt))); +} + +TEST_CASE_FIXTURE(NormalizeFixture, "cyclic_union") +{ + // T where T = any & (number | T) + TypeId t = arena.addType(BlockedType{}); + TypeId u = arena.addType(UnionType{{builtinTypes->numberType, t}}); + asMutable(t)->ty.emplace(IntersectionType{{builtinTypes->anyType, u}}); + + std::shared_ptr nt = normalizer.normalize(t); + REQUIRE(nt); + + CHECK("number" == toString(normalizer.typeFromNormal(*nt))); +} + +TEST_CASE_FIXTURE(NormalizeFixture, "cyclic_union_of_intersection") +{ + // t1 where t1 = (string & t1) | string + TypeId boundTy = arena.addType(BlockedType{}); + TypeId intersectTy = arena.addType(IntersectionType{{builtinTypes->stringType, boundTy}}); + TypeId unionTy = arena.addType(UnionType{{builtinTypes->stringType, intersectTy}}); + asMutable(boundTy)->reassign(Type{BoundType{unionTy}}); + + std::shared_ptr nt = normalizer.normalize(unionTy); + + CHECK("string" == toString(normalizer.typeFromNormal(*nt))); +} + +TEST_CASE_FIXTURE(NormalizeFixture, "cyclic_intersection_of_unions") +{ + // t1 where t1 = (string & t1) | string + TypeId boundTy = arena.addType(BlockedType{}); + TypeId unionTy = arena.addType(UnionType{{builtinTypes->stringType, boundTy}}); + TypeId intersectionTy = arena.addType(IntersectionType{{builtinTypes->stringType, unionTy}}); + asMutable(boundTy)->reassign(Type{BoundType{intersectionTy}}); + + std::shared_ptr nt = normalizer.normalize(intersectionTy); + + CHECK("string" == toString(normalizer.typeFromNormal(*nt))); +} + +TEST_CASE_FIXTURE(NormalizeFixture, "crazy_metatable") +{ + CHECK("never" == toString(normal("Mt<{}, number> & Mt<{}, string>"))); +} + +TEST_CASE_FIXTURE(NormalizeFixture, "negations_of_classes") +{ createSomeClasses(&frontend); CHECK("(Parent & ~Child) | Unrelated" == toString(normal("(Parent & Not) | Unrelated"))); - CHECK("((class & ~Child) | boolean | function | number | string | table | thread)?" == toString(normal("Not"))); + CHECK("((class & ~Child) | boolean | buffer | function | number | string | table | thread)?" == toString(normal("Not"))); CHECK("Child" == toString(normal("Not & Child"))); - CHECK("((class & ~Parent) | Child | boolean | function | number | string | table | thread)?" == toString(normal("Not | Child"))); - CHECK("(boolean | function | number | string | table | thread)?" == toString(normal("Not"))); - CHECK("(Parent | Unrelated | boolean | function | number | string | table | thread)?" == - toString(normal("Not & Not & Not>"))); + CHECK("((class & ~Parent) | Child | boolean | buffer | function | number | string | table | thread)?" == toString(normal("Not | Child"))); + CHECK("(boolean | buffer | function | number | string | table | thread)?" == toString(normal("Not"))); + CHECK( + "(Parent | Unrelated | boolean | buffer | function | number | string | table | thread)?" == + toString(normal("Not & Not & Not>")) + ); + + CHECK("Child" == toString(normal("(Child | Unrelated) & Not"))); } TEST_CASE_FIXTURE(NormalizeFixture, "classes_and_unknown") { - ScopedFastFlag sff{"LuauNegatedClassTypes", true}; - createSomeClasses(&frontend); CHECK("Parent" == toString(normal("Parent & unknown"))); } TEST_CASE_FIXTURE(NormalizeFixture, "classes_and_never") { - ScopedFastFlag sff{"LuauNegatedClassTypes", true}; - createSomeClasses(&frontend); CHECK("never" == toString(normal("Parent & never"))); } TEST_CASE_FIXTURE(NormalizeFixture, "top_table_type") { - ScopedFastFlag sff{"LuauNegatedTableTypes", true}; - CHECK("table" == toString(normal("{} | tbl"))); - CHECK("{| |}" == toString(normal("{} & tbl"))); + if (FFlag::LuauSolverV2) + CHECK("{ }" == toString(normal("{} & tbl"))); + else + CHECK("{| |}" == toString(normal("{} & tbl"))); CHECK("never" == toString(normal("number & tbl"))); } TEST_CASE_FIXTURE(NormalizeFixture, "negations_of_tables") { - ScopedFastFlag sff{"LuauNegatedTableTypes", true}; - - CHECK(nullptr == toNormalizedType("Not<{}>")); - CHECK("(boolean | class | function | number | string | thread)?" == toString(normal("Not"))); + CHECK(nullptr == toNormalizedType("Not<{}>", FFlag::LuauSolverV2 ? 1 : 0)); + CHECK("(boolean | buffer | class | function | number | string | thread)?" == toString(normal("Not"))); CHECK("table" == toString(normal("Not>"))); } +TEST_CASE_FIXTURE(NormalizeFixture, "normalize_blocked_types") +{ + Type blocked{BlockedType{}}; + + std::shared_ptr norm = normalizer.normalize(&blocked); + + CHECK_EQ(normalizer.typeFromNormal(*norm), &blocked); +} + +TEST_CASE_FIXTURE(NormalizeFixture, "normalize_is_exactly_number") +{ + std::shared_ptr number = normalizer.normalize(builtinTypes->numberType); + // 1. all types for which Types::number say true for, NormalizedType::isExactlyNumber should say true as well + CHECK(Luau::isNumber(builtinTypes->numberType) == number->isExactlyNumber()); + // 2. isExactlyNumber should handle cases like `number & number` + TypeId intersection = arena.addType(IntersectionType{{builtinTypes->numberType, builtinTypes->numberType}}); + std::shared_ptr normIntersection = normalizer.normalize(intersection); + CHECK(normIntersection->isExactlyNumber()); + + // 3. isExactlyNumber should reject things that are definitely not precisely numbers `number | any` + + TypeId yoonion = arena.addType(UnionType{{builtinTypes->anyType, builtinTypes->numberType}}); + std::shared_ptr unionIntersection = normalizer.normalize(yoonion); + CHECK(!unionIntersection->isExactlyNumber()); +} + +TEST_CASE_FIXTURE(NormalizeFixture, "normalize_unknown") +{ + auto nt = toNormalizedType("Not | Not"); + CHECK(nt); + CHECK(nt->isUnknown()); + CHECK(toString(normalizer.typeFromNormal(*nt)) == "unknown"); +} + +TEST_CASE_FIXTURE(NormalizeFixture, "read_only_props") +{ + ScopedFastFlag sff{FFlag::LuauSolverV2, true}; + + CHECK("{ x: string }" == toString(normal("{ read x: string } & { x: string }"), {true})); + CHECK("{ x: string }" == toString(normal("{ x: string } & { read x: string }"), {true})); +} + +TEST_CASE_FIXTURE(NormalizeFixture, "read_only_props_2") +{ + ScopedFastFlag sff{FFlag::LuauSolverV2, true}; + + CHECK(R"({ x: "hello" })" == toString(normal(R"({ x: "hello" } & { x: string })"), {true})); + CHECK(R"(never)" == toString(normal(R"({ x: "hello" } & { x: "world" })"), {true})); +} + +TEST_CASE_FIXTURE(NormalizeFixture, "read_only_props_3") +{ + ScopedFastFlag sff{FFlag::LuauSolverV2, true}; + + CHECK(R"({ read x: "hello" })" == toString(normal(R"({ read x: "hello" } & { read x: string })"), {true})); + CHECK("never" == toString(normal(R"({ read x: "hello" } & { read x: "world" })"), {true})); +} + +TEST_CASE_FIXTURE(NormalizeFixture, "final_types_are_cached") +{ + std::shared_ptr na1 = normalizer.normalize(builtinTypes->numberType); + std::shared_ptr na2 = normalizer.normalize(builtinTypes->numberType); + + CHECK(na1 == na2); +} + +TEST_CASE_FIXTURE(NormalizeFixture, "non_final_types_can_be_normalized_but_are_not_cached") +{ + TypeId a = arena.freshType(&globalScope); + + std::shared_ptr na1 = normalizer.normalize(a); + std::shared_ptr na2 = normalizer.normalize(a); + + CHECK(na1 != na2); +} + +TEST_CASE_FIXTURE(NormalizeFixture, "intersect_with_not_unknown") +{ + TypeId notUnknown = arena.addType(NegationType{builtinTypes->unknownType}); + TypeId type = arena.addType(IntersectionType{{builtinTypes->numberType, notUnknown}}); + std::shared_ptr normalized = normalizer.normalize(type); + + CHECK("never" == toString(normalizer.typeFromNormal(*normalized.get()))); +} + +TEST_CASE_FIXTURE(NormalizeFixture, "cyclic_stack_overflow_1") +{ + ScopedFastInt sfi{FInt::LuauTypeInferRecursionLimit, 165}; + this->unifierState.counters.recursionLimit = FInt::LuauTypeInferRecursionLimit; + TypeId t1 = arena.addType(TableType{}); + TypeId t2 = arena.addType(TableType{}); + TypeId t3 = arena.addType(IntersectionType{{t1, t2}}); + asMutable(t1)->ty.get_if()->props = {{"foo", Property::readonly(t2)}}; + asMutable(t2)->ty.get_if()->props = {{"foo", Property::readonly(t1)}}; + + std::shared_ptr normalized = normalizer.normalize(t3); + CHECK(normalized); +} + +TEST_CASE_FIXTURE(NormalizeFixture, "cyclic_stack_overflow_2") +{ + ScopedFastInt sfi{FInt::LuauTypeInferRecursionLimit, 165}; + this->unifierState.counters.recursionLimit = FInt::LuauTypeInferRecursionLimit; + TypeId t1 = arena.addType(TableType{}); + TypeId t2 = arena.addType(TableType{}); + TypeId t3 = arena.addType(IntersectionType{{t1, t2}}); + asMutable(t1)->ty.get_if()->props = {{"foo", Property::readonly(t3)}}; + asMutable(t2)->ty.get_if()->props = {{"foo", Property::readonly(t1)}}; + + std::shared_ptr normalized = normalizer.normalize(t3); + CHECK(normalized); +} + +TEST_CASE_FIXTURE(NormalizeFixture, "truthy_table_property_and_optional_table_with_optional_prop") +{ + ScopedFastFlag sff{FFlag::LuauSolverV2, true}; + + // { x: ~(false?) } + TypeId t1 = arena.addType(TableType{TableType::Props{{"x", builtinTypes->truthyType}}, std::nullopt, TypeLevel{}, TableState::Sealed}); + + // { x: number? }? + TypeId t2 = arena.addType(UnionType{ + {arena.addType(TableType{TableType::Props{{"x", builtinTypes->optionalNumberType}}, std::nullopt, TypeLevel{}, TableState::Sealed}), + builtinTypes->nilType} + }); + + TypeId intersection = arena.addType(IntersectionType{{t2, t1}}); + + auto norm = normalizer.normalize(intersection); + REQUIRE(norm); + + TypeId ty = normalizer.typeFromNormal(*norm); + CHECK("{ x: number }" == toString(ty)); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "normalizer_should_be_able_to_detect_cyclic_tables_and_not_stack_overflow") +{ + if (!FFlag::LuauSolverV2) + return; + ScopedFastInt sfi{FInt::LuauTypeInferRecursionLimit, 0}; + ScopedFastFlag sff{FFlag::LuauNormalizationTracksCyclicPairsThroughInhabitance, true}; + CheckResult result = check(R"( +--!strict + +type Array = { [number] : T} +type Object = { [number] : any} + +type Set = typeof(setmetatable( + {} :: { + size: number, + -- method definitions + add: (self: Set, T) -> Set, + clear: (self: Set) -> (), + delete: (self: Set, T) -> boolean, + has: (self: Set, T) -> boolean, + ipairs: (self: Set) -> any, + }, + {} :: { + __index: Set, + __iter: (self: Set) -> (({ [K]: V }, K?) -> (K, V), T), + } +)) + +type Map = typeof(setmetatable( + {} :: { + size: number, + -- method definitions + set: (self: Map, K, V) -> Map, + get: (self: Map, K) -> V | nil, + clear: (self: Map) -> (), + delete: (self: Map, K) -> boolean, + [K]: V, + has: (self: Map, K) -> boolean, + keys: (self: Map) -> Array, + values: (self: Map) -> Array, + entries: (self: Map) -> Array>, + ipairs: (self: Map) -> any, + _map: { [K]: V }, + _array: { [number]: K }, + __index: (self: Map, key: K) -> V, + __iter: (self: Map) -> (({ [K]: V }, K?) -> (K?, V), V), + __newindex: (self: Map, key: K, value: V) -> (), + }, + {} :: { + __index: Map, + __iter: (self: Map) -> (({ [K]: V }, K?) -> (K, V), V), + __newindex: (self: Map, key: K, value: V) -> (), + } +)) +type mapFn = (element: T, index: number) -> U +type mapFnWithThisArg = (thisArg: any, element: T, index: number) -> U + +function fromSet( + value: Set, + mapFn: (mapFn | mapFnWithThisArg)?, + thisArg: Object? + -- FIXME Luau: need overloading so the return type on this is more sane and doesn't require manual casts +): Array | Array | Array + + local array : { [number] : string} = {"foo"} + return array +end + +function instanceof(tbl: any, class: any): boolean + return true +end + +function fromArray( + value: Array, + mapFn: (mapFn | mapFnWithThisArg)?, + thisArg: Object? + -- FIXME Luau: need overloading so the return type on this is more sane and doesn't require manual casts +): Array | Array | Array + local array : {[number] : string} = {} + return array +end + +return function( + value: string | Array | Set | Map, + mapFn: (mapFn | mapFnWithThisArg)?, + thisArg: Object? + -- FIXME Luau: need overloading so the return type on this is more sane and doesn't require manual casts +): Array | Array | Array + if value == nil then + error("cannot create array from a nil value") + end + local array: Array | Array | Array + + if instanceof(value, Set) then + array = fromSet(value :: Set, mapFn, thisArg) + else + array = {} + end + + + return array +end +)"); +} + TEST_SUITE_END(); diff --git a/tests/Parser.test.cpp b/tests/Parser.test.cpp index 9ff16d16b..dea628594 100644 --- a/tests/Parser.test.cpp +++ b/tests/Parser.test.cpp @@ -3,6 +3,7 @@ #include "AstQueryDsl.h" #include "Fixture.h" +#include "Luau/Common.h" #include "ScopedFlags.h" #include "doctest.h" @@ -11,6 +12,13 @@ using namespace Luau; +LUAU_FASTINT(LuauRecursionLimit) +LUAU_FASTINT(LuauTypeLengthLimit) +LUAU_FASTINT(LuauParseErrorLimit) +LUAU_FASTFLAG(LuauSolverV2) +LUAU_FASTFLAG(LuauAttributeSyntaxFunExpr) +LUAU_FASTFLAG(LuauUserDefinedTypeFunctionsSyntax2) + namespace { @@ -54,7 +62,8 @@ TEST_SUITE_BEGIN("AllocatorTests"); TEST_CASE("allocator_can_be_moved") { Counter* c = nullptr; - auto inner = [&]() { + auto inner = [&]() + { Luau::Allocator allocator; c = allocator.alloc(); Luau::Allocator moved{std::move(allocator)}; @@ -112,14 +121,6 @@ TEST_CASE_FIXTURE(Fixture, "can_haz_annotations") REQUIRE(block != nullptr); } -TEST_CASE_FIXTURE(Fixture, "local_cannot_have_annotation_with_extensions_disabled") -{ - Luau::ParseOptions options; - options.allowTypeAnnotations = false; - - CHECK_THROWS_AS(parse("local foo: string = \"Hello Types!\"", options), std::exception); -} - TEST_CASE_FIXTURE(Fixture, "local_with_annotation") { AstStatBlock* block = parse(R"( @@ -150,14 +151,6 @@ TEST_CASE_FIXTURE(Fixture, "type_names_can_contain_dots") REQUIRE(block != nullptr); } -TEST_CASE_FIXTURE(Fixture, "functions_cannot_have_return_annotations_if_extensions_are_disabled") -{ - Luau::ParseOptions options; - options.allowTypeAnnotations = false; - - CHECK_THROWS_AS(parse("function foo(): number return 55 end", options), std::exception); -} - TEST_CASE_FIXTURE(Fixture, "functions_can_have_return_annotations") { AstStatBlock* block = parse(R"( @@ -395,14 +388,6 @@ TEST_CASE_FIXTURE(Fixture, "return_type_is_an_intersection_type_if_led_with_one_ CHECK(returnAnnotation->types.data[1]->as()); } -TEST_CASE_FIXTURE(Fixture, "illegal_type_alias_if_extensions_are_disabled") -{ - Luau::ParseOptions options; - options.allowTypeAnnotations = false; - - CHECK_THROWS_AS(parse("type A = number", options), std::exception); -} - TEST_CASE_FIXTURE(Fixture, "type_alias_to_a_typeof") { AstStatBlock* block = parse(R"( @@ -478,46 +463,62 @@ TEST_CASE_FIXTURE(Fixture, "type_alias_span_is_correct") TEST_CASE_FIXTURE(Fixture, "parse_error_messages") { - CHECK_EQ(getParseError(R"( + CHECK_EQ( + getParseError(R"( local a: (number, number) -> (string )"), - "Expected ')' (to close '(' at line 2), got "); + "Expected ')' (to close '(' at line 2), got " + ); - CHECK_EQ(getParseError(R"( + CHECK_EQ( + getParseError(R"( local a: (number, number) -> ( string )"), - "Expected ')' (to close '(' at line 2), got "); + "Expected ')' (to close '(' at line 2), got " + ); - CHECK_EQ(getParseError(R"( + CHECK_EQ( + getParseError(R"( local a: (number, number) )"), - "Expected '->' when parsing function type, got "); + "Expected '->' when parsing function type, got " + ); - CHECK_EQ(getParseError(R"( + CHECK_EQ( + getParseError(R"( local a: (number, number )"), - "Expected ')' (to close '(' at line 2), got "); + "Expected ')' (to close '(' at line 2), got " + ); - CHECK_EQ(getParseError(R"( + CHECK_EQ( + getParseError(R"( local a: {foo: string, )"), - "Expected identifier when parsing table field, got "); + "Expected identifier when parsing table field, got " + ); - CHECK_EQ(getParseError(R"( + CHECK_EQ( + getParseError(R"( local a: {foo: string )"), - "Expected '}' (to close '{' at line 2), got "); + "Expected '}' (to close '{' at line 2), got " + ); - CHECK_EQ(getParseError(R"( + CHECK_EQ( + getParseError(R"( local a: { [string]: number, [number]: string } )"), - "Cannot have more than one table indexer"); + "Cannot have more than one table indexer" + ); - CHECK_EQ(getParseError(R"( + CHECK_EQ( + getParseError(R"( type T = foo )"), - "Expected '(' when parsing function parameters, got 'foo'"); + "Expected '(' when parsing function parameters, got 'foo'" + ); } TEST_CASE_FIXTURE(Fixture, "mixed_intersection_and_union_not_allowed") @@ -652,10 +653,12 @@ TEST_CASE_FIXTURE(Fixture, "vertical_space") TEST_CASE_FIXTURE(Fixture, "parse_error_type_name") { - CHECK_EQ(getParseError(R"( + CHECK_EQ( + getParseError(R"( local a: Foo.= )"), - "Expected identifier when parsing field name, got '='"); + "Expected identifier when parsing field name, got '='" + ); } TEST_CASE_FIXTURE(Fixture, "parse_numbers_decimal") @@ -717,10 +720,12 @@ TEST_CASE_FIXTURE(Fixture, "break_return_not_last_error") TEST_CASE_FIXTURE(Fixture, "error_on_unicode") { - CHECK_EQ(getParseError(R"( + CHECK_EQ( + getParseError(R"( local ☃ = 10 )"), - "Expected identifier when parsing variable name, got Unicode character U+2603"); + "Expected identifier when parsing variable name, got Unicode character U+2603" + ); } TEST_CASE_FIXTURE(Fixture, "allow_unicode_in_string") @@ -731,10 +736,12 @@ TEST_CASE_FIXTURE(Fixture, "allow_unicode_in_string") TEST_CASE_FIXTURE(Fixture, "error_on_confusable") { - CHECK_EQ(getParseError(R"( + CHECK_EQ( + getParseError(R"( local pi = 3․13 )"), - "Expected identifier when parsing expression, got Unicode character U+2024 (did you mean '.'?)"); + "Expected identifier when parsing expression, got Unicode character U+2024 (did you mean '.'?)" + ); } TEST_CASE_FIXTURE(Fixture, "error_on_non_utf8_sequence") @@ -924,7 +931,7 @@ TEST_CASE_FIXTURE(Fixture, "parse_interpolated_string_double_brace_begin") } catch (const ParseErrors& e) { - CHECK_EQ("Double braces are not permitted within interpolated strings. Did you mean '\\{'?", e.getErrors().front().getMessage()); + CHECK_EQ("Double braces are not permitted within interpolated strings; did you mean '\\{'?", e.getErrors().front().getMessage()); } } @@ -939,13 +946,14 @@ TEST_CASE_FIXTURE(Fixture, "parse_interpolated_string_double_brace_mid") } catch (const ParseErrors& e) { - CHECK_EQ("Double braces are not permitted within interpolated strings. Did you mean '\\{'?", e.getErrors().front().getMessage()); + CHECK_EQ("Double braces are not permitted within interpolated strings; did you mean '\\{'?", e.getErrors().front().getMessage()); } } TEST_CASE_FIXTURE(Fixture, "parse_interpolated_string_without_end_brace") { - auto columnOfEndBraceError = [this](const char* code) { + auto columnOfEndBraceError = [this](const char* code) + { try { parse(code); @@ -957,7 +965,7 @@ TEST_CASE_FIXTURE(Fixture, "parse_interpolated_string_without_end_brace") CHECK_EQ(e.getErrors().size(), 1); auto error = e.getErrors().front(); - CHECK_EQ("Malformed interpolated string, did you forget to add a '}'?", error.getMessage()); + CHECK_EQ("Malformed interpolated string; did you forget to add a '}'?", error.getMessage()); return error.getLocation().begin.column; } }; @@ -980,7 +988,7 @@ TEST_CASE_FIXTURE(Fixture, "parse_interpolated_string_without_end_brace_in_table { CHECK_EQ(e.getErrors().size(), 2); - CHECK_EQ("Malformed interpolated string, did you forget to add a '}'?", e.getErrors().front().getMessage()); + CHECK_EQ("Malformed interpolated string; did you forget to add a '}'?", e.getErrors().front().getMessage()); CHECK_EQ("Expected '}' (to close '{' at line 2), got ", e.getErrors().back().getMessage()); } } @@ -998,7 +1006,7 @@ TEST_CASE_FIXTURE(Fixture, "parse_interpolated_string_mid_without_end_brace_in_t { CHECK_EQ(e.getErrors().size(), 2); - CHECK_EQ("Malformed interpolated string, did you forget to add a '}'?", e.getErrors().front().getMessage()); + CHECK_EQ("Malformed interpolated string; did you forget to add a '}'?", e.getErrors().front().getMessage()); CHECK_EQ("Expected '}' (to close '{' at line 2), got ", e.getErrors().back().getMessage()); } } @@ -1040,8 +1048,6 @@ TEST_CASE_FIXTURE(Fixture, "parse_interpolated_string_call_without_parens") TEST_CASE_FIXTURE(Fixture, "parse_interpolated_string_without_expression") { - ScopedFastFlag sff("LuauFixInterpStringMid", true); - try { parse(R"( @@ -1067,6 +1073,36 @@ TEST_CASE_FIXTURE(Fixture, "parse_interpolated_string_without_expression") } } +TEST_CASE_FIXTURE(Fixture, "parse_interpolated_string_malformed_escape") +{ + try + { + parse(R"( + local a = `???\xQQ {1}` + )"); + FAIL("Expected ParseErrors to be thrown"); + } + catch (const ParseErrors& e) + { + CHECK_EQ("Interpolated string literal contains malformed escape sequence", e.getErrors().front().getMessage()); + } +} + +TEST_CASE_FIXTURE(Fixture, "parse_interpolated_string_weird_token") +{ + try + { + parse(R"( + local a = `??? {42 !!}` + )"); + FAIL("Expected ParseErrors to be thrown"); + } + catch (const ParseErrors& e) + { + CHECK_EQ("Malformed interpolated string, got '!'", e.getErrors().front().getMessage()); + } +} + TEST_CASE_FIXTURE(Fixture, "parse_nesting_based_end_detection") { try @@ -1094,8 +1130,9 @@ end } catch (const ParseErrors& e) { - CHECK_EQ("Expected 'end' (to close 'function' at line 2), got ; did you forget to close 'else' at line 8?", - e.getErrors().front().getMessage()); + CHECK_EQ( + "Expected 'end' (to close 'function' at line 2), got ; did you forget to close 'else' at line 8?", e.getErrors().front().getMessage() + ); } } @@ -1123,8 +1160,9 @@ end } catch (const ParseErrors& e) { - CHECK_EQ("Expected 'end' (to close 'function' at line 2), got ; did you forget to close 'else' at line 3?", - e.getErrors().front().getMessage()); + CHECK_EQ( + "Expected 'end' (to close 'function' at line 2), got ; did you forget to close 'else' at line 3?", e.getErrors().front().getMessage() + ); } } @@ -1144,13 +1182,19 @@ until false } catch (const ParseErrors& e) { - CHECK_EQ("Expected 'until' (to close 'repeat' at line 2), got ; did you forget to close 'repeat' at line 4?", - e.getErrors().front().getMessage()); + CHECK_EQ( + "Expected 'until' (to close 'repeat' at line 2), got ; did you forget to close 'repeat' at line 4?", + e.getErrors().front().getMessage() + ); } } TEST_CASE_FIXTURE(Fixture, "parse_nesting_based_end_detection_local_function") { + ScopedFastFlag sff[] = { + {FFlag::LuauSolverV2, false}, + }; + try { parse(R"(-- i am line 1 @@ -1176,13 +1220,18 @@ end } catch (const ParseErrors& e) { - CHECK_EQ("Expected 'end' (to close 'function' at line 2), got ; did you forget to close 'else' at line 8?", - e.getErrors().front().getMessage()); + CHECK_EQ( + "Expected 'end' (to close 'function' at line 2), got ; did you forget to close 'else' at line 8?", e.getErrors().front().getMessage() + ); } } TEST_CASE_FIXTURE(Fixture, "parse_nesting_based_end_detection_failsafe_earlier") { + ScopedFastFlag sff[] = { + {FFlag::LuauSolverV2, false}, + }; + try { parse(R"(-- i am line 1 @@ -1241,8 +1290,9 @@ end } catch (const ParseErrors& e) { - CHECK_EQ("Expected 'end' (to close 'function' at line 2), got ; did you forget to close 'else' at line 8?", - e.getErrors().front().getMessage()); + CHECK_EQ( + "Expected 'end' (to close 'function' at line 2), got ; did you forget to close 'else' at line 8?", e.getErrors().front().getMessage() + ); } } @@ -1261,7 +1311,8 @@ end catch (const ParseErrors& e) { CHECK_EQ( - "Expected ')' (to close '(' at column 17), got '='; did you mean to use '{' when defining a table?", e.getErrors().front().getMessage()); + "Expected ')' (to close '(' at column 17), got '='; did you mean to use '{' when defining a table?", e.getErrors().front().getMessage() + ); } } @@ -1305,64 +1356,115 @@ end TEST_CASE_FIXTURE(Fixture, "parse_error_with_too_many_nested_type_group") { - ScopedFastInt sfis{"LuauRecursionLimit", 20}; + ScopedFastInt sfis{FInt::LuauRecursionLimit, 10}; + + matchParseError( + "function f(): ((((((((((Fail)))))))))) end", "Exceeded allowed recursion depth; simplify your type annotation to make the code compile" + ); + + matchParseError( + "function f(): () -> () -> () -> () -> () -> () -> () -> () -> () -> () -> () end", + "Exceeded allowed recursion depth; simplify your type annotation to make the code compile" + ); + + matchParseError( + "local t: {a: {b: {c: {d: {e: {f: {g: {h: {i: {j: {}}}}}}}}}}}", + "Exceeded allowed recursion depth; simplify your type annotation to make the code compile" + ); + + matchParseError("local f: ((((((((((Fail))))))))))", "Exceeded allowed recursion depth; simplify your type annotation to make the code compile"); matchParseError( - "function f(): (((((((((Fail))))))))) end", "Exceeded allowed recursion depth; simplify your type annotation to make the code compile"); + "local t: a & (b & (c & (d & (e & (f & (g & (h & (i & (j & nil)))))))))", + "Exceeded allowed recursion depth; simplify your type annotation to make the code compile" + ); +} + +TEST_CASE_FIXTURE(Fixture, "can_parse_complex_unions_successfully") +{ + ScopedFastInt sfis[] = {{FInt::LuauRecursionLimit, 10}, {FInt::LuauTypeLengthLimit, 10}}; - matchParseError("function f(): () -> () -> () -> () -> () -> () -> () -> () -> () -> () -> () end", - "Exceeded allowed recursion depth; simplify your type annotation to make the code compile"); + parse(R"( +local f: +() -> () +| +() -> () +| +{a: number} +| +{b: number} +| +((number)) +| +((number)) +| +(a & (b & nil)) +| +(a & (b & nil)) +)"); + + parse(R"( +local f: a? | b? | c? | d? | e? | f? | g? | h? +)"); matchParseError( - "local t: {a: {b: {c: {d: {e: {f: {}}}}}}}", "Exceeded allowed recursion depth; simplify your type annotation to make the code compile"); + "local t: a & b & c & d & e & f & g & h & i & j & nil", "Exceeded allowed type length; simplify your type annotation to make the code compile" + ); } TEST_CASE_FIXTURE(Fixture, "parse_error_with_too_many_nested_if_statements") { - ScopedFastInt sfis{"LuauRecursionLimit", 10}; + ScopedFastInt sfis{FInt::LuauRecursionLimit, 10}; matchParseErrorPrefix( "function f() if true then if true then if true then if true then if true then if true then if true then if true then if true " "then if true then if true then end end end end end end end end end end end end", - "Exceeded allowed recursion depth;"); + "Exceeded allowed recursion depth;" + ); } TEST_CASE_FIXTURE(Fixture, "parse_error_with_too_many_changed_elseif_statements") { - ScopedFastInt sfis{"LuauRecursionLimit", 10}; + ScopedFastInt sfis{FInt::LuauRecursionLimit, 10}; matchParseErrorPrefix( "function f() if false then elseif false then elseif false then elseif false then elseif false then elseif false then elseif " "false then elseif false then elseif false then elseif false then elseif false then end end", - "Exceeded allowed recursion depth;"); + "Exceeded allowed recursion depth;" + ); } TEST_CASE_FIXTURE(Fixture, "parse_error_with_too_many_nested_ifelse_expressions1") { - ScopedFastInt sfis{"LuauRecursionLimit", 10}; + ScopedFastInt sfis{FInt::LuauRecursionLimit, 10}; - matchParseError("function f() return if true then 1 elseif true then 2 elseif true then 3 elseif true then 4 elseif true then 5 elseif true then " - "6 elseif true then 7 elseif true then 8 elseif true then 9 elseif true then 10 else 11 end", - "Exceeded allowed recursion depth; simplify your expression to make the code compile"); + matchParseError( + "function f() return if true then 1 elseif true then 2 elseif true then 3 elseif true then 4 elseif true then 5 elseif true then " + "6 elseif true then 7 elseif true then 8 elseif true then 9 elseif true then 10 else 11 end", + "Exceeded allowed recursion depth; simplify your expression to make the code compile" + ); } TEST_CASE_FIXTURE(Fixture, "parse_error_with_too_many_nested_ifelse_expressions2") { - ScopedFastInt sfis{"LuauRecursionLimit", 10}; + ScopedFastInt sfis{FInt::LuauRecursionLimit, 10}; matchParseError( "function f() return if if if if if if if if if if true then false else true then false else true then false else true then false else true " "then false else true then false else true then false else true then false else true then false else true then 1 else 2 end", - "Exceeded allowed recursion depth; simplify your expression to make the code compile"); + "Exceeded allowed recursion depth; simplify your expression to make the code compile" + ); } TEST_CASE_FIXTURE(Fixture, "unparenthesized_function_return_type_list") { matchParseError( - "function foo(): string, number end", "Expected a statement, got ','; did you forget to wrap the list of return types in parentheses?"); + "function foo(): string, number end", "Expected a statement, got ','; did you forget to wrap the list of return types in parentheses?" + ); - matchParseError("function foo(): (number) -> string, string", - "Expected a statement, got ','; did you forget to wrap the list of return types in parentheses?"); + matchParseError( + "function foo(): (number) -> string, string", "Expected a statement, got ','; did you forget to wrap the list of return types in parentheses?" + ); // Will throw if the parse fails parse(R"( @@ -1587,9 +1689,9 @@ TEST_CASE_FIXTURE(Fixture, "string_literals_escapes_broken") TEST_CASE_FIXTURE(Fixture, "string_literals_broken") { - matchParseError("return \"", "Malformed string"); - matchParseError("return \"\\", "Malformed string"); - matchParseError("return \"\r\r", "Malformed string"); + matchParseError("return \"", "Malformed string; did you forget to finish it?"); + matchParseError("return \"\\", "Malformed string; did you forget to finish it?"); + matchParseError("return \"\r\r", "Malformed string; did you forget to finish it?"); } TEST_CASE_FIXTURE(Fixture, "number_literals") @@ -1670,12 +1772,14 @@ TEST_CASE_FIXTURE(Fixture, "end_extent_doesnt_consume_comments_even_with_capture ParseOptions opts; opts.captureComments = true; - AstStatBlock* block = parse(R"( + AstStatBlock* block = parse( + R"( type F = number --comment print('hello') )", - opts); + opts + ); REQUIRE_EQ(2, block->body.size); CHECK_EQ((Position{1, 23}), block->body.data[0]->location.end); @@ -1691,45 +1795,53 @@ TEST_CASE_FIXTURE(Fixture, "parse_error_loop_control") TEST_CASE_FIXTURE(Fixture, "parse_error_confusing_function_call") { - auto result1 = matchParseError(R"( + auto result1 = matchParseError( + R"( function add(x, y) return x + y end add (4, 7) )", "Ambiguous syntax: this looks like an argument list for a function call, but could also be a start of new statement; use ';' to separate " - "statements"); + "statements" + ); CHECK(result1.errors.size() == 1); - auto result2 = matchParseError(R"( + auto result2 = matchParseError( + R"( function add(x, y) return x + y end local f = add (f :: any)['x'] = 2 )", "Ambiguous syntax: this looks like an argument list for a function call, but could also be a start of new statement; use ';' to separate " - "statements"); + "statements" + ); CHECK(result2.errors.size() == 1); - auto result3 = matchParseError(R"( + auto result3 = matchParseError( + R"( local x = {} function x:add(a, b) return a + b end x:add (1, 2) )", "Ambiguous syntax: this looks like an argument list for a function call, but could also be a start of new statement; use ';' to separate " - "statements"); + "statements" + ); CHECK(result3.errors.size() == 1); - auto result4 = matchParseError(R"( + auto result4 = matchParseError( + R"( local t = {} function f() return t end t.x, (f) ().y = 5, 6 )", "Ambiguous syntax: this looks like an argument list for a function call, but could also be a start of new statement; use ';' to separate " - "statements"); + "statements" + ); CHECK(result4.errors.size() == 1); } @@ -1741,17 +1853,21 @@ TEST_CASE_FIXTURE(Fixture, "parse_error_varargs") TEST_CASE_FIXTURE(Fixture, "parse_error_assignment_lvalue") { - matchParseError(R"( + matchParseError( + R"( local a, b (2), b = b, a )", - "Assigned expression must be a variable or a field"); + "Assigned expression must be a variable or a field" + ); - matchParseError(R"( + matchParseError( + R"( local a, b a, (3) = b, a )", - "Assigned expression must be a variable or a field"); + "Assigned expression must be a variable or a field" + ); } TEST_CASE_FIXTURE(Fixture, "parse_error_type_annotation") @@ -1814,18 +1930,23 @@ TEST_CASE_FIXTURE(Fixture, "parse_declarations") AstStatDeclareGlobal* global = stat->body.data[0]->as(); REQUIRE(global); CHECK(global->name == "foo"); + CHECK(global->nameLocation == Location({1, 16}, {1, 19})); CHECK(global->type); AstStatDeclareFunction* func = stat->body.data[1]->as(); REQUIRE(func); CHECK(func->name == "bar"); + CHECK(func->nameLocation == Location({2, 25}, {2, 28})); REQUIRE_EQ(func->params.types.size, 1); REQUIRE_EQ(func->retTypes.types.size, 1); AstStatDeclareFunction* varFunc = stat->body.data[2]->as(); REQUIRE(varFunc); CHECK(varFunc->name == "var"); + CHECK(varFunc->nameLocation == Location({3, 25}, {3, 28})); CHECK(varFunc->params.tailType); + CHECK(varFunc->vararg); + CHECK(varFunc->varargLocation == Location({3, 29}, {3, 32})); matchParseError("declare function foo(x)", "All declaration parameters must be annotated"); matchParseError("declare foo", "Expected ':' when parsing global variable declaration, got "); @@ -1856,11 +1977,16 @@ TEST_CASE_FIXTURE(Fixture, "parse_class_declarations") AstDeclaredClassProp& prop = declaredClass->props.data[0]; CHECK(prop.name == "prop"); + CHECK(prop.nameLocation == Location({2, 12}, {2, 16})); CHECK(prop.ty->is()); + CHECK(prop.location == Location({2, 12}, {2, 24})); AstDeclaredClassProp& method = declaredClass->props.data[1]; CHECK(method.name == "method"); + CHECK(method.nameLocation == Location({3, 21}, {3, 27})); CHECK(method.ty->is()); + CHECK(method.location == Location({3, 12}, {3, 54})); + CHECK(method.isMethod); AstStatDeclareClass* subclass = stat->body.data[1]->as(); REQUIRE(subclass); @@ -1871,19 +1997,23 @@ TEST_CASE_FIXTURE(Fixture, "parse_class_declarations") REQUIRE_EQ(subclass->props.size, 1); AstDeclaredClassProp& prop2 = subclass->props.data[0]; CHECK(prop2.name == "prop2"); + CHECK(prop2.nameLocation == Location({7, 12}, {7, 17})); CHECK(prop2.ty->is()); + CHECK(prop2.location == Location({7, 12}, {7, 25})); } TEST_CASE_FIXTURE(Fixture, "class_method_properties") { - const ParseResult p1 = matchParseError(R"( + const ParseResult p1 = matchParseError( + R"( declare class Foo -- method's first parameter must be 'self' function method(foo: number) function method2(self) end )", - "'self' must be present as the unannotated first parameter"); + "'self' must be present as the unannotated first parameter" + ); REQUIRE_EQ(1, p1.root->body.size); @@ -1892,13 +2022,15 @@ TEST_CASE_FIXTURE(Fixture, "class_method_properties") CHECK_EQ(2, klass->props.size); - const ParseResult p2 = matchParseError(R"( + const ParseResult p2 = matchParseError( + R"( declare class Foo function method(self, foo) function method2() end )", - "All declaration parameters aside from 'self' must be annotated"); + "All declaration parameters aside from 'self' must be annotated" + ); REQUIRE_EQ(1, p2.root->body.size); @@ -1908,9 +2040,46 @@ TEST_CASE_FIXTURE(Fixture, "class_method_properties") CHECK_EQ(2, klass2->props.size); } +TEST_CASE_FIXTURE(Fixture, "class_indexer") +{ + AstStatBlock* stat = parseEx(R"( + declare class Foo + prop: boolean + [string]: number + end + )") + .root; + + REQUIRE_EQ(stat->body.size, 1); + + AstStatDeclareClass* declaredClass = stat->body.data[0]->as(); + REQUIRE(declaredClass); + REQUIRE(declaredClass->indexer); + REQUIRE(declaredClass->indexer->indexType->is()); + CHECK(declaredClass->indexer->indexType->as()->name == "string"); + REQUIRE(declaredClass->indexer->resultType->is()); + CHECK(declaredClass->indexer->resultType->as()->name == "number"); + + const ParseResult p1 = matchParseError( + R"( + declare class Foo + [string]: number + -- can only have one indexer + [number]: number + end + )", + "Cannot have more than one class indexer" + ); + + REQUIRE_EQ(1, p1.root->body.size); + + AstStatDeclareClass* klass = p1.root->body.data[0]->as(); + REQUIRE(klass != nullptr); + CHECK(klass->indexer); +} + TEST_CASE_FIXTURE(Fixture, "parse_variadics") { - //clang-format off AstStatBlock* stat = parseEx(R"( function foo(bar, ...: number): ...string end @@ -1919,7 +2088,6 @@ TEST_CASE_FIXTURE(Fixture, "parse_variadics") type Bar = () -> (number, ...boolean) )") .root; - //clang-format on REQUIRE(stat); REQUIRE_EQ(stat->body.size, 3); @@ -2092,8 +2260,9 @@ TEST_CASE_FIXTURE(Fixture, "function_type_named_arguments") CHECK_EQ(funcRet->argNames.data[2]->first, "f"); } - matchParseError("type MyFunc = (a: number, b: string, c: number) -> (d: number, e: string, f: number)", - "Expected '->' when parsing function type, got "); + matchParseError( + "type MyFunc = (a: number, b: string, c: number) -> (d: number, e: string, f: number)", "Expected '->' when parsing function type, got " + ); matchParseError("type MyFunc = (number) -> (d: number) -> number", "Expected '->' when parsing function type, got '<'"); } @@ -2127,8 +2296,11 @@ TEST_CASE_FIXTURE(Fixture, "parse_type_alias_default_type_errors") TEST_CASE_FIXTURE(Fixture, "parse_type_pack_errors") { - matchParseError("type Y = {a: T..., b: number}", "Unexpected '...' after type name; type pack is not allowed in this context", - Location{{0, 20}, {0, 23}}); + matchParseError( + "type Y = {a: T..., b: number}", + "Unexpected '...' after type name; type pack is not allowed in this context", + Location{{0, 20}, {0, 23}} + ); matchParseError("type Y = {a: (number | string)...", "Unexpected '...' after type annotation", Location{{0, 36}, {0, 39}}); } @@ -2206,6 +2378,54 @@ TEST_CASE_FIXTURE(Fixture, "invalid_type_forms") matchParseError("type F = (T...) -> ()", "Expected '->' when parsing function type, got '>'"); } +TEST_CASE_FIXTURE(Fixture, "parse_user_defined_type_functions") +{ + ScopedFastFlag sff{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; + + AstStat* stat = parse(R"( + type function foo() + return + end + )"); + + REQUIRE(stat != nullptr); + AstStatTypeFunction* f = stat->as()->body.data[0]->as(); + REQUIRE(f != nullptr); + REQUIRE(f->name == "foo"); +} + +TEST_CASE_FIXTURE(Fixture, "parse_nested_type_function") +{ + ScopedFastFlag sff{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; + + AstStat* stat = parse(R"( + local v1 = 1 + type function foo() + local v2 = 2 + local function bar() + v2 += 1 + type function inner() end + v2 += 2 + end + local function bar2() + v2 += 3 + end + end + local function bar() v1 += 1 end + )"); + + REQUIRE(stat != nullptr); +} + +TEST_CASE_FIXTURE(Fixture, "invalid_user_defined_type_functions") +{ + ScopedFastFlag sff{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; + + matchParseError("export type function foo() end", "Type function cannot be exported"); + matchParseError("local foo = 1; type function bar() print(foo) end", "Type function cannot reference outer local 'foo'"); + matchParseError("type function foo() local v1 = 1; type function bar() print(v1) end end", "Type function cannot reference outer local 'v1'"); +} + TEST_SUITE_END(); TEST_SUITE_BEGIN("ParseErrorRecovery"); @@ -2327,7 +2547,7 @@ local a : { [string] : number, [number] : string, count: number } TEST_CASE_FIXTURE(Fixture, "recovery_error_limit_1") { - ScopedFastInt luauParseErrorLimit("LuauParseErrorLimit", 1); + ScopedFastInt luauParseErrorLimit(FInt::LuauParseErrorLimit, 1); try { @@ -2343,7 +2563,7 @@ TEST_CASE_FIXTURE(Fixture, "recovery_error_limit_1") TEST_CASE_FIXTURE(Fixture, "recovery_error_limit_2") { - ScopedFastInt luauParseErrorLimit("LuauParseErrorLimit", 2); + ScopedFastInt luauParseErrorLimit(FInt::LuauParseErrorLimit, 2); try { @@ -2373,7 +2593,8 @@ class CountAstNodes : public AstVisitor TEST_CASE_FIXTURE(Fixture, "recovery_of_parenthesized_expressions") { - auto checkAstEquivalence = [this](const char* codeWithErrors, const char* code) { + auto checkAstEquivalence = [this](const char* codeWithErrors, const char* code) + { try { parse(codeWithErrors); @@ -2393,7 +2614,8 @@ TEST_CASE_FIXTURE(Fixture, "recovery_of_parenthesized_expressions") CHECK_EQ(counterWithErrors.count, counter.count); }; - auto checkRecovery = [this, checkAstEquivalence](const char* codeWithErrors, const char* code, unsigned expectedErrorCount) { + auto checkRecovery = [this, checkAstEquivalence](const char* codeWithErrors, const char* code, unsigned expectedErrorCount) + { try { parse(codeWithErrors); @@ -2406,9 +2628,14 @@ TEST_CASE_FIXTURE(Fixture, "recovery_of_parenthesized_expressions") } }; + ScopedFastFlag sff[] = { + {FFlag::LuauSolverV2, false}, + }; + checkRecovery("function foo(a, b. c) return a + b end", "function foo(a, b) return a + b end", 1); - checkRecovery("function foo(a, b: { a: number, b: number. c:number }) return a + b end", - "function foo(a, b: { a: number, b: number }) return a + b end", 1); + checkRecovery( + "function foo(a, b: { a: number, b: number. c:number }) return a + b end", "function foo(a, b: { a: number, b: number }) return a + b end", 1 + ); checkRecovery("function foo(a, b): (number -> number return a + b end", "function foo(a, b): (number) -> number return a + b end", 1); checkRecovery("function foo(a, b): (number, number -> number return a + b end", "function foo(a, b): (number) -> number return a + b end", 1); @@ -2425,12 +2652,15 @@ TEST_CASE_FIXTURE(Fixture, "recovery_of_parenthesized_expressions") checkRecovery("local n: (string | number = 2", "local n: (string | number) = 2", 1); // Check that we correctly stop at the end of a line - checkRecovery(R"( + checkRecovery( + R"( function foo(a, b return a + b end )", - "function foo(a, b) return a + b end", 1); + "function foo(a, b) return a + b end", + 1 + ); } TEST_CASE_FIXTURE(Fixture, "incomplete_method_call") @@ -2506,12 +2736,12 @@ TEST_CASE_FIXTURE(Fixture, "incomplete_method_call_still_yields_an_AstExprIndexN TEST_CASE_FIXTURE(Fixture, "recover_confusables") { // Binary - matchParseError("local a = 4 != 10", "Unexpected '!=', did you mean '~='?"); - matchParseError("local a = true && false", "Unexpected '&&', did you mean 'and'?"); - matchParseError("local a = false || true", "Unexpected '||', did you mean 'or'?"); + matchParseError("local a = 4 != 10", "Unexpected '!='; did you mean '~='?"); + matchParseError("local a = true && false", "Unexpected '&&'; did you mean 'and'?"); + matchParseError("local a = false || true", "Unexpected '||'; did you mean 'or'?"); // Unary - matchParseError("local a = !false", "Unexpected '!', did you mean 'not'?"); + matchParseError("local a = !false", "Unexpected '!'; did you mean 'not'?"); // Check that separate tokens are not considered as a single one matchParseError("local a = 4 ! = 10", "Expected identifier when parsing expression, got '!'"); @@ -2524,7 +2754,8 @@ TEST_CASE_FIXTURE(Fixture, "capture_comments") ParseOptions options; options.captureComments = true; - ParseResult result = parseEx(R"( + ParseResult result = parseEx( + R"( --!strict local a = 5 -- comment one @@ -2534,7 +2765,8 @@ TEST_CASE_FIXTURE(Fixture, "capture_comments") ]] local c = 'see' )", - options); + options + ); CHECK(result.errors.empty()); @@ -2550,10 +2782,12 @@ TEST_CASE_FIXTURE(Fixture, "capture_broken_comment_at_the_start_of_the_file") ParseOptions options; options.captureComments = true; - ParseResult result = tryParse(R"( + ParseResult result = tryParse( + R"( --[[ )", - options); + options + ); CHECK_EQ(1, result.commentLocations.size()); CHECK_EQ((Location{{1, 8}, {2, 4}}), result.commentLocations[0].location); @@ -2564,12 +2798,14 @@ TEST_CASE_FIXTURE(Fixture, "capture_broken_comment") ParseOptions options; options.captureComments = true; - ParseResult result = tryParse(R"( + ParseResult result = tryParse( + R"( local a = "test" --[[broken! )", - options); + options + ); CHECK_EQ(1, result.commentLocations.size()); CHECK_EQ((Location{{3, 8}, {4, 4}}), result.commentLocations[0].location); @@ -2636,6 +2872,10 @@ TEST_CASE_FIXTURE(Fixture, "AstName_comparison") TEST_CASE_FIXTURE(Fixture, "generic_type_list_recovery") { + ScopedFastFlag sff[] = { + {FFlag::LuauSolverV2, false}, + }; + try { parse(R"( @@ -2739,8 +2979,10 @@ TEST_CASE_FIXTURE(Fixture, "error_message_for_using_function_as_type_annotation" type Foo = function )"); REQUIRE_EQ(1, result.errors.size()); - CHECK_EQ("Using 'function' as a type annotation is not supported, consider replacing with a function type annotation e.g. '(...any) -> ...any'", - result.errors[0].getMessage()); + CHECK_EQ( + "Using 'function' as a type annotation is not supported, consider replacing with a function type annotation e.g. '(...any) -> ...any'", + result.errors[0].getMessage() + ); } TEST_CASE_FIXTURE(Fixture, "get_a_nice_error_when_there_is_an_extra_comma_at_the_end_of_a_function_argument_list") @@ -2839,8 +3081,6 @@ TEST_CASE_FIXTURE(Fixture, "get_a_nice_error_when_there_is_no_comma_after_last_t TEST_CASE_FIXTURE(Fixture, "missing_default_type_pack_argument_after_variadic_type_parameter") { - ScopedFastFlag sff{"LuauParserErrorsOnMissingDefaultTypePackArgument", true}; - ParseResult result = tryParse(R"( type Foo = nil )"); @@ -2854,4 +3094,592 @@ TEST_CASE_FIXTURE(Fixture, "missing_default_type_pack_argument_after_variadic_ty CHECK_EQ("Expected type pack after '=', got type", result.errors[1].getMessage()); } +TEST_CASE_FIXTURE(Fixture, "table_type_keys_cant_contain_nul") +{ + ParseResult result = tryParse(R"( + type Foo = { ["\0"]: number } + )"); + + REQUIRE_EQ(1, result.errors.size()); + + CHECK_EQ(Location{{1, 21}, {1, 22}}, result.errors[0].getLocation()); + CHECK_EQ("String literal contains malformed escape sequence or \\0", result.errors[0].getMessage()); +} + +TEST_CASE_FIXTURE(Fixture, "invalid_escape_literals_get_reported_but_parsing_continues") +{ + ParseResult result = tryParse(R"( + local foo = "\xQQ" + print(foo) + )"); + + REQUIRE_EQ(1, result.errors.size()); + + CHECK_EQ(Location{{1, 20}, {1, 26}}, result.errors[0].getLocation()); + CHECK_EQ("String literal contains malformed escape sequence", result.errors[0].getMessage()); + + REQUIRE(result.root); + CHECK_EQ(result.root->body.size, 2); +} + +TEST_CASE_FIXTURE(Fixture, "unfinished_string_literals_get_reported_but_parsing_continues") +{ + ParseResult result = tryParse(R"( + local foo = "hi + print(foo) + )"); + + REQUIRE_EQ(1, result.errors.size()); + + CHECK_EQ(Location{{1, 20}, {1, 23}}, result.errors[0].getLocation()); + CHECK_EQ("Malformed string; did you forget to finish it?", result.errors[0].getMessage()); + + REQUIRE(result.root); + CHECK_EQ(result.root->body.size, 2); +} + +TEST_CASE_FIXTURE(Fixture, "unfinished_string_literal_types_get_reported_but_parsing_continues") +{ + ParseResult result = tryParse(R"( + type Foo = "hi + print(foo) + )"); + + REQUIRE_EQ(1, result.errors.size()); + + CHECK_EQ(Location{{1, 19}, {1, 22}}, result.errors[0].getLocation()); + CHECK_EQ("Malformed string; did you forget to finish it?", result.errors[0].getMessage()); + + REQUIRE(result.root); + CHECK_EQ(result.root->body.size, 2); +} + +TEST_CASE_FIXTURE(Fixture, "do_block_with_no_end") +{ + ParseResult result = tryParse(R"( + do + )"); + + REQUIRE_EQ(1, result.errors.size()); + + AstStatBlock* stat0 = result.root->body.data[0]->as(); + REQUIRE(stat0); + + CHECK(!stat0->hasEnd); +} + +TEST_CASE_FIXTURE(Fixture, "parse_interpolated_string_with_lookahead_involved") +{ + ParseResult result = tryParse(R"( + local x = `{ {y} }` + )"); + + REQUIRE_MESSAGE(result.errors.empty(), result.errors[0].getMessage()); +} + +TEST_CASE_FIXTURE(Fixture, "parse_interpolated_string_with_lookahead_involved2") +{ + ParseResult result = tryParse(R"( + local x = `{ { y{} } }` + )"); + + REQUIRE_MESSAGE(result.errors.empty(), result.errors[0].getMessage()); +} + +TEST_CASE_FIXTURE(Fixture, "parse_top_level_checked_fn") +{ + ParseOptions opts; + opts.allowDeclarationSyntax = true; + + std::string src = R"BUILTIN_SRC( +@checked declare function abs(n: number): number +)BUILTIN_SRC"; + + ParseResult pr = tryParse(src, opts); + LUAU_ASSERT(pr.errors.size() == 0); + + LUAU_ASSERT(pr.root->body.size == 1); + AstStat* root = *(pr.root->body.data); + auto func = root->as(); + LUAU_ASSERT(func); + LUAU_ASSERT(func->isCheckedFunction()); +} + +TEST_CASE_FIXTURE(Fixture, "parse_declared_table_checked_member") +{ + ParseOptions opts; + opts.allowDeclarationSyntax = true; + + const std::string src = R"BUILTIN_SRC( + declare math : { + abs : @checked (number) -> number +} +)BUILTIN_SRC"; + + ParseResult pr = tryParse(src, opts); + LUAU_ASSERT(pr.errors.size() == 0); + + LUAU_ASSERT(pr.root->body.size == 1); + AstStat* root = *(pr.root->body.data); + auto glob = root->as(); + LUAU_ASSERT(glob); + auto tbl = glob->type->as(); + LUAU_ASSERT(tbl); + LUAU_ASSERT(tbl->props.size == 1); + auto prop = *tbl->props.data; + auto func = prop.type->as(); + LUAU_ASSERT(func); + LUAU_ASSERT(func->isCheckedFunction()); +} + +TEST_CASE_FIXTURE(Fixture, "parse_checked_outside_decl_fails") +{ + ParseOptions opts; + opts.allowDeclarationSyntax = true; + + ParseResult pr = tryParse( + R"( + local @checked = 3 +)", + opts + ); + LUAU_ASSERT(pr.errors.size() > 0); + auto ts = pr.errors[1].getMessage(); +} + +TEST_CASE_FIXTURE(Fixture, "parse_checked_in_and_out_of_decl_fails") +{ + ParseOptions opts; + opts.allowDeclarationSyntax = true; + + auto pr = tryParse( + R"( + local @checked = 3 + @checked declare function abs(n: number): number +)", + opts + ); + LUAU_ASSERT(pr.errors.size() == 2); + LUAU_ASSERT(pr.errors[0].getLocation().begin.line == 1); + LUAU_ASSERT(pr.errors[1].getLocation().begin.line == 1); +} + +TEST_CASE_FIXTURE(Fixture, "parse_checked_as_function_name_fails") +{ + ParseOptions opts; + opts.allowDeclarationSyntax = true; + + auto pr = tryParse( + R"( + @checked function(x: number) : number + end +)", + opts + ); + LUAU_ASSERT(pr.errors.size() > 0); +} + +TEST_CASE_FIXTURE(Fixture, "cannot_use_@_as_variable_name") +{ + ParseOptions opts; + opts.allowDeclarationSyntax = true; + + auto pr = tryParse( + R"( + local @blah = 3 +)", + opts + ); + + LUAU_ASSERT(pr.errors.size() > 0); +} + +TEST_CASE_FIXTURE(Fixture, "read_write_table_properties") +{ + auto pr = tryParse(R"( + type A = {read x: number} + type B = {write x: number} + type C = {read x: number, write x: number} + type D = {read: () -> string} + type E = {write: (string) -> ()} + type F = {read read: () -> string} + type G = {read write: (string) -> ()} + + type H = {read ["A"]: number} + type I = {write ["A"]: string} + + type J = {read [number]: number} + type K = {write [number]: string} + )"); + + LUAU_ASSERT(pr.errors.size() == 0); +} + +void checkAttribute(const AstAttr* attr, const AstAttr::Type type, const Location& location) +{ + CHECK_EQ(attr->type, type); + CHECK_EQ(attr->location, location); +} + +void checkFirstErrorForAttributes(const std::vector& errors, const size_t minSize, const Location& location, const std::string& message) +{ + LUAU_ASSERT(minSize >= 1); + + CHECK_GE(errors.size(), minSize); + CHECK_EQ(errors[0].getLocation(), location); + CHECK_EQ(errors[0].getMessage(), message); +} + +TEST_CASE_FIXTURE(Fixture, "parse_attribute_on_function_stat") +{ + + AstStatBlock* stat = parse(R"( +@checked +function hello(x, y) + return x + y +end)"); + + LUAU_ASSERT(stat != nullptr); + + AstStatFunction* statFun = stat->body.data[0]->as(); + LUAU_ASSERT(statFun != nullptr); + + AstArray attributes = statFun->func->attributes; + + CHECK_EQ(attributes.size, 1); + + checkAttribute(attributes.data[0], AstAttr::Type::Checked, Location(Position(1, 0), Position(1, 8))); +} + +TEST_CASE_FIXTURE(Fixture, "parse_attribute_for_function_expression") +{ + ScopedFastFlag sff[] = {{FFlag::LuauAttributeSyntaxFunExpr, true}}; + + AstStatBlock* stat1 = parse(R"( +local function invoker(f) + return f(1) +end + +invoker(@checked function(x) return (x + 2) end) +)"); + + LUAU_ASSERT(stat1 != nullptr); + + AstExprFunction* func1 = stat1->body.data[1]->as()->expr->as()->args.data[0]->as(); + LUAU_ASSERT(func1 != nullptr); + + AstArray attributes1 = func1->attributes; + + CHECK_EQ(attributes1.size, 1); + + checkAttribute(attributes1.data[0], AstAttr::Type::Checked, Location(Position(5, 8), Position(5, 16))); + + AstStatBlock* stat2 = parse(R"( +local f = @checked function(x) return (x + 2) end +)"); + + LUAU_ASSERT(stat2 != nullptr); + + AstExprFunction* func2 = stat2->body.data[0]->as()->values.data[0]->as(); + LUAU_ASSERT(func2 != nullptr); + + AstArray attributes2 = func2->attributes; + + CHECK_EQ(attributes2.size, 1); + + checkAttribute(attributes2.data[0], AstAttr::Type::Checked, Location(Position(1, 10), Position(1, 18))); +} + +TEST_CASE_FIXTURE(Fixture, "parse_attribute_on_local_function_stat") +{ + AstStatBlock* stat = parse(R"( + @checked +local function hello(x, y) + return x + y +end)"); + + LUAU_ASSERT(stat != nullptr); + + AstStatLocalFunction* statFun = stat->body.data[0]->as(); + LUAU_ASSERT(statFun != nullptr); + + AstArray attributes = statFun->func->attributes; + + CHECK_EQ(attributes.size, 1); + + checkAttribute(attributes.data[0], AstAttr::Type::Checked, Location(Position(1, 4), Position(1, 12))); +} + +TEST_CASE_FIXTURE(Fixture, "empty_attribute_name_is_not_allowed") +{ + ParseResult result = tryParse(R"( +@ +function hello(x, y) + return x + y +end)"); + + checkFirstErrorForAttributes(result.errors, 1, Location(Position(1, 0), Position(1, 1)), "Attribute name is missing"); +} + +TEST_CASE_FIXTURE(Fixture, "dont_parse_attributes_on_non_function_stat") +{ + ParseResult pr1 = tryParse(R"( +@checked +if a<0 then a = 0 end)"); + checkFirstErrorForAttributes( + pr1.errors, + 1, + Location(Position(2, 0), Position(2, 2)), + "Expected 'function', 'local function', 'declare function' or a function type declaration after attribute, but got 'if' instead" + ); + + ParseResult pr2 = tryParse(R"( +local i = 1 +@checked +while a[i] do + print(a[i]) + i = i + 1 +end)"); + checkFirstErrorForAttributes( + pr2.errors, + 1, + Location(Position(3, 0), Position(3, 5)), + "Expected 'function', 'local function', 'declare function' or a function type declaration after attribute, but got 'while' instead" + ); + + ParseResult pr3 = tryParse(R"( +@checked +do + local a2 = 2*a + local d = sqrt(b^2 - 4*a*c) + x1 = (-b + d)/a2 + x2 = (-b - d)/a2 +end)"); + checkFirstErrorForAttributes( + pr3.errors, + 1, + Location(Position(2, 0), Position(2, 2)), + "Expected 'function', 'local function', 'declare function' or a function type declaration after attribute, but got 'do' instead" + ); + + ParseResult pr4 = tryParse(R"( +@checked +for i=1,10 do print(i) end +)"); + checkFirstErrorForAttributes( + pr4.errors, + 1, + Location(Position(2, 0), Position(2, 3)), + "Expected 'function', 'local function', 'declare function' or a function type declaration after attribute, but got 'for' instead" + ); + + ParseResult pr5 = tryParse(R"( +@checked +repeat + line = io.read() +until line ~= "" +)"); + checkFirstErrorForAttributes( + pr5.errors, + 1, + Location(Position(2, 0), Position(2, 6)), + "Expected 'function', 'local function', 'declare function' or a function type declaration after attribute, but got 'repeat' instead" + ); + + + ParseResult pr6 = tryParse(R"( +@checked +local x = 10 +)"); + checkFirstErrorForAttributes( + pr6.errors, 1, Location(Position(2, 6), Position(2, 7)), "Expected 'function' after local declaration with attribute, but got 'x' instead" + ); + + ParseResult pr7 = tryParse(R"( +local i = 1 +while a[i] do + if a[i] == v then @checked break end + i = i + 1 +end +)"); + checkFirstErrorForAttributes( + pr7.errors, + 1, + Location(Position(3, 31), Position(3, 36)), + "Expected 'function', 'local function', 'declare function' or a function type declaration after attribute, but got 'break' instead" + ); + + + ParseResult pr8 = tryParse(R"( +function foo1 () @checked return 'a' end +)"); + checkFirstErrorForAttributes( + pr8.errors, + 1, + Location(Position(1, 26), Position(1, 32)), + "Expected 'function', 'local function', 'declare function' or a function type declaration after attribute, but got 'return' instead" + ); +} + +TEST_CASE_FIXTURE(Fixture, "dont_parse_attribute_on_argument_non_function") +{ + ScopedFastFlag sff[] = {{FFlag::LuauAttributeSyntaxFunExpr, true}}; + + ParseResult pr = tryParse(R"( +local function invoker(f, y) + return f(y) +end + +invoker(function(x) return (x + 2) end, @checked 1) +)"); + + checkFirstErrorForAttributes( + pr.errors, 1, Location(Position(5, 40), Position(5, 48)), "Expected 'function' declaration after attribute, but got '1' instead" + ); +} + +TEST_CASE_FIXTURE(Fixture, "parse_attribute_on_function_type_declaration") +{ + ParseOptions opts; + opts.allowDeclarationSyntax = true; + + std::string src = R"( +@checked declare function abs(n: number): number +)"; + + ParseResult pr = tryParse(src, opts); + CHECK_EQ(pr.errors.size(), 0); + + LUAU_ASSERT(pr.root->body.size == 1); + + AstStat* root = *(pr.root->body.data); + + auto func = root->as(); + LUAU_ASSERT(func != nullptr); + + CHECK(func->isCheckedFunction()); + + AstArray attributes = func->attributes; + + checkAttribute(attributes.data[0], AstAttr::Type::Checked, Location(Position(1, 0), Position(1, 8))); +} + +TEST_CASE_FIXTURE(Fixture, "parse_attributes_on_function_type_declaration_in_table") +{ + ParseOptions opts; + opts.allowDeclarationSyntax = true; + + std::string src = R"( +declare bit32: { + band: @checked (...number) -> number +})"; + + ParseResult pr = tryParse(src, opts); + CHECK_EQ(pr.errors.size(), 0); + + LUAU_ASSERT(pr.root->body.size == 1); + + AstStat* root = *(pr.root->body.data); + + AstStatDeclareGlobal* glob = root->as(); + LUAU_ASSERT(glob); + + auto tbl = glob->type->as(); + LUAU_ASSERT(tbl); + + LUAU_ASSERT(tbl->props.size == 1); + AstTableProp prop = tbl->props.data[0]; + + AstTypeFunction* func = prop.type->as(); + LUAU_ASSERT(func); + + AstArray attributes = func->attributes; + + CHECK_EQ(attributes.size, 1); + checkAttribute(attributes.data[0], AstAttr::Type::Checked, Location(Position(2, 10), Position(2, 18))); +} + +TEST_CASE_FIXTURE(Fixture, "dont_parse_attributes_on_non_function_type_declarations") +{ + ParseOptions opts; + opts.allowDeclarationSyntax = true; + + ParseResult pr1 = tryParse( + R"( +@checked declare foo: number + )", + opts + ); + + checkFirstErrorForAttributes( + pr1.errors, 1, Location(Position(1, 17), Position(1, 20)), "Expected a function type declaration after attribute, but got 'foo' instead" + ); + + ParseResult pr2 = tryParse( + R"( +@checked declare class Foo + prop: number + function method(self, foo: number): string +end)", + opts + ); + + checkFirstErrorForAttributes( + pr2.errors, 1, Location(Position(1, 17), Position(1, 22)), "Expected a function type declaration after attribute, but got 'class' instead" + ); + + ParseResult pr3 = tryParse( + R"( +declare bit32: { + band: @checked number +})", + opts + ); + + checkFirstErrorForAttributes( + pr3.errors, 1, Location(Position(2, 19), Position(2, 25)), "Expected '(' when parsing function parameters, got 'number'" + ); +} + +TEST_CASE_FIXTURE(Fixture, "attributes_cannot_be_duplicated") +{ + ParseResult result = tryParse(R"( +@checked + @checked +function hello(x, y) + return x + y +end)"); + + checkFirstErrorForAttributes(result.errors, 1, Location(Position(2, 4), Position(2, 12)), "Cannot duplicate attribute '@checked'"); +} + +TEST_CASE_FIXTURE(Fixture, "unsupported_attributes_are_not_allowed") +{ + ParseResult result = tryParse(R"( +@checked + @cool_attribute +function hello(x, y) + return x + y +end)"); + + checkFirstErrorForAttributes(result.errors, 1, Location(Position(2, 4), Position(2, 19)), "Invalid attribute '@cool_attribute'"); +} + +TEST_CASE_FIXTURE(Fixture, "can_parse_leading_bar_unions_successfully") +{ + parse(R"(type A = | "Hello" | "World")"); +} + +TEST_CASE_FIXTURE(Fixture, "can_parse_leading_ampersand_intersections_successfully") +{ + parse(R"(type A = & { string } & { number })"); +} + +TEST_CASE_FIXTURE(Fixture, "mixed_leading_intersection_and_union_not_allowed") +{ + matchParseError("type A = & number | string | boolean", "Mixing union and intersection types is not allowed; consider wrapping in parentheses."); + matchParseError("type A = | number & string & boolean", "Mixing union and intersection types is not allowed; consider wrapping in parentheses."); +} + + TEST_SUITE_END(); diff --git a/tests/RegisterCallbacks.cpp b/tests/RegisterCallbacks.cpp new file mode 100644 index 000000000..9f4719335 --- /dev/null +++ b/tests/RegisterCallbacks.cpp @@ -0,0 +1,20 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "RegisterCallbacks.h" + +namespace Luau +{ + +std::unordered_set& getRegisterCallbacks() +{ + static std::unordered_set cbs; + return cbs; +} + +int addTestCallback(RegisterCallback cb) +{ + getRegisterCallbacks().insert(cb); + return 0; +} + +} // namespace Luau diff --git a/tests/RegisterCallbacks.h b/tests/RegisterCallbacks.h new file mode 100644 index 000000000..f62ac0e7c --- /dev/null +++ b/tests/RegisterCallbacks.h @@ -0,0 +1,22 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include +#include + +namespace Luau +{ + +using RegisterCallback = void (*)(); + +/// Gets a set of callbacks to run immediately before running tests, intended +/// for registering new tests at runtime. +std::unordered_set& getRegisterCallbacks(); + +/// Adds a new callback to be ran immediately before running tests. +/// +/// @param cb the callback to add. +/// @returns a dummy integer to satisfy a doctest internal contract. +int addTestCallback(RegisterCallback cb); + +} // namespace Luau diff --git a/tests/Repl.test.cpp b/tests/Repl.test.cpp index c22d464ee..71a46878d 100644 --- a/tests/Repl.test.cpp +++ b/tests/Repl.test.cpp @@ -3,6 +3,7 @@ #include "lualib.h" #include "Repl.h" +#include "ScopedFlags.h" #include "doctest.h" @@ -12,6 +13,8 @@ #include #include +LUAU_FASTFLAG(LuauMathMap) + struct Completion { std::string completion; @@ -52,9 +55,14 @@ class ReplFixture { CompletionSet result; int top = lua_gettop(L); - getCompletions(L, inputPrefix, [&result](const std::string& completion, const std::string& display) { - result.insert(Completion{completion, display}); - }); + getCompletions( + L, + inputPrefix, + [&result](const std::string& completion, const std::string& display) + { + result.insert(Completion{completion, display}); + } + ); // Ensure that generating completions doesn't change the position of luau's stack top. CHECK(top == lua_gettop(L)); @@ -167,15 +175,17 @@ TEST_CASE_FIXTURE(ReplFixture, "CompleteGlobalVariables") CHECK(checkCompletion(completions, prefix, "myvariable1")); CHECK(checkCompletion(completions, prefix, "myvariable2")); } + if (FFlag::LuauMathMap) { // Try completing some builtin functions CompletionSet completions = getCompletionSet("math.m"); std::string prefix = "math."; - CHECK(completions.size() == 3); + CHECK(completions.size() == 4); CHECK(checkCompletion(completions, prefix, "max(")); CHECK(checkCompletion(completions, prefix, "min(")); CHECK(checkCompletion(completions, prefix, "modf(")); + CHECK(checkCompletion(completions, prefix, "map(")); } } @@ -420,4 +430,22 @@ print(NewProxyOne.HelloICauseACrash) )"); } +TEST_CASE_FIXTURE(ReplFixture, "InteractiveStackReserve1") +{ + // Reset stack reservation + lua_resume(L, nullptr, 0); + + runCode(L, R"( +local t = {} +)"); +} + +TEST_CASE_FIXTURE(ReplFixture, "InteractiveStackReserve2") +{ + // Reset stack reservation + lua_resume(L, nullptr, 0); + + getCompletionSet("a"); +} + TEST_SUITE_END(); diff --git a/tests/RequireByString.test.cpp b/tests/RequireByString.test.cpp new file mode 100644 index 000000000..641323c2d --- /dev/null +++ b/tests/RequireByString.test.cpp @@ -0,0 +1,480 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/Common.h" +#include "ScopedFlags.h" +#include "lua.h" +#include "lualib.h" + +#include "Repl.h" +#include "FileUtils.h" + +#include "doctest.h" + +#include +#include +#include + +#if __APPLE__ +#include +#if TARGET_OS_IPHONE +#include + +std::optional getResourcePath0() +{ + CFBundleRef mainBundle = CFBundleGetMainBundle(); + if (mainBundle == NULL) + { + return std::nullopt; + } + CFURLRef mainBundleURL = CFBundleCopyBundleURL(mainBundle); + if (mainBundleURL == NULL) + { + CFRelease(mainBundle); + return std::nullopt; + } + + char pathBuffer[PATH_MAX]; + if (!CFURLGetFileSystemRepresentation(mainBundleURL, true, (UInt8*)pathBuffer, PATH_MAX)) + { + CFRelease(mainBundleURL); + CFRelease(mainBundle); + return std::nullopt; + } + + CFRelease(mainBundleURL); + CFRelease(mainBundle); + return std::string(pathBuffer); +} + +std::optional getResourcePath() +{ + static std::optional path0 = getResourcePath0(); + return path0; +} +#endif +#endif + +class ReplWithPathFixture +{ +public: + ReplWithPathFixture() + : luaState(luaL_newstate(), lua_close) + { + L = luaState.get(); + setupState(L); + luaL_sandboxthread(L); + + runCode(L, prettyPrintSource); + } + + // Returns all of the output captured from the pretty printer + std::string getCapturedOutput() + { + lua_getglobal(L, "capturedoutput"); + const char* str = lua_tolstring(L, -1, nullptr); + std::string result(str); + lua_pop(L, 1); + return result; + } + + enum class PathType + { + Absolute, + Relative + }; + + std::string getLuauDirectory(PathType type) + { + std::string luauDirRel = "."; + std::string luauDirAbs; + +#if TARGET_OS_IPHONE + std::optional cwd0 = getCurrentWorkingDirectory(); + std::optional cwd = getResourcePath(); + if (cwd && cwd0) + { + // when running in xcode cwd0 is "/", however that is not always the case + const auto& _res = *cwd; + const auto& _cwd = *cwd0; + if (_res.find(_cwd) == 0) + { + // we need relative path so we subtract cwd0 from cwd + luauDirRel = "./" + _res.substr(_cwd.length()); + } + } +#else + std::optional cwd = getCurrentWorkingDirectory(); +#endif + + REQUIRE_MESSAGE(cwd, "Error getting Luau path"); + std::replace((*cwd).begin(), (*cwd).end(), '\\', '/'); + luauDirAbs = *cwd; + + for (int i = 0; i < 20; ++i) + { + bool engineTestDir = isDirectory(luauDirAbs + "/Client/Luau/tests"); + bool luauTestDir = isDirectory(luauDirAbs + "/luau/tests/require"); + + if (engineTestDir || luauTestDir) + { + if (engineTestDir) + { + luauDirRel += "/Client/Luau"; + luauDirAbs += "/Client/Luau"; + } + else + { + luauDirRel += "/luau"; + luauDirAbs += "/luau"; + } + + + if (type == PathType::Relative) + return luauDirRel; + if (type == PathType::Absolute) + return luauDirAbs; + } + + luauDirRel += "/.."; + std::optional parentPath = getParentPath(luauDirAbs); + REQUIRE_MESSAGE(parentPath, "Error getting Luau path"); + luauDirAbs = *parentPath; + } + + // Could not find the directory + REQUIRE_MESSAGE(false, "Error getting Luau path"); + return {}; + } + + void runProtectedRequire(const std::string& path) + { + runCode(L, "return pcall(function() return require(\"" + path + "\") end)"); + } + + void assertOutputContainsAll(const std::initializer_list& list) + { + const std::string capturedOutput = getCapturedOutput(); + for (const std::string& elem : list) + { + CHECK_MESSAGE(capturedOutput.find(elem) != std::string::npos, "Captured output: ", capturedOutput); + } + } + + lua_State* L; + +private: + std::unique_ptr luaState; + + // This is a simplistic and incomplete pretty printer. + // It is included here to test that the pretty printer hook is being called. + // More elaborate tests to ensure correct output can be added if we introduce + // a more feature rich pretty printer. + std::string prettyPrintSource = R"( +-- Accumulate pretty printer output in `capturedoutput` +capturedoutput = "" + +function arraytostring(arr) + local strings = {} + table.foreachi(arr, function(k,v) table.insert(strings, pptostring(v)) end ) + return "{" .. table.concat(strings, ", ") .. "}" +end + +function pptostring(x) + if type(x) == "table" then + -- Just assume array-like tables for now. + return arraytostring(x) + elseif type(x) == "string" then + return '"' .. x .. '"' + else + return tostring(x) + end +end + +-- Note: Instead of calling print, the pretty printer just stores the output +-- in `capturedoutput` so we can check for the correct results. +function _PRETTYPRINT(...) + local args = table.pack(...) + local strings = {} + for i=1, args.n do + local item = args[i] + local str = pptostring(item, customoptions) + if i == 1 then + capturedoutput = capturedoutput .. str + else + capturedoutput = capturedoutput .. "\t" .. str + end + end +end +)"; +}; + +TEST_SUITE_BEGIN("RequireByStringTests"); + +TEST_CASE("PathResolution") +{ +#ifdef _WIN32 + std::string prefix = "C:/"; +#else + std::string prefix = "/"; +#endif + + CHECK(resolvePath(prefix + "Users/modules/module.luau", "") == prefix + "Users/modules/module.luau"); + CHECK(resolvePath(prefix + "Users/modules/module.luau", "a/string/that/should/be/ignored") == prefix + "Users/modules/module.luau"); + CHECK(resolvePath(prefix + "Users/modules/module.luau", "./a/string/that/should/be/ignored") == prefix + "Users/modules/module.luau"); + CHECK(resolvePath(prefix + "Users/modules/module.luau", "/a/string/that/should/be/ignored") == prefix + "Users/modules/module.luau"); + CHECK(resolvePath(prefix + "Users/modules/module.luau", "/Users/modules") == prefix + "Users/modules/module.luau"); + + CHECK(resolvePath("../module", "") == "../module"); + CHECK(resolvePath("../../module", "") == "../../module"); + CHECK(resolvePath("../module/..", "") == ".."); + CHECK(resolvePath("../module/../..", "") == "../.."); + + CHECK(resolvePath("../dependency", prefix + "Users/modules/module.luau") == prefix + "Users/dependency"); + CHECK(resolvePath("../dependency/", prefix + "Users/modules/module.luau") == prefix + "Users/dependency"); + CHECK(resolvePath("../../../../../Users/dependency", prefix + "Users/modules/module.luau") == prefix + "Users/dependency"); + CHECK(resolvePath("../..", prefix + "Users/modules/module.luau") == prefix); +} + +TEST_CASE("PathNormalization") +{ +#ifdef _WIN32 + std::string prefix = "C:/"; +#else + std::string prefix = "/"; +#endif + + // Relative path + std::optional result = normalizePath("../../modules/module"); + CHECK(result); + std::string normalized = *result; + std::vector variants = { + "./.././.././modules/./module/", "placeholder/../../../modules/module", "../placeholder/placeholder2/../../../modules/module" + }; + for (const std::string& variant : variants) + { + result = normalizePath(variant); + CHECK(result); + CHECK(normalized == *result); + } + + // Absolute path + result = normalizePath(prefix + "Users/modules/module"); + CHECK(result); + normalized = *result; + variants = { + "Users/Users/Users/.././.././modules/./module/", + "placeholder/../Users/..//Users/modules/module", + "Users/../placeholder/placeholder2/../../Users/modules/module" + }; + for (const std::string& variant : variants) + { + result = normalizePath(prefix + variant); + CHECK(result); + CHECK(normalized == *result); + } +} + +TEST_CASE_FIXTURE(ReplWithPathFixture, "RequireSimpleRelativePath") +{ + std::string path = getLuauDirectory(PathType::Relative) + "/tests/require/without_config/dependency"; + runProtectedRequire(path); + assertOutputContainsAll({"true", "result from dependency"}); +} + +TEST_CASE_FIXTURE(ReplWithPathFixture, "RequireRelativeToRequiringFile") +{ + std::string path = getLuauDirectory(PathType::Relative) + "/tests/require/without_config/module"; + runProtectedRequire(path); + assertOutputContainsAll({"true", "result from dependency", "required into module"}); +} + +TEST_CASE_FIXTURE(ReplWithPathFixture, "RequireLua") +{ + std::string path = getLuauDirectory(PathType::Relative) + "/tests/require/without_config/lua_dependency"; + runProtectedRequire(path); + assertOutputContainsAll({"true", "result from lua_dependency"}); +} + +TEST_CASE_FIXTURE(ReplWithPathFixture, "RequireInitLuau") +{ + std::string path = getLuauDirectory(PathType::Relative) + "/tests/require/without_config/luau"; + runProtectedRequire(path); + assertOutputContainsAll({"true", "result from init.luau"}); +} + +TEST_CASE_FIXTURE(ReplWithPathFixture, "RequireInitLua") +{ + std::string path = getLuauDirectory(PathType::Relative) + "/tests/require/without_config/lua"; + runProtectedRequire(path); + assertOutputContainsAll({"true", "result from init.lua"}); +} + +TEST_CASE_FIXTURE(ReplWithPathFixture, "RequireWithFileAmbiguity") +{ + std::string ambiguousPath = getLuauDirectory(PathType::Relative) + "/tests/require/without_config/ambiguous_file_requirer"; + + runProtectedRequire(ambiguousPath); + assertOutputContainsAll({"false", "require path could not be resolved to a unique file"}); +} + +TEST_CASE_FIXTURE(ReplWithPathFixture, "RequireWithDirectoryAmbiguity") +{ + std::string ambiguousPath = getLuauDirectory(PathType::Relative) + "/tests/require/without_config/ambiguous_directory_requirer"; + + runProtectedRequire(ambiguousPath); + assertOutputContainsAll({"false", "require path could not be resolved to a unique file"}); +} + +TEST_CASE_FIXTURE(ReplWithPathFixture, "CheckCacheAfterRequireLuau") +{ + std::string relativePath = getLuauDirectory(PathType::Relative) + "/tests/require/without_config/module"; + std::string absolutePath = getLuauDirectory(PathType::Absolute) + "/tests/require/without_config/module"; + + luaL_findtable(L, LUA_REGISTRYINDEX, "_MODULES", 1); + lua_getfield(L, -1, (absolutePath + ".luau").c_str()); + REQUIRE_MESSAGE(lua_isnil(L, -1), "Cache already contained module result"); + + runProtectedRequire(relativePath); + + assertOutputContainsAll({"true", "result from dependency", "required into module"}); + + // Check cache for the absolute path as a cache key + luaL_findtable(L, LUA_REGISTRYINDEX, "_MODULES", 1); + lua_getfield(L, -1, (absolutePath + ".luau").c_str()); + REQUIRE_FALSE_MESSAGE(lua_isnil(L, -1), "Cache did not contain module result"); +} + +TEST_CASE_FIXTURE(ReplWithPathFixture, "CheckCacheAfterRequireLua") +{ + std::string relativePath = getLuauDirectory(PathType::Relative) + "/tests/require/without_config/lua_dependency"; + std::string absolutePath = getLuauDirectory(PathType::Absolute) + "/tests/require/without_config/lua_dependency"; + + luaL_findtable(L, LUA_REGISTRYINDEX, "_MODULES", 1); + lua_getfield(L, -1, (absolutePath + ".luau").c_str()); + REQUIRE_MESSAGE(lua_isnil(L, -1), "Cache already contained module result"); + + runProtectedRequire(relativePath); + + assertOutputContainsAll({"true", "result from lua_dependency"}); + + // Check cache for the absolute path as a cache key + luaL_findtable(L, LUA_REGISTRYINDEX, "_MODULES", 1); + lua_getfield(L, -1, (absolutePath + ".lua").c_str()); + REQUIRE_FALSE_MESSAGE(lua_isnil(L, -1), "Cache did not contain module result"); +} + +TEST_CASE_FIXTURE(ReplWithPathFixture, "CheckCacheAfterRequireInitLuau") +{ + std::string relativePath = getLuauDirectory(PathType::Relative) + "/tests/require/without_config/luau"; + std::string absolutePath = getLuauDirectory(PathType::Absolute) + "/tests/require/without_config/luau"; + + luaL_findtable(L, LUA_REGISTRYINDEX, "_MODULES", 1); + lua_getfield(L, -1, (absolutePath + "/init.luau").c_str()); + REQUIRE_MESSAGE(lua_isnil(L, -1), "Cache already contained module result"); + + runProtectedRequire(relativePath); + + assertOutputContainsAll({"true", "result from init.luau"}); + + // Check cache for the absolute path as a cache key + luaL_findtable(L, LUA_REGISTRYINDEX, "_MODULES", 1); + lua_getfield(L, -1, (absolutePath + "/init.luau").c_str()); + REQUIRE_FALSE_MESSAGE(lua_isnil(L, -1), "Cache did not contain module result"); +} + +TEST_CASE_FIXTURE(ReplWithPathFixture, "CheckCacheAfterRequireInitLua") +{ + std::string relativePath = getLuauDirectory(PathType::Relative) + "/tests/require/without_config/lua"; + std::string absolutePath = getLuauDirectory(PathType::Absolute) + "/tests/require/without_config/lua"; + + luaL_findtable(L, LUA_REGISTRYINDEX, "_MODULES", 1); + lua_getfield(L, -1, (absolutePath + "/init.lua").c_str()); + REQUIRE_MESSAGE(lua_isnil(L, -1), "Cache already contained module result"); + + runProtectedRequire(relativePath); + + assertOutputContainsAll({"true", "result from init.lua"}); + + // Check cache for the absolute path as a cache key + luaL_findtable(L, LUA_REGISTRYINDEX, "_MODULES", 1); + lua_getfield(L, -1, (absolutePath + "/init.lua").c_str()); + REQUIRE_FALSE_MESSAGE(lua_isnil(L, -1), "Cache did not contain module result"); +} + +TEST_CASE_FIXTURE(ReplWithPathFixture, "LoadStringRelative") +{ + runCode(L, "return pcall(function() return loadstring(\"require('a/relative/path')\")() end)"); + assertOutputContainsAll({"false", "require is not supported in this context"}); +} + +TEST_CASE_FIXTURE(ReplWithPathFixture, "RequireAbsolutePath") +{ +#ifdef _WIN32 + std::string absolutePath = "C:/an/absolute/path"; +#else + std::string absolutePath = "/an/absolute/path"; +#endif + runProtectedRequire(absolutePath); + assertOutputContainsAll({"false", "cannot require an absolute path"}); +} + +TEST_CASE_FIXTURE(ReplWithPathFixture, "RequireUnprefixedPath") +{ + std::string path = "an/unprefixed/path"; + runProtectedRequire(path); + assertOutputContainsAll({"false", "require path must start with a valid prefix: ./, ../, or @"}); +} + +TEST_CASE_FIXTURE(ReplWithPathFixture, "RequirePathWithAlias") +{ + std::string path = getLuauDirectory(PathType::Relative) + "/tests/require/with_config/src/alias_requirer"; + runProtectedRequire(path); + assertOutputContainsAll({"true", "result from dependency"}); +} + +TEST_CASE_FIXTURE(ReplWithPathFixture, "RequirePathWithParentAlias") +{ + std::string path = getLuauDirectory(PathType::Relative) + "/tests/require/with_config/src/parent_alias_requirer"; + runProtectedRequire(path); + assertOutputContainsAll({"true", "result from other_dependency"}); +} + +TEST_CASE_FIXTURE(ReplWithPathFixture, "RequirePathWithAliasPointingToDirectory") +{ + std::string path = getLuauDirectory(PathType::Relative) + "/tests/require/with_config/src/directory_alias_requirer"; + runProtectedRequire(path); + assertOutputContainsAll({"true", "result from subdirectory_dependency"}); +} + +TEST_CASE_FIXTURE(ReplWithPathFixture, "RequireAliasThatDoesNotExist") +{ + std::string nonExistentAlias = "@this.alias.does.not.exist"; + + runProtectedRequire(nonExistentAlias); + assertOutputContainsAll({"false", "@this.alias.does.not.exist is not a valid alias"}); +} + +TEST_CASE_FIXTURE(ReplWithPathFixture, "AliasHasIllegalFormat") +{ + std::string illegalCharacter = "@@"; + + runProtectedRequire(illegalCharacter); + assertOutputContainsAll({"false", "@@ is not a valid alias"}); + + std::string pathAlias1 = "@."; + + runProtectedRequire(pathAlias1); + assertOutputContainsAll({"false", ". is not a valid alias"}); + + + std::string pathAlias2 = "@.."; + + runProtectedRequire(pathAlias2); + assertOutputContainsAll({"false", ".. is not a valid alias"}); + + std::string emptyAlias = "@"; + + runProtectedRequire(emptyAlias); + assertOutputContainsAll({"false", " is not a valid alias"}); +} + +TEST_SUITE_END(); diff --git a/tests/RuntimeLimits.test.cpp b/tests/RuntimeLimits.test.cpp index 7e50d5b64..b4acf1384 100644 --- a/tests/RuntimeLimits.test.cpp +++ b/tests/RuntimeLimits.test.cpp @@ -13,21 +13,31 @@ #include "doctest.h" +#include + using namespace Luau; +LUAU_FASTINT(LuauTypeInferRecursionLimit) +LUAU_FASTFLAG(LuauSolverV2) + struct LimitFixture : BuiltinsFixture { #if defined(_NOOPT) || defined(_DEBUG) - ScopedFastInt LuauTypeInferRecursionLimit{"LuauTypeInferRecursionLimit", 100}; + ScopedFastInt LuauTypeInferRecursionLimit{FInt::LuauTypeInferRecursionLimit, 100}; #endif }; template bool hasError(const CheckResult& result, T* = nullptr) { - auto it = std::find_if(result.errors.begin(), result.errors.end(), [](const TypeError& a) { - return nullptr != get(a); - }); + auto it = std::find_if( + result.errors.begin(), + result.errors.end(), + [](const TypeError& a) + { + return nullptr != get(a); + } + ); return it != result.errors.end(); } @@ -35,6 +45,10 @@ TEST_SUITE_BEGIN("RuntimeLimits"); TEST_CASE_FIXTURE(LimitFixture, "typescript_port_of_Result_type") { + ScopedFastFlag sff[] = { + {FFlag::LuauSolverV2, false}, + }; + constexpr const char* src = R"LUA( --!strict @@ -263,9 +277,8 @@ TEST_CASE_FIXTURE(LimitFixture, "typescript_port_of_Result_type") )LUA"; CheckResult result = check(src); - CodeTooComplex ctc; - CHECK(hasError(result, &ctc)); + CHECK(hasError(result)); } TEST_SUITE_END(); diff --git a/tests/ScopedFlags.h b/tests/ScopedFlags.h index 9454307e0..beb3cc066 100644 --- a/tests/ScopedFlags.h +++ b/tests/ScopedFlags.h @@ -6,25 +6,18 @@ #include template -struct ScopedFValue +struct [[nodiscard]] ScopedFValue { private: Luau::FValue* value = nullptr; T oldValue = T(); public: - ScopedFValue(const char* name, T newValue) + ScopedFValue(Luau::FValue& fvalue, T newValue) { - for (Luau::FValue* v = Luau::FValue::list; v; v = v->next) - if (strcmp(v->name, name) == 0) - { - value = v; - oldValue = v->value; - v->value = newValue; - break; - } - - LUAU_ASSERT(value); + value = &fvalue; + oldValue = fvalue.value; + fvalue.value = newValue; } ScopedFValue(const ScopedFValue&) = delete; diff --git a/tests/Set.test.cpp b/tests/Set.test.cpp new file mode 100644 index 000000000..b3824bf1c --- /dev/null +++ b/tests/Set.test.cpp @@ -0,0 +1,146 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "ScopedFlags.h" +#include "Luau/Set.h" + +#include "doctest.h" + +#include +#include + +TEST_SUITE_BEGIN("SetTests"); + +TEST_CASE("empty_set_size_0") +{ + Luau::Set s1{0}; + CHECK(s1.size() == 0); + CHECK(s1.empty()); +} + +TEST_CASE("insertion_works_and_increases_size") +{ + Luau::Set s1{0}; + CHECK(s1.size() == 0); + CHECK(s1.empty()); + + s1.insert(1); + CHECK(s1.contains(1)); + CHECK(s1.size() == 1); + + s1.insert(2); + CHECK(s1.contains(2)); + CHECK(s1.size() == 2); +} + +TEST_CASE("clear_resets_size") +{ + Luau::Set s1{0}; + s1.insert(1); + s1.insert(2); + REQUIRE(s1.size() == 2); + + s1.clear(); + CHECK(s1.size() == 0); + CHECK(s1.empty()); +} + +TEST_CASE("erase_works_and_decreases_size") +{ + Luau::Set s1{0}; + s1.insert(1); + s1.insert(2); + CHECK(s1.size() == 2); + CHECK(s1.contains(1)); + CHECK(s1.contains(2)); + + s1.erase(1); + CHECK(s1.size() == 1); + CHECK(!s1.contains(1)); + CHECK(s1.contains(2)); + + s1.erase(2); + CHECK(s1.size() == 0); + CHECK(s1.empty()); + CHECK(!s1.contains(1)); + CHECK(!s1.contains(2)); +} + +TEST_CASE("iterate_over_set") +{ + Luau::Set s1{0}; + s1.insert(1); + s1.insert(2); + s1.insert(3); + REQUIRE(s1.size() == 3); + + int sum = 0; + + for (int e : s1) + sum += e; + + CHECK(sum == 6); +} + +TEST_CASE("iterate_over_set_skips_erased_elements") +{ + Luau::Set s1{0}; + s1.insert(1); + s1.insert(2); + s1.insert(3); + s1.insert(4); + s1.insert(5); + s1.insert(6); + REQUIRE(s1.size() == 6); + + s1.erase(2); + s1.erase(4); + s1.erase(6); + + int sum = 0; + + for (int e : s1) + sum += e; + + CHECK(sum == 9); +} + +TEST_CASE("iterate_over_set_skips_first_element_if_it_is_erased") +{ + /* + * As of this writing, in the following set, the key "y" happens to occur + * before "x" in the underlying DenseHashSet. This is important because it + * surfaces something that Set::const_iterator needs to do: If the + * underlying iterator happens to start at a deleted element, we need to + * advance until we find the first live element (or the end of the set). + */ + Luau::Set s1{{}}; + s1.insert("x"); + s1.insert("y"); + s1.erase("y"); + + std::vector out; + auto it = s1.begin(); + auto endIt = s1.end(); + while (it != endIt) + { + out.push_back(*it); + ++it; + } + + CHECK(1 == out.size()); +} + +TEST_CASE("erase_using_const_ref_argument") +{ + Luau::Set s1{{}}; + + s1.insert("x"); + s1.insert("y"); + + std::string key = "y"; + s1.erase(key); + + CHECK(s1.count("x")); + CHECK(!s1.count("y")); +} + +TEST_SUITE_END(); diff --git a/tests/SharedCodeAllocator.test.cpp b/tests/SharedCodeAllocator.test.cpp new file mode 100644 index 000000000..13bf9f988 --- /dev/null +++ b/tests/SharedCodeAllocator.test.cpp @@ -0,0 +1,454 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/SharedCodeAllocator.h" + +#include "Luau/CodeAllocator.h" + +#include "luacode.h" +#include "luacodegen.h" +#include "lualib.h" + +#include "doctest.h" +#include "ScopedFlags.h" + +// We explicitly test correctness of self-assignment for some types +#ifdef __clang__ +#pragma GCC diagnostic ignored "-Wself-assign-overloaded" +#endif + +using namespace Luau::CodeGen; + + +constexpr size_t kBlockSize = 1024 * 1024; +constexpr size_t kMaxTotalSize = 1024 * 1024; + +static const uint8_t fakeCode[1] = {0x00}; + +TEST_SUITE_BEGIN("SharedCodeAllocator"); + +TEST_CASE("NativeModuleRefRefcounting") +{ + if (!luau_codegen_supported()) + return; + + CodeAllocator codeAllocator{kBlockSize, kMaxTotalSize}; + SharedCodeAllocator allocator{&codeAllocator}; + + REQUIRE(allocator.tryGetNativeModule(ModuleId{0x0a}).empty()); + + NativeModuleRef modRefA = allocator.getOrInsertNativeModule(ModuleId{0x0a}, {}, nullptr, 0, fakeCode, std::size(fakeCode)).first; + REQUIRE(!modRefA.empty()); + + // If we attempt to get the module again, we should get the same module back: + REQUIRE(allocator.tryGetNativeModule(ModuleId{0x0a}).get() == modRefA.get()); + + // If we try to insert another instance of the module, we should get the + // existing module back: + REQUIRE(allocator.getOrInsertNativeModule(ModuleId{0x0a}, {}, nullptr, 0, fakeCode, std::size(fakeCode)).first.get() == modRefA.get()); + + // If we try to look up a different module, we should not get the existing + // module back: + REQUIRE(allocator.tryGetNativeModule(ModuleId{0x0b}).empty()); + + // (Insert a second module to help with validation below) + NativeModuleRef modRefB = allocator.getOrInsertNativeModule(ModuleId{0x0b}, {}, nullptr, 0, fakeCode, std::size(fakeCode)).first; + REQUIRE(!modRefB.empty()); + REQUIRE(modRefB.get() != modRefA.get()); + + // Verify NativeModuleRef refcounting: + REQUIRE(modRefA->getRefcount() == 1); + REQUIRE(modRefB->getRefcount() == 1); + + // NativeModuleRef non-null copy construction: + { + NativeModuleRef modRef1{modRefA}; + REQUIRE(modRef1.get() == modRefA.get()); + REQUIRE(modRefA->getRefcount() == 2); + } + + REQUIRE(modRefA->getRefcount() == 1); + REQUIRE(modRefB->getRefcount() == 1); + + // NativeModuleRef null copy construction: + { + NativeModuleRef modRef1{}; + NativeModuleRef modRef2{modRef1}; + REQUIRE(modRef1.empty()); + REQUIRE(modRef2.empty()); + } + + REQUIRE(modRefA->getRefcount() == 1); + REQUIRE(modRefB->getRefcount() == 1); + + // NativeModuleRef non-null move construction: + { + NativeModuleRef modRef1{modRefA}; + NativeModuleRef modRef2{std::move(modRef1)}; + REQUIRE(modRef1.empty()); + REQUIRE(modRef2.get() == modRefA.get()); + REQUIRE(modRefA->getRefcount() == 2); + } + + REQUIRE(modRefA->getRefcount() == 1); + REQUIRE(modRefB->getRefcount() == 1); + + // NativeModuleRef null move construction: + { + NativeModuleRef modRef1{}; + NativeModuleRef modRef2{std::move(modRef1)}; + REQUIRE(modRef1.empty()); + REQUIRE(modRef2.empty()); + } + + REQUIRE(modRefA->getRefcount() == 1); + REQUIRE(modRefB->getRefcount() == 1); + + // NativeModuleRef null -> non-null copy assignment: + { + NativeModuleRef modRef1{}; + modRef1 = modRefA; + REQUIRE(modRef1.get() == modRefA.get()); + REQUIRE(modRefA->getRefcount() == 2); + } + + REQUIRE(modRefA->getRefcount() == 1); + REQUIRE(modRefB->getRefcount() == 1); + + // NativeModuleRef null -> null copy assignment: + { + NativeModuleRef modRef1{}; + NativeModuleRef modRef2{}; + modRef2 = modRef1; + REQUIRE(modRef1.empty()); + REQUIRE(modRef2.empty()); + } + + REQUIRE(modRefA->getRefcount() == 1); + REQUIRE(modRefB->getRefcount() == 1); + + // NativeModuleRef self copy assignment: + { + NativeModuleRef modRef1{modRefA}; + modRef1 = modRef1; + REQUIRE(modRef1.get() == modRefA.get()); + REQUIRE(modRefA->getRefcount() == 2); + } + + REQUIRE(modRefA->getRefcount() == 1); + REQUIRE(modRefB->getRefcount() == 1); + + // NativeModuleRef non-null -> non-null copy assignment: + { + NativeModuleRef modRef1{modRefA}; + NativeModuleRef modRef2{modRefB}; + modRef2 = modRef1; + REQUIRE(modRef1.get() == modRefA.get()); + REQUIRE(modRef2.get() == modRefA.get()); + REQUIRE(modRefA->getRefcount() == 3); + REQUIRE(modRefB->getRefcount() == 1); + } + + REQUIRE(modRefA->getRefcount() == 1); + REQUIRE(modRefB->getRefcount() == 1); + + // NativeModuleRef null -> non-null move assignment: + { + NativeModuleRef modRef1{modRefA}; + NativeModuleRef modRef2{}; + modRef2 = std::move(modRef1); + REQUIRE(modRef1.empty()); + REQUIRE(modRef2.get() == modRefA.get()); + REQUIRE(modRefA->getRefcount() == 2); + } + + REQUIRE(modRefA->getRefcount() == 1); + REQUIRE(modRefB->getRefcount() == 1); + + // NativeModuleRef null -> null move assignment: + { + NativeModuleRef modRef1{}; + NativeModuleRef modRef2{}; + modRef2 = std::move(modRef1); + REQUIRE(modRef1.empty()); + REQUIRE(modRef2.empty()); + } + + REQUIRE(modRefA->getRefcount() == 1); + REQUIRE(modRefB->getRefcount() == 1); + +#if defined(__linux__) && defined(__GNUC__) +#else + // NativeModuleRef self move assignment: + { + NativeModuleRef modRef1{modRefA}; + modRef1 = std::move(modRef1); + REQUIRE(modRef1.get() == modRefA.get()); + REQUIRE(modRefA->getRefcount() == 2); + } + +#endif + + REQUIRE(modRefA->getRefcount() == 1); + REQUIRE(modRefB->getRefcount() == 1); + + // NativeModuleRef non-null -> non-null move assignment: + { + NativeModuleRef modRef1{modRefA}; + NativeModuleRef modRef2{modRefB}; + modRef2 = std::move(modRef1); + REQUIRE(modRef1.empty()); + REQUIRE(modRef2.get() == modRefA.get()); + REQUIRE(modRefA->getRefcount() == 2); + REQUIRE(modRefB->getRefcount() == 1); + } + + REQUIRE(modRefA->getRefcount() == 1); + REQUIRE(modRefB->getRefcount() == 1); + + // NativeModuleRef null reset: + { + NativeModuleRef modRef1{}; + modRef1.reset(); + REQUIRE(modRef1.empty()); + } + + REQUIRE(modRefA->getRefcount() == 1); + REQUIRE(modRefB->getRefcount() == 1); + + // NativeModuleRef non-null reset: + { + NativeModuleRef modRef1{modRefA}; + modRef1.reset(); + REQUIRE(modRef1.empty()); + REQUIRE(modRefA->getRefcount() == 1); + } + + REQUIRE(modRefA->getRefcount() == 1); + REQUIRE(modRefB->getRefcount() == 1); + + // NativeModuleRef swap: + { + NativeModuleRef modRef1{modRefA}; + NativeModuleRef modRef2{modRefB}; + modRef1.swap(modRef2); + REQUIRE(modRef1.get() == modRefB.get()); + REQUIRE(modRef2.get() == modRefA.get()); + REQUIRE(modRefA->getRefcount() == 2); + REQUIRE(modRefB->getRefcount() == 2); + } + + REQUIRE(modRefA->getRefcount() == 1); + REQUIRE(modRefB->getRefcount() == 1); + + // If we release the last reference to a module, it should destroy the + // module: + modRefA.reset(); + REQUIRE(allocator.tryGetNativeModule(ModuleId{0x0a}).empty()); +} + +TEST_CASE("NativeProtoRefcounting") +{ + if (!luau_codegen_supported()) + return; + + CodeAllocator codeAllocator{kBlockSize, kMaxTotalSize}; + SharedCodeAllocator allocator{&codeAllocator}; + + std::vector nativeProtos; + nativeProtos.reserve(1); + NativeProtoExecDataPtr nativeProto = createNativeProtoExecData(0); + getNativeProtoExecDataHeader(nativeProto.get()).bytecodeId = 0x01; + nativeProtos.push_back(std::move(nativeProto)); + + NativeModuleRef modRefA = + allocator.getOrInsertNativeModule(ModuleId{0x0a}, std::move(nativeProtos), nullptr, 0, fakeCode, std::size(fakeCode)).first; + REQUIRE(!modRefA.empty()); + REQUIRE(modRefA->getRefcount() == 1); + + // Verify behavior of addRef: + modRefA->addRef(); + REQUIRE(modRefA->getRefcount() == 2); + + // Verify behavior of addRefs: + modRefA->addRefs(2); + REQUIRE(modRefA->getRefcount() == 4); + + // Undo two of our addRef(s): + modRefA->release(); + REQUIRE(modRefA->getRefcount() == 3); + + modRefA->release(); + REQUIRE(modRefA->getRefcount() == 2); + + // If we release our NativeModuleRef, the module should be kept alive by + // the owning reference we acquired: + modRefA.reset(); + + modRefA = allocator.tryGetNativeModule(ModuleId{0x0a}); + REQUIRE(!modRefA.empty()); + REQUIRE(modRefA->getRefcount() == 2); + + // If the last "release" comes via releaseOwningPointerToInstructionOffsets, + // the module should be successfully destroyed: + const NativeModule* rawModA = modRefA.get(); + + modRefA.reset(); + rawModA->release(); + REQUIRE(allocator.tryGetNativeModule(ModuleId{0x0a}).empty()); +} + +TEST_CASE("NativeProtoState") +{ + if (!luau_codegen_supported()) + return; + + CodeAllocator codeAllocator{kBlockSize, kMaxTotalSize}; + SharedCodeAllocator allocator{&codeAllocator}; + + const std::vector data(16); + const std::vector code(16); + + std::vector nativeProtos; + nativeProtos.reserve(2); + + { + NativeProtoExecDataPtr nativeProto = createNativeProtoExecData(2); + getNativeProtoExecDataHeader(nativeProto.get()).bytecodeId = 1; + getNativeProtoExecDataHeader(nativeProto.get()).entryOffsetOrAddress = reinterpret_cast(0x00); + nativeProto[0] = 0; + nativeProto[1] = 4; + + nativeProtos.push_back(std::move(nativeProto)); + } + + { + NativeProtoExecDataPtr nativeProto = createNativeProtoExecData(2); + getNativeProtoExecDataHeader(nativeProto.get()).bytecodeId = 3; + getNativeProtoExecDataHeader(nativeProto.get()).entryOffsetOrAddress = reinterpret_cast(0x08); + nativeProto[0] = 8; + nativeProto[1] = 12; + + nativeProtos.push_back(std::move(nativeProto)); + } + + NativeModuleRef modRefA = + allocator.getOrInsertNativeModule(ModuleId{0x0a}, std::move(nativeProtos), data.data(), data.size(), code.data(), code.size()).first; + REQUIRE(!modRefA.empty()); + REQUIRE(modRefA->getModuleBaseAddress() != nullptr); + + const uint32_t* proto1 = modRefA->tryGetNativeProto(1); + REQUIRE(proto1 != nullptr); + REQUIRE(getNativeProtoExecDataHeader(proto1).bytecodeId == 1); + REQUIRE(getNativeProtoExecDataHeader(proto1).entryOffsetOrAddress == modRefA->getModuleBaseAddress() + 0x00); + REQUIRE(proto1[0] == 0); + REQUIRE(proto1[1] == 4); + + const uint32_t* proto3 = modRefA->tryGetNativeProto(3); + REQUIRE(proto3 != nullptr); + REQUIRE(getNativeProtoExecDataHeader(proto3).bytecodeId == 3); + REQUIRE(getNativeProtoExecDataHeader(proto3).entryOffsetOrAddress == modRefA->getModuleBaseAddress() + 0x08); + REQUIRE(proto3[0] == 8); + REQUIRE(proto3[1] == 12); + + // Ensure that non-existent native protos cannot be found: + REQUIRE(modRefA->tryGetNativeProto(0) == nullptr); + REQUIRE(modRefA->tryGetNativeProto(2) == nullptr); + REQUIRE(modRefA->tryGetNativeProto(4) == nullptr); +} + +TEST_CASE("AnonymousModuleLifetime") +{ + if (!luau_codegen_supported()) + return; + + CodeAllocator codeAllocator{kBlockSize, kMaxTotalSize}; + SharedCodeAllocator allocator{&codeAllocator}; + + const std::vector data(8); + const std::vector code(8); + + std::vector nativeProtos; + nativeProtos.reserve(1); + + { + NativeProtoExecDataPtr nativeProto = createNativeProtoExecData(2); + getNativeProtoExecDataHeader(nativeProto.get()).bytecodeId = 1; + getNativeProtoExecDataHeader(nativeProto.get()).entryOffsetOrAddress = reinterpret_cast(0x00); + nativeProto[0] = 0; + nativeProto[1] = 4; + + nativeProtos.push_back(std::move(nativeProto)); + } + + NativeModuleRef modRef = allocator.insertAnonymousNativeModule(std::move(nativeProtos), data.data(), data.size(), code.data(), code.size()); + REQUIRE(!modRef.empty()); + REQUIRE(modRef->getModuleBaseAddress() != nullptr); + REQUIRE(modRef->tryGetNativeProto(1) != nullptr); + REQUIRE(modRef->getRefcount() == 1); + + const NativeModule* mod = modRef.get(); + + // Acquire a reference (as if we are binding it to a Luau VM Proto): + modRef->addRef(); + REQUIRE(mod->getRefcount() == 2); + + // Release our "owning" reference: + modRef.reset(); + REQUIRE(mod->getRefcount() == 1); + + // Release our added reference (as if the Luau VM Proto is being GC'ed): + mod->release(); + + // When we return and the sharedCodeAllocator is destroyed it will verify + // that there are no outstanding anonymous NativeModules. +} + +TEST_CASE("SharedAllocation") +{ + if (!luau_codegen_supported()) + return; + + UniqueSharedCodeGenContext sharedCodeGenContext = createSharedCodeGenContext(); + + std::unique_ptr L1{luaL_newstate(), lua_close}; + std::unique_ptr L2{luaL_newstate(), lua_close}; + + create(L1.get(), sharedCodeGenContext.get()); + create(L2.get(), sharedCodeGenContext.get()); + + std::string source = R"( + function add(x, y) return x + y end + function sub(x, y) return x - y end + )"; + + size_t bytecodeSize = 0; + std::unique_ptr bytecode{luau_compile(source.data(), source.size(), nullptr, &bytecodeSize), free}; + const int loadResult1 = luau_load(L1.get(), "=Functions", bytecode.get(), bytecodeSize, 0); + const int loadResult2 = luau_load(L2.get(), "=Functions", bytecode.get(), bytecodeSize, 0); + REQUIRE(loadResult1 == 0); + REQUIRE(loadResult2 == 0); + bytecode.reset(); + + const ModuleId moduleId = {0x01}; + + CompilationOptions options; + options.flags = CodeGen_ColdFunctions; + + CompilationStats nativeStats1 = {}; + CompilationStats nativeStats2 = {}; + const CompilationResult codeGenResult1 = Luau::CodeGen::compile(moduleId, L1.get(), -1, options, &nativeStats1); + const CompilationResult codeGenResult2 = Luau::CodeGen::compile(moduleId, L2.get(), -1, options, &nativeStats2); + REQUIRE(codeGenResult1.result == CodeGenCompilationResult::Success); + REQUIRE(codeGenResult2.result == CodeGenCompilationResult::Success); + + // We should have identified all three functions both times through: + REQUIRE(nativeStats1.functionsTotal == 3); + REQUIRE(nativeStats2.functionsTotal == 3); + + // We should have compiled the three functions only the first time: + REQUIRE(nativeStats1.functionsCompiled == 3); + REQUIRE(nativeStats2.functionsCompiled == 0); + + // We should have bound all three functions both times through: + REQUIRE(nativeStats1.functionsBound == 3); + REQUIRE(nativeStats2.functionsBound == 3); +} diff --git a/tests/Simplify.test.cpp b/tests/Simplify.test.cpp new file mode 100644 index 000000000..92ac68dec --- /dev/null +++ b/tests/Simplify.test.cpp @@ -0,0 +1,622 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Fixture.h" + +#include "doctest.h" + +#include "Luau/Simplify.h" + +using namespace Luau; + +LUAU_FASTFLAG(LuauSolverV2) +LUAU_DYNAMIC_FASTINT(LuauSimplificationComplexityLimit) + +namespace +{ + +struct SimplifyFixture : Fixture +{ + TypeArena _arena; + const NotNull arena{&_arena}; + + ToStringOptions opts; + + Scope scope{builtinTypes->anyTypePack}; + + const TypeId anyTy = builtinTypes->anyType; + const TypeId unknownTy = builtinTypes->unknownType; + const TypeId neverTy = builtinTypes->neverType; + const TypeId errorTy = builtinTypes->errorType; + + const TypeId functionTy = builtinTypes->functionType; + const TypeId tableTy = builtinTypes->tableType; + + const TypeId numberTy = builtinTypes->numberType; + const TypeId stringTy = builtinTypes->stringType; + const TypeId booleanTy = builtinTypes->booleanType; + const TypeId nilTy = builtinTypes->nilType; + + const TypeId classTy = builtinTypes->classType; + + const TypeId trueTy = builtinTypes->trueType; + const TypeId falseTy = builtinTypes->falseType; + + const TypeId truthyTy = builtinTypes->truthyType; + const TypeId falsyTy = builtinTypes->falsyType; + + const TypeId freeTy = freshType(arena, builtinTypes, &scope); + const TypeId genericTy = arena->addType(GenericType{}); + const TypeId blockedTy = arena->addType(BlockedType{}); + const TypeId pendingTy = arena->addType(PendingExpansionType{{}, {}, {}, {}}); + + const TypeId helloTy = arena->addType(SingletonType{StringSingleton{"hello"}}); + const TypeId worldTy = arena->addType(SingletonType{StringSingleton{"world"}}); + + const TypePackId emptyTypePack = arena->addTypePack({}); + + const TypeId fn1Ty = arena->addType(FunctionType{emptyTypePack, emptyTypePack}); + const TypeId fn2Ty = arena->addType(FunctionType{builtinTypes->anyTypePack, emptyTypePack}); + + TypeId parentClassTy = nullptr; + TypeId childClassTy = nullptr; + TypeId anotherChildClassTy = nullptr; + TypeId unrelatedClassTy = nullptr; + + ScopedFastFlag sff{FFlag::LuauSolverV2, true}; + + SimplifyFixture() + { + createSomeClasses(&frontend); + + parentClassTy = frontend.globals.globalScope->linearSearchForBinding("Parent")->typeId; + childClassTy = frontend.globals.globalScope->linearSearchForBinding("Child")->typeId; + anotherChildClassTy = frontend.globals.globalScope->linearSearchForBinding("AnotherChild")->typeId; + unrelatedClassTy = frontend.globals.globalScope->linearSearchForBinding("Unrelated")->typeId; + } + + TypeId intersect(TypeId a, TypeId b) + { + return simplifyIntersection(builtinTypes, arena, a, b).result; + } + + std::string intersectStr(TypeId a, TypeId b) + { + return toString(intersect(a, b), opts); + } + + bool isIntersection(TypeId a) + { + return bool(get(follow(a))); + } + + TypeId mkTable(std::map propTypes) + { + TableType::Props props; + for (const auto& [name, ty] : propTypes) + props[name] = Property{ty}; + + return arena->addType(TableType{props, {}, TypeLevel{}, TableState::Sealed}); + } + + TypeId mkNegation(TypeId ty) + { + return arena->addType(NegationType{ty}); + } + + TypeId mkFunction(TypeId arg, TypeId ret) + { + return arena->addType(FunctionType{arena->addTypePack({arg}), arena->addTypePack({ret})}); + } + + TypeId union_(TypeId a, TypeId b) + { + return simplifyUnion(builtinTypes, arena, a, b).result; + } +}; + +} // namespace + +TEST_SUITE_BEGIN("Simplify"); + +TEST_CASE_FIXTURE(SimplifyFixture, "overload_negation_refinement_is_never") +{ + TypeId f1 = mkFunction(stringTy, numberTy); + TypeId f2 = mkFunction(numberTy, stringTy); + TypeId intersection = arena->addType(IntersectionType{{f1, f2}}); + TypeId unionT = arena->addType(UnionType{{errorTy, functionTy}}); + TypeId negationT = mkNegation(unionT); + // The intersection of string -> number & number -> string, ~(error | function) + CHECK(neverTy == intersect(intersection, negationT)); +} + +TEST_CASE_FIXTURE(SimplifyFixture, "unknown_and_other_tops_and_bottom_types") +{ + + CHECK(unknownTy == intersect(unknownTy, unknownTy)); + + CHECK("any" == intersectStr(unknownTy, anyTy)); + CHECK("any" == intersectStr(anyTy, unknownTy)); + + CHECK(neverTy == intersect(unknownTy, neverTy)); + CHECK(neverTy == intersect(neverTy, unknownTy)); + + CHECK(errorTy == intersect(unknownTy, errorTy)); + CHECK(errorTy == intersect(errorTy, unknownTy)); +} + +TEST_CASE_FIXTURE(SimplifyFixture, "nil") +{ + CHECK(nilTy == intersect(nilTy, nilTy)); + CHECK(neverTy == intersect(nilTy, numberTy)); + CHECK(neverTy == intersect(nilTy, trueTy)); + CHECK(neverTy == intersect(nilTy, tableTy)); +} + +TEST_CASE_FIXTURE(SimplifyFixture, "boolean_singletons") +{ + CHECK(trueTy == intersect(trueTy, booleanTy)); + CHECK(trueTy == intersect(booleanTy, trueTy)); + + CHECK(falseTy == intersect(falseTy, booleanTy)); + CHECK(falseTy == intersect(booleanTy, falseTy)); + + CHECK(neverTy == intersect(falseTy, trueTy)); + CHECK(neverTy == intersect(trueTy, falseTy)); + + CHECK(booleanTy == union_(trueTy, booleanTy)); + CHECK(booleanTy == union_(booleanTy, trueTy)); + CHECK(booleanTy == union_(falseTy, booleanTy)); + CHECK(booleanTy == union_(booleanTy, falseTy)); + CHECK(booleanTy == union_(falseTy, trueTy)); +} + +TEST_CASE_FIXTURE(SimplifyFixture, "boolean_and_truthy_and_falsy") +{ + TypeId optionalBooleanTy = arena->addType(UnionType{{booleanTy, nilTy}}); + + CHECK(trueTy == intersect(booleanTy, truthyTy)); + + CHECK(trueTy == intersect(optionalBooleanTy, truthyTy)); +} + +TEST_CASE_FIXTURE(SimplifyFixture, "any_and_indeterminate_types") +{ + CHECK("'a | *error-type*" == intersectStr(anyTy, freeTy)); + CHECK("'a | *error-type*" == intersectStr(freeTy, anyTy)); + + CHECK("*error-type* | b" == intersectStr(anyTy, genericTy)); + CHECK("*error-type* | b" == intersectStr(genericTy, anyTy)); + + auto anyRhsBlocked = get(intersect(anyTy, blockedTy)); + auto anyLhsBlocked = get(intersect(blockedTy, anyTy)); + + REQUIRE(anyRhsBlocked); + REQUIRE(anyRhsBlocked->options.size() == 2); + CHECK(blockedTy == anyRhsBlocked->options[0]); + CHECK(errorTy == anyRhsBlocked->options[1]); + + REQUIRE(anyLhsBlocked); + REQUIRE(anyLhsBlocked->options.size() == 2); + CHECK(blockedTy == anyLhsBlocked->options[0]); + CHECK(errorTy == anyLhsBlocked->options[1]); + + auto anyRhsPending = get(intersect(anyTy, pendingTy)); + auto anyLhsPending = get(intersect(pendingTy, anyTy)); + + REQUIRE(anyRhsPending); + REQUIRE(anyRhsPending->options.size() == 2); + CHECK(pendingTy == anyRhsPending->options[0]); + CHECK(errorTy == anyRhsPending->options[1]); + + REQUIRE(anyLhsPending); + REQUIRE(anyLhsPending->options.size() == 2); + CHECK(pendingTy == anyLhsPending->options[0]); + CHECK(errorTy == anyLhsPending->options[1]); +} + +TEST_CASE_FIXTURE(SimplifyFixture, "union_where_lhs_elements_are_a_subset_of_the_rhs") +{ + TypeId lhs = union_(numberTy, stringTy); + TypeId rhs = union_(stringTy, numberTy); + + CHECK("number | string" == toString(union_(lhs, rhs))); +} + +TEST_CASE_FIXTURE(SimplifyFixture, "unknown_and_indeterminate_types") +{ + CHECK(freeTy == intersect(unknownTy, freeTy)); + CHECK(freeTy == intersect(freeTy, unknownTy)); + + CHECK(genericTy == intersect(unknownTy, genericTy)); + CHECK(genericTy == intersect(genericTy, unknownTy)); + + CHECK(blockedTy == intersect(unknownTy, blockedTy)); + CHECK(blockedTy == intersect(unknownTy, blockedTy)); + + CHECK(pendingTy == intersect(unknownTy, pendingTy)); + CHECK(pendingTy == intersect(unknownTy, pendingTy)); +} + +TEST_CASE_FIXTURE(SimplifyFixture, "unknown_and_concrete") +{ + CHECK(numberTy == intersect(numberTy, unknownTy)); + CHECK(numberTy == intersect(unknownTy, numberTy)); + CHECK(trueTy == intersect(trueTy, unknownTy)); + CHECK(trueTy == intersect(unknownTy, trueTy)); +} + +TEST_CASE_FIXTURE(SimplifyFixture, "error_and_other_tops_and_bottom_types") +{ + CHECK(errorTy == intersect(errorTy, errorTy)); + + CHECK(errorTy == intersect(errorTy, anyTy)); + CHECK(errorTy == intersect(anyTy, errorTy)); + + CHECK(neverTy == intersect(errorTy, neverTy)); + CHECK(neverTy == intersect(neverTy, errorTy)); +} + +TEST_CASE_FIXTURE(SimplifyFixture, "error_and_indeterminate_types") +{ + CHECK("'a & *error-type*" == intersectStr(errorTy, freeTy)); + CHECK("'a & *error-type*" == intersectStr(freeTy, errorTy)); + + CHECK("*error-type* & b" == intersectStr(errorTy, genericTy)); + CHECK("*error-type* & b" == intersectStr(genericTy, errorTy)); + + CHECK(isIntersection(intersect(errorTy, blockedTy))); + CHECK(isIntersection(intersect(blockedTy, errorTy))); + + CHECK(isIntersection(intersect(errorTy, pendingTy))); + CHECK(isIntersection(intersect(pendingTy, errorTy))); +} + +TEST_CASE_FIXTURE(SimplifyFixture, "unknown_and_concrete") +{ + CHECK(neverTy == intersect(numberTy, errorTy)); + CHECK(neverTy == intersect(errorTy, numberTy)); + CHECK(neverTy == intersect(trueTy, errorTy)); + CHECK(neverTy == intersect(errorTy, trueTy)); +} + +TEST_CASE_FIXTURE(SimplifyFixture, "primitives") +{ + // This shouldn't be possible, but we'll make it work even if it is. + TypeId numberTyDuplicate = arena->addType(PrimitiveType{PrimitiveType::Number}); + + CHECK(numberTy == intersect(numberTy, numberTyDuplicate)); + CHECK(neverTy == intersect(numberTy, stringTy)); + + CHECK(neverTy == intersect(neverTy, numberTy)); + CHECK(neverTy == intersect(numberTy, neverTy)); + + CHECK(neverTy == intersect(neverTy, functionTy)); + CHECK(neverTy == intersect(functionTy, neverTy)); + + CHECK(neverTy == intersect(neverTy, tableTy)); + CHECK(neverTy == intersect(tableTy, neverTy)); + + CHECK("*error-type* | number" == intersectStr(anyTy, numberTy)); + CHECK("*error-type* | number" == intersectStr(numberTy, anyTy)); + + CHECK(neverTy == intersect(stringTy, nilTy)); + CHECK(neverTy == intersect(nilTy, stringTy)); +} + +TEST_CASE_FIXTURE(SimplifyFixture, "primitives_and_falsy") +{ + CHECK(neverTy == intersect(numberTy, falsyTy)); + CHECK(neverTy == intersect(falsyTy, numberTy)); + + CHECK(nilTy == intersect(nilTy, falsyTy)); + CHECK(nilTy == intersect(falsyTy, nilTy)); +} + +TEST_CASE_FIXTURE(SimplifyFixture, "primitives_and_singletons") +{ + CHECK(helloTy == intersect(helloTy, stringTy)); + CHECK(helloTy == intersect(stringTy, helloTy)); + + CHECK(neverTy == intersect(worldTy, helloTy)); +} + +TEST_CASE_FIXTURE(SimplifyFixture, "functions") +{ + CHECK(fn1Ty == intersect(fn1Ty, functionTy)); + CHECK(fn1Ty == intersect(functionTy, fn1Ty)); + + // Intersections of functions are super weird if you think about it. + CHECK("(() -> ()) & ((...any) -> ())" == intersectStr(fn1Ty, fn2Ty)); +} + +TEST_CASE_FIXTURE(SimplifyFixture, "negated_top_function_type") +{ + TypeId negatedFunctionTy = mkNegation(functionTy); + + CHECK(numberTy == intersect(numberTy, negatedFunctionTy)); + CHECK(numberTy == intersect(negatedFunctionTy, numberTy)); + + CHECK(falsyTy == intersect(falsyTy, negatedFunctionTy)); + CHECK(falsyTy == intersect(negatedFunctionTy, falsyTy)); + + TypeId f = mkFunction(stringTy, numberTy); + + CHECK(neverTy == intersect(f, negatedFunctionTy)); + CHECK(neverTy == intersect(negatedFunctionTy, f)); +} + +TEST_CASE_FIXTURE(SimplifyFixture, "optional_overloaded_function_and_top_function") +{ + // (((number) -> string) & ((string) -> number))? & ~function + + TypeId f1 = mkFunction(numberTy, stringTy); + TypeId f2 = mkFunction(stringTy, numberTy); + + TypeId f12 = arena->addType(IntersectionType{{f1, f2}}); + + TypeId t = arena->addType(UnionType{{f12, nilTy}}); + + TypeId notFunctionTy = mkNegation(functionTy); + + CHECK(nilTy == intersect(t, notFunctionTy)); + CHECK(nilTy == intersect(notFunctionTy, t)); +} + +TEST_CASE_FIXTURE(SimplifyFixture, "negated_function_does_not_intersect_cleanly_with_truthy") +{ + // ~function & ~(false?) + // ~function & ~(false | nil) + // ~function & ~false & ~nil + + TypeId negatedFunctionTy = mkNegation(functionTy); + CHECK(isIntersection(intersect(negatedFunctionTy, truthyTy))); +} + +TEST_CASE_FIXTURE(SimplifyFixture, "tables") +{ + TypeId t1 = mkTable({{"tag", stringTy}}); + + CHECK(t1 == intersect(t1, tableTy)); + CHECK(neverTy == intersect(t1, functionTy)); + + TypeId t2 = mkTable({{"tag", helloTy}}); + + CHECK(t2 == intersect(t1, t2)); + CHECK(t2 == intersect(t2, t1)); + + TypeId t3 = mkTable({}); + // {tag : string} intersect {} + CHECK(t1 == intersect(t1, t3)); + CHECK(t1 == intersect(t3, t1)); +} + +TEST_CASE_FIXTURE(SimplifyFixture, "combine_disjoint_sealed_tables") +{ + TypeId t1 = mkTable({{"prop", stringTy}}); + TypeId t2 = mkTable({{"second_prop", numberTy}}); + + CHECK("{ prop: string, second_prop: number }" == toString(intersect(t1, t2))); +} + +TEST_CASE_FIXTURE(SimplifyFixture, "non_disjoint_tables_do_not_simplify") +{ + TypeId t1 = mkTable({{"prop", stringTy}}); + TypeId t2 = mkTable({{"prop", unknownTy}, {"second_prop", numberTy}}); + + CHECK("{ prop: string } & { prop: unknown, second_prop: number }" == toString(intersect(t1, t2))); +} + +// Simplification has an extra code path especially for intersections with +// single-property tables, so it's worthwhile to separately test the case where +// both tables have multiple properties. +TEST_CASE_FIXTURE(SimplifyFixture, "non_disjoint_tables_do_not_simplify_2") +{ + TypeId t1 = mkTable({{"prop", stringTy}, {"third_prop", numberTy}}); + TypeId t2 = mkTable({{"prop", unknownTy}, {"second_prop", numberTy}}); + + CHECK("{ prop: string, third_prop: number } & { prop: unknown, second_prop: number }" == toString(intersect(t1, t2))); +} + +TEST_CASE_FIXTURE(SimplifyFixture, "tables_and_top_table") +{ + TypeId notTableType = mkNegation(tableTy); + TypeId t1 = mkTable({{"prop", stringTy}, {"another", numberTy}}); + + CHECK(t1 == intersect(t1, tableTy)); + CHECK(t1 == intersect(tableTy, t1)); + + CHECK(neverTy == intersect(t1, notTableType)); + CHECK(neverTy == intersect(notTableType, t1)); +} + +TEST_CASE_FIXTURE(SimplifyFixture, "tables_and_truthy") +{ + TypeId t1 = mkTable({{"prop", stringTy}, {"another", numberTy}}); + + CHECK(t1 == intersect(t1, truthyTy)); + CHECK(t1 == intersect(truthyTy, t1)); +} + +TEST_CASE_FIXTURE(SimplifyFixture, "table_with_a_tag") +{ + // {tag: string, prop: number} & {tag: "hello"} + // I think we can decline to simplify this: + TypeId t1 = mkTable({{"tag", stringTy}, {"prop", numberTy}}); + TypeId t2 = mkTable({{"tag", helloTy}}); + + CHECK("{ prop: number, tag: string } & { tag: \"hello\" }" == intersectStr(t1, t2)); + CHECK("{ prop: number, tag: string } & { tag: \"hello\" }" == intersectStr(t2, t1)); +} + +TEST_CASE_FIXTURE(SimplifyFixture, "nested_table_tag_test") +{ + TypeId t1 = mkTable({ + {"subtable", + mkTable({ + {"tag", helloTy}, + {"subprop", numberTy}, + })}, + {"prop", stringTy}, + }); + TypeId t2 = mkTable({ + {"subtable", + mkTable({ + {"tag", helloTy}, + })}, + }); + + CHECK(t1 == intersect(t1, t2)); + CHECK(t1 == intersect(t2, t1)); +} + +TEST_CASE_FIXTURE(SimplifyFixture, "union") +{ + TypeId t1 = arena->addType(UnionType{{numberTy, stringTy, nilTy, tableTy}}); + + CHECK(nilTy == intersect(t1, nilTy)); + // CHECK(nilTy == intersect(nilTy, t1)); // TODO? + + CHECK(builtinTypes->stringType == intersect(builtinTypes->optionalStringType, truthyTy)); +} + +TEST_CASE_FIXTURE(SimplifyFixture, "two_unions") +{ + ScopedFastInt sfi{DFInt::LuauSimplificationComplexityLimit, 10}; + TypeId t1 = arena->addType(UnionType{{numberTy, booleanTy, stringTy, nilTy, tableTy}}); + + CHECK("false?" == intersectStr(t1, falsyTy)); +} + +TEST_CASE_FIXTURE(SimplifyFixture, "curious_union") +{ + // (a & false) | (a & nil) + TypeId curious = + arena->addType(UnionType{{arena->addType(IntersectionType{{freeTy, falseTy}}), arena->addType(IntersectionType{{freeTy, nilTy}})}}); + + CHECK("('a & false) | ('a & nil) | number" == toString(union_(curious, numberTy))); +} + +TEST_CASE_FIXTURE(SimplifyFixture, "negations") +{ + TypeId notNumberTy = mkNegation(numberTy); + TypeId notStringTy = mkNegation(stringTy); + + CHECK(neverTy == intersect(numberTy, notNumberTy)); + + CHECK(numberTy == intersect(numberTy, notStringTy)); + CHECK(numberTy == intersect(notStringTy, numberTy)); +} + +TEST_CASE_FIXTURE(SimplifyFixture, "top_class_type") +{ + CHECK(neverTy == intersect(classTy, stringTy)); +} + +TEST_CASE_FIXTURE(SimplifyFixture, "classes") +{ + CHECK(childClassTy == intersect(childClassTy, parentClassTy)); + CHECK(childClassTy == intersect(parentClassTy, childClassTy)); + + CHECK(parentClassTy == union_(childClassTy, parentClassTy)); + CHECK(parentClassTy == union_(parentClassTy, childClassTy)); + + CHECK(neverTy == intersect(childClassTy, unrelatedClassTy)); +} + +TEST_CASE_FIXTURE(SimplifyFixture, "negations_of_classes") +{ + TypeId notChildClassTy = mkNegation(childClassTy); + TypeId notParentClassTy = mkNegation(parentClassTy); + + CHECK(neverTy == intersect(childClassTy, notParentClassTy)); + CHECK(neverTy == intersect(notParentClassTy, childClassTy)); + + CHECK("Parent & ~Child" == intersectStr(notChildClassTy, parentClassTy)); + CHECK("Parent & ~Child" == intersectStr(parentClassTy, notChildClassTy)); + + CHECK(notParentClassTy == intersect(notChildClassTy, notParentClassTy)); + CHECK(notParentClassTy == intersect(notParentClassTy, notChildClassTy)); +} + +TEST_CASE_FIXTURE(SimplifyFixture, "intersection_of_intersection_of_a_free_type_can_result_in_removal_of_that_free_type") +{ + // a & string and number + // (a & number) & (string & number) + + TypeId t1 = arena->addType(IntersectionType{{freeTy, stringTy}}); + + CHECK(neverTy == intersect(t1, numberTy)); +} + +TEST_CASE_FIXTURE(SimplifyFixture, "some_tables_are_really_never") +{ + TypeId notAnyTy = mkNegation(anyTy); + + TypeId t1 = mkTable({{"someKey", notAnyTy}}); + + CHECK(neverTy == intersect(t1, numberTy)); + CHECK(neverTy == intersect(numberTy, t1)); + CHECK(t1 == intersect(t1, t1)); + + TypeId notUnknownTy = mkNegation(unknownTy); + + TypeId t2 = mkTable({{"someKey", notUnknownTy}}); + + CHECK(neverTy == intersect(t2, numberTy)); + CHECK(neverTy == intersect(numberTy, t2)); + CHECK(neverTy == intersect(t2, t2)); +} + +TEST_CASE_FIXTURE(SimplifyFixture, "simplify_stops_at_cycles") +{ + TypeId t = mkTable({}); + TableType* tt = getMutable(t); + REQUIRE(tt); + + TypeId t2 = mkTable({}); + TableType* t2t = getMutable(t2); + REQUIRE(t2t); + + tt->props["cyclic"] = Property{t2}; + t2t->props["cyclic"] = Property{t}; + + CHECK(t == intersect(t, unknownTy)); + CHECK(t == intersect(unknownTy, t)); + + CHECK(t2 == intersect(t2, unknownTy)); + CHECK(t2 == intersect(unknownTy, t2)); + + CHECK("*error-type* | t1 where t1 = { cyclic: { cyclic: t1 } }" == intersectStr(t, anyTy)); + CHECK("*error-type* | t1 where t1 = { cyclic: { cyclic: t1 } }" == intersectStr(anyTy, t)); + + CHECK("*error-type* | t1 where t1 = { cyclic: { cyclic: t1 } }" == intersectStr(t2, anyTy)); + CHECK("*error-type* | t1 where t1 = { cyclic: { cyclic: t1 } }" == intersectStr(anyTy, t2)); +} + +TEST_CASE_FIXTURE(SimplifyFixture, "free_type_bound_by_any_with_any") +{ + CHECK("'a | *error-type*" == intersectStr(freeTy, anyTy)); + CHECK("'a | *error-type*" == intersectStr(anyTy, freeTy)); + + CHECK("'a | *error-type*" == intersectStr(freeTy, anyTy)); + CHECK("'a | *error-type*" == intersectStr(anyTy, freeTy)); +} + +TEST_CASE_FIXTURE(SimplifyFixture, "bound_intersected_by_itself_should_be_itself") +{ + TypeId blocked = arena->addType(BlockedType{}); + CHECK(toString(blocked) == intersectStr(blocked, blocked)); +} + +TEST_CASE_FIXTURE(SimplifyFixture, "cyclic_never_union_and_string") +{ + // t1 where t1 = never | t1 + TypeId leftType = arena->addType(UnionType{{builtinTypes->neverType, builtinTypes->neverType}}); + UnionType* leftUnion = getMutable(leftType); + REQUIRE(leftUnion); + leftUnion->options[0] = leftType; + + CHECK(builtinTypes->stringType == union_(leftType, builtinTypes->stringType)); +} + +TEST_SUITE_END(); diff --git a/tests/StringUtils.test.cpp b/tests/StringUtils.test.cpp index afef3b06e..cf65856d1 100644 --- a/tests/StringUtils.test.cpp +++ b/tests/StringUtils.test.cpp @@ -59,7 +59,7 @@ TEST_CASE("BenchmarkLevenshteinDistance") auto end = std::chrono::steady_clock::now(); auto time = std::chrono::duration_cast(end - start); - std::cout << "Running levenshtein distance " << count << " times took " << time.count() << "ms" << std::endl; + MESSAGE("Running levenshtein distance ", count, " times took ", time.count(), "ms"); } #endif @@ -106,4 +106,22 @@ TEST_CASE("AreWeUsingDistanceWithAdjacentTranspositionsAndNotOptimalStringAlignm CHECK_EQ(distance, 2); } +TEST_CASE("EditDistanceSupportsUnicode") +{ + // ASCII character + CHECK_EQ(Luau::editDistance("A block", "X block"), 1); + + // UTF-8 2 byte character + CHECK_EQ(Luau::editDistance("A block", "À block"), 2); + + // UTF-8 3 byte character + CHECK_EQ(Luau::editDistance("A block", "⪻ block"), 3); + + // UTF-8 4 byte character + CHECK_EQ(Luau::editDistance("A block", "ð’‹„ block"), 4); + + // UTF-8 extreme characters + CHECK_EQ(Luau::editDistance("A block", "R̴̨̢̟̚Å̶̳̳͚ÌÍ…b̶̡̻̞ÌÌ¿Í…l̸̼Íợ̷̜͓̒ÌÍœÍáºÌ´Ì̦̟̰ÌÌ’ÌÌŒ block"), 85); +} + TEST_SUITE_END(); diff --git a/tests/Subtyping.test.cpp b/tests/Subtyping.test.cpp new file mode 100644 index 000000000..27b2f6e7c --- /dev/null +++ b/tests/Subtyping.test.cpp @@ -0,0 +1,1609 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Luau/TypeFwd.h" +#include "Luau/TypePath.h" + +#include "Luau/Normalize.h" +#include "Luau/Subtyping.h" +#include "Luau/Type.h" +#include "Luau/TypePack.h" +#include "Luau/TypeFunction.h" + +#include "doctest.h" +#include "Fixture.h" +#include "RegisterCallbacks.h" + +#include + +LUAU_FASTFLAG(LuauSolverV2); + +using namespace Luau; + +namespace Luau +{ + +std::ostream& operator<<(std::ostream& lhs, const SubtypingVariance& variance) +{ + switch (variance) + { + case SubtypingVariance::Covariant: + return lhs << "covariant"; + case SubtypingVariance::Contravariant: + return lhs << "contravariant"; + case SubtypingVariance::Invariant: + return lhs << "invariant"; + case SubtypingVariance::Invalid: + return lhs << "*invalid*"; + } + + return lhs; +} + +std::ostream& operator<<(std::ostream& lhs, const SubtypingReasoning& reasoning) +{ + return lhs << toString(reasoning.subPath) << " & set, const std::vector& items) +{ + if (items.size() != set.size()) + return false; + + for (const SubtypingReasoning& r : items) + { + if (!set.contains(r)) + return false; + } + + return true; +} + +}; // namespace Luau + +struct SubtypeFixture : Fixture +{ + TypeArena arena; + InternalErrorReporter iceReporter; + UnifierSharedState sharedState{&ice}; + Normalizer normalizer{&arena, builtinTypes, NotNull{&sharedState}}; + TypeCheckLimits limits; + TypeFunctionRuntime typeFunctionRuntime{NotNull{&iceReporter}, NotNull{&limits}}; + + ScopedFastFlag sff{FFlag::LuauSolverV2, true}; + + ScopePtr rootScope{new Scope(builtinTypes->emptyTypePack)}; + ScopePtr moduleScope{new Scope(rootScope)}; + + Subtyping subtyping = mkSubtyping(); + BuiltinTypeFunctions builtinTypeFunctions{}; + + Subtyping mkSubtyping() + { + return Subtyping{builtinTypes, NotNull{&arena}, NotNull{&normalizer}, NotNull{&typeFunctionRuntime}, NotNull{&iceReporter}}; + } + + TypePackId pack(std::initializer_list tys) + { + return arena.addTypePack(tys); + } + + TypePackId pack(std::initializer_list tys, TypePackVariant tail) + { + return arena.addTypePack(tys, arena.addTypePack(std::move(tail))); + } + + TypeId fn(std::initializer_list args, std::initializer_list rets) + { + return arena.addType(FunctionType{pack(args), pack(rets)}); + } + + TypeId fn(std::initializer_list argHead, TypePackVariant argTail, std::initializer_list rets) + { + return arena.addType(FunctionType{pack(argHead, std::move(argTail)), pack(rets)}); + } + + TypeId fn(std::initializer_list args, std::initializer_list retHead, TypePackVariant retTail) + { + return arena.addType(FunctionType{pack(args), pack(retHead, std::move(retTail))}); + } + + TypeId fn(std::initializer_list argHead, TypePackVariant argTail, std::initializer_list retHead, TypePackVariant retTail) + { + return arena.addType(FunctionType{pack(argHead, std::move(argTail)), pack(retHead, std::move(retTail))}); + } + + TypeId tbl(TableType::Props&& props) + { + return arena.addType(TableType{std::move(props), std::nullopt, {}, TableState::Sealed}); + } + + TypeId idx(TypeId keyTy, TypeId valueTy) + { + return arena.addType(TableType{{}, TableIndexer{keyTy, valueTy}, {}, TableState::Sealed}); + } + + // `&` + TypeId meet(TypeId a, TypeId b) + { + return arena.addType(IntersectionType{{a, b}}); + } + + // `|` + TypeId join(TypeId a, TypeId b) + { + return arena.addType(UnionType{{a, b}}); + } + + // `~` + TypeId negate(TypeId ty) + { + return arena.addType(NegationType{ty}); + } + + // "literal" + TypeId str(const char* literal) + { + return arena.addType(SingletonType{StringSingleton{literal}}); + } + + TypeId cls(const std::string& name, std::optional parent = std::nullopt) + { + return arena.addType(ClassType{name, {}, parent.value_or(builtinTypes->classType), {}, {}, nullptr, "", {}}); + } + + TypeId cls(const std::string& name, ClassType::Props&& props) + { + TypeId ty = cls(name); + getMutable(ty)->props = std::move(props); + return ty; + } + + TypeId opt(TypeId ty) + { + return join(ty, builtinTypes->nilType); + } + + TypeId cyclicTable(std::function&& cb) + { + TypeId res = arena.addType(GenericType{}); + TableType tt{}; + cb(res, &tt); + emplaceType(asMutable(res), std::move(tt)); + return res; + } + + TypeId meta(TableType::Props&& metaProps, TableType::Props&& tableProps = {}) + { + return arena.addType(MetatableType{tbl(std::move(tableProps)), tbl(std::move(metaProps))}); + } + + TypeId genericT = arena.addType(GenericType{moduleScope.get(), "T"}); + TypeId genericU = arena.addType(GenericType{moduleScope.get(), "U"}); + + TypePackId genericAs = arena.addTypePack(GenericTypePack{"A"}); + TypePackId genericBs = arena.addTypePack(GenericTypePack{"B"}); + TypePackId genericCs = arena.addTypePack(GenericTypePack{"C"}); + + SubtypingResult isSubtype(TypeId subTy, TypeId superTy) + { + return subtyping.isSubtype(subTy, superTy, NotNull{rootScope.get()}); + } + + TypeId helloType = arena.addType(SingletonType{StringSingleton{"hello"}}); + TypeId helloType2 = arena.addType(SingletonType{StringSingleton{"hello"}}); + TypeId worldType = arena.addType(SingletonType{StringSingleton{"world"}}); + + TypeId aType = arena.addType(SingletonType{StringSingleton{"a"}}); + TypeId bType = arena.addType(SingletonType{StringSingleton{"b"}}); + TypeId trueSingleton = arena.addType(SingletonType{BooleanSingleton{true}}); + TypeId falseSingleton = arena.addType(SingletonType{BooleanSingleton{false}}); + TypeId helloOrWorldType = join(helloType, worldType); + TypeId trueOrFalseType = join(builtinTypes->trueType, builtinTypes->falseType); + + TypeId helloAndWorldType = meet(helloType, worldType); + TypeId booleanAndTrueType = meet(builtinTypes->booleanType, builtinTypes->trueType); + + /** + * class + * \- Root + * |- Child + * | |-GrandchildOne + * | \-GrandchildTwo + * \- AnotherChild + * |- AnotherGrandchildOne + * \- AnotherGrandchildTwo + */ + TypeId rootClass = cls("Root"); + TypeId childClass = cls("Child", rootClass); + TypeId grandchildOneClass = cls("GrandchildOne", childClass); + TypeId grandchildTwoClass = cls("GrandchildTwo", childClass); + TypeId anotherChildClass = cls("AnotherChild", rootClass); + TypeId anotherGrandchildOneClass = cls("AnotherGrandchildOne", anotherChildClass); + TypeId anotherGrandchildTwoClass = cls("AnotherGrandchildTwo", anotherChildClass); + + TypeId vec2Class = + cls("Vec2", + { + {"X", builtinTypes->numberType}, + {"Y", builtinTypes->numberType}, + }); + + TypeId readOnlyVec2Class = + cls("ReadOnlyVec2", + { + {"X", Property::readonly(builtinTypes->numberType)}, + {"Y", Property::readonly(builtinTypes->numberType)}, + }); + + // "hello" | "hello" + TypeId helloOrHelloType = arena.addType(UnionType{{helloType, helloType}}); + + // () -> () + const TypeId nothingToNothingType = fn({}, {}); + + // (number) -> string + const TypeId numberToStringType = fn({builtinTypes->numberType}, {builtinTypes->stringType}); + + // (unknown) -> string + const TypeId unknownToStringType = fn({builtinTypes->unknownType}, {builtinTypes->stringType}); + + // (number) -> () + const TypeId numberToNothingType = fn({builtinTypes->numberType}, {}); + + // () -> number + const TypeId nothingToNumberType = fn({}, {builtinTypes->numberType}); + + // (number) -> number + const TypeId numberToNumberType = fn({builtinTypes->numberType}, {builtinTypes->numberType}); + + // (number) -> unknown + const TypeId numberToUnknownType = fn({builtinTypes->numberType}, {builtinTypes->unknownType}); + + // (number) -> (string, string) + const TypeId numberToTwoStringsType = fn({builtinTypes->numberType}, {builtinTypes->stringType, builtinTypes->stringType}); + + // (number) -> (string, unknown) + const TypeId numberToStringAndUnknownType = fn({builtinTypes->numberType}, {builtinTypes->stringType, builtinTypes->unknownType}); + + // (number, number) -> string + const TypeId numberNumberToStringType = fn({builtinTypes->numberType, builtinTypes->numberType}, {builtinTypes->stringType}); + + // (unknown, number) -> string + const TypeId unknownNumberToStringType = fn({builtinTypes->unknownType, builtinTypes->numberType}, {builtinTypes->stringType}); + + // (number, string) -> string + const TypeId numberAndStringToStringType = fn({builtinTypes->numberType, builtinTypes->stringType}, {builtinTypes->stringType}); + + // (number, ...string) -> string + const TypeId numberAndStringsToStringType = + fn({builtinTypes->numberType}, VariadicTypePack{builtinTypes->stringType}, {builtinTypes->stringType}); + + // (number, ...string?) -> string + const TypeId numberAndOptionalStringsToStringType = + fn({builtinTypes->numberType}, VariadicTypePack{builtinTypes->optionalStringType}, {builtinTypes->stringType}); + + // (...number) -> number + const TypeId numbersToNumberType = + arena.addType(FunctionType{arena.addTypePack(VariadicTypePack{builtinTypes->numberType}), arena.addTypePack({builtinTypes->numberType})}); + + // (T) -> () + const TypeId genericTToNothingType = arena.addType(FunctionType{{genericT}, {}, arena.addTypePack({genericT}), builtinTypes->emptyTypePack}); + + // (T) -> T + const TypeId genericTToTType = arena.addType(FunctionType{{genericT}, {}, arena.addTypePack({genericT}), arena.addTypePack({genericT})}); + + // (U) -> () + const TypeId genericUToNothingType = arena.addType(FunctionType{{genericU}, {}, arena.addTypePack({genericU}), builtinTypes->emptyTypePack}); + + // () -> T + const TypeId genericNothingToTType = arena.addType(FunctionType{{genericT}, {}, builtinTypes->emptyTypePack, arena.addTypePack({genericT})}); + + // (A...) -> A... + const TypeId genericAsToAsType = arena.addType(FunctionType{{}, {genericAs}, genericAs, genericAs}); + + // (A...) -> number + const TypeId genericAsToNumberType = arena.addType(FunctionType{{}, {genericAs}, genericAs, arena.addTypePack({builtinTypes->numberType})}); + + // (B...) -> B... + const TypeId genericBsToBsType = arena.addType(FunctionType{{}, {genericBs}, genericBs, genericBs}); + + // (B...) -> C... + const TypeId genericBsToCsType = arena.addType(FunctionType{{}, {genericBs, genericCs}, genericBs, genericCs}); + + // () -> A... + const TypeId genericNothingToAsType = arena.addType(FunctionType{{}, {genericAs}, builtinTypes->emptyTypePack, genericAs}); + + // { lower : string -> string } + TypeId tableWithLower = tbl(TableType::Props{{"lower", fn({builtinTypes->stringType}, {builtinTypes->stringType})}}); + // { insaneThingNoScalarHas : () -> () } + TypeId tableWithoutScalarProp = tbl(TableType::Props{{"insaneThingNoScalarHas", fn({}, {})}}); +}; + +#define CHECK_IS_SUBTYPE(left, right) \ + do \ + { \ + const auto& leftTy = (left); \ + const auto& rightTy = (right); \ + SubtypingResult result = isSubtype(leftTy, rightTy); \ + CHECK_MESSAGE(result.isSubtype, "Expected " << leftTy << " <: " << rightTy); \ + } while (0) + +#define CHECK_IS_NOT_SUBTYPE(left, right) \ + do \ + { \ + const auto& leftTy = (left); \ + const auto& rightTy = (right); \ + SubtypingResult result = isSubtype(leftTy, rightTy); \ + CHECK_MESSAGE(!result.isSubtype, "Expected " << leftTy << " numberType, builtinTypes->anyType); +TEST_IS_NOT_SUBTYPE(builtinTypes->numberType, builtinTypes->stringType); + +TEST_CASE_FIXTURE(SubtypeFixture, "basic_reducible_sub_type_function") +{ + // add <: number + TypeId typeFunctionNum = + arena.addType(TypeFunctionInstanceType{NotNull{&builtinTypeFunctions.addFunc}, {builtinTypes->numberType, builtinTypes->numberType}, {}}); + TypeId superTy = builtinTypes->numberType; + SubtypingResult result = isSubtype(typeFunctionNum, superTy); + CHECK(result.isSubtype); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "basic_reducible_super_type_function") +{ + // number <: add ~ number + TypeId typeFunctionNum = + arena.addType(TypeFunctionInstanceType{NotNull{&builtinTypeFunctions.addFunc}, {builtinTypes->numberType, builtinTypes->numberType}, {}}); + TypeId subTy = builtinTypes->numberType; + SubtypingResult result = isSubtype(subTy, typeFunctionNum); + CHECK(result.isSubtype); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "basic_irreducible_sub_type_function") +{ + // add ~ never <: number + TypeId typeFunctionNum = + arena.addType(TypeFunctionInstanceType{NotNull{&builtinTypeFunctions.addFunc}, {builtinTypes->stringType, builtinTypes->booleanType}, {}}); + TypeId superTy = builtinTypes->numberType; + SubtypingResult result = isSubtype(typeFunctionNum, superTy); + CHECK(result.isSubtype); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "basic_irreducible_super_type_function") +{ + // number <\: add ~ irreducible/never + TypeId typeFunctionNum = + arena.addType(TypeFunctionInstanceType{NotNull{&builtinTypeFunctions.addFunc}, {builtinTypes->stringType, builtinTypes->booleanType}, {}}); + TypeId subTy = builtinTypes->numberType; + SubtypingResult result = isSubtype(subTy, typeFunctionNum); + CHECK(!result.isSubtype); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "basic_type_function_with_generics") +{ + // (x: T, x: U) -> add <: (number, number) -> number + TypeId addTypeFunction = arena.addType(TypeFunctionInstanceType{NotNull{&builtinTypeFunctions.addFunc}, {genericT, genericU}, {}}); + FunctionType ft{{genericT, genericU}, {}, arena.addTypePack({genericT, genericU}), arena.addTypePack({addTypeFunction})}; + TypeId functionType = arena.addType(std::move(ft)); + FunctionType superFt{arena.addTypePack({builtinTypes->numberType, builtinTypes->numberType}), arena.addTypePack({builtinTypes->numberType})}; + TypeId superFunction = arena.addType(std::move(superFt)); + SubtypingResult result = isSubtype(functionType, superFunction); + CHECK(result.isSubtype); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "variadic_subpath_in_pack") +{ + TypePackId subTArgs = arena.addTypePack(TypePack{{builtinTypes->stringType, builtinTypes->stringType}, builtinTypes->anyTypePack}); + TypePackId superTArgs = arena.addTypePack(TypePack{{builtinTypes->numberType}, builtinTypes->anyTypePack}); + // (string, string, ...any) -> number + TypeId functionSub = arena.addType(FunctionType{subTArgs, arena.addTypePack({builtinTypes->numberType})}); + // (number, ...any) -> string + TypeId functionSuper = arena.addType(FunctionType{superTArgs, arena.addTypePack({builtinTypes->stringType})}); + + + SubtypingResult result = isSubtype(functionSub, functionSuper); + CHECK( + result.reasoning == + std::vector{ + SubtypingReasoning{ + TypePath::PathBuilder().rets().index(0).build(), TypePath::PathBuilder().rets().index(0).build(), SubtypingVariance::Covariant + }, + SubtypingReasoning{ + TypePath::PathBuilder().args().index(0).build(), TypePath::PathBuilder().args().index(0).build(), SubtypingVariance::Contravariant + }, + SubtypingReasoning{ + TypePath::PathBuilder().args().index(1).build(), + TypePath::PathBuilder().args().tail().variadic().build(), + SubtypingVariance::Contravariant + } + } + ); + CHECK(!result.isSubtype); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "any <: unknown") +{ + // We have added this as an exception - the set of inhabitants of any is exactly the set of inhabitants of unknown (since error has no + // inhabitants). any = err | unknown, so under semantic subtyping, {} U unknown = unknown + CHECK_IS_SUBTYPE(builtinTypes->anyType, builtinTypes->unknownType); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "number? <: unknown") +{ + CHECK_IS_SUBTYPE(builtinTypes->optionalNumberType, builtinTypes->unknownType); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "number <: unknown") +{ + CHECK_IS_SUBTYPE(builtinTypes->numberType, builtinTypes->unknownType); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "number <: number") +{ + CHECK_IS_SUBTYPE(builtinTypes->numberType, builtinTypes->numberType); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "number <: number?") +{ + CHECK_IS_SUBTYPE(builtinTypes->numberType, builtinTypes->optionalNumberType); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "\"hello\" <: string") +{ + CHECK_IS_SUBTYPE(helloType, builtinTypes->stringType); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "string stringType, helloType); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "\"hello\" <: \"hello\"") +{ + CHECK_IS_SUBTYPE(helloType, helloType2); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "true <: boolean") +{ + CHECK_IS_SUBTYPE(builtinTypes->trueType, builtinTypes->booleanType); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "true <: true | false") +{ + CHECK_IS_SUBTYPE(builtinTypes->trueType, trueOrFalseType); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "true | false trueType); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "true | false <: boolean") +{ + CHECK_IS_SUBTYPE(trueOrFalseType, builtinTypes->booleanType); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "true | false <: true | false") +{ + CHECK_IS_SUBTYPE(trueOrFalseType, trueOrFalseType); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "\"hello\" | \"world\" <: number") +{ + CHECK_IS_NOT_SUBTYPE(helloOrWorldType, builtinTypes->numberType); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "string stringType, helloOrHelloType); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "true <: boolean & true") +{ + CHECK_IS_SUBTYPE(builtinTypes->trueType, booleanAndTrueType); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "boolean & true <: true") +{ + CHECK_IS_SUBTYPE(booleanAndTrueType, builtinTypes->trueType); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "boolean & true <: boolean & true") +{ + CHECK_IS_SUBTYPE(booleanAndTrueType, booleanAndTrueType); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "\"hello\" & \"world\" <: number") +{ + CHECK_IS_SUBTYPE(helloAndWorldType, builtinTypes->numberType); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "false falseType, booleanAndTrueType); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "(unknown) -> string <: (number) -> string") +{ + CHECK_IS_SUBTYPE(unknownToStringType, numberToStringType); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "(number) -> string string") +{ + CHECK_IS_NOT_SUBTYPE(numberToStringType, unknownToStringType); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "(number, number) -> string string") +{ + CHECK_IS_NOT_SUBTYPE(numberNumberToStringType, numberToStringType); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "(number) -> string string") +{ + CHECK_IS_NOT_SUBTYPE(numberToStringType, numberNumberToStringType); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "(number, number) -> string string") +{ + CHECK_IS_NOT_SUBTYPE(numberNumberToStringType, unknownNumberToStringType); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "(unknown, number) -> string <: (number, number) -> string") +{ + CHECK_IS_SUBTYPE(unknownNumberToStringType, numberNumberToStringType); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "(number) -> (string, unknown) (string, string)") +{ + CHECK_IS_NOT_SUBTYPE(numberToStringAndUnknownType, numberToTwoStringsType); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "(number) -> (string, string) <: (number) -> (string, unknown)") +{ + CHECK_IS_SUBTYPE(numberToTwoStringsType, numberToStringAndUnknownType); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "(number) -> (string, string) string") +{ + CHECK_IS_NOT_SUBTYPE(numberToTwoStringsType, numberToStringType); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "(number) -> string (string, string)") +{ + CHECK_IS_NOT_SUBTYPE(numberToStringType, numberToTwoStringsType); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "(number, ...string) -> string <: (number) -> string") +{ + CHECK_IS_SUBTYPE(numberAndStringsToStringType, numberToStringType); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "(number) -> string string") +{ + CHECK_IS_NOT_SUBTYPE(numberToStringType, numberAndStringsToStringType); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "(number, ...string?) -> string <: (number, ...string) -> string") +{ + CHECK_IS_SUBTYPE(numberAndOptionalStringsToStringType, numberAndStringsToStringType); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "(number, ...string) -> string string") +{ + CHECK_IS_NOT_SUBTYPE(numberAndStringsToStringType, numberAndOptionalStringsToStringType); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "(number, ...string) -> string <: (number, string) -> string") +{ + CHECK_IS_SUBTYPE(numberAndStringsToStringType, numberAndStringToStringType); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "(number, string) -> string string") +{ + CHECK_IS_NOT_SUBTYPE(numberAndStringToStringType, numberAndStringsToStringType); +} + +/* + * (A) -> A <: (X) -> X + * A can be bound to X. + * + * (A) -> A (X) -> number + * A can be bound to X, but A number (A) -> A + * Only generics on the left side can be bound. + * number (A, B) -> boolean <: (X, X) -> boolean + * It is ok to bind both A and B to X. + * + * (A, A) -> boolean (X, Y) -> boolean + * A cannot be bound to both X and Y. + */ +TEST_CASE_FIXTURE(SubtypeFixture, "() -> T <: () -> number") +{ + CHECK_IS_SUBTYPE(genericNothingToTType, nothingToNumberType); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "(T) -> () <: (U) -> ()") +{ + CHECK_IS_SUBTYPE(genericTToNothingType, genericUToNothingType); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "() -> number () -> T") +{ + CHECK_IS_NOT_SUBTYPE(nothingToNumberType, genericNothingToTType); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "(T) -> () <: (number) -> ()") +{ + CHECK_IS_SUBTYPE(genericTToNothingType, numberToNothingType); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "(T) -> T <: (number) -> number") +{ + CHECK_IS_SUBTYPE(genericTToTType, numberToNumberType); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "(T) -> T string") +{ + CHECK_IS_NOT_SUBTYPE(genericTToTType, numberToStringType); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "(T) -> () <: (U) -> ()") +{ + CHECK_IS_SUBTYPE(genericTToNothingType, genericUToNothingType); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "(number) -> () (T) -> ()") +{ + CHECK_IS_NOT_SUBTYPE(numberToNothingType, genericTToNothingType); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "() -> (T, T) (string, number)") +{ + TypeId nothingToTwoTs = arena.addType(FunctionType{{genericT}, {}, builtinTypes->emptyTypePack, arena.addTypePack({genericT, genericT})}); + + TypeId nothingToStringAndNumber = fn({}, {builtinTypes->stringType, builtinTypes->numberType}); + + CHECK_IS_NOT_SUBTYPE(nothingToTwoTs, nothingToStringAndNumber); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "(A...) -> A... <: (number) -> number") +{ + CHECK_IS_SUBTYPE(genericAsToAsType, numberToNumberType); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "(number) -> number (A...) -> A...") +{ + CHECK_IS_NOT_SUBTYPE(numberToNumberType, genericAsToAsType); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "(A...) -> A... <: (B...) -> B...") +{ + CHECK_IS_SUBTYPE(genericAsToAsType, genericBsToBsType); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "(B...) -> C... <: (A...) -> A...") +{ + CHECK_IS_SUBTYPE(genericBsToCsType, genericAsToAsType); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "(A...) -> A... (B...) -> C...") +{ + CHECK_IS_NOT_SUBTYPE(genericAsToAsType, genericBsToCsType); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "(A...) -> number <: (number) -> number") +{ + CHECK_IS_SUBTYPE(genericAsToNumberType, numberToNumberType); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "(number) -> number (A...) -> number") +{ + CHECK_IS_NOT_SUBTYPE(numberToNumberType, genericAsToNumberType); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "(A...) -> number <: (...number) -> number") +{ + CHECK_IS_SUBTYPE(genericAsToNumberType, numbersToNumberType); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "(...number) -> number (A...) -> number") +{ + CHECK_IS_NOT_SUBTYPE(numbersToNumberType, genericAsToNumberType); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "() -> A... <: () -> ()") +{ + CHECK_IS_SUBTYPE(genericNothingToAsType, nothingToNothingType); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "() -> () () -> A...") +{ + CHECK_IS_NOT_SUBTYPE(nothingToNothingType, genericNothingToAsType); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "(A...) -> A... <: () -> ()") +{ + CHECK_IS_SUBTYPE(genericAsToAsType, nothingToNothingType); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "() -> () (A...) -> A...") +{ + CHECK_IS_NOT_SUBTYPE(nothingToNothingType, genericAsToAsType); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "{} <: {}") +{ + CHECK_IS_SUBTYPE(tbl({}), tbl({})); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "{x: number} <: {}") +{ + CHECK_IS_SUBTYPE(tbl({{"x", builtinTypes->numberType}}), tbl({})); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "{} numberType}})); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "{x: number} numberType}}), tbl({{"x", builtinTypes->stringType}})); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "{x: number} numberType}}), tbl({{"x", builtinTypes->optionalNumberType}})); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "{x: number?} optionalNumberType}}), tbl({{"x", builtinTypes->numberType}})); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "{x: (T) -> ()} <: {x: (U) -> ()}") +{ + CHECK_IS_SUBTYPE(tbl({{"x", genericTToNothingType}}), tbl({{"x", genericUToNothingType}})); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "{ x: number } <: { read x: number }") +{ + ScopedFastFlag sff{FFlag::LuauSolverV2, true}; + + CHECK_IS_SUBTYPE(tbl({{"x", builtinTypes->numberType}}), tbl({{"x", Property::readonly(builtinTypes->numberType)}})); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "{ x: number } <: { write x: number }") +{ + ScopedFastFlag sff{FFlag::LuauSolverV2, true}; + + CHECK_IS_SUBTYPE(tbl({{"x", builtinTypes->numberType}}), tbl({{"x", Property::writeonly(builtinTypes->numberType)}})); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "{ x: \"hello\" } <: { read x: string }") +{ + ScopedFastFlag sff{FFlag::LuauSolverV2, true}; + + CHECK_IS_SUBTYPE(tbl({{"x", helloType}}), tbl({{"x", Property::readonly(builtinTypes->stringType)}})); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "{ x: string } <: { write x: string }") +{ + ScopedFastFlag sff{FFlag::LuauSolverV2, true}; + + CHECK_IS_SUBTYPE(tbl({{"x", builtinTypes->stringType}}), tbl({{"x", Property::writeonly(builtinTypes->stringType)}})); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "{ @metatable { x: number } } <: { @metatable {} }") +{ + CHECK_IS_SUBTYPE(meta({{"x", builtinTypes->numberType}}), meta({})); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "{ @metatable { x: number } } numberType}}), meta({{"x", builtinTypes->booleanType}})); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "{ @metatable {} } booleanType}})); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "{ @metatable {} } <: {}") +{ + CHECK_IS_SUBTYPE(meta({}), tbl({})); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "{ @metatable { u: boolean }, x: number } <: { x: number }") +{ + CHECK_IS_SUBTYPE(meta({{"u", builtinTypes->booleanType}}, {{"x", builtinTypes->numberType}}), tbl({{"x", builtinTypes->numberType}})); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "{ @metatable { x: number } } numberType}}), tbl({{"x", builtinTypes->numberType}})); +} + +TEST_IS_SUBTYPE(builtinTypes->tableType, tbl({})); +TEST_IS_SUBTYPE(tbl({}), builtinTypes->tableType); + +// Negated subtypes +TEST_IS_NOT_SUBTYPE(negate(builtinTypes->neverType), builtinTypes->stringType); +TEST_IS_SUBTYPE(negate(builtinTypes->unknownType), builtinTypes->stringType); +TEST_IS_SUBTYPE(negate(builtinTypes->anyType), builtinTypes->stringType); +TEST_IS_SUBTYPE(negate(meet(builtinTypes->neverType, builtinTypes->unknownType)), builtinTypes->stringType); +TEST_IS_SUBTYPE(negate(join(builtinTypes->neverType, builtinTypes->unknownType)), builtinTypes->stringType); + +// Negated supertypes: never/unknown/any/error +TEST_IS_SUBTYPE(builtinTypes->stringType, negate(builtinTypes->neverType)); +TEST_IS_SUBTYPE(builtinTypes->neverType, negate(builtinTypes->unknownType)); +TEST_IS_NOT_SUBTYPE(builtinTypes->stringType, negate(builtinTypes->unknownType)); +TEST_IS_SUBTYPE(builtinTypes->numberType, negate(builtinTypes->anyType)); +TEST_IS_SUBTYPE(builtinTypes->unknownType, negate(builtinTypes->anyType)); + +// Negated supertypes: unions +TEST_IS_SUBTYPE(builtinTypes->booleanType, negate(join(builtinTypes->stringType, builtinTypes->numberType))); +TEST_IS_SUBTYPE(rootClass, negate(join(childClass, builtinTypes->numberType))); +TEST_IS_SUBTYPE(str("foo"), negate(join(builtinTypes->numberType, builtinTypes->booleanType))); +TEST_IS_NOT_SUBTYPE(str("foo"), negate(join(builtinTypes->stringType, builtinTypes->numberType))); +TEST_IS_NOT_SUBTYPE(childClass, negate(join(rootClass, builtinTypes->numberType))); +TEST_IS_NOT_SUBTYPE(numbersToNumberType, negate(join(builtinTypes->functionType, rootClass))); + +// Negated supertypes: intersections +TEST_IS_SUBTYPE(builtinTypes->booleanType, negate(meet(builtinTypes->stringType, str("foo")))); +TEST_IS_SUBTYPE(builtinTypes->trueType, negate(meet(builtinTypes->booleanType, builtinTypes->numberType))); +TEST_IS_SUBTYPE(rootClass, negate(meet(builtinTypes->classType, childClass))); +TEST_IS_SUBTYPE(childClass, negate(meet(builtinTypes->classType, builtinTypes->numberType))); +TEST_IS_SUBTYPE(builtinTypes->unknownType, negate(meet(builtinTypes->classType, builtinTypes->numberType))); +TEST_IS_NOT_SUBTYPE(str("foo"), negate(meet(builtinTypes->stringType, negate(str("bar"))))); + +// Negated supertypes: tables and metatables +TEST_IS_SUBTYPE(tbl({}), negate(builtinTypes->numberType)); +TEST_IS_NOT_SUBTYPE(tbl({}), negate(builtinTypes->tableType)); +TEST_IS_SUBTYPE(meta({}), negate(builtinTypes->numberType)); +TEST_IS_NOT_SUBTYPE(meta({}), negate(builtinTypes->tableType)); + +// Negated supertypes: Functions +TEST_IS_SUBTYPE(numberToNumberType, negate(builtinTypes->classType)); +TEST_IS_NOT_SUBTYPE(numberToNumberType, negate(builtinTypes->functionType)); + +// Negated supertypes: Primitives and singletons +TEST_IS_NOT_SUBTYPE(builtinTypes->stringType, negate(builtinTypes->stringType)); +TEST_IS_SUBTYPE(builtinTypes->stringType, negate(builtinTypes->numberType)); +TEST_IS_SUBTYPE(str("foo"), meet(builtinTypes->stringType, negate(str("bar")))); +TEST_IS_NOT_SUBTYPE(builtinTypes->trueType, negate(builtinTypes->booleanType)); +TEST_IS_NOT_SUBTYPE(str("foo"), negate(str("foo"))); +TEST_IS_NOT_SUBTYPE(str("foo"), negate(builtinTypes->stringType)); +TEST_IS_SUBTYPE(builtinTypes->falseType, negate(builtinTypes->trueType)); +TEST_IS_SUBTYPE(builtinTypes->falseType, meet(builtinTypes->booleanType, negate(builtinTypes->trueType))); +TEST_IS_NOT_SUBTYPE(builtinTypes->stringType, meet(builtinTypes->booleanType, negate(builtinTypes->trueType))); +TEST_IS_NOT_SUBTYPE(builtinTypes->stringType, negate(str("foo"))); +TEST_IS_NOT_SUBTYPE(builtinTypes->booleanType, negate(builtinTypes->falseType)); + +// Negated supertypes: Classes +TEST_IS_SUBTYPE(rootClass, negate(builtinTypes->tableType)); +TEST_IS_NOT_SUBTYPE(rootClass, negate(builtinTypes->classType)); +TEST_IS_NOT_SUBTYPE(childClass, negate(rootClass)); +TEST_IS_NOT_SUBTYPE(childClass, meet(builtinTypes->classType, negate(rootClass))); +TEST_IS_SUBTYPE(anotherChildClass, meet(builtinTypes->classType, negate(childClass))); + +TEST_CASE_FIXTURE(SubtypeFixture, "Root <: class") +{ + CHECK_IS_SUBTYPE(rootClass, builtinTypes->classType); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "Child | AnotherChild <: class") +{ + CHECK_IS_SUBTYPE(join(childClass, anotherChildClass), builtinTypes->classType); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "Child | AnotherChild <: Child | AnotherChild") +{ + CHECK_IS_SUBTYPE(join(childClass, anotherChildClass), join(childClass, anotherChildClass)); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "Child | Root <: Root") +{ + CHECK_IS_SUBTYPE(join(childClass, rootClass), rootClass); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "Child & AnotherChild <: class") +{ + CHECK_IS_SUBTYPE(meet(childClass, anotherChildClass), builtinTypes->classType); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "Child & Root <: class") +{ + CHECK_IS_SUBTYPE(meet(childClass, rootClass), builtinTypes->classType); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "Child & ~Root <: class") +{ + CHECK_IS_SUBTYPE(meet(childClass, negate(rootClass)), builtinTypes->classType); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "Child & AnotherChild <: number") +{ + CHECK_IS_SUBTYPE(meet(childClass, anotherChildClass), builtinTypes->numberType); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "Child & ~GrandchildOne numberType); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "semantic_subtyping_disj") +{ + TypeId subTy = builtinTypes->unknownType; + TypeId superTy = join(negate(builtinTypes->numberType), negate(builtinTypes->stringType)); + SubtypingResult result = isSubtype(subTy, superTy); + CHECK(result.isSubtype); +} + + +TEST_CASE_FIXTURE(SubtypeFixture, "t1 where t1 = {trim: (t1) -> string} <: t2 where t2 = {trim: (t2) -> string}") +{ + TypeId t1 = cyclicTable( + [&](TypeId ty, TableType* tt) + { + tt->props["trim"] = fn({ty}, {builtinTypes->stringType}); + } + ); + + TypeId t2 = cyclicTable( + [&](TypeId ty, TableType* tt) + { + tt->props["trim"] = fn({ty}, {builtinTypes->stringType}); + } + ); + + CHECK_IS_SUBTYPE(t1, t2); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "t1 where t1 = {trim: (t1) -> string} t2}") +{ + TypeId t1 = cyclicTable( + [&](TypeId ty, TableType* tt) + { + tt->props["trim"] = fn({ty}, {builtinTypes->stringType}); + } + ); + + TypeId t2 = cyclicTable( + [&](TypeId ty, TableType* tt) + { + tt->props["trim"] = fn({ty}, {ty}); + } + ); + + CHECK_IS_NOT_SUBTYPE(t1, t2); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "t1 where t1 = {trim: (t1) -> t1} string}") +{ + TypeId t1 = cyclicTable( + [&](TypeId ty, TableType* tt) + { + tt->props["trim"] = fn({ty}, {ty}); + } + ); + + TypeId t2 = cyclicTable( + [&](TypeId ty, TableType* tt) + { + tt->props["trim"] = fn({ty}, {builtinTypes->stringType}); + } + ); + + CHECK_IS_NOT_SUBTYPE(t1, t2); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "Vec2 <: { X: number, Y: number }") +{ + TypeId xy = tbl({ + {"X", builtinTypes->numberType}, + {"Y", builtinTypes->numberType}, + }); + + CHECK_IS_SUBTYPE(vec2Class, xy); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "Vec2 <: { X: number }") +{ + TypeId x = tbl({ + {"X", builtinTypes->numberType}, + }); + + CHECK_IS_SUBTYPE(vec2Class, x); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "{ X: number, Y: number } numberType}, + {"Y", builtinTypes->numberType}, + }); + + CHECK_IS_NOT_SUBTYPE(xy, vec2Class); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "{ X: number } numberType}, + }); + + CHECK_IS_NOT_SUBTYPE(x, vec2Class); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "table & { X: number, Y: number } numberType}, + {"Y", builtinTypes->numberType}, + }); + + CHECK_IS_NOT_SUBTYPE(meet(builtinTypes->tableType, x), vec2Class); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "Vec2 numberType}, + {"Y", builtinTypes->numberType}, + }); + + CHECK_IS_NOT_SUBTYPE(vec2Class, meet(builtinTypes->tableType, xy)); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "ReadOnlyVec2 numberType}, {"Y", builtinTypes->numberType}})); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "ReadOnlyVec2 <: { read X: number, read Y: number}") +{ + CHECK_IS_SUBTYPE( + readOnlyVec2Class, tbl({{"X", Property::readonly(builtinTypes->numberType)}, {"Y", Property::readonly(builtinTypes->numberType)}}) + ); +} + +TEST_IS_SUBTYPE(vec2Class, tbl({{"X", Property::readonly(builtinTypes->numberType)}, {"Y", Property::readonly(builtinTypes->numberType)}})); + +TEST_IS_NOT_SUBTYPE(tbl({{"P", grandchildOneClass}}), tbl({{"P", Property::rw(rootClass)}})); +TEST_IS_SUBTYPE(tbl({{"P", grandchildOneClass}}), tbl({{"P", Property::readonly(rootClass)}})); +TEST_IS_SUBTYPE(tbl({{"P", rootClass}}), tbl({{"P", Property::writeonly(grandchildOneClass)}})); + +TEST_IS_NOT_SUBTYPE(cls("HasChild", {{"P", childClass}}), tbl({{"P", rootClass}})); +TEST_IS_SUBTYPE(cls("HasChild", {{"P", childClass}}), tbl({{"P", Property::readonly(rootClass)}})); +TEST_IS_NOT_SUBTYPE(cls("HasChild", {{"P", childClass}}), tbl({{"P", grandchildOneClass}})); +TEST_IS_SUBTYPE(cls("HasChild", {{"P", childClass}}), tbl({{"P", Property::writeonly(grandchildOneClass)}})); + +TEST_CASE_FIXTURE(SubtypeFixture, "\"hello\" <: { lower : (string) -> string }") +{ + CHECK_IS_SUBTYPE(helloType, tableWithLower); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "\"hello\" () }") +{ + CHECK_IS_NOT_SUBTYPE(helloType, tableWithoutScalarProp); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "string <: { lower : (string) -> string }") +{ + CHECK_IS_SUBTYPE(builtinTypes->stringType, tableWithLower); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "string () }") +{ + CHECK_IS_NOT_SUBTYPE(builtinTypes->stringType, tableWithoutScalarProp); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "~fun & (string) -> number <: (string) -> number") +{ + CHECK_IS_SUBTYPE(meet(negate(builtinTypes->functionType), numberToStringType), numberToStringType); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "(string) -> number <: ~fun & (string) -> number") +{ + CHECK_IS_NOT_SUBTYPE(numberToStringType, meet(negate(builtinTypes->functionType), numberToStringType)); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "~\"a\" & ~\"b\" & string <: { lower : (string) -> ()}") +{ + CHECK_IS_SUBTYPE(meet(meet(negate(aType), negate(bType)), builtinTypes->stringType), tableWithLower); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "\"a\" | (~\"b\" & string) <: { lower : (string) -> ()}") +{ + CHECK_IS_SUBTYPE(join(aType, meet(negate(bType), builtinTypes->stringType)), tableWithLower); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "(string | number) & (\"a\" | true) <: { lower: (string) -> string }") +{ + auto base = meet(join(builtinTypes->stringType, builtinTypes->numberType), join(aType, trueSingleton)); + CHECK_IS_SUBTYPE(base, tableWithLower); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "number <: ~~number") +{ + CHECK_IS_SUBTYPE(builtinTypes->numberType, negate(negate(builtinTypes->numberType))); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "~~number <: number") +{ + CHECK_IS_SUBTYPE(negate(negate(builtinTypes->numberType)), builtinTypes->numberType); +} + +// See https://github.com/luau-lang/luau/issues/767 +TEST_CASE_FIXTURE(SubtypeFixture, "(...any) -> () <: (T...) -> ()") +{ + TypeId anysToNothing = arena.addType(FunctionType{builtinTypes->anyTypePack, builtinTypes->emptyTypePack}); + TypeId genericTToAnys = arena.addType(FunctionType{genericAs, builtinTypes->emptyTypePack}); + + CHECK_MESSAGE(isSubtype(anysToNothing, genericTToAnys).isSubtype, "(...any) -> () <: (T...) -> ()"); +} + +// See https://github.com/luau-lang/luau/issues/767 +TEST_CASE_FIXTURE(SubtypeFixture, "(...unknown) -> () <: (T...) -> ()") +{ + TypeId unknownsToNothing = + arena.addType(FunctionType{arena.addTypePack(VariadicTypePack{builtinTypes->unknownType}), builtinTypes->emptyTypePack}); + TypeId genericTToAnys = arena.addType(FunctionType{genericAs, builtinTypes->emptyTypePack}); + + CHECK_MESSAGE(isSubtype(unknownsToNothing, genericTToAnys).isSubtype, "(...unknown) -> () <: (T...) -> ()"); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "bill") +{ + TypeId a = arena.addType(TableType{ + {{"a", builtinTypes->stringType}}, TableIndexer{builtinTypes->stringType, builtinTypes->numberType}, TypeLevel{}, nullptr, TableState::Sealed + }); + + TypeId b = arena.addType(TableType{ + {{"a", builtinTypes->stringType}}, TableIndexer{builtinTypes->stringType, builtinTypes->numberType}, TypeLevel{}, nullptr, TableState::Sealed + }); + + CHECK(isSubtype(a, b).isSubtype); + CHECK(isSubtype(b, a).isSubtype); +} + +// TEST_CASE_FIXTURE(SubtypeFixture, "({[string]: number, a: string}) -> () <: ({[string]: number, a: string}) -> ()") +TEST_CASE_FIXTURE(SubtypeFixture, "fred") +{ + auto makeTheType = [&]() + { + TypeId argType = arena.addType(TableType{ + {{"a", builtinTypes->stringType}}, + TableIndexer{builtinTypes->stringType, builtinTypes->numberType}, + TypeLevel{}, + nullptr, + TableState::Sealed + }); + + return arena.addType(FunctionType{arena.addTypePack({argType}), builtinTypes->emptyTypePack}); + }; + + TypeId a = makeTheType(); + TypeId b = makeTheType(); + + CHECK_MESSAGE(isSubtype(a, b).isSubtype, "({[string]: number, a: string}) -> () <: ({[string]: number, a: string}) -> ()"); +} + +/* + * Within the scope to which a generic belongs, that generic ought to be treated + * as its bounds. + * + * We do not yet support bounded generics, so all generics are considered to be + * bounded by unknown. + */ +TEST_CASE_FIXTURE(SubtypeFixture, "unknown <: X") +{ + ScopePtr childScope{new Scope(rootScope)}; + ScopePtr grandChildScope{new Scope(childScope)}; + + TypeId genericX = arena.addType(GenericType(childScope.get(), "X")); + + SubtypingResult usingGlobalScope = isSubtype(builtinTypes->unknownType, genericX); + CHECK_MESSAGE(!usingGlobalScope.isSubtype, "Expected " << builtinTypes->unknownType << " unknownType, genericX, NotNull{childScope.get()}); + CHECK_MESSAGE(usingChildScope.isSubtype, "Expected " << builtinTypes->unknownType << " <: " << genericX); + + Subtyping grandChildSubtyping{mkSubtyping()}; + + SubtypingResult usingGrandChildScope = grandChildSubtyping.isSubtype(builtinTypes->unknownType, genericX, NotNull{grandChildScope.get()}); + CHECK_MESSAGE(usingGrandChildScope.isSubtype, "Expected " << builtinTypes->unknownType << " <: " << genericX); +} + +TEST_IS_SUBTYPE(idx(builtinTypes->numberType, builtinTypes->numberType), tbl({})); +TEST_IS_NOT_SUBTYPE(tbl({}), idx(builtinTypes->numberType, builtinTypes->numberType)); + +TEST_IS_NOT_SUBTYPE(tbl({{"X", builtinTypes->numberType}}), idx(builtinTypes->numberType, builtinTypes->numberType)); +TEST_IS_NOT_SUBTYPE(idx(builtinTypes->numberType, builtinTypes->numberType), tbl({{"X", builtinTypes->numberType}})); + +TEST_IS_NOT_SUBTYPE( + idx(join(builtinTypes->numberType, builtinTypes->stringType), builtinTypes->numberType), + idx(builtinTypes->numberType, builtinTypes->numberType) +); +TEST_IS_NOT_SUBTYPE( + idx(builtinTypes->numberType, builtinTypes->numberType), + idx(join(builtinTypes->numberType, builtinTypes->stringType), builtinTypes->numberType) +); + +TEST_IS_NOT_SUBTYPE( + idx(builtinTypes->numberType, join(builtinTypes->stringType, builtinTypes->numberType)), + idx(builtinTypes->numberType, builtinTypes->numberType) +); +TEST_IS_NOT_SUBTYPE( + idx(builtinTypes->numberType, builtinTypes->numberType), + idx(builtinTypes->numberType, join(builtinTypes->stringType, builtinTypes->numberType)) +); + +TEST_IS_NOT_SUBTYPE(tbl({{"X", builtinTypes->numberType}}), idx(builtinTypes->stringType, builtinTypes->numberType)); +TEST_IS_SUBTYPE(idx(builtinTypes->stringType, builtinTypes->numberType), tbl({{"X", builtinTypes->numberType}})); + +TEST_IS_NOT_SUBTYPE(tbl({{"X", opt(builtinTypes->numberType)}}), idx(builtinTypes->stringType, builtinTypes->numberType)); +TEST_IS_NOT_SUBTYPE(idx(builtinTypes->stringType, builtinTypes->numberType), tbl({{"X", opt(builtinTypes->numberType)}})); + +TEST_IS_SUBTYPE(tbl({{"X", builtinTypes->numberType}, {"Y", builtinTypes->numberType}}), tbl({{"X", builtinTypes->numberType}})); +TEST_IS_NOT_SUBTYPE(tbl({{"X", builtinTypes->numberType}}), tbl({{"X", builtinTypes->numberType}, {"Y", builtinTypes->numberType}})); + +TEST_CASE_FIXTURE(SubtypeFixture, "interior_tests_are_cached") +{ + TypeId tableA = tbl({{"X", builtinTypes->numberType}, {"Y", builtinTypes->numberType}}); + TypeId tableB = tbl({{"X", builtinTypes->optionalNumberType}, {"Y", builtinTypes->optionalNumberType}}); + + CHECK_IS_NOT_SUBTYPE(tableA, tableB); + + const SubtypingResult* cachedResult = subtyping.peekCache().find({builtinTypes->numberType, builtinTypes->optionalNumberType}); + REQUIRE(cachedResult); + + CHECK(cachedResult->isSubtype); + + cachedResult = subtyping.peekCache().find({tableA, tableB}); + REQUIRE(cachedResult); + + CHECK(!cachedResult->isSubtype); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "results_that_are_contingent_on_generics_are_not_cached") +{ + // (T) -> T <: (number) -> number + CHECK_IS_SUBTYPE(genericTToTType, numberToNumberType); + + CHECK(subtyping.peekCache().empty()); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "dont_cache_tests_involving_cycles") +{ + TypeId tableA = arena.addType(BlockedType{}); + TypeId tableA2 = tbl({{"self", tableA}}); + asMutable(tableA)->ty.emplace(tableA2); + + TypeId tableB = arena.addType(BlockedType{}); + TypeId tableB2 = tbl({{"self", tableB}}); + asMutable(tableB)->ty.emplace(tableB2); + + CHECK_IS_SUBTYPE(tableA, tableB); + + CHECK(!subtyping.peekCache().find({tableA, tableB})); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "({ x: T }) -> T <: ({ method: ({ x: T }) -> T, x: number }) -> number") +{ + // ({ x: T }) -> T + TypeId tableToPropType = arena.addType(FunctionType{{genericT}, {}, arena.addTypePack({tbl({{"x", genericT}})}), arena.addTypePack({genericT})}); + + // ({ method: ({ x: T }) -> T, x: number }) -> number + TypeId otherType = fn({tbl({{"method", tableToPropType}, {"x", builtinTypes->numberType}})}, {builtinTypes->numberType}); + + CHECK_IS_SUBTYPE(tableToPropType, otherType); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "subtyping_reasonings_to_follow_a_reduced_type_function_instance") +{ + TypeId longTy = arena.addType(UnionType{ + {builtinTypes->booleanType, + builtinTypes->bufferType, + builtinTypes->classType, + builtinTypes->functionType, + builtinTypes->numberType, + builtinTypes->stringType, + builtinTypes->tableType, + builtinTypes->threadType} + }); + TypeId tblTy = tbl({{"depth", builtinTypes->unknownType}}); + TypeId combined = meet(longTy, tblTy); + TypeId subTy = arena.addType(TypeFunctionInstanceType{NotNull{&builtinTypeFunctions.unionFunc}, {combined, builtinTypes->neverType}, {}}); + TypeId superTy = builtinTypes->neverType; + SubtypingResult result = isSubtype(subTy, superTy); + CHECK(!result.isSubtype); + + for (const SubtypingReasoning& reasoning : result.reasoning) + { + if (reasoning.subPath.empty() && reasoning.superPath.empty()) + continue; + + std::optional optSubLeaf = traverse(subTy, reasoning.subPath, builtinTypes); + std::optional optSuperLeaf = traverse(superTy, reasoning.superPath, builtinTypes); + + if (!optSubLeaf || !optSuperLeaf) + CHECK(false); + } +} + +TEST_SUITE_END(); + +TEST_SUITE_BEGIN("Subtyping.Subpaths"); + +TEST_CASE_FIXTURE(SubtypeFixture, "table_property") +{ + TypeId subTy = tbl({{"X", builtinTypes->numberType}}); + TypeId superTy = tbl({{"X", builtinTypes->booleanType}}); + + SubtypingResult result = isSubtype(subTy, superTy); + CHECK(!result.isSubtype); + REQUIRE(result.reasoning.size() == 1); + CHECK( + *result.reasoning.begin() == SubtypingReasoning{ + /* subPath */ Path(TypePath::Property::read("X")), + /* superPath */ Path(TypePath::Property::read("X")), + /* variance */ SubtypingVariance::Invariant + } + ); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "table_indexers") +{ + TypeId subTy = idx(builtinTypes->numberType, builtinTypes->stringType); + TypeId superTy = idx(builtinTypes->stringType, builtinTypes->numberType); + + SubtypingResult result = isSubtype(subTy, superTy); + CHECK(!result.isSubtype); + CHECK( + result.reasoning == + std::vector{ + SubtypingReasoning{ + /* subPath */ Path(TypePath::TypeField::IndexLookup), + /* superPath */ Path(TypePath::TypeField::IndexLookup), + /* variance */ SubtypingVariance::Invariant, + }, + SubtypingReasoning{ + /* subPath */ Path(TypePath::TypeField::IndexResult), + /* superPath */ Path(TypePath::TypeField::IndexResult), + /* variance */ SubtypingVariance::Invariant, + } + } + ); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "fn_arguments") +{ + TypeId subTy = fn({builtinTypes->numberType}, {}); + TypeId superTy = fn({builtinTypes->stringType}, {}); + + SubtypingResult result = isSubtype(subTy, superTy); + CHECK(!result.isSubtype); + CHECK( + result.reasoning == std::vector{SubtypingReasoning{ + /* subPath */ TypePath::PathBuilder().args().index(0).build(), + /* superPath */ TypePath::PathBuilder().args().index(0).build(), + /* variance */ SubtypingVariance::Contravariant, + }} + ); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "arity_mismatch") +{ + TypeId subTy = fn({builtinTypes->numberType}, {}); + TypeId superTy = fn({}, {}); + + SubtypingResult result = isSubtype(subTy, superTy); + CHECK(!result.isSubtype); + CHECK( + result.reasoning == std::vector{SubtypingReasoning{ + /* subPath */ TypePath::PathBuilder().args().build(), + /* superPath */ TypePath::PathBuilder().args().build(), + /* variance */ SubtypingVariance::Contravariant, + }} + ); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "fn_arguments_tail") +{ + TypeId subTy = fn({}, VariadicTypePack{builtinTypes->numberType}, {}); + TypeId superTy = fn({}, VariadicTypePack{builtinTypes->stringType}, {}); + + SubtypingResult result = isSubtype(subTy, superTy); + CHECK(!result.isSubtype); + CHECK( + result.reasoning == std::vector{SubtypingReasoning{ + /* subPath */ TypePath::PathBuilder().args().tail().variadic().build(), + /* superPath */ TypePath::PathBuilder().args().tail().variadic().build(), + /* variance */ SubtypingVariance::Contravariant, + }} + ); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "fn_rets") +{ + TypeId subTy = fn({}, {builtinTypes->numberType}); + TypeId superTy = fn({}, {builtinTypes->stringType}); + + SubtypingResult result = isSubtype(subTy, superTy); + CHECK(!result.isSubtype); + REQUIRE(result.reasoning.size() == 1); + CHECK( + *result.reasoning.begin() == SubtypingReasoning{ + /* subPath */ TypePath::PathBuilder().rets().index(0).build(), + /* superPath */ TypePath::PathBuilder().rets().index(0).build(), + } + ); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "fn_rets_tail") +{ + TypeId subTy = fn({}, {}, VariadicTypePack{builtinTypes->numberType}); + TypeId superTy = fn({}, {}, VariadicTypePack{builtinTypes->stringType}); + + SubtypingResult result = isSubtype(subTy, superTy); + CHECK(!result.isSubtype); + REQUIRE(result.reasoning.size() == 1); + CHECK( + *result.reasoning.begin() == SubtypingReasoning{ + /* subPath */ TypePath::PathBuilder().rets().tail().variadic().build(), + /* superPath */ TypePath::PathBuilder().rets().tail().variadic().build(), + } + ); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "nested_table_properties") +{ + TypeId subTy = tbl({{"X", tbl({{"Y", tbl({{"Z", builtinTypes->numberType}})}})}}); + TypeId superTy = tbl({{"X", tbl({{"Y", tbl({{"Z", builtinTypes->stringType}})}})}}); + + SubtypingResult result = isSubtype(subTy, superTy); + CHECK(!result.isSubtype); + REQUIRE(result.reasoning.size() == 1); + CHECK( + *result.reasoning.begin() == SubtypingReasoning{ + /* subPath */ TypePath::PathBuilder().readProp("X").readProp("Y").readProp("Z").build(), + /* superPath */ TypePath::PathBuilder().readProp("X").readProp("Y").readProp("Z").build(), + /* variance */ SubtypingVariance::Invariant, + } + ); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "string_table_mt") +{ + TypeId subTy = builtinTypes->stringType; + TypeId superTy = tbl({{"X", builtinTypes->numberType}}); + + SubtypingResult result = isSubtype(subTy, superTy); + CHECK(!result.isSubtype); + // This check is weird. Because we don't have built-in types, we don't have + // the string metatable. That means subtyping will see that the entire + // metatable is empty, and abort there, without looking at the metatable + // properties (because there aren't any). + CHECK( + result.reasoning == std::vector{SubtypingReasoning{ + /* subPath */ TypePath::PathBuilder().mt().readProp("__index").build(), + /* superPath */ TypePath::kEmpty, + }} + ); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "negation") +{ + TypeId subTy = builtinTypes->numberType; + TypeId superTy = negate(builtinTypes->numberType); + + SubtypingResult result = isSubtype(subTy, superTy); + CHECK(!result.isSubtype); + CHECK( + result.reasoning == std::vector{SubtypingReasoning{ + /* subPath */ TypePath::kEmpty, + /* superPath */ Path(TypePath::TypeField::Negated), + }} + ); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "multiple_reasonings") +{ + TypeId subTy = tbl({{"X", builtinTypes->stringType}, {"Y", builtinTypes->numberType}}); + TypeId superTy = tbl({{"X", builtinTypes->numberType}, {"Y", builtinTypes->stringType}}); + + SubtypingResult result = isSubtype(subTy, superTy); + CHECK(!result.isSubtype); + CHECK( + result.reasoning == + std::vector{ + SubtypingReasoning{ + /* subPath */ Path(TypePath::Property::read("X")), + /* superPath */ Path(TypePath::Property::read("X")), + /* variance */ SubtypingVariance::Invariant + }, + SubtypingReasoning{ + /* subPath */ Path(TypePath::Property::read("Y")), + /* superPath */ Path(TypePath::Property::read("Y")), + /* variance */ SubtypingVariance::Invariant + }, + } + ); +} + +TEST_SUITE_END(); diff --git a/tests/Symbol.test.cpp b/tests/Symbol.test.cpp index 278c6ce2b..83482c03f 100644 --- a/tests/Symbol.test.cpp +++ b/tests/Symbol.test.cpp @@ -8,6 +8,8 @@ using namespace Luau; +LUAU_FASTFLAG(LuauSolverV2) + TEST_SUITE_BEGIN("SymbolTests"); TEST_CASE("equality_and_hashing_of_globals") @@ -66,7 +68,7 @@ TEST_CASE("equality_and_hashing_of_locals") TEST_CASE("equality_of_empty_symbols") { - ScopedFastFlag sff{"DebugLuauDeferredConstraintResolution", true}; + ScopedFastFlag sff{FFlag::LuauSolverV2, true}; std::string s1 = "name"; std::string s2 = "name"; diff --git a/tests/ToDot.test.cpp b/tests/ToDot.test.cpp index 11dca1106..fd72579b2 100644 --- a/tests/ToDot.test.cpp +++ b/tests/ToDot.test.cpp @@ -9,31 +9,31 @@ using namespace Luau; -LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); +LUAU_FASTFLAG(LuauSolverV2); struct ToDotClassFixture : Fixture { ToDotClassFixture() { - TypeArena& arena = typeChecker.globalTypes; + TypeArena& arena = frontend.globals.globalTypes; unfreeze(arena); TypeId baseClassMetaType = arena.addType(TableType{}); - TypeId baseClassInstanceType = arena.addType(ClassType{"BaseClass", {}, std::nullopt, baseClassMetaType, {}, {}, "Test"}); + TypeId baseClassInstanceType = arena.addType(ClassType{"BaseClass", {}, std::nullopt, baseClassMetaType, {}, {}, "Test", {}}); getMutable(baseClassInstanceType)->props = { - {"BaseField", {typeChecker.numberType}}, + {"BaseField", {builtinTypes->numberType}}, }; - typeChecker.globalScope->exportedTypeBindings["BaseClass"] = TypeFun{{}, baseClassInstanceType}; + frontend.globals.globalScope->exportedTypeBindings["BaseClass"] = TypeFun{{}, baseClassInstanceType}; - TypeId childClassInstanceType = arena.addType(ClassType{"ChildClass", {}, baseClassInstanceType, std::nullopt, {}, {}, "Test"}); + TypeId childClassInstanceType = arena.addType(ClassType{"ChildClass", {}, baseClassInstanceType, std::nullopt, {}, {}, "Test", {}}); getMutable(childClassInstanceType)->props = { - {"ChildField", {typeChecker.stringType}}, + {"ChildField", {builtinTypes->stringType}}, }; - typeChecker.globalScope->exportedTypeBindings["ChildClass"] = TypeFun{{}, childClassInstanceType}; + frontend.globals.globalScope->exportedTypeBindings["ChildClass"] = TypeFun{{}, childClassInstanceType}; - for (const auto& [name, ty] : typeChecker.globalScope->exportedTypeBindings) + for (const auto& [name, ty] : frontend.globals.globalScope->exportedTypeBindings) persist(ty.type); freeze(arena); @@ -44,38 +44,75 @@ TEST_SUITE_BEGIN("ToDot"); TEST_CASE_FIXTURE(Fixture, "primitive") { - CheckResult result = check(R"( -local a: nil -local b: number -local c: any -)"); - LUAU_REQUIRE_NO_ERRORS(result); - - CHECK_NE("nil", toDot(requireType("a"))); + CHECK_EQ( + R"(digraph graphname { +n1 [label="nil"]; +})", + toDot(builtinTypes->nilType) + ); - CHECK_EQ(R"(digraph graphname { + CHECK_EQ( + R"(digraph graphname { n1 [label="number"]; })", - toDot(requireType("b"))); + toDot(builtinTypes->numberType) + ); - CHECK_EQ(R"(digraph graphname { + CHECK_EQ( + R"(digraph graphname { n1 [label="any"]; })", - toDot(requireType("c"))); + toDot(builtinTypes->anyType) + ); + + CHECK_EQ( + R"(digraph graphname { +n1 [label="unknown"]; +})", + toDot(builtinTypes->unknownType) + ); + CHECK_EQ( + R"(digraph graphname { +n1 [label="never"]; +})", + toDot(builtinTypes->neverType) + ); +} + +TEST_CASE_FIXTURE(Fixture, "no_duplicatePrimitives") +{ ToDotOptions opts; opts.showPointers = false; opts.duplicatePrimitives = false; - CHECK_EQ(R"(digraph graphname { + CHECK_EQ( + R"(digraph graphname { n1 [label="PrimitiveType number"]; })", - toDot(requireType("b"), opts)); + toDot(builtinTypes->numberType, opts) + ); - CHECK_EQ(R"(digraph graphname { + CHECK_EQ( + R"(digraph graphname { n1 [label="AnyType 1"]; })", - toDot(requireType("c"), opts)); + toDot(builtinTypes->anyType, opts) + ); + + CHECK_EQ( + R"(digraph graphname { +n1 [label="UnknownType 1"]; +})", + toDot(builtinTypes->unknownType, opts) + ); + + CHECK_EQ( + R"(digraph graphname { +n1 [label="NeverType 1"]; +})", + toDot(builtinTypes->neverType, opts) + ); } TEST_CASE_FIXTURE(Fixture, "bound") @@ -86,12 +123,14 @@ TEST_CASE_FIXTURE(Fixture, "bound") ToDotOptions opts; opts.showPointers = false; - CHECK_EQ(R"(digraph graphname { + CHECK_EQ( + R"(digraph graphname { n1 [label="BoundType 1"]; n1 -> n2; n2 [label="number"]; })", - toDot(ty, opts)); + toDot(ty, opts) + ); } TEST_CASE_FIXTURE(Fixture, "function") @@ -106,9 +145,10 @@ local function f(a, ...: string) return a end ToDotOptions opts; opts.showPointers = false; - if (FFlag::DebugLuauDeferredConstraintResolution) + if (FFlag::LuauSolverV2) { - CHECK_EQ(R"(digraph graphname { + CHECK_EQ( + R"(digraph graphname { n1 [label="FunctionType 1"]; n1 -> n2 [label="arg"]; n2 [label="TypePack 2"]; @@ -119,14 +159,18 @@ n4 [label="VariadicTypePack 4"]; n4 -> n5; n5 [label="string"]; n1 -> n6 [label="ret"]; -n6 [label="TypePack 6"]; -n6 -> n3; +n6 [label="BoundTypePack 6"]; +n6 -> n7; +n7 [label="TypePack 7"]; +n7 -> n3; })", - toDot(requireType("f"), opts)); + toDot(requireType("f"), opts) + ); } else { - CHECK_EQ(R"(digraph graphname { + CHECK_EQ( + R"(digraph graphname { n1 [label="FunctionType 1"]; n1 -> n2 [label="arg"]; n2 [label="TypePack 2"]; @@ -142,7 +186,8 @@ n6 -> n7; n7 [label="TypePack 7"]; n7 -> n3; })", - toDot(requireType("f"), opts)); + toDot(requireType("f"), opts) + ); } } @@ -155,14 +200,16 @@ local a: string | number ToDotOptions opts; opts.showPointers = false; - CHECK_EQ(R"(digraph graphname { + CHECK_EQ( + R"(digraph graphname { n1 [label="UnionType 1"]; n1 -> n2; n2 [label="string"]; n1 -> n3; n3 [label="number"]; })", - toDot(requireType("a"), opts)); + toDot(requireType("a"), opts) + ); } TEST_CASE_FIXTURE(Fixture, "intersection") @@ -173,14 +220,16 @@ TEST_CASE_FIXTURE(Fixture, "intersection") ToDotOptions opts; opts.showPointers = false; - CHECK_EQ(R"(digraph graphname { + CHECK_EQ( + R"(digraph graphname { n1 [label="IntersectionType 1"]; n1 -> n2; n2 [label="string"]; n1 -> n3; n3 [label="number"]; })", - toDot(ty, opts)); + toDot(ty, opts) + ); } TEST_CASE_FIXTURE(Fixture, "table") @@ -193,35 +242,36 @@ local a: A ToDotOptions opts; opts.showPointers = false; - if (FFlag::DebugLuauDeferredConstraintResolution) + if (FFlag::LuauSolverV2) { - CHECK_EQ(R"(digraph graphname { + CHECK_EQ( + R"(digraph graphname { n1 [label="TableType A"]; n1 -> n2 [label="x"]; n2 [label="number"]; n1 -> n3 [label="y"]; n3 [label="FunctionType 3"]; n3 -> n4 [label="arg"]; -n4 [label="TypePack 4"]; -n4 -> n5 [label="tail"]; -n5 [label="VariadicTypePack 5"]; -n5 -> n6; -n6 [label="string"]; -n3 -> n7 [label="ret"]; -n7 [label="TypePack 7"]; -n1 -> n8 [label="[index]"]; -n8 [label="string"]; -n1 -> n9 [label="[value]"]; -n9 [label="any"]; -n1 -> n10 [label="typeParam"]; -n10 [label="number"]; -n1 -> n5 [label="typePackParam"]; +n4 [label="VariadicTypePack 4"]; +n4 -> n5; +n5 [label="string"]; +n3 -> n6 [label="ret"]; +n6 [label="TypePack 6"]; +n1 -> n7 [label="[index]"]; +n7 [label="string"]; +n1 -> n8 [label="[value]"]; +n8 [label="any"]; +n1 -> n9 [label="typeParam"]; +n9 [label="number"]; +n1 -> n4 [label="typePackParam"]; })", - toDot(requireType("a"), opts)); + toDot(requireType("a"), opts) + ); } else { - CHECK_EQ(R"(digraph graphname { + CHECK_EQ( + R"(digraph graphname { n1 [label="TableType A"]; n1 -> n2 [label="x"]; n2 [label="number"]; @@ -241,7 +291,8 @@ n1 -> n9 [label="typeParam"]; n9 [label="number"]; n1 -> n4 [label="typePackParam"]; })", - toDot(requireType("a"), opts)); + toDot(requireType("a"), opts) + ); } // Extra coverage with pointers (unstable values) @@ -257,26 +308,60 @@ local a: typeof(setmetatable({}, {})) ToDotOptions opts; opts.showPointers = false; - CHECK_EQ(R"(digraph graphname { + CHECK_EQ( + R"(digraph graphname { n1 [label="MetatableType 1"]; n1 -> n2 [label="table"]; n2 [label="TableType 2"]; n1 -> n3 [label="metatable"]; n3 [label="TableType 3"]; })", - toDot(requireType("a"), opts)); + toDot(requireType("a"), opts) + ); } TEST_CASE_FIXTURE(Fixture, "free") { + ScopedFastFlag sff[] = { + {FFlag::LuauSolverV2, false}, + }; + Type type{TypeVariant{FreeType{TypeLevel{0, 0}}}}; ToDotOptions opts; opts.showPointers = false; - CHECK_EQ(R"(digraph graphname { + CHECK_EQ( + R"(digraph graphname { n1 [label="FreeType 1"]; })", - toDot(&type, opts)); + toDot(&type, opts) + ); +} + +TEST_CASE_FIXTURE(Fixture, "free_with_constraints") +{ + ScopedFastFlag sff[] = { + {FFlag::LuauSolverV2, true}, + }; + + Type type{TypeVariant{FreeType{nullptr, builtinTypes->numberType, builtinTypes->optionalNumberType}}}; + + ToDotOptions opts; + opts.showPointers = false; + CHECK_EQ( + R"(digraph graphname { +n1 [label="FreeType 1"]; +n1 -> n2 [label="[lowerBound]"]; +n2 [label="number"]; +n1 -> n3 [label="[upperBound]"]; +n3 [label="UnionType 3"]; +n3 -> n4; +n4 [label="number"]; +n3 -> n5; +n5 [label="nil"]; +})", + toDot(&type, opts) + ); } TEST_CASE_FIXTURE(Fixture, "error") @@ -285,10 +370,12 @@ TEST_CASE_FIXTURE(Fixture, "error") ToDotOptions opts; opts.showPointers = false; - CHECK_EQ(R"(digraph graphname { + CHECK_EQ( + R"(digraph graphname { n1 [label="ErrorType 1"]; })", - toDot(&type, opts)); + toDot(&type, opts) + ); } TEST_CASE_FIXTURE(Fixture, "generic") @@ -297,10 +384,12 @@ TEST_CASE_FIXTURE(Fixture, "generic") ToDotOptions opts; opts.showPointers = false; - CHECK_EQ(R"(digraph graphname { + CHECK_EQ( + R"(digraph graphname { n1 [label="GenericType T"]; })", - toDot(&type, opts)); + toDot(&type, opts) + ); } TEST_CASE_FIXTURE(ToDotClassFixture, "class") @@ -312,7 +401,8 @@ local a: ChildClass ToDotOptions opts; opts.showPointers = false; - CHECK_EQ(R"(digraph graphname { + CHECK_EQ( + R"(digraph graphname { n1 [label="ClassType ChildClass"]; n1 -> n2 [label="ChildField"]; n2 [label="string"]; @@ -323,7 +413,8 @@ n4 [label="number"]; n3 -> n5 [label="[metatable]"]; n5 [label="TableType 5"]; })", - toDot(requireType("a"), opts)); + toDot(requireType("a"), opts) + ); } TEST_CASE_FIXTURE(Fixture, "free_pack") @@ -332,10 +423,12 @@ TEST_CASE_FIXTURE(Fixture, "free_pack") ToDotOptions opts; opts.showPointers = false; - CHECK_EQ(R"(digraph graphname { + CHECK_EQ( + R"(digraph graphname { n1 [label="FreeTypePack 1"]; })", - toDot(&pack, opts)); + toDot(&pack, opts) + ); } TEST_CASE_FIXTURE(Fixture, "error_pack") @@ -344,10 +437,12 @@ TEST_CASE_FIXTURE(Fixture, "error_pack") ToDotOptions opts; opts.showPointers = false; - CHECK_EQ(R"(digraph graphname { + CHECK_EQ( + R"(digraph graphname { n1 [label="ErrorTypePack 1"]; })", - toDot(&pack, opts)); + toDot(&pack, opts) + ); // Extra coverage with pointers (unstable values) (void)toDot(&pack); @@ -360,32 +455,38 @@ TEST_CASE_FIXTURE(Fixture, "generic_pack") ToDotOptions opts; opts.showPointers = false; - CHECK_EQ(R"(digraph graphname { + CHECK_EQ( + R"(digraph graphname { n1 [label="GenericTypePack 1"]; })", - toDot(&pack1, opts)); + toDot(&pack1, opts) + ); - CHECK_EQ(R"(digraph graphname { + CHECK_EQ( + R"(digraph graphname { n1 [label="GenericTypePack T"]; })", - toDot(&pack2, opts)); + toDot(&pack2, opts) + ); } TEST_CASE_FIXTURE(Fixture, "bound_pack") { - TypePackVar pack{TypePackVariant{TypePack{{typeChecker.numberType}, {}}}}; + TypePackVar pack{TypePackVariant{TypePack{{builtinTypes->numberType}, {}}}}; TypePackVar bound{TypePackVariant{BoundTypePack{&pack}}}; ToDotOptions opts; opts.showPointers = false; - CHECK_EQ(R"(digraph graphname { + CHECK_EQ( + R"(digraph graphname { n1 [label="BoundTypePack 1"]; n1 -> n2; n2 [label="TypePack 2"]; n2 -> n3; n3 [label="number"]; })", - toDot(&bound, opts)); + toDot(&bound, opts) + ); } TEST_CASE_FIXTURE(Fixture, "bound_table") @@ -401,14 +502,16 @@ TEST_CASE_FIXTURE(Fixture, "bound_table") ToDotOptions opts; opts.showPointers = false; - CHECK_EQ(R"(digraph graphname { + CHECK_EQ( + R"(digraph graphname { n1 [label="TableType 1"]; n1 -> n2 [label="boundTo"]; n2 [label="TableType 2"]; n2 -> n3 [label="x"]; n3 [label="number"]; })", - toDot(boundTy, opts)); + toDot(boundTy, opts) + ); } TEST_CASE_FIXTURE(Fixture, "builtintypes") @@ -420,20 +523,37 @@ TEST_CASE_FIXTURE(Fixture, "builtintypes") ToDotOptions opts; opts.showPointers = false; - CHECK_EQ(R"(digraph graphname { + CHECK_EQ( + R"(digraph graphname { n1 [label="UnionType 1"]; n1 -> n2; n2 [label="SingletonType string: hi"]; n1 -> n3; )" - "n3 [label=\"SingletonType string: \\\"hello\\\"\"];" - R"( + "n3 [label=\"SingletonType string: \\\"hello\\\"\"];" + R"( n1 -> n4; n4 [label="SingletonType boolean: true"]; n1 -> n5; n5 [label="SingletonType boolean: false"]; })", - toDot(requireType("x"), opts)); + toDot(requireType("x"), opts) + ); +} + +TEST_CASE_FIXTURE(Fixture, "negation") +{ + TypeArena arena; + TypeId t = arena.addType(NegationType{builtinTypes->stringType}); + + ToDotOptions opts; + opts.showPointers = false; + + CHECK(R"(digraph graphname { +n1 [label="NegationType 1"]; +n1 -> n2 [label="[negated]"]; +n2 [label="string"]; +})" == toDot(t, opts)); } TEST_SUITE_END(); diff --git a/tests/ToString.test.cpp b/tests/ToString.test.cpp index 2fc5187b8..fe87b6a70 100644 --- a/tests/ToString.test.cpp +++ b/tests/ToString.test.cpp @@ -11,6 +11,9 @@ using namespace Luau; LUAU_FASTFLAG(LuauRecursiveTypeParameterRestriction); +LUAU_FASTFLAG(LuauSolverV2); +LUAU_FASTFLAG(LuauAttributeSyntax); +LUAU_FASTFLAG(LuauUserDefinedTypeFunctions2) TEST_SUITE_BEGIN("ToString"); @@ -19,8 +22,13 @@ TEST_CASE_FIXTURE(Fixture, "primitive") CheckResult result = check("local a = nil local b = 44 local c = 'lalala' local d = true"); LUAU_REQUIRE_NO_ERRORS(result); - // A variable without an annotation and with a nil literal should infer as 'free', not 'nil' - CHECK_NE("nil", toString(requireType("a"))); + if (FFlag::LuauSolverV2) + CHECK("nil" == toString(requireType("a"))); + else + { + // A variable without an annotation and with a nil literal should infer as 'free', not 'nil' + CHECK_NE("nil", toString(requireType("a"))); + } CHECK_EQ("number", toString(requireType("b"))); CHECK_EQ("string", toString(requireType("c"))); @@ -37,6 +45,8 @@ TEST_CASE_FIXTURE(Fixture, "bound_types") TEST_CASE_FIXTURE(Fixture, "free_types") { + ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + CheckResult result = check("local a"); LUAU_REQUIRE_NO_ERRORS(result); @@ -49,7 +59,10 @@ TEST_CASE_FIXTURE(Fixture, "cyclic_table") TableType* tableOne = getMutable(&cyclicTable); tableOne->props["self"] = {&cyclicTable}; - CHECK_EQ("t1 where t1 = { self: t1 }", toString(&cyclicTable)); + if (FFlag::LuauSolverV2) + CHECK_EQ("t1 where t1 = {| self: t1 |}", toString(&cyclicTable)); + else + CHECK_EQ("t1 where t1 = { self: t1 }", toString(&cyclicTable)); } TEST_CASE_FIXTURE(Fixture, "named_table") @@ -67,12 +80,18 @@ TEST_CASE_FIXTURE(Fixture, "empty_table") local a: {} )"); - CHECK_EQ("{| |}", toString(requireType("a"))); + if (FFlag::LuauSolverV2) + CHECK_EQ("{ }", toString(requireType("a"))); + else + CHECK_EQ("{| |}", toString(requireType("a"))); // Should stay the same with useLineBreaks enabled ToStringOptions opts; opts.useLineBreaks = true; - CHECK_EQ("{| |}", toString(requireType("a"), opts)); + if (FFlag::LuauSolverV2) + CHECK_EQ("{ }", toString(requireType("a"), opts)); + else + CHECK_EQ("{| |}", toString(requireType("a"), opts)); } TEST_CASE_FIXTURE(Fixture, "table_respects_use_line_break") @@ -84,14 +103,24 @@ TEST_CASE_FIXTURE(Fixture, "table_respects_use_line_break") ToStringOptions opts; opts.useLineBreaks = true; - //clang-format off - CHECK_EQ("{|\n" - " anotherProp: number,\n" - " prop: string,\n" - " thirdProp: boolean\n" - "|}", - toString(requireType("a"), opts)); - //clang-format on + if (FFlag::LuauSolverV2) + CHECK_EQ( + "{\n" + " anotherProp: number,\n" + " prop: string,\n" + " thirdProp: boolean\n" + "}", + toString(requireType("a"), opts) + ); + else + CHECK_EQ( + "{|\n" + " anotherProp: number,\n" + " prop: string,\n" + " thirdProp: boolean\n" + "|}", + toString(requireType("a"), opts) + ); } TEST_CASE_FIXTURE(Fixture, "nil_or_nil_is_nil_not_question_mark") @@ -121,7 +150,10 @@ TEST_CASE_FIXTURE(Fixture, "metatable") Type table{TypeVariant(TableType())}; Type metatable{TypeVariant(TableType())}; Type mtv{TypeVariant(MetatableType{&table, &metatable})}; - CHECK_EQ("{ @metatable { }, { } }", toString(&mtv)); + if (FFlag::LuauSolverV2) + CHECK_EQ("{ @metatable {| |}, {| |} }", toString(&mtv)); + else + CHECK_EQ("{ @metatable { }, { } }", toString(&mtv)); } TEST_CASE_FIXTURE(Fixture, "named_metatable") @@ -134,6 +166,8 @@ TEST_CASE_FIXTURE(Fixture, "named_metatable") TEST_CASE_FIXTURE(BuiltinsFixture, "named_metatable_toStringNamedFunction") { + ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + CheckResult result = check(R"( local function createTbl(): NamedMetatable return setmetatable({}, {}) @@ -173,38 +207,50 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "exhaustive_toString_of_cyclic_table") CHECK_EQ(std::string::npos, a.find("CYCLE")); CHECK_EQ(std::string::npos, a.find("TRUNCATED")); - //clang-format off - CHECK_EQ("t2 where " - "t1 = { __index: t1, __mul: ((t2, number) -> t2) & ((t2, t2) -> t2), new: () -> t2 } ; " - "t2 = { @metatable t1, {| x: number, y: number, z: number |} }", - a); - //clang-format on + if (FFlag::LuauSolverV2) + { + CHECK( + "t2 where " + "t1 = { __index: t1, __mul: ((t2, number) -> t2) & ((t2, t2) -> t2), new: () -> t2 } ; " + "t2 = { @metatable t1, { x: number, y: number, z: number } }" == + a + ); + } + else + { + CHECK_EQ( + "t2 where " + "t1 = { __index: t1, __mul: ((t2, number) -> t2) & ((t2, t2) -> t2), new: () -> t2 } ; " + "t2 = { @metatable t1, {| x: number, y: number, z: number |} }", + a + ); + } } TEST_CASE_FIXTURE(Fixture, "intersection_parenthesized_only_if_needed") { - auto utv = Type{UnionType{{typeChecker.numberType, typeChecker.stringType}}}; - auto itv = Type{IntersectionType{{&utv, typeChecker.booleanType}}}; + auto utv = Type{UnionType{{builtinTypes->numberType, builtinTypes->stringType}}}; + auto itv = Type{IntersectionType{{&utv, builtinTypes->booleanType}}}; CHECK_EQ(toString(&itv), "(number | string) & boolean"); } TEST_CASE_FIXTURE(Fixture, "union_parenthesized_only_if_needed") { - auto itv = Type{IntersectionType{{typeChecker.numberType, typeChecker.stringType}}}; - auto utv = Type{UnionType{{&itv, typeChecker.booleanType}}}; + auto itv = Type{IntersectionType{{builtinTypes->numberType, builtinTypes->stringType}}}; + auto utv = Type{UnionType{{&itv, builtinTypes->booleanType}}}; CHECK_EQ(toString(&utv), "(number & string) | boolean"); } TEST_CASE_FIXTURE(Fixture, "functions_are_always_parenthesized_in_unions_or_intersections") { - auto stringAndNumberPack = TypePackVar{TypePack{{typeChecker.stringType, typeChecker.numberType}}}; - auto numberAndStringPack = TypePackVar{TypePack{{typeChecker.numberType, typeChecker.stringType}}}; + auto stringAndNumberPack = TypePackVar{TypePack{{builtinTypes->stringType, builtinTypes->numberType}}}; + auto numberAndStringPack = TypePackVar{TypePack{{builtinTypes->numberType, builtinTypes->stringType}}}; auto sn2ns = Type{FunctionType{&stringAndNumberPack, &numberAndStringPack}}; - auto ns2sn = Type{FunctionType(typeChecker.globalScope->level, &numberAndStringPack, &stringAndNumberPack)}; + auto ns2sn = Type{FunctionType(frontend.globals.globalScope->level, &numberAndStringPack, &stringAndNumberPack)}; auto utv = Type{UnionType{{&ns2sn, &sn2ns}}}; auto itv = Type{IntersectionType{{&ns2sn, &sn2ns}}}; @@ -213,7 +259,37 @@ TEST_CASE_FIXTURE(Fixture, "functions_are_always_parenthesized_in_unions_or_inte CHECK_EQ(toString(&itv), "((number, string) -> (string, number)) & ((string, number) -> (number, string))"); } -TEST_CASE_FIXTURE(Fixture, "intersections_respects_use_line_breaks") +TEST_CASE_FIXTURE(Fixture, "simple_intersections_printed_on_one_line") +{ + CheckResult result = check(R"( + local a: string & number + )"); + + ToStringOptions opts; + opts.useLineBreaks = true; + + CHECK_EQ("number & string", toString(requireType("a"), opts)); +} + +TEST_CASE_FIXTURE(Fixture, "complex_intersections_printed_on_multiple_lines") +{ + CheckResult result = check(R"( + local a: string & number & boolean + )"); + + ToStringOptions opts; + opts.useLineBreaks = true; + opts.compositeTypesSingleLineLimit = 2; + + CHECK_EQ( + "boolean\n" + "& number\n" + "& string", + toString(requireType("a"), opts) + ); +} + +TEST_CASE_FIXTURE(Fixture, "overloaded_functions_always_printed_on_multiple_lines") { CheckResult result = check(R"( local a: ((string) -> string) & ((number) -> number) @@ -222,56 +298,75 @@ TEST_CASE_FIXTURE(Fixture, "intersections_respects_use_line_breaks") ToStringOptions opts; opts.useLineBreaks = true; - //clang-format off - CHECK_EQ("((number) -> number)\n" - "& ((string) -> string)", - toString(requireType("a"), opts)); - //clang-format on + CHECK_EQ( + "((number) -> number)\n" + "& ((string) -> string)", + toString(requireType("a"), opts) + ); } -TEST_CASE_FIXTURE(Fixture, "unions_respects_use_line_breaks") +TEST_CASE_FIXTURE(Fixture, "simple_unions_printed_on_one_line") +{ + CheckResult result = check(R"( + local a: number | boolean + )"); + + ToStringOptions opts; + opts.useLineBreaks = true; + + CHECK_EQ("boolean | number", toString(requireType("a"), opts)); +} + +TEST_CASE_FIXTURE(Fixture, "complex_unions_printed_on_multiple_lines") { CheckResult result = check(R"( local a: string | number | boolean )"); ToStringOptions opts; + opts.compositeTypesSingleLineLimit = 2; opts.useLineBreaks = true; - //clang-format off - CHECK_EQ("boolean\n" - "| number\n" - "| string", - toString(requireType("a"), opts)); - //clang-format on + CHECK_EQ( + "boolean\n" + "| number\n" + "| string", + toString(requireType("a"), opts) + ); } TEST_CASE_FIXTURE(Fixture, "quit_stringifying_table_type_when_length_is_exceeded") { TableType ttv{}; for (char c : std::string("abcdefghijklmno")) - ttv.props[std::string(1, c)] = {typeChecker.numberType}; + ttv.props[std::string(1, c)] = {builtinTypes->numberType}; Type tv{ttv}; ToStringOptions o; o.exhaustive = false; o.maxTableLength = 40; - CHECK_EQ(toString(&tv, o), "{ a: number, b: number, c: number, d: number, e: number, ... 10 more ... }"); + if (FFlag::LuauSolverV2) + CHECK_EQ(toString(&tv, o), "{| a: number, b: number, c: number, d: number, e: number, ... 10 more ... |}"); + else + CHECK_EQ(toString(&tv, o), "{ a: number, b: number, c: number, d: number, e: number, ... 10 more ... }"); } TEST_CASE_FIXTURE(Fixture, "stringifying_table_type_is_still_capped_when_exhaustive") { TableType ttv{}; for (char c : std::string("abcdefg")) - ttv.props[std::string(1, c)] = {typeChecker.numberType}; + ttv.props[std::string(1, c)] = {builtinTypes->numberType}; Type tv{ttv}; ToStringOptions o; o.exhaustive = true; o.maxTableLength = 40; - CHECK_EQ(toString(&tv, o), "{ a: number, b: number, c: number, d: number, e: number, ... 2 more ... }"); + if (FFlag::LuauSolverV2) + CHECK_EQ(toString(&tv, o), "{| a: number, b: number, c: number, d: number, e: number, ... 2 more ... |}"); + else + CHECK_EQ(toString(&tv, o), "{ a: number, b: number, c: number, d: number, e: number, ... 2 more ... }"); } TEST_CASE_FIXTURE(Fixture, "quit_stringifying_type_when_length_is_exceeded") @@ -282,21 +377,24 @@ TEST_CASE_FIXTURE(Fixture, "quit_stringifying_type_when_length_is_exceeded") function f2(f) return f or f1 end function f3(f) return f or f2 end )"); - LUAU_REQUIRE_NO_ERRORS(result); - - ToStringOptions o; - o.exhaustive = false; - - if (FFlag::DebugLuauDeferredConstraintResolution) + if (FFlag::LuauSolverV2) { - o.maxTypeLength = 30; + LUAU_REQUIRE_NO_ERRORS(result); + + ToStringOptions o; + o.exhaustive = false; + o.maxTypeLength = 20; CHECK_EQ(toString(requireType("f0"), o), "() -> ()"); - CHECK_EQ(toString(requireType("f1"), o), "(a) -> (() -> ()) | (a & ~false & ~nil)... *TRUNCATED*"); - CHECK_EQ(toString(requireType("f2"), o), "(b) -> ((a) -> (() -> ()) | (a & ~false & ~nil)... *TRUNCATED*"); - CHECK_EQ(toString(requireType("f3"), o), "(c) -> ((b) -> ((a) -> (() -> ()) | (a & ~false & ~nil)... *TRUNCATED*"); + CHECK_EQ(toString(requireType("f1"), o), "(a) -> (() -> ()) ... *TRUNCATED*"); + CHECK_EQ(toString(requireType("f2"), o), "(b) -> ((a) -> (() -> ())... *TRUNCATED*"); + CHECK_EQ(toString(requireType("f3"), o), "(c) -> ((b) -> ((a) -> (() -> ())... *TRUNCATED*"); } else { + LUAU_REQUIRE_NO_ERRORS(result); + + ToStringOptions o; + o.exhaustive = false; o.maxTypeLength = 40; CHECK_EQ(toString(requireType("f0"), o), "() -> ()"); CHECK_EQ(toString(requireType("f1"), o), "(() -> ()) -> () -> ()"); @@ -313,20 +411,25 @@ TEST_CASE_FIXTURE(Fixture, "stringifying_type_is_still_capped_when_exhaustive") function f2(f) return f or f1 end function f3(f) return f or f2 end )"); - LUAU_REQUIRE_NO_ERRORS(result); - ToStringOptions o; - o.exhaustive = true; - if (FFlag::DebugLuauDeferredConstraintResolution) + if (FFlag::LuauSolverV2) { - o.maxTypeLength = 30; + LUAU_REQUIRE_NO_ERRORS(result); + + ToStringOptions o; + o.exhaustive = true; + o.maxTypeLength = 20; CHECK_EQ(toString(requireType("f0"), o), "() -> ()"); - CHECK_EQ(toString(requireType("f1"), o), "(a) -> (() -> ()) | (a & ~false & ~nil)... *TRUNCATED*"); - CHECK_EQ(toString(requireType("f2"), o), "(b) -> ((a) -> (() -> ()) | (a & ~false & ~nil)... *TRUNCATED*"); - CHECK_EQ(toString(requireType("f3"), o), "(c) -> ((b) -> ((a) -> (() -> ()) | (a & ~false & ~nil)... *TRUNCATED*"); + CHECK_EQ(toString(requireType("f1"), o), "(a) -> (() -> ()) ... *TRUNCATED*"); + CHECK_EQ(toString(requireType("f2"), o), "(b) -> ((a) -> (() -> ())... *TRUNCATED*"); + CHECK_EQ(toString(requireType("f3"), o), "(c) -> ((b) -> ((a) -> (() -> ())... *TRUNCATED*"); } else { + LUAU_REQUIRE_NO_ERRORS(result); + + ToStringOptions o; + o.exhaustive = true; o.maxTypeLength = 40; CHECK_EQ(toString(requireType("f0"), o), "() -> ()"); CHECK_EQ(toString(requireType("f1"), o), "(() -> ()) -> () -> ()"); @@ -339,18 +442,21 @@ TEST_CASE_FIXTURE(Fixture, "stringifying_table_type_correctly_use_matching_table { TableType ttv{TableState::Sealed, TypeLevel{}}; for (char c : std::string("abcdefghij")) - ttv.props[std::string(1, c)] = {typeChecker.numberType}; + ttv.props[std::string(1, c)] = {builtinTypes->numberType}; Type tv{ttv}; ToStringOptions o; o.maxTableLength = 40; - CHECK_EQ(toString(&tv, o), "{| a: number, b: number, c: number, d: number, e: number, ... 5 more ... |}"); + if (FFlag::LuauSolverV2) + CHECK_EQ(toString(&tv, o), "{ a: number, b: number, c: number, d: number, e: number, ... 5 more ... }"); + else + CHECK_EQ(toString(&tv, o), "{| a: number, b: number, c: number, d: number, e: number, ... 5 more ... |}"); } TEST_CASE_FIXTURE(Fixture, "stringifying_cyclic_union_type_bails_early") { - Type tv{UnionType{{typeChecker.stringType, typeChecker.numberType}}}; + Type tv{UnionType{{builtinTypes->stringType, builtinTypes->numberType}}}; UnionType* utv = getMutable(&tv); utv->options.push_back(&tv); utv->options.push_back(&tv); @@ -371,12 +477,15 @@ TEST_CASE_FIXTURE(Fixture, "stringifying_cyclic_intersection_type_bails_early") TEST_CASE_FIXTURE(Fixture, "stringifying_array_uses_array_syntax") { TableType ttv{TableState::Sealed, TypeLevel{}}; - ttv.indexer = TableIndexer{typeChecker.numberType, typeChecker.stringType}; + ttv.indexer = TableIndexer{builtinTypes->numberType, builtinTypes->stringType}; CHECK_EQ("{string}", toString(Type{ttv})); - ttv.props["A"] = {typeChecker.numberType}; - CHECK_EQ("{| [number]: string, A: number |}", toString(Type{ttv})); + ttv.props["A"] = {builtinTypes->numberType}; + if (FFlag::LuauSolverV2) + CHECK_EQ("{ [number]: string, A: number }", toString(Type{ttv})); + else + CHECK_EQ("{| [number]: string, A: number |}", toString(Type{ttv})); ttv.props.clear(); ttv.state = TableState::Unsealed; @@ -445,10 +554,12 @@ TEST_CASE_FIXTURE(Fixture, "generate_friendly_names_for_inferred_generics") CHECK_EQ("(a) -> a", toString(requireType("id"))); - CHECK_EQ("(a, b, c, d, e, f, g, h, i, j, k, l, " - "m, n, o, p, q, r, s, t, u, v, w, x, y, z, a1, b1, c1, d1) -> (a, b, c, d, e, f, g, h, i, j, k, l, m, n, o, p, q, r, s, t, u, v, w, " - "x, y, z, a1, b1, c1, d1)", - toString(requireType("id2"))); + CHECK_EQ( + "(a, b, c, d, e, f, g, h, i, j, k, l, " + "m, n, o, p, q, r, s, t, u, v, w, x, y, z, a1, b1, c1, d1) -> (a, b, c, d, e, f, g, h, i, j, k, l, m, n, o, p, q, r, s, t, u, v, w, " + "x, y, z, a1, b1, c1, d1)", + toString(requireType("id2")) + ); } TEST_CASE_FIXTURE(Fixture, "toStringDetailed") @@ -481,64 +592,10 @@ TEST_CASE_FIXTURE(Fixture, "toStringDetailed") CHECK("c" == toString(params[2], opts)); } -TEST_CASE_FIXTURE(BuiltinsFixture, "toStringDetailed2") -{ - ScopedFastFlag sff[] = { - {"DebugLuauSharedSelf", true}, - }; - - CheckResult result = check(R"( - local base = {} - function base:one() return 1 end - - local child = {} - setmetatable(child, {__index=base}) - function child:two() return 2 end - - local inst = {} - setmetatable(inst, {__index=child}) - )"); - LUAU_REQUIRE_NO_ERRORS(result); - - ToStringOptions opts; - - TypeId tType = requireType("inst"); - ToStringResult r = toStringDetailed(tType, opts); - CHECK_EQ("{ @metatable { __index: { @metatable {| __index: base |}, child } }, inst }", r.name); - CHECK(0 == opts.nameMap.types.size()); - - const MetatableType* tMeta = get(tType); - REQUIRE(tMeta); - - TableType* tMeta2 = getMutable(tMeta->metatable); - REQUIRE(tMeta2); - REQUIRE(tMeta2->props.count("__index")); - - const MetatableType* tMeta3 = get(tMeta2->props["__index"].type); - REQUIRE(tMeta3); - - TableType* tMeta4 = getMutable(tMeta3->metatable); - REQUIRE(tMeta4); - REQUIRE(tMeta4->props.count("__index")); - - TableType* tMeta5 = getMutable(tMeta4->props["__index"].type); - REQUIRE(tMeta5); - REQUIRE(tMeta5->props.count("one") > 0); - - TableType* tMeta6 = getMutable(tMeta3->table); - REQUIRE(tMeta6); - REQUIRE(tMeta6->props.count("two") > 0); - - ToStringResult oneResult = toStringDetailed(tMeta5->props["one"].type, opts); - - std::string twoResult = toString(tMeta6->props["two"].type, opts); - - CHECK_EQ("(a) -> number", oneResult.name); - CHECK_EQ("(b) -> number", twoResult); -} - TEST_CASE_FIXTURE(Fixture, "toStringErrorPack") { + ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + CheckResult result = check(R"( local function target(callback: nil) return callback(4, "hello") end )"); @@ -562,21 +619,30 @@ TEST_CASE_FIXTURE(Fixture, "toString_the_boundTo_table_type_contained_within_a_T Type tv1{TableType{}}; TableType* ttv = getMutable(&tv1); ttv->state = TableState::Sealed; - ttv->props["hello"] = {typeChecker.numberType}; - ttv->props["world"] = {typeChecker.numberType}; + ttv->props["hello"] = {builtinTypes->numberType}; + ttv->props["world"] = {builtinTypes->numberType}; TypePackVar tpv1{TypePack{{&tv1}}}; Type tv2{TableType{}}; TableType* bttv = getMutable(&tv2); bttv->state = TableState::Free; - bttv->props["hello"] = {typeChecker.numberType}; + bttv->props["hello"] = {builtinTypes->numberType}; bttv->boundTo = &tv1; TypePackVar tpv2{TypePack{{&tv2}}}; - CHECK_EQ("{| hello: number, world: number |}", toString(&tpv1)); - CHECK_EQ("{| hello: number, world: number |}", toString(&tpv2)); + + if (FFlag::LuauSolverV2) + { + CHECK_EQ("{ hello: number, world: number }", toString(&tpv1)); + CHECK_EQ("{ hello: number, world: number }", toString(&tpv2)); + } + else + { + CHECK_EQ("{| hello: number, world: number |}", toString(&tpv1)); + CHECK_EQ("{| hello: number, world: number |}", toString(&tpv2)); + } } TEST_CASE_FIXTURE(Fixture, "no_parentheses_around_return_type_if_pack_has_an_empty_head_link") @@ -614,7 +680,10 @@ TEST_CASE_FIXTURE(Fixture, "no_parentheses_around_cyclic_function_type_in_inters LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("((number) -> ()) & t1 where t1 = () -> t1", toString(requireType("a"))); + if (FFlag::LuauSolverV2) + CHECK("(() -> t1) & ((number) -> ()) where t1 = () -> t1" == toString(requireType("a"))); + else + CHECK_EQ("((number) -> ()) & t1 where t1 = () -> t1", toString(requireType("a"))); } TEST_CASE_FIXTURE(Fixture, "self_recursive_instantiated_param") @@ -654,7 +723,10 @@ TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_map") TypeId ty = requireType("map"); const FunctionType* ftv = get(follow(ty)); - CHECK_EQ("map(arr: {a}, fn: (a) -> b): {b}", toStringNamedFunction("map", *ftv)); + if (FFlag::LuauSolverV2) + CHECK_EQ("map(arr: {a}, fn: (a) -> (b, ...unknown)): {b}", toStringNamedFunction("map", *ftv)); + else + CHECK_EQ("map(arr: {a}, fn: (a) -> b): {b}", toStringNamedFunction("map", *ftv)); } TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_generic_pack") @@ -769,65 +841,135 @@ TEST_CASE_FIXTURE(Fixture, "pick_distinct_names_for_mixed_explicit_and_implicit_ function foo(x: a, y) end )"); - CHECK("(a, b) -> ()" == toString(requireType("foo"))); + if (FFlag::LuauSolverV2) + { + CHECK("(a, unknown) -> ()" == toString(requireType("foo"))); + } + else + CHECK("(a, b) -> ()" == toString(requireType("foo"))); } -TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_include_self_param") +TEST_CASE_FIXTURE(Fixture, "tostring_unsee_ttv_if_array") { - ScopedFastFlag sff[]{ - {"DebugLuauSharedSelf", true}, - }; + CheckResult result = check(R"( + local x: {string} + -- This code is constructed very specifically to use the same (by pointer + -- identity) type in the function twice. + local y: (typeof(x), typeof(x)) -> () + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK(toString(requireType("y")) == "({string}, {string}) -> ()"); +} + +TEST_CASE_FIXTURE(Fixture, "tostring_error_mismatch") +{ CheckResult result = check(R"( - local foo = {} - function foo:method(arg: string): () + --!strict + function f1() : {a : number, b : string, c : { d : number}} + return { a = 1, b = "b", c = {d = "d"}} end )"); - TypeId parentTy = requireType("foo"); - auto ttv = get(follow(parentTy)); - auto ftv = get(follow(ttv->props.at("method").type)); + std::string expected; + if (FFlag::LuauSolverV2) + expected = + R"(Type pack '{ a: number, b: string, c: { d: string } }' could not be converted into '{ a: number, b: string, c: { d: number } }'; at [0][read "c"][read "d"], string is not exactly number)"; + else + expected = R"(Type + '{ a: number, b: string, c: { d: string } }' +could not be converted into + '{| a: number, b: string, c: {| d: number |} |}' +caused by: + Property 'c' is not compatible. +Type + '{ d: string }' +could not be converted into + '{| d: number |}' +caused by: + Property 'd' is not compatible. +Type 'string' could not be converted into 'number' in an invariant context)"; + + LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ("foo:method(self: a, arg: string): ()", toStringNamedFunction("foo:method", *ftv)); + std::string actual = toString(result.errors[0]); + + CHECK(expected == actual); } -TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_hide_self_param") +TEST_CASE_FIXTURE(Fixture, "checked_fn_toString") { - ScopedFastFlag sff[]{ - {"DebugLuauSharedSelf", true}, + ScopedFastFlag flags[] = { + {FFlag::LuauSolverV2, true}, }; - CheckResult result = check(R"( - local foo = {} - function foo:method(arg: string): () - end - )"); + auto _result = loadDefinition(R"( +@checked declare function abs(n: number) : number +)"); - ToStringOptions opts; - opts.hideFunctionSelfArgument = true; + auto result = check(Mode::Nonstrict, R"( +local f = abs +)"); - TypeId parentTy = requireType("foo"); - auto ttv = get(follow(parentTy)); - REQUIRE_MESSAGE(ttv, "Expected a table but got " << toString(parentTy, opts)); - TypeId methodTy = follow(ttv->props.at("method").type); - auto ftv = get(methodTy); - REQUIRE_MESSAGE(ftv, "Expected a function but got " << toString(methodTy, opts)); + LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("foo:method(arg: string): ()", toStringNamedFunction("foo:method", *ftv, opts)); + TypeId fn = requireType("f"); + CHECK("@checked (number) -> number" == toString(fn)); } -TEST_CASE_FIXTURE(Fixture, "tostring_unsee_ttv_if_array") +TEST_CASE_FIXTURE(Fixture, "read_only_properties") { + ScopedFastFlag sff{FFlag::LuauSolverV2, true}; + CheckResult result = check(R"( - local x: {string} - -- This code is constructed very specifically to use the same (by pointer - -- identity) type in the function twice. - local y: (typeof(x), typeof(x)) -> () + type A = {x: string} + type B = {read x: string} )"); LUAU_REQUIRE_NO_ERRORS(result); - CHECK(toString(requireType("y")) == "({string}, {string}) -> ()"); + CHECK("{ x: string }" == toString(requireTypeAlias("A"), {true})); + CHECK("{ read x: string }" == toString(requireTypeAlias("B"), {true})); +} + +TEST_CASE_FIXTURE(Fixture, "cycle_rooted_in_a_pack") +{ + TypeArena arena; + + TypePackId thePack = arena.addTypePack({builtinTypes->numberType, builtinTypes->numberType}); + TypePack* packPtr = getMutable(thePack); + REQUIRE(packPtr); + + const TableType::Props theProps = { + {"BaseField", Property::readonly(builtinTypes->unknownType)}, + {"BaseMethod", Property::readonly(arena.addType(FunctionType{thePack, arena.addTypePack({})}))} + }; + + TypeId theTable = arena.addType(TableType{theProps, {}, TypeLevel{}, TableState::Sealed}); + + packPtr->head[0] = theTable; + + if (FFlag::LuauSolverV2) + CHECK("tp1 where tp1 = { read BaseField: unknown, read BaseMethod: (tp1) -> () }, number" == toString(thePack)); + else + CHECK("tp1 where tp1 = {| BaseField: unknown, BaseMethod: (tp1) -> () |}, number" == toString(thePack)); +} + +TEST_CASE_FIXTURE(Fixture, "correct_stringification_user_defined_type_functions") +{ + TypeFunction user{"user", nullptr}; + TypeFunctionInstanceType tftt{ + NotNull{&user}, + std::vector{builtinTypes->numberType}, // Type Function Arguments + {}, + {AstName{"woohoo"}}, // Type Function Name + }; + + Type tv{tftt}; + + if (FFlag::LuauSolverV2 && FFlag::LuauUserDefinedTypeFunctions2) + CHECK_EQ(toString(&tv, {}), "woohoo"); } TEST_SUITE_END(); diff --git a/tests/Transpiler.test.cpp b/tests/Transpiler.test.cpp index 87f0b2b8a..188d9682d 100644 --- a/tests/Transpiler.test.cpp +++ b/tests/Transpiler.test.cpp @@ -12,6 +12,8 @@ using namespace Luau; +LUAU_FASTFLAG(LuauUserDefinedTypeFunctionsSyntax2) + TEST_SUITE_BEGIN("TranspilerTests"); TEST_CASE("test_1") @@ -345,7 +347,7 @@ TEST_CASE("always_emit_a_space_after_local_keyword") TEST_CASE_FIXTURE(Fixture, "types_should_not_be_considered_cyclic_if_they_are_not_recursive") { std::string code = R"( - local common: {foo:string} + local common: {foo:string} = {foo = 'foo'} local t = {} t.x = common @@ -353,7 +355,7 @@ TEST_CASE_FIXTURE(Fixture, "types_should_not_be_considered_cyclic_if_they_are_no )"; std::string expected = R"( - local common: {foo:string} + local common: {foo:string} = {foo = 'foo'} local t:{x:{foo:string},y:{foo:string}}={} t.x = common @@ -529,7 +531,7 @@ until c CHECK_EQ(code, transpile(code, {}, true).code); } -TEST_CASE_FIXTURE(Fixture, "transpile_compound_assignmenr") +TEST_CASE_FIXTURE(Fixture, "transpile_compound_assignment") { std::string code = R"( local a = 1 @@ -537,6 +539,7 @@ a += 2 a -= 3 a *= 4 a /= 5 +a //= 5 a %= 6 a ^= 7 a ..= ' - result' @@ -693,4 +696,13 @@ TEST_CASE_FIXTURE(Fixture, "transpile_string_literal_escape") CHECK_EQ(code, transpile(code, {}, true).code); } +TEST_CASE_FIXTURE(Fixture, "transpile_type_functions") +{ + ScopedFastFlag sff{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; + + std::string code = R"( type function foo(arg1, arg2) if arg1 == arg2 then return arg1 end return arg2 end )"; + + CHECK_EQ(code, transpile(code, {}, true).code); +} + TEST_SUITE_END(); diff --git a/tests/TxnLog.test.cpp b/tests/TxnLog.test.cpp new file mode 100644 index 000000000..b4b183537 --- /dev/null +++ b/tests/TxnLog.test.cpp @@ -0,0 +1,126 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "doctest.h" + +#include "Luau/Scope.h" +#include "Luau/ToString.h" +#include "Luau/TxnLog.h" +#include "Luau/Type.h" +#include "Luau/TypeArena.h" + +#include "ScopedFlags.h" + +using namespace Luau; + +LUAU_FASTFLAG(LuauSolverV2) + +struct TxnLogFixture +{ + TxnLog log{/*useScopes*/ true}; + TxnLog log2{/*useScopes*/ true}; + TypeArena arena; + BuiltinTypes builtinTypes; + + ScopePtr globalScope = std::make_shared(builtinTypes.anyTypePack); + ScopePtr childScope = std::make_shared(globalScope); + + TypeId a = freshType(NotNull{&arena}, NotNull{&builtinTypes}, globalScope.get()); + TypeId b = freshType(NotNull{&arena}, NotNull{&builtinTypes}, globalScope.get()); + TypeId c = freshType(NotNull{&arena}, NotNull{&builtinTypes}, childScope.get()); + + TypeId g = arena.addType(GenericType{"G"}); +}; + +TEST_SUITE_BEGIN("TxnLog"); + +TEST_CASE_FIXTURE(TxnLogFixture, "colliding_union_incoming_type_has_greater_scope") +{ + ScopedFastFlag sff{FFlag::LuauSolverV2, true}; + + log.replace(c, BoundType{a}); + log2.replace(a, BoundType{c}); + + CHECK(nullptr != log.pending(c)); + + log.concatAsUnion(std::move(log2), NotNull{&arena}); + + // 'a has greater scope than 'c, so we expect the incoming binding of 'a to + // be discarded. + + CHECK(nullptr == log.pending(a)); + + const PendingType* pt = log.pending(c); + REQUIRE(pt != nullptr); + + CHECK(!pt->dead); + const BoundType* bt = get_if(&pt->pending.ty); + + CHECK(a == bt->boundTo); + + log.commit(); + + REQUIRE(get(a)); + + const BoundType* bound = get(c); + REQUIRE(bound); + CHECK(a == bound->boundTo); +} + +TEST_CASE_FIXTURE(TxnLogFixture, "colliding_union_incoming_type_has_lesser_scope") +{ + ScopedFastFlag sff{FFlag::LuauSolverV2, true}; + + log.replace(a, BoundType{c}); + log2.replace(c, BoundType{a}); + + CHECK(nullptr != log.pending(a)); + + log.concatAsUnion(std::move(log2), NotNull{&arena}); + + // 'a has greater scope than 'c, so we expect the binding of 'a to be + // discarded, and for that of 'c to be brought in. + + CHECK(nullptr == log.pending(a)); + + const PendingType* pt = log.pending(c); + REQUIRE(pt != nullptr); + + CHECK(!pt->dead); + const BoundType* bt = get_if(&pt->pending.ty); + + CHECK(a == bt->boundTo); + + log.commit(); + + REQUIRE(get(a)); + + const BoundType* bound = get(c); + REQUIRE(bound); + CHECK(a == bound->boundTo); +} + +TEST_CASE_FIXTURE(TxnLogFixture, "colliding_coincident_logs_do_not_create_degenerate_unions") +{ + ScopedFastFlag sff{FFlag::LuauSolverV2, true}; + + log.replace(a, BoundType{b}); + log2.replace(a, BoundType{b}); + + log.concatAsUnion(std::move(log2), NotNull{&arena}); + + log.commit(); + + CHECK("'a" == toString(a)); + CHECK("'a" == toString(b)); +} + +TEST_CASE_FIXTURE(TxnLogFixture, "replacing_persistent_types_is_allowed_but_makes_the_log_radioactive") +{ + persist(g); + + log.replace(g, BoundType{a}); + + CHECK(log.radioactive); +} + +TEST_SUITE_END(); diff --git a/tests/TypeFunction.test.cpp b/tests/TypeFunction.test.cpp new file mode 100644 index 000000000..68bfff57a --- /dev/null +++ b/tests/TypeFunction.test.cpp @@ -0,0 +1,1266 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/TypeFunction.h" + +#include "Luau/ConstraintSolver.h" +#include "Luau/NotNull.h" +#include "Luau/Type.h" + +#include "ClassFixture.h" +#include "Fixture.h" + +#include "doctest.h" + +using namespace Luau; + +LUAU_FASTFLAG(LuauSolverV2) +LUAU_FASTFLAG(LuauUserDefinedTypeFunctions2) +LUAU_DYNAMIC_FASTINT(LuauTypeFamilyApplicationCartesianProductLimit) + +struct TypeFunctionFixture : Fixture +{ + TypeFunction swapFunction; + + TypeFunctionFixture() + : Fixture(false) + { + swapFunction = TypeFunction{ + /* name */ "Swap", + /* reducer */ + [](TypeId instance, const std::vector& tys, const std::vector& tps, NotNull ctx + ) -> TypeFunctionReductionResult + { + LUAU_ASSERT(tys.size() == 1); + TypeId param = follow(tys.at(0)); + + if (isString(param)) + { + return TypeFunctionReductionResult{ctx->builtins->numberType, false, {}, {}}; + } + else if (isNumber(param)) + { + return TypeFunctionReductionResult{ctx->builtins->stringType, false, {}, {}}; + } + else if (is(param) || is(param) || is(param) || + (ctx->solver && ctx->solver->hasUnresolvedConstraints(param))) + { + return TypeFunctionReductionResult{std::nullopt, false, {param}, {}}; + } + else + { + return TypeFunctionReductionResult{std::nullopt, true, {}, {}}; + } + } + }; + + unfreeze(frontend.globals.globalTypes); + TypeId t = frontend.globals.globalTypes.addType(GenericType{"T"}); + GenericTypeDefinition genericT{t}; + + ScopePtr globalScope = frontend.globals.globalScope; + globalScope->exportedTypeBindings["Swap"] = + TypeFun{{genericT}, frontend.globals.globalTypes.addType(TypeFunctionInstanceType{NotNull{&swapFunction}, {t}, {}})}; + freeze(frontend.globals.globalTypes); + } +}; + +TEST_SUITE_BEGIN("TypeFunctionTests"); + +TEST_CASE_FIXTURE(TypeFunctionFixture, "basic_type_function") +{ + if (!FFlag::LuauSolverV2) + return; + + CheckResult result = check(R"( + type A = Swap + type B = Swap + type C = Swap + + local x = 123 + local y: Swap = "foo" + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK("string" == toString(requireTypeAlias("A"))); + CHECK("number" == toString(requireTypeAlias("B"))); + CHECK("Swap" == toString(requireTypeAlias("C"))); + CHECK("string" == toString(requireType("y"))); + CHECK("Type function instance Swap is uninhabited" == toString(result.errors[0])); +}; + +TEST_CASE_FIXTURE(TypeFunctionFixture, "function_as_fn_ret") +{ + if (!FFlag::LuauSolverV2) + return; + + CheckResult result = check(R"( + local swapper: (T) -> Swap + local a = swapper(123) + local b = swapper("foo") + local c = swapper(false) + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK("string" == toString(requireType("a"))); + CHECK("number" == toString(requireType("b"))); + CHECK("Swap" == toString(requireType("c"))); + CHECK("Type function instance Swap is uninhabited" == toString(result.errors[0])); +} + +TEST_CASE_FIXTURE(TypeFunctionFixture, "function_as_fn_arg") +{ + if (!FFlag::LuauSolverV2) + return; + + CheckResult result = check(R"( + local swapper: (Swap) -> T + local a = swapper(123) + local b = swapper(false) + )"); + + LUAU_REQUIRE_ERROR_COUNT(2, result); + CHECK("unknown" == toString(requireType("a"))); + CHECK("unknown" == toString(requireType("b"))); + CHECK("Type 'number' could not be converted into 'never'" == toString(result.errors[0])); + CHECK("Type 'boolean' could not be converted into 'never'" == toString(result.errors[1])); +} + +TEST_CASE_FIXTURE(TypeFunctionFixture, "resolve_deep_functions") +{ + if (!FFlag::LuauSolverV2) + return; + + CheckResult result = check(R"( + local x: Swap>> + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK("number" == toString(requireType("x"))); +} + +TEST_CASE_FIXTURE(TypeFunctionFixture, "unsolvable_function") +{ + if (!FFlag::LuauSolverV2) + return; + + CheckResult result = check(R"( + local impossible: (Swap) -> Swap> + local a = impossible(123) + local b = impossible(true) + )"); + + LUAU_REQUIRE_ERROR_COUNT(6, result); + CHECK(toString(result.errors[0]) == "Type function instance Swap> is uninhabited"); + CHECK(toString(result.errors[1]) == "Type function instance Swap is uninhabited"); + CHECK(toString(result.errors[2]) == "Type function instance Swap> is uninhabited"); + CHECK(toString(result.errors[3]) == "Type function instance Swap is uninhabited"); + CHECK(toString(result.errors[4]) == "Type function instance Swap> is uninhabited"); + CHECK(toString(result.errors[5]) == "Type function instance Swap is uninhabited"); +} + +TEST_CASE_FIXTURE(TypeFunctionFixture, "table_internal_functions") +{ + if (!FFlag::LuauSolverV2) + return; + + CheckResult result = check(R"( + local t: ({T}) -> {Swap} + local a = t({1, 2, 3}) + local b = t({"a", "b", "c"}) + local c = t({true, false, true}) + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK(toString(requireType("a")) == "{string}"); + CHECK(toString(requireType("b")) == "{number}"); + // FIXME: table types are constructing a trivial union here. + CHECK(toString(requireType("c")) == "{Swap}"); + CHECK(toString(result.errors[0]) == "Type function instance Swap is uninhabited"); +} + +TEST_CASE_FIXTURE(TypeFunctionFixture, "function_internal_functions") +{ + if (!FFlag::LuauSolverV2) + return; + + CheckResult result = check(R"( + local f0: (T) -> (() -> T) + local f: (T) -> (() -> Swap) + local a = f(1) + local b = f("a") + local c = f(true) + local d = f0(1) + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK(toString(requireType("a")) == "() -> string"); + CHECK(toString(requireType("b")) == "() -> number"); + CHECK(toString(requireType("c")) == "() -> Swap"); + CHECK(toString(result.errors[0]) == "Type function instance Swap is uninhabited"); +} + +TEST_CASE_FIXTURE(Fixture, "add_function_at_work") +{ + if (!FFlag::LuauSolverV2) + return; + + CheckResult result = check(R"( + local function add(a, b) + return a + b + end + + local a = add(1, 2) + local b = add(1, "foo") + local c = add("foo", 1) + )"); + + LUAU_REQUIRE_ERROR_COUNT(2, result); + CHECK(toString(requireType("a")) == "number"); + CHECK(toString(requireType("b")) == "add"); + CHECK(toString(requireType("c")) == "add"); + CHECK( + toString(result.errors[0]) == + "Operator '+' could not be applied to operands of types number and string; there is no corresponding overload for __add" + ); + CHECK( + toString(result.errors[1]) == + "Operator '+' could not be applied to operands of types string and number; there is no corresponding overload for __add" + ); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "cyclic_add_function_at_work") +{ + if (!FFlag::LuauSolverV2) + return; + + CheckResult result = check(R"( + type T = add + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK(toString(requireTypeAlias("T")) == "number"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "mul_function_with_union_of_multiplicatives") +{ + if (!FFlag::LuauSolverV2) + return; + + loadDefinition(R"( + declare class Vec2 + function __mul(self, rhs: number): Vec2 + end + + declare class Vec3 + function __mul(self, rhs: number): Vec3 + end + )"); + + CheckResult result = check(R"( + type T = mul + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK(toString(requireTypeAlias("T")) == "Vec2 | Vec3"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "mul_function_with_union_of_multiplicatives_2") +{ + if (!FFlag::LuauSolverV2) + return; + + loadDefinition(R"( + declare class Vec3 + function __mul(self, rhs: number): Vec3 + function __mul(self, rhs: Vec3): Vec3 + end + )"); + + CheckResult result = check(R"( + type T = mul + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK(toString(requireTypeAlias("T")) == "Vec3"); +} + +TEST_CASE_FIXTURE(Fixture, "internal_functions_raise_errors") +{ + if (!FFlag::LuauSolverV2) + return; + + CheckResult result = check(R"( + local function innerSum(a, b) + local _ = a + b + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK( + toString(result.errors[0]) == + "Operator '+' could not be applied to operands of types unknown and unknown; there is no corresponding overload for __add" + ); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "type_functions_can_be_shadowed") +{ + if (!FFlag::LuauSolverV2) + return; + + CheckResult result = check(R"( + type add = string -- shadow add + + -- this should be ok + function hi(f: add) + return string.format("hi %s", f) + end + + -- this should still work totally fine (and use the real type function) + function plus(a, b) + return a + b + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK(toString(requireType("hi")) == "(string) -> string"); + CHECK(toString(requireType("plus")) == "(a, b) -> add"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "type_functions_inhabited_with_normalization") +{ + if (!FFlag::LuauSolverV2) + return; + + CheckResult result = check(R"( + local useGridConfig : any + local columns = useGridConfig("columns", {}) or 1 + local gutter = useGridConfig('gutter', {}) or 0 + local margin = useGridConfig('margin', {}) or 0 + return function(frameAbsoluteWidth: number) + local cellAbsoluteWidth = (frameAbsoluteWidth - 2 * margin + gutter) / columns - gutter + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "keyof_type_function_works") +{ + if (!FFlag::LuauSolverV2) + return; + + CheckResult result = check(R"( + type MyObject = { x: number, y: number, z: number } + type KeysOfMyObject = keyof + + local function ok(idx: KeysOfMyObject): "x" | "y" | "z" return idx end + local function err(idx: KeysOfMyObject): "x" | "y" return idx end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + TypePackMismatch* tpm = get(result.errors[0]); + REQUIRE(tpm); + CHECK_EQ("\"x\" | \"y\"", toString(tpm->wantedTp)); + CHECK_EQ("\"x\" | \"y\" | \"z\"", toString(tpm->givenTp)); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "keyof_type_function_works_with_metatables") +{ + if (!FFlag::LuauSolverV2) + return; + + CheckResult result = check(R"( + local metatable = { __index = {w = 1} } + local obj = setmetatable({x = 1, y = 2, z = 3}, metatable) + type MyObject = typeof(obj) + type KeysOfMyObject = keyof + + local function ok(idx: KeysOfMyObject): "w" | "x" | "y" | "z" return idx end + local function err(idx: KeysOfMyObject): "x" | "y" | "z" return idx end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + TypePackMismatch* tpm = get(result.errors[0]); + REQUIRE(tpm); + CHECK_EQ("\"x\" | \"y\" | \"z\"", toString(tpm->wantedTp)); + CHECK_EQ("\"w\" | \"x\" | \"y\" | \"z\"", toString(tpm->givenTp)); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "keyof_single_entry_no_uniontype") +{ + if (!FFlag::LuauSolverV2) + return; + + CheckResult result = check(R"( + local tbl_A = { abc = "value" } + local tbl_B = { a1 = nil, ["a2"] = nil } + + type keyof_A = keyof + type keyof_B = keyof + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK(toString(requireTypeAlias("keyof_A")) == "\"abc\""); + CHECK(toString(requireTypeAlias("keyof_B")) == "\"a1\" | \"a2\""); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "keyof_type_function_errors_if_it_has_nontable_part") +{ + if (!FFlag::LuauSolverV2) + return; + + CheckResult result = check(R"( + type MyObject = { x: number, y: number, z: number } + type KeysOfMyObject = keyof + + local function err(idx: KeysOfMyObject): "x" | "y" | "z" return idx end + )"); + + // FIXME(CLI-95289): we should actually only report the type function being uninhabited error at its first use, I think? + LUAU_REQUIRE_ERROR_COUNT(2, result); + CHECK(toString(result.errors[0]) == "Type 'MyObject | boolean' does not have keys, so 'keyof' is invalid"); + CHECK(toString(result.errors[1]) == "Type 'MyObject | boolean' does not have keys, so 'keyof' is invalid"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "keyof_type_function_string_indexer") +{ + if (!FFlag::LuauSolverV2) + return; + + CheckResult result = check(R"( + type MyObject = { x: number, y: number, z: number } + type MyOtherObject = { [string]: number } + type KeysOfMyOtherObject = keyof + type KeysOfMyObjects = keyof + + local function ok(idx: KeysOfMyOtherObject): "z" return idx end + local function err(idx: KeysOfMyObjects): "z" return idx end + )"); + + LUAU_REQUIRE_ERROR_COUNT(2, result); + + TypePackMismatch* tpm = get(result.errors[0]); + REQUIRE(tpm); + CHECK_EQ("\"z\"", toString(tpm->wantedTp)); + CHECK_EQ("string", toString(tpm->givenTp)); + + tpm = get(result.errors[1]); + REQUIRE(tpm); + CHECK_EQ("\"z\"", toString(tpm->wantedTp)); + CHECK_EQ("\"x\" | \"y\" | \"z\"", toString(tpm->givenTp)); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "keyof_type_function_common_subset_if_union_of_differing_tables") +{ + if (!FFlag::LuauSolverV2) + return; + + CheckResult result = check(R"( + type MyObject = { x: number, y: number, z: number } + type MyOtherObject = { w: number, y: number, z: number } + type KeysOfMyObject = keyof + + local function err(idx: KeysOfMyObject): "z" return idx end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + TypePackMismatch* tpm = get(result.errors[0]); + REQUIRE(tpm); + CHECK_EQ("\"z\"", toString(tpm->wantedTp)); + CHECK_EQ("\"y\" | \"z\"", toString(tpm->givenTp)); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "keyof_type_function_never_for_empty_table") +{ + if (!FFlag::LuauSolverV2) + return; + + CheckResult result = check(R"( + type KeyofEmpty = keyof<{}> + + local foo = ((nil :: any) :: KeyofEmpty) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK(toString(requireType("foo")) == "never"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "rawkeyof_type_function_works") +{ + if (!FFlag::LuauSolverV2) + return; + + CheckResult result = check(R"( + type MyObject = { x: number, y: number, z: number } + type KeysOfMyObject = rawkeyof + + local function ok(idx: KeysOfMyObject): "x" | "y" | "z" return idx end + local function err(idx: KeysOfMyObject): "x" | "y" return idx end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + TypePackMismatch* tpm = get(result.errors[0]); + REQUIRE(tpm); + CHECK_EQ("\"x\" | \"y\"", toString(tpm->wantedTp)); + CHECK_EQ("\"x\" | \"y\" | \"z\"", toString(tpm->givenTp)); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "rawkeyof_type_function_ignores_metatables") +{ + if (!FFlag::LuauSolverV2) + return; + + CheckResult result = check(R"( + local metatable = { __index = {w = 1} } + local obj = setmetatable({x = 1, y = 2, z = 3}, metatable) + type MyObject = typeof(obj) + type KeysOfMyObject = rawkeyof + + local function ok(idx: KeysOfMyObject): "x" | "y" | "z" return idx end + local function err(idx: KeysOfMyObject): "x" | "y" return idx end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + TypePackMismatch* tpm = get(result.errors[0]); + REQUIRE(tpm); + CHECK_EQ("\"x\" | \"y\"", toString(tpm->wantedTp)); + CHECK_EQ("\"x\" | \"y\" | \"z\"", toString(tpm->givenTp)); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "rawkeyof_type_function_errors_if_it_has_nontable_part") +{ + if (!FFlag::LuauSolverV2) + return; + + CheckResult result = check(R"( + type MyObject = { x: number, y: number, z: number } + type KeysOfMyObject = rawkeyof + + local function err(idx: KeysOfMyObject): "x" | "y" | "z" return idx end + )"); + + // FIXME(CLI-95289): we should actually only report the type function being uninhabited error at its first use, I think? + LUAU_REQUIRE_ERROR_COUNT(2, result); + CHECK(toString(result.errors[0]) == "Type 'MyObject | boolean' does not have keys, so 'rawkeyof' is invalid"); + CHECK(toString(result.errors[1]) == "Type 'MyObject | boolean' does not have keys, so 'rawkeyof' is invalid"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "rawkeyof_type_function_common_subset_if_union_of_differing_tables") +{ + if (!FFlag::LuauSolverV2) + return; + + CheckResult result = check(R"( + type MyObject = { x: number, y: number, z: number } + type MyOtherObject = { w: number, y: number, z: number } + type KeysOfMyObject = rawkeyof + + local function err(idx: KeysOfMyObject): "z" return idx end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + TypePackMismatch* tpm = get(result.errors[0]); + REQUIRE(tpm); + CHECK_EQ("\"z\"", toString(tpm->wantedTp)); + CHECK_EQ("\"y\" | \"z\"", toString(tpm->givenTp)); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "rawkeyof_type_function_never_for_empty_table") +{ + if (!FFlag::LuauSolverV2) + return; + + CheckResult result = check(R"( + type RawkeyofEmpty = rawkeyof<{}> + + local foo = ((nil :: any) :: RawkeyofEmpty) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK(toString(requireType("foo")) == "never"); +} + +TEST_CASE_FIXTURE(ClassFixture, "keyof_type_function_works_on_classes") +{ + if (!FFlag::LuauSolverV2) + return; + + CheckResult result = check(R"( + type KeysOfMyObject = keyof + + local function ok(idx: KeysOfMyObject): "BaseMethod" | "BaseField" | "Touched" return idx end + local function err(idx: KeysOfMyObject): "BaseMethod" return idx end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + TypePackMismatch* tpm = get(result.errors[0]); + REQUIRE(tpm); + CHECK_EQ("\"BaseMethod\"", toString(tpm->wantedTp)); + CHECK_EQ("\"BaseField\" | \"BaseMethod\" | \"Touched\"", toString(tpm->givenTp)); +} + +TEST_CASE_FIXTURE(ClassFixture, "keyof_type_function_errors_if_it_has_nonclass_part") +{ + if (!FFlag::LuauSolverV2) + return; + + CheckResult result = check(R"( + type KeysOfMyObject = keyof + + local function err(idx: KeysOfMyObject): "BaseMethod" | "BaseField" return idx end + )"); + + // FIXME(CLI-95289): we should actually only report the type function being uninhabited error at its first use, I think? + LUAU_REQUIRE_ERROR_COUNT(2, result); + CHECK(toString(result.errors[0]) == "Type 'BaseClass | boolean' does not have keys, so 'keyof' is invalid"); + CHECK(toString(result.errors[1]) == "Type 'BaseClass | boolean' does not have keys, so 'keyof' is invalid"); +} + +TEST_CASE_FIXTURE(ClassFixture, "keyof_type_function_common_subset_if_union_of_differing_classes") +{ + if (!FFlag::LuauSolverV2) + return; + + CheckResult result = check(R"( + type KeysOfMyObject = keyof + + local function ok(idx: KeysOfMyObject): never return idx end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(ClassFixture, "keyof_type_function_works_with_parent_classes_too") +{ + if (!FFlag::LuauSolverV2) + return; + + CheckResult result = check(R"( + type KeysOfMyObject = keyof + + local function ok(idx: KeysOfMyObject): "BaseField" | "BaseMethod" | "Method" | "Touched" return idx end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(ClassFixture, "binary_type_function_works_with_default_argument") +{ + if (!FFlag::LuauSolverV2) + return; + + CheckResult result = check(R"( + type result = mul + + local function thunk(): result return 5 * 4 end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK("() -> number" == toString(requireType("thunk"))); +} + +TEST_CASE_FIXTURE(ClassFixture, "vector2_multiply_is_overloaded") +{ + if (!FFlag::LuauSolverV2) + return; + + CheckResult result = check(R"( + local v = Vector2.New(1, 2) + + local v2 = v * 1.5 + local v3 = v * v + local v4 = v * "Hello" -- line 5 + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CHECK(5 == result.errors[0].location.begin.line); + CHECK(5 == result.errors[0].location.end.line); + + CHECK("Vector2" == toString(requireType("v2"))); + CHECK("Vector2" == toString(requireType("v3"))); + CHECK("mul" == toString(requireType("v4"))); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "keyof_rfc_example") +{ + if (!FFlag::LuauSolverV2) + return; + + CheckResult result = check(R"( + local animals = { + cat = { speak = function() print "meow" end }, + dog = { speak = function() print "woof woof" end }, + monkey = { speak = function() print "oo oo" end }, + fox = { speak = function() print "gekk gekk" end } + } + + type AnimalType = keyof + + function speakByType(animal: AnimalType) + animals[animal].speak() + end + + speakByType("dog") -- ok + speakByType("cactus") -- errors + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + TypeMismatch* tm = get(result.errors[0]); + REQUIRE(tm); + CHECK_EQ("\"cat\" | \"dog\" | \"fox\" | \"monkey\"", toString(tm->wantedType)); + CHECK_EQ("\"cactus\"", toString(tm->givenType)); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "keyof_oss_crash_gh1161") +{ + if (!FFlag::LuauSolverV2) + return; + + CheckResult result = check(R"( + local EnumVariants = { + ["a"] = 1, ["b"] = 2, ["c"] = 3 + } + + type EnumKey = keyof + + function fnA(i: T): keyof end + + function fnB(i: EnumKey) end + + local result = fnA(EnumVariants) + fnB(result) + )"); + + LUAU_REQUIRE_ERROR_COUNT(2, result); + CHECK(get(result.errors[0])); + CHECK(get(result.errors[1])); +} + +TEST_CASE_FIXTURE(TypeFunctionFixture, "fuzzer_numeric_binop_doesnt_assert_on_generalizeFreeType") +{ + CheckResult result = check(R"( +Module 'l0': +local _ = (67108864)(_ >= _).insert +do end +do end +_(...,_(_,_(_()),_())) +(67108864)()() +_(_ ~= _ // _,l0)(_(_({n0,})),_(_),_) +_(setmetatable(_,{[...]=_,})) + +)"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "cyclic_concat_function_at_work") +{ + if (!FFlag::LuauSolverV2) + return; + + CheckResult result = check(R"( + type T = concat + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK(toString(requireTypeAlias("T")) == "string"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "exceeded_distributivity_limits") +{ + if (!FFlag::LuauSolverV2) + return; + + ScopedFastInt sfi{DFInt::LuauTypeFamilyApplicationCartesianProductLimit, 10}; + + loadDefinition(R"( + declare class A + function __mul(self, rhs: unknown): A + end + + declare class B + function __mul(self, rhs: unknown): B + end + + declare class C + function __mul(self, rhs: unknown): C + end + + declare class D + function __mul(self, rhs: unknown): D + end + )"); + + CheckResult result = check(R"( + type T = mul + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK(get(result.errors[0])); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "didnt_quite_exceed_distributivity_limits") +{ + if (!FFlag::LuauSolverV2) + return; + + // We duplicate the test here because we want to make sure the test failed + // due to exceeding the limits specifically, rather than any possible reasons. + ScopedFastInt sfi{DFInt::LuauTypeFamilyApplicationCartesianProductLimit, 20}; + + loadDefinition(R"( + declare class A + function __mul(self, rhs: unknown): A + end + + declare class B + function __mul(self, rhs: unknown): B + end + + declare class C + function __mul(self, rhs: unknown): C + end + + declare class D + function __mul(self, rhs: unknown): D + end + )"); + + CheckResult result = check(R"( + type T = mul + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "ensure_equivalence_with_distributivity") +{ + if (!FFlag::LuauSolverV2) + return; + + loadDefinition(R"( + declare class A + function __mul(self, rhs: unknown): A + end + + declare class B + function __mul(self, rhs: unknown): B + end + + declare class C + function __mul(self, rhs: unknown): C + end + + declare class D + function __mul(self, rhs: unknown): D + end + )"); + + CheckResult result = check(R"( + type T = mul + type U = mul | mul | mul | mul + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK(toString(requireTypeAlias("T")) == "A | B"); + CHECK(toString(requireTypeAlias("U")) == "A | A | B | B"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "we_shouldnt_warn_that_a_reducible_type_function_is_uninhabited") +{ + if (!FFlag::LuauSolverV2) + return; + + CheckResult result = check(R"( + +local Debounce = false +local Active = false + +local function Use(Mode) + + if Mode ~= nil then + + if Mode == false and Active == false then + return + else + Active = not Mode + end + + Debounce = false + end + Active = not Active + +end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "index_type_function_works") +{ + if (!FFlag::LuauSolverV2) + return; + + CheckResult result = check(R"( + type MyObject = {a: string, b: number, c: boolean} + type IdxAType = index + type IdxBType = index> + + local function ok(idx: IdxAType): string return idx end + local function ok2(idx: IdxBType): string | number | boolean return idx end + local function err(idx: IdxAType): boolean return idx end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + TypePackMismatch* tpm = get(result.errors[0]); + REQUIRE(tpm); + CHECK_EQ("boolean", toString(tpm->wantedTp)); + CHECK_EQ("string", toString(tpm->givenTp)); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "index_wait_for_pending_no_crash") +{ + if (!FFlag::LuauSolverV2) + return; + + CheckResult result = check(R"( + local PlayerData = { + Coins = 0, + Level = 1, + Exp = 0, + MaxExp = 100 + } + type Keys = index> + -- This function makes it think that there's going to be a pending expansion + local function UpdateData(key: Keys, value) + PlayerData[key] = value + end + UpdateData("Coins", 2) + )"); + + // Should not crash! +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "index_type_function_works_w_array") +{ + if (!FFlag::LuauSolverV2) + return; + + CheckResult result = check(R"( + local MyObject = {"hello", 1, true} + type IdxAType = index + + local function ok(idx: IdxAType): string | number | boolean return idx end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "index_type_function_works_w_generic_types") +{ + if (!FFlag::LuauSolverV2) + return; + + CheckResult result = check(R"( + local function access(tbl: T & {}, key: K): index + return tbl[key] + end + + local subjects = { + english = "boring", + math = "fun" + } + + local key: "english" = "english" + local a: string = access(subjects, key) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "index_type_function_errors_w_bad_indexer") +{ + if (!FFlag::LuauSolverV2) + return; + + CheckResult result = check(R"( + type MyObject = {a: string, b: number, c: boolean} + type errType1 = index + type errType2 = index + )"); + + LUAU_REQUIRE_ERROR_COUNT(2, result); + CHECK(toString(result.errors[0]) == "Property '\"d\"' does not exist on type 'MyObject'"); + CHECK(toString(result.errors[1]) == "Property 'boolean' does not exist on type 'MyObject'"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "index_type_function_errors_w_var_indexer") +{ + if (!FFlag::LuauSolverV2) + return; + + CheckResult result = check(R"( + type MyObject = {a: string, b: number, c: boolean} + local key = "a" + + type errType1 = index + )"); + + LUAU_REQUIRE_ERROR_COUNT(2, result); + CHECK(toString(result.errors[0]) == "Second argument to index is not a valid index type"); + CHECK(toString(result.errors[1]) == "Unknown type 'key'"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "index_type_function_works_w_union_type_indexer") +{ + if (!FFlag::LuauSolverV2) + return; + + CheckResult result = check(R"( + type MyObject = {a: string, b: number, c: boolean} + + type idxType = index + local function ok(idx: idxType): string | number return idx end + + type errType = index + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK(toString(result.errors[0]) == "Property '\"a\" | \"d\"' does not exist on type 'MyObject'"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "index_type_function_works_w_union_type_indexee") +{ + if (!FFlag::LuauSolverV2) + return; + + CheckResult result = check(R"( + type MyObject = {a: string, b: number, c: boolean} + type MyObject2 = {a: number} + + type idxTypeA = index + local function ok(idx: idxTypeA): string | number return idx end + + type errType = index + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK(toString(result.errors[0]) == "Property '\"b\"' does not exist on type 'MyObject | MyObject2'"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "index_type_function_rfc_alternative_section") +{ + if (!FFlag::LuauSolverV2) + return; + + CheckResult result = check(R"( + type MyObject = {a: string} + type MyObject2 = {a: string, b: number} + + local function edgeCase(param: MyObject) + type unknownType = index + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK(toString(result.errors[0]) == "Property '\"b\"' does not exist on type 'MyObject'"); +} + +TEST_CASE_FIXTURE(ClassFixture, "index_type_function_works_on_classes") +{ + if (!FFlag::LuauSolverV2) + return; + + CheckResult result = check(R"( + type KeysOfMyObject = index + + local function ok(idx: KeysOfMyObject): number return idx end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(ClassFixture, "index_type_function_works_on_classes_with_parents") +{ + if (!FFlag::LuauSolverV2) + return; + + CheckResult result = check(R"( + type KeysOfMyObject = index + + local function ok(idx: KeysOfMyObject): number return idx end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "index_type_function_works_w_index_metatables") +{ + if (!FFlag::LuauSolverV2) + return; + + CheckResult result = check(R"( + local exampleClass = { Foo = "text", Bar = true } + + local exampleClass2 = setmetatable({ Foo = 8 }, { __index = exampleClass }) + type exampleTy2 = index + local function ok(idx: exampleTy2): number return idx end + + local exampleClass3 = setmetatable({ Bar = 5 }, { __index = exampleClass }) + type exampleTy3 = index + local function ok2(idx: exampleTy3): string return idx end + + type exampleTy4 = index + local function ok3(idx: exampleTy4): string | number return idx end + + type errTy = index + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK(toString(result.errors[0]) == "Property '\"Car\"' does not exist on type 'exampleClass2'"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "rawget_type_function_works") +{ + if (!FFlag::LuauSolverV2) + return; + + CheckResult result = check(R"( + type MyObject = {a: string, b: number, c: boolean} + type RawAType = rawget + type RawBType = rawget> + local function ok(idx: RawAType): string return idx end + local function ok2(idx: RawBType): string | number | boolean return idx end + local function err(idx: RawAType): boolean return idx end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + TypePackMismatch* tpm = get(result.errors[0]); + REQUIRE(tpm); + CHECK_EQ("boolean", toString(tpm->wantedTp)); + CHECK_EQ("string", toString(tpm->givenTp)); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "rawget_type_function_works_w_array") +{ + if (!FFlag::LuauSolverV2) + return; + + CheckResult result = check(R"( + local MyObject = {"hello", 1, true} + type RawAType = rawget + local function ok(idx: RawAType): string | number | boolean return idx end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "rawget_type_function_errors_w_var_indexer") +{ + if (!FFlag::LuauSolverV2) + return; + + CheckResult result = check(R"( + type MyObject = {a: string, b: number, c: boolean} + local key = "a" + type errType1 = rawget + )"); + + LUAU_REQUIRE_ERROR_COUNT(2, result); + CHECK(toString(result.errors[0]) == "Second argument to rawget is not a valid index type"); + CHECK(toString(result.errors[1]) == "Unknown type 'key'"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "rawget_type_function_works_w_union_type_indexer") +{ + if (!FFlag::LuauSolverV2) + return; + + CheckResult result = check(R"( + type MyObject = {a: string, b: number, c: boolean} + type rawType = rawget + local function ok(idx: rawType): string | number return idx end + type errType = rawget + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK(toString(result.errors[0]) == "Property '\"a\" | \"d\"' does not exist on type 'MyObject'"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "rawget_type_function_works_w_union_type_indexee") +{ + if (!FFlag::LuauSolverV2) + return; + + CheckResult result = check(R"( + type MyObject = {a: string, b: number, c: boolean} + type MyObject2 = {a: number} + type rawTypeA = rawget + local function ok(idx: rawTypeA): string | number return idx end + type errType = rawget + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK(toString(result.errors[0]) == "Property '\"b\"' does not exist on type 'MyObject | MyObject2'"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "rawget_type_function_works_w_index_metatables") +{ + if (!FFlag::LuauSolverV2) + return; + + CheckResult result = check(R"( + local exampleClass = { Foo = "text", Bar = true } + local exampleClass2 = setmetatable({ Foo = 8 }, { __index = exampleClass }) + type exampleTy2 = rawget + local function ok(idx: exampleTy2): number return idx end + local exampleClass3 = setmetatable({ Bar = 5 }, { __index = exampleClass }) + type errType = rawget + type errType2 = rawget + )"); + + LUAU_REQUIRE_ERROR_COUNT(2, result); + CHECK(toString(result.errors[0]) == "Property '\"Foo\"' does not exist on type 'exampleClass3'"); + CHECK(toString(result.errors[1]) == "Property '\"Bar\" | \"Foo\"' does not exist on type 'exampleClass3'"); +} + +TEST_CASE_FIXTURE(ClassFixture, "rawget_type_function_errors_w_classes") +{ + if (!FFlag::LuauSolverV2) + return; + + CheckResult result = check(R"( + type PropsOfMyObject = rawget + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK(toString(result.errors[0]) == "Property '\"BaseField\"' does not exist on type 'BaseClass'"); +} + +TEST_CASE_FIXTURE(Fixture, "fuzz_len_type_function_follow") +{ + // Should not fail assertions + check(R"( + local _ + _ = true + for l0=_,_,# _ do + end + for l0=_,_ do + if _ then + _ += _ + end + end + )"); +} + +TEST_SUITE_END(); diff --git a/tests/TypeFunction.user.test.cpp b/tests/TypeFunction.user.test.cpp new file mode 100644 index 000000000..b6160b3c9 --- /dev/null +++ b/tests/TypeFunction.user.test.cpp @@ -0,0 +1,1233 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "ClassFixture.h" +#include "Fixture.h" + +#include "doctest.h" + +using namespace Luau; + +LUAU_FASTFLAG(LuauSolverV2) +LUAU_FASTFLAG(LuauUserDefinedTypeFunctionsSyntax2) +LUAU_FASTFLAG(LuauUserDefinedTypeFunctions2) +LUAU_FASTFLAG(LuauUserDefinedTypeFunctionNoEvaluation) +LUAU_FASTFLAG(LuauUserTypeFunFixRegister) +LUAU_FASTFLAG(LuauUserTypeFunFixNoReadWrite) + +TEST_SUITE_BEGIN("UserDefinedTypeFunctionTests"); + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_nil_serialization_works") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; + ScopedFastFlag udtf{FFlag::LuauUserDefinedTypeFunctions2, true}; + + CheckResult result = check(R"( + type function serialize_nil(arg) + return arg + end + type type_being_serialized = nil + local function ok(idx: serialize_nil): nil return idx end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_nil_methods_work") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; + ScopedFastFlag udtf{FFlag::LuauUserDefinedTypeFunctions2, true}; + + CheckResult result = check(R"( + type function getnil() + local ty = types.singleton(nil) + if ty:is("nil") then + return ty + end + -- this should never be returned + return types.string + end + local function ok(idx: getnil<>): nil return idx end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_unknown_serialization_works") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; + ScopedFastFlag udtf{FFlag::LuauUserDefinedTypeFunctions2, true}; + + CheckResult result = check(R"( + type function serialize_unknown(arg) + return arg + end + type type_being_serialized = unknown + local function ok(idx: serialize_unknown): unknown return idx end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_unknown_methods_work") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; + ScopedFastFlag udtf{FFlag::LuauUserDefinedTypeFunctions2, true}; + + CheckResult result = check(R"( + type function getunknown() + local ty = types.unknown + if ty:is("unknown") then + return ty + end + -- this should never be returned + return types.string + end + local function ok(idx: getunknown<>): unknown return idx end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_never_serialization_works") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; + ScopedFastFlag udtf{FFlag::LuauUserDefinedTypeFunctions2, true}; + + CheckResult result = check(R"( + type function serialize_never(arg) + return arg + end + type type_being_serialized = never + local function ok(idx: serialize_never): never return idx end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_never_methods_work") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; + ScopedFastFlag udtf{FFlag::LuauUserDefinedTypeFunctions2, true}; + + CheckResult result = check(R"( + type function getnever() + local ty = types.never + if ty:is("never") then + return ty + end + -- this should never be returned + return types.string + end + local function ok(idx: getnever<>): never return idx end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_any_serialization_works") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; + ScopedFastFlag udtf{FFlag::LuauUserDefinedTypeFunctions2, true}; + + CheckResult result = check(R"( + type function serialize_any(arg) + return arg + end + type type_being_serialized = any + local function ok(idx: serialize_any): any return idx end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_any_methods_work") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; + ScopedFastFlag udtf{FFlag::LuauUserDefinedTypeFunctions2, true}; + + CheckResult result = check(R"( + type function getany() + local ty = types.any + if ty:is("any") then + return ty + end + -- this should never be returned + return types.string + end + local function ok(idx: getany<>): any return idx end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_boolean_serialization_works") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; + ScopedFastFlag udtf{FFlag::LuauUserDefinedTypeFunctions2, true}; + + CheckResult result = check(R"( + type function serialize_bool(arg) + return arg + end + type type_being_serialized = boolean + local function ok(idx: serialize_bool): boolean return idx end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_boolean_methods_work") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; + ScopedFastFlag udtf{FFlag::LuauUserDefinedTypeFunctions2, true}; + + CheckResult result = check(R"( + type function getboolean() + local ty = types.boolean + if ty:is("boolean") then + return ty + end + -- this should never be returned + return types.string + end + local function ok(idx: getboolean<>): boolean return idx end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_number_serialization_works") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; + ScopedFastFlag udtf{FFlag::LuauUserDefinedTypeFunctions2, true}; + + CheckResult result = check(R"( + type function serialize_num(arg) + return arg + end + type type_being_serialized = number + local function ok(idx: serialize_num): number return idx end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_number_methods_work") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; + ScopedFastFlag udtf{FFlag::LuauUserDefinedTypeFunctions2, true}; + + CheckResult result = check(R"( + type function getnumber() + local ty = types.number + if ty:is("number") then + return ty + end + -- this should never be returned + return types.string + end + local function ok(idx: getnumber<>): number return idx end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_string_serialization_works") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; + ScopedFastFlag udtf{FFlag::LuauUserDefinedTypeFunctions2, true}; + + CheckResult result = check(R"( + type function serialize_str(arg) + return arg + end + type type_being_serialized = string + local function ok(idx: serialize_str): string return idx end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_string_methods_work") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; + ScopedFastFlag udtf{FFlag::LuauUserDefinedTypeFunctions2, true}; + + CheckResult result = check(R"( + type function getstring() + local ty = types.string + if ty:is("string") then + return ty + end + -- this should never be returned + return types.boolean + end + local function ok(idx: getstring<>): string return idx end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_boolsingleton_serialization_works") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; + ScopedFastFlag udtf{FFlag::LuauUserDefinedTypeFunctions2, true}; + + CheckResult result = check(R"( + type function serialize_boolsingleton(arg) + return arg + end + type type_being_serialized = true + local function ok(idx: serialize_boolsingleton): true return idx end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_boolsingleton_methods_work") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; + ScopedFastFlag udtf{FFlag::LuauUserDefinedTypeFunctions2, true}; + + CheckResult result = check(R"( + type function getboolsingleton() + local ty = types.singleton(true) + if ty:is("singleton") and ty:value() then + return ty + end + -- this should never be returned + return types.string + end + local function ok(idx: getboolsingleton<>): true return idx end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_strsingleton_serialization_works") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; + ScopedFastFlag udtf{FFlag::LuauUserDefinedTypeFunctions2, true}; + + CheckResult result = check(R"( + type function serialize_strsingleton(arg) + return arg + end + type type_being_serialized = "popcorn and movies!" + local function ok(idx: serialize_strsingleton): "popcorn and movies!" return idx end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_strsingleton_methods_work") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; + ScopedFastFlag udtf{FFlag::LuauUserDefinedTypeFunctions2, true}; + + CheckResult result = check(R"( + type function getstrsingleton() + local ty = types.singleton("hungry hippo") + if ty:is("singleton") and ty:value() == "hungry hippo" then + return ty + end + -- this should never be returned + return types.number + end + local function ok(idx: getstrsingleton<>): "hungry hippo" return idx end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_union_serialization_works") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; + ScopedFastFlag udtf{FFlag::LuauUserDefinedTypeFunctions2, true}; + + CheckResult result = check(R"( + type function serialize_union(arg) + return arg + end + type type_being_serialized = number | string | boolean + -- forcing an error here to check the exact type of the union + local function ok(idx: serialize_union): nil return idx end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + TypePackMismatch* tpm = get(result.errors[0]); + REQUIRE(tpm); + CHECK(toString(tpm->givenTp) == "boolean | number | string"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_union_methods_work") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; + ScopedFastFlag udtf{FFlag::LuauUserDefinedTypeFunctions2, true}; + + CheckResult result = check(R"( + type function getunion() + local ty = types.unionof(types.string, types.number, types.boolean) + if ty:is("union") then + -- creating a copy of `ty` + local arr = {} + for _, value in ty:components() do + table.insert(arr, value) + end + return types.unionof(table.unpack(arr)) + end + -- this should never be returned + return types.number + end + -- forcing an error here to check the exact type of the union + local function ok(idx: getunion<>): never return idx end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + TypePackMismatch* tpm = get(result.errors[0]); + REQUIRE(tpm); + CHECK(toString(tpm->givenTp) == "boolean | number | string"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_intersection_serialization_works") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; + ScopedFastFlag udtf{FFlag::LuauUserDefinedTypeFunctions2, true}; + + CheckResult result = check(R"( + type function serialize_intersection(arg) + return arg + end + type type_being_serialized = { boolean: boolean, number: number } & { boolean: boolean, string: string } + -- forcing an error here to check the exact type of the intersection + local function ok(idx: serialize_intersection): nil return idx end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + TypePackMismatch* tpm = get(result.errors[0]); + REQUIRE(tpm); + CHECK(toString(tpm->givenTp) == "{ boolean: boolean, number: number } & { boolean: boolean, string: string }"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_intersection_methods_work") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; + ScopedFastFlag udtf{FFlag::LuauUserDefinedTypeFunctions2, true}; + + CheckResult result = check(R"( + type function getintersection() + local tbl1 = types.newtable(nil, nil, nil) + tbl1:setproperty(types.singleton("boolean"), types.boolean) -- {boolean: boolean} + tbl1:setproperty(types.singleton("number"), types.number) -- {boolean: boolean, number: number} + local tbl2 = types.newtable(nil, nil, nil) + tbl2:setproperty(types.singleton("boolean"), types.boolean) -- {boolean: boolean} + tbl2:setproperty(types.singleton("string"), types.string) -- {boolean: boolean, string: string} + local ty = types.intersectionof(tbl1, tbl2) + if ty:is("intersection") then + -- creating a copy of `ty` + local arr = {} + for index, value in ty:components() do + table.insert(arr, value) + end + return types.intersectionof(table.unpack(arr)) + end + -- this should never be returned + return types.string + end + -- forcing an error here to check the exact type of the intersection + local function ok(idx: getintersection<>): never return idx end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + TypePackMismatch* tpm = get(result.errors[0]); + REQUIRE(tpm); + CHECK(toString(tpm->givenTp) == "{ boolean: boolean, number: number } & { boolean: boolean, string: string }"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_negation_methods_work") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; + ScopedFastFlag udtf{FFlag::LuauUserDefinedTypeFunctions2, true}; + + CheckResult result = check(R"( + type function getnegation() + local ty = types.negationof(types.string) + if ty:is("negation") then + return ty + end + -- this should never be returned + return types.number + end + + -- forcing an error here to check the exact type of the negation + local function ok(idx: getnegation<>): never return idx end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + TypePackMismatch* tpm = get(result.errors[0]); + REQUIRE(tpm); + CHECK(toString(tpm->givenTp) == "~string"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_table_serialization_works") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; + ScopedFastFlag udtf{FFlag::LuauUserDefinedTypeFunctions2, true}; + + CheckResult result = check(R"( + type function serialize_table(arg) + return arg + end + type type_being_serialized = { boolean: boolean, number: number, [string]: number } + -- forcing an error here to check the exact type of the table + local function ok(idx: serialize_table): nil return idx end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + TypePackMismatch* tpm = get(result.errors[0]); + REQUIRE(tpm); + CHECK(toString(tpm->givenTp) == "{ [string]: number, boolean: boolean, number: number }"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_table_methods_work") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; + ScopedFastFlag udtf{FFlag::LuauUserDefinedTypeFunctions2, true}; + + CheckResult result = check(R"( + type function gettable() + local indexer = { + index = types.number, + readresult = types.boolean, + writeresult = types.boolean, + } + local ty = types.newtable(nil, indexer, nil) -- {[number]: boolean} + ty:setproperty(types.singleton("string"), types.number) -- {string: number, [number] = boolean} + ty:setproperty(types.singleton("number"), types.string) -- {string: number, number: string, [number] = boolean} + ty:setproperty(types.singleton("string"), nil) -- {number: string, [number] = boolean} + local ret = types.newtable(nil, nil, nil) -- {} + -- creating a copy of `ty` + for k, v in ty:properties() do + ret:setreadproperty(k, v.read) + ret:setwriteproperty(k, v.write) + end + if ret:is("table") then + ret:setindexer(types.boolean, types.string) -- {number: string, [boolean] = string} + return ret -- {number: string, [boolean] = string} + end + -- this should never be returned + return types.number + end + -- forcing an error here to check the exact type of the table + local function ok(idx: gettable<>): never return idx end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + TypePackMismatch* tpm = get(result.errors[0]); + REQUIRE(tpm); + CHECK(toString(tpm->givenTp) == "{ [boolean]: string, number: string }"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_metatable_methods_work") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; + ScopedFastFlag udtf{FFlag::LuauUserDefinedTypeFunctions2, true}; + + CheckResult result = check(R"( + type function getmetatable() + local indexer = { + index = types.number, + readresult = types.boolean, + writeresult = types.boolean, + } + local ty = types.newtable(nil, indexer, nil) -- {[number]: boolean} + ty:setproperty(types.singleton("string"), types.number) -- {string: number, [number]: boolean} + local metatbl = types.newtable(nil, nil, ty) -- { { }, @metatable { [number]: boolean, string: number } } + metatbl:setmetatable(types.newtable(nil, indexer, nil)) -- { { }, @metatable { [number]: boolean } } + local ret = metatbl:metatable() + if metatbl:is("table") and metatbl:metatable() then + return ret -- { @metatable { [number]: boolean } } + end + -- this should never be returned + return types.number + end + -- forcing an error here to check the exact type of the metatable + local function ok(idx: getmetatable<>): never return idx end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + TypePackMismatch* tpm = get(result.errors[0]); + REQUIRE(tpm); + CHECK(toString(tpm->givenTp) == "{boolean}"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_function_serialization_works") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; + ScopedFastFlag udtf{FFlag::LuauUserDefinedTypeFunctions2, true}; + + CheckResult result = check(R"( + type function serialize_func(arg) + return arg + end + type type_being_serialized = (boolean, number, nil) -> (...string) + local function ok(idx: serialize_func): (boolean, number, nil) -> (...string) return idx end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_function_methods_work") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; + ScopedFastFlag udtf{FFlag::LuauUserDefinedTypeFunctions2, true}; + + CheckResult result = check(R"( + type function getfunction() + local ty = types.newfunction(nil, nil) -- () -> () + ty:setparameters({types.string, types.number}, nil) -- (string, number) -> () + ty:setreturns(nil, types.boolean) -- (string, number) -> (...boolean) + if ty:is("function") then + -- creating a copy of `ty` parameters + local arr = {} + for index, val in ty:parameters().head do + table.insert(arr, val) + end + return types.newfunction({head = arr}, ty:returns()) -- (string, number) -> (...boolean) + end + -- this should never be returned + return types.number + end + local function ok(idx: getfunction<>): never return idx end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + TypePackMismatch* tpm = get(result.errors[0]); + REQUIRE(tpm); + CHECK(toString(tpm->givenTp) == "(string, number) -> (...boolean)"); +} + +TEST_CASE_FIXTURE(ClassFixture, "udtf_class_serialization_works") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; + ScopedFastFlag udtf{FFlag::LuauUserDefinedTypeFunctions2, true}; + + CheckResult result = check(R"( + type function serialize_class(arg) + return arg + end + local function ok(idx: serialize_class): BaseClass return idx end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(ClassFixture, "udtf_class_methods_works") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; + ScopedFastFlag udtf{FFlag::LuauUserDefinedTypeFunctions2, true}; + + + CheckResult result = check(R"( + type function getclass(arg) + local props = arg:properties() + local indexer = arg:indexer() + local metatable = arg:metatable() + return types.newtable(props, indexer, metatable) + end + -- forcing an error here to check the exact type of the metatable + local function ok(idx: getclass): nil return idx end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + TypePackMismatch* tpm = get(result.errors[0]); + REQUIRE(tpm); + CHECK(toString(tpm->givenTp) == "{ BaseField: number, read BaseMethod: (BaseClass, number) -> (), read Touched: Connection }"); +} + +TEST_CASE_FIXTURE(ClassFixture, "write_of_readonly_is_nil") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; + ScopedFastFlag udtf{FFlag::LuauUserDefinedTypeFunctions2, true}; + ScopedFastFlag udtfRwFix{FFlag::LuauUserTypeFunFixNoReadWrite, true}; + + + CheckResult result = check(R"( + type function getclass(arg) + local props = arg:properties() + local table = types.newtable(props) + local singleton = types.singleton("BaseMethod") + + if table:writeproperty(singleton) then + return types.singleton(true) + else + return types.singleton(false) + end + end + -- forcing an error here to check the exact type of the metatable + local function ok(idx: getclass): nil return idx end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + TypePackMismatch* tpm = get(result.errors[0]); + REQUIRE(tpm); + CHECK(toString(tpm->givenTp) == "false"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_check_mutability") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; + ScopedFastFlag udtf{FFlag::LuauUserDefinedTypeFunctions2, true}; + + CheckResult result = check(R"( + type function checkmut() + local indexer = { + index = types.number, + readresult = types.boolean, + writeresult = types.boolean, + } + local ty = types.newtable(props, indexer, nil) -- {[number]: boolean} + ty:setproperty(types.singleton("string"), types.number) -- {string: number, [number]: boolean} + local metatbl = types.newtable(nil, nil, ty) -- { { }, @metatable { [number]: boolean, string: number } } + -- mutate the table + ty:setproperty(types.singleton("string"), nil) -- {[number]: boolean} + if metatbl:is("table") and metatbl:metatable() then + return metatbl -- { @metatable { [number]: boolean }, { } } + end + -- this should never be returned + return types.number + end + local function ok(idx: checkmut<>): never return idx end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + TypePackMismatch* tpm = get(result.errors[0]); + REQUIRE(tpm); + CHECK(toString(tpm->givenTp) == "{ @metatable {boolean}, { } }"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_copy_works") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; + ScopedFastFlag udtf{FFlag::LuauUserDefinedTypeFunctions2, true}; + + CheckResult result = check(R"( + type function getcopy() + local indexer = { + index = types.number, + readresult = types.boolean, + writeresult = types.boolean, + } + local ty = types.newtable(nil, indexer, nil) -- {[number]: boolean} + ty:setproperty(types.singleton("string"), types.number) -- {string: number, [number]: boolean} + local metaty = types.newtable(nil, nil, ty) -- { { }, @metatable { [number]: boolean, string: number } } + local copy = types.copy(metaty) + -- mutate the table + ty:setproperty(types.singleton("string"), nil) -- {[number]: boolean} + if copy:is("table") and copy:metatable() then + return copy -- { { }, @metatable { [number]: boolean, string: number } } + end + -- this should never be returned + return types.number + end + local function ok(idx: getcopy<>): never return idx end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + TypePackMismatch* tpm = get(result.errors[0]); + REQUIRE(tpm); + CHECK(toString(tpm->givenTp) == "{ @metatable { [number]: boolean, string: number }, { } }"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_simple_cyclic_serialization_works") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; + ScopedFastFlag udtf{FFlag::LuauUserDefinedTypeFunctions2, true}; + + CheckResult result = check(R"( + type function serialize_cycle(arg) + return arg + end + type basety = { + first: basety2 + } + type basety2 = { + second: basety + } + local function ok(idx: serialize_cycle): basety return idx end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_createtable_bad_metatable") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; + ScopedFastFlag udtf{FFlag::LuauUserDefinedTypeFunctions2, true}; + + CheckResult result = check(R"( + type function badmetatable() + return types.newtable(nil, nil, types.number) + end + local function bad(arg: badmetatable<>) end + )"); + + LUAU_CHECK_ERROR_COUNT(4, result); // There are 2 type function uninhabited error, 2 user defined type function error + UserDefinedTypeFunctionError* e = get(result.errors[0]); + REQUIRE(e); + CHECK( + e->message == "'badmetatable' type function errored at runtime: [string \"badmetatable\"]:3: types.newtable: expected to be given a table " + "type as a metatable, but got number instead" + ); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_complex_cyclic_serialization_works") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; + ScopedFastFlag udtf{FFlag::LuauUserDefinedTypeFunctions2, true}; + + CheckResult result = check(R"( + type function serialize_cycle2(arg) + return arg + end + type Employee = { + name: string, + department: Department? + } + type Department = { + name: string, + manager: Employee?, + employees: { Employee }, + company: Company? + } + type Company = { + name: string, + departments: { Department } + } + local function ok(idx: serialize_cycle2): Company return idx end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_user_error_is_reported") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; + ScopedFastFlag udtf{FFlag::LuauUserDefinedTypeFunctions2, true}; + + CheckResult result = check(R"( + type function errors_if_string(arg) + if arg:is("string") then + local a = 1 + error("We are in a math class! not english") + end + return arg + end + local function ok(idx: errors_if_string): nil return idx end + )"); + + LUAU_CHECK_ERROR_COUNT(4, result); // There are 2 type function uninhabited error, 2 user defined type function error + UserDefinedTypeFunctionError* e = get(result.errors[0]); + REQUIRE(e); + CHECK(e->message == "'errors_if_string' type function errored at runtime: [string \"errors_if_string\"]:5: We are in a math class! not english"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_type_overrides_call_metamethod") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; + ScopedFastFlag udtf{FFlag::LuauUserDefinedTypeFunctions2, true}; + + CheckResult result = check(R"( + type function hello(arg) + error(type(arg)) + end + local function ok(idx: hello): nil return idx end + )"); + + LUAU_CHECK_ERROR_COUNT(4, result); // There are 2 type function uninhabited error, 2 user defined type function error + UserDefinedTypeFunctionError* e = get(result.errors[0]); + REQUIRE(e); + CHECK(e->message == "'hello' type function errored at runtime: [string \"hello\"]:3: userdata"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_type_overrides_eq_metamethod") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; + ScopedFastFlag udtf{FFlag::LuauUserDefinedTypeFunctions2, true}; + + CheckResult result = check(R"( + type function hello() + local p1 = types.string + local p2 = types.string + local t1 = types.newtable(nil, nil, nil) + t1:setproperty(types.singleton("string"), types.boolean) + t1:setmetatable(t1) + local t2 = types.newtable(nil, nil, nil) + t2:setproperty(types.singleton("string"), types.boolean) + t1:setmetatable(t1) + if p1 == p2 and t1 == t2 then + return types.number + end + end + local function ok(idx: hello<>): number return idx end + )"); + + LUAU_CHECK_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_function_type_cant_call_get_props") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; + ScopedFastFlag udtf{FFlag::LuauUserDefinedTypeFunctions2, true}; + + CheckResult result = check(R"( + type function hello(arg) + local arr = arg:properties() + end + local function ok(idx: hello<() -> ()>): nil return idx end + )"); + + LUAU_CHECK_ERROR_COUNT(4, result); // There are 2 type function uninhabited error, 2 user defined type function error + UserDefinedTypeFunctionError* e = get(result.errors[0]); + REQUIRE(e); + CHECK( + e->message == "'hello' type function errored at runtime: [string \"hello\"]:3: type.properties: expected self to be either a table or class, " + "but got function instead" + ); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_calling_each_other") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; + ScopedFastFlag udtf{FFlag::LuauUserDefinedTypeFunctions2, true}; + + CheckResult result = check(R"( + type function foo() + return "hi" + end + type function bar() + return types.singleton(foo()) + end + local function ok(idx: bar<>): nil return idx end + )"); + + LUAU_CHECK_ERROR_COUNT(1, result); + TypePackMismatch* tpm = get(result.errors[0]); + REQUIRE(tpm); + CHECK(toString(tpm->givenTp) == "\"hi\""); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_no_shared_state") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; + ScopedFastFlag udtf{FFlag::LuauUserDefinedTypeFunctions2, true}; + + CheckResult result = check(R"( + type function foo() + if not glob then + glob = 'a' + else + glob ..= 'b' + end + + return glob + end + type function bar(prefix) + return types.singleton(prefix:value() .. foo()) + end + local function ok1(idx: bar<'x'>): nil return idx end + local function ok2(idx: bar<'y'>): nil return idx end + )"); + + // We are only checking first errors, others are mostly duplicates + LUAU_CHECK_ERROR_COUNT(8, result); + CHECK(toString(result.errors[0]) == R"('bar' type function errored at runtime: [string "foo"]:4: attempt to modify a readonly table)"); + CHECK(toString(result.errors[1]) == R"(Type function instance bar<"x"> is uninhabited)"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_optionify") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; + ScopedFastFlag udtf{FFlag::LuauUserDefinedTypeFunctions2, true}; + + CheckResult result = check(R"( + type function optionify(tbl) + if not tbl:is("table") then + error("Argument is not a table") + end + for k, v in tbl:properties() do + tbl:setproperty(k, types.unionof(v.read, types.singleton(nil))) + end + return tbl + end + type Person = { + name: string, + age: number, + alive: boolean + } + local function ok(idx: optionify): nil return idx end + )"); + + LUAU_CHECK_ERROR_COUNT(1, result); + TypePackMismatch* tpm = get(result.errors[0]); + REQUIRE(tpm); + CHECK(toString(tpm->givenTp) == "{ age: number?, alive: boolean?, name: string? }"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_calling_illegal_global") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; + ScopedFastFlag udtf{FFlag::LuauUserDefinedTypeFunctions2, true}; + + CheckResult result = check(R"( + type function illegal(arg) + gcinfo() -- this should error + + return arg -- this should not be reached + end + + local function ok(idx: illegal): nil return idx end + )"); + + LUAU_CHECK_ERROR_COUNT(4, result); // There are 2 type function uninhabited error, 2 user defined type function error + UserDefinedTypeFunctionError* e = get(result.errors[0]); + REQUIRE(e); + CHECK(e->message == "'illegal' type function errored at runtime: [string \"illegal\"]:3: this function is not supported in type functions"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_recursion_and_gc") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; + ScopedFastFlag udtf{FFlag::LuauUserDefinedTypeFunctions2, true}; + + CheckResult result = check(R"( + type function foo(tbl) + local count = 0 + for k,v in tbl:properties() do count += 1 end + if count < 100 then + tbl:setproperty(types.singleton(`m{count}`), types.string) + foo(tbl) + end + for i = 1,100 do table.create(10000) end + return tbl + end + type Test = {} + local function ok(idx: foo): nil return idx end + )"); + + LUAU_CHECK_ERROR_COUNT(1, result); + TypePackMismatch* tpm = get(result.errors[0]); + REQUIRE(tpm); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_recovery_no_upvalues") +{ + ScopedFastFlag solverV2{FFlag::LuauSolverV2, true}; + ScopedFastFlag userDefinedTypeFunctionsSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; + ScopedFastFlag userDefinedTypeFunctions{FFlag::LuauUserDefinedTypeFunctions2, true}; + ScopedFastFlag userDefinedTypeFunctionNoEvaluation{FFlag::LuauUserDefinedTypeFunctionNoEvaluation, true}; + + CheckResult result = check(R"( + local var + + type function save_upvalue(arg) + var = 1 + return arg + end + + type test = "test" + local function ok(idx: save_upvalue): "test" + return idx + end + )"); + + LUAU_CHECK_ERROR_COUNT(1, result); + CHECK(toString(result.errors[0]) == R"(Type function cannot reference outer local 'var')"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_follow") +{ + ScopedFastFlag solverV2{FFlag::LuauSolverV2, true}; + ScopedFastFlag userDefinedTypeFunctionsSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; + ScopedFastFlag userDefinedTypeFunctions{FFlag::LuauUserDefinedTypeFunctions2, true}; + + CheckResult result = check(R"( + type t0 = any + type function t0() + return types.any + end + )"); + + LUAU_CHECK_ERROR_COUNT(1, result); + CHECK(toString(result.errors[0]) == R"(Redefinition of type 't0', previously defined at line 2)"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_strip_indexer") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; + ScopedFastFlag udtf{FFlag::LuauUserDefinedTypeFunctions2, true}; + + CheckResult result = check(R"( + type function stripindexer(tbl) + if not tbl:is("table") then + error("can only strip the indexer on a table!") + end + tbl:setindexer(types.never, types.never) + return tbl + end + + type map = { [number]: string, foo: string } + -- forcing an error here to check the exact type + local function ok(tbl: stripindexer): never return tbl end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + TypePackMismatch* tpm = get(result.errors[0]); + REQUIRE(tpm); + CHECK(toString(tpm->givenTp) == "{ foo: string }"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "no_type_methods_on_types") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; + ScopedFastFlag udtf{FFlag::LuauUserDefinedTypeFunctions2, true}; + ScopedFastFlag luauUserTypeFunFixRegister{FFlag::LuauUserTypeFunFixRegister, true}; + + CheckResult result = check(R"( + type function test(x) + return if types.is(x, "number") then types.string else types.boolean + end + local function ok(tbl: test): never return tbl end + )"); + + LUAU_REQUIRE_ERROR_COUNT(4, result); + CHECK(toString(result.errors[0]) == R"('test' type function errored at runtime: [string "test"]:3: attempt to call a nil value)"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "no_types_functions_on_type") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; + ScopedFastFlag udtf{FFlag::LuauUserDefinedTypeFunctions2, true}; + ScopedFastFlag luauUserTypeFunFixRegister{FFlag::LuauUserTypeFunFixRegister, true}; + + CheckResult result = check(R"( + type function test(x) + return x.singleton("a") + end + local function ok(tbl: test): never return tbl end + )"); + + LUAU_REQUIRE_ERROR_COUNT(4, result); + CHECK(toString(result.errors[0]) == R"('test' type function errored at runtime: [string "test"]:3: attempt to call a nil value)"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "no_metatable_writes") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; + ScopedFastFlag udtf{FFlag::LuauUserDefinedTypeFunctions2, true}; + ScopedFastFlag luauUserTypeFunFixRegister{FFlag::LuauUserTypeFunFixRegister, true}; + + CheckResult result = check(R"( + type function test(x) + local a = x.__index + a.is = function() return false end + return types.singleton(x.is("number")) + end + local function ok(tbl: test): never return tbl end + )"); + + LUAU_REQUIRE_ERROR_COUNT(4, result); + CHECK(toString(result.errors[0]) == R"('test' type function errored at runtime: [string "test"]:4: attempt to index nil with 'is')"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "no_eq_field") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; + ScopedFastFlag udtf{FFlag::LuauUserDefinedTypeFunctions2, true}; + ScopedFastFlag luauUserTypeFunFixRegister{FFlag::LuauUserTypeFunFixRegister, true}; + + CheckResult result = check(R"( + type function test(x) + return types.singleton(x.__eq(x, types.number)) + end + local function ok(tbl: test): never return tbl end + )"); + + LUAU_REQUIRE_ERROR_COUNT(4, result); + CHECK(toString(result.errors[0]) == R"('test' type function errored at runtime: [string "test"]:3: attempt to call a nil value)"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "tag_field") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; + ScopedFastFlag udtf{FFlag::LuauUserDefinedTypeFunctions2, true}; + ScopedFastFlag luauUserTypeFunFixRegister{FFlag::LuauUserTypeFunFixRegister, true}; + + CheckResult result = check(R"( + type function test(x) + return types.singleton(x.tag) + end + + local function ok1(tbl: test): never return tbl end + local function ok2(tbl: test): never return tbl end + local function ok3(tbl: test<{}>): never return tbl end + )"); + + LUAU_REQUIRE_ERROR_COUNT(3, result); + CHECK(toString(result.errors[0]) == R"(Type pack '"number"' could not be converted into 'never'; at [0], "number" is not a subtype of never)"); + CHECK(toString(result.errors[1]) == R"(Type pack '"string"' could not be converted into 'never'; at [0], "string" is not a subtype of never)"); + CHECK(toString(result.errors[2]) == R"(Type pack '"table"' could not be converted into 'never'; at [0], "table" is not a subtype of never)"); +} + +TEST_SUITE_END(); diff --git a/tests/TypeInfer.aliases.test.cpp b/tests/TypeInfer.aliases.test.cpp index a2fc0c75e..53f134f30 100644 --- a/tests/TypeInfer.aliases.test.cpp +++ b/tests/TypeInfer.aliases.test.cpp @@ -4,11 +4,13 @@ #include "doctest.h" #include "Luau/BuiltinDefinitions.h" +#include "Luau/AstQuery.h" using namespace Luau; -LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) -LUAU_FASTFLAG(LuauTypeMismatchInvarianceInError) +LUAU_FASTFLAG(LuauSolverV2) +LUAU_FASTFLAG(LuauUserDefinedTypeFunctionsSyntax2) +LUAU_FASTFLAG(LuauUserDefinedTypeFunctions2) TEST_SUITE_BEGIN("TypeAliases"); @@ -71,39 +73,50 @@ TEST_CASE_FIXTURE(Fixture, "cannot_steal_hoisted_type_alias") )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - if (FFlag::DebugLuauDeferredConstraintResolution) + if (FFlag::LuauSolverV2) { - CHECK(result.errors[0] == TypeError{ - Location{{1, 21}, {1, 26}}, - getMainSourceModule()->name, - TypeMismatch{ - builtinTypes->numberType, - builtinTypes->stringType, - }, - }); + CHECK( + result.errors[0] == + TypeError{ + Location{{1, 21}, {1, 26}}, + getMainSourceModule()->name, + TypeMismatch{ + builtinTypes->numberType, + builtinTypes->stringType, + }, + } + ); } else { - CHECK(result.errors[0] == TypeError{ - Location{{1, 8}, {1, 26}}, - getMainSourceModule()->name, - TypeMismatch{ - builtinTypes->numberType, - builtinTypes->stringType, - }, - }); + CHECK( + result.errors[0] == + TypeError{ + Location{{1, 8}, {1, 26}}, + getMainSourceModule()->name, + TypeMismatch{ + builtinTypes->numberType, + builtinTypes->stringType, + }, + } + ); } } TEST_CASE_FIXTURE(Fixture, "mismatched_generic_type_param") { + // We erroneously report an extra error in this case when the new solver is enabled. + ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + CheckResult result = check(R"( type T = (A...) -> () )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK(toString(result.errors[0]) == - "Generic type 'A' is used as a variadic type parameter; consider changing 'A' to 'A...' in the generic argument list"); + CHECK( + toString(result.errors[0]) == + "Generic type 'A' is used as a variadic type parameter; consider changing 'A' to 'A...' in the generic argument list" + ); CHECK(result.errors[0].location == Location{{1, 21}, {1, 25}}); } @@ -114,8 +127,10 @@ TEST_CASE_FIXTURE(Fixture, "mismatched_generic_pack_type_param") )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK(toString(result.errors[0]) == - "Variadic type parameter 'A...' is used as a regular generic type; consider changing 'A...' to 'A' in the generic argument list"); + CHECK( + toString(result.errors[0]) == + "Variadic type parameter 'A...' is used as a regular generic type; consider changing 'A...' to 'A' in the generic argument list" + ); CHECK(result.errors[0].location == Location{{1, 24}, {1, 25}}); } @@ -159,16 +174,18 @@ TEST_CASE_FIXTURE(Fixture, "cyclic_types_of_named_table_fields_do_not_expand_whe CheckResult result = check(R"( --!strict type Node = { Parent: Node?; } - local node: Node; - node.Parent = 1 + + function f(node: Node) + node.Parent = 1 + end )"); LUAU_REQUIRE_ERROR_COUNT(1, result); TypeMismatch* tm = get(result.errors[0]); - REQUIRE(tm); + REQUIRE_MESSAGE(tm, result.errors[0]); CHECK_EQ("Node?", toString(tm->wantedType)); - CHECK_EQ(typeChecker.numberType, tm->givenType); + CHECK_EQ(builtinTypes->numberType, tm->givenType); } TEST_CASE_FIXTURE(Fixture, "mutually_recursive_aliases") @@ -188,8 +205,9 @@ TEST_CASE_FIXTURE(Fixture, "mutually_recursive_aliases") TEST_CASE_FIXTURE(Fixture, "generic_aliases") { - ScopedFastFlag sff_DebugLuauDeferredConstraintResolution{"DebugLuauDeferredConstraintResolution", true}; - + ScopedFastFlag sff[] = { + {FFlag::LuauSolverV2, true}, + }; CheckResult result = check(R"( type T = { v: a } local x: T = { v = 123 } @@ -198,24 +216,16 @@ TEST_CASE_FIXTURE(Fixture, "generic_aliases") )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - - const char* expectedError; - if (FFlag::LuauTypeMismatchInvarianceInError) - expectedError = "Type 'bad' could not be converted into 'T'\n" - "caused by:\n" - " Property 'v' is not compatible. Type 'string' could not be converted into 'number' in an invariant context"; - else - expectedError = "Type 'bad' could not be converted into 'T'\n" - "caused by:\n" - " Property 'v' is not compatible. Type 'string' could not be converted into 'number'"; - + const std::string expected = R"(Type '{ v: string }' could not be converted into 'T'; at [read "v"], string is not exactly number)"; CHECK(result.errors[0].location == Location{{4, 31}, {4, 44}}); - CHECK(toString(result.errors[0]) == expectedError); + CHECK_EQ(expected, toString(result.errors[0])); } TEST_CASE_FIXTURE(Fixture, "dependent_generic_aliases") { - ScopedFastFlag sff_DebugLuauDeferredConstraintResolution{"DebugLuauDeferredConstraintResolution", true}; + ScopedFastFlag sff[] = { + {FFlag::LuauSolverV2, true}, + }; CheckResult result = check(R"( type T = { v: a } @@ -225,27 +235,18 @@ TEST_CASE_FIXTURE(Fixture, "dependent_generic_aliases") )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - - std::string expectedError; - if (FFlag::LuauTypeMismatchInvarianceInError) - expectedError = "Type 'bad' could not be converted into 'U'\n" - "caused by:\n" - " Property 't' is not compatible. Type '{ v: string }' could not be converted into 'T'\n" - "caused by:\n" - " Property 'v' is not compatible. Type 'string' could not be converted into 'number' in an invariant context"; - else - expectedError = "Type 'bad' could not be converted into 'U'\n" - "caused by:\n" - " Property 't' is not compatible. Type '{ v: string }' could not be converted into 'T'\n" - "caused by:\n" - " Property 'v' is not compatible. Type 'string' could not be converted into 'number'"; + const std::string expected = + R"(Type '{ t: { v: string } }' could not be converted into 'U'; at [read "t"][read "v"], string is not exactly number)"; CHECK(result.errors[0].location == Location{{4, 31}, {4, 52}}); - CHECK(toString(result.errors[0]) == expectedError); + CHECK_EQ(expected, toString(result.errors[0])); } TEST_CASE_FIXTURE(Fixture, "mutually_recursive_generic_aliases") { + // CLI-116108 + ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + CheckResult result = check(R"( --!strict type T = { f: a, g: U } @@ -276,7 +277,7 @@ TEST_CASE_FIXTURE(Fixture, "mutually_recursive_types_errors") // We had a UAF in this example caused by not cloning type function arguments ModulePtr module = frontend.moduleResolver.getModule("MainModule"); unfreeze(module->interfaceTypes); - copyErrors(module->errors, module->interfaceTypes); + copyErrors(module->errors, module->interfaceTypes, builtinTypes); freeze(module->interfaceTypes); module->internalTypes.clear(); module->astTypes.clear(); @@ -329,7 +330,7 @@ TEST_CASE_FIXTURE(Fixture, "stringify_type_alias_of_recursive_template_table_typ TypeMismatch* tm = get(result.errors[0]); REQUIRE(tm); CHECK_EQ("Wrapped", toString(tm->wantedType)); - CHECK_EQ(typeChecker.numberType, tm->givenType); + CHECK_EQ(builtinTypes->numberType, tm->givenType); } TEST_CASE_FIXTURE(Fixture, "stringify_type_alias_of_recursive_template_table_type2") @@ -344,13 +345,20 @@ TEST_CASE_FIXTURE(Fixture, "stringify_type_alias_of_recursive_template_table_typ TypeMismatch* tm = get(result.errors[0]); REQUIRE(tm); - CHECK_EQ("t1 where t1 = ({| a: t1 |}) -> string", toString(tm->wantedType)); - CHECK_EQ(typeChecker.numberType, tm->givenType); + if (FFlag::LuauSolverV2) + CHECK_EQ("t1 where t1 = ({ a: t1 }) -> string", toString(tm->wantedType)); + else + CHECK_EQ("t1 where t1 = ({| a: t1 |}) -> string", toString(tm->wantedType)); + CHECK_EQ(builtinTypes->numberType, tm->givenType); } // Check that recursive intersection type doesn't generate an OOM TEST_CASE_FIXTURE(Fixture, "cli_38393_recursive_intersection_oom") { + ScopedFastFlag sff[] = { + {FFlag::LuauSolverV2, false}, + }; // FIXME + CheckResult result = check(R"( function _(l0:(t0)&((t0)&(((t0)&((t0)->()))->(typeof(_),typeof(# _)))),l39,...):any end @@ -374,15 +382,19 @@ TEST_CASE_FIXTURE(Fixture, "corecursive_types_generic") const std::string code = R"( type A = {v:T, b:B} type B = {v:T, a:A} - local aa:A - local bb = aa + + function f(a: A) + return a + end )"; const std::string expected = R"( type A = {v:T, b:B} type B = {v:T, a:A} - local aa:A - local bb:A=aa + + function f(a: A): A + return a + end )"; CHECK_EQ(expected, decorateWithTypes(code)); @@ -396,18 +408,18 @@ TEST_CASE_FIXTURE(Fixture, "corecursive_function_types") CheckResult result = check(R"( type A = () -> (number, B) type B = () -> (string, A) - local a: A - local b: B )"); LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("t1 where t1 = () -> (number, () -> (string, t1))", toString(requireType("a"))); - CHECK_EQ("t1 where t1 = () -> (string, () -> (number, t1))", toString(requireType("b"))); + CHECK_EQ("t1 where t1 = () -> (number, () -> (string, t1))", toString(requireTypeAlias("A"))); + CHECK_EQ("t1 where t1 = () -> (string, () -> (number, t1))", toString(requireTypeAlias("B"))); } TEST_CASE_FIXTURE(Fixture, "generic_param_remap") { + ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + const std::string code = R"( -- An example of a forwarded use of a type that has different type arguments than parameters type A = {t:T, u:U, next:A?} @@ -520,7 +532,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "type_alias_import_mutation") CheckResult result = check("type t10 = typeof(table)"); LUAU_REQUIRE_NO_ERRORS(result); - TypeId ty = getGlobalBinding(frontend, "table"); + TypeId ty = getGlobalBinding(frontend.globals, "table"); CHECK(toString(ty) == "typeof(table)"); @@ -532,11 +544,13 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "type_alias_import_mutation") TEST_CASE_FIXTURE(Fixture, "type_alias_local_mutation") { + ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + CheckResult result = check(R"( -type Cool = { a: number, b: string } -local c: Cool = { a = 1, b = "s" } -type NotCool = Cool -)"); + type Cool = { a: number, b: string } + local c: Cool = { a = 1, b = "s" } + type NotCool = Cool + )"); LUAU_REQUIRE_NO_ERRORS(result); std::optional ty = requireType("c"); @@ -551,6 +565,8 @@ type NotCool = Cool TEST_CASE_FIXTURE(Fixture, "type_alias_local_rename") { + ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + CheckResult result = check(R"( type Cool = { a: number, b: string } type NotCool = Cool @@ -612,16 +628,16 @@ type X = Import.X TEST_CASE_FIXTURE(BuiltinsFixture, "type_alias_of_an_imported_recursive_generic_type") { fileResolver.source["game/A"] = R"( -export type X = { a: T, b: U, C: X? } -return {} + export type X = { a: T, b: U, C: X? } + return {} )"; CheckResult aResult = frontend.check("game/A"); LUAU_REQUIRE_NO_ERRORS(aResult); CheckResult bResult = check(R"( -local Import = require(game.A) -type X = Import.X + local Import = require(game.A) + type X = Import.X )"); LUAU_REQUIRE_NO_ERRORS(bResult); @@ -634,8 +650,8 @@ type X = Import.X CHECK_EQ(toString(*ty1, {true}), toString(*ty2, {true})); bResult = check(R"( -local Import = require(game.A) -type X = Import.X + local Import = require(game.A) + type X = Import.X )"); LUAU_REQUIRE_NO_ERRORS(bResult); @@ -645,8 +661,16 @@ type X = Import.X ty2 = lookupType("X"); REQUIRE(ty2); - CHECK_EQ(toString(*ty1, {true}), "t1 where t1 = {| C: t1?, a: T, b: U |}"); - CHECK_EQ(toString(*ty2, {true}), "{| C: t1, a: U, b: T |} where t1 = {| C: t1, a: U, b: T |}?"); + if (FFlag::LuauSolverV2) + { + CHECK(toString(*ty1, {true}) == "t1 where t1 = { C: t1?, a: T, b: U }"); + CHECK(toString(*ty2, {true}) == "t1 where t1 = { C: t1?, a: U, b: T }"); + } + else + { + CHECK_EQ(toString(*ty1, {true}), "t1 where t1 = {| C: t1?, a: T, b: U |}"); + CHECK_EQ(toString(*ty2, {true}), "{| C: t1, a: U, b: T |} where t1 = {| C: t1, a: U, b: T |}?"); + } } TEST_CASE_FIXTURE(Fixture, "module_export_free_type_leak") @@ -688,6 +712,9 @@ TEST_CASE_FIXTURE(Fixture, "mutually_recursive_types_restriction_ok") TEST_CASE_FIXTURE(Fixture, "mutually_recursive_types_restriction_not_ok_1") { + // CLI-116108 + ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + CheckResult result = check(R"( -- OK because forwarded types are used with their parameters. type Tree = { data: T, children: Forest } @@ -699,6 +726,9 @@ TEST_CASE_FIXTURE(Fixture, "mutually_recursive_types_restriction_not_ok_1") TEST_CASE_FIXTURE(Fixture, "mutually_recursive_types_restriction_not_ok_2") { + // CLI-116108 + ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + CheckResult result = check(R"( -- Not OK because forwarded types are used with different types than their parameters. type Forest = {Tree<{T}>} @@ -720,6 +750,9 @@ TEST_CASE_FIXTURE(Fixture, "mutually_recursive_types_swapsies_ok") TEST_CASE_FIXTURE(Fixture, "mutually_recursive_types_swapsies_not_ok") { + // CLI-116108 + ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + CheckResult result = check(R"( type Tree1 = { data: T, children: {Tree2} } type Tree2 = { data: U, children: {Tree1} } @@ -823,42 +856,14 @@ TEST_CASE_FIXTURE(Fixture, "forward_declared_alias_is_not_clobbered_by_prior_uni local d: FutureType = { smth = true } -- missing error, 'd' is resolved to 'any' )"); - CHECK_EQ("{| foo: number |}", toString(requireType("d"), {true})); + if (FFlag::LuauSolverV2) + CHECK_EQ("{ foo: number }", toString(requireType("d"), {true})); + else + CHECK_EQ("{| foo: number |}", toString(requireType("d"), {true})); LUAU_REQUIRE_ERROR_COUNT(1, result); } -TEST_CASE_FIXTURE(Fixture, "forward_declared_alias_is_not_clobbered_by_prior_unification_with_any_2") -{ - ScopedFastFlag sff[] = { - {"DebugLuauSharedSelf", true}, - }; - - CheckResult result = check(R"( - local B = {} - B.bar = 4 - - function B:smth1() - local self: FutureIntersection = self - self.foo = 4 - return 4 - end - - function B:smth2() - local self: FutureIntersection = self - self.bar = 5 -- error, even though we should have B part with bar - end - - type A = { foo: typeof(B.smth1({foo=3})) } -- trick toposort into sorting functions before types - type B = typeof(B) - - type FutureIntersection = A & B - )"); - - // TODO: shared self causes this test to break in bizarre ways. - LUAU_REQUIRE_ERRORS(result); -} - TEST_CASE_FIXTURE(Fixture, "recursive_types_restriction_ok") { CheckResult result = check(R"( @@ -870,6 +875,9 @@ TEST_CASE_FIXTURE(Fixture, "recursive_types_restriction_ok") TEST_CASE_FIXTURE(Fixture, "recursive_types_restriction_not_ok") { + // CLI-116108 + ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + CheckResult result = check(R"( -- this would be an infinite type if we allowed it type Tree = { data: T, children: {Tree<{T}>} } @@ -880,6 +888,9 @@ TEST_CASE_FIXTURE(Fixture, "recursive_types_restriction_not_ok") TEST_CASE_FIXTURE(Fixture, "report_shadowed_aliases") { + // CLI-116110 + ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + // We allow a previous type alias to depend on a future type alias. That exact feature enables a confusing example, like the following snippet, // which has the type alias FakeString point to the type alias `string` that which points to `number`. CheckResult result = check(R"( @@ -922,4 +933,267 @@ TEST_CASE_FIXTURE(Fixture, "cannot_create_cyclic_type_with_unknown_module") CHECK(toString(result.errors[0]) == "Unknown type 'B.AAA'"); } +TEST_CASE_FIXTURE(Fixture, "type_alias_locations") +{ + check(R"( + type T = number + + do + type T = string + type X = boolean + end + )"); + + ModulePtr mod = getMainModule(); + REQUIRE(!mod->scopes.empty()); + + REQUIRE(mod->scopes[0].second->typeAliasNameLocations.count("T") > 0); + CHECK(mod->scopes[0].second->typeAliasNameLocations["T"] == Location(Position(1, 13), 1)); + + ScopePtr doScope = findScopeAtPosition(*mod, Position{4, 0}); + REQUIRE(doScope); + + REQUIRE(doScope->typeAliasNameLocations.count("T") > 0); + CHECK(doScope->typeAliasNameLocations["T"] == Location(Position(4, 17), 1)); + + REQUIRE(doScope->typeAliasNameLocations.count("X") > 0); + CHECK(doScope->typeAliasNameLocations["X"] == Location(Position(5, 17), 1)); +} + +/* + * We had a bug in DCR where substitution would improperly clone a + * PendingExpansionType. + * + * This cloned type did not have a matching constraint to expand it, so it was + * left dangling and unexpanded forever. + * + * We must also delay the dispatch a constraint if doing so would require + * unifying a PendingExpansionType. + */ +TEST_CASE_FIXTURE(BuiltinsFixture, "dont_lose_track_of_PendingExpansionTypes_after_substitution") +{ + // CLI-114134 - We need egraphs to properly simplify these types. + ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + + fileResolver.source["game/ReactCurrentDispatcher"] = R"( + export type BasicStateAction = ((S) -> S) | S + export type Dispatch = (A) -> () + + export type Dispatcher = { + useState: (initialState: (() -> S) | S) -> (S, Dispatch>), + } + + return {} + )"; + + // Note: This script path is actually as short as it can be. Any shorter + // and we somehow fail to surface the bug. + fileResolver.source["game/React/React/ReactHooks"] = R"( + local RCD = require(script.Parent.Parent.Parent.ReactCurrentDispatcher) + + local function resolveDispatcher(): RCD.Dispatcher + return (nil :: any) :: RCD.Dispatcher + end + + function useState( + initialState: (() -> S) | S + ): (S, RCD.Dispatch>) + local dispatcher = resolveDispatcher() + return dispatcher.useState(initialState) + end + )"; + + CheckResult result = frontend.check("game/React/React/ReactHooks"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "another_thing_from_roact") +{ + CheckResult result = check(R"( + type Map = { [K]: V } + type Set = { [T]: boolean } + + type FiberRoot = { + pingCache: Map | Map>)> | nil, + } + + type Wakeable = { + andThen: (self: Wakeable) -> nil | Wakeable, + } + + local function attachPingListener(root: FiberRoot, wakeable: Wakeable, lanes: number) + local pingCache: Map | Map>)> | nil = root.pingCache + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +/* + * It is sometimes possible for type alias resolution to produce a TypeId that + * belongs to a different module. + * + * We must not mutate any fields of the resulting type when this happens. The + * memory has been frozen. + */ +TEST_CASE_FIXTURE(BuiltinsFixture, "alias_expands_to_bare_reference_to_imported_type") +{ + fileResolver.source["game/A"] = R"( + --!strict + export type Object = {[string]: any} + return {} + )"; + + fileResolver.source["game/B"] = R"( + local A = require(script.Parent.A) + + type Object = A.Object + type ReadOnly = T + + local function f(): ReadOnly + return nil :: any + end + )"; + + CheckResult result = frontend.check("game/B"); + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "table_types_record_the_property_locations") +{ + CheckResult result = check(R"( + type Table = { + create: () -> () + } + + local x: Table + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + auto ty = requireTypeAlias("Table"); + + auto ttv = Luau::get(follow(ty)); + REQUIRE(ttv); + + auto propIt = ttv->props.find("create"); + REQUIRE(propIt != ttv->props.end()); + + CHECK_EQ(propIt->second.location, std::nullopt); + CHECK_EQ(propIt->second.typeLocation, Location({2, 12}, {2, 18})); +} + +TEST_CASE_FIXTURE(Fixture, "typeof_is_not_a_valid_alias_name") +{ + CheckResult result = check(R"( + type typeof = number + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CHECK("Type aliases cannot be named typeof" == toString(result.errors[0])); +} + +TEST_CASE_FIXTURE(Fixture, "fuzzer_bug_doesnt_crash") +{ + CheckResult result = check(R"( +type t0 = (t0) +)"); + LUAU_REQUIRE_ERRORS(result); +} + + +TEST_CASE_FIXTURE(Fixture, "recursive_type_alias_warns") +{ + CheckResult result = check(R"( +type Foo = Foo +)"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + auto occursCheckError = get(result.errors[0]); + REQUIRE(occursCheckError); +} + +TEST_CASE_FIXTURE(Fixture, "recursive_type_alias_bad_pack_use_warns") +{ + if (!FFlag::LuauSolverV2) + return; + + CheckResult result = check(R"( +type Foo = Foo +)"); + + LUAU_REQUIRE_ERROR_COUNT(4, result); + auto occursCheckFailed = get(result.errors[1]); + REQUIRE(occursCheckFailed); + + auto swappedGeneric = get(result.errors[2]); + REQUIRE(swappedGeneric); + CHECK(swappedGeneric->name == "T"); +} + +TEST_CASE_FIXTURE(Fixture, "corecursive_aliases") +{ + CheckResult result = check(R"( +type Foo = Bar +type Bar = Foo +)"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + auto err = get(result.errors[0]); + REQUIRE(err); +} + +TEST_CASE_FIXTURE(Fixture, "should_also_occurs_check") +{ + CheckResult result = check(R"( +type Foo = Foo | string +)"); + LUAU_REQUIRE_ERROR_COUNT(1, result); + auto err = get(result.errors[0]); + REQUIRE(err); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "type_alias_adds_reduce_constraint_for_type_function") +{ + if (!FFlag::LuauSolverV2 || !FFlag::LuauUserDefinedTypeFunctions2) + return; + + CheckResult result = check(R"( + type plus = add + + local sum: plus = 10 + )"); + + LUAU_CHECK_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "user_defined_type_function_errors") +{ + ScopedFastFlag sff{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; + ScopedFastFlag noUDTFimpl{FFlag::LuauUserDefinedTypeFunctions2, false}; + + CheckResult result = check(R"( + type function foo() + return nil + end + )"); + LUAU_CHECK_ERROR_COUNT(1, result); + CHECK(toString(result.errors[0]) == "This syntax is not supported"); +} + +TEST_CASE_FIXTURE(Fixture, "bound_type_in_alias_segfault") +{ + ScopedFastFlag sff{FFlag::LuauSolverV2, true}; + + LUAU_CHECK_NO_ERRORS(check(R"( + --!nonstrict + type Map = {[ K]: V} + function foo:bar(): Config end + type Config = Map & { fields: FieldConfigMap} + export type FieldConfig = {[ string]: any} + export type FieldConfigMap = Map> + )")); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.annotations.test.cpp b/tests/TypeInfer.annotations.test.cpp index d5f953746..4029924a7 100644 --- a/tests/TypeInfer.annotations.test.cpp +++ b/tests/TypeInfer.annotations.test.cpp @@ -7,27 +7,106 @@ #include "doctest.h" +LUAU_FASTFLAG(LuauSolverV2); +LUAU_FASTFLAG(DebugLuauMagicTypes); + using namespace Luau; TEST_SUITE_BEGIN("AnnotationTests"); -TEST_CASE_FIXTURE(Fixture, "check_against_annotations") +TEST_CASE_FIXTURE(Fixture, "initializers_are_checked_against_annotations") { CheckResult result = check("local a: number = \"Hello Types!\""); LUAU_REQUIRE_ERROR_COUNT(1, result); } -TEST_CASE_FIXTURE(Fixture, "check_multi_assign") +TEST_CASE_FIXTURE(Fixture, "check_multi_initialize") { - CheckResult result = check("local a: number, b: string = \"994\", 888"); - CHECK_EQ(2, result.errors.size()); + CheckResult result = check(R"( + local a: number, b: string = "one", 2 + )"); + + LUAU_REQUIRE_ERROR_COUNT(2, result); + + CHECK(get(result.errors[0])); + CHECK(get(result.errors[1])); } TEST_CASE_FIXTURE(Fixture, "successful_check") { - CheckResult result = check("local a: number, b: string = 994, \"eight eighty eight\""); + CheckResult result = check(R"( + local a: number, b: string = 1, "two" + )"); + LUAU_REQUIRE_NO_ERRORS(result); - dumpErrors(result); +} + +TEST_CASE_FIXTURE(Fixture, "assignments_are_checked_against_annotations") +{ + CheckResult result = check(R"( + local x: number = 1 + x = "two" + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); +} + +TEST_CASE_FIXTURE(Fixture, "multi_assign_checks_against_annotations") +{ + CheckResult result = check(R"( + local a: number, b: string = 1, "two" + a, b = "one", 2 + )"); + + LUAU_REQUIRE_ERROR_COUNT(2, result); + + CHECK(Location{{2, 15}, {2, 20}} == result.errors[0].location); + CHECK(Location{{2, 22}, {2, 23}} == result.errors[1].location); +} + +TEST_CASE_FIXTURE(Fixture, "assignment_cannot_transform_a_table_property_type") +{ + CheckResult result = check(R"( + local a = {x=0} + a.x = "one" + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CHECK(Location{{2, 14}, {2, 19}} == result.errors[0].location); +} + +TEST_CASE_FIXTURE(Fixture, "assignments_to_unannotated_parameters_can_transform_the_type") +{ + ScopedFastFlag sff{FFlag::LuauSolverV2, true}; + + CheckResult result = check(R"( + function f(x) + x = 0 + return x + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK("(unknown) -> number" == toString(requireType("f"))); +} + +TEST_CASE_FIXTURE(Fixture, "assignments_to_annotated_parameters_are_checked") +{ + ScopedFastFlag sff{FFlag::LuauSolverV2, true}; + + CheckResult result = check(R"( + function f(x: string) + x = 0 + return x + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK(Location{{2, 16}, {2, 17}} == result.errors[0].location); + + CHECK("(string) -> number" == toString(requireType("f"))); } TEST_CASE_FIXTURE(Fixture, "variable_type_is_supertype") @@ -40,6 +119,22 @@ TEST_CASE_FIXTURE(Fixture, "variable_type_is_supertype") LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(Fixture, "assignment_also_checks_subtyping") +{ + CheckResult result = check(R"( + function f(): number? + return nil + end + local x: number = 1 + local y: number? = f() + x = y + y = x + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK(Location{{6, 12}, {6, 13}} == result.errors[0].location); +} + TEST_CASE_FIXTURE(Fixture, "function_parameters_can_have_annotations") { CheckResult result = check(R"( @@ -86,7 +181,7 @@ TEST_CASE_FIXTURE(Fixture, "function_return_annotations_are_checked") REQUIRE_EQ(1, tp->head.size()); - REQUIRE_EQ(typeChecker.anyType, follow(tp->head[0])); + REQUIRE_EQ(builtinTypes->anyType, follow(tp->head[0])); } TEST_CASE_FIXTURE(Fixture, "function_return_multret_annotations_are_checked") @@ -133,14 +228,17 @@ TEST_CASE_FIXTURE(Fixture, "unknown_type_reference_generates_error") )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK(result.errors[0] == TypeError{ - Location{{1, 17}, {1, 28}}, - getMainSourceModule()->name, - UnknownSymbol{ - "IDoNotExist", - UnknownSymbol::Context::Type, - }, - }); + CHECK( + result.errors[0] == + TypeError{ + Location{{1, 17}, {1, 28}}, + getMainSourceModule()->name, + UnknownSymbol{ + "IDoNotExist", + UnknownSymbol::Context::Type, + }, + } + ); } TEST_CASE_FIXTURE(Fixture, "typeof_variable_type_annotation_should_return_its_type") @@ -166,17 +264,33 @@ TEST_CASE_FIXTURE(Fixture, "infer_type_of_value_a_via_typeof_with_assignment") a = "foo" )"); - CHECK_EQ(*typeChecker.numberType, *requireType("a")); - CHECK_EQ(*typeChecker.numberType, *requireType("b")); + if (FFlag::LuauSolverV2) + { + CHECK("string?" == toString(requireType("a"))); + CHECK("nil" == toString(requireType("b"))); - LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(result.errors[0], (TypeError{Location{Position{4, 12}, Position{4, 17}}, TypeMismatch{typeChecker.numberType, typeChecker.stringType}})); + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK( + result.errors[0] == (TypeError{Location{Position{2, 29}, Position{2, 30}}, TypeMismatch{builtinTypes->nilType, builtinTypes->numberType}}) + ); + } + else + { + CHECK_EQ(*builtinTypes->numberType, *requireType("a")); + CHECK_EQ(*builtinTypes->numberType, *requireType("b")); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ( + result.errors[0], + (TypeError{Location{Position{4, 12}, Position{4, 17}}, TypeMismatch{builtinTypes->numberType, builtinTypes->stringType}}) + ); + } } TEST_CASE_FIXTURE(Fixture, "table_annotation") { CheckResult result = check(R"( - local x: {a: number, b: string} + local x: {a: number, b: string} = {a=2, b="three"} local y = x.a local z = x.b )"); @@ -329,7 +443,7 @@ TEST_CASE_FIXTURE(Fixture, "self_referential_type_alias") std::optional incr = get(oTable->props, "incr"); REQUIRE(incr); - const FunctionType* incrFunc = get(incr->type); + const FunctionType* incrFunc = get(incr->type()); REQUIRE(incrFunc); std::optional firstArg = first(incrFunc->argTypes); @@ -376,7 +490,7 @@ TEST_CASE_FIXTURE(Fixture, "two_type_params") { CheckResult result = check(R"( type Map = {[K]: V} - local m: Map = {}; + local m: Map = {} local a = m['foo'] local b = m[9] -- error here )"); @@ -442,10 +556,10 @@ TEST_CASE_FIXTURE(Fixture, "corecursive_types_error_on_tight_loop") local bb:B )"); - TypeId fType = requireType("aa"); - const AnyType* ftv = get(follow(fType)); - REQUIRE(ftv != nullptr); - REQUIRE(!result.errors.empty()); + LUAU_REQUIRE_ERROR_COUNT(1, result); + + OccursCheckFailed* ocf = get(result.errors[0]); + REQUIRE(ocf); } TEST_CASE_FIXTURE(Fixture, "type_alias_always_resolve_to_a_real_type") @@ -459,7 +573,7 @@ TEST_CASE_FIXTURE(Fixture, "type_alias_always_resolve_to_a_real_type") )"); TypeId fType = requireType("aa"); - REQUIRE(follow(fType) == typeChecker.numberType); + REQUIRE(follow(fType) == builtinTypes->numberType); LUAU_REQUIRE_NO_ERRORS(result); } @@ -480,7 +594,7 @@ TEST_CASE_FIXTURE(Fixture, "interface_types_belong_to_interface_arena") const TypeFun& a = mod.exportedTypeBindings["A"]; CHECK(isInArena(a.type, mod.interfaceTypes)); - CHECK(!isInArena(a.type, typeChecker.globalTypes)); + CHECK(!isInArena(a.type, frontend.globals.globalTypes)); std::optional exportsType = first(mod.returnType); REQUIRE(exportsType); @@ -488,7 +602,7 @@ TEST_CASE_FIXTURE(Fixture, "interface_types_belong_to_interface_arena") TableType* exportsTable = getMutable(*exportsType); REQUIRE(exportsTable != nullptr); - TypeId n = exportsTable->props["n"].type; + TypeId n = exportsTable->props["n"].type(); REQUIRE(n != nullptr); CHECK(isInArena(n, mod.interfaceTypes)); @@ -543,23 +657,23 @@ TEST_CASE_FIXTURE(Fixture, "cloned_interface_maintains_pointers_between_definiti TableType* exportsTable = getMutable(*exportsType); REQUIRE(exportsTable != nullptr); - TypeId aType = exportsTable->props["a"].type; + TypeId aType = exportsTable->props["a"].type(); REQUIRE(aType); - TypeId bType = exportsTable->props["b"].type; + TypeId bType = exportsTable->props["b"].type(); REQUIRE(bType); CHECK(isInArena(recordType, mod.interfaceTypes)); CHECK(isInArena(aType, mod.interfaceTypes)); CHECK(isInArena(bType, mod.interfaceTypes)); - CHECK_EQ(recordType, aType); - CHECK_EQ(recordType, bType); + CHECK(toString(recordType, {true}) == toString(aType, {true})); + CHECK(toString(recordType, {true}) == toString(bType, {true})); } TEST_CASE_FIXTURE(BuiltinsFixture, "use_type_required_from_another_file") { - addGlobalBinding(frontend, "script", frontend.typeChecker.anyType, "@test"); + addGlobalBinding(frontend.globals, "script", builtinTypes->anyType, "@test"); fileResolver.source["Modules/Main"] = R"( --!strict @@ -585,7 +699,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "use_type_required_from_another_file") TEST_CASE_FIXTURE(BuiltinsFixture, "cannot_use_nonexported_type") { - addGlobalBinding(frontend, "script", frontend.typeChecker.anyType, "@test"); + addGlobalBinding(frontend.globals, "script", builtinTypes->anyType, "@test"); fileResolver.source["Modules/Main"] = R"( --!strict @@ -611,7 +725,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "cannot_use_nonexported_type") TEST_CASE_FIXTURE(BuiltinsFixture, "builtin_types_are_not_exported") { - addGlobalBinding(frontend, "script", frontend.typeChecker.anyType, "@test"); + addGlobalBinding(frontend.globals, "script", builtinTypes->anyType, "@test"); fileResolver.source["Modules/Main"] = R"( --!strict @@ -641,7 +755,8 @@ struct AssertionCatcher { tripped = 0; oldhook = Luau::assertHandler(); - Luau::assertHandler() = [](const char* expr, const char* file, int line, const char* function) -> int { + Luau::assertHandler() = [](const char* expr, const char* file, int line, const char* function) -> int + { ++tripped; return 0; }; @@ -661,39 +776,44 @@ int AssertionCatcher::tripped; TEST_CASE_FIXTURE(Fixture, "luau_ice_triggers_an_ice_exception_with_flag") { - ScopedFastFlag sffs{"DebugLuauMagicTypes", true}; + ScopedFastFlag sffs{FFlag::DebugLuauMagicTypes, true}; AssertionCatcher ac; - CHECK_THROWS_AS(check(R"( + CHECK_THROWS_AS( + check(R"( local a: _luau_ice = 55 )"), - InternalCompilerError); + InternalCompilerError + ); LUAU_ASSERT(1 == AssertionCatcher::tripped); } TEST_CASE_FIXTURE(Fixture, "luau_ice_triggers_an_ice_exception_with_flag_handler") { - ScopedFastFlag sffs{"DebugLuauMagicTypes", true}; + ScopedFastFlag sffs{FFlag::DebugLuauMagicTypes, true}; bool caught = false; - frontend.iceHandler.onInternalError = [&](const char*) { + frontend.iceHandler.onInternalError = [&](const char*) + { caught = true; }; - CHECK_THROWS_AS(check(R"( + CHECK_THROWS_AS( + check(R"( local a: _luau_ice = 55 )"), - InternalCompilerError); + InternalCompilerError + ); CHECK_EQ(true, caught); } TEST_CASE_FIXTURE(Fixture, "luau_ice_is_not_special_without_the_flag") { - ScopedFastFlag sffs{"DebugLuauMagicTypes", false}; + ScopedFastFlag sffs{FFlag::DebugLuauMagicTypes, false}; // We only care that this does not throw check(R"( @@ -705,11 +825,14 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "luau_print_is_magic_if_the_flag_is_set") { static std::vector output; output.clear(); - Luau::setPrintLine([](const std::string& s) { - output.push_back(s); - }); + Luau::setPrintLine( + [](const std::string& s) + { + output.push_back(s); + } + ); - ScopedFastFlag sffs{"DebugLuauMagicTypes", true}; + ScopedFastFlag sffs{FFlag::DebugLuauMagicTypes, true}; CheckResult result = check(R"( local a: _luau_print @@ -722,7 +845,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "luau_print_is_magic_if_the_flag_is_set") TEST_CASE_FIXTURE(Fixture, "luau_print_is_not_special_without_the_flag") { - ScopedFastFlag sffs{"DebugLuauMagicTypes", false}; + ScopedFastFlag sffs{FFlag::DebugLuauMagicTypes, false}; CheckResult result = check(R"( local a: _luau_print @@ -731,6 +854,18 @@ TEST_CASE_FIXTURE(Fixture, "luau_print_is_not_special_without_the_flag") LUAU_REQUIRE_ERROR_COUNT(1, result); } +TEST_CASE_FIXTURE(Fixture, "luau_print_incomplete") +{ + ScopedFastFlag sffs{FFlag::DebugLuauMagicTypes, true}; + + CheckResult result = check(R"( + local a: _luau_print + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ("_luau_print requires one generic parameter", toString(result.errors[0])); +} + TEST_CASE_FIXTURE(Fixture, "instantiate_type_fun_should_not_trip_rbxassert") { CheckResult result = check(R"( @@ -761,6 +896,7 @@ TEST_CASE_FIXTURE(Fixture, "occurs_check_on_cyclic_union_type") { CheckResult result = check(R"( type T = T | T + local x : T )"); LUAU_REQUIRE_ERROR_COUNT(1, result); diff --git a/tests/TypeInfer.anyerror.test.cpp b/tests/TypeInfer.anyerror.test.cpp index 9988a1fc5..445ce072b 100644 --- a/tests/TypeInfer.anyerror.test.cpp +++ b/tests/TypeInfer.anyerror.test.cpp @@ -13,6 +13,8 @@ using namespace Luau; +LUAU_FASTFLAG(LuauSolverV2); + TEST_SUITE_BEGIN("TypeInferAnyError"); TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_returns_any") @@ -30,7 +32,10 @@ TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_returns_any") LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(typeChecker.anyType, requireType("a")); + if (FFlag::LuauSolverV2) + CHECK("any?" == toString(requireType("a"))); + else + CHECK(builtinTypes->anyType == requireType("a")); } TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_returns_any2") @@ -48,13 +53,16 @@ TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_returns_any2") LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("any", toString(requireType("a"))); + if (FFlag::LuauSolverV2) + CHECK("any?" == toString(requireType("a"))); + else + CHECK("any" == toString(requireType("a"))); } TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_is_any") { CheckResult result = check(R"( - local bar: any + local bar = nil :: any local a for b in bar do @@ -64,13 +72,33 @@ TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_is_any") LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("any", toString(requireType("a"))); + if (FFlag::LuauSolverV2) + CHECK("any?" == toString(requireType("a"))); + else + CHECK("any" == toString(requireType("a"))); } TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_is_any2") { CheckResult result = check(R"( - local bar: any + local bar = nil :: any + + local a + for b in bar() do + a = b + end + )"); + + if (FFlag::LuauSolverV2) + CHECK("any?" == toString(requireType("a"))); + else + CHECK("any" == toString(requireType("a"))); +} + +TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_is_any_pack") +{ + CheckResult result = check(R"( + function bar(): ...any end local a for b in bar() do @@ -80,7 +108,10 @@ TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_is_any2") LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("any", toString(requireType("a"))); + if (FFlag::LuauSolverV2) + CHECK("any?" == toString(requireType("a"))); + else + CHECK("any" == toString(requireType("a"))); } TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_is_error") @@ -94,7 +125,16 @@ TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_is_error") LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ("*error-type*", toString(requireType("a"))); + + if (FFlag::LuauSolverV2) + { + // Bug: We do not simplify at the right time + CHECK_EQ("*error-type*?", toString(requireType("a"))); + } + else + { + CHECK_EQ("*error-type*", toString(requireType("a"))); + } } TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_is_error2") @@ -108,9 +148,21 @@ TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_is_error2") end )"); - LUAU_REQUIRE_ERROR_COUNT(1, result); - - CHECK_EQ("*error-type*", toString(requireType("a"))); + if (FFlag::LuauSolverV2) + { + // CLI-97375(awe): `bar()` is returning `nil` here, which isn't wrong necessarily, + // but then we're signaling an additional error for the access on `nil`. + LUAU_REQUIRE_ERROR_COUNT(2, result); + + // Bug: We do not simplify at the right time + CHECK_EQ("*error-type*?", toString(requireType("a"))); + } + else + { + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CHECK_EQ("*error-type*", toString(requireType("a"))); + } } TEST_CASE_FIXTURE(Fixture, "length_of_error_type_does_not_produce_an_error") @@ -169,7 +221,7 @@ TEST_CASE_FIXTURE(Fixture, "can_subscript_any") TEST_CASE_FIXTURE(Fixture, "can_get_length_of_any") { CheckResult result = check(R"( - local foo: any = {} + local foo = ({} :: any) local bar = #foo )"); @@ -195,7 +247,7 @@ TEST_CASE_FIXTURE(Fixture, "assign_prop_to_table_by_calling_any_yields_any") REQUIRE(ttv); REQUIRE(ttv->props.count("prop")); - REQUIRE_EQ("any", toString(ttv->props["prop"].type)); + REQUIRE_EQ("any", toString(ttv->props["prop"].type())); } TEST_CASE_FIXTURE(Fixture, "quantify_any_does_not_bind_to_itself") @@ -209,7 +261,7 @@ TEST_CASE_FIXTURE(Fixture, "quantify_any_does_not_bind_to_itself") LUAU_REQUIRE_NO_ERRORS(result); TypeId aType = requireType("A"); - CHECK_EQ(aType, typeChecker.anyType); + CHECK_EQ(aType, builtinTypes->anyType); } TEST_CASE_FIXTURE(Fixture, "calling_error_type_yields_error") @@ -225,7 +277,10 @@ TEST_CASE_FIXTURE(Fixture, "calling_error_type_yields_error") CHECK_EQ("unknown", err->name); - CHECK_EQ("*error-type*", toString(requireType("a"))); + if (FFlag::LuauSolverV2) + CHECK_EQ("any", toString(requireType("a"))); + else + CHECK_EQ("*error-type*", toString(requireType("a"))); } TEST_CASE_FIXTURE(Fixture, "chain_calling_error_type_yields_error") @@ -234,7 +289,10 @@ TEST_CASE_FIXTURE(Fixture, "chain_calling_error_type_yields_error") local a = Utility.Create "Foo" {} )"); - CHECK_EQ("*error-type*", toString(requireType("a"))); + if (FFlag::LuauSolverV2) + CHECK_EQ("any", toString(requireType("a"))); + else + CHECK_EQ("*error-type*", toString(requireType("a"))); } TEST_CASE_FIXTURE(BuiltinsFixture, "replace_every_free_type_when_unifying_a_complex_function_with_any") @@ -248,7 +306,11 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "replace_every_free_type_when_unifying_a_comp )"); LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("any", toString(requireType("b"))); + + if (FFlag::LuauSolverV2) + CHECK_EQ("any?", toString(requireType("b"))); + else + CHECK_EQ("any", toString(requireType("b"))); } TEST_CASE_FIXTURE(Fixture, "call_to_any_yields_any") @@ -343,4 +405,42 @@ stat = stat and tonumber(stat) or stat LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(BuiltinsFixture, "table_of_any_calls") +{ + CheckResult result = check(R"( + local function testFunc(input: {any}) + end + + local v = {true} + + testFunc(v) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "intersection_of_any_can_have_props") +{ + // *blocked-130* ~ hasProp any & ~(false?), "_status" + CheckResult result = check(R"( +function foo(x: any, y) + if x then + return x._status + end + return y +end +)"); + + CHECK("(any, any) -> any" == toString(requireType("foo"))); +} + +TEST_CASE_FIXTURE(Fixture, "cast_to_table_of_any") +{ + CheckResult result = check(R"( + local v = {true} :: {any} + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.builtins.test.cpp b/tests/TypeInfer.builtins.test.cpp index 860dcfd03..7f73f8e2c 100644 --- a/tests/TypeInfer.builtins.test.cpp +++ b/tests/TypeInfer.builtins.test.cpp @@ -1,6 +1,7 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/TypeInfer.h" #include "Luau/BuiltinDefinitions.h" +#include "Luau/Common.h" #include "Fixture.h" @@ -8,8 +9,10 @@ using namespace Luau; -LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); -LUAU_FASTFLAG(LuauMatchReturnsOptionalString); +LUAU_FASTFLAG(LuauSolverV2) +LUAU_DYNAMIC_FASTINT(LuauTypeSolverRelease) +LUAU_FASTFLAG(LuauTypestateBuiltins) +LUAU_FASTFLAG(LuauStringFormatArityFix) TEST_SUITE_BEGIN("BuiltinTests"); @@ -106,7 +109,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "table_concat_returns_string") )"); LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(*typeChecker.stringType, *requireType("r")); + CHECK_EQ(*builtinTypes->stringType, *requireType("r")); } TEST_CASE_FIXTURE(BuiltinsFixture, "sort") @@ -133,6 +136,8 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "sort_with_predicate") TEST_CASE_FIXTURE(BuiltinsFixture, "sort_with_bad_predicate") { + ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + CheckResult result = check(R"( --!strict local t = {'one', 'two', 'three'} @@ -141,12 +146,20 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "sort_with_bad_predicate") )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(R"(Type '(number, number) -> boolean' could not be converted into '((a, a) -> boolean)?' + const std::string expected = R"(Type + '(number, number) -> boolean' +could not be converted into + '((string, string) -> boolean)?' caused by: - None of the union options are compatible. For example: Type '(number, number) -> boolean' could not be converted into '(a, a) -> boolean' + None of the union options are compatible. For example: +Type + '(number, number) -> boolean' +could not be converted into + '(string, string) -> boolean' caused by: - Argument #1 type is not compatible. Type 'string' could not be converted into 'number')", - toString(result.errors[0])); + Argument #1 type is not compatible. +Type 'string' could not be converted into 'number')"; + CHECK_EQ(expected, toString(result.errors[0])); } TEST_CASE_FIXTURE(Fixture, "strings_have_methods") @@ -156,7 +169,7 @@ TEST_CASE_FIXTURE(Fixture, "strings_have_methods") )LUA"); LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(*typeChecker.stringType, *requireType("s")); + CHECK_EQ(*builtinTypes->stringType, *requireType("s")); } TEST_CASE_FIXTURE(BuiltinsFixture, "math_max_variatic") @@ -166,7 +179,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "math_max_variatic") )"); LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(*typeChecker.numberType, *requireType("n")); + CHECK_EQ(*builtinTypes->numberType, *requireType("n")); } TEST_CASE_FIXTURE(BuiltinsFixture, "math_max_checks_for_numbers") @@ -356,6 +369,24 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "setmetatable_unpacks_arg_types_correctly") LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(BuiltinsFixture, "setmetatable_on_union_of_tables") +{ + CheckResult result = check(R"( + type A = {tag: "A", x: number} + type B = {tag: "B", y: string} + + type T = A | B + + type X = typeof( + setmetatable({} :: T, {}) + ) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK("{ @metatable { }, A } | { @metatable { }, B }" == toString(requireTypeAlias("X"))); +} + TEST_CASE_FIXTURE(BuiltinsFixture, "table_insert_correctly_infers_type_of_array_2_args_overload") { CheckResult result = check(R"( @@ -365,7 +396,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "table_insert_correctly_infers_type_of_array_ )"); LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(typeChecker.stringType, requireType("s")); + CHECK_EQ(builtinTypes->stringType, requireType("s")); } TEST_CASE_FIXTURE(BuiltinsFixture, "table_insert_correctly_infers_type_of_array_3_args_overload") @@ -387,7 +418,10 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "table_pack") )"); LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("{| [number]: boolean | number | string, n: number |}", toString(requireType("t"))); + if (FFlag::LuauSolverV2) + CHECK_EQ("{ [number]: boolean | number | string, n: number }", toString(requireType("t"))); + else + CHECK_EQ("{| [number]: boolean | number | string, n: number |}", toString(requireType("t"))); } TEST_CASE_FIXTURE(BuiltinsFixture, "table_pack_variadic") @@ -402,7 +436,10 @@ local t = table.pack(f()) )"); LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("{| [number]: number | string, n: number |}", toString(requireType("t"))); + if (FFlag::LuauSolverV2) + CHECK_EQ("{ [number]: number | string, n: number }", toString(requireType("t"))); + else + CHECK_EQ("{| [number]: number | string, n: number |}", toString(requireType("t"))); } TEST_CASE_FIXTURE(BuiltinsFixture, "table_pack_reduce") @@ -412,14 +449,20 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "table_pack_reduce") )"); LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("{| [number]: boolean | number, n: number |}", toString(requireType("t"))); + if (FFlag::LuauSolverV2) + CHECK_EQ("{ [number]: boolean | number, n: number }", toString(requireType("t"))); + else + CHECK_EQ("{| [number]: boolean | number, n: number |}", toString(requireType("t"))); result = check(R"( local t = table.pack("a", "b", "c") )"); LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("{| [number]: string, n: number |}", toString(requireType("t"))); + if (FFlag::LuauSolverV2) + CHECK_EQ("{ [number]: string, n: number }", toString(requireType("t"))); + else + CHECK_EQ("{| [number]: string, n: number |}", toString(requireType("t"))); } TEST_CASE_FIXTURE(BuiltinsFixture, "gcinfo") @@ -429,7 +472,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "gcinfo") )"); LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(*typeChecker.numberType, *requireType("n")); + CHECK_EQ(*builtinTypes->numberType, *requireType("n")); } TEST_CASE_FIXTURE(BuiltinsFixture, "getfenv") @@ -446,9 +489,9 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "os_time_takes_optional_date_table") )"); LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(*typeChecker.numberType, *requireType("n1")); - CHECK_EQ(*typeChecker.numberType, *requireType("n2")); - CHECK_EQ(*typeChecker.numberType, *requireType("n3")); + CHECK_EQ(*builtinTypes->numberType, *requireType("n1")); + CHECK_EQ(*builtinTypes->numberType, *requireType("n2")); + CHECK_EQ(*builtinTypes->numberType, *requireType("n3")); } TEST_CASE_FIXTURE(BuiltinsFixture, "thread_is_a_type") @@ -461,8 +504,20 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "thread_is_a_type") CHECK("thread" == toString(requireType("co"))); } +TEST_CASE_FIXTURE(BuiltinsFixture, "buffer_is_a_type") +{ + CheckResult result = check(R"( + local b = buffer.create(10) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK("buffer" == toString(requireType("b"))); +} + TEST_CASE_FIXTURE(BuiltinsFixture, "coroutine_resume_anything_goes") { + ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + CheckResult result = check(R"( local function nifty(x, y) print(x, y) @@ -500,6 +555,9 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "coroutine_wrap_anything_goes") TEST_CASE_FIXTURE(BuiltinsFixture, "setmetatable_should_not_mutate_persisted_types") { + if (FFlag::LuauSolverV2) + return; + CheckResult result = check(R"( local string = string @@ -544,6 +602,10 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "string_format_arg_count_mismatch") TEST_CASE_FIXTURE(BuiltinsFixture, "string_format_correctly_ordered_types") { + // CLI-115690 + if (FFlag::LuauSolverV2) + return; + CheckResult result = check(R"( --!strict string.format("%s", 123) @@ -552,8 +614,8 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "string_format_correctly_ordered_types") LUAU_REQUIRE_ERROR_COUNT(1, result); TypeMismatch* tm = get(result.errors[0]); REQUIRE(tm); - CHECK_EQ(tm->wantedType, typeChecker.stringType); - CHECK_EQ(tm->givenType, typeChecker.numberType); + CHECK_EQ(tm->wantedType, builtinTypes->stringType); + CHECK_EQ(tm->givenType, builtinTypes->numberType); } TEST_CASE_FIXTURE(BuiltinsFixture, "string_format_tostring_specifier") @@ -640,18 +702,32 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "bad_select_should_not_crash") local _ = function(l0,...) end local _ = function() - _(_); - _ += select(_()) + _(_); + _ += select(_()) end )"); - LUAU_REQUIRE_ERROR_COUNT(2, result); - CHECK_EQ("Argument count mismatch. Function '_' expects at least 1 argument, but none are specified", toString(result.errors[0])); - CHECK_EQ("Argument count mismatch. Function 'select' expects 1 argument, but none are specified", toString(result.errors[1])); + if (FFlag::LuauSolverV2) + { + // Counterintuitively, the parameter l0 is unconstrained and therefore it is valid to pass nil. + // The new solver therefore considers that parameter to be optional. + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK("Argument count mismatch. Function expects at least 1 argument, but none are specified" == toString(result.errors[0])); + } + else + { + LUAU_REQUIRE_ERROR_COUNT(2, result); + CHECK_EQ("Argument count mismatch. Function '_' expects at least 1 argument, but none are specified", toString(result.errors[0])); + CHECK_EQ("Argument count mismatch. Function 'select' expects 1 argument, but none are specified", toString(result.errors[1])); + } } TEST_CASE_FIXTURE(BuiltinsFixture, "select_way_out_of_range") { + // CLI-115720 + if (FFlag::LuauSolverV2) + return; + CheckResult result = check(R"( select(5432598430953240958) )"); @@ -663,6 +739,10 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "select_way_out_of_range") TEST_CASE_FIXTURE(BuiltinsFixture, "select_slightly_out_of_range") { + // CLI-115720 + if (FFlag::LuauSolverV2) + return; + CheckResult result = check(R"( select(3, "a", 1) )"); @@ -693,6 +773,10 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "select_with_variadic_typepack_tail") TEST_CASE_FIXTURE(BuiltinsFixture, "select_with_variadic_typepack_tail_and_string_head") { + // CLI-115720 + if (FFlag::LuauSolverV2) + return; + CheckResult result = check(R"( --!nonstrict local function f(...) @@ -704,11 +788,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "select_with_variadic_typepack_tail_and_strin LUAU_REQUIRE_NO_ERRORS(result); - if (FFlag::DebugLuauDeferredConstraintResolution) - CHECK_EQ("string", toString(requireType("foo"))); - else - CHECK_EQ("any", toString(requireType("foo"))); - + CHECK_EQ("any", toString(requireType("foo"))); CHECK_EQ("any", toString(requireType("bar"))); CHECK_EQ("any", toString(requireType("baz"))); CHECK_EQ("any", toString(requireType("quux"))); @@ -722,8 +802,21 @@ TEST_CASE_FIXTURE(Fixture, "string_format_as_method") TypeMismatch* tm = get(result.errors[0]); REQUIRE(tm); - CHECK_EQ(tm->wantedType, typeChecker.stringType); - CHECK_EQ(tm->givenType, typeChecker.numberType); + CHECK_EQ(tm->wantedType, builtinTypes->stringType); + CHECK_EQ(tm->givenType, builtinTypes->numberType); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "string_format_trivial_arity") +{ + ScopedFastFlag sff{FFlag::LuauStringFormatArityFix, true}; + + CheckResult result = check(R"( + string.format() + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CHECK_EQ("Argument count mismatch. Function 'string.format' expects at least 1 argument, but none are specified", toString(result.errors[0])); } TEST_CASE_FIXTURE(Fixture, "string_format_use_correct_argument") @@ -766,16 +859,17 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "string_format_use_correct_argument3") TEST_CASE_FIXTURE(BuiltinsFixture, "debug_traceback_is_crazy") { CheckResult result = check(R"( -local co: thread = ... --- debug.traceback takes thread?, message?, level? - yes, all optional! -debug.traceback() -debug.traceback(nil, 1) -debug.traceback("msg") -debug.traceback("msg", 1) -debug.traceback(co) -debug.traceback(co, "msg") -debug.traceback(co, "msg", 1) -)"); + function f(co: thread) + -- debug.traceback takes thread?, message?, level? - yes, all optional! + debug.traceback() + debug.traceback(nil, 1) + debug.traceback("msg") + debug.traceback("msg", 1) + debug.traceback(co) + debug.traceback(co, "msg") + debug.traceback(co, "msg", 1) + end + )"); LUAU_REQUIRE_NO_ERRORS(result); } @@ -783,13 +877,13 @@ debug.traceback(co, "msg", 1) TEST_CASE_FIXTURE(BuiltinsFixture, "debug_info_is_crazy") { CheckResult result = check(R"( -local co: thread, f: ()->() = ... - --- debug.info takes thread?, level, options or function, options -debug.info(1, "n") -debug.info(co, 1, "n") -debug.info(f, "n") -)"); + function f(co: thread, f: () -> ()) + -- debug.info takes thread?, level, options or function, options + debug.info(1, "n") + debug.info(co, 1, "n") + debug.info(f, "n") + end + )"); LUAU_REQUIRE_NO_ERRORS(result); } @@ -860,9 +954,9 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "string_format_report_all_type_errors_at_corr string.format("%s%d%s", 1, "hello", true) )"); - TypeId stringType = typeChecker.stringType; - TypeId numberType = typeChecker.numberType; - TypeId booleanType = typeChecker.booleanType; + TypeId stringType = builtinTypes->stringType; + TypeId numberType = builtinTypes->numberType; + TypeId booleanType = builtinTypes->booleanType; LUAU_REQUIRE_ERROR_COUNT(6, result); @@ -893,7 +987,13 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "tonumber_returns_optional_number_type") )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ("Type 'number?' could not be converted into 'number'", toString(result.errors[0])); + if (FFlag::LuauSolverV2) + CHECK_EQ( + "Type 'number?' could not be converted into 'number'; type number?[1] (nil) is not a subtype of number (number)", + toString(result.errors[0]) + ); + else + CHECK_EQ("Type 'number?' could not be converted into 'number'", toString(result.errors[0])); } TEST_CASE_FIXTURE(BuiltinsFixture, "tonumber_returns_optional_number_type2") @@ -908,6 +1008,11 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "tonumber_returns_optional_number_type2") TEST_CASE_FIXTURE(BuiltinsFixture, "dont_add_definitions_to_persistent_types") { + // This test makes no sense with type states and I think it generally makes no sense under the new solver. + // TODO: clip. + if (FFlag::LuauSolverV2) + return; + CheckResult result = check(R"( local f = math.sin local function g(x) return math.sin(x) end @@ -937,7 +1042,11 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "assert_removes_falsy_types") )"); LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("((boolean | number)?) -> boolean | number", toString(requireType("f"))); + + if (FFlag::LuauSolverV2) + CHECK_EQ("((boolean | number)?) -> number | true", toString(requireType("f"))); + else + CHECK_EQ("((boolean | number)?) -> boolean | number", toString(requireType("f"))); } TEST_CASE_FIXTURE(BuiltinsFixture, "assert_removes_falsy_types2") @@ -952,8 +1061,27 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "assert_removes_falsy_types2") CHECK_EQ("((boolean | number)?) -> number | true", toString(requireType("f"))); } +TEST_CASE_FIXTURE(BuiltinsFixture, "assert_removes_falsy_types3") +{ + CheckResult result = check(R"( + local function f(x: (number | boolean)?) + assert(x) + return x + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + if (FFlag::LuauSolverV2) + CHECK_EQ("((boolean | number)?) -> number | true", toString(requireType("f"))); + else // without the annotation, the old solver doesn't infer the best return type here + CHECK_EQ("((boolean | number)?) -> boolean | number", toString(requireType("f"))); +} + TEST_CASE_FIXTURE(BuiltinsFixture, "assert_removes_falsy_types_even_from_type_pack_tail_but_only_for_the_first_type") { + if (FFlag::LuauSolverV2) + return; + CheckResult result = check(R"( local function f(...: number?) return assert(...) @@ -966,6 +1094,12 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "assert_removes_falsy_types_even_from_type_pa TEST_CASE_FIXTURE(BuiltinsFixture, "assert_returns_false_and_string_iff_it_knows_the_first_argument_cannot_be_truthy") { + if (FFlag::LuauSolverV2) + { + // CLI-114134 - egraph simplification + return; + } + CheckResult result = check(R"( local function f(x: nil) return assert(x, "hmm") @@ -992,24 +1126,128 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "table_freeze_is_generic") local c = tf3[2] local d = tf1.b + + local a2 = t1.a + local b2 = t2.b + local c2 = t3[2] )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ("Key 'b' not found in table '{| a: number |}'", toString(result.errors[0])); + if (FFlag::LuauSolverV2 && FFlag::LuauTypestateBuiltins) + CHECK("Key 'b' not found in table '{ read a: number }'" == toString(result.errors[0])); + else if (FFlag::LuauSolverV2) + CHECK("Key 'b' not found in table '{ a: number }'" == toString(result.errors[0])); + else + CHECK_EQ("Key 'b' not found in table '{| a: number |}'", toString(result.errors[0])); + CHECK(Location({13, 18}, {13, 23}) == result.errors[0].location); + + if (FFlag::LuauSolverV2 && FFlag::LuauTypestateBuiltins) + { + CHECK_EQ("{ read a: number }", toString(requireTypeAtPosition({15, 19}))); + CHECK_EQ("{ read b: string }", toString(requireTypeAtPosition({16, 19}))); + CHECK_EQ("{boolean}", toString(requireTypeAtPosition({17, 19}))); + } CHECK_EQ("number", toString(requireType("a"))); CHECK_EQ("string", toString(requireType("b"))); CHECK_EQ("boolean", toString(requireType("c"))); - CHECK_EQ("*error-type*", toString(requireType("d"))); + + if (FFlag::LuauSolverV2) + CHECK_EQ("any", toString(requireType("d"))); + else + CHECK_EQ("*error-type*", toString(requireType("d"))); + + CHECK_EQ("number", toString(requireType("a2"))); + CHECK_EQ("string", toString(requireType("b2"))); + CHECK_EQ("boolean", toString(requireType("c2"))); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "table_freeze_does_not_retroactively_block_mutation") +{ + CheckResult result = check(R"( + local t1 = {a = 42} + + t1.q = ":3" + + local tf1 = table.freeze(t1) + + local a = tf1.a + local b = t1.a + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + + if (FFlag::LuauTypestateBuiltins) + { + CHECK_EQ("t1 | { read a: number, read q: string }", toString(requireType("t1"))); + // before the assignment, it's `t1` + CHECK_EQ("t1", toString(requireTypeAtPosition({3, 8}))); + // after the assignment, it's read-only. + CHECK_EQ("{ read a: number, read q: string }", toString(requireTypeAtPosition({8, 18}))); + } + + CHECK_EQ("number", toString(requireType("a"))); + CHECK_EQ("number", toString(requireType("b"))); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "table_freeze_no_generic_table") +{ + ScopedFastFlag sff{FFlag::LuauSolverV2, true}; + + CheckResult result = check(R"( + --!strict + type k = { + read k: string, + } + + function _(): k + return table.freeze({ + k = "", + }) + end + )"); + + if (FFlag::LuauTypestateBuiltins) + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "table_freeze_errors_on_non_tables") +{ + CheckResult result = check(R"( + --!strict + table.freeze(42) + )"); + + // this does not error in the new solver without the typestate builtins functionality. + if (FFlag::LuauSolverV2 && !FFlag::LuauTypestateBuiltins) + { + LUAU_REQUIRE_NO_ERRORS(result); + return; + } + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + TypeMismatch* tm = get(result.errors[0]); + REQUIRE(tm); + + if (FFlag::LuauSolverV2 && FFlag::LuauTypestateBuiltins) + CHECK_EQ(toString(tm->wantedType), "table"); + else + CHECK_EQ(toString(tm->wantedType), "{- -}"); + CHECK_EQ(toString(tm->givenType), "number"); } TEST_CASE_FIXTURE(BuiltinsFixture, "set_metatable_needs_arguments") { + // In the new solver, nil can certainly be used where a generic is required, so all generic parameters are optional. + ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + CheckResult result = check(R"( -local a = {b=setmetatable} -a.b() -a:b() -a:b({}) + local a = {b=setmetatable} + a.b() + a:b() + a:b({}) )"); LUAU_REQUIRE_ERROR_COUNT(2, result); CHECK_EQ(toString(result.errors[0]), "Argument count mismatch. Function 'a.b' expects 2 arguments, but none are specified"); @@ -1019,19 +1257,19 @@ a:b({}) TEST_CASE_FIXTURE(Fixture, "typeof_unresolved_function") { CheckResult result = check(R"( -local function f(a: typeof(f)) end -)"); + local function f(a: typeof(f)) end + )"); LUAU_REQUIRE_ERROR_COUNT(1, result); CHECK_EQ("Unknown global 'f'", toString(result.errors[0])); } TEST_CASE_FIXTURE(BuiltinsFixture, "no_persistent_typelevel_change") { - TypeId mathTy = requireType(typeChecker.globalScope, "math"); + TypeId mathTy = requireType(frontend.globals.globalScope, "math"); REQUIRE(mathTy); TableType* ttv = getMutable(mathTy); REQUIRE(ttv); - const FunctionType* ftv = get(ttv->props["frexp"].type); + const FunctionType* ftv = get(ttv->props["frexp"].type()); REQUIRE(ftv); auto original = ftv->level; @@ -1058,15 +1296,12 @@ end TEST_CASE_FIXTURE(Fixture, "string_match") { CheckResult result = check(R"( - local s:string + local s: string = "hello" local p = s:match("foo") )"); LUAU_REQUIRE_NO_ERRORS(result); - if (FFlag::LuauMatchReturnsOptionalString) - CHECK_EQ(toString(requireType("p")), "string?"); - else - CHECK_EQ(toString(requireType("p")), "string"); + CHECK_EQ(toString(requireType("p")), "string?"); } TEST_CASE_FIXTURE(BuiltinsFixture, "gmatch_capture_types") @@ -1077,18 +1312,9 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "gmatch_capture_types") LUAU_REQUIRE_NO_ERRORS(result); - if (FFlag::LuauMatchReturnsOptionalString) - { - CHECK_EQ(toString(requireType("a")), "string?"); - CHECK_EQ(toString(requireType("b")), "number?"); - CHECK_EQ(toString(requireType("c")), "string?"); - } - else - { - CHECK_EQ(toString(requireType("a")), "string"); - CHECK_EQ(toString(requireType("b")), "number"); - CHECK_EQ(toString(requireType("c")), "string"); - } + CHECK_EQ(toString(requireType("a")), "string?"); + CHECK_EQ(toString(requireType("b")), "number?"); + CHECK_EQ(toString(requireType("c")), "string?"); } TEST_CASE_FIXTURE(Fixture, "gmatch_capture_types2") @@ -1099,18 +1325,9 @@ TEST_CASE_FIXTURE(Fixture, "gmatch_capture_types2") LUAU_REQUIRE_NO_ERRORS(result); - if (FFlag::LuauMatchReturnsOptionalString) - { - CHECK_EQ(toString(requireType("a")), "string?"); - CHECK_EQ(toString(requireType("b")), "number?"); - CHECK_EQ(toString(requireType("c")), "string?"); - } - else - { - CHECK_EQ(toString(requireType("a")), "string"); - CHECK_EQ(toString(requireType("b")), "number"); - CHECK_EQ(toString(requireType("c")), "string"); - } + CHECK_EQ(toString(requireType("a")), "string?"); + CHECK_EQ(toString(requireType("b")), "number?"); + CHECK_EQ(toString(requireType("c")), "string?"); } TEST_CASE_FIXTURE(BuiltinsFixture, "gmatch_capture_types_default_capture") @@ -1127,10 +1344,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "gmatch_capture_types_default_capture") CHECK_EQ(acm->expected, 1); CHECK_EQ(acm->actual, 4); - if (FFlag::LuauMatchReturnsOptionalString) - CHECK_EQ(toString(requireType("a")), "string?"); - else - CHECK_EQ(toString(requireType("a")), "string"); + CHECK_EQ(toString(requireType("a")), "string?"); } TEST_CASE_FIXTURE(BuiltinsFixture, "gmatch_capture_types_balanced_escaped_parens") @@ -1147,18 +1361,9 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "gmatch_capture_types_balanced_escaped_parens CHECK_EQ(acm->expected, 3); CHECK_EQ(acm->actual, 4); - if (FFlag::LuauMatchReturnsOptionalString) - { - CHECK_EQ(toString(requireType("a")), "string?"); - CHECK_EQ(toString(requireType("b")), "string?"); - CHECK_EQ(toString(requireType("c")), "number?"); - } - else - { - CHECK_EQ(toString(requireType("a")), "string"); - CHECK_EQ(toString(requireType("b")), "string"); - CHECK_EQ(toString(requireType("c")), "number"); - } + CHECK_EQ(toString(requireType("a")), "string?"); + CHECK_EQ(toString(requireType("b")), "string?"); + CHECK_EQ(toString(requireType("c")), "number?"); } TEST_CASE_FIXTURE(BuiltinsFixture, "gmatch_capture_types_parens_in_sets_are_ignored") @@ -1175,16 +1380,8 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "gmatch_capture_types_parens_in_sets_are_igno CHECK_EQ(acm->expected, 2); CHECK_EQ(acm->actual, 3); - if (FFlag::LuauMatchReturnsOptionalString) - { - CHECK_EQ(toString(requireType("a")), "string?"); - CHECK_EQ(toString(requireType("b")), "number?"); - } - else - { - CHECK_EQ(toString(requireType("a")), "string"); - CHECK_EQ(toString(requireType("b")), "number"); - } + CHECK_EQ(toString(requireType("a")), "string?"); + CHECK_EQ(toString(requireType("b")), "number?"); } TEST_CASE_FIXTURE(BuiltinsFixture, "gmatch_capture_types_set_containing_lbracket") @@ -1195,16 +1392,8 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "gmatch_capture_types_set_containing_lbracket LUAU_REQUIRE_NO_ERRORS(result); - if (FFlag::LuauMatchReturnsOptionalString) - { - CHECK_EQ(toString(requireType("a")), "number?"); - CHECK_EQ(toString(requireType("b")), "string?"); - } - else - { - CHECK_EQ(toString(requireType("a")), "number"); - CHECK_EQ(toString(requireType("b")), "string"); - } + CHECK_EQ(toString(requireType("a")), "number?"); + CHECK_EQ(toString(requireType("b")), "string?"); } TEST_CASE_FIXTURE(BuiltinsFixture, "gmatch_capture_types_leading_end_bracket_is_part_of_set") @@ -1252,18 +1441,9 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "match_capture_types") LUAU_REQUIRE_NO_ERRORS(result); - if (FFlag::LuauMatchReturnsOptionalString) - { - CHECK_EQ(toString(requireType("a")), "string?"); - CHECK_EQ(toString(requireType("b")), "number?"); - CHECK_EQ(toString(requireType("c")), "string?"); - } - else - { - CHECK_EQ(toString(requireType("a")), "string"); - CHECK_EQ(toString(requireType("b")), "number"); - CHECK_EQ(toString(requireType("c")), "string"); - } + CHECK_EQ(toString(requireType("a")), "string?"); + CHECK_EQ(toString(requireType("b")), "number?"); + CHECK_EQ(toString(requireType("c")), "string?"); } TEST_CASE_FIXTURE(BuiltinsFixture, "match_capture_types2") @@ -1279,18 +1459,9 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "match_capture_types2") CHECK_EQ(toString(tm->wantedType), "number?"); CHECK_EQ(toString(tm->givenType), "string"); - if (FFlag::LuauMatchReturnsOptionalString) - { - CHECK_EQ(toString(requireType("a")), "string?"); - CHECK_EQ(toString(requireType("b")), "number?"); - CHECK_EQ(toString(requireType("c")), "string?"); - } - else - { - CHECK_EQ(toString(requireType("a")), "string"); - CHECK_EQ(toString(requireType("b")), "number"); - CHECK_EQ(toString(requireType("c")), "string"); - } + CHECK_EQ(toString(requireType("a")), "string?"); + CHECK_EQ(toString(requireType("b")), "number?"); + CHECK_EQ(toString(requireType("c")), "string?"); } TEST_CASE_FIXTURE(BuiltinsFixture, "find_capture_types") @@ -1301,18 +1472,9 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "find_capture_types") LUAU_REQUIRE_NO_ERRORS(result); - if (FFlag::LuauMatchReturnsOptionalString) - { - CHECK_EQ(toString(requireType("a")), "string?"); - CHECK_EQ(toString(requireType("b")), "number?"); - CHECK_EQ(toString(requireType("c")), "string?"); - } - else - { - CHECK_EQ(toString(requireType("a")), "string"); - CHECK_EQ(toString(requireType("b")), "number"); - CHECK_EQ(toString(requireType("c")), "string"); - } + CHECK_EQ(toString(requireType("a")), "string?"); + CHECK_EQ(toString(requireType("b")), "number?"); + CHECK_EQ(toString(requireType("c")), "string?"); CHECK_EQ(toString(requireType("d")), "number?"); CHECK_EQ(toString(requireType("e")), "number?"); } @@ -1330,18 +1492,9 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "find_capture_types2") CHECK_EQ(toString(tm->wantedType), "number?"); CHECK_EQ(toString(tm->givenType), "string"); - if (FFlag::LuauMatchReturnsOptionalString) - { - CHECK_EQ(toString(requireType("a")), "string?"); - CHECK_EQ(toString(requireType("b")), "number?"); - CHECK_EQ(toString(requireType("c")), "string?"); - } - else - { - CHECK_EQ(toString(requireType("a")), "string"); - CHECK_EQ(toString(requireType("b")), "number"); - CHECK_EQ(toString(requireType("c")), "string"); - } + CHECK_EQ(toString(requireType("a")), "string?"); + CHECK_EQ(toString(requireType("b")), "number?"); + CHECK_EQ(toString(requireType("c")), "string?"); CHECK_EQ(toString(requireType("d")), "number?"); CHECK_EQ(toString(requireType("e")), "number?"); } @@ -1359,18 +1512,9 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "find_capture_types3") CHECK_EQ(toString(tm->wantedType), "boolean?"); CHECK_EQ(toString(tm->givenType), "string"); - if (FFlag::LuauMatchReturnsOptionalString) - { - CHECK_EQ(toString(requireType("a")), "string?"); - CHECK_EQ(toString(requireType("b")), "number?"); - CHECK_EQ(toString(requireType("c")), "string?"); - } - else - { - CHECK_EQ(toString(requireType("a")), "string"); - CHECK_EQ(toString(requireType("b")), "number"); - CHECK_EQ(toString(requireType("c")), "string"); - } + CHECK_EQ(toString(requireType("a")), "string?"); + CHECK_EQ(toString(requireType("b")), "number?"); + CHECK_EQ(toString(requireType("c")), "string?"); CHECK_EQ(toString(requireType("d")), "number?"); CHECK_EQ(toString(requireType("e")), "number?"); } @@ -1393,4 +1537,18 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "find_capture_types3") CHECK_EQ(toString(requireType("e")), "number?"); } +TEST_CASE_FIXTURE(BuiltinsFixture, "string_find_should_not_crash") +{ + ScopedFastFlag _{FFlag::LuauSolverV2, true}; + + LUAU_REQUIRE_NO_ERRORS(check(R"( + local function StringSplit(input, separator) + string.find(input, separator) + if not separator then + separator = "%s+" + end + end + )")); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.cfa.test.cpp b/tests/TypeInfer.cfa.test.cpp new file mode 100644 index 000000000..e097e18e4 --- /dev/null +++ b/tests/TypeInfer.cfa.test.cpp @@ -0,0 +1,977 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Fixture.h" +#include "doctest.h" + +using namespace Luau; + +TEST_SUITE_BEGIN("ControlFlowAnalysis"); + +TEST_CASE_FIXTURE(BuiltinsFixture, "if_not_x_return") +{ + CheckResult result = check(R"( + local function f(x: string?) + if not x then + return + end + + local foo = x + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("string", toString(requireTypeAtPosition({6, 24}))); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "if_not_x_break") +{ + CheckResult result = check(R"( + local function f(x: {{value: string?}}) + for _, record in x do + if not record.value then + break + end + + local foo = record.value + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("string", toString(requireTypeAtPosition({7, 34}))); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "if_not_x_continue") +{ + CheckResult result = check(R"( + local function f(x: {{value: string?}}) + for _, record in x do + if not record.value then + continue + end + + local foo = record.value + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("string", toString(requireTypeAtPosition({7, 38}))); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "if_not_x_return_elif_not_y_return") +{ + CheckResult result = check(R"( + local function f(x: string?, y: string?) + if not x then + return + elseif not y then + return + end + + local foo = x + local bar = y + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("string", toString(requireTypeAtPosition({8, 24}))); + CHECK_EQ("string", toString(requireTypeAtPosition({9, 24}))); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "if_not_x_break_elif_not_y_break") +{ + CheckResult result = check(R"( + local function f(x: {{value: string?}}, y: {{value: string?}}) + for i, recordX in x do + local recordY = y[i] + if not recordX.value then + break + elseif not recordY.value then + break + end + + local foo = recordX.value + local bar = recordY.value + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("string", toString(requireTypeAtPosition({10, 38}))); + CHECK_EQ("string", toString(requireTypeAtPosition({11, 38}))); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "if_not_x_continue_elif_not_y_continue") +{ + CheckResult result = check(R"( + local function f(x: {{value: string?}}, y: {{value: string?}}) + for i, recordX in x do + local recordY = y[i] + if not recordX.value then + continue + elseif not recordY.value then + continue + end + + local foo = recordX.value + local bar = recordY.value + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("string", toString(requireTypeAtPosition({10, 38}))); + CHECK_EQ("string", toString(requireTypeAtPosition({11, 38}))); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "if_not_x_return_elif_not_y_break") +{ + CheckResult result = check(R"( + local function f(x: {{value: string?}}, y: {{value: string?}}) + for i, recordX in x do + local recordY = y[i] + if not recordX.value then + return + elseif not recordY.value then + break + end + + local foo = recordX.value + local bar = recordY.value + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("string", toString(requireTypeAtPosition({10, 38}))); + CHECK_EQ("string", toString(requireTypeAtPosition({11, 38}))); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "if_not_x_break_elif_not_y_continue") +{ + CheckResult result = check(R"( + local function f(x: {{value: string?}}, y: {{value: string?}}) + for i, recordX in x do + local recordY = y[i] + if not recordX.value then + break + elseif not recordY.value then + continue + end + + local foo = recordX.value + local bar = recordY.value + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("string", toString(requireTypeAtPosition({10, 38}))); + CHECK_EQ("string", toString(requireTypeAtPosition({11, 38}))); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "if_not_x_return_elif_rand_return_elif_not_y_return") +{ + CheckResult result = check(R"( + local function f(x: string?, y: string?) + if not x then + return + elseif math.random() > 0.5 then + return + elseif not y then + return + end + + local foo = x + local bar = y + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("string", toString(requireTypeAtPosition({10, 24}))); + CHECK_EQ("string", toString(requireTypeAtPosition({11, 24}))); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "if_not_x_break_elif_rand_break_elif_not_y_break") +{ + CheckResult result = check(R"( + local function f(x: {{value: string?}}, y: {{value: string?}}) + for i, recordX in x do + local recordY = y[i] + if not recordX.value then + break + elseif math.random() > 0.5 then + break + elseif not recordY.value then + break + end + + local foo = recordX.value + local bar = recordY.value + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("string", toString(requireTypeAtPosition({12, 38}))); + CHECK_EQ("string", toString(requireTypeAtPosition({13, 38}))); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "if_not_x_continue_elif_rand_continue_elif_not_y_continue") +{ + CheckResult result = check(R"( + local function f(x: {{value: string?}}, y: {{value: string?}}) + for i, recordX in x do + local recordY = y[i] + if not recordX.value then + continue + elseif math.random() > 0.5 then + continue + elseif not recordY.value then + continue + end + + local foo = recordX.value + local bar = recordY.value + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("string", toString(requireTypeAtPosition({12, 38}))); + CHECK_EQ("string", toString(requireTypeAtPosition({13, 38}))); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "if_not_x_return_elif_not_rand_return_elif_not_y_fallthrough") +{ + CheckResult result = check(R"( + local function f(x: string?, y: string?) + if not x then + return + elseif math.random() > 0.5 then + return + elseif not y then + + end + + local foo = x + local bar = y + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("string", toString(requireTypeAtPosition({10, 24}))); + CHECK_EQ("string?", toString(requireTypeAtPosition({11, 24}))); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "if_not_x_break_elif_rand_break_elif_not_y_fallthrough") +{ + CheckResult result = check(R"( + local function f(x: {{value: string?}}, y: {{value: string?}}) + for i, recordX in x do + local recordY = y[i] + if not recordX.value then + break + elseif math.random() > 0.5 then + break + elseif not recordY.value then + + end + + local foo = recordX.value + local bar = recordY.value + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("string", toString(requireTypeAtPosition({12, 38}))); + CHECK_EQ("string?", toString(requireTypeAtPosition({13, 38}))); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "if_not_x_continue_elif_rand_continue_elif_not_y_fallthrough") +{ + CheckResult result = check(R"( + local function f(x: {{value: string?}}, y: {{value: string?}}) + for i, recordX in x do + local recordY = y[i] + if not recordX.value then + continue + elseif math.random() > 0.5 then + continue + elseif not recordY.value then + + end + + local foo = recordX.value + local bar = recordY.value + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("string", toString(requireTypeAtPosition({12, 38}))); + CHECK_EQ("string?", toString(requireTypeAtPosition({13, 38}))); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "if_not_x_return_elif_not_y_fallthrough_elif_not_z_return") +{ + CheckResult result = check(R"( + local function f(x: string?, y: string?, z: string?) + if not x then + return + elseif not y then + + elseif not z then + return + end + + local foo = x + local bar = y + local baz = z + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("string", toString(requireTypeAtPosition({10, 24}))); + CHECK_EQ("string?", toString(requireTypeAtPosition({11, 24}))); + CHECK_EQ("string?", toString(requireTypeAtPosition({12, 24}))); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "if_not_x_break_elif_not_y_fallthrough_elif_not_z_break") +{ + CheckResult result = check(R"( + local function f(x: {{value: string?}}, y: {{value: string?}}, z: {{value: string?}}) + for i, recordX in x do + local recordY = y[i] + local recordZ = y[i] + if not recordX.value then + break + elseif not recordY.value then + + elseif not recordZ.value then + break + end + + local foo = recordX.value + local bar = recordY.value + local baz = recordZ.value + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("string", toString(requireTypeAtPosition({13, 38}))); + CHECK_EQ("string?", toString(requireTypeAtPosition({14, 38}))); + CHECK_EQ("string?", toString(requireTypeAtPosition({15, 38}))); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "if_not_x_continue_elif_not_y_fallthrough_elif_not_z_continue") +{ + CheckResult result = check(R"( + local function f(x: {{value: string?}}, y: {{value: string?}}, z: {{value: string?}}) + for i, recordX in x do + local recordY = y[i] + local recordZ = y[i] + if not recordX.value then + continue + elseif not recordY.value then + + elseif not recordZ.value then + continue + end + + local foo = recordX.value + local bar = recordY.value + local baz = recordZ.value + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("string", toString(requireTypeAtPosition({13, 38}))); + CHECK_EQ("string?", toString(requireTypeAtPosition({14, 38}))); + CHECK_EQ("string?", toString(requireTypeAtPosition({15, 38}))); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "if_not_x_continue_elif_not_y_throw_elif_not_z_fallthrough") +{ + CheckResult result = check(R"( + local function f(x: {{value: string?}}, y: {{value: string?}}, z: {{value: string?}}) + for i, recordX in x do + local recordY = y[i] + local recordZ = y[i] + if not recordX.value then + continue + elseif not recordY.value then + error("Y value not defined") + elseif not recordZ.value then + + end + + local foo = recordX.value + local bar = recordY.value + local baz = recordZ.value + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("string", toString(requireTypeAtPosition({13, 38}))); + CHECK_EQ("string", toString(requireTypeAtPosition({14, 38}))); + CHECK_EQ("string?", toString(requireTypeAtPosition({15, 38}))); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "if_not_x_return_elif_not_y_fallthrough_elif_not_z_break") +{ + CheckResult result = check(R"( + local function f(x: {{value: string?}}, y: {{value: string?}}, z: {{value: string?}}) + for i, recordX in x do + local recordY = y[i] + local recordZ = y[i] + if not recordX.value then + return + elseif not recordY.value then + + elseif not recordZ.value then + break + end + + local foo = recordX.value + local bar = recordY.value + local baz = recordZ.value + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("string", toString(requireTypeAtPosition({13, 38}))); + CHECK_EQ("string?", toString(requireTypeAtPosition({14, 38}))); + CHECK_EQ("string?", toString(requireTypeAtPosition({15, 38}))); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "do_if_not_x_return") +{ + CheckResult result = check(R"( + local function f(x: string?) + do + if not x then + return + end + end + + local foo = x + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("string", toString(requireTypeAtPosition({8, 24}))); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "for_record_do_if_not_x_break") +{ + CheckResult result = check(R"( + local function f(x: {{value: string?}}) + for _, record in x do + do + if not record.value then + break + end + end + + local foo = record.value + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("string", toString(requireTypeAtPosition({9, 38}))); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "for_record_do_if_not_x_continue") +{ + CheckResult result = check(R"( + local function f(x: {{value: string?}}) + for _, record in x do + do + if not record.value then + continue + end + end + + local foo = record.value + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("string", toString(requireTypeAtPosition({9, 38}))); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "early_return_in_a_loop_which_isnt_guaranteed_to_run_first") +{ + CheckResult result = check(R"( + local function f(x: string?) + while math.random() > 0.5 do + if not x then + return + end + + local foo = x + end + + local bar = x + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("string", toString(requireTypeAtPosition({7, 28}))); + CHECK_EQ("string?", toString(requireTypeAtPosition({10, 24}))); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "early_return_in_a_loop_which_is_guaranteed_to_run_first") +{ + CheckResult result = check(R"( + local function f(x: string?) + repeat + if not x then + return + end + + local foo = x + until math.random() > 0.5 + + local bar = x + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("string", toString(requireTypeAtPosition({7, 28}))); + CHECK_EQ("string?", toString(requireTypeAtPosition({10, 24}))); // TODO: This is wrong, should be `string`. +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "early_return_in_a_loop_which_is_guaranteed_to_run_first_2") +{ + CheckResult result = check(R"( + local function f(x: string?) + for i = 1, 10 do + if not x then + return + end + + local foo = x + end + + local bar = x + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("string", toString(requireTypeAtPosition({7, 28}))); + CHECK_EQ("string?", toString(requireTypeAtPosition({10, 24}))); // TODO: This is wrong, should be `string`. +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "if_not_x_then_error") +{ + CheckResult result = check(R"( + local function f(x: string?) + if not x then + error("oops") + end + + local foo = x + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("string", toString(requireTypeAtPosition({6, 24}))); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "if_not_x_then_assert_false") +{ + CheckResult result = check(R"( + local function f(x: string?) + if not x then + assert(false) + end + + local foo = x + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("string", toString(requireTypeAtPosition({6, 24}))); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "if_not_x_return_if_not_y_return") +{ + CheckResult result = check(R"( + local function f(x: string?, y: string?) + if not x then + return + end + + if not y then + return + end + + local foo = x + local bar = y + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("string", toString(requireTypeAtPosition({10, 24}))); + CHECK_EQ("string", toString(requireTypeAtPosition({11, 24}))); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "if_not_x_break_if_not_y_break") +{ + CheckResult result = check(R"( + local function f(x: {{value: string?}}, y: {{value: string?}}) + for i, recordX in x do + local recordY = y[i] + if not recordX.value then + break + end + + if not recordY.value then + break + end + + local foo = recordX.value + local bar = recordY.value + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("string", toString(requireTypeAtPosition({12, 38}))); + CHECK_EQ("string", toString(requireTypeAtPosition({13, 38}))); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "if_not_x_continue_if_not_y_continue") +{ + CheckResult result = check(R"( + local function f(x: {{value: string?}}, y: {{value: string?}}) + for i, recordX in x do + local recordY = y[i] + if not recordX.value then + continue + end + + if not recordY.value then + continue + end + + local foo = recordX.value + local bar = recordY.value + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("string", toString(requireTypeAtPosition({12, 38}))); + CHECK_EQ("string", toString(requireTypeAtPosition({13, 38}))); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "if_not_x_continue_if_not_y_throw") +{ + CheckResult result = check(R"( + local function f(x: {{value: string?}}, y: {{value: string?}}) + for i, recordX in x do + local recordY = y[i] + if not recordX.value then + continue + end + + if not recordY.value then + error("Y value not defined") + end + + local foo = recordX.value + local bar = recordY.value + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("string", toString(requireTypeAtPosition({12, 38}))); + CHECK_EQ("string", toString(requireTypeAtPosition({13, 38}))); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "if_not_x_break_if_not_y_continue") +{ + CheckResult result = check(R"( + local function f(x: {{value: string?}}, y: {{value: string?}}) + for i, recordX in x do + local recordY = y[i] + if not recordX.value then + break + end + + if not recordY.value then + continue + end + + local foo = recordX.value + local bar = recordY.value + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("string", toString(requireTypeAtPosition({12, 38}))); + CHECK_EQ("string", toString(requireTypeAtPosition({13, 38}))); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "type_alias_does_not_leak_out") +{ + CheckResult result = check(R"( + local function f(x: string?) + if typeof(x) == "string" then + return + else + type Foo = number + end + + local foo: Foo = x + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CHECK_EQ("Unknown type 'Foo'", toString(result.errors[0])); + + CHECK_EQ("nil", toString(requireTypeAtPosition({8, 29}))); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "type_alias_does_not_leak_out_breaking") +{ + CheckResult result = check(R"( + local function f(x: {{value: string?}}) + for _, record in x do + if typeof(record.value) == "string" then + break + else + type Foo = number + end + + local foo: Foo = record.value + end + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CHECK_EQ("Unknown type 'Foo'", toString(result.errors[0])); + + CHECK_EQ("nil", toString(requireTypeAtPosition({9, 43}))); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "type_alias_does_not_leak_out_continuing") +{ + CheckResult result = check(R"( + local function f(x: {{value: string?}}) + for _, record in x do + if typeof(record.value) == "string" then + continue + else + type Foo = number + end + + local foo: Foo = record.value + end + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CHECK_EQ("Unknown type 'Foo'", toString(result.errors[0])); + + CHECK_EQ("nil", toString(requireTypeAtPosition({9, 43}))); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "prototyping_and_visiting_alias_has_the_same_scope") +{ + // In CG, we walk the block to prototype aliases. We then visit the block in-order, which will resolve the prototype to a real type. + // That second walk assumes that the name occurs in the same `Scope` that the prototype walk had. If we arbitrarily change scope midway + // through, we'd invoke UB. + CheckResult result = check(R"( + local function f(x: string?) + type Foo = number + + if typeof(x) == "string" then + return + end + + local foo: Foo = x + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CHECK_EQ("Type 'nil' could not be converted into 'number'", toString(result.errors[0])); + + CHECK_EQ("nil", toString(requireTypeAtPosition({8, 29}))); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "prototyping_and_visiting_alias_has_the_same_scope_breaking") +{ + CheckResult result = check(R"( + local function f(x: {{value: string?}}) + for _, record in x do + type Foo = number + + if typeof(record.value) == "string" then + break + end + + local foo: Foo = record.value + end + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CHECK_EQ("Type 'nil' could not be converted into 'number'", toString(result.errors[0])); + + CHECK_EQ("nil", toString(requireTypeAtPosition({9, 43}))); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "prototyping_and_visiting_alias_has_the_same_scope_continuing") +{ + CheckResult result = check(R"( + local function f(x: {{value: string?}}) + for _, record in x do + type Foo = number + + if typeof(record.value) == "string" then + continue + end + + local foo: Foo = record.value + end + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CHECK_EQ("Type 'nil' could not be converted into 'number'", toString(result.errors[0])); + + CHECK_EQ("nil", toString(requireTypeAtPosition({9, 43}))); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "tagged_unions") +{ + CheckResult result = check(R"( + type Ok = { tag: "ok", value: T } + type Err = { tag: "err", error: E } + type Result = Ok | Err + + local function map(result: Result, f: (T) -> U): Result + if result.tag == "ok" then + local tag = result.tag + local val = result.value + + return { tag = "ok", value = f(result.value) } + end + + local tag = result.tag + local err = result.error + + return result + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("\"ok\"", toString(requireTypeAtPosition({7, 35}))); + CHECK_EQ("T", toString(requireTypeAtPosition({8, 35}))); + + CHECK_EQ("\"err\"", toString(requireTypeAtPosition({13, 31}))); + CHECK_EQ("E", toString(requireTypeAtPosition({14, 31}))); + + CHECK_EQ("Err", toString(requireTypeAtPosition({16, 19}))); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "tagged_unions_breaking") +{ + CheckResult result = check(R"( + type Ok = { tag: "ok", value: T } + type Err = { tag: "err", error: E } + type Result = Ok | Err + + local function process(results: {Result}) + for _, result in results do + if result.tag == "ok" then + local tag = result.tag + local val = result.value + + break + end + + local tag = result.tag + local err = result.error + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("\"ok\"", toString(requireTypeAtPosition({8, 39}))); + CHECK_EQ("T", toString(requireTypeAtPosition({9, 39}))); + + CHECK_EQ("\"err\"", toString(requireTypeAtPosition({14, 35}))); + CHECK_EQ("E", toString(requireTypeAtPosition({15, 35}))); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "tagged_unions_continuing") +{ + CheckResult result = check(R"( + type Ok = { tag: "ok", value: T } + type Err = { tag: "err", error: E } + type Result = Ok | Err + + local function process(results: {Result}) + for _, result in results do + if result.tag == "ok" then + local tag = result.tag + local val = result.value + + continue + end + + local tag = result.tag + local err = result.error + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("\"ok\"", toString(requireTypeAtPosition({8, 39}))); + CHECK_EQ("T", toString(requireTypeAtPosition({9, 39}))); + + CHECK_EQ("\"err\"", toString(requireTypeAtPosition({14, 35}))); + CHECK_EQ("E", toString(requireTypeAtPosition({15, 35}))); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "do_assert_x") +{ + CheckResult result = check(R"( + local function f(x: string?) + do + assert(x) + end + + local foo = x + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("string", toString(requireTypeAtPosition({6, 24}))); +} + +TEST_SUITE_END(); diff --git a/tests/TypeInfer.classes.test.cpp b/tests/TypeInfer.classes.test.cpp index becc88aa6..16751559f 100644 --- a/tests/TypeInfer.classes.test.cpp +++ b/tests/TypeInfer.classes.test.cpp @@ -7,15 +7,49 @@ #include "Fixture.h" #include "ClassFixture.h" +#include "ScopedFlags.h" #include "doctest.h" using namespace Luau; using std::nullopt; -LUAU_FASTFLAG(LuauTypeMismatchInvarianceInError); +LUAU_FASTFLAG(LuauSolverV2); TEST_SUITE_BEGIN("TypeInferClasses"); +TEST_CASE_FIXTURE(ClassFixture, "Luau.Analyze.CLI_crashes_on_this_test") +{ + CheckResult result = check(R"( + local CircularQueue = {} +CircularQueue.__index = CircularQueue + +function CircularQueue:new() + local newCircularQueue = { + head = nil, + } + setmetatable(newCircularQueue, CircularQueue) + + return newCircularQueue +end + +function CircularQueue:push() + local newListNode + + if self.head then + newListNode = { + prevNode = self.head.prevNode, + nextNode = self.head, + } + newListNode.prevNode.nextNode = newListNode + newListNode.nextNode.prevNode = newListNode + end +end + +return CircularQueue + + )"); +} + TEST_CASE_FIXTURE(ClassFixture, "call_method_of_a_class") { CheckResult result = check(R"( @@ -94,6 +128,8 @@ TEST_CASE_FIXTURE(ClassFixture, "we_can_infer_that_a_parameter_must_be_a_particu TEST_CASE_FIXTURE(ClassFixture, "we_can_report_when_someone_is_trying_to_use_a_table_rather_than_a_class") { + ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + CheckResult result = check(R"( function makeClone(o) return BaseClass.Clone(o) @@ -111,7 +147,35 @@ TEST_CASE_FIXTURE(ClassFixture, "we_can_report_when_someone_is_trying_to_use_a_t )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - TypeMismatch* tm = get(result.errors[0]); + TypeMismatch* tm = get(result.errors.at(0)); + REQUIRE(tm != nullptr); + + CHECK_EQ("Oopsies", toString(tm->givenType)); + CHECK_EQ("BaseClass", toString(tm->wantedType)); +} + +TEST_CASE_FIXTURE(ClassFixture, "we_can_report_when_someone_is_trying_to_use_a_table_rather_than_a_class_using_new_solver") +{ + ScopedFastFlag sff{FFlag::LuauSolverV2, true}; + + CheckResult result = check(R"( + function makeClone(o) + return BaseClass.Clone(o) + end + + type Oopsies = { read BaseMethod: (Oopsies, number) -> ()} + + local oopsies: Oopsies = { + BaseMethod = function (self: Oopsies, i: number) + print('gadzooks!') + end + } + + makeClone(oopsies) + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + TypeMismatch* tm = get(result.errors.at(0)); REQUIRE(tm != nullptr); CHECK_EQ("Oopsies", toString(tm->givenType)); @@ -170,6 +234,9 @@ TEST_CASE_FIXTURE(ClassFixture, "can_assign_to_prop_of_base_class_using_string") TEST_CASE_FIXTURE(ClassFixture, "cannot_unify_class_instance_with_primitive") { + // This is allowed in the new solver + ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + CheckResult result = check(R"( local v = Vector2.New(0, 5) v = 444 @@ -186,7 +253,7 @@ TEST_CASE_FIXTURE(ClassFixture, "warn_when_prop_almost_matches") LUAU_REQUIRE_ERROR_COUNT(1, result); - auto err = get(result.errors[0]); + auto err = get(result.errors.at(0)); REQUIRE(err != nullptr); REQUIRE_EQ(1, err->candidates.size()); @@ -290,7 +357,7 @@ TEST_CASE_FIXTURE(ClassFixture, "table_properties_are_invariant") )"); LUAU_REQUIRE_ERROR_COUNT(2, result); - CHECK_EQ(6, result.errors[0].location.begin.line); + CHECK_EQ(6, result.errors.at(0).location.begin.line); CHECK_EQ(13, result.errors[1].location.begin.line); } @@ -313,7 +380,7 @@ TEST_CASE_FIXTURE(ClassFixture, "table_indexers_are_invariant") )"); LUAU_REQUIRE_ERROR_COUNT(2, result); - CHECK_EQ(6, result.errors[0].location.begin.line); + CHECK_EQ(6, result.errors.at(0).location.begin.line); CHECK_EQ(13, result.errors[1].location.begin.line); } @@ -330,9 +397,17 @@ TEST_CASE_FIXTURE(ClassFixture, "table_class_unification_reports_sane_errors_for foo(a) )"); - LUAU_REQUIRE_ERROR_COUNT(2, result); - REQUIRE_EQ("Key 'w' not found in class 'Vector2'", toString(result.errors[0])); - REQUIRE_EQ("Key 'x' not found in class 'Vector2'. Did you mean 'X'?", toString(result.errors[1])); + if (FFlag::LuauSolverV2) + { + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK("Type 'Vector2' could not be converted into '{ Y: number, w: number, x: number }'" == toString(result.errors[0])); + } + else + { + LUAU_REQUIRE_ERROR_COUNT(2, result); + REQUIRE_EQ("Key 'w' not found in class 'Vector2'", toString(result.errors.at(0))); + REQUIRE_EQ("Key 'x' not found in class 'Vector2'. Did you mean 'X'?", toString(result.errors[1])); + } } TEST_CASE_FIXTURE(ClassFixture, "class_unification_type_mismatch_is_correct_order") @@ -345,7 +420,7 @@ TEST_CASE_FIXTURE(ClassFixture, "class_unification_type_mismatch_is_correct_orde LUAU_REQUIRE_ERROR_COUNT(2, result); - REQUIRE_EQ("Type 'BaseClass' could not be converted into 'number'", toString(result.errors[0])); + REQUIRE_EQ("Type 'BaseClass' could not be converted into 'number'", toString(result.errors.at(0))); REQUIRE_EQ("Type 'number' could not be converted into 'BaseClass'", toString(result.errors[1])); } @@ -359,7 +434,7 @@ b.X = 2 -- real Vector2.X is also read-only )"); LUAU_REQUIRE_ERROR_COUNT(4, result); - CHECK_EQ("Value of type 'Vector2?' could be nil", toString(result.errors[0])); + CHECK_EQ("Value of type 'Vector2?' could be nil", toString(result.errors.at(0))); CHECK_EQ("Value of type 'Vector2?' could be nil", toString(result.errors[1])); CHECK_EQ("Key 'Z' not found in class 'Vector2'", toString(result.errors[2])); CHECK_EQ("Value of type 'Vector2?' could be nil", toString(result.errors[3])); @@ -378,14 +453,27 @@ b(a) )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(R"(Type 'Vector2' could not be converted into '{- X: a, Y: string -}' + + if (FFlag::LuauSolverV2) + { + CHECK("Type 'number' could not be converted into 'string'" == toString(result.errors.at(0))); + } + else + { + const std::string expected = R"(Type 'Vector2' could not be converted into '{- X: number, Y: string -}' caused by: - Property 'Y' is not compatible. Type 'number' could not be converted into 'string')", - toString(result.errors[0])); + Property 'Y' is not compatible. +Type 'number' could not be converted into 'string')"; + + CHECK_EQ(expected, toString(result.errors.at(0))); + } } TEST_CASE_FIXTURE(ClassFixture, "class_type_mismatch_with_name_conflict") { + // CLI-116433 + ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + CheckResult result = check(R"( local i = ChildClass.New() type ChildClass = { x: number } @@ -393,7 +481,7 @@ local a: ChildClass = i )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ("Type 'ChildClass' from 'Test' could not be converted into 'ChildClass' from 'MainModule'", toString(result.errors[0])); + CHECK_EQ("Type 'ChildClass' from 'Test' could not be converted into 'ChildClass' from 'MainModule'", toString(result.errors.at(0))); } TEST_CASE_FIXTURE(ClassFixture, "intersections_of_unions_of_classes") @@ -422,8 +510,6 @@ TEST_CASE_FIXTURE(ClassFixture, "unions_of_intersections_of_classes") TEST_CASE_FIXTURE(ClassFixture, "index_instance_property") { - ScopedFastFlag luauAllowIndexClassParameters{"LuauAllowIndexClassParameters", true}; - CheckResult result = check(R"( local function execute(object: BaseClass, name: string) print(object[name]) @@ -431,13 +517,11 @@ TEST_CASE_FIXTURE(ClassFixture, "index_instance_property") )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ("Attempting a dynamic property access on type 'BaseClass' is unsafe and may cause exceptions at runtime", toString(result.errors[0])); + CHECK_EQ("Attempting a dynamic property access on type 'BaseClass' is unsafe and may cause exceptions at runtime", toString(result.errors.at(0))); } TEST_CASE_FIXTURE(ClassFixture, "index_instance_property_nonstrict") { - ScopedFastFlag luauAllowIndexClassParameters{"LuauAllowIndexClassParameters", true}; - CheckResult result = check(R"( --!nonstrict @@ -455,19 +539,47 @@ TEST_CASE_FIXTURE(ClassFixture, "type_mismatch_invariance_required_for_error") type A = { x: ChildClass } type B = { x: BaseClass } -local a: A +local a: A = { x = ChildClass.New() } local b: B = a )"); LUAU_REQUIRE_ERRORS(result); - if (FFlag::LuauTypeMismatchInvarianceInError) - CHECK_EQ(toString(result.errors[0]), R"(Type 'A' could not be converted into 'B' -caused by: - Property 'x' is not compatible. Type 'ChildClass' could not be converted into 'BaseClass' in an invariant context)"); + + if (FFlag::LuauSolverV2) + CHECK(toString(result.errors.at(0)) == "Type 'A' could not be converted into 'B'; at [read \"x\"], ChildClass is not exactly BaseClass"); else - CHECK_EQ(toString(result.errors[0]), R"(Type 'A' could not be converted into 'B' + { + const std::string expected = R"(Type 'A' could not be converted into 'B' caused by: - Property 'x' is not compatible. Type 'ChildClass' could not be converted into 'BaseClass')"); + Property 'x' is not compatible. +Type 'ChildClass' could not be converted into 'BaseClass' in an invariant context)"; + CHECK_EQ(expected, toString(result.errors.at(0))); + } +} + +TEST_CASE_FIXTURE(ClassFixture, "optional_class_casts_work_in_new_solver") +{ + ScopedFastFlag sff{FFlag::LuauSolverV2, true}; + + CheckResult result = check(R"( + type A = { x: ChildClass } + type B = { x: BaseClass } + + local a = { x = ChildClass.New() } :: A + local opt_a = a :: A? + local b = { x = BaseClass.New() } :: B + local opt_b = b :: B? + local b_from_a = a :: B + local b_from_opt_a = opt_a :: B + local opt_b_from_a = a :: B? + local opt_b_from_opt_a = opt_a :: B? + local a_from_b = b :: A + local a_from_opt_b = opt_b :: A + local opt_a_from_b = b :: A? + local opt_a_from_opt_b = opt_b :: A? + )"); + + LUAU_REQUIRE_NO_ERRORS(result); } TEST_CASE_FIXTURE(ClassFixture, "callable_classes") @@ -481,4 +593,300 @@ TEST_CASE_FIXTURE(ClassFixture, "callable_classes") CHECK_EQ("number", toString(requireType("y"))); } +TEST_CASE_FIXTURE(ClassFixture, "indexable_classes") +{ + // Test reading from an index + { + CheckResult result = check(R"( + local x : IndexableClass + local y = x.stringKey + )"); + LUAU_REQUIRE_NO_ERRORS(result); + } + { + CheckResult result = check(R"( + local x : IndexableClass + local y = x["stringKey"] + )"); + LUAU_REQUIRE_NO_ERRORS(result); + } + { + CheckResult result = check(R"( + local x : IndexableClass + local str : string + local y = x[str] -- Index with a non-const string + )"); + LUAU_REQUIRE_NO_ERRORS(result); + } + { + CheckResult result = check(R"( + local x : IndexableClass + local y = x[7] -- Index with a numeric key + )"); + LUAU_REQUIRE_NO_ERRORS(result); + } + + // Test writing to an index + { + CheckResult result = check(R"( + local x : IndexableClass + x.stringKey = 42 + )"); + LUAU_REQUIRE_NO_ERRORS(result); + } + { + CheckResult result = check(R"( + local x : IndexableClass + x["stringKey"] = 42 + )"); + LUAU_REQUIRE_NO_ERRORS(result); + } + { + CheckResult result = check(R"( + local x : IndexableClass + local str : string + x[str] = 42 -- Index with a non-const string + )"); + LUAU_REQUIRE_NO_ERRORS(result); + } + { + CheckResult result = check(R"( + local x : IndexableClass + x[1] = 42 -- Index with a numeric key + )"); + LUAU_REQUIRE_NO_ERRORS(result); + } + + // Try to index the class using an invalid type for the key (key type is 'number | string'.) + { + CheckResult result = check(R"( + local x : IndexableClass + local y = x[true] + )"); + + if (FFlag::LuauSolverV2) + CHECK( + "Type 'boolean' could not be converted into 'number | string'" == toString(result.errors.at(0)) + ); + else + CHECK_EQ( + toString(result.errors.at(0)), "Type 'boolean' could not be converted into 'number | string'; none of the union options are compatible" + ); + } + { + CheckResult result = check(R"( + local x : IndexableClass + x[true] = 42 + )"); + + if (FFlag::LuauSolverV2) + CHECK( + "Type 'boolean' could not be converted into 'number | string'" == toString(result.errors.at(0)) + ); + else + CHECK_EQ( + toString(result.errors.at(0)), "Type 'boolean' could not be converted into 'number | string'; none of the union options are compatible" + ); + } + + // Test type checking for the return type of the indexer (i.e. a number) + { + CheckResult result = check(R"( + local x : IndexableClass + x.key = "string value" + )"); + + if (FFlag::LuauSolverV2) + { + // Disabled for now. CLI-115686 + } + else + CHECK_EQ(toString(result.errors.at(0)), "Type 'string' could not be converted into 'number'"); + } + { + CheckResult result = check(R"( + local x : IndexableClass + local str : string = x.key + )"); + CHECK_EQ(toString(result.errors.at(0)), "Type 'number' could not be converted into 'string'"); + } + + // Check that we string key are rejected if the indexer's key type is not compatible with string + { + CheckResult result = check(R"( + local x : IndexableNumericKeyClass + x.key = 1 + )"); + CHECK_EQ(toString(result.errors.at(0)), "Key 'key' not found in class 'IndexableNumericKeyClass'"); + } + { + CheckResult result = check(R"( + local x : IndexableNumericKeyClass + x["key"] = 1 + )"); + if (FFlag::LuauSolverV2) + CHECK_EQ(toString(result.errors.at(0)), "Key 'key' not found in class 'IndexableNumericKeyClass'"); + else + CHECK_EQ(toString(result.errors.at(0)), "Type 'string' could not be converted into 'number'"); + } + { + CheckResult result = check(R"( + local x : IndexableNumericKeyClass + local str : string + x[str] = 1 -- Index with a non-const string + )"); + CHECK_EQ(toString(result.errors.at(0)), "Type 'string' could not be converted into 'number'"); + } + { + CheckResult result = check(R"( + local x : IndexableNumericKeyClass + local y = x.key + )"); + CHECK_EQ(toString(result.errors.at(0)), "Key 'key' not found in class 'IndexableNumericKeyClass'"); + } + { + CheckResult result = check(R"( + local x : IndexableNumericKeyClass + local y = x["key"] + )"); + if (FFlag::LuauSolverV2) + CHECK(toString(result.errors.at(0)) == "Key 'key' not found in class 'IndexableNumericKeyClass'"); + else + CHECK_EQ(toString(result.errors.at(0)), "Type 'string' could not be converted into 'number'"); + } + { + CheckResult result = check(R"( + local x : IndexableNumericKeyClass + local str : string + local y = x[str] -- Index with a non-const string + )"); + CHECK_EQ(toString(result.errors.at(0)), "Type 'string' could not be converted into 'number'"); + } +} + +TEST_CASE_FIXTURE(Fixture, "read_write_class_properties") +{ + ScopedFastFlag sff{FFlag::LuauSolverV2, true}; + + TypeArena& arena = frontend.globals.globalTypes; + + unfreeze(arena); + + TypeId instanceType = arena.addType(ClassType{"Instance", {}, nullopt, nullopt, {}, {}, "Test", {}}); + getMutable(instanceType)->props = {{"Parent", Property::rw(instanceType)}}; + + // + + TypeId workspaceType = arena.addType(ClassType{"Workspace", {}, nullopt, nullopt, {}, {}, "Test", {}}); + + TypeId scriptType = + arena.addType(ClassType{"Script", {{"Parent", Property::rw(workspaceType, instanceType)}}, instanceType, nullopt, {}, {}, "Test", {}}); + + TypeId partType = arena.addType(ClassType{ + "Part", + {{"BrickColor", Property::rw(builtinTypes->stringType)}, {"Parent", Property::rw(workspaceType, instanceType)}}, + instanceType, + nullopt, + {}, + {}, + "Test", + {} + }); + + getMutable(workspaceType)->props = {{"Script", Property::readonly(scriptType)}, {"Part", Property::readonly(partType)}}; + + frontend.globals.globalScope->bindings[frontend.globals.globalNames.names->getOrAdd("script")] = Binding{scriptType}; + + freeze(arena); + + CheckResult result = check(R"( + script.Parent.Part.BrickColor = 0xFFFFFF + script.Parent.Part.Parent = script + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CHECK(Location{{1, 40}, {1, 48}} == result.errors[0].location); + TypeMismatch* tm = get(result.errors[0]); + REQUIRE(tm); + CHECK(builtinTypes->stringType == tm->wantedType); + CHECK(builtinTypes->numberType == tm->givenType); +} + +TEST_CASE_FIXTURE(ClassFixture, "cannot_index_a_class_with_no_indexer") +{ + CheckResult result = check(R"( + local a = BaseClass.New() + + local c = a[1] + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CHECK_MESSAGE( + get(result.errors[0]), "Expected DynamicPropertyLookupOnClassesUnsafe but got " << result.errors[0] + ); + + CHECK(builtinTypes->errorType == requireType("c")); +} + +TEST_CASE_FIXTURE(ClassFixture, "cyclic_tables_are_assumed_to_be_compatible_with_classes") +{ + /* + * This is technically documenting a case where we are intentionally + * unsound. + * + * Our builtins are essentially defined like so: + * + * declare class BaseClass + * BaseField: number + * function BaseMethod(self, number): () + * read Touched: Connection + * end + * + * declare class Connection + * Connect: (Connection, (BaseClass) -> ()) -> () + * end + * + * The type we infer for `onTouch` is + * + * (t1) -> () where t1 = { read BaseField: unknown, read BaseMethod: (t1, number) -> () } + * + * In order to validate that onTouch can be passed to Connect, we must + * verify the following relation: + * + * BaseClass <: t1 where t1 = { read BaseField: unknown, read BaseMethod: (t1, number) -> () } + * + * However, the cycle between the table and the function gums up the works + * here and the worst thing is that it's perfectly reasonable in principle. + * Just from these types, we cannot see that BaseMethod will only be passed + * t1. Without that guarantee, BaseClass cannot be used as a subtype of t1. + * + * I think the theoretically-correct way to untangle this would be to infer + * t1 as a bounded existential type. + * + * For now, we have a subtyping has a rule that provisionally substitutes + * the table for the class type when performing the subtyping test. We + * essentially assume that, for all cyclic functions, that the table and the + * class are mutually subtypes of one another. + * + * For more information, read uses of Subtyping::substitutions. + */ + + CheckResult result = check(R"( + local c = BaseClass.New() + + function requiresNothing() end + + function onTouch(other) + requiresNothing(other:BaseMethod(0)) + print(other.BaseField) + end + + c.Touched:Connect(onTouch) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.definitions.test.cpp b/tests/TypeInfer.definitions.test.cpp index 2a681d1a6..e1eaf5e92 100644 --- a/tests/TypeInfer.definitions.test.cpp +++ b/tests/TypeInfer.definitions.test.cpp @@ -19,13 +19,13 @@ TEST_CASE_FIXTURE(Fixture, "definition_file_simple") declare foo2: typeof(foo) )"); - TypeId globalFooTy = getGlobalBinding(frontend, "foo"); + TypeId globalFooTy = getGlobalBinding(frontend.globals, "foo"); CHECK_EQ(toString(globalFooTy), "number"); - TypeId globalBarTy = getGlobalBinding(frontend, "bar"); + TypeId globalBarTy = getGlobalBinding(frontend.globals, "bar"); CHECK_EQ(toString(globalBarTy), "(number) -> string"); - TypeId globalFoo2Ty = getGlobalBinding(frontend, "foo2"); + TypeId globalFoo2Ty = getGlobalBinding(frontend.globals, "foo2"); CHECK_EQ(toString(globalFoo2Ty), "number"); CheckResult result = check(R"( @@ -48,20 +48,20 @@ TEST_CASE_FIXTURE(Fixture, "definition_file_loading") declare function var(...: any): string )"); - TypeId globalFooTy = getGlobalBinding(frontend, "foo"); + TypeId globalFooTy = getGlobalBinding(frontend.globals, "foo"); CHECK_EQ(toString(globalFooTy), "number"); - std::optional globalAsdfTy = frontend.getGlobalScope()->lookupType("Asdf"); + std::optional globalAsdfTy = frontend.globals.globalScope->lookupType("Asdf"); REQUIRE(bool(globalAsdfTy)); CHECK_EQ(toString(globalAsdfTy->type), "number | string"); - TypeId globalBarTy = getGlobalBinding(frontend, "bar"); + TypeId globalBarTy = getGlobalBinding(frontend.globals, "bar"); CHECK_EQ(toString(globalBarTy), "(number) -> string"); - TypeId globalFoo2Ty = getGlobalBinding(frontend, "foo2"); + TypeId globalFoo2Ty = getGlobalBinding(frontend.globals, "foo2"); CHECK_EQ(toString(globalFoo2Ty), "number"); - TypeId globalVarTy = getGlobalBinding(frontend, "var"); + TypeId globalVarTy = getGlobalBinding(frontend.globals, "var"); CHECK_EQ(toString(globalVarTy), "(...any) -> string"); @@ -77,25 +77,35 @@ TEST_CASE_FIXTURE(Fixture, "definition_file_loading") TEST_CASE_FIXTURE(Fixture, "load_definition_file_errors_do_not_pollute_global_scope") { - unfreeze(typeChecker.globalTypes); - LoadDefinitionFileResult parseFailResult = loadDefinitionFile(typeChecker, typeChecker.globalScope, R"( + unfreeze(frontend.globals.globalTypes); + LoadDefinitionFileResult parseFailResult = frontend.loadDefinitionFile( + frontend.globals, + frontend.globals.globalScope, + R"( declare foo )", - "@test"); - freeze(typeChecker.globalTypes); + "@test", + /* captureComments */ false + ); + freeze(frontend.globals.globalTypes); REQUIRE(!parseFailResult.success); - std::optional fooTy = tryGetGlobalBinding(frontend, "foo"); + std::optional fooTy = tryGetGlobalBinding(frontend.globals, "foo"); CHECK(!fooTy.has_value()); - LoadDefinitionFileResult checkFailResult = loadDefinitionFile(typeChecker, typeChecker.globalScope, R"( + LoadDefinitionFileResult checkFailResult = frontend.loadDefinitionFile( + frontend.globals, + frontend.globals.globalScope, + R"( local foo: string = 123 declare bar: typeof(foo) )", - "@test"); + "@test", + /* captureComments */ false + ); REQUIRE(!checkFailResult.success); - std::optional barTy = tryGetGlobalBinding(frontend, "bar"); + std::optional barTy = tryGetGlobalBinding(frontend.globals, "bar"); CHECK(!barTy.has_value()); } @@ -139,15 +149,20 @@ TEST_CASE_FIXTURE(Fixture, "definition_file_classes") TEST_CASE_FIXTURE(Fixture, "class_definitions_cannot_overload_non_function") { - unfreeze(typeChecker.globalTypes); - LoadDefinitionFileResult result = loadDefinitionFile(typeChecker, typeChecker.globalScope, R"( + unfreeze(frontend.globals.globalTypes); + LoadDefinitionFileResult result = frontend.loadDefinitionFile( + frontend.globals, + frontend.globals.globalScope, + R"( declare class A X: number X: string end )", - "@test"); - freeze(typeChecker.globalTypes); + "@test", + /* captureComments */ false + ); + freeze(frontend.globals.globalTypes); REQUIRE(!result.success); CHECK_EQ(result.parseResult.errors.size(), 0); @@ -160,15 +175,20 @@ TEST_CASE_FIXTURE(Fixture, "class_definitions_cannot_overload_non_function") TEST_CASE_FIXTURE(Fixture, "class_definitions_cannot_extend_non_class") { - unfreeze(typeChecker.globalTypes); - LoadDefinitionFileResult result = loadDefinitionFile(typeChecker, typeChecker.globalScope, R"( + unfreeze(frontend.globals.globalTypes); + LoadDefinitionFileResult result = frontend.loadDefinitionFile( + frontend.globals, + frontend.globals.globalScope, + R"( type NotAClass = {} declare class Foo extends NotAClass end )", - "@test"); - freeze(typeChecker.globalTypes); + "@test", + /* captureComments */ false + ); + freeze(frontend.globals.globalTypes); REQUIRE(!result.success); CHECK_EQ(result.parseResult.errors.size(), 0); @@ -181,16 +201,21 @@ TEST_CASE_FIXTURE(Fixture, "class_definitions_cannot_extend_non_class") TEST_CASE_FIXTURE(Fixture, "no_cyclic_defined_classes") { - unfreeze(typeChecker.globalTypes); - LoadDefinitionFileResult result = loadDefinitionFile(typeChecker, typeChecker.globalScope, R"( + unfreeze(frontend.globals.globalTypes); + LoadDefinitionFileResult result = frontend.loadDefinitionFile( + frontend.globals, + frontend.globals.globalScope, + R"( declare class Foo extends Bar end declare class Bar extends Foo end )", - "@test"); - freeze(typeChecker.globalTypes); + "@test", + /* captureComments */ false + ); + freeze(frontend.globals.globalTypes); REQUIRE(!result.success); } @@ -228,10 +253,14 @@ TEST_CASE_FIXTURE(Fixture, "class_definition_function_prop") declare class Foo X: (number) -> string end + + declare Foo: { + new: () -> Foo + } )"); CheckResult result = check(R"( - local x: Foo + local x: Foo = Foo.new() local prop = x.X )"); @@ -248,10 +277,14 @@ TEST_CASE_FIXTURE(Fixture, "definition_file_class_function_args") y: (a: number, b: string) -> string end + + declare Foo: { + new: () -> Foo + } )"); CheckResult result = check(R"( - local x: Foo + local x: Foo = Foo.new() local methodRef1 = x.foo1 local methodRef2 = x.foo2 local prop = x.y @@ -281,16 +314,16 @@ TEST_CASE_FIXTURE(Fixture, "definitions_documentation_symbols") } )"); - std::optional xBinding = typeChecker.globalScope->linearSearchForBinding("x"); + std::optional xBinding = frontend.globals.globalScope->linearSearchForBinding("x"); REQUIRE(bool(xBinding)); // note: loadDefinition uses the @test package name. CHECK_EQ(xBinding->documentationSymbol, "@test/global/x"); - std::optional fooTy = typeChecker.globalScope->lookupType("Foo"); + std::optional fooTy = frontend.globals.globalScope->lookupType("Foo"); REQUIRE(bool(fooTy)); CHECK_EQ(fooTy->type->documentationSymbol, "@test/globaltype/Foo"); - std::optional barTy = typeChecker.globalScope->lookupType("Bar"); + std::optional barTy = frontend.globals.globalScope->lookupType("Bar"); REQUIRE(bool(barTy)); CHECK_EQ(barTy->type->documentationSymbol, "@test/globaltype/Bar"); @@ -299,7 +332,7 @@ TEST_CASE_FIXTURE(Fixture, "definitions_documentation_symbols") REQUIRE_EQ(barClass->props.count("prop"), 1); CHECK_EQ(barClass->props["prop"].documentationSymbol, "@test/globaltype/Bar.prop"); - std::optional yBinding = typeChecker.globalScope->linearSearchForBinding("y"); + std::optional yBinding = frontend.globals.globalScope->linearSearchForBinding("y"); REQUIRE(bool(yBinding)); CHECK_EQ(yBinding->documentationSymbol, "@test/global/y"); @@ -319,9 +352,25 @@ TEST_CASE_FIXTURE(Fixture, "definitions_symbols_are_generated_for_recursively_re declare function myFunc(): MyClass )"); - std::optional myClassTy = typeChecker.globalScope->lookupType("MyClass"); + std::optional myClassTy = frontend.globals.globalScope->lookupType("MyClass"); REQUIRE(bool(myClassTy)); CHECK_EQ(myClassTy->type->documentationSymbol, "@test/globaltype/MyClass"); + + ClassType* cls = getMutable(myClassTy->type); + REQUIRE(bool(cls)); + REQUIRE_EQ(cls->props.count("myMethod"), 1); + + const auto& method = cls->props["myMethod"]; + CHECK_EQ(method.documentationSymbol, "@test/globaltype/MyClass.myMethod"); + + FunctionType* function = getMutable(method.type()); + REQUIRE(function); + + REQUIRE(function->definition.has_value()); + CHECK(function->definition->definitionModuleName == "@test"); + CHECK(function->definition->definitionLocation == Location({2, 12}, {2, 35})); + CHECK(!function->definition->varargLocation.has_value()); + CHECK(function->definition->originalNameLocation == Location({2, 21}, {2, 29})); } TEST_CASE_FIXTURE(Fixture, "documentation_symbols_dont_attach_to_persistent_types") @@ -330,7 +379,7 @@ TEST_CASE_FIXTURE(Fixture, "documentation_symbols_dont_attach_to_persistent_type export type Evil = string )"); - std::optional ty = typeChecker.globalScope->lookupType("Evil"); + std::optional ty = frontend.globals.globalScope->lookupType("Evil"); REQUIRE(bool(ty)); CHECK_EQ(ty->type->documentationSymbol, std::nullopt); } @@ -394,10 +443,40 @@ TEST_CASE_FIXTURE(Fixture, "class_definition_string_props") CHECK_EQ(toString(requireType("y")), "string"); } + +TEST_CASE_FIXTURE(Fixture, "class_definition_indexer") +{ + loadDefinition(R"( + declare class Foo + [number]: string + end + )"); + + CheckResult result = check(R"( + local x: Foo + local y = x[1] + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + const ClassType* ctv = get(requireType("x")); + REQUIRE(ctv != nullptr); + + REQUIRE(bool(ctv->indexer)); + + CHECK_EQ(*ctv->indexer->indexType, *builtinTypes->numberType); + CHECK_EQ(*ctv->indexer->indexResultType, *builtinTypes->stringType); + + CHECK_EQ(toString(requireType("y")), "string"); +} + TEST_CASE_FIXTURE(Fixture, "class_definitions_reference_other_classes") { - unfreeze(typeChecker.globalTypes); - LoadDefinitionFileResult result = loadDefinitionFile(typeChecker, typeChecker.globalScope, R"( + unfreeze(frontend.globals.globalTypes); + LoadDefinitionFileResult result = frontend.loadDefinitionFile( + frontend.globals, + frontend.globals.globalScope, + R"( declare class Channel Messages: { Message } OnMessage: (message: Message) -> () @@ -408,10 +487,33 @@ TEST_CASE_FIXTURE(Fixture, "class_definitions_reference_other_classes") Channel: Channel end )", - "@test"); - freeze(typeChecker.globalTypes); + "@test", + /* captureComments */ false + ); + freeze(frontend.globals.globalTypes); + + REQUIRE(result.success); +} + +TEST_CASE_FIXTURE(Fixture, "definition_file_has_source_module_name_set") +{ + LoadDefinitionFileResult result = loadDefinition(R"( + declare class Foo + end + )"); REQUIRE(result.success); + + CHECK_EQ(result.sourceModule.name, "@test"); + CHECK_EQ(result.sourceModule.humanReadableName, "@test"); + + std::optional fooTy = frontend.globals.globalScope->lookupType("Foo"); + REQUIRE(fooTy); + + const ClassType* ctv = get(fooTy->type); + + REQUIRE(ctv); + CHECK_EQ(ctv->definitionModuleName, "@test"); } TEST_SUITE_END(); diff --git a/tests/TypeInfer.functions.test.cpp b/tests/TypeInfer.functions.test.cpp index 7c2e451a6..3803df7c1 100644 --- a/tests/TypeInfer.functions.test.cpp +++ b/tests/TypeInfer.functions.test.cpp @@ -8,16 +8,54 @@ #include "Luau/Type.h" #include "Luau/VisitType.h" +#include "ClassFixture.h" #include "Fixture.h" +#include "ScopedFlags.h" #include "doctest.h" using namespace Luau; LUAU_FASTFLAG(LuauInstantiateInSubtyping); +LUAU_FASTFLAG(LuauSolverV2); +LUAU_FASTINT(LuauTarjanChildLimit); TEST_SUITE_BEGIN("TypeInferFunctions"); +TEST_CASE_FIXTURE(Fixture, "general_case_table_literal_blocks") +{ + CheckResult result = check(R"( +--!strict +function f(x : {[any]: number}) + return x +end + +local Foo = {bar = "$$$"} + +f({[Foo.bar] = 0}) +)"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "overload_resolution") +{ + CheckResult result = check(R"( + type A = (number) -> string + type B = (string) -> number + + local function foo(f: A & B) + return f(1), f("five") + end + )"); + LUAU_REQUIRE_NO_ERRORS(result); + TypeId t = requireType("foo"); + const FunctionType* fooType = get(requireType("foo")); + REQUIRE(fooType != nullptr); + + CHECK(toString(t) == "(((number) -> string) & ((string) -> number)) -> (string, number)"); +} + TEST_CASE_FIXTURE(Fixture, "tc_function") { CheckResult result = check("function five() return 5 end"); @@ -29,13 +67,36 @@ TEST_CASE_FIXTURE(Fixture, "tc_function") TEST_CASE_FIXTURE(Fixture, "check_function_bodies") { - CheckResult result = check("function myFunction() local a = 0 a = true end"); + CheckResult result = check(R"( + function myFunction(): number + local a = 0 + a = true + return a + end + )"); + LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(result.errors[0], (TypeError{Location{Position{0, 44}, Position{0, 48}}, TypeMismatch{ - typeChecker.numberType, - typeChecker.booleanType, - }})); + if (FFlag::LuauSolverV2) + { + const TypePackMismatch* tm = get(result.errors[0]); + REQUIRE(tm); + CHECK(toString(tm->wantedTp) == "number"); + CHECK(toString(tm->givenTp) == "boolean"); + } + else + { + CHECK_EQ( + result.errors[0], + (TypeError{ + Location{Position{3, 16}, Position{3, 20}}, + TypeMismatch{ + builtinTypes->numberType, + builtinTypes->booleanType, + } + }) + ); + } } TEST_CASE_FIXTURE(Fixture, "cannot_hoist_interior_defns_into_signature") @@ -52,11 +113,17 @@ TEST_CASE_FIXTURE(Fixture, "cannot_hoist_interior_defns_into_signature") )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK(result.errors[0] == TypeError{Location{{1, 28}, {1, 29}}, getMainSourceModule()->name, - UnknownSymbol{ - "T", - UnknownSymbol::Context::Type, - }}); + CHECK( + result.errors[0] == + TypeError{ + Location{{1, 28}, {1, 29}}, + getMainSourceModule()->name, + UnknownSymbol{ + "T", + UnknownSymbol::Context::Type, + } + } + ); } TEST_CASE_FIXTURE(Fixture, "infer_return_type") @@ -70,7 +137,7 @@ TEST_CASE_FIXTURE(Fixture, "infer_return_type") std::vector retVec = flatten(takeFiveType->retTypes).first; REQUIRE(!retVec.empty()); - REQUIRE_EQ(*follow(retVec[0]), *typeChecker.numberType); + REQUIRE_EQ(*follow(retVec[0]), *builtinTypes->numberType); } TEST_CASE_FIXTURE(Fixture, "infer_from_function_return_type") @@ -78,7 +145,7 @@ TEST_CASE_FIXTURE(Fixture, "infer_from_function_return_type") CheckResult result = check("function take_five() return 5 end local five = take_five()"); LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(*typeChecker.numberType, *follow(requireType("five"))); + CHECK_EQ(*builtinTypes->numberType, *follow(requireType("five"))); } TEST_CASE_FIXTURE(Fixture, "infer_that_function_does_not_return_a_table") @@ -92,7 +159,7 @@ TEST_CASE_FIXTURE(Fixture, "infer_that_function_does_not_return_a_table") )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(result.errors[0], (TypeError{Location{Position{5, 8}, Position{5, 24}}, NotATable{typeChecker.numberType}})); + CHECK_EQ(result.errors[0], (TypeError{Location{Position{5, 8}, Position{5, 24}}, NotATable{builtinTypes->numberType}})); } TEST_CASE_FIXTURE(Fixture, "generalize_table_property") @@ -111,7 +178,7 @@ TEST_CASE_FIXTURE(Fixture, "generalize_table_property") const TableType* tt = get(follow(t)); REQUIRE(tt); - TypeId fooTy = tt->props.at("foo").type; + TypeId fooTy = tt->props.at("foo").type(); CHECK("(a) -> a" == toString(fooTy)); } @@ -156,7 +223,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "vararg_function_is_quantified") REQUIRE(ttv); REQUIRE(ttv->props.count("f")); - TypeId k = ttv->props["f"].type; + TypeId k = ttv->props["f"].type(); REQUIRE(k); } @@ -169,14 +236,27 @@ TEST_CASE_FIXTURE(Fixture, "list_only_alternative_overloads_that_match_argument_ LUAU_REQUIRE_ERROR_COUNT(2, result); - TypeMismatch* tm = get(result.errors[0]); - REQUIRE(tm); - CHECK_EQ(typeChecker.numberType, tm->wantedType); - CHECK_EQ(typeChecker.stringType, tm->givenType); + if (FFlag::LuauSolverV2) + { + GenericError* g = get(result.errors[0]); + REQUIRE(g); + CHECK(g->message == "None of the overloads for function that accept 1 arguments are compatible."); + } + else + { + TypeMismatch* tm = get(result.errors[0]); + REQUIRE(tm); + CHECK_EQ(builtinTypes->numberType, tm->wantedType); + CHECK_EQ(builtinTypes->stringType, tm->givenType); + } ExtraInformation* ei = get(result.errors[1]); REQUIRE(ei); - CHECK_EQ("Other overloads are also not viable: (number) -> string", ei->message); + + if (FFlag::LuauSolverV2) + CHECK("Available overloads: (number) -> number; (number) -> string; and (number, number) -> number" == ei->message); + else + CHECK_EQ("Other overloads are also not viable: (number) -> string", ei->message); } TEST_CASE_FIXTURE(Fixture, "list_all_overloads_if_no_overload_takes_given_argument_count") @@ -208,8 +288,8 @@ TEST_CASE_FIXTURE(Fixture, "dont_give_other_overloads_message_if_only_one_argume TypeMismatch* tm = get(result.errors[0]); REQUIRE(tm); - CHECK_EQ(typeChecker.numberType, tm->wantedType); - CHECK_EQ(typeChecker.stringType, tm->givenType); + CHECK_EQ(builtinTypes->numberType, tm->wantedType); + CHECK_EQ(builtinTypes->stringType, tm->givenType); } TEST_CASE_FIXTURE(Fixture, "infer_return_type_from_selected_overload") @@ -229,6 +309,9 @@ TEST_CASE_FIXTURE(Fixture, "infer_return_type_from_selected_overload") TEST_CASE_FIXTURE(Fixture, "too_many_arguments") { + // This is not part of the new non-strict specification currently. + ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + CheckResult result = check(R"( --!nonstrict @@ -345,6 +428,22 @@ TEST_CASE_FIXTURE(Fixture, "another_recursive_local_function") LUAU_REQUIRE_NO_ERRORS(result); } +// We had a bug where we'd look up the type of a recursive call using the DFG, +// not the bindings tables. As a result, we would erroneously use the +// generalized type of foo() in this recursive fragment. This creates a +// constraint cycle that doesn't always work itself out. +// +// The fix is for the DFG node within the scope of foo() to retain the +// ungeneralized type of foo. +TEST_CASE_FIXTURE(BuiltinsFixture, "recursive_calls_must_refer_to_the_ungeneralized_type") +{ + CheckResult result = check(R"( + function foo() + string.format('%s: %s', "51", foo()) + end + )"); +} + TEST_CASE_FIXTURE(Fixture, "cyclic_function_type_in_rets") { CheckResult result = check(R"( @@ -379,13 +478,27 @@ TEST_CASE_FIXTURE(Fixture, "another_higher_order_function") TEST_CASE_FIXTURE(Fixture, "another_other_higher_order_function") { - CheckResult result = check(R"( - local d - d:foo() - d:foo() - )"); + if (FFlag::LuauSolverV2) + { + CheckResult result = check(R"( + local function f(d) + d:foo() + d:foo() + end + )"); - LUAU_REQUIRE_NO_ERRORS(result); + LUAU_REQUIRE_NO_ERRORS(result); + } + else + { + CheckResult result = check(R"( + local d + d:foo() + d:foo() + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + } } TEST_CASE_FIXTURE(Fixture, "local_function") @@ -488,6 +601,9 @@ TEST_CASE_FIXTURE(Fixture, "duplicate_functions_allowed_in_nonstrict") TEST_CASE_FIXTURE(Fixture, "duplicate_functions_with_different_signatures_not_allowed_in_nonstrict") { + // This is not part of the spec for the new non-strict mode currently. + ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + CheckResult result = check(R"( --!nonstrict function foo(): number @@ -515,7 +631,7 @@ TEST_CASE_FIXTURE(Fixture, "complicated_return_types_require_an_explicit_annotat local i = 0 function most_of_the_natural_numbers(): number? if i < 10 then - i = i + 1 + i += 1 return i else return nil @@ -619,7 +735,7 @@ TEST_CASE_FIXTURE(Fixture, "higher_order_function_3") REQUIRE_EQ(1, argVec.size()); const TableType* argType = get(follow(argVec[0])); - REQUIRE(argType != nullptr); + REQUIRE_MESSAGE(argType != nullptr, argVec[0]); CHECK(bool(argType->indexer)); } @@ -640,7 +756,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "higher_order_function_4") end end - function mergesort(arr, comp) + function mergesort(arr: {T}, comp: (T, T) -> boolean) local work = {} for i = 1, #arr do work[i] = arr[i] @@ -704,7 +820,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "mutual_recursion") )"); LUAU_REQUIRE_NO_ERRORS(result); - dumpErrors(result); } TEST_CASE_FIXTURE(BuiltinsFixture, "toposort_doesnt_break_mutual_recursion") @@ -765,6 +880,9 @@ TEST_CASE_FIXTURE(Fixture, "another_indirect_function_case_where_it_is_ok_to_pro TEST_CASE_FIXTURE(Fixture, "report_exiting_without_return_nonstrict") { + // new non-strict mode spec does not include this error yet. + ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + CheckResult result = check(R"( --!nonstrict @@ -846,15 +964,27 @@ TEST_CASE_FIXTURE(Fixture, "calling_function_with_incorrect_argument_type_yields LUAU_REQUIRE_ERROR_COUNT(2, result); - CHECK_EQ(result.errors[0], (TypeError{Location{Position{3, 12}, Position{3, 18}}, TypeMismatch{ - typeChecker.numberType, - typeChecker.stringType, - }})); - - CHECK_EQ(result.errors[1], (TypeError{Location{Position{3, 20}, Position{3, 23}}, TypeMismatch{ - typeChecker.stringType, - typeChecker.numberType, - }})); + CHECK_EQ( + result.errors[0], + (TypeError{ + Location{Position{3, 12}, Position{3, 18}}, + TypeMismatch{ + builtinTypes->numberType, + builtinTypes->stringType, + } + }) + ); + + CHECK_EQ( + result.errors[1], + (TypeError{ + Location{Position{3, 20}, Position{3, 23}}, + TypeMismatch{ + builtinTypes->stringType, + builtinTypes->numberType, + } + }) + ); } TEST_CASE_FIXTURE(BuiltinsFixture, "calling_function_with_anytypepack_doesnt_leak_free_types") @@ -877,11 +1007,17 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "calling_function_with_anytypepack_doesnt_lea opts.exhaustive = true; opts.maxTableLength = 0; - CHECK_EQ("{any}", toString(requireType("tab"), opts)); + if (FFlag::LuauSolverV2) + CHECK_EQ("{string}", toString(requireType("tab"), opts)); + else + CHECK_EQ("{any}", toString(requireType("tab"), opts)); } TEST_CASE_FIXTURE(Fixture, "too_many_return_values") { + // FIXME: CLI-116157 variadic and generic type packs seem to be interacting incorrectly. + ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + CheckResult result = check(R"( --!strict @@ -903,6 +1039,9 @@ TEST_CASE_FIXTURE(Fixture, "too_many_return_values") TEST_CASE_FIXTURE(Fixture, "too_many_return_values_in_parentheses") { + // FIXME: CLI-116157 variadic and generic type packs seem to be interacting incorrectly. + ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + CheckResult result = check(R"( --!strict @@ -924,6 +1063,9 @@ TEST_CASE_FIXTURE(Fixture, "too_many_return_values_in_parentheses") TEST_CASE_FIXTURE(Fixture, "too_many_return_values_no_function") { + // FIXME: CLI-116157 variadic and generic type packs seem to be interacting incorrectly. + ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + CheckResult result = check(R"( --!strict @@ -964,13 +1106,25 @@ TEST_CASE_FIXTURE(Fixture, "function_does_not_return_enough_values") end )"); - LUAU_REQUIRE_ERROR_COUNT(1, result); + if (FFlag::LuauSolverV2) + { + LUAU_REQUIRE_ERROR_COUNT(1, result); - CountMismatch* acm = get(result.errors[0]); - REQUIRE(acm); - CHECK_EQ(acm->context, CountMismatch::Return); - CHECK_EQ(acm->expected, 2); - CHECK_EQ(acm->actual, 1); + auto tpm = get(result.errors[0]); + REQUIRE(tpm); + CHECK("number, string" == toString(tpm->wantedTp)); + CHECK("number" == toString(tpm->givenTp)); + } + else + { + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CountMismatch* acm = get(result.errors[0]); + REQUIRE(acm); + CHECK_EQ(acm->context, CountMismatch::Return); + CHECK_EQ(acm->expected, 2); + CHECK_EQ(acm->actual, 1); + } } TEST_CASE_FIXTURE(Fixture, "function_cast_error_uses_correct_language") @@ -992,13 +1146,19 @@ TEST_CASE_FIXTURE(Fixture, "function_cast_error_uses_correct_language") REQUIRE(tm1); CHECK_EQ("(string) -> number", toString(tm1->wantedType)); - CHECK_EQ("(string, *error-type*) -> number", toString(tm1->givenType)); + if (FFlag::LuauSolverV2) + CHECK_EQ("(unknown, unknown) -> number", toString(tm1->givenType)); + else + CHECK_EQ("(string, *error-type*) -> number", toString(tm1->givenType)); auto tm2 = get(result.errors[1]); REQUIRE(tm2); CHECK_EQ("(number, number) -> (number, number)", toString(tm2->wantedType)); - CHECK_EQ("(string, *error-type*) -> number", toString(tm2->givenType)); + if (FFlag::LuauSolverV2) + CHECK_EQ("(unknown, unknown) -> number", toString(tm1->givenType)); + else + CHECK_EQ("(string, *error-type*) -> number", toString(tm2->givenType)); } TEST_CASE_FIXTURE(Fixture, "no_lossy_function_type") @@ -1015,8 +1175,11 @@ TEST_CASE_FIXTURE(Fixture, "no_lossy_function_type") LUAU_REQUIRE_NO_ERRORS(result); TypeId type = requireTypeAtPosition(Position(6, 14)); - CHECK_EQ("(tbl, number, number) -> number", toString(type)); - auto ftv = get(type); + if (FFlag::LuauSolverV2) + CHECK_EQ("(unknown, number, number) -> number", toString(type)); + else + CHECK_EQ("(tbl, number, number) -> number", toString(type)); + auto ftv = get(follow(type)); REQUIRE(ftv); CHECK(ftv->hasSelf); } @@ -1058,13 +1221,20 @@ TEST_CASE_FIXTURE(Fixture, "return_type_by_overload") LUAU_REQUIRE_ERRORS(result); CHECK_EQ("string", toString(requireType("x"))); - CHECK_EQ("number", toString(requireType("y"))); + // the new solver does not currently "favor" arity-matching overloads when the call itself is ill-typed. + if (FFlag::LuauSolverV2) + CHECK_EQ("string", toString(requireType("y"))); + else + CHECK_EQ("number", toString(requireType("y"))); // Should this be string|number? CHECK_EQ("string", toString(requireType("z"))); } TEST_CASE_FIXTURE(BuiltinsFixture, "infer_anonymous_function_arguments") { + // FIXME: CLI-116133 bidirectional type inference needs to push expected types in for higher-order function calls + ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + // Simple direct arg to arg propagation CheckResult result = check(R"( type Table = { x: number, y: number } @@ -1120,21 +1290,28 @@ f(function(a, b, c, ...) return a + b end) LUAU_REQUIRE_ERRORS(result); + std::string expected; if (FFlag::LuauInstantiateInSubtyping) { - CHECK_EQ(R"(Type '(number, number, a) -> number' could not be converted into '(number, number) -> number' + expected = R"(Type + '(number, number, a) -> number' +could not be converted into + '(number, number) -> number' caused by: - Argument count mismatch. Function expects 3 arguments, but only 2 are specified)", - toString(result.errors[0])); + Argument count mismatch. Function expects 3 arguments, but only 2 are specified)"; } else { - CHECK_EQ(R"(Type '(number, number, a) -> number' could not be converted into '(number, number) -> number' + expected = R"(Type + '(number, number, *error-type*) -> number' +could not be converted into + '(number, number) -> number' caused by: - Argument count mismatch. Function expects 3 arguments, but only 2 are specified)", - toString(result.errors[0])); + Argument count mismatch. Function expects 3 arguments, but only 2 are specified)"; } + CHECK_EQ(expected, toString(result.errors[0])); + // Infer from variadic packs into elements result = check(R"( function f(a: (...number) -> number) return a(1, 2) end @@ -1172,110 +1349,99 @@ f(function(x) return x * 2 end) LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(BuiltinsFixture, "infer_anonymous_function_arguments") +TEST_CASE_FIXTURE(BuiltinsFixture, "infer_generic_function_function_argument") { - // Simple direct arg to arg propagation + // FIXME: CLI-116133 bidirectional type inference needs to push expected types in for higher-order function calls + ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + CheckResult result = check(R"( -type Table = { x: number, y: number } -local function f(a: (Table) -> number) return a({x = 1, y = 2}) end -f(function(a) return a.x + a.y end) +local function sum(x: a, y: a, f: (a, a) -> a) return f(x, y) end +return sum(2, 3, function(a, b) return a + b end) )"); LUAU_REQUIRE_NO_ERRORS(result); - // An optional function is accepted, but since we already provide a function, nil can be ignored result = check(R"( -type Table = { x: number, y: number } -local function f(a: ((Table) -> number)?) if a then return a({x = 1, y = 2}) else return 0 end end -f(function(a) return a.x + a.y end) +local function map(arr: {a}, f: (a) -> b) local r = {} for i,v in ipairs(arr) do table.insert(r, f(v)) end return r end +local a = {1, 2, 3} +local r = map(a, function(a) return a + a > 100 end) )"); LUAU_REQUIRE_NO_ERRORS(result); + REQUIRE_EQ("{boolean}", toString(requireType("r"))); - // Make sure self calls match correct index - result = check(R"( -type Table = { x: number, y: number } -local x = {} -x.b = {x = 1, y = 2} -function x:f(a: (Table) -> number) return a(self.b) end -x:f(function(a) return a.x + a.y end) + check(R"( +local function foldl(arr: {a}, init: b, f: (b, a) -> b) local r = init for i,v in ipairs(arr) do r = f(r, v) end return r end +local a = {1, 2, 3} +local r = foldl(a, {s=0,c=0}, function(a, b) return {s = a.s + b, c = a.c + 1} end) )"); LUAU_REQUIRE_NO_ERRORS(result); + REQUIRE_EQ("{ c: number, s: number }", toString(requireType("r"))); +} - // Mix inferred and explicit argument types - result = check(R"( -function f(a: (a: number, b: number, c: boolean) -> number) return a(1, 2, true) end -f(function(a: number, b, c) return c and a + b or b - a end) - )"); +TEST_CASE_FIXTURE(Fixture, "infer_generic_function_function_argument_overloaded") +{ + // FIXME: CLI-116133 bidirectional type inference needs to push expected types in for higher-order function calls + ScopedFastFlag sff{FFlag::LuauSolverV2, false}; - LUAU_REQUIRE_NO_ERRORS(result); + CheckResult result = check(R"( +local function g1(a: T, f: (T) -> T) return f(a) end +local function g2(a: T, b: T, f: (T, T) -> T) return f(a, b) end - // Anonymous function has a variadic pack - result = check(R"( -type Table = { x: number, y: number } -local function f(a: (Table) -> number) return a({x = 1, y = 2}) end -f(function(...) return select(1, ...).z end) +local g12: typeof(g1) & typeof(g2) + +g12(1, function(x) return x + x end) +g12(1, 2, function(x, y) return x + y end) )"); - LUAU_REQUIRE_ERRORS(result); - CHECK_EQ("Key 'z' not found in table 'Table'", toString(result.errors[0])); + LUAU_REQUIRE_NO_ERRORS(result); - // Can't accept more arguments than provided result = check(R"( -function f(a: (a: number, b: number) -> number) return a(1, 2) end -f(function(a, b, c, ...) return a + b end) - )"); - - LUAU_REQUIRE_ERRORS(result); +local function g1(a: T, f: (T) -> T) return f(a) end +local function g2(a: T, b: T, f: (T, T) -> T) return f(a, b) end - if (FFlag::LuauInstantiateInSubtyping) - { - CHECK_EQ(R"(Type '(number, number, a) -> number' could not be converted into '(number, number) -> number' -caused by: - Argument count mismatch. Function expects 3 arguments, but only 2 are specified)", - toString(result.errors[0])); - } - else - { - CHECK_EQ(R"(Type '(number, number, a) -> number' could not be converted into '(number, number) -> number' -caused by: - Argument count mismatch. Function expects 3 arguments, but only 2 are specified)", - toString(result.errors[0])); - } +local g12: typeof(g1) & typeof(g2) - // Infer from variadic packs into elements - result = check(R"( -function f(a: (...number) -> number) return a(1, 2) end -f(function(a, b) return a + b end) +g12({x=1}, function(x) return {x=-x.x} end) +g12({x=1}, {x=2}, function(x, y) return {x=x.x + y.x} end) )"); LUAU_REQUIRE_NO_ERRORS(result); +} - // Infer from variadic packs into variadic packs - result = check(R"( -type Table = { x: number, y: number } -function f(a: (...Table) -> number) return a({x = 1, y = 2}, {x = 3, y = 4}) end -f(function(a, ...) local b = ... return b.z end) +TEST_CASE_FIXTURE(BuiltinsFixture, "infer_generic_lib_function_function_argument") +{ + CheckResult result = check(R"( +local a = {{x=4}, {x=7}, {x=1}} +table.sort(a, function(x, y) return x.x < y.x end) )"); - LUAU_REQUIRE_ERRORS(result); - CHECK_EQ("Key 'z' not found in table 'Table'", toString(result.errors[0])); + LUAU_REQUIRE_NO_ERRORS(result); +} - // Return type inference - result = check(R"( -type Table = { x: number, y: number } -function f(a: (number) -> Table) return a(4) end -f(function(x) return x * 2 end) +TEST_CASE_FIXTURE(Fixture, "variadic_any_is_compatible_with_a_generic_TypePack") +{ + CheckResult result = check(R"( + --!strict + local function f(...) return ... end + local g = function(...) return f(...) end )"); - LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ("Type 'number' could not be converted into 'Table'", toString(result.errors[0])); + LUAU_REQUIRE_NO_ERRORS(result); +} - // Return type doesn't inference 'nil' - result = check(R"( - function f(a: (number) -> nil) return a(4) end - f(function(x) print(x) end) +// https://github.com/luau-lang/luau/issues/767 +TEST_CASE_FIXTURE(BuiltinsFixture, "variadic_any_is_compatible_with_a_generic_TypePack_2") +{ + CheckResult result = check(R"( + local function somethingThatsAny(...: any) + print(...) + end + + local function x(...: T...) + somethingThatsAny(...) -- Failed to unify variadic type packs + end )"); LUAU_REQUIRE_NO_ERRORS(result); @@ -1283,6 +1449,9 @@ f(function(x) return x * 2 end) TEST_CASE_FIXTURE(Fixture, "infer_anonymous_function_arguments_outside_call") { + // FIXME: CLI-116133 bidirectional type inference needs to push expected types in for higher-order function calls + ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + CheckResult result = check(R"( type Table = { x: number, y: number } local f: (Table) -> number = function(t) return t.x + t.y end @@ -1314,11 +1483,18 @@ local function i(): ...{string|number} end )"); - LUAU_REQUIRE_NO_ERRORS(result); + // `h` regresses in the new solver, the return type is not being pushed into the body. + if (FFlag::LuauSolverV2) + LUAU_REQUIRE_ERROR_COUNT(1, result); + else + LUAU_REQUIRE_NO_ERRORS(result); } TEST_CASE_FIXTURE(Fixture, "error_detailed_function_mismatch_arg_count") { + // FIXME: CLI-116111 test disabled until type path stringification is improved + ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + CheckResult result = check(R"( type A = (number, number) -> string type B = (number) -> string @@ -1328,13 +1504,20 @@ local b: B = a )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(toString(result.errors[0]), R"(Type '(number, number) -> string' could not be converted into '(number) -> string' + const std::string expected = R"(Type + '(number, number) -> string' +could not be converted into + '(number) -> string' caused by: - Argument count mismatch. Function expects 2 arguments, but only 1 is specified)"); + Argument count mismatch. Function expects 2 arguments, but only 1 is specified)"; + CHECK_EQ(expected, toString(result.errors[0])); } TEST_CASE_FIXTURE(Fixture, "error_detailed_function_mismatch_arg") { + // FIXME: CLI-116111 test disabled until type path stringification is improved + ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + CheckResult result = check(R"( type A = (number, number) -> string type B = (number, string) -> string @@ -1344,13 +1527,21 @@ local b: B = a )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(toString(result.errors[0]), R"(Type '(number, number) -> string' could not be converted into '(number, string) -> string' + const std::string expected = R"(Type + '(number, number) -> string' +could not be converted into + '(number, string) -> string' caused by: - Argument #2 type is not compatible. Type 'string' could not be converted into 'number')"); + Argument #2 type is not compatible. +Type 'string' could not be converted into 'number')"; + CHECK_EQ(expected, toString(result.errors[0])); } TEST_CASE_FIXTURE(Fixture, "error_detailed_function_mismatch_ret_count") { + // FIXME: CLI-116111 test disabled until type path stringification is improved + ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + CheckResult result = check(R"( type A = (number, number) -> (number) type B = (number, number) -> (number, boolean) @@ -1360,13 +1551,20 @@ local b: B = a )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(toString(result.errors[0]), R"(Type '(number, number) -> number' could not be converted into '(number, number) -> (number, boolean)' + const std::string expected = R"(Type + '(number, number) -> number' +could not be converted into + '(number, number) -> (number, boolean)' caused by: - Function only returns 1 value, but 2 are required here)"); + Function only returns 1 value, but 2 are required here)"; + CHECK_EQ(expected, toString(result.errors[0])); } TEST_CASE_FIXTURE(Fixture, "error_detailed_function_mismatch_ret") { + // FIXME: CLI-116111 test disabled until type path stringification is improved + ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + CheckResult result = check(R"( type A = (number, number) -> string type B = (number, number) -> number @@ -1376,13 +1574,21 @@ local b: B = a )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(toString(result.errors[0]), R"(Type '(number, number) -> string' could not be converted into '(number, number) -> number' + const std::string expected = R"(Type + '(number, number) -> string' +could not be converted into + '(number, number) -> number' caused by: - Return type is not compatible. Type 'string' could not be converted into 'number')"); + Return type is not compatible. +Type 'string' could not be converted into 'number')"; + CHECK_EQ(expected, toString(result.errors[0])); } TEST_CASE_FIXTURE(Fixture, "error_detailed_function_mismatch_ret_mult") { + // FIXME: CLI-116111 test disabled until type path stringification is improved + ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + CheckResult result = check(R"( type A = (number, number) -> (number, string) type B = (number, number) -> (number, boolean) @@ -1392,10 +1598,14 @@ local b: B = a )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(toString(result.errors[0]), - R"(Type '(number, number) -> (number, string)' could not be converted into '(number, number) -> (number, boolean)' + const std::string expected = R"(Type + '(number, number) -> (number, string)' +could not be converted into + '(number, number) -> (number, boolean)' caused by: - Return #2 type is not compatible. Type 'string' could not be converted into 'boolean')"); + Return #2 type is not compatible. +Type 'string' could not be converted into 'boolean')"; + CHECK_EQ(expected, toString(result.errors[0])); } TEST_CASE_FIXTURE(BuiltinsFixture, "function_decl_quantify_right_type") @@ -1458,9 +1668,24 @@ t.f = function(x) end )"); - LUAU_REQUIRE_ERROR_COUNT(2, result); - CHECK_EQ(toString(result.errors[0]), R"(Type 'string' could not be converted into 'number')"); - CHECK_EQ(toString(result.errors[1]), R"(Type 'string' could not be converted into 'number')"); + if (FFlag::LuauSolverV2) + { + LUAU_REQUIRE_ERROR_COUNT(2, result); + CHECK_EQ( + toString(result.errors[0]), + R"(Type function instance add depends on generic function parameters but does not appear in the function signature; this construct cannot be type-checked at this time)" + ); + CHECK_EQ( + toString(result.errors[1]), + R"(Type function instance add depends on generic function parameters but does not appear in the function signature; this construct cannot be type-checked at this time)" + ); + } + else + { + LUAU_REQUIRE_ERROR_COUNT(2, result); + CHECK_EQ(toString(result.errors[0]), R"(Type 'string' could not be converted into 'number')"); + CHECK_EQ(toString(result.errors[1]), R"(Type 'string' could not be converted into 'number')"); + } } TEST_CASE_FIXTURE(Fixture, "inferred_higher_order_functions_are_quantified_at_the_right_time2") @@ -1485,12 +1710,16 @@ TEST_CASE_FIXTURE(Fixture, "inferred_higher_order_functions_are_quantified_at_th if (!result.errors.empty()) { for (const auto& e : result.errors) - printf("%s %s: %s\n", e.moduleName.c_str(), toString(e.location).c_str(), toString(e).c_str()); + MESSAGE(e.moduleName << " " << toString(e.location) << ": " << toString(e)); } } TEST_CASE_FIXTURE(Fixture, "inferred_higher_order_functions_are_quantified_at_the_right_time3") { + // This test regresses in the new solver, but is sort of nonsensical insofar as `foo` is known to be `nil`, so it's "right" to not be able to call + // it. + ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + CheckResult result = check(R"( local foo @@ -1517,13 +1746,32 @@ t.f = function(x) end )"); - LUAU_REQUIRE_ERROR_COUNT(2, result); - CHECK_EQ(toString(result.errors[0]), R"(Type '(string) -> string' could not be converted into '((number) -> number)?' + if (FFlag::LuauSolverV2) + { + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ( + toString(result.errors[0]), + R"(Type function instance add depends on generic function parameters but does not appear in the function signature; this construct cannot be type-checked at this time)" + ); + } + else + { + LUAU_REQUIRE_ERROR_COUNT(2, result); + CHECK_EQ(toString(result.errors[0]), R"(Type + '(string) -> string' +could not be converted into + '((number) -> number)?' caused by: - None of the union options are compatible. For example: Type '(string) -> string' could not be converted into '(number) -> number' + None of the union options are compatible. For example: +Type + '(string) -> string' +could not be converted into + '(number) -> number' caused by: - Argument #1 type is not compatible. Type 'number' could not be converted into 'string')"); - CHECK_EQ(toString(result.errors[1]), R"(Type 'string' could not be converted into 'number')"); + Argument #1 type is not compatible. +Type 'number' could not be converted into 'string')"); + CHECK_EQ(toString(result.errors[1]), R"(Type 'string' could not be converted into 'number')"); + } } TEST_CASE_FIXTURE(Fixture, "strict_mode_ok_with_missing_arguments") @@ -1538,6 +1786,9 @@ TEST_CASE_FIXTURE(Fixture, "strict_mode_ok_with_missing_arguments") TEST_CASE_FIXTURE(Fixture, "function_statement_sealed_table_assignment_through_indexer") { + // FIXME: CLI-116122 bug where `t:b` does not check against the type from the indexer annotation on `t`. + ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + CheckResult result = check(R"( local t: {[string]: () -> number} = {} @@ -1546,10 +1797,15 @@ function t:b() return 2 end -- not OK )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(R"(Type '(*error-type*) -> number' could not be converted into '() -> number' + CHECK_EQ( + R"(Type + '(*error-type*) -> number' +could not be converted into + '() -> number' caused by: Argument count mismatch. Function expects 1 argument, but none are specified)", - toString(result.errors[0])); + toString(result.errors[0]) + ); } TEST_CASE_FIXTURE(Fixture, "too_few_arguments_variadic") @@ -1575,6 +1831,9 @@ TEST_CASE_FIXTURE(Fixture, "too_few_arguments_variadic") TEST_CASE_FIXTURE(Fixture, "too_few_arguments_variadic_generic") { + // FIXME: CLI-116157 variadic and generic type packs seem to be interacting incorrectly. + ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + CheckResult result = check(R"( function test(a: number, b: string, ...) return 1 @@ -1600,6 +1859,9 @@ wrapper(test) TEST_CASE_FIXTURE(BuiltinsFixture, "too_few_arguments_variadic_generic2") { + // FIXME: CLI-116157 variadic and generic type packs seem to be interacting incorrectly. + ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + CheckResult result = check(R"( function test(a: number, b: string, ...) return 1 @@ -1638,6 +1900,10 @@ TEST_CASE_FIXTURE(Fixture, "occurs_check_failure_in_function_return_type") TEST_CASE_FIXTURE(Fixture, "free_is_not_bound_to_unknown") { + // This test only makes sense for the old solver + if (FFlag::LuauSolverV2) + return; + CheckResult result = check(R"( local function foo(f: (unknown) -> (), x) f(x) @@ -1669,6 +1935,10 @@ TEST_CASE_FIXTURE(Fixture, "dont_infer_parameter_types_for_functions_from_their_ LUAU_REQUIRE_NO_ERRORS(result); CHECK_EQ("(a) -> a", toString(requireType("f"))); + if (FFlag::LuauSolverV2) + CHECK_EQ("({ read p: { read q: unknown } }) -> ~(false?)?", toString(requireType("g"))); + else + CHECK_EQ("({+ p: {+ q: nil +} +}) -> nil", toString(requireType("g"))); } TEST_CASE_FIXTURE(Fixture, "dont_mutate_the_underlying_head_of_typepack_when_calling_with_self") @@ -1713,24 +1983,43 @@ u.b().foo() )"); LUAU_REQUIRE_ERROR_COUNT(9, result); - CHECK_EQ(toString(result.errors[0]), "Argument count mismatch. Function 'foo1' expects 1 argument, but none are specified"); - CHECK_EQ(toString(result.errors[1]), "Argument count mismatch. Function 'foo2' expects 1 to 2 arguments, but none are specified"); - CHECK_EQ(toString(result.errors[2]), "Argument count mismatch. Function 'foo3' expects 1 to 3 arguments, but none are specified"); - CHECK_EQ(toString(result.errors[3]), "Argument count mismatch. Function 'string.find' expects 2 to 4 arguments, but none are specified"); - CHECK_EQ(toString(result.errors[4]), "Argument count mismatch. Function 't.foo' expects at least 1 argument, but none are specified"); - CHECK_EQ(toString(result.errors[5]), "Argument count mismatch. Function 't.bar' expects 2 to 3 arguments, but only 1 is specified"); - CHECK_EQ(toString(result.errors[6]), "Argument count mismatch. Function 'u.a.foo' expects at least 1 argument, but none are specified"); - CHECK_EQ(toString(result.errors[7]), "Argument count mismatch. Function 'u.a.foo' expects at least 1 argument, but none are specified"); - CHECK_EQ(toString(result.errors[8]), "Argument count mismatch. Function expects at least 1 argument, but none are specified"); + if (FFlag::LuauSolverV2) + { + // These improvements to the error messages are currently regressed in the new type solver. + CHECK_EQ(toString(result.errors[0]), "Argument count mismatch. Function expects 1 argument, but none are specified"); + CHECK_EQ(toString(result.errors[1]), "Argument count mismatch. Function expects 1 to 2 arguments, but none are specified"); + CHECK_EQ(toString(result.errors[2]), "Argument count mismatch. Function expects 1 to 3 arguments, but none are specified"); + CHECK_EQ(toString(result.errors[3]), "Argument count mismatch. Function expects 2 to 4 arguments, but none are specified"); + CHECK_EQ(toString(result.errors[4]), "Argument count mismatch. Function expects at least 1 argument, but none are specified"); + CHECK_EQ(toString(result.errors[5]), "Argument count mismatch. Function expects 2 to 3 arguments, but only 1 is specified"); + CHECK_EQ(toString(result.errors[6]), "Argument count mismatch. Function expects at least 1 argument, but none are specified"); + CHECK_EQ(toString(result.errors[7]), "Argument count mismatch. Function expects at least 1 argument, but none are specified"); + CHECK_EQ(toString(result.errors[8]), "Argument count mismatch. Function expects at least 1 argument, but none are specified"); + } + else + { + CHECK_EQ(toString(result.errors[0]), "Argument count mismatch. Function 'foo1' expects 1 argument, but none are specified"); + CHECK_EQ(toString(result.errors[1]), "Argument count mismatch. Function 'foo2' expects 1 to 2 arguments, but none are specified"); + CHECK_EQ(toString(result.errors[2]), "Argument count mismatch. Function 'foo3' expects 1 to 3 arguments, but none are specified"); + CHECK_EQ(toString(result.errors[3]), "Argument count mismatch. Function 'string.find' expects 2 to 4 arguments, but none are specified"); + CHECK_EQ(toString(result.errors[4]), "Argument count mismatch. Function 't.foo' expects at least 1 argument, but none are specified"); + CHECK_EQ(toString(result.errors[5]), "Argument count mismatch. Function 't.bar' expects 2 to 3 arguments, but only 1 is specified"); + CHECK_EQ(toString(result.errors[6]), "Argument count mismatch. Function 'u.a.foo' expects at least 1 argument, but none are specified"); + CHECK_EQ(toString(result.errors[7]), "Argument count mismatch. Function 'u.a.foo' expects at least 1 argument, but none are specified"); + CHECK_EQ(toString(result.errors[8]), "Argument count mismatch. Function expects at least 1 argument, but none are specified"); + } } // This might be surprising, but since 'any' became optional, unannotated functions in non-strict 'expect' 0 arguments TEST_CASE_FIXTURE(BuiltinsFixture, "improved_function_arg_mismatch_error_nonstrict") { + // This behavior is not part of the current specification of the new type solver. + ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + CheckResult result = check(R"( ---!nonstrict -local function foo(a, b) end -foo(string.find("hello", "e")) + --!nonstrict + local function foo(a, b) end + foo(string.find("hello", "e")) )"); LUAU_REQUIRE_ERROR_COUNT(1, result); @@ -1739,6 +2028,9 @@ foo(string.find("hello", "e")) TEST_CASE_FIXTURE(Fixture, "luau_subtyping_is_np_hard") { + // The case that _should_ succeed here (`z = x`) does not currently in the new solver. + ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + CheckResult result = check(R"( --!strict @@ -1771,16 +2063,15 @@ z = y -- Not OK, so the line is colorable )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(toString(result.errors[0]), - "Type '((\"blue\" | \"red\") -> (\"blue\" | \"red\") -> (\"blue\" | \"red\") -> boolean) & ((\"blue\" | \"red\") -> (\"blue\") -> (\"blue\") " - "-> false) & ((\"blue\" | \"red\") -> (\"red\") -> (\"red\") -> false) & ((\"blue\") -> (\"blue\") -> (\"blue\" | \"red\") -> false) & " - "((\"red\") -> (\"red\") -> (\"blue\" | \"red\") -> false)' could not be converted into '(\"blue\" | \"red\") -> (\"blue\" | \"red\") -> " - "(\"blue\" | \"red\") -> false'; none of the intersection parts are compatible"); + const std::string expected = R"(Type + '(("blue" | "red") -> ("blue" | "red") -> ("blue" | "red") -> boolean) & (("blue" | "red") -> ("blue") -> ("blue") -> false) & (("blue" | "red") -> ("red") -> ("red") -> false) & (("blue") -> ("blue") -> ("blue" | "red") -> false) & (("red") -> ("red") -> ("blue" | "red") -> false)' +could not be converted into + '("blue" | "red") -> ("blue" | "red") -> ("blue" | "red") -> false'; none of the intersection parts are compatible)"; + CHECK_EQ(expected, toString(result.errors[0])); } TEST_CASE_FIXTURE(Fixture, "function_is_supertype_of_concrete_functions") { - ScopedFastFlag sff{"LuauNegatedFunctionTypes", true}; registerHiddenTypes(&frontend); CheckResult result = check(R"( @@ -1799,7 +2090,6 @@ TEST_CASE_FIXTURE(Fixture, "function_is_supertype_of_concrete_functions") TEST_CASE_FIXTURE(Fixture, "concrete_functions_are_not_supertypes_of_function") { - ScopedFastFlag sff{"LuauNegatedFunctionTypes", true}; registerHiddenTypes(&frontend); CheckResult result = check(R"( @@ -1815,12 +2105,20 @@ TEST_CASE_FIXTURE(Fixture, "concrete_functions_are_not_supertypes_of_function") LUAU_REQUIRE_ERROR_COUNT(2, result); CHECK(6 == result.errors[0].location.begin.line); + auto tm1 = get(result.errors[0]); + REQUIRE(tm1); + CHECK("() -> ()" == toString(tm1->wantedType)); + CHECK("function" == toString(tm1->givenType)); + CHECK(7 == result.errors[1].location.begin.line); + auto tm2 = get(result.errors[1]); + REQUIRE(tm2); + CHECK("(T) -> T" == toString(tm2->wantedType)); + CHECK("function" == toString(tm2->givenType)); } TEST_CASE_FIXTURE(Fixture, "other_things_are_not_related_to_function") { - ScopedFastFlag sff{"LuauNegatedFunctionTypes", true}; registerHiddenTypes(&frontend); CheckResult result = check(R"( @@ -1851,4 +2149,849 @@ end LUAU_REQUIRE_ERRORS(result); } +TEST_CASE_FIXTURE(BuiltinsFixture, "dont_assert_when_the_tarjan_limit_is_exceeded_during_generalization") +{ + ScopedFastFlag sff{FFlag::LuauSolverV2, true}; + ScopedFastInt sfi{FInt::LuauTarjanChildLimit, 1}; + + CheckResult result = check(R"( + function f(t) + t.x.y.z = 441 + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CHECK_MESSAGE(get(result.errors[0]), "Expected UnificationTooComplex but got: " << toString(result.errors[0])); +} + +/* We had a bug under DCR where instantiated type packs had a nullptr scope. + * + * This caused an issue with promotion. + */ +TEST_CASE_FIXTURE(Fixture, "instantiated_type_packs_must_have_a_non_null_scope") +{ + CheckResult result = check(R"( + function pcall(...: (A...) -> R...): (boolean, R...) + return nil :: any + end + + type Dispatch = (A) -> () + + function mountReducer() + dispatchAction() + return nil :: any + end + + function dispatchAction() + end + + function useReducer(): Dispatch + local result, setResult = pcall(mountReducer) + return setResult + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "inner_frees_become_generic_in_dcr") +{ + if (!FFlag::LuauSolverV2) + return; + + CheckResult result = check(R"( + function f(x) + local z = x + return x + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + std::optional ty = findTypeAtPosition(Position{3, 19}); + REQUIRE(ty); + CHECK(get(follow(*ty))); +} + +TEST_CASE_FIXTURE(Fixture, "function_exprs_are_generalized_at_signature_scope_not_enclosing") +{ + CheckResult result = check(R"( + local foo + local bar + + -- foo being a function expression is deliberate: the bug we're testing + -- only existed for function expressions, not for function statements. + foo = function(a) + return bar + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + if (FFlag::LuauSolverV2) + CHECK(toString(requireType("foo")) == "((unknown) -> nil)?"); + else + { + // note that b is not in the generic list; it is free, the unconstrained type of `bar`. + CHECK(toString(requireType("foo")) == "(a) -> b"); + } +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "param_1_and_2_both_takes_the_same_generic_but_their_arguments_are_incompatible") +{ + CheckResult result = check(R"( + local function foo(x: a, y: a?) + return x + end + local vec2 = { x = 5, y = 7 } + local ret: number = foo(vec2, { x = 5 }) + )"); + + if (FFlag::LuauSolverV2) + { + LUAU_REQUIRE_ERROR_COUNT(1, result); + + auto tm = get(result.errors[0]); + REQUIRE(tm); + CHECK("number" == toString(tm->wantedType)); + CHECK("{ x: number }" == toString(tm->givenType)); + } + else + { + // In the old solver, this produces a very strange result: + // + // Here, we instantiate `(x: a, y: a?) -> a` with a fresh type `'a` for `a`. + // In argument #1, we unify `vec2` with `'a`. + // This is ok, so we record an equality constraint `'a` with `vec2`. + // In argument #2, we unify `{ x: number }` with `'a?`. + // This fails because `'a` has equality constraint with `vec2`, + // so `{ x: number } <: vec2?`, which is false. + // + // If the unifications were to be committed, then it'd result in the following type error: + // + // Type '{ x: number }' could not be converted into 'vec2?' + // caused by: + // [...] Table type '{ x: number }' not compatible with type 'vec2' because the former is missing field 'y' + // + // However, whenever we check the argument list, if there's an error, we don't commit the unifications, so it actually looks like this: + // + // Type '{ x: number }' could not be converted into 'a?' + // caused by: + // [...] Table type '{ x: number }' not compatible with type 'vec2' because the former is missing field 'y' + // + // Then finally, that generic is left floating free, and since the function returns that generic, + // that free type is then later bound to `number`, which succeeds and mutates the type graph. + // This again changes the type error where `a` becomes bound to `number`. + // + // Type '{ x: number }' could not be converted into 'number?' + // caused by: + // [...] Table type '{ x: number }' not compatible with type 'vec2' because the former is missing field 'y' + // + // Uh oh, that type error is extremely confusing for people who doesn't know how that went down. + // Really, what should happen is we roll each argument incompatibility into a union type, but that needs local type inference. + + LUAU_REQUIRE_ERROR_COUNT(2, result); + + const std::string expected = R"(Type '{ x: number }' could not be converted into 'vec2?' +caused by: + None of the union options are compatible. For example: +Table type '{ x: number }' not compatible with type 'vec2' because the former is missing field 'y')"; + CHECK_EQ(expected, toString(result.errors[0])); + CHECK_EQ("Type 'vec2' could not be converted into 'number'", toString(result.errors[1])); + } +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "param_1_and_2_both_takes_the_same_generic_but_their_arguments_are_incompatible_2") +{ + CheckResult result = check(R"( + local function f(x: a, y: a): a + return if math.random() > 0.5 then x else y + end + + local z: boolean = f(5, "five") + )"); + + if (FFlag::LuauSolverV2) + { + LUAU_REQUIRE_ERROR_COUNT(1, result); + + auto tm = get(result.errors[0]); + REQUIRE(tm); + CHECK("boolean" == toString(tm->wantedType)); + CHECK("number | string" == toString(tm->givenType)); + } + else + { + LUAU_REQUIRE_ERROR_COUNT(2, result); + + CHECK_EQ(toString(result.errors[0]), "Type 'string' could not be converted into 'number'"); + CHECK_EQ(toString(result.errors[1]), "Type 'number' could not be converted into 'boolean'"); + } +} + +TEST_CASE_FIXTURE(Fixture, "attempt_to_call_an_intersection_of_tables") +{ + CheckResult result = check(R"( + local function f(t: { x: number } & { y: string }) + t() + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + if (FFlag::LuauSolverV2) + CHECK_EQ(toString(result.errors[0]), "Cannot call a value of type { x: number } & { y: string }"); + else + CHECK_EQ(toString(result.errors[0]), "Cannot call a value of type {| x: number |}"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "attempt_to_call_an_intersection_of_tables_with_call_metamethod") +{ + CheckResult result = check(R"( + type Callable = typeof(setmetatable({}, { + __call = function(self, ...) return ... end + })) + + local function f(t: Callable & { x: number }) + t() + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "generic_packs_are_not_variadic") +{ + ScopedFastFlag sff{FFlag::LuauSolverV2, true}; + + CheckResult result = check(R"( + local function apply(f: (a, b...) -> c..., x: a) + return f(x) + end + + local function add(x: number, y: number) + return x + y + end + + apply(add, 5) + )"); + + // FIXME: this errored at some point, but doesn't anymore. + // the desired behavior here is erroring. + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "num_is_solved_before_num_or_str") +{ + CheckResult result = check(R"( + function num() + return 5 + end + + local function num_or_str() + if math.random() > 0.5 then + return num() + else + return "some string" + end + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + if (FFlag::LuauSolverV2) + CHECK(toString(result.errors.at(0)) == "Type pack 'string' could not be converted into 'number'; at [0], string is not a subtype of number"); + else + CHECK_EQ("Type 'string' could not be converted into 'number'", toString(result.errors[0])); + + CHECK_EQ("() -> number", toString(requireType("num_or_str"))); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "num_is_solved_after_num_or_str") +{ + CheckResult result = check(R"( + local function num_or_str() + if math.random() > 0.5 then + return num() + else + return "some string" + end + end + + function num() + return 5 + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + if (FFlag::LuauSolverV2) + CHECK(toString(result.errors.at(0)) == "Type pack 'string' could not be converted into 'number'; at [0], string is not a subtype of number"); + else + CHECK_EQ("Type 'string' could not be converted into 'number'", toString(result.errors[0])); + CHECK_EQ("() -> number", toString(requireType("num_or_str"))); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "apply_of_lambda_with_inferred_and_explicit_types") +{ + CheckResult result = check(R"( + local function apply(f, x) return f(x) end + local x = apply(function(x: string): number return 5 end, "hello!") + + local function apply_explicit(f: (A) -> B..., x: A): B... return f(x) end + local x = apply_explicit(function(x: string): number return 5 end, "hello!") + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "regex_benchmark_string_format_minimization") +{ + CheckResult result = check(R"( + (nil :: any)(function(n) + if tonumber(n) then + n = tonumber(n) + elseif n ~= nil then + string.format("invalid argument #4 to 'sub': number expected, got %s", typeof(n)) + end + end); + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "subgeneric_type_function_super_monomorphic") +{ + CheckResult result = check(R"( +local a: (number, number) -> number = function(a, b) return a - b end + +a = function(a, b) return a + b end +)"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "simple_unannotated_mutual_recursion") +{ + // CLI-117118 - TypeInferFunctions.simple_unannotated_mutual_recursion relies on unstable assertions to pass. + if (FFlag::LuauSolverV2) + return; + CheckResult result = check(R"( +function even(n) + if n == 0 then + return true + else + return odd(n - 1) + end +end + +function odd(n) + if n == 0 then + return false + elseif n == 1 then + return true + else + return even(n - 1) + end +end +)"); + + if (FFlag::LuauSolverV2) + { + LUAU_REQUIRE_ERROR_COUNT(5, result); + // CLI-117117 Constraint solving is incomplete inTypeInferFunctions.simple_unannotated_mutual_recursion + CHECK(get(result.errors[0])); + // This check is unstable between different machines and different runs of DCR because it depends on string equality between + // blocked type numbers, which is not guaranteed. + bool r = toString(result.errors[1]) == "Type pack '*blocked-tp-1*' could not be converted into 'boolean'; type *blocked-tp-1*.tail() " + "(*blocked-tp-1*) is not a subtype of boolean (boolean)"; + CHECK(r); + CHECK( + toString(result.errors[2]) == + "Operator '-' could not be applied to operands of types unknown and number; there is no corresponding overload for __sub" + ); + CHECK( + toString(result.errors[3]) == + "Operator '-' could not be applied to operands of types unknown and number; there is no corresponding overload for __sub" + ); + CHECK( + toString(result.errors[4]) == + "Operator '-' could not be applied to operands of types unknown and number; there is no corresponding overload for __sub" + ); + } + else + { + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK(toString(result.errors[0]) == "Unknown type used in - operation; consider adding a type annotation to 'n'"); + } +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "simple_lightly_annotated_mutual_recursion") +{ + CheckResult result = check(R"( +function even(n: number) + if n == 0 then + return true + else + return odd(n - 1) + end +end + +function odd(n: number) + if n == 0 then + return false + elseif n == 1 then + return true + else + return even(n - 1) + end +end +)"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("(number) -> boolean", toString(requireType("even"))); + CHECK_EQ("(number) -> boolean", toString(requireType("odd"))); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "tf_suggest_return_type") +{ + if (!FFlag::LuauSolverV2) + return; + CheckResult result = check(R"( +function fib(n) + return n < 2 and 1 or fib(n-1) + fib(n-2) +end +)"); + + LUAU_REQUIRE_ERRORS(result); + auto err = get(result.errors.back()); + LUAU_ASSERT(err); + CHECK("number" == toString(err->recommendedReturn)); + REQUIRE(1 == err->recommendedArgs.size()); + CHECK("number" == toString(err->recommendedArgs[0].second)); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "tf_suggest_arg_type") +{ + if (!FFlag::LuauSolverV2) + return; + CheckResult result = check(R"( +function fib(n, u) + return (n or u) and (n < u and n + fib(n,u)) +end +)"); + + LUAU_REQUIRE_ERRORS(result); + auto err = get(result.errors.back()); + LUAU_ASSERT(err); + CHECK("number" == toString(err->recommendedReturn)); + REQUIRE(err->recommendedArgs.size() == 2); + CHECK("number" == toString(err->recommendedArgs[0].second)); + CHECK("number" == toString(err->recommendedArgs[1].second)); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "tf_suggest_arg_type_2") +{ + if (!FFlag::LuauSolverV2) + return; + + // Make sure the error types are cloned to module interface + frontend.options.retainFullTypeGraphs = false; + + CheckResult result = check(R"( + local function escape_fslash(pre) + return (#pre % 2 == 0 and '\\' or '') .. pre .. '.' + end + )"); + + LUAU_REQUIRE_ERRORS(result); + auto err = get(result.errors.back()); + REQUIRE(err); + CHECK("a" == toString(err->ty)); +} + +TEST_CASE_FIXTURE(Fixture, "local_function_fwd_decl_doesnt_crash") +{ + CheckResult result = check(R"( + local foo + + local function bar() + foo() + end + + function foo() + end + + bar() + )"); + + // This test verifies that an ICE doesn't occur, so the bulk of the test is + // just from running check above. + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "bidirectional_checking_of_callback_property") +{ + CheckResult result = check(R"( + function print(x: number) end + + type Point = {x: number, y: number} + local T : {callback: ((Point) -> ())?} = {} + + T.callback = function(p) -- No error here + print(p.z) -- error here. Point has no property z + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + if (FFlag::LuauSolverV2) + { + auto tm = get(result.errors[0]); + REQUIRE(tm); + + CHECK("((Point) -> ())?" == toString(tm->wantedType)); + CHECK("({ read z: number }) -> ()" == toString(tm->givenType)); + + Location location = result.errors[0].location; + CHECK(location.begin.line == 6); + CHECK(location.end.line == 8); + } + else + { + CHECK_MESSAGE(get(result.errors[0]), "Expected UnknownProperty but got " << result.errors[0]); + + Location location = result.errors[0].location; + CHECK(location.begin.line == 7); + CHECK(location.end.line == 7); + } +} + +TEST_CASE_FIXTURE(ClassFixture, "bidirectional_inference_of_class_methods") +{ + CheckResult result = check(R"( + local c = ChildClass.New() + + -- Instead of reporting that the lambda is the wrong type, report that we are using its argument improperly. + c.Touched:Connect(function(other) + print(other.ThisDoesNotExist) + end) + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + UnknownProperty* err = get(result.errors[0]); + REQUIRE(err); + + CHECK("ThisDoesNotExist" == err->key); + CHECK("BaseClass" == toString(err->table)); +} + +TEST_CASE_FIXTURE(Fixture, "pass_table_literal_to_function_expecting_optional_prop") +{ + CheckResult result = check(R"( + type T = {prop: number?} + + function f(t: T) end + + f({prop=5}) + f({}) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "dont_infer_overloaded_functions") +{ + CheckResult result = check(R"( + function getR6Attachments(model) + model:FindFirstChild("Right Leg") + model:FindFirstChild("Left Leg") + model:FindFirstChild("Torso") + model:FindFirstChild("Torso") + model:FindFirstChild("Head") + model:FindFirstChild("Left Arm") + model:FindFirstChild("Right Arm") + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + if (FFlag::LuauSolverV2) + CHECK("(t1) -> () where t1 = { read FindFirstChild: (t1, string) -> (...unknown) }" == toString(requireType("getR6Attachments"))); + else + CHECK("(t1) -> () where t1 = {+ FindFirstChild: (t1, string) -> (a...) +}" == toString(requireType("getR6Attachments"))); +} + +TEST_CASE_FIXTURE(Fixture, "param_y_is_bounded_by_x_of_type_string") +{ + CheckResult result = check(R"( + local function f(x: string, y) + x = y + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK("(string, string) -> ()" == toString(requireType("f"))); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "function_that_could_return_anything_is_compatible_with_function_that_is_expected_to_return_nothing") +{ + CheckResult result = check(R"( + -- We infer foo : (g: (number) -> (...unknown)) -> () + function foo(g) + g(0) + end + + -- a requires a function that returns no values + function a(f: ((number) -> ()) -> ()) + end + + -- "Returns an unknown number of values" is close enough to "returns no values." + a(foo) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "self_application_does_not_segfault") +{ + (void)check(R"( + function f(a) + f(f) + return f(), a + end + )"); + + // We only care that type checking completes without tripping a crash or an assertion. +} + +TEST_CASE_FIXTURE(Fixture, "function_definition_in_a_do_block") +{ + CheckResult result = check(R"( + local f + do + function f() + end + end + f() + )"); + + // We are predominantly interested in this test not crashing. + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "function_definition_in_a_do_block_with_global") +{ + CheckResult result = check(R"( + function f() print("a") end + do + function f() + print("b") + end + end + f() + )"); + + // We are predominantly interested in this test not crashing. + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "fuzzer_alias_global_function_doesnt_hit_nil_assert") +{ + CheckResult result = check(R"( +function _() +end +local function l0() + function _() + end +end +_ = _ +)"); + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "fuzzer_bug_missing_follow_causes_assertion") +{ + CheckResult result = check(R"( +local _ = ({_=function() +return _ +end,}),true,_[_()] +for l0=_[_[_[`{function(l0) +end}`]]],_[_.n6[_[_.n6]]],_[_[_.n6[_[_.n6]]]] do +_ += if _ then "" +end +return _ +)"); +} + +TEST_CASE_FIXTURE(Fixture, "cannot_call_union_of_functions") +{ + CheckResult result = check(R"( + local f: (() -> ()) | (() -> () -> ()) = nil :: any + f() + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + std::string expected = R"(Cannot call a value of the union type: + | () -> () + | () -> () -> () +We are unable to determine the appropriate result type for such a call.)"; + + CHECK(expected == toString(result.errors[0])); +} + +TEST_CASE_FIXTURE(Fixture, "fuzzer_missing_follow_in_ast_stat_fun") +{ + (void)check(R"( + local _ = function() + end ~= _ + + while (_) do + _,_,_,_,_,_,_,_,_,_._,_ = nil + function _(...):()->() + end + function _(...):any + _ ..= ... + end + _,_,_,_,_,_,_,_,_,_,_ = nil + end + )"); +} + +TEST_CASE_FIXTURE(Fixture, "unifier_should_not_bind_free_types") +{ + CheckResult result = check(R"( + function foo(player) + local success,result = player:thing() + if(success) then + return "Successfully posted message."; + elseif(not result) then + return false; + else + return result; + end + end + )"); + + if (FFlag::LuauSolverV2) + { + // The new solver should ideally be able to do better here, but this is no worse than the old solver. + + LUAU_REQUIRE_ERROR_COUNT(2, result); + + auto tm1 = get(result.errors[0]); + REQUIRE(tm1); + CHECK(toString(tm1->wantedTp) == "string"); + CHECK(toString(tm1->givenTp) == "boolean"); + + auto tm2 = get(result.errors[1]); + REQUIRE(tm2); + CHECK(toString(tm2->wantedTp) == "string"); + CHECK(toString(tm2->givenTp) == "buffer | class | function | number | string | table | thread | true"); + } + else + { + LUAU_REQUIRE_ERROR_COUNT(1, result); + + const TypeMismatch* tm = get(result.errors[0]); + REQUIRE(tm); + CHECK(toString(tm->wantedType) == "string"); + CHECK(toString(tm->givenType) == "boolean"); + } +} + +TEST_CASE_FIXTURE(Fixture, "captured_local_is_assigned_a_function") +{ + CheckResult result = check(R"( + local f + + local function g() + f() + end + + function f() + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "error_suppression_propagates_through_function_calls") +{ + CheckResult result = check(R"( + function first(x: any) + return pairs(x)(x) + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK("(any) -> (any?, any)" == toString(requireType("first"))); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "fuzzer_normalizer_out_of_resources") +{ + // This luau code should finish typechecking, not segfault upon dereferencing + // the normalized type + CheckResult result = check(R"( + Module 'l0': +local _ = true,...,_ +if ... then +while _:_(_._G) do +do end +_ = _ and _ +_ = 0 and {# _,} +local _ = "CCCCCCCCCCCCCCCCCCCCCCCCCCC" +local l0 = require(module0) +end +local function l0() +end +elseif _ then +l0 = _ +end +do end +while _ do +_ = if _ then _ elseif _ then _,if _ then _ else _ +_ = _() +do end +do end +if _ then +end +end +_ = _,{} + + )"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "overload_resolution_crash_when_argExprs_is_smaller_than_type_args") +{ + CheckResult result = check(R"( +--!strict +local parseError +type Set = {[T]: any} +local function captureDependencies( + saveToSet: Set, + callback: (...any) -> any, + ... +) + local data = table.pack(xpcall(callback, parseError, ...)) + end +)"); +} + +TEST_CASE_FIXTURE(Fixture, "unpack_depends_on_rhs_pack_to_be_fully_resolved") +{ + CheckResult result = check(R"( +--!strict +local function id(x) + return x +end +local u,v = id(3), id(id(44)) +)"); + + CHECK_EQ(builtinTypes->numberType, requireType("v")); + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.generics.test.cpp b/tests/TypeInfer.generics.test.cpp index 0ba889c89..a84a0206c 100644 --- a/tests/TypeInfer.generics.test.cpp +++ b/tests/TypeInfer.generics.test.cpp @@ -7,10 +7,11 @@ #include "Fixture.h" +#include "ScopedFlags.h" #include "doctest.h" -LUAU_FASTFLAG(LuauInstantiateInSubtyping) -LUAU_FASTFLAG(LuauTypeMismatchInvarianceInError) +LUAU_FASTFLAG(LuauInstantiateInSubtyping); +LUAU_FASTFLAG(LuauSolverV2); using namespace Luau; @@ -26,6 +27,8 @@ TEST_CASE_FIXTURE(Fixture, "check_generic_function") local y: number = id(37) )"); LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ(builtinTypes->stringType, requireType("x")); + CHECK_EQ(builtinTypes->numberType, requireType("y")); } TEST_CASE_FIXTURE(Fixture, "check_generic_local_function") @@ -38,6 +41,40 @@ TEST_CASE_FIXTURE(Fixture, "check_generic_local_function") local y: number = id(37) )"); LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ(builtinTypes->stringType, requireType("x")); + CHECK_EQ(builtinTypes->numberType, requireType("y")); +} + +TEST_CASE_FIXTURE(Fixture, "check_generic_local_function2") +{ + CheckResult result = check(R"( + local function id(x:a): a + return x + end + local x = id("hi") + local y = id(37) + )"); + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ(builtinTypes->stringType, requireType("x")); + CHECK_EQ(builtinTypes->numberType, requireType("y")); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "unions_and_generics") +{ + CheckResult result = check(R"( + type foo = (T | {T}) -> T + local foo = (nil :: any) :: foo + + type Test = number | {number} + local res = foo(1 :: Test) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + if (FFlag::LuauSolverV2) + CHECK_EQ("number | {number}", toString(requireType("res"))); + else // in the old solver, this just totally falls apart + CHECK_EQ("a", toString(requireType("res"))); } TEST_CASE_FIXTURE(Fixture, "check_generic_typepack_function") @@ -106,6 +143,8 @@ TEST_CASE_FIXTURE(Fixture, "properties_can_be_polytypes") TEST_CASE_FIXTURE(Fixture, "properties_can_be_instantiated_polytypes") { + ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + CheckResult result = check(R"( local t: { m: (number)->number } = { m = function(x:number) return x+1 end } local function id(x:a):a return x end @@ -143,12 +182,12 @@ TEST_CASE_FIXTURE(Fixture, "check_recursive_generic_function") TEST_CASE_FIXTURE(Fixture, "check_mutual_generic_functions") { CheckResult result = check(R"( - local id2 - local function id1(x:a):a + function id1(x:a):a local y: string = id2("hi") local z: number = id2(37) return x end + function id2(x:a):a local y: string = id1("hi") local z: number = id1(37) @@ -158,8 +197,72 @@ TEST_CASE_FIXTURE(Fixture, "check_mutual_generic_functions") LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "generic_functions_in_types") +TEST_CASE_FIXTURE(Fixture, "check_mutual_generic_functions_unannotated") +{ + if (!FFlag::LuauSolverV2) + return; + + CheckResult result = check(R"( + function id1(x) + local y: string = id2("hi") + local z: number = id2(37) + return x + end + + function id2(x) + local y: string = id1("hi") + local z: number = id1(37) + return x + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "check_mutual_generic_functions_errors") +{ + if (!FFlag::LuauSolverV2) + return; + + CheckResult result = check(R"( + function id1(x) + local y: string = id2(37) -- odd + local z: number = id2("hi") -- even + return x + end + + function id2(x) + local y: string = id1(37) -- odd + local z: number = id1("hi") -- even + return x + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(4, result); + + // odd errors + for (int i = 0; i < 4; i += 2) + { + TypeMismatch* tm = get(result.errors[i]); + REQUIRE(tm); + CHECK_EQ("string", toString(tm->wantedType)); + CHECK_EQ("number", toString(tm->givenType)); + } + + // even errors + for (int i = 1; i < 4; i += 2) + { + TypeMismatch* tm = get(result.errors[i]); + REQUIRE(tm); + CHECK_EQ("number", toString(tm->wantedType)); + CHECK_EQ("string", toString(tm->givenType)); + } +} + +TEST_CASE_FIXTURE(Fixture, "generic_functions_in_types_old_solver") { + ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + CheckResult result = check(R"( type T = { id: (a) -> a } local x: T = { id = function(x:a):a return x end } @@ -169,8 +272,23 @@ TEST_CASE_FIXTURE(Fixture, "generic_functions_in_types") LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(Fixture, "generic_functions_in_types_new_solver") +{ + ScopedFastFlag sff{FFlag::LuauSolverV2, true}; + + CheckResult result = check(R"( + type T = { read id: (a) -> a } + local x: T = { id = function(x:a):a return x end } + local y: string = x.id("hi") + local z: number = x.id(37) + )"); + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_CASE_FIXTURE(Fixture, "generic_factories") { + ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + CheckResult result = check(R"( type T = { id: (a) -> a } type Factory = { build: () -> T } @@ -192,6 +310,8 @@ TEST_CASE_FIXTURE(Fixture, "generic_factories") TEST_CASE_FIXTURE(Fixture, "factories_of_generics") { + ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + CheckResult result = check(R"( type T = { id: (a) -> a } type Factory = { build: () -> T } @@ -272,26 +392,6 @@ TEST_CASE_FIXTURE(Fixture, "infer_nested_generic_function") LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "infer_generic_methods") -{ - ScopedFastFlag sff{"DebugLuauSharedSelf", true}; - - CheckResult result = check(R"( - local x = {} - function x:id(x) return x end - function x:f(): string return self:id("hello") end - function x:g(): number return self:id(37) end - )"); - - if (FFlag::DebugLuauDeferredConstraintResolution) - LUAU_REQUIRE_NO_ERRORS(result); - else - { - // TODO: Quantification should be doing the conversion, not normalization. - LUAU_REQUIRE_ERRORS(result); - } -} - TEST_CASE_FIXTURE(Fixture, "calling_self_generic_methods") { CheckResult result = check(R"( @@ -303,7 +403,6 @@ TEST_CASE_FIXTURE(Fixture, "calling_self_generic_methods") end )"); - // TODO: Should typecheck but currently errors CLI-54277 LUAU_REQUIRE_ERRORS(result); } @@ -367,7 +466,15 @@ TEST_CASE_FIXTURE(Fixture, "dont_leak_generic_types") -- so this assignment should fail local b: boolean = f(true) )"); - LUAU_REQUIRE_ERRORS(result); + + if (FFlag::LuauSolverV2) + { + LUAU_REQUIRE_NO_ERRORS(result); + } + else + { + LUAU_REQUIRE_ERRORS(result); + } } TEST_CASE_FIXTURE(Fixture, "dont_leak_inferred_generic_types") @@ -383,7 +490,14 @@ TEST_CASE_FIXTURE(Fixture, "dont_leak_inferred_generic_types") local y: number = id(37) end )"); - LUAU_REQUIRE_ERRORS(result); + if (FFlag::LuauSolverV2) + { + LUAU_REQUIRE_NO_ERRORS(result); + } + else + { + LUAU_REQUIRE_ERRORS(result); + } } TEST_CASE_FIXTURE(Fixture, "dont_substitute_bound_types") @@ -533,7 +647,12 @@ TEST_CASE_FIXTURE(Fixture, "generic_type_pack_parentheses") function f(...: a...): any return (...) end )"); - LUAU_REQUIRE_ERROR_COUNT(1, result); + // This should really error, but the error from the old solver is wrong. + // `a...` is a generic type pack, and we don't know that it will be non-empty, thus this code may not work. + if (FFlag::LuauSolverV2) + LUAU_REQUIRE_NO_ERRORS(result); + else + LUAU_REQUIRE_ERROR_COUNT(1, result); } TEST_CASE_FIXTURE(Fixture, "better_mismatch_error_messages") @@ -548,13 +667,27 @@ TEST_CASE_FIXTURE(Fixture, "better_mismatch_error_messages") end )"); - LUAU_REQUIRE_ERROR_COUNT(2, result); - SwappedGenericTypeParameter* fErr = get(result.errors[0]); + SwappedGenericTypeParameter* fErr; + SwappedGenericTypeParameter* gErr; + + if (FFlag::LuauSolverV2) + { + LUAU_REQUIRE_ERROR_COUNT(3, result); + // The first error here is an unknown symbol that is redundant with the `fErr`. + fErr = get(result.errors[1]); + gErr = get(result.errors[2]); + } + else + { + LUAU_REQUIRE_ERROR_COUNT(2, result); + fErr = get(result.errors[0]); + gErr = get(result.errors[1]); + } + REQUIRE(fErr); CHECK_EQ(fErr->name, "T"); CHECK_EQ(fErr->kind, SwappedGenericTypeParameter::Pack); - SwappedGenericTypeParameter* gErr = get(result.errors[1]); REQUIRE(gErr); CHECK_EQ(gErr->name, "T"); CHECK_EQ(gErr->kind, SwappedGenericTypeParameter::Type); @@ -640,17 +773,19 @@ return exports LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "instantiated_function_argument_names") +TEST_CASE_FIXTURE(Fixture, "instantiated_function_argument_names_old_solver") { + ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + CheckResult result = check(R"( -local function f(a: T, ...: U...) end + local function f(a: T, ...: U...) end -f(1, 2, 3) + f(1, 2, 3) )"); LUAU_REQUIRE_NO_ERRORS(result); - auto ty = findTypeAtPosition(Position(3, 0)); + auto ty = findTypeAtPosition(Position(3, 8)); REQUIRE(ty); ToStringOptions opts; opts.functionTypeArguments = true; @@ -659,6 +794,8 @@ f(1, 2, 3) TEST_CASE_FIXTURE(Fixture, "error_detailed_function_mismatch_generic_types") { + ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + CheckResult result = check(R"( type C = () -> () type D = () -> () @@ -674,6 +811,8 @@ local d: D = c TEST_CASE_FIXTURE(Fixture, "error_detailed_function_mismatch_generic_pack") { + ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + CheckResult result = check(R"( type C = () -> () type D = () -> () @@ -684,14 +823,16 @@ local d: D = c LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(toString(result.errors[0]), - R"(Type '() -> ()' could not be converted into '() -> ()'; different number of generic type pack parameters)"); + CHECK_EQ( + toString(result.errors[0]), + R"(Type '() -> ()' could not be converted into '() -> ()'; different number of generic type pack parameters)" + ); } TEST_CASE_FIXTURE(BuiltinsFixture, "generic_functions_dont_cache_type_parameters") { CheckResult result = check(R"( --- See https://github.com/Roblox/luau/issues/332 +-- See https://github.com/luau-lang/luau/issues/332 -- This function has a type parameter with the same name as clones, -- so if we cache type parameter names for functions these get confused. -- function id(x : Z) : Z @@ -725,28 +866,29 @@ y.a.c = y )"); LUAU_REQUIRE_ERRORS(result); - if (FFlag::LuauTypeMismatchInvarianceInError) - { - CHECK_EQ(toString(result.errors[0]), - R"(Type 'y' could not be converted into 'T' -caused by: - Property 'a' is not compatible. Type '{ c: T?, d: number }' could not be converted into 'U' -caused by: - Property 'd' is not compatible. Type 'number' could not be converted into 'string' in an invariant context)"); - } + + if (FFlag::LuauSolverV2) + CHECK( + toString(result.errors.at(0)) == + R"(Type '{ a: { c: nil, d: number }, b: number }' could not be converted into 'T'; type { a: { c: nil, d: number }, b: number }[read "a"][read "c"] (nil) is not exactly T[read "a"][read "c"][0] (T))" + ); else { - CHECK_EQ(toString(result.errors[0]), - R"(Type 'y' could not be converted into 'T' + const std::string expected = R"(Type 'y' could not be converted into 'T' caused by: - Property 'a' is not compatible. Type '{ c: T?, d: number }' could not be converted into 'U' + Property 'a' is not compatible. +Type '{ c: T?, d: number }' could not be converted into 'U' caused by: - Property 'd' is not compatible. Type 'number' could not be converted into 'string')"); + Property 'd' is not compatible. +Type 'number' could not be converted into 'string' in an invariant context)"; + CHECK_EQ(expected, toString(result.errors[0])); } } TEST_CASE_FIXTURE(Fixture, "generic_type_pack_unification1") { + ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + CheckResult result = check(R"( --!strict type Dispatcher = { @@ -765,6 +907,8 @@ local TheDispatcher: Dispatcher = { TEST_CASE_FIXTURE(Fixture, "generic_type_pack_unification2") { + ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + CheckResult result = check(R"( --!strict type Dispatcher = { @@ -783,6 +927,8 @@ local TheDispatcher: Dispatcher = { TEST_CASE_FIXTURE(Fixture, "generic_type_pack_unification3") { + ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + CheckResult result = check(R"( --!strict type Dispatcher = { @@ -801,6 +947,8 @@ local TheDispatcher: Dispatcher = { TEST_CASE_FIXTURE(Fixture, "generic_argument_count_too_few") { + ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + CheckResult result = check(R"( function test(a: number) return 1 @@ -818,6 +966,8 @@ wrapper(test) TEST_CASE_FIXTURE(Fixture, "generic_argument_count_too_many") { + ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + CheckResult result = check(R"( function test2(a: number, b: string) return 1 @@ -844,8 +994,8 @@ TEST_CASE_FIXTURE(Fixture, "generic_function") LUAU_REQUIRE_NO_ERRORS(result); CHECK_EQ("(a) -> a", toString(requireType("id"))); - CHECK_EQ(*typeChecker.numberType, *requireType("a")); - CHECK_EQ(*typeChecker.nilType, *requireType("b")); + CHECK_EQ(*builtinTypes->numberType, *requireType("a")); + CHECK_EQ(*builtinTypes->nilType, *requireType("b")); } TEST_CASE_FIXTURE(Fixture, "generic_table_method") @@ -865,7 +1015,7 @@ TEST_CASE_FIXTURE(Fixture, "generic_table_method") REQUIRE(tTable != nullptr); REQUIRE(tTable->props.count("bar")); - TypeId barType = tTable->props["bar"].type; + TypeId barType = tTable->props["bar"].type(); REQUIRE(barType != nullptr); const FunctionType* ftv = get(follow(barType)); @@ -874,7 +1024,7 @@ TEST_CASE_FIXTURE(Fixture, "generic_table_method") std::vector args = flatten(ftv->argTypes).first; TypeId argType = args.at(1); - CHECK_MESSAGE(get(argType), "Should be generic: " << *barType); + CHECK_MESSAGE(get(argType), "Should be generic: " << *barType); } TEST_CASE_FIXTURE(Fixture, "correctly_instantiate_polymorphic_member_functions") @@ -900,7 +1050,7 @@ TEST_CASE_FIXTURE(Fixture, "correctly_instantiate_polymorphic_member_functions") std::optional fooProp = get(t->props, "foo"); REQUIRE(bool(fooProp)); - const FunctionType* foo = get(follow(fooProp->type)); + const FunctionType* foo = get(follow(fooProp->type())); REQUIRE(bool(foo)); std::optional ret_ = first(foo->retTypes); @@ -942,12 +1092,12 @@ TEST_CASE_FIXTURE(Fixture, "instantiate_cyclic_generic_function") TypeId arg = follow(*optionArg); const TableType* argTable = get(arg); - REQUIRE(argTable != nullptr); + REQUIRE_MESSAGE(argTable != nullptr, "Expected table but got " << toString(arg)); std::optional methodProp = get(argTable->props, "method"); REQUIRE(bool(methodProp)); - const FunctionType* methodFunction = get(methodProp->type); + const FunctionType* methodFunction = get(follow(methodProp->type())); REQUIRE(methodFunction != nullptr); std::optional methodArg = first(methodFunction->argTypes); @@ -975,7 +1125,11 @@ TEST_CASE_FIXTURE(Fixture, "instantiate_generic_function_in_assignments") TypeMismatch* tm = get(result.errors[0]); REQUIRE(tm); CHECK_EQ("((number) -> number, string) -> number", toString(tm->wantedType)); - if (FFlag::LuauInstantiateInSubtyping) + // The new solver does not attempt to instantiate generics here, so if + // either the instantiate in subtyping flag _or_ the new solver flags + // are set, assert that we're getting back the original generic + // function definition. + if (FFlag::LuauInstantiateInSubtyping || FFlag::LuauSolverV2) CHECK_EQ("((a) -> (b...), a) -> (b...)", toString(tm->givenType)); else CHECK_EQ("((number) -> number, number) -> number", toString(tm->givenType)); @@ -998,7 +1152,11 @@ TEST_CASE_FIXTURE(Fixture, "instantiate_generic_function_in_assignments2") TypeMismatch* tm = get(result.errors[0]); REQUIRE(tm); CHECK_EQ("(string, string) -> number", toString(tm->wantedType)); - if (FFlag::LuauInstantiateInSubtyping) + // The new solver does not attempt to instantiate generics here, so if + // either the instantiate in subtyping flag _or_ the new solver flags + // are set, assert that we're getting back the original generic + // function definition. + if (FFlag::LuauInstantiateInSubtyping || FFlag::LuauSolverV2) CHECK_EQ("((a) -> (b...), a) -> (b...)", toString(tm->givenType)); else CHECK_EQ("((string) -> number, string) -> number", toString(*tm->givenType)); @@ -1014,7 +1172,10 @@ local a: Self )"); LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(toString(requireType("a")), "Table"); + if (FFlag::LuauSolverV2) + CHECK_EQ(toString(requireType("a")), "Table
"); + else + CHECK_EQ(toString(requireType("a")), "Table"); } TEST_CASE_FIXTURE(Fixture, "no_stack_overflow_from_quantifying") @@ -1030,30 +1191,54 @@ TEST_CASE_FIXTURE(Fixture, "no_stack_overflow_from_quantifying") std::optional t0 = lookupType("t0"); REQUIRE(t0); - CHECK_EQ("*error-type*", toString(*t0)); - - auto it = std::find_if(result.errors.begin(), result.errors.end(), [](TypeError& err) { - return get(err); - }); + if (FFlag::LuauSolverV2) + CHECK_EQ("any", toString(*t0)); + else + CHECK_EQ("*error-type*", toString(*t0)); + + auto it = std::find_if( + result.errors.begin(), + result.errors.end(), + [](TypeError& err) + { + return get(err); + } + ); CHECK(it != result.errors.end()); } TEST_CASE_FIXTURE(BuiltinsFixture, "infer_generic_function_function_argument") { - CheckResult result = check(R"( - local function sum(x: a, y: a, f: (a, a) -> a) - return f(x, y) - end - return sum(2, 3, function(a, b) return a + b end) - )"); - LUAU_REQUIRE_NO_ERRORS(result); + + if (FFlag::LuauSolverV2) + { + CheckResult result = check(R"( + local function sum(x: a, y: a, f: (a, a) -> add) + return f(x, y) + end + return sum(2, 3, function(a: T, b: T): add return a + b end) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + } + else + { + CheckResult result = check(R"( + local function sum(x: a, y: a, f: (a, a) -> a) + return f(x, y) + end + return sum(2, 3, function(a, b) return a + b end) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + } } TEST_CASE_FIXTURE(BuiltinsFixture, "infer_generic_function_function_argument_2") { CheckResult result = check(R"( - local function map(arr: {a}, f: (a) -> b) + local function map(arr: {a}, f: (a) -> b): {b} local r = {} for i,v in ipairs(arr) do table.insert(r, f(v)) @@ -1061,7 +1246,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "infer_generic_function_function_argument_2") return r end local a = {1, 2, 3} - local r = map(a, function(a) return a + a > 100 end) + local r = map(a, function(a: number) return a + a > 100 end) )"); LUAU_REQUIRE_NO_ERRORS(result); @@ -1079,11 +1264,14 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "infer_generic_function_function_argument_3") return r end local a = {1, 2, 3} - local r = foldl(a, {s=0,c=0}, function(a, b) return {s = a.s + b, c = a.c + 1} end) + local r = foldl(a, {s=0,c=0}, function(a: {s: number, c: number}, b: number) return {s = a.s + b, c = a.c + 1} end) )"); LUAU_REQUIRE_NO_ERRORS(result); - REQUIRE_EQ("{ c: number, s: number }", toString(requireType("r"))); + if (FFlag::LuauSolverV2) + REQUIRE_EQ("{ c: number, s: number } | { c: number, s: number }", toString(requireType("r"))); + else + REQUIRE_EQ("{ c: number, s: number }", toString(requireType("r"))); } TEST_CASE_FIXTURE(Fixture, "infer_generic_function_function_argument_overloaded") @@ -1117,22 +1305,44 @@ table.sort(a, function(x, y) return x.x < y.x end) LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "do_not_infer_generic_functions") +TEST_CASE_FIXTURE(BuiltinsFixture, "do_not_infer_generic_functions") { - CheckResult result = check(R"( -local function sum(x: a, y: a, f: (a, a) -> a) return f(x, y) end -local function sumrec(f: typeof(sum)) - return sum(2, 3, function(a, b) return a + b end) -end + CheckResult result; -local b = sumrec(sum) -- ok -local c = sumrec(function(x, y, f) return f(x, y) end) -- type binders are not inferred - )"); + if (FFlag::LuauSolverV2) + { + result = check(R"( + local function sum(x: a, y: a, f: (a, a) -> a) return f(x, y) end - LUAU_REQUIRE_NO_ERRORS(result); + local function sumrec(f: typeof(sum)) + return sum(2, 3, function(a: T, b: T): add return a + b end) + end + + local b = sumrec(sum) -- ok + local c = sumrec(function(x, y, f) return f(x, y) end) -- type binders are not inferred + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + } + else + { + result = check(R"( + local function sum(x: a, y: a, f: (a, a) -> a) return f(x, y) end + + local function sumrec(f: typeof(sum)) + return sum(2, 3, function(a, b) return a + b end) + end + + local b = sumrec(sum) -- ok + local c = sumrec(function(x, y, f) return f(x, y) end) -- type binders are not inferred + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + } } + TEST_CASE_FIXTURE(Fixture, "substitution_with_bound_table") { CheckResult result = check(R"( @@ -1149,7 +1359,9 @@ TEST_CASE_FIXTURE(Fixture, "substitution_with_bound_table") TEST_CASE_FIXTURE(Fixture, "apply_type_function_nested_generics1") { - // https://github.com/Roblox/luau/issues/484 + // CLI-114507: temporarily changed to have a cast for `object` to silence false positive error + + // https://github.com/luau-lang/luau/issues/484 CheckResult result = check(R"( --!strict type MyObject = { @@ -1159,7 +1371,7 @@ local object: MyObject = { getReturnValue = function(cb: () -> U): U return cb() end, -} +} :: MyObject type ComplexObject = { id: T, @@ -1177,7 +1389,7 @@ local complex: ComplexObject = { TEST_CASE_FIXTURE(Fixture, "apply_type_function_nested_generics2") { - // https://github.com/Roblox/luau/issues/484 + // https://github.com/luau-lang/luau/issues/484 CheckResult result = check(R"( --!strict type MyObject = { @@ -1188,20 +1400,45 @@ type ComplexObject = { nested: MyObject } -local complex2: ComplexObject = nil +function f(complex: ComplexObject) + local x = complex.nested.getReturnValue(function(): string + return "" + end) + + local y = complex.nested.getReturnValue(function() + return 3 + end) +end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} -local x = complex2.nested.getReturnValue(function(): string - return "" -end) +TEST_CASE_FIXTURE(Fixture, "apply_type_function_nested_generics3") +{ + // This minimization was useful for debugging a particular issue with + // cyclic types under local type inference. -local y = complex2.nested.getReturnValue(function() - return 3 -end) + CheckResult result = check(R"( + local getReturnValue: (cb: () -> V) -> V = nil :: any + + local y = getReturnValue(function() return nil :: any end) )"); LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(Fixture, "quantify_functions_with_no_generics") +{ + CheckResult result = check(R"( + function foo(f, x) + return f(x) + end + )"); + + CHECK("((a) -> (b...), a) -> (b...)" == toString(requireType("foo"))); +} + TEST_CASE_FIXTURE(Fixture, "quantify_functions_even_if_they_have_an_explicit_generic") { CheckResult result = check(R"( @@ -1213,8 +1450,21 @@ TEST_CASE_FIXTURE(Fixture, "quantify_functions_even_if_they_have_an_explicit_gen CHECK("((X) -> (a...), X) -> (a...)" == toString(requireType("foo"))); } +TEST_CASE_FIXTURE(Fixture, "no_extra_quantification_for_generic_functions") +{ + CheckResult result = check(R"( + function foo(f : (X) -> Y, x: X) + return f(x) + end + )"); + + CHECK("((X) -> Y, X) -> Y" == toString(requireType("foo"))); +} + TEST_CASE_FIXTURE(Fixture, "do_not_always_instantiate_generic_intersection_types") { + ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + CheckResult result = check(R"( --!strict type Array = { [number]: T } @@ -1249,7 +1499,8 @@ end TEST_CASE_FIXTURE(BuiltinsFixture, "higher_rank_polymorphism_should_not_accept_instantiated_arguments") { ScopedFastFlag sffs[] = { - {"LuauInstantiateInSubtyping", true}, + {FFlag::LuauSolverV2, false}, + {FFlag::LuauInstantiateInSubtyping, true}, }; CheckResult result = check(R"( @@ -1288,4 +1539,87 @@ TEST_CASE_FIXTURE(Fixture, "bidirectional_checking_and_generalization_play_nice" CHECK("string" == toString(requireType("b"))); } +TEST_CASE_FIXTURE(BuiltinsFixture, "generalization_no_cyclic_intersections") +{ + CheckResult result = check(R"( + local f, t, n = pairs({"foo"}) + local k, v = f(t) + )"); + + CHECK("({string}, number?) -> (number?, string)" == toString(requireType("f"))); + CHECK("{string}" == toString(requireType("t"))); + CHECK("number?" == toString(requireType("k"))); + CHECK("string" == toString(requireType("v"))); +} + +TEST_CASE_FIXTURE(Fixture, "missing_generic_type_parameter") +{ + CheckResult result = check(R"( + function f(x: T): T return x end + )"); + + LUAU_REQUIRE_ERROR_COUNT(2, result); + + REQUIRE(get(result.errors[0])); + REQUIRE(get(result.errors[1])); +} + +TEST_CASE_FIXTURE(Fixture, "generic_implicit_explicit_name_clash") +{ + ScopedFastFlag _{FFlag::LuauSolverV2, true}; + + auto result = check(R"( + function apply(func, argument: a) + return func(argument) + end + )"); + + CHECK("((a) -> (b...), a) -> (b...)" == toString(requireType("apply"))); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "generic_type_functions_work_in_subtyping") +{ + ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + + if (!FFlag::LuauSolverV2) + return; + + CheckResult result = check(R"( + local function addOne(x: T): add return x + 1 end + + local function six(): number + return addOne(5) + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "generic_type_subtyping_nested_bounds_with_new_mappings") +{ + // Test shows how going over mapped generics in a subtyping check can generate more mapped generics when making a subtyping check between bounds. + // It has previously caused iterator invalidation in the new solver, but this specific test doesn't trigger a UAF, only shows an example. + if (!FFlag::LuauSolverV2) + return; + + CheckResult result = check(R"( +type Dispatch = (A) -> () +type BasicStateAction = ((S) -> S) | S + +function updateReducer(reducer: (S, A) -> S, initialArg: I, init: ((I) -> S)?): (S, Dispatch) + return 1 :: any +end + +function basicStateReducer(state: S, action: BasicStateAction): S + return action +end + +function updateState(initialState: (() -> S) | S): (S, Dispatch>) + return updateReducer(basicStateReducer, initialState) +end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.intersectionTypes.test.cpp b/tests/TypeInfer.intersectionTypes.test.cpp index ea6fff773..50e285050 100644 --- a/tests/TypeInfer.intersectionTypes.test.cpp +++ b/tests/TypeInfer.intersectionTypes.test.cpp @@ -4,10 +4,12 @@ #include "Fixture.h" +#include "ScopedFlags.h" #include "doctest.h" using namespace Luau; +LUAU_FASTFLAG(LuauSolverV2); TEST_SUITE_BEGIN("IntersectionTypes"); @@ -16,14 +18,15 @@ TEST_CASE_FIXTURE(Fixture, "select_correct_union_fn") CheckResult result = check(R"( type A = (number) -> (string) type B = (string) -> (number) - local f:A & B - local b = f(10) -- b is a string - local c = f("a") -- c is a number + + local function foo(f: A & B) + return f(10), f("a") + end )"); LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(requireType("b"), typeChecker.stringType); - CHECK_EQ(requireType("c"), typeChecker.numberType); + + CHECK_EQ("(((number) -> string) & ((string) -> number)) -> (string, number)", toString(requireType("foo"))); } TEST_CASE_FIXTURE(Fixture, "table_combines") @@ -31,6 +34,7 @@ TEST_CASE_FIXTURE(Fixture, "table_combines") CheckResult result = check(R"( type A={a:number} type B={b:string} + local c:A & B = {a=10, b="s"} )"); @@ -42,6 +46,7 @@ TEST_CASE_FIXTURE(Fixture, "table_combines_missing") CheckResult result = check(R"( type A={a:number} type B={b:string} + local c:A & B = {a=10} )"); @@ -62,8 +67,10 @@ TEST_CASE_FIXTURE(Fixture, "table_extra_ok") CheckResult result = check(R"( type A={a:number} type B={b:string} - local c:A & B - local d:A = c + + local function f(t: A & B): A + return t + end )"); LUAU_REQUIRE_NO_ERRORS(result); @@ -75,9 +82,10 @@ TEST_CASE_FIXTURE(Fixture, "fx_intersection_as_argument") type A = (number) -> (string) type B = (string) -> (number) type C = (A) -> (number) - local f:A & B - local g:C - local b = g(f) + + local function foo(f: A & B, g: C) + return g(f) + end )"); LUAU_REQUIRE_NO_ERRORS(result); @@ -89,9 +97,10 @@ TEST_CASE_FIXTURE(Fixture, "fx_union_as_argument_fails") type A = (number) -> (string) type B = (string) -> (number) type C = (A) -> (number) - local f:A | B - local g:C - local b = g(f) + + local function foo(f: A | B, g: C) + return g(f) + end )"); REQUIRE(!result.errors.empty()); @@ -101,10 +110,11 @@ TEST_CASE_FIXTURE(Fixture, "argument_is_intersection") { CheckResult result = check(R"( type A = (number | boolean) -> number - local f: A - f(5) - f(true) + local function foo(f: A) + f(5) + f(true) + end )"); LUAU_REQUIRE_NO_ERRORS(result); @@ -113,59 +123,40 @@ TEST_CASE_FIXTURE(Fixture, "argument_is_intersection") TEST_CASE_FIXTURE(Fixture, "should_still_pick_an_overload_whose_arguments_are_unions") { CheckResult result = check(R"( - type A = (number | boolean) -> number - type B = (string | nil) -> string - local f: A & B + type A = (number) -> string + type B = (string) -> number - local a1, a2 = f(1), f(true) - local b1, b2 = f("foo"), f(nil) + local function foo(f: A & B) + return f(1), f("five") + end )"); LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(*requireType("a1"), *typeChecker.numberType); - CHECK_EQ(*requireType("a2"), *typeChecker.numberType); - - CHECK_EQ(*requireType("b1"), *typeChecker.stringType); - CHECK_EQ(*requireType("b2"), *typeChecker.stringType); + CHECK_EQ("(((number) -> string) & ((string) -> number)) -> (string, number)", toString(requireType("foo"))); } TEST_CASE_FIXTURE(Fixture, "propagates_name") { - if (FFlag::DebugLuauDeferredConstraintResolution) - { - CheckResult result = check(R"( - type A={a:number} - type B={b:string} - - local c:A&B - local b = c - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - CHECK("{| a: number, b: string |}" == toString(requireType("b"))); - } - else - { - const std::string code = R"( - type A={a:number} - type B={b:string} + const std::string code = R"( + type A={a:number} + type B={b:string} - local c:A&B - local b = c - )"; + local function f(t: A & B) + return t + end + )"; - const std::string expected = R"( - type A={a:number} - type B={b:string} + const std::string expected = R"( + type A={a:number} + type B={b:string} - local c:A&B - local b:A&B=c - )"; + local function f(t: A & B): A&B + return t + end + )"; - CHECK_EQ(expected, decorateWithTypes(code)); - } + CHECK_EQ(expected, decorateWithTypes(code)); } TEST_CASE_FIXTURE(Fixture, "index_on_an_intersection_type_with_property_guaranteed_to_exist") @@ -173,16 +164,18 @@ TEST_CASE_FIXTURE(Fixture, "index_on_an_intersection_type_with_property_guarante CheckResult result = check(R"( type A = {x: {y: number}} type B = {x: {y: number}} - local t: A & B - local r = t.x + local function f(t: A & B) + return t.x + end )"); LUAU_REQUIRE_NO_ERRORS(result); - if (FFlag::DebugLuauDeferredConstraintResolution) - CHECK("{| y: number |}" == toString(requireType("r"))); + + if (FFlag::LuauSolverV2) + CHECK("(A & B) -> { y: number }" == toString(requireType("f"))); else - CHECK("{| y: number |} & {| y: number |}" == toString(requireType("r"))); + CHECK("(A & B) -> {| y: number |} & {| y: number |}" == toString(requireType("f"))); } TEST_CASE_FIXTURE(Fixture, "index_on_an_intersection_type_works_at_arbitrary_depth") @@ -190,21 +183,18 @@ TEST_CASE_FIXTURE(Fixture, "index_on_an_intersection_type_works_at_arbitrary_dep CheckResult result = check(R"( type A = {x: {y: {z: {thing: string}}}} type B = {x: {y: {z: {thing: string}}}} - local t: A & B - local r = t.x.y.z.thing + local function f(t: A & B) + return t.x.y.z.thing + end )"); LUAU_REQUIRE_NO_ERRORS(result); - if (FFlag::DebugLuauDeferredConstraintResolution) - { - CHECK_EQ("string", toString(requireType("r"))); - } + if (FFlag::LuauSolverV2) + CHECK_EQ("(A & B) -> string", toString(requireType("f"))); else - { - CHECK_EQ("string & string", toString(requireType("r"))); - } + CHECK_EQ("(A & B) -> string & string", toString(requireType("f"))); } TEST_CASE_FIXTURE(Fixture, "index_on_an_intersection_type_with_mixed_types") @@ -212,16 +202,18 @@ TEST_CASE_FIXTURE(Fixture, "index_on_an_intersection_type_with_mixed_types") CheckResult result = check(R"( type A = {x: number} type B = {x: string} - local t: A & B - local r = t.x + local function f(t: A & B) + return t.x + end )"); LUAU_REQUIRE_NO_ERRORS(result); - if (FFlag::DebugLuauDeferredConstraintResolution) - CHECK_EQ("never", toString(requireType("r"))); + + if (FFlag::LuauSolverV2) + CHECK_EQ("(A & B) -> never", toString(requireType("f"))); else - CHECK_EQ("number & string", toString(requireType("r"))); + CHECK_EQ("(A & B) -> number & string", toString(requireType("f"))); } TEST_CASE_FIXTURE(Fixture, "index_on_an_intersection_type_with_one_part_missing_the_property") @@ -229,13 +221,14 @@ TEST_CASE_FIXTURE(Fixture, "index_on_an_intersection_type_with_one_part_missing_ CheckResult result = check(R"( type A = {x: number} type B = {} - local t: A & B - local r = t.x + local function f(t: A & B) + return t.x + end )"); LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("number", toString(requireType("r"))); + CHECK_EQ("(A & B) -> number", toString(requireType("f"))); } TEST_CASE_FIXTURE(Fixture, "index_on_an_intersection_type_with_one_property_of_type_any") @@ -243,13 +236,14 @@ TEST_CASE_FIXTURE(Fixture, "index_on_an_intersection_type_with_one_property_of_t CheckResult result = check(R"( type A = {y: number} type B = {x: any} - local t: A & B - local r = t.x + local function f(t: A & B) + return t.x + end )"); LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(*typeChecker.anyType, *requireType("r")); + CHECK_EQ("(A & B) -> any", toString(requireType("f"))); } TEST_CASE_FIXTURE(Fixture, "index_on_an_intersection_type_with_all_parts_missing_the_property") @@ -276,8 +270,9 @@ TEST_CASE_FIXTURE(Fixture, "table_intersection_write") type X = { x: number } type XY = X & { y: number } - local a : XY = { x = 1, y = 2 } - a.x = 10 + function f(t: XY) + t.x = 10 + end )"); LUAU_REQUIRE_NO_ERRORS(result); @@ -286,8 +281,9 @@ TEST_CASE_FIXTURE(Fixture, "table_intersection_write") type X = {} type XY = X & { x: number, y: number } - local a : XY = { x = 1, y = 2 } - a.x = 10 + function f(t: XY) + t.x = 10 + end )"); LUAU_REQUIRE_NO_ERRORS(result); @@ -297,8 +293,9 @@ TEST_CASE_FIXTURE(Fixture, "table_intersection_write") type Y = { y: number } type XY = X & Y - local a : XY = { x = 1, y = 2 } - a.x = 10 + function f(t: XY) + t.x = 10 + end )"); LUAU_REQUIRE_NO_ERRORS(result); @@ -306,10 +303,11 @@ TEST_CASE_FIXTURE(Fixture, "table_intersection_write") result = check(R"( type A = { x: {y: number} } type B = { x: {y: number} } - local t : A & B = { x = { y = 1 } } - t.x = { y = 4 } - t.x.y = 40 + function f(t: A & B) + t.x = { y = 4 } + t.x.y = 40 + end )"); LUAU_REQUIRE_NO_ERRORS(result); @@ -322,42 +320,59 @@ TEST_CASE_FIXTURE(Fixture, "table_intersection_write_sealed") type Y = { y: number } type XY = X & Y - local a : XY = { x = 1, y = 2 } - a.z = 10 + function f(t: XY) + t.z = 10 + end )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(toString(result.errors[0]), "Cannot add property 'z' to table 'X & Y'"); + auto e = toString(result.errors[0]); + CHECK_EQ("Cannot add property 'z' to table 'X & Y'", e); } TEST_CASE_FIXTURE(Fixture, "table_intersection_write_sealed_indirect") { + ScopedFastFlag dcr{ + FFlag::LuauSolverV2, false + }; // CLI-116476 Subtyping between type alias and an equivalent but not named type isn't working. CheckResult result = check(R"( type X = { x: (number) -> number } type Y = { y: (string) -> string } type XY = X & Y - local xy : XY = { - x = function(a: number) return -a end, - y = function(a: string) return a .. "b" end - } - function xy.z(a:number) return a * 10 end - function xy:y(a:number) return a * 10 end - function xy:w(a:number) return a * 10 end + function f(t: XY) + function t.z(a:number) return a * 10 end + function t:y(a:number) return a * 10 end + function t:w(a:number) return a * 10 end + end )"); - LUAU_REQUIRE_ERROR_COUNT(4, result); - CHECK_EQ(toString(result.errors[0]), R"(Type '(string, number) -> string' could not be converted into '(string) -> string' + if (FFlag::LuauSolverV2) + { + LUAU_REQUIRE_ERROR_COUNT(2, result); + CHECK_EQ(toString(result.errors[0]), "Cannot add property 'z' to table 'X & Y'"); + CHECK_EQ(toString(result.errors[1]), "Cannot add property 'w' to table 'X & Y'"); + } + else + { + LUAU_REQUIRE_ERROR_COUNT(4, result); + const std::string expected = R"(Type + '(string, number) -> string' +could not be converted into + '(string) -> string' caused by: - Argument count mismatch. Function expects 2 arguments, but only 1 is specified)"); - CHECK_EQ(toString(result.errors[1]), "Cannot add property 'z' to table 'X & Y'"); - CHECK_EQ(toString(result.errors[2]), "Type 'number' could not be converted into 'string'"); - CHECK_EQ(toString(result.errors[3]), "Cannot add property 'w' to table 'X & Y'"); + Argument count mismatch. Function expects 2 arguments, but only 1 is specified)"; + CHECK_EQ(expected, toString(result.errors[0])); + CHECK_EQ(toString(result.errors[1]), "Cannot add property 'z' to table 'X & Y'"); + CHECK_EQ(toString(result.errors[2]), "Type 'number' could not be converted into 'string'"); + CHECK_EQ(toString(result.errors[3]), "Cannot add property 'w' to table 'X & Y'"); + } } TEST_CASE_FIXTURE(Fixture, "table_write_sealed_indirect") { + ScopedFastFlag dcr{FFlag::LuauSolverV2, false}; // CLI- // After normalization, previous 'table_intersection_write_sealed_indirect' is identical to this one CheckResult result = check(R"( type XY = { x: (number) -> number, y: (string) -> string } @@ -372,9 +387,14 @@ TEST_CASE_FIXTURE(Fixture, "table_write_sealed_indirect") )"); LUAU_REQUIRE_ERROR_COUNT(4, result); - CHECK_EQ(toString(result.errors[0]), R"(Type '(string, number) -> string' could not be converted into '(string) -> string' + const std::string expected = R"(Type + '(string, number) -> string' +could not be converted into + '(string) -> string' caused by: - Argument count mismatch. Function expects 2 arguments, but only 1 is specified)"); + Argument count mismatch. Function expects 2 arguments, but only 1 is specified)"; + CHECK_EQ(expected, toString(result.errors[0])); + CHECK_EQ(toString(result.errors[1]), "Cannot add property 'z' to table 'XY'"); CHECK_EQ(toString(result.errors[2]), "Type 'number' could not be converted into 'string'"); CHECK_EQ(toString(result.errors[3]), "Cannot add property 'w' to table 'XY'"); @@ -383,8 +403,9 @@ caused by: TEST_CASE_FIXTURE(BuiltinsFixture, "table_intersection_setmetatable") { CheckResult result = check(R"( - local t: {} & {} - setmetatable(t, {}) + function f(t: {} & {}) + setmetatable(t, {}) + end )"); LUAU_REQUIRE_NO_ERRORS(result); @@ -401,12 +422,18 @@ local a: XYZ = 3 )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - if (FFlag::DebugLuauDeferredConstraintResolution) - CHECK_EQ(toString(result.errors[0]), R"(Type 'number' could not be converted into '{| x: number, y: number, z: number |}')"); - else - CHECK_EQ(toString(result.errors[0]), R"(Type 'number' could not be converted into 'X & Y & Z' + const std::string expected = R"(Type 'number' could not be converted into 'X & Y & Z' caused by: - Not all intersection parts are compatible. Type 'number' could not be converted into 'X')"); + Not all intersection parts are compatible. +Type 'number' could not be converted into 'X')"; + const std::string dcrExprected = + R"(Type 'number' could not be converted into 'X & Y & Z'; type number (number) is not a subtype of X & Y & Z[0] (X) + type number (number) is not a subtype of X & Y & Z[1] (Y) + type number (number) is not a subtype of X & Y & Z[2] (Z))"; + if (FFlag::LuauSolverV2) + CHECK_EQ(dcrExprected, toString(result.errors[0])); + else + CHECK_EQ(expected, toString(result.errors[0])); } TEST_CASE_FIXTURE(Fixture, "error_detailed_intersection_all") @@ -416,16 +443,26 @@ type X = { x: number } type Y = { y: number } type Z = { z: number } type XYZ = X & Y & Z -local a: XYZ -local b: number = a + +function f(a: XYZ): number + return a +end )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - if (FFlag::DebugLuauDeferredConstraintResolution) - CHECK_EQ(toString(result.errors[0]), R"(Type '{| x: number, y: number, z: number |}' could not be converted into 'number')"); + if (FFlag::LuauSolverV2) + { + CHECK_EQ( + R"(Type pack 'X & Y & Z' could not be converted into 'number'; type X & Y & Z[0][0] (X) is not a subtype of number[0] (number) + type X & Y & Z[0][1] (Y) is not a subtype of number[0] (number) + type X & Y & Z[0][2] (Z) is not a subtype of number[0] (number))", + toString(result.errors[0]) + ); + } else CHECK_EQ( - toString(result.errors[0]), R"(Type 'X & Y & Z' could not be converted into 'number'; none of the intersection parts are compatible)"); + toString(result.errors[0]), R"(Type 'X & Y & Z' could not be converted into 'number'; none of the intersection parts are compatible)" + ); } TEST_CASE_FIXTURE(Fixture, "overload_is_not_a_function") @@ -459,139 +496,174 @@ TEST_CASE_FIXTURE(Fixture, "no_stack_overflow_from_flattenintersection") TEST_CASE_FIXTURE(Fixture, "intersect_bool_and_false") { CheckResult result = check(R"( - local x : (boolean & false) - local y : false = x -- OK - local z : true = x -- Not OK + function f(x: boolean & false) + local y : false = x -- OK + local z : true = x -- Not OK + end )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - if (FFlag::DebugLuauDeferredConstraintResolution) - CHECK_EQ(toString(result.errors[0]), "Type 'false' could not be converted into 'true'"); + if (FFlag::LuauSolverV2) + { + CHECK_EQ( + R"(Type 'boolean & false' could not be converted into 'true'; type boolean & false[0] (boolean) is not a subtype of true (true) + type boolean & false[1] (false) is not a subtype of true (true))", + toString(result.errors[0]) + ); + } else CHECK_EQ( - toString(result.errors[0]), "Type 'boolean & false' could not be converted into 'true'; none of the intersection parts are compatible"); + toString(result.errors[0]), "Type 'boolean & false' could not be converted into 'true'; none of the intersection parts are compatible" + ); } TEST_CASE_FIXTURE(Fixture, "intersect_false_and_bool_and_false") { CheckResult result = check(R"( - local x : false & (boolean & false) - local y : false = x -- OK - local z : true = x -- Not OK + function f(x: false & (boolean & false)) + local y : false = x -- OK + local z : true = x -- Not OK + end )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - if (FFlag::DebugLuauDeferredConstraintResolution) - CHECK_EQ(toString(result.errors[0]), "Type 'false' could not be converted into 'true'"); + // TODO: odd stringification of `false & (boolean & false)`.) + if (FFlag::LuauSolverV2) + CHECK_EQ( + R"(Type 'boolean & false & false' could not be converted into 'true'; type boolean & false & false[0] (false) is not a subtype of true (true) + type boolean & false & false[1] (boolean) is not a subtype of true (true) + type boolean & false & false[2] (false) is not a subtype of true (true))", + toString(result.errors[0]) + ); else - { - // TODO: odd stringification of `false & (boolean & false)`.) - CHECK_EQ(toString(result.errors[0]), - "Type 'boolean & false & false' could not be converted into 'true'; none of the intersection parts are compatible"); - } + CHECK_EQ( + toString(result.errors[0]), + "Type 'boolean & false & false' could not be converted into 'true'; none of the intersection parts are compatible" + ); } TEST_CASE_FIXTURE(Fixture, "intersect_saturate_overloaded_functions") { CheckResult result = check(R"( - local x : ((number?) -> number?) & ((string?) -> string?) - local y : (nil) -> nil = x -- OK - local z : (number) -> number = x -- Not OK + function foo(x: ((number?) -> number?) & ((string?) -> string?)) + local y : (nil) -> nil = x -- Not OK (fixed in DCR) + local z : (number) -> number = x -- Not OK + end )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(toString(result.errors[0]), "Type '((number?) -> number?) & ((string?) -> string?)' could not be converted into '(number) -> number'; " - "none of the intersection parts are compatible"); + if (FFlag::LuauSolverV2) + { + LUAU_REQUIRE_ERROR_COUNT(2, result); + const std::string expected1 = R"(Type + '((number?) -> number?) & ((string?) -> string?)' +could not be converted into + '(nil) -> nil'; type ((number?) -> number?) & ((string?) -> string?)[0].returns()[0][0] (number) is not a subtype of (nil) -> nil.returns()[0] (nil) + type ((number?) -> number?) & ((string?) -> string?)[1].returns()[0][0] (string) is not a subtype of (nil) -> nil.returns()[0] (nil))"; + const std::string expected2 = R"(Type + '((number?) -> number?) & ((string?) -> string?)' +could not be converted into + '(number) -> number'; type ((number?) -> number?) & ((string?) -> string?)[0].returns()[0][1] (nil) is not a subtype of (number) -> number.returns()[0] (number) + type ((number?) -> number?) & ((string?) -> string?)[1].arguments()[0] (string?) is not a supertype of (number) -> number.arguments()[0] (number) + type ((number?) -> number?) & ((string?) -> string?)[1].returns()[0][0] (string) is not a subtype of (number) -> number.returns()[0] (number) + type ((number?) -> number?) & ((string?) -> string?)[1].returns()[0][1] (nil) is not a subtype of (number) -> number.returns()[0] (number))"; + CHECK_EQ(expected1, toString(result.errors[0])); + CHECK_EQ(expected2, toString(result.errors[1])); + } + else + { + LUAU_REQUIRE_ERROR_COUNT(1, result); + const std::string expected = R"(Type + '((number?) -> number?) & ((string?) -> string?)' +could not be converted into + '(number) -> number'; none of the intersection parts are compatible)"; + CHECK_EQ(expected, toString(result.errors[0])); + } } TEST_CASE_FIXTURE(Fixture, "union_saturate_overloaded_functions") { + ScopedFastFlag dcr{ + FFlag::LuauSolverV2, false + }; // CLI-116474 Semantic subtyping of assignments needs to decide how to interpret intersections of functions CheckResult result = check(R"( - local x : ((number) -> number) & ((string) -> string) - local y : ((number | string) -> (number | string)) = x -- OK - local z : ((number | boolean) -> (number | boolean)) = x -- Not OK + function f(x: ((number) -> number) & ((string) -> string)) + local y : ((number | string) -> (number | string)) = x -- OK + local z : ((number | boolean) -> (number | boolean)) = x -- Not OK + end )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(toString(result.errors[0]), "Type '((number) -> number) & ((string) -> string)' could not be converted into '(boolean | number) -> " - "boolean | number'; none of the intersection parts are compatible"); + const std::string expected = R"(Type + '((number) -> number) & ((string) -> string)' +could not be converted into + '(boolean | number) -> boolean | number'; none of the intersection parts are compatible)"; + CHECK_EQ(expected, toString(result.errors[0])); } TEST_CASE_FIXTURE(Fixture, "intersection_of_tables") { CheckResult result = check(R"( - local x : { p : number?, q : string? } & { p : number?, q : number?, r : number? } - local y : { p : number?, q : nil, r : number? } = x -- OK - local z : { p : nil } = x -- Not OK + function f(x: { p : number?, q : string? } & { p : number?, q : number?, r : number? }) + local y : { p : number?, q : nil, r : number? } = x -- OK + local z : { p : nil } = x -- Not OK + end )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - if (FFlag::DebugLuauDeferredConstraintResolution) - { - CHECK_EQ(toString(result.errors[0]), - "Type '{| p: number?, q: nil, r: number? |}' could not be converted into '{| p: nil |}'\n" - "caused by:\n" - " Property 'p' is not compatible. Type 'number?' could not be converted into 'nil'\n" - "caused by:\n" - " Not all union options are compatible. Type 'number' could not be converted into 'nil' in an invariant context"); - } - else - { - CHECK_EQ(toString(result.errors[0]), - "Type '{| p: number?, q: number?, r: number? |} & {| p: number?, q: string? |}' could not be converted into " - "'{| p: nil |}'; none of the intersection parts are compatible"); - } + const std::string expected = + (FFlag::LuauSolverV2) + ? R"(Type '{ p: number?, q: number?, r: number? } & { p: number?, q: string? }' could not be converted into '{ p: nil }'; type { p: number?, q: number?, r: number? } & { p: number?, q: string? }[0][read "p"][0] (number) is not exactly { p: nil }[read "p"] (nil) + type { p: number?, q: number?, r: number? } & { p: number?, q: string? }[1][read "p"][0] (number) is not exactly { p: nil }[read "p"] (nil))" + : + R"(Type + '{| p: number?, q: number?, r: number? |} & {| p: number?, q: string? |}' +could not be converted into + '{| p: nil |}'; none of the intersection parts are compatible)"; + CHECK_EQ(expected, toString(result.errors[0])); } TEST_CASE_FIXTURE(Fixture, "intersection_of_tables_with_top_properties") { CheckResult result = check(R"( - local x : { p : number?, q : any } & { p : unknown, q : string? } - local y : { p : number?, q : string? } = x -- OK - local z : { p : string?, q : number? } = x -- Not OK + function f(x : { p : number?, q : any } & { p : unknown, q : string? }) + local y : { p : number?, q : string? } = x -- OK + local z : { p : string?, q : number? } = x -- Not OK + end )"); - if (FFlag::DebugLuauDeferredConstraintResolution) + if (FFlag::LuauSolverV2) { - LUAU_REQUIRE_ERROR_COUNT(2, result); - - CHECK_EQ(toString(result.errors[0]), - "Type '{| p: number?, q: string? |}' could not be converted into '{| p: string?, q: number? |}'\n" - "caused by:\n" - " Property 'p' is not compatible. Type 'number?' could not be converted into 'string?'\n" - "caused by:\n" - " Not all union options are compatible. Type 'number' could not be converted into 'string?'\n" - "caused by:\n" - " None of the union options are compatible. For example: Type 'number' could not be converted into 'string' in an invariant context"); - - CHECK_EQ(toString(result.errors[1]), - "Type '{| p: number?, q: string? |}' could not be converted into '{| p: string?, q: number? |}'\n" - "caused by:\n" - " Property 'q' is not compatible. Type 'string?' could not be converted into 'number?'\n" - "caused by:\n" - " Not all union options are compatible. Type 'string' could not be converted into 'number?'\n" - "caused by:\n" - " None of the union options are compatible. For example: Type 'string' could not be converted into 'number' in an invariant context"); + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ( + R"(Type + '{ p: number?, q: any } & { p: unknown, q: string? }' +could not be converted into + '{ p: string?, q: number? }'; type { p: number?, q: any } & { p: unknown, q: string? }[0][read "p"] (number?) is not exactly { p: string?, q: number? }[read "p"][0] (string) + type { p: number?, q: any } & { p: unknown, q: string? }[0][read "p"][0] (number) is not exactly { p: string?, q: number? }[read "p"] (string?) + type { p: number?, q: any } & { p: unknown, q: string? }[0][read "q"] (any) is not exactly { p: string?, q: number? }[read "q"] (number?) + type { p: number?, q: any } & { p: unknown, q: string? }[1][read "p"] (unknown) is not exactly { p: string?, q: number? }[read "p"] (string?) + type { p: number?, q: any } & { p: unknown, q: string? }[1][read "q"] (string?) is not exactly { p: string?, q: number? }[read "q"][0] (number) + type { p: number?, q: any } & { p: unknown, q: string? }[1][read "q"][0] (string) is not exactly { p: string?, q: number? }[read "q"] (number?))", + toString(result.errors[0]) + ); } else { LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(toString(result.errors[0]), - "Type '{| p: number?, q: any |} & {| p: unknown, q: string? |}' could not be converted into '{| p: string?, " - "q: number? |}'; none of the intersection parts are compatible"); + const std::string expected = R"(Type + '{| p: number?, q: any |} & {| p: unknown, q: string? |}' +could not be converted into + '{| p: string?, q: number? |}'; none of the intersection parts are compatible)"; + CHECK_EQ(expected, toString(result.errors[0])); } } TEST_CASE_FIXTURE(Fixture, "intersection_of_tables_with_never_properties") { - ScopedFastFlag sffs[]{ - {"LuauUninhabitedSubAnything2", true}, - }; - CheckResult result = check(R"( - local x : { p : number?, q : never } & { p : never, q : string? } -- OK - local y : { p : never, q : never } = x -- OK - local z : never = x -- OK + function f(x : { p : number?, q : never } & { p : never, q : string? }) + local y : { p : never, q : never } = x -- OK + local z : never = x -- OK + end )"); LUAU_REQUIRE_NO_ERRORS(result); @@ -600,261 +672,491 @@ TEST_CASE_FIXTURE(Fixture, "intersection_of_tables_with_never_properties") TEST_CASE_FIXTURE(Fixture, "overloaded_functions_returning_intersections") { CheckResult result = check(R"( - local x : ((number?) -> ({ p : number } & { q : number })) & ((string?) -> ({ p : number } & { r : number })) - local y : (nil) -> { p : number, q : number, r : number} = x -- OK - local z : (number?) -> { p : number, q : number, r : number} = x -- Not OK + function f(x : ((number?) -> ({ p : number } & { q : number })) & ((string?) -> ({ p : number } & { r : number }))) + local y : (nil) -> { p : number, q : number, r : number} = x -- OK + local z : (number?) -> { p : number, q : number, r : number} = x -- Not OK + end )"); - LUAU_REQUIRE_ERROR_COUNT(1, result); - if (FFlag::DebugLuauDeferredConstraintResolution) + if (FFlag::LuauSolverV2) { - CHECK_EQ(toString(result.errors[0]), - "Type '((number?) -> {| p: number, q: number |}) & ((string?) -> {| p: number, r: number |})' could not be converted into " - "'(number?) -> {| p: number, q: number, r: number |}'; none of the intersection parts are compatible"); + LUAU_REQUIRE_ERROR_COUNT(2, result); + CHECK_EQ( + R"(Type + '((number?) -> { p: number } & { q: number }) & ((string?) -> { p: number } & { r: number })' +could not be converted into + '(nil) -> { p: number, q: number, r: number }'; type ((number?) -> { p: number } & { q: number }) & ((string?) -> { p: number } & { r: number })[0].returns()[0][0] ({ p: number }) is not a subtype of (nil) -> { p: number, q: number, r: number }.returns()[0] ({ p: number, q: number, r: number }) + type ((number?) -> { p: number } & { q: number }) & ((string?) -> { p: number } & { r: number })[0].returns()[0][1] ({ q: number }) is not a subtype of (nil) -> { p: number, q: number, r: number }.returns()[0] ({ p: number, q: number, r: number }) + type ((number?) -> { p: number } & { q: number }) & ((string?) -> { p: number } & { r: number })[1].returns()[0][0] ({ p: number }) is not a subtype of (nil) -> { p: number, q: number, r: number }.returns()[0] ({ p: number, q: number, r: number }) + type ((number?) -> { p: number } & { q: number }) & ((string?) -> { p: number } & { r: number })[1].returns()[0][1] ({ r: number }) is not a subtype of (nil) -> { p: number, q: number, r: number }.returns()[0] ({ p: number, q: number, r: number }))", + toString(result.errors[0]) + ); + CHECK_EQ( + R"(Type + '((number?) -> { p: number } & { q: number }) & ((string?) -> { p: number } & { r: number })' +could not be converted into + '(number?) -> { p: number, q: number, r: number }'; type ((number?) -> { p: number } & { q: number }) & ((string?) -> { p: number } & { r: number })[0].returns()[0][0] ({ p: number }) is not a subtype of (number?) -> { p: number, q: number, r: number }.returns()[0] ({ p: number, q: number, r: number }) + type ((number?) -> { p: number } & { q: number }) & ((string?) -> { p: number } & { r: number })[0].returns()[0][1] ({ q: number }) is not a subtype of (number?) -> { p: number, q: number, r: number }.returns()[0] ({ p: number, q: number, r: number }) + type ((number?) -> { p: number } & { q: number }) & ((string?) -> { p: number } & { r: number })[1].arguments()[0] (string?) is not a supertype of (number?) -> { p: number, q: number, r: number }.arguments()[0][0] (number) + type ((number?) -> { p: number } & { q: number }) & ((string?) -> { p: number } & { r: number })[1].returns()[0][0] ({ p: number }) is not a subtype of (number?) -> { p: number, q: number, r: number }.returns()[0] ({ p: number, q: number, r: number }) + type ((number?) -> { p: number } & { q: number }) & ((string?) -> { p: number } & { r: number })[1].returns()[0][1] ({ r: number }) is not a subtype of (number?) -> { p: number, q: number, r: number }.returns()[0] ({ p: number, q: number, r: number }))", + toString(result.errors[1]) + ); } else { - CHECK_EQ(toString(result.errors[0]), - "Type '((number?) -> {| p: number |} & {| q: number |}) & ((string?) -> {| p: number |} & {| r: number |})' could not be converted into " - "'(number?) -> {| p: number, q: number, r: number |}'; none of the intersection parts are compatible"); + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ( + R"(Type + '((number?) -> {| p: number |} & {| q: number |}) & ((string?) -> {| p: number |} & {| r: number |})' +could not be converted into + '(number?) -> {| p: number, q: number, r: number |}'; none of the intersection parts are compatible)", + toString(result.errors[0]) + ); } } TEST_CASE_FIXTURE(Fixture, "overloaded_functions_mentioning_generic") { CheckResult result = check(R"( - function f() - local x : ((number?) -> (a | number)) & ((string?) -> (a | string)) - local y : (nil) -> a = x -- OK - local z : (number?) -> a = x -- Not OK - end + function f() + function g(x : ((number?) -> (a | number)) & ((string?) -> (a | string))) + local y : (nil) -> a = x -- OK + local z : (number?) -> a = x -- Not OK + end + end )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(toString(result.errors[0]), "Type '((number?) -> a | number) & ((string?) -> a | string)' could not be converted into '(number?) -> a'; " - "none of the intersection parts are compatible"); + if (FFlag::LuauSolverV2) + { + LUAU_REQUIRE_ERROR_COUNT(0, result); + } + else + { + LUAU_REQUIRE_ERROR_COUNT(1, result); + const std::string expected = R"(Type + '((number?) -> a | number) & ((string?) -> a | string)' +could not be converted into + '(number?) -> a'; none of the intersection parts are compatible)"; + CHECK_EQ(expected, toString(result.errors[0])); + } } TEST_CASE_FIXTURE(Fixture, "overloaded_functions_mentioning_generics") { CheckResult result = check(R"( - function f() - local x : ((a?) -> (a | b)) & ((c?) -> (b | c)) - local y : (nil) -> ((a & c) | b) = x -- OK - local z : (a?) -> ((a & c) | b) = x -- Not OK - end + function f() + function g(x : ((a?) -> (a | b)) & ((c?) -> (b | c))) + local y : (nil) -> ((a & c) | b) = x -- OK + local z : (a?) -> ((a & c) | b) = x -- Not OK + end + end )"); - LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(toString(result.errors[0]), - "Type '((a?) -> a | b) & ((c?) -> b | c)' could not be converted into '(a?) -> (a & c) | b'; none of the intersection parts are compatible"); + + if (FFlag::LuauSolverV2) + { + LUAU_REQUIRE_NO_ERRORS(result); + } + else + { + LUAU_REQUIRE_ERROR_COUNT(1, result); + const std::string expected = R"(Type + '((a?) -> a | b) & ((c?) -> b | c)' +could not be converted into + '(a?) -> (a & c) | b'; none of the intersection parts are compatible)"; + CHECK_EQ(expected, toString(result.errors[0])); + } } TEST_CASE_FIXTURE(Fixture, "overloaded_functions_mentioning_generic_packs") { CheckResult result = check(R"( - function f() - local x : ((number?, a...) -> (number?, b...)) & ((string?, a...) -> (string?, b...)) - local y : ((nil, a...) -> (nil, b...)) = x -- OK - local z : ((nil, b...) -> (nil, a...)) = x -- Not OK - end + function f() + function g(x : ((number?, a...) -> (number?, b...)) & ((string?, a...) -> (string?, b...))) + local y : ((nil, a...) -> (nil, b...)) = x -- OK + local z : ((nil, b...) -> (nil, a...)) = x -- Not OK + end + end )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(toString(result.errors[0]), "Type '((number?, a...) -> (number?, b...)) & ((string?, a...) -> (string?, b...))' could not be converted " - "into '(nil, b...) -> (nil, a...)'; none of the intersection parts are compatible"); + if (FFlag::LuauSolverV2) + { + LUAU_REQUIRE_ERROR_COUNT(2, result); + CHECK_EQ( + R"(Type + '((number?, a...) -> (number?, b...)) & ((string?, a...) -> (string?, b...))' +could not be converted into + '(nil, a...) -> (nil, b...)'; type ((number?, a...) -> (number?, b...)) & ((string?, a...) -> (string?, b...))[0].returns()[0][0] (number) is not a subtype of (nil, a...) -> (nil, b...).returns()[0] (nil) + type ((number?, a...) -> (number?, b...)) & ((string?, a...) -> (string?, b...))[1].returns()[0][0] (string) is not a subtype of (nil, a...) -> (nil, b...).returns()[0] (nil))", + toString(result.errors[0]) + ); + CHECK_EQ( + R"(Type + '((number?, a...) -> (number?, b...)) & ((string?, a...) -> (string?, b...))' +could not be converted into + '(nil, b...) -> (nil, a...)'; type ((number?, a...) -> (number?, b...)) & ((string?, a...) -> (string?, b...))[0].returns()[0][0] (number) is not a subtype of (nil, b...) -> (nil, a...).returns()[0] (nil) + type ((number?, a...) -> (number?, b...)) & ((string?, a...) -> (string?, b...))[1].returns()[0][0] (string) is not a subtype of (nil, b...) -> (nil, a...).returns()[0] (nil))", + toString(result.errors[1]) + ); + } + else + { + LUAU_REQUIRE_ERROR_COUNT(1, result); + const std::string expected = R"(Type + '((number?, a...) -> (number?, b...)) & ((string?, a...) -> (string?, b...))' +could not be converted into + '(nil, b...) -> (nil, a...)'; none of the intersection parts are compatible)"; + CHECK_EQ(expected, toString(result.errors[0])); + } } TEST_CASE_FIXTURE(Fixture, "overloadeded_functions_with_unknown_result") { + ScopedFastFlag dcr{ + FFlag::LuauSolverV2, false + }; // CLI-116474 Semantic subtyping of assignments needs to decide how to interpret intersections of functions CheckResult result = check(R"( - function f() - local x : ((number) -> number) & ((nil) -> unknown) - local y : (number?) -> unknown = x -- OK - local z : (number?) -> number? = x -- Not OK - end + function f() + function g(x : ((number) -> number) & ((nil) -> unknown)) + local y : (number?) -> unknown = x -- OK + local z : (number?) -> number? = x -- Not OK + end + end )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(toString(result.errors[0]), "Type '((nil) -> unknown) & ((number) -> number)' could not be converted into '(number?) -> number?'; none " - "of the intersection parts are compatible"); + const std::string expected = R"(Type + '((nil) -> unknown) & ((number) -> number)' +could not be converted into + '(number?) -> number?'; none of the intersection parts are compatible)"; + CHECK_EQ(expected, toString(result.errors[0])); } TEST_CASE_FIXTURE(Fixture, "overloadeded_functions_with_unknown_arguments") { + ScopedFastFlag dcr{ + FFlag::LuauSolverV2, false + }; // CLI-116474 Semantic subtyping of assignments needs to decide how to interpret intersections of functions CheckResult result = check(R"( - function f() - local x : ((number) -> number?) & ((unknown) -> string?) - local y : (number) -> nil = x -- OK - local z : (number?) -> nil = x -- Not OK - end + function f() + function g(x : ((number) -> number?) & ((unknown) -> string?)) + local y : (number) -> nil = x -- OK + local z : (number?) -> nil = x -- Not OK + end + end )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(toString(result.errors[0]), "Type '((number) -> number?) & ((unknown) -> string?)' could not be converted into '(number?) -> nil'; none " - "of the intersection parts are compatible"); + const std::string expected = R"(Type + '((number) -> number?) & ((unknown) -> string?)' +could not be converted into + '(number?) -> nil'; none of the intersection parts are compatible)"; + CHECK_EQ(expected, toString(result.errors[0])); } TEST_CASE_FIXTURE(Fixture, "overloadeded_functions_with_never_result") { CheckResult result = check(R"( - function f() - local x : ((number) -> number) & ((nil) -> never) - local y : (number?) -> number = x -- OK - local z : (number?) -> never = x -- Not OK - end + function f() + function g(x : ((number) -> number) & ((nil) -> never)) + local y : (number?) -> number = x -- OK + local z : (number?) -> never = x -- Not OK + end + end )"); - LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(toString(result.errors[0]), "Type '((nil) -> never) & ((number) -> number)' could not be converted into '(number?) -> never'; none of " - "the intersection parts are compatible"); + if (FFlag::LuauSolverV2) + { + LUAU_REQUIRE_ERROR_COUNT(2, result); + CHECK_EQ( + R"(Type + '((nil) -> never) & ((number) -> number)' +could not be converted into + '(number?) -> number'; type ((nil) -> never) & ((number) -> number)[0].arguments()[0] (number) is not a supertype of (number?) -> number.arguments()[0][1] (nil) + type ((nil) -> never) & ((number) -> number)[1].arguments()[0] (nil) is not a supertype of (number?) -> number.arguments()[0][0] (number))", + toString(result.errors[0]) + ); + CHECK_EQ( + R"(Type + '((nil) -> never) & ((number) -> number)' +could not be converted into + '(number?) -> never'; type ((nil) -> never) & ((number) -> number)[0].arguments()[0] (number) is not a supertype of (number?) -> never.arguments()[0][1] (nil) + type ((nil) -> never) & ((number) -> number)[0].returns()[0] (number) is not a subtype of (number?) -> never.returns()[0] (never) + type ((nil) -> never) & ((number) -> number)[1].arguments()[0] (nil) is not a supertype of (number?) -> never.arguments()[0][0] (number))", + toString(result.errors[1]) + ); + } + else + { + LUAU_REQUIRE_ERROR_COUNT(1, result); + const std::string expected = R"(Type + '((nil) -> never) & ((number) -> number)' +could not be converted into + '(number?) -> never'; none of the intersection parts are compatible)"; + CHECK_EQ(expected, toString(result.errors[0])); + } } TEST_CASE_FIXTURE(Fixture, "overloadeded_functions_with_never_arguments") { CheckResult result = check(R"( - function f() - local x : ((number) -> number?) & ((never) -> string?) - local y : (never) -> nil = x -- OK - local z : (number?) -> nil = x -- Not OK - end + function f() + function g(x : ((number) -> number?) & ((never) -> string?)) + local y : (never) -> nil = x -- OK + local z : (number?) -> nil = x -- Not OK + end + end )"); - LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(toString(result.errors[0]), "Type '((never) -> string?) & ((number) -> number?)' could not be converted into '(number?) -> nil'; none " - "of the intersection parts are compatible"); + if (FFlag::LuauSolverV2) + { + LUAU_REQUIRE_ERROR_COUNT(2, result); + const std::string expected1 = R"(Type + '((never) -> string?) & ((number) -> number?)' +could not be converted into + '(never) -> nil'; type ((never) -> string?) & ((number) -> number?)[0].returns()[0][0] (number) is not a subtype of (never) -> nil.returns()[0] (nil) + type ((never) -> string?) & ((number) -> number?)[1].returns()[0][0] (string) is not a subtype of (never) -> nil.returns()[0] (nil))"; + const std::string expected2 = R"(Type + '((never) -> string?) & ((number) -> number?)' +could not be converted into + '(number?) -> nil'; type ((never) -> string?) & ((number) -> number?)[0].arguments()[0] (number) is not a supertype of (number?) -> nil.arguments()[0][1] (nil) + type ((never) -> string?) & ((number) -> number?)[0].returns()[0][0] (number) is not a subtype of (number?) -> nil.returns()[0] (nil) + type ((never) -> string?) & ((number) -> number?)[1].arguments()[0] (never) is not a supertype of (number?) -> nil.arguments()[0][0] (number) + type ((never) -> string?) & ((number) -> number?)[1].arguments()[0] (never) is not a supertype of (number?) -> nil.arguments()[0][1] (nil) + type ((never) -> string?) & ((number) -> number?)[1].returns()[0][0] (string) is not a subtype of (number?) -> nil.returns()[0] (nil))"; + CHECK_EQ(expected1, toString(result.errors[0])); + CHECK_EQ(expected2, toString(result.errors[1])); + } + else + { + LUAU_REQUIRE_ERROR_COUNT(1, result); + const std::string expected = R"(Type + '((never) -> string?) & ((number) -> number?)' +could not be converted into + '(number?) -> nil'; none of the intersection parts are compatible)"; + CHECK_EQ(expected, toString(result.errors[0])); + } } TEST_CASE_FIXTURE(Fixture, "overloadeded_functions_with_overlapping_results_and_variadics") { + ScopedFastFlag dcr{ + FFlag::LuauSolverV2, false + }; // CLI-116474 Semantic subtyping of assignments needs to decide how to interpret intersections of functions CheckResult result = check(R"( - local x : ((string?) -> (string | number)) & ((number?) -> ...number) - local y : ((nil) -> (number, number?)) = x -- OK - local z : ((string | number) -> (number, number?)) = x -- Not OK + function f(x : ((string?) -> (string | number)) & ((number?) -> ...number)) + local y : ((nil) -> (number, number?)) = x -- OK + local z : ((string | number) -> (number, number?)) = x -- Not OK + end )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(toString(result.errors[0]), "Type '((number?) -> (...number)) & ((string?) -> number | string)' could not be converted into '(number | " - "string) -> (number, number?)'; none of the intersection parts are compatible"); + const std::string expected = R"(Type + '((number?) -> (...number)) & ((string?) -> number | string)' +could not be converted into + '(number | string) -> (number, number?)'; none of the intersection parts are compatible)"; + CHECK_EQ(expected, toString(result.errors[0])); } TEST_CASE_FIXTURE(Fixture, "overloadeded_functions_with_weird_typepacks_1") { CheckResult result = check(R"( - function f() - local x : (() -> a...) & (() -> b...) - local y : (() -> b...) & (() -> a...) = x -- OK - local z : () -> () = x -- Not OK - end + function f() + function g(x : (() -> a...) & (() -> b...)) + local y : (() -> b...) & (() -> a...) = x -- OK + local z : () -> () = x -- Not OK + end + end )"); - LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(toString(result.errors[0]), - "Type '(() -> (a...)) & (() -> (b...))' could not be converted into '() -> ()'; none of the intersection parts are compatible"); + if (FFlag::LuauSolverV2) + { + LUAU_REQUIRE_NO_ERRORS(result); + } + else + { + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ( + toString(result.errors[0]), + "Type '(() -> (a...)) & (() -> (b...))' could not be converted into '() -> ()'; none of the intersection parts are compatible" + ); + } } TEST_CASE_FIXTURE(Fixture, "overloadeded_functions_with_weird_typepacks_2") { CheckResult result = check(R"( - function f() - local x : ((a...) -> ()) & ((b...) -> ()) - local y : ((b...) -> ()) & ((a...) -> ()) = x -- OK - local z : () -> () = x -- Not OK - end + function f() + function g(x : ((a...) -> ()) & ((b...) -> ())) + local y : ((b...) -> ()) & ((a...) -> ()) = x -- OK + local z : () -> () = x -- Not OK + end + end )"); - LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(toString(result.errors[0]), - "Type '((a...) -> ()) & ((b...) -> ())' could not be converted into '() -> ()'; none of the intersection parts are compatible"); + if (FFlag::LuauSolverV2) + { + LUAU_REQUIRE_NO_ERRORS(result); + } + else + { + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ( + toString(result.errors[0]), + "Type '((a...) -> ()) & ((b...) -> ())' could not be converted into '() -> ()'; none of the intersection parts are compatible" + ); + } } TEST_CASE_FIXTURE(Fixture, "overloadeded_functions_with_weird_typepacks_3") { CheckResult result = check(R"( - function f() - local x : (() -> a...) & (() -> (number?,a...)) - local y : (() -> (number?,a...)) & (() -> a...) = x -- OK - local z : () -> (number) = x -- Not OK - end + function f() + function g(x : (() -> a...) & (() -> (number?,a...))) + local y : (() -> (number?,a...)) & (() -> a...) = x -- OK + local z : () -> (number) = x -- Not OK + end + end )"); - LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(toString(result.errors[0]), - "Type '(() -> (a...)) & (() -> (number?, a...))' could not be converted into '() -> number'; none of the intersection parts are compatible"); + if (FFlag::LuauSolverV2) + { + LUAU_REQUIRE_NO_ERRORS(result); + } + else + { + LUAU_REQUIRE_ERROR_COUNT(1, result); + const std::string expected = R"(Type + '(() -> (a...)) & (() -> (number?, a...))' +could not be converted into + '() -> number'; none of the intersection parts are compatible)"; + CHECK_EQ(expected, toString(result.errors[0])); + } } TEST_CASE_FIXTURE(Fixture, "overloadeded_functions_with_weird_typepacks_4") { CheckResult result = check(R"( - function f() - local x : ((a...) -> ()) & ((number,a...) -> number) - local y : ((number,a...) -> number) & ((a...) -> ()) = x -- OK - local z : (number?) -> () = x -- Not OK - end + function f() + function g(x : ((a...) -> ()) & ((number,a...) -> number)) + local y : ((number,a...) -> number) & ((a...) -> ()) = x -- OK + local z : (number?) -> () = x -- Not OK + end + end )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(toString(result.errors[0]), "Type '((a...) -> ()) & ((number, a...) -> number)' could not be converted into '(number?) -> ()'; none of " - "the intersection parts are compatible"); + if (FFlag::LuauSolverV2) + { + CHECK_EQ( + R"(Type + '((a...) -> ()) & ((number, a...) -> number)' +could not be converted into + '((a...) -> ()) & ((number, a...) -> number)'; at [0].returns(), is not a subtype of number + type ((a...) -> ()) & ((number, a...) -> number)[1].arguments().tail() (a...) is not a supertype of ((a...) -> ()) & ((number, a...) -> number)[0].arguments().tail() (a...))", + toString(result.errors[0]) + ); + } + else + { + CHECK_EQ( + R"(Type + '((a...) -> ()) & ((number, a...) -> number)' +could not be converted into + '(number?) -> ()'; none of the intersection parts are compatible)", + toString(result.errors[0]) + ); + } } TEST_CASE_FIXTURE(BuiltinsFixture, "intersect_metatables") { - CheckResult result = check(R"( - local a : string? = nil - local b : number? = nil + // CLI-117121 - Intersection of types are not compatible with the equivalent alias + if (FFlag::LuauSolverV2) + return; - local x = setmetatable({}, { p = 5, q = a }); - local y = setmetatable({}, { q = b, r = "hi" }); - local z = setmetatable({}, { p = 5, q = nil, r = "hi" }); + if (FFlag::LuauSolverV2) + { + CheckResult result = check(R"( + function f(a: string?, b: string?) + local x = setmetatable({}, { p = 5, q = a }) + local y = setmetatable({}, { q = b, r = "hi" }) + local z = setmetatable({}, { p = 5, q = nil, r = "hi" }) - type X = typeof(x) - type Y = typeof(y) - type Z = typeof(z) + type X = typeof(x) + type Y = typeof(y) + type Z = typeof(z) - local xy : X&Y = z; - local yx : Y&X = z; - z = xy; - z = yx; - )"); + function g(xy: X&Y, yx: Y&X): (Z, Z) + return xy, yx + end - LUAU_REQUIRE_NO_ERRORS(result); + g(z, z) + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + } + else + { + CheckResult result = check(R"( + local a : string? = nil + local b : number? = nil + + local x = setmetatable({}, { p = 5, q = a }); + local y = setmetatable({}, { q = b, r = "hi" }); + local z = setmetatable({}, { p = 5, q = nil, r = "hi" }); + + type X = typeof(x) + type Y = typeof(y) + type Z = typeof(z) + + local xy : X&Y = z; + local yx : Y&X = z; + z = xy; + z = yx; + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + } } TEST_CASE_FIXTURE(BuiltinsFixture, "intersect_metatable_subtypes") { CheckResult result = check(R"( - local x = setmetatable({ a = 5 }, { p = 5 }); - local y = setmetatable({ b = "hi" }, { p = 5, q = "hi" }); - local z = setmetatable({ a = 5, b = "hi" }, { p = 5, q = "hi" }); + local x = setmetatable({ a = 5 }, { p = 5 }) + local y = setmetatable({ b = "hi" }, { p = 5, q = "hi" }) + local z = setmetatable({ a = 5, b = "hi" }, { p = 5, q = "hi" }) type X = typeof(x) type Y = typeof(y) type Z = typeof(z) - local xy : X&Y = z; - local yx : Y&X = z; - z = xy; - z = yx; + function f(xy: X&Y, yx: Y&X): (Z, Z) + return xy, yx + end + + f(z, z) )"); LUAU_REQUIRE_NO_ERRORS(result); } - TEST_CASE_FIXTURE(BuiltinsFixture, "intersect_metatables_with_properties") { CheckResult result = check(R"( - local x = setmetatable({ a = 5 }, { p = 5 }); - local y = setmetatable({ b = "hi" }, { q = "hi" }); - local z = setmetatable({ a = 5, b = "hi" }, { p = 5, q = "hi" }); + local x = setmetatable({ a = 5 }, { p = 5 }) + local y = setmetatable({ b = "hi" }, { q = "hi" }) + local z = setmetatable({ a = 5, b = "hi" }, { p = 5, q = "hi" }) type X = typeof(x) type Y = typeof(y) type Z = typeof(z) - local xy : X&Y = z; - z = xy; + function f(xy: X&Y): Z + return xy + end + + f(z) )"); LUAU_REQUIRE_NO_ERRORS(result); @@ -862,22 +1164,44 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "intersect_metatables_with_properties") TEST_CASE_FIXTURE(BuiltinsFixture, "intersect_metatable_with_table") { - CheckResult result = check(R"( - local x = setmetatable({ a = 5 }, { p = 5 }); - local z = setmetatable({ a = 5, b = "hi" }, { p = 5 }); + if (FFlag::LuauSolverV2) + { + CheckResult result = check(R"( + local x = setmetatable({ a = 5 }, { p = 5 }) + local z = setmetatable({ a = 5, b = "hi" }, { p = 5 }) - type X = typeof(x) - type Y = { b : string } - type Z = typeof(z) + type X = typeof(x) + type Y = { b : string } + type Z = typeof(z) - -- TODO: once we have shape types, we should be able to initialize these with z - local xy : X&Y; - local yx : Y&X; - z = xy; - z = yx; - )"); + function f(xy: X&Y, yx: Y&X): (Z, Z) + return xy, yx + end - LUAU_REQUIRE_NO_ERRORS(result); + f(z, z) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + } + else + { + CheckResult result = check(R"( + local x = setmetatable({ a = 5 }, { p = 5 }); + local z = setmetatable({ a = 5, b = "hi" }, { p = 5 }); + + type X = typeof(x) + type Y = { b : string } + type Z = typeof(z) + + -- TODO: once we have shape types, we should be able to initialize these with z + local xy : X&Y; + local yx : Y&X; + z = xy; + z = yx; + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + } } TEST_CASE_FIXTURE(Fixture, "CLI-44817") @@ -890,11 +1214,11 @@ TEST_CASE_FIXTURE(Fixture, "CLI-44817") type XY = {x: number, y: number} type XYZ = {x:number, y: number, z: number} - local xy: XY = {x = 0, y = 0} - local xyz: XYZ = {x = 0, y = 0, z = 0} + function f(xy: XY, xyz: XYZ): (X&Y, X&Y&Z) + return xy, xyz + end - local xNy: X&Y = xy - local xNyNz: X&Y&Z = xyz + local xNy, xNyNz = f({x = 0, y = 0}, {x = 0, y = 0, z = 0}) local t1: XY = xNy -- Type 'X & Y' could not be converted into 'XY' local t2: XY = xNyNz -- Type 'X & Y & Z' could not be converted into 'XY' @@ -906,7 +1230,7 @@ TEST_CASE_FIXTURE(Fixture, "CLI-44817") TEST_CASE_FIXTURE(Fixture, "less_greedy_unification_with_intersection_types") { - if (!FFlag::DebugLuauDeferredConstraintResolution) + if (!FFlag::LuauSolverV2) return; CheckResult result = check(R"( @@ -916,14 +1240,14 @@ TEST_CASE_FIXTURE(Fixture, "less_greedy_unification_with_intersection_types") end )"); - LUAU_REQUIRE_NO_ERRORS(result); + LUAU_REQUIRE_ERROR_COUNT(3, result); - CHECK_EQ("(never) -> never", toString(requireType("f"))); + CHECK_EQ("(never) -> { x: number } & { x: string }", toString(requireType("f"))); } TEST_CASE_FIXTURE(Fixture, "less_greedy_unification_with_intersection_types_2") { - if (!FFlag::DebugLuauDeferredConstraintResolution) + if (!FFlag::LuauSolverV2) return; CheckResult result = check(R"( @@ -934,7 +1258,85 @@ TEST_CASE_FIXTURE(Fixture, "less_greedy_unification_with_intersection_types_2") LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("(never) -> never", toString(requireType("f"))); + CHECK_EQ("({ x: number } & { x: string }) -> never", toString(requireType("f"))); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "index_property_table_intersection_1") +{ + CheckResult result = check(R"( +type Foo = { + Bar: string, +} & { Baz: number } + +function f(x: Foo) + return x.Bar +end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "index_property_table_intersection_2") +{ + CheckResult result = check(R"( + type Foo = { + Bar: string, + } & { Baz: number } + + function f(x: Foo) + return x["Bar"] + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "cli_80596_simplify_degenerate_intersections") +{ + ScopedFastFlag dcr{FFlag::LuauSolverV2, true}; + + CheckResult result = check(R"( + type A = { + x: number?, + } + + type B = { + x: number?, + } + + type C = A & B + + function f(obj: C): number + return obj.x or 3 + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "cli_80596_simplify_more_realistic_intersections") +{ + ScopedFastFlag dcr{FFlag::LuauSolverV2, true}; + + CheckResult result = check(R"( + type A = { + x: number?, + y: string?, + } + + type B = { + x: number?, + z: string?, + } + + type C = A & B + + function f(obj: C): number + return obj.x or 3 + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); } TEST_SUITE_END(); diff --git a/tests/TypeInfer.loops.test.cpp b/tests/TypeInfer.loops.test.cpp index 30cbe1d5b..0498437a9 100644 --- a/tests/TypeInfer.loops.test.cpp +++ b/tests/TypeInfer.loops.test.cpp @@ -2,6 +2,7 @@ #include "Luau/AstQuery.h" #include "Luau/BuiltinDefinitions.h" +#include "Luau/Frontend.h" #include "Luau/Scope.h" #include "Luau/TypeInfer.h" #include "Luau/Type.h" @@ -13,7 +14,7 @@ using namespace Luau; -LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) +LUAU_FASTFLAG(LuauSolverV2) TEST_SUITE_BEGIN("TypeInferLoops"); @@ -28,7 +29,99 @@ TEST_CASE_FIXTURE(Fixture, "for_loop") LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(*typeChecker.numberType, *requireType("q")); + if (FFlag::LuauSolverV2) + { + // Luau cannot see that the loop must always run at least once, so we + // think that q could be nil. + CHECK("number?" == toString(requireType("q"))); + } + else + CHECK_EQ(*builtinTypes->numberType, *requireType("q")); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "iteration_no_table_passed") +{ + // This test may block CI if forced to run outside of DCR. + if (!FFlag::LuauSolverV2) + return; + + ScopedFastFlag sff{FFlag::LuauSolverV2, true}; + CheckResult result = check(R"( + +type Iterable = typeof(setmetatable( + {}, + {}::{ + __iter: (self: Iterable) -> (any, number) -> (number, string) + } +)) + +local t: Iterable + +for a, b in t do end +)"); + + + LUAU_REQUIRE_ERROR_COUNT(1, result); + GenericError* ge = get(result.errors[0]); + REQUIRE(ge); + CHECK_EQ("__iter metamethod must return (next[, table[, state]])", ge->message); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "iteration_regression_issue_69967") +{ + if (!FFlag::LuauSolverV2) + return; + + CheckResult result = check(R"( + type Iterable = typeof(setmetatable( + {}, + {}::{ + __iter: (self: Iterable) -> () -> (number, string) + } + )) + + local t: Iterable + + for a, b in t do end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "iteration_regression_issue_69967_alt") +{ + if (!FFlag::LuauSolverV2) + return; + + CheckResult result = check(R"( + type Iterable = typeof(setmetatable( + {}, + {}::{ + __iter: (self: Iterable) -> () -> (number, string) + } + )) + + local t: Iterable + local x, y + + for a, b in t do + x = a + y = b + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + if (FFlag::LuauSolverV2) + { + // It's possible for the loop body to execute 0 times. + CHECK("number?" == toString(requireType("x"))); + CHECK("string?" == toString(requireType("y"))); + } + else + { + CHECK_EQ("number", toString(requireType("x"))); + CHECK_EQ("string", toString(requireType("y"))); + } } TEST_CASE_FIXTURE(BuiltinsFixture, "for_in_loop") @@ -44,12 +137,23 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "for_in_loop") LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(*typeChecker.numberType, *requireType("n")); - CHECK_EQ(*typeChecker.stringType, *requireType("s")); + if (FFlag::LuauSolverV2) + { + CHECK("number?" == toString(requireType("n"))); + CHECK("string?" == toString(requireType("s"))); + } + else + { + CHECK_EQ(*builtinTypes->numberType, *requireType("n")); + CHECK_EQ(*builtinTypes->stringType, *requireType("s")); + } } TEST_CASE_FIXTURE(BuiltinsFixture, "for_in_loop_with_next") { + // CLI-116494 The generics K and V are leaking out of the next() function somehow. + ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + CheckResult result = check(R"( local n local s @@ -61,8 +165,8 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "for_in_loop_with_next") LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(*typeChecker.numberType, *requireType("n")); - CHECK_EQ(*typeChecker.stringType, *requireType("s")); + CHECK_EQ(*builtinTypes->numberType, *requireType("n")); + CHECK_EQ(*builtinTypes->stringType, *requireType("s")); } TEST_CASE_FIXTURE(Fixture, "for_in_with_an_iterator_of_type_any") @@ -87,7 +191,8 @@ TEST_CASE_FIXTURE(Fixture, "for_in_loop_should_fail_with_non_function_iterator") )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ("Cannot call non-function string", toString(result.errors[0])); + + CHECK_EQ("Cannot call a value of type string", toString(result.errors[0])); } TEST_CASE_FIXTURE(BuiltinsFixture, "for_in_with_just_one_iterator_is_ok") @@ -111,14 +216,14 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "for_in_with_just_one_iterator_is_ok") TEST_CASE_FIXTURE(BuiltinsFixture, "for_in_loop_with_zero_iterators_dcr") { - ScopedFastFlag sff{"DebugLuauDeferredConstraintResolution", true}; + ScopedFastFlag sff{FFlag::LuauSolverV2, true}; CheckResult result = check(R"( function no_iter() end for key in no_iter() do end )"); - LUAU_REQUIRE_ERROR_COUNT(1, result); + LUAU_REQUIRE_ERROR_COUNT(2, result); } TEST_CASE_FIXTURE(BuiltinsFixture, "for_in_with_a_custom_iterator_should_type_check") @@ -156,11 +261,17 @@ TEST_CASE_FIXTURE(Fixture, "for_in_loop_on_error") LUAU_REQUIRE_ERROR_COUNT(2, result); TypeId p = requireType("p"); - CHECK_EQ("*error-type*", toString(p)); + if (FFlag::LuauSolverV2) + CHECK_EQ("*error-type*?", toString(p)); + else + CHECK_EQ("*error-type*", toString(p)); } TEST_CASE_FIXTURE(Fixture, "for_in_loop_on_non_function") { + // We report a spuriouus duplicate error here. + ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + CheckResult result = check(R"( local bad_iter = 5 @@ -175,6 +286,9 @@ TEST_CASE_FIXTURE(Fixture, "for_in_loop_on_non_function") TEST_CASE_FIXTURE(BuiltinsFixture, "for_in_loop_error_on_factory_not_returning_the_right_amount_of_values") { + // Spurious duplicate errors + ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + CheckResult result = check(R"( local function hasDivisors(value: number, table) return false @@ -218,12 +332,15 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "for_in_loop_error_on_factory_not_returning_t TypeMismatch* tm = get(result.errors[1]); REQUIRE(tm); - CHECK_EQ(typeChecker.numberType, tm->wantedType); - CHECK_EQ(typeChecker.stringType, tm->givenType); + CHECK_EQ(builtinTypes->numberType, tm->wantedType); + CHECK_EQ(builtinTypes->stringType, tm->givenType); } TEST_CASE_FIXTURE(BuiltinsFixture, "for_in_loop_error_on_iterator_requiring_args_but_none_given") { + // CLI-116496 + ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + CheckResult result = check(R"( function prime_iter(state, index) return 1 @@ -281,8 +398,8 @@ TEST_CASE_FIXTURE(Fixture, "for_in_loop_with_custom_iterator") TypeMismatch* tm = get(result.errors[0]); REQUIRE(tm); - CHECK_EQ(typeChecker.numberType, tm->wantedType); - CHECK_EQ(typeChecker.stringType, tm->givenType); + CHECK_EQ(builtinTypes->numberType, tm->wantedType); + CHECK_EQ(builtinTypes->stringType, tm->givenType); } TEST_CASE_FIXTURE(Fixture, "while_loop") @@ -296,7 +413,10 @@ TEST_CASE_FIXTURE(Fixture, "while_loop") LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(*typeChecker.numberType, *requireType("i")); + if (FFlag::LuauSolverV2) + CHECK("number?" == toString(requireType("i"))); + else + CHECK_EQ(*builtinTypes->numberType, *requireType("i")); } TEST_CASE_FIXTURE(Fixture, "repeat_loop") @@ -310,7 +430,10 @@ TEST_CASE_FIXTURE(Fixture, "repeat_loop") LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(*typeChecker.stringType, *requireType("i")); + if (FFlag::LuauSolverV2) + CHECK("string?" == toString(requireType("i"))); + else + CHECK_EQ(*builtinTypes->stringType, *requireType("i")); } TEST_CASE_FIXTURE(Fixture, "repeat_loop_condition_binds_to_its_block") @@ -349,9 +472,53 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "varlist_declared_by_for_in_loop_should_be_fr end )"); + if (FFlag::LuauSolverV2) + { + LUAU_REQUIRE_ERROR_COUNT(1, result); + auto err = get(result.errors[0]); + CHECK(err != nullptr); + } + else + { + LUAU_REQUIRE_NO_ERRORS(result); + } +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "iter_constraint_before_loop_body") +{ + CheckResult result = check(R"( + local T = { + fields = {}, + } + + function f() + for u, v in pairs(T.fields) do + T.fields[u] = nil + end + end + )"); + LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(BuiltinsFixture, "rbxl_place_file_crash_for_wrong_constraints") +{ + CheckResult result = check(R"( +local VehicleParameters = { + -- These are default values in the case the package structure is broken + StrutSpringStiffnessFront = 28000, +} + +local function updateFromConfiguration() + for property, value in pairs(VehicleParameters) do + VehicleParameters[property] = value + end +end +)"); + LUAU_REQUIRE_NO_ERRORS(result); +} + + TEST_CASE_FIXTURE(BuiltinsFixture, "properly_infer_iteratee_is_a_free_table") { // In this case, we cannot know the element type of the table {}. It could be anything. @@ -362,7 +529,15 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "properly_infer_iteratee_is_a_free_table") end )"); - LUAU_REQUIRE_NO_ERRORS(result); + if (FFlag::LuauSolverV2) + { + // In the new solver, we infer iter: unknown and so we warn on use of its properties. + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK(get(result.errors[0])); + CHECK(Location{{2, 12}, {2, 18}} == result.errors[0].location); + } + else + LUAU_REQUIRE_NO_ERRORS(result); } TEST_CASE_FIXTURE(BuiltinsFixture, "correctly_scope_locals_while") @@ -382,6 +557,19 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "correctly_scope_locals_while") CHECK_EQ(us->name, "a"); } +TEST_CASE_FIXTURE(BuiltinsFixture, "trivial_ipairs_usage") +{ + CheckResult result = check(R"( + local next, t, s = ipairs({1, 2, 3}) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + REQUIRE_EQ("({number}, number) -> (number?, number)", toString(requireType("next"))); + REQUIRE_EQ("{number}", toString(requireType("t"))); + REQUIRE_EQ("number", toString(requireType("s"))); +} + TEST_CASE_FIXTURE(BuiltinsFixture, "ipairs_produces_integral_indices") { CheckResult result = check(R"( @@ -391,7 +579,10 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "ipairs_produces_integral_indices") LUAU_REQUIRE_NO_ERRORS(result); - REQUIRE_EQ("number", toString(requireType("key"))); + if (FFlag::LuauSolverV2) + CHECK("number?" == toString(requireType("key"))); + else + REQUIRE_EQ("number", toString(requireType("key"))); } TEST_CASE_FIXTURE(Fixture, "for_in_loop_where_iteratee_is_free") @@ -498,6 +689,10 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "unreachable_code_after_infinite_loop") TEST_CASE_FIXTURE(BuiltinsFixture, "loop_typecheck_crash_on_empty_optional") { + // CLI-116498 Sometimes you can iterate over tables with no indexers. + if (FFlag::LuauSolverV2) + return; + CheckResult result = check(R"( local t = {} for _ in t do @@ -506,7 +701,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "loop_typecheck_crash_on_empty_optional") end )"); - LUAU_REQUIRE_ERROR_COUNT(2, result); + LUAU_REQUIRE_ERROR_COUNT(1, result); } TEST_CASE_FIXTURE(Fixture, "fuzz_fail_missing_instantitation_follow") @@ -547,11 +742,23 @@ TEST_CASE_FIXTURE(Fixture, "loop_iter_basic") )"); LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(*typeChecker.numberType, *requireType("key")); + + // The old solver just infers the wrong type here. + // The right type for `key` is `number?` + if (FFlag::LuauSolverV2) + { + TypeId keyTy = requireType("key"); + CHECK("number?" == toString(keyTy)); + } + else + CHECK_EQ(*builtinTypes->numberType, *requireType("key")); } TEST_CASE_FIXTURE(Fixture, "loop_iter_trailing_nil") { + // CLI-116498 Sometimes you can iterate over tables with no indexers. + ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + CheckResult result = check(R"( local t: {string} = {} local extra @@ -561,22 +768,23 @@ TEST_CASE_FIXTURE(Fixture, "loop_iter_trailing_nil") )"); LUAU_REQUIRE_ERROR_COUNT(0, result); - CHECK_EQ(*typeChecker.nilType, *requireType("extra")); + CHECK_EQ(*builtinTypes->nilType, *requireType("extra")); } TEST_CASE_FIXTURE(Fixture, "loop_iter_no_indexer_strict") { + // CLI-116498 Sometimes you can iterate over tables with no indexers. + ScopedFastFlag sff[] = { + {FFlag::LuauSolverV2, false}, + }; + CheckResult result = check(R"( local t = {} for k, v in t do end )"); - LUAU_REQUIRE_ERROR_COUNT(1, result); - - GenericError* ge = get(result.errors[0]); - REQUIRE(ge); - CHECK_EQ("Cannot iterate over a table without indexer", ge->message); + LUAU_REQUIRE_NO_ERRORS(result); } TEST_CASE_FIXTURE(Fixture, "loop_iter_no_indexer_nonstrict") @@ -592,7 +800,8 @@ TEST_CASE_FIXTURE(Fixture, "loop_iter_no_indexer_nonstrict") TEST_CASE_FIXTURE(BuiltinsFixture, "loop_iter_metamethod_nil") { - if (!FFlag::DebugLuauDeferredConstraintResolution) + // CLI-116499 Free types persisting until typechecking time. + if (1 || !FFlag::LuauSolverV2) return; CheckResult result = check(R"( @@ -607,7 +816,8 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "loop_iter_metamethod_nil") TEST_CASE_FIXTURE(BuiltinsFixture, "loop_iter_metamethod_not_enough_returns") { - if (!FFlag::DebugLuauDeferredConstraintResolution) + // CLI-116500 + if (1 || !FFlag::LuauSolverV2) return; CheckResult result = check(R"( @@ -617,15 +827,19 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "loop_iter_metamethod_not_enough_returns") )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK(result.errors[0] == TypeError{ - Location{{2, 36}, {2, 37}}, - GenericError{"__iter must return at least one value"}, - }); + CHECK( + result.errors[0] == + TypeError{ + Location{{2, 36}, {2, 37}}, + GenericError{"__iter must return at least one value"}, + } + ); } TEST_CASE_FIXTURE(BuiltinsFixture, "loop_iter_metamethod_ok") { - if (!FFlag::DebugLuauDeferredConstraintResolution) + // CLI-116500 + if (1 || !FFlag::LuauSolverV2) return; CheckResult result = check(R"( @@ -641,7 +855,8 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "loop_iter_metamethod_ok") TEST_CASE_FIXTURE(BuiltinsFixture, "loop_iter_metamethod_ok_with_inference") { - if (!FFlag::DebugLuauDeferredConstraintResolution) + // CLI-116500 + if (1 || !FFlag::LuauSolverV2) return; CheckResult result = check(R"( @@ -689,4 +904,357 @@ TEST_CASE_FIXTURE(Fixture, "for_loop_lower_bound_is_string_3") LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(BuiltinsFixture, "cli_68448_iterators_need_not_accept_nil") +{ + // CLI-116500 + if (FFlag::LuauSolverV2) + return; + + CheckResult result = check(R"( + local function makeEnum(members) + local enum = {} + for _, memberName in ipairs(members) do + enum[memberName] = memberName + end + return enum + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + // HACK (CLI-68453): We name this inner table `enum`. For now, use the + // exhaustive switch to see past it. + CHECK(toString(requireType("makeEnum"), {true}) == "({a}) -> {| [a]: a |}"); +} + +TEST_CASE_FIXTURE(Fixture, "iterate_over_free_table") +{ + CheckResult result = check(R"( + function print(x) end + + function dump(tbl) + print(tbl.whatever) + for k, v in tbl do + print(k) + print(v) + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "dcr_iteration_explore_raycast_minimization") +{ + CheckResult result = check(R"( + local testResults = {} + for _, testData in pairs(testResults) do + end + + table.insert(testResults, {}) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "dcr_iteration_minimized_fragmented_keys_1") +{ + CheckResult result = check(R"( + local function rawpairs(t) + return next, t, nil + end + + local function getFragmentedKeys(tbl) + local _ = rawget(tbl, 0) + for _ in rawpairs(tbl) do + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "dcr_iteration_minimized_fragmented_keys_2") +{ + CheckResult result = check(R"( + local function getFragmentedKeys(tbl) + local _ = rawget(tbl, 0) + for _ in next, tbl, nil do + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "dcr_iteration_minimized_fragmented_keys_3") +{ + CheckResult result = check(R"( + local function getFragmentedKeys(tbl) + local _ = rawget(tbl, 0) + for _ in pairs(tbl) do + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "dcr_iteration_fragmented_keys") +{ + CheckResult result = check(R"( + local function isIndexKey(k, contiguousLength) + return true + end + + local function getTableLength(tbl) + local length = 1 + local value = rawget(tbl, length) + while value ~= nil do + length += 1 + value = rawget(tbl, length) + end + return length - 1 + end + + local function rawpairs(t) + return next, t, nil + end + + local function getFragmentedKeys(tbl) + local keys = {} + local keysLength = 0 + local tableLength = getTableLength(tbl) + for key, _ in rawpairs(tbl) do + if not isIndexKey(key, tableLength) then + keysLength = keysLength + 1 + keys[keysLength] = key + end + end + return keys, keysLength, tableLength + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "dcr_xpath_candidates") +{ + // CLI-116500 + if (FFlag::LuauSolverV2) + return; + + CheckResult result = check(R"( + type Instance = {} + local function findCandidates(instances: { Instance }, path: { string }) + for _, name in ipairs(path) do + end + return {} + end + + local canditates = findCandidates({}, {}) + for _, canditate in ipairs(canditates) do end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "dcr_iteration_on_never_gives_never") +{ + if (!FFlag::LuauSolverV2) + return; + + CheckResult result = check(R"( + local iter: never + local ans + for xs in iter do + ans = xs + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + if (FFlag::LuauSolverV2) + CHECK("never?" == toString(requireType("ans"))); // CLI-114134 egraph simplification. Should just be nil. + else + CHECK(toString(requireType("ans")) == "never"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "iterate_over_properties") +{ + // CLI-116498 - Sometimes you can iterate over tables with no indexer. + ScopedFastFlag sff0{FFlag::LuauSolverV2, false}; + + CheckResult result = check(R"( + local function f() + local t = { p = 5, q = "hello" } + for k, v in t do + return k, v + end + + error("") + end + + local k, v = f() + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("unknown", toString(requireType("k"))); + CHECK_EQ("unknown", toString(requireType("v"))); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "iterate_over_properties_nonstrict") +{ + CheckResult result = check(R"( + --!nonstrict + local function f() + local t = { p = 5, q = "hello" } + for k, v in t do + return k, v + end + + error("") + end + + local k, v = f() + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "pairs_should_not_retroactively_add_an_indexer") +{ + CheckResult result = check(R"( + --!strict + local prices = { + hat = 1, + bat = 2, + } + print(prices.wwwww) + for _, _ in pairs(prices) do + end + print(prices.wwwww) + )"); + + if (FFlag::LuauSolverV2) + { + // We regress a little here: The old solver would typecheck the first + // access to prices.wwwww on a table that had no indexer, and the second + // on a table that does. + LUAU_REQUIRE_ERROR_COUNT(0, result); + } + else + LUAU_REQUIRE_ERROR_COUNT(1, result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "lti_fuzzer_uninitialized_loop_crash") +{ + CheckResult result = check(R"( + for l0=_,_ do + return _() + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(3, result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "iterate_array_of_singletons") +{ + CheckResult result = check(R"( + --!strict + type Direction = "Left" | "Right" | "Up" | "Down" + local Instructions: { Direction } = { "Left", "Down" } + + for _, step in Instructions do + local dir: Direction = step + print(dir) + end + )"); + + if (FFlag::LuauSolverV2) + LUAU_REQUIRE_NO_ERRORS(result); + else + LUAU_REQUIRE_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "iter_mm_results_are_lvalue") +{ + CheckResult result = check(R"( + local foo = setmetatable({}, { + __iter = function() + return pairs({1, 2, 3}) + end, + }) + + for k, v in foo do + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "forin_metatable_no_iter_mm") +{ + ScopedFastFlag sff{FFlag::LuauSolverV2, true}; + + CheckResult result = check(R"( + local t = setmetatable({1, 2, 3}, {}) + + for i, v in t do + print(i, v) + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("number", toString(requireTypeAtPosition({4, 18}))); + CHECK_EQ("number", toString(requireTypeAtPosition({4, 21}))); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "forin_metatable_iter_mm") +{ + ScopedFastFlag sff{FFlag::LuauSolverV2, true}; + + CheckResult result = check(R"( + type Iterable = typeof(setmetatable({}, {} :: { + __iter: (Iterable) -> () -> T... + })) + + for i, v in {} :: Iterable<...number> do + print(i, v) + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("number", toString(requireTypeAtPosition({6, 18}))); + CHECK_EQ("number", toString(requireTypeAtPosition({6, 21}))); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "iteration_preserves_error_suppression") +{ + CheckResult result = check(R"( + function first(x: any) + for k, v in pairs(x) do + print(k, v) + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK("any" == toString(requireTypeAtPosition({3, 22}))); + CHECK("any" == toString(requireTypeAtPosition({3, 25}))); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "tryDispatchIterableFunction_under_constrained_loop_should_not_assert") +{ + CheckResult result = check(R"( +local function foo(Instance) + for _, Child in next, Instance:GetChildren() do + end +end + )"); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.modules.test.cpp b/tests/TypeInfer.modules.test.cpp index ed3af11b4..c31f3d8cf 100644 --- a/tests/TypeInfer.modules.test.cpp +++ b/tests/TypeInfer.modules.test.cpp @@ -11,7 +11,8 @@ #include "doctest.h" LUAU_FASTFLAG(LuauInstantiateInSubtyping) -LUAU_FASTFLAG(LuauTypeMismatchInvarianceInError) +LUAU_FASTFLAG(LuauSolverV2) +LUAU_FASTFLAG(LuauTypestateBuiltins) using namespace Luau; @@ -39,7 +40,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "dcr_require_basic") CheckResult bResult = frontend.check("game/B"); LUAU_REQUIRE_NO_ERRORS(bResult); - ModulePtr b = frontend.moduleResolver.modules["game/B"]; + ModulePtr b = frontend.moduleResolver.getModule("game/B"); REQUIRE(b != nullptr); std::optional bType = requireType(b, "b"); REQUIRE(bType); @@ -56,12 +57,24 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "require") return {hooty=hooty} )"; - fileResolver.source["game/B"] = R"( - local Hooty = require(game.A) + if (FFlag::LuauSolverV2) + { + fileResolver.source["game/B"] = R"( + local Hooty = require(game.A) - local h -- free! - local i = Hooty.hooty(h) - )"; + local h = 4 + local i = Hooty.hooty(h) + )"; + } + else + { + fileResolver.source["game/B"] = R"( + local Hooty = require(game.A) + + local h -- free! + local i = Hooty.hooty(h) + )"; + } CheckResult aResult = frontend.check("game/A"); dumpErrors(aResult); @@ -71,7 +84,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "require") dumpErrors(bResult); LUAU_REQUIRE_NO_ERRORS(bResult); - ModulePtr b = frontend.moduleResolver.modules["game/B"]; + ModulePtr b = frontend.moduleResolver.getModule("game/B"); REQUIRE(b != nullptr); @@ -101,7 +114,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "require_types") CheckResult bResult = frontend.check("workspace/B"); LUAU_REQUIRE_NO_ERRORS(bResult); - ModulePtr b = frontend.moduleResolver.modules["workspace/B"]; + ModulePtr b = frontend.moduleResolver.getModule("workspace/B"); REQUIRE(b != nullptr); TypeId hType = requireType(b, "h"); @@ -140,6 +153,45 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "require_a_variadic_function") CHECK(get(*iter.tail())); } +TEST_CASE_FIXTURE(BuiltinsFixture, "cross_module_table_freeze") +{ + fileResolver.source["game/A"] = R"( + --!strict + return { + a = 1, + } + )"; + + fileResolver.source["game/B"] = R"( + --!strict + return table.freeze(require(game.A)) + )"; + + CheckResult aResult = frontend.check("game/A"); + LUAU_REQUIRE_NO_ERRORS(aResult); + + CheckResult bResult = frontend.check("game/B"); + LUAU_REQUIRE_NO_ERRORS(bResult); + + ModulePtr a = frontend.moduleResolver.getModule("game/A"); + REQUIRE(a != nullptr); + // confirm that no cross-module mutation happened here! + if (FFlag::LuauSolverV2) + CHECK(toString(a->returnType) == "{ a: number }"); + else + CHECK(toString(a->returnType) == "{| a: number |}"); + + ModulePtr b = frontend.moduleResolver.getModule("game/B"); + REQUIRE(b != nullptr); + // confirm that no cross-module mutation happened here! + if (FFlag::LuauSolverV2 && FFlag::LuauTypestateBuiltins) + CHECK(toString(b->returnType) == "{ read a: number }"); + else if (FFlag::LuauSolverV2) + CHECK(toString(b->returnType) == "{ a: number }"); + else + CHECK(toString(b->returnType) == "{| a: number |}"); +} + TEST_CASE_FIXTURE(Fixture, "type_error_of_unknown_qualified_type") { CheckResult result = check(R"( @@ -166,8 +218,8 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "require_module_that_does_not_export") frontend.check("game/Workspace/A"); frontend.check("game/Workspace/B"); - ModulePtr aModule = frontend.moduleResolver.modules["game/Workspace/A"]; - ModulePtr bModule = frontend.moduleResolver.modules["game/Workspace/B"]; + ModulePtr aModule = frontend.moduleResolver.getModule("game/Workspace/A"); + ModulePtr bModule = frontend.moduleResolver.getModule("game/Workspace/B"); CHECK(aModule->errors.empty()); REQUIRE_EQ(1, bModule->errors.size()); @@ -225,7 +277,10 @@ local tbl: string = require(game.A) CheckResult result = frontend.check("game/B"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ("Type '{| def: number |}' could not be converted into 'string'", toString(result.errors[0])); + if (FFlag::LuauSolverV2) + CHECK_EQ("Type '{ def: number }' could not be converted into 'string'", toString(result.errors[0])); + else + CHECK_EQ("Type '{| def: number |}' could not be converted into 'string'", toString(result.errors[0])); } TEST_CASE_FIXTURE(Fixture, "bound_free_table_export_is_ok") @@ -409,14 +464,16 @@ local b: B.T = a CheckResult result = frontend.check("game/C"); LUAU_REQUIRE_ERROR_COUNT(1, result); - if (FFlag::LuauTypeMismatchInvarianceInError) - CHECK_EQ(toString(result.errors[0]), R"(Type 'T' from 'game/A' could not be converted into 'T' from 'game/B' -caused by: - Property 'x' is not compatible. Type 'number' could not be converted into 'string' in an invariant context)"); + if (FFlag::LuauSolverV2) + CHECK(toString(result.errors.at(0)) == "Type 'T' could not be converted into 'T'; at [read \"x\"], number is not exactly string"); else - CHECK_EQ(toString(result.errors[0]), R"(Type 'T' from 'game/A' could not be converted into 'T' from 'game/B' + { + const std::string expected = R"(Type 'T' from 'game/A' could not be converted into 'T' from 'game/B' caused by: - Property 'x' is not compatible. Type 'number' could not be converted into 'string')"); + Property 'x' is not compatible. +Type 'number' could not be converted into 'string' in an invariant context)"; + CHECK_EQ(expected, toString(result.errors[0])); + } } TEST_CASE_FIXTURE(BuiltinsFixture, "module_type_conflict_instantiated") @@ -448,14 +505,16 @@ local b: B.T = a CheckResult result = frontend.check("game/D"); LUAU_REQUIRE_ERROR_COUNT(1, result); - if (FFlag::LuauTypeMismatchInvarianceInError) - CHECK_EQ(toString(result.errors[0]), R"(Type 'T' from 'game/B' could not be converted into 'T' from 'game/C' -caused by: - Property 'x' is not compatible. Type 'number' could not be converted into 'string' in an invariant context)"); + if (FFlag::LuauSolverV2) + CHECK(toString(result.errors.at(0)) == "Type 'T' could not be converted into 'T'; at [read \"x\"], number is not exactly string"); else - CHECK_EQ(toString(result.errors[0]), R"(Type 'T' from 'game/B' could not be converted into 'T' from 'game/C' + { + const std::string expected = R"(Type 'T' from 'game/B' could not be converted into 'T' from 'game/C' caused by: - Property 'x' is not compatible. Type 'number' could not be converted into 'string')"); + Property 'x' is not compatible. +Type 'number' could not be converted into 'string' in an invariant context)"; + CHECK_EQ(expected, toString(result.errors[0])); + } } TEST_CASE_FIXTURE(BuiltinsFixture, "constrained_anyification_clone_immutable_types") @@ -482,4 +541,199 @@ return unpack(l0[_]) LUAU_REQUIRE_ERRORS(result); } +TEST_CASE_FIXTURE(BuiltinsFixture, "check_imported_module_names") +{ + fileResolver.source["game/A"] = R"( +return function(...) end + )"; + + fileResolver.source["game/B"] = R"( +local l0 = require(game.A) +return l0 + )"; + + CheckResult result = check(R"( +local l0 = require(game.B) +if true then + local l1 = require(game.A) +end +return l0 + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + ModulePtr mod = getMainModule(); + REQUIRE(mod); + + REQUIRE(mod->scopes.size() == 4); + CHECK(mod->scopes[0].second->importedModules["l0"] == "game/B"); + CHECK(mod->scopes[3].second->importedModules["l1"] == "game/A"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "ensure_scope_is_nullptr_after_shallow_copy") +{ + ScopedFastFlag _{FFlag::LuauSolverV2, true}; + frontend.options.retainFullTypeGraphs = false; + + fileResolver.source["game/A"] = R"( +-- Roughly taken from ReactTypes.lua +type CoreBinding = {} +type BindingMap = {} +export type Binding = CoreBinding & BindingMap + +return {} + )"; + + LUAU_REQUIRE_NO_ERRORS(check(R"( +local Types = require(game.A) +type Binding = Types.Binding + )")); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "ensure_free_variables_are_generialized_across_function_boundaries") +{ + ScopedFastFlag _{FFlag::LuauSolverV2, true}; + + fileResolver.source["game/A"] = R"( +-- Roughly taken from react-shallow-renderer +function createUpdater(renderer) + local updater = { + _renderer = renderer, + } + + function updater.enqueueForceUpdate(publicInstance, callback, _callerName) + updater._renderer.render( + updater._renderer, + updater._renderer._element, + updater._renderer._context + ) + end + + function updater.enqueueReplaceState( + publicInstance, + completeState, + callback, + _callerName + ) + updater._renderer.render( + updater._renderer, + updater._renderer._element, + updater._renderer._context + ) + end + + function updater.enqueueSetState(publicInstance, partialState, callback, _callerName) + local currentState = updater._renderer._newState or publicInstance.state + updater._renderer.render( + updater._renderer, + updater._renderer._element, + updater._renderer._context + ) + end + + return updater +end + +local ReactShallowRenderer = {} + +function ReactShallowRenderer:_reset() + self._updater = createUpdater(self) +end + +return ReactShallowRenderer + )"; + + LUAU_REQUIRE_NO_ERRORS(check(R"( +local ReactShallowRenderer = require(game.A); + )")); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "untitled_segfault_number_13") +{ + ScopedFastFlag _{FFlag::LuauSolverV2, true}; + + fileResolver.source["game/A"] = R"( + -- minimized from roblox-requests/http/src/response.lua + local Response = {} + Response.__index = Response + function Response.new(content_type) + -- creates response object from original request and roblox http response + local self = setmetatable({}, Response) + self.content_type = content_type + return self + end + + function Response:xml(ignore_content_type) + if ignore_content_type or self.content_type:find("+xml") or self.content_type:find("/xml") then + else + end + end + + --------------- + + return Response + )"; + + LUAU_REQUIRE_NO_ERRORS(check(R"( + local _ = require(game.A); + )")); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "spooky_blocked_type_laundered_by_bound_type") +{ + ScopedFastFlag _{FFlag::LuauSolverV2, true}; + + fileResolver.source["game/A"] = R"( + local Cache = {} + + Cache.settings = {} + + Cache.data = {} + + function Cache.should_cache(url) + url = url:split("?")[1] + + for key, _ in pairs(Cache.settings) do + if url:match('') then + return key + end + end + + return "" + end + + function Cache.is_cached(url, req_id) + -- check local server cache first + + local setting_key = Cache.should_cache(url) + local settings = Cache.settings[setting_key] + + if not setting_key then + return false + end + + if Cache.data[req_id] ~= nil then + return true + end + + if Cache.settings[setting_key].cache_globally then + return false + else + return true + end + end + + function Cache.get_expire(url) + local setting_key = Cache.should_cache(url) + return Cache.settings[setting_key].expires or math.huge + end + + return Cache + )"; + + LUAU_REQUIRE_NO_ERRORS(check(R"( + local _ = require(game.A); + )")); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.negations.test.cpp b/tests/TypeInfer.negations.test.cpp index adf036532..a21751ecf 100644 --- a/tests/TypeInfer.negations.test.cpp +++ b/tests/TypeInfer.negations.test.cpp @@ -2,6 +2,7 @@ #include "Fixture.h" +#include "Luau/ToString.h" #include "doctest.h" #include "Luau/Common.h" #include "ScopedFlags.h" @@ -47,4 +48,35 @@ TEST_CASE_FIXTURE(NegationFixture, "string_is_not_a_subtype_of_negated_string") LUAU_REQUIRE_ERROR_COUNT(1, result); } +TEST_CASE_FIXTURE(Fixture, "cofinite_strings_can_be_compared_for_equality") +{ + // CLI-117082 Cofinite strings cannot be compared for equality because normalization produces a large type with cycles + if (FFlag::LuauSolverV2) + return; + CheckResult result = check(R"( + function f(e) + if e == 'strictEqual' then + e = 'strictEqualObject' + end + if e == 'deepStrictEqual' or e == 'strictEqual' then + elseif e == 'notDeepStrictEqual' or e == 'notStrictEqual' then + end + return e + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK("(string) -> string" == toString(requireType("f"))); +} + +TEST_CASE_FIXTURE(NegationFixture, "compare_cofinite_strings") +{ + CheckResult result = check(R"( +local u : Not<"a"> +local v : "b" +if u == v then +end +)"); + LUAU_REQUIRE_NO_ERRORS(result); +} TEST_SUITE_END(); diff --git a/tests/TypeInfer.oop.test.cpp b/tests/TypeInfer.oop.test.cpp index cf27518a6..3ccc04c1d 100644 --- a/tests/TypeInfer.oop.test.cpp +++ b/tests/TypeInfer.oop.test.cpp @@ -2,21 +2,25 @@ #include "Luau/AstQuery.h" #include "Luau/BuiltinDefinitions.h" -#include "Luau/Scope.h" -#include "Luau/TypeInfer.h" #include "Luau/Type.h" #include "Luau/VisitType.h" #include "Fixture.h" +#include "ScopedFlags.h" #include "doctest.h" using namespace Luau; +LUAU_FASTFLAG(LuauSolverV2); + TEST_SUITE_BEGIN("TypeInferOOP"); TEST_CASE_FIXTURE(Fixture, "dont_suggest_using_colon_rather_than_dot_if_not_defined_with_colon") { + // CLI-116571 method calls are missing arity checking? + ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + CheckResult result = check(R"( local someTable = {} @@ -27,12 +31,14 @@ TEST_CASE_FIXTURE(Fixture, "dont_suggest_using_colon_rather_than_dot_if_not_defi )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - REQUIRE(get(result.errors[0])); } TEST_CASE_FIXTURE(Fixture, "dont_suggest_using_colon_rather_than_dot_if_it_wont_help_2") { + // CLI-116571 method calls are missing arity checking? + ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + CheckResult result = check(R"( local someTable = {} @@ -43,7 +49,6 @@ TEST_CASE_FIXTURE(Fixture, "dont_suggest_using_colon_rather_than_dot_if_it_wont_ )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - REQUIRE(get(result.errors[0])); } @@ -139,7 +144,10 @@ TEST_CASE_FIXTURE(Fixture, "inferring_hundreds_of_self_calls_should_not_suffocat )"); ModulePtr module = getMainModule(); - CHECK_GE(50, module->internalTypes.types.size()); + if (FFlag::LuauSolverV2) + CHECK_GE(80, module->internalTypes.types.size()); + else + CHECK_GE(50, module->internalTypes.types.size()); } TEST_CASE_FIXTURE(BuiltinsFixture, "object_constructor_can_refer_to_method_of_self") @@ -217,7 +225,7 @@ TEST_CASE_FIXTURE(Fixture, "inferred_methods_of_free_tables_have_the_same_level_ check(R"( function Base64FileReader(data) local reader = {} - local index: number + local index: number = 0 function reader:PeekByte() return data:byte(index) @@ -290,4 +298,257 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "set_prop_of_intersection_containing_metatabl )"); } +// DCR once had a bug in the following code where it would erroneously bind the 'self' table to itself. +TEST_CASE_FIXTURE(Fixture, "dont_bind_free_tables_to_themselves") +{ + CheckResult result = check(R"( + local T = {} + local b: any + + function T:m() + local a = b[i] + if a then + self:n() + if self:p(a) then + self:n() + end + end + end + )"); +} + +// We should probably flag an error on this. See CLI-68672 +TEST_CASE_FIXTURE(BuiltinsFixture, "flag_when_index_metamethod_returns_0_values") +{ + CheckResult result = check(R"( + local T = {} + function T.__index() + end + + local a = setmetatable({}, T) + local p = a.prop + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK("nil" == toString(requireType("p"))); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "augmenting_an_unsealed_table_with_a_metatable") +{ + CheckResult result = check(R"( + local A = {number = 8} + + local B = setmetatable({}, A) + + function B:method() + return "hello!!" + end + )"); + + if (FFlag::LuauSolverV2) + CHECK("{ @metatable { number: number }, { method: (unknown) -> string } }" == toString(requireType("B"), {true})); + else + CHECK("{ @metatable { number: number }, { method: (a) -> string } }" == toString(requireType("B"), {true})); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "react_style_oo") +{ + CheckResult result = check(R"( + local Prototype = {} + + local ClassMetatable = { + __index = Prototype + } + + local BaseClass = (setmetatable({}, ClassMetatable)) + + function BaseClass:extend(name) + local class = { + name=name + } + + class.__index = class + + function class.ctor(props) + return setmetatable({props=props}, class) + end + + return setmetatable(class, getmetatable(self)) + end + + local C = BaseClass:extend('C') + local i = C.ctor({hello='world'}) + + local iName = i.name + local cName = C.name + local hello = i.props.hello + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK("string" == toString(requireType("iName"))); + CHECK("string" == toString(requireType("cName"))); + CHECK("string" == toString(requireType("hello"))); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "cycle_between_object_constructor_and_alias") +{ + CheckResult result = check(R"( + local T = {} + T.__index = T + + function T.new(): T + return setmetatable({}, T) + end + + export type T = typeof(T.new()) + + return T + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + auto module = getMainModule(); + + REQUIRE(module->exportedTypeBindings.count("T")); + + TypeId aliasType = module->exportedTypeBindings["T"].type; + CHECK_MESSAGE(get(follow(aliasType)), "Expected metatable type but got: " << toString(aliasType)); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "promise_type_error_too_complex" * doctest::timeout(2)) +{ + frontend.options.retainFullTypeGraphs = false; + + // Used `luau-reduce` tool to extract a minimal reproduction. + // Credit: https://github.com/evaera/roblox-lua-promise/blob/v4.0.0/lib/init.lua + CheckResult result = check(R"( + --!strict + + local Promise = {} + Promise.prototype = {} + Promise.__index = Promise.prototype + + function Promise._new(traceback, callback, parent) + if parent ~= nil and not Promise.is(parent)then + end + + local self = { + _parent = parent, + } + + parent._consumers[self] = true + setmetatable(self, Promise) + self:_reject() + + return self + end + + function Promise.resolve(...) + return Promise._new(debug.traceback(nil, 2), function(resolve) + end) + end + + function Promise.reject(...) + return Promise._new(debug.traceback(nil, 2), function(_, reject) + end) + end + + function Promise._try(traceback, callback, ...) + return Promise._new(traceback, function(resolve) + end) + end + + function Promise.try(callback, ...) + return Promise._try(debug.traceback(nil, 2), callback, ...) + end + + function Promise._all(traceback, promises, amount) + if #promises == 0 or amount == 0 then + return Promise.resolve({}) + end + return Promise._new(traceback, function(resolve, reject, onCancel) + end) + end + + function Promise.all(promises) + return Promise._all(debug.traceback(nil, 2), promises) + end + + function Promise.allSettled(promises) + return Promise.resolve({}) + end + + function Promise.race(promises) + return Promise._new(debug.traceback(nil, 2), function(resolve, reject, onCancel) + end) + end + + function Promise.each(list, predicate) + return Promise._new(debug.traceback(nil, 2), function(resolve, reject, onCancel) + local predicatePromise = Promise.resolve(predicate(value, index)) + local success, result = predicatePromise:await() + end) + end + + function Promise.is(object) + end + + function Promise.prototype:_reject(...) + self:_finalize() + end + )"); + + LUAU_REQUIRE_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "method_should_not_create_cyclic_type") +{ + ScopedFastFlag sff(FFlag::LuauSolverV2, true); + + CheckResult result = check(R"( + local Component = {} + + function Component:__resolveUpdate(incomingState) + local oldState = self.state + incomingState = oldState + self.state = incomingState + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "cross_module_metatable") +{ + fileResolver.source["game/A"] = R"( + --!strict + local cls = {} + cls.__index = cls + function cls:abc() return 4 end + return cls + )"; + + fileResolver.source["game/B"] = R"( + --!strict + local cls = require(game.A) + local tbl = {} + setmetatable(tbl, cls) + )"; + + CheckResult result = frontend.check("game/B"); + LUAU_REQUIRE_NO_ERRORS(result); + + ModulePtr b = frontend.moduleResolver.getModule("game/B"); + REQUIRE(b); + + std::optional clsBinding = b->getModuleScope()->linearSearchForBinding("tbl"); + REQUIRE(clsBinding); + + TypeId clsType = clsBinding->typeId; + + CHECK("{ @metatable cls, tbl }" == toString(clsType)); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.operators.test.cpp b/tests/TypeInfer.operators.test.cpp index d75f00a2d..d0715669d 100644 --- a/tests/TypeInfer.operators.test.cpp +++ b/tests/TypeInfer.operators.test.cpp @@ -12,9 +12,11 @@ #include "doctest.h" +#include "ScopedFlags.h" + using namespace Luau; -LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) +LUAU_FASTFLAG(LuauSolverV2) TEST_SUITE_BEGIN("TypeInferOperators"); @@ -48,15 +50,11 @@ TEST_CASE_FIXTURE(Fixture, "or_joins_types_with_no_superfluous_union") local x:string = s )"); LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(*requireType("s"), *typeChecker.stringType); + CHECK_EQ(*requireType("s"), *builtinTypes->stringType); } TEST_CASE_FIXTURE(Fixture, "and_does_not_always_add_boolean") { - ScopedFastFlag sff[]{ - {"LuauTryhardAnd", true}, - }; - CheckResult result = check(R"( local s = "a" and 10 local x:boolean|number = s @@ -72,7 +70,7 @@ TEST_CASE_FIXTURE(Fixture, "and_adds_boolean_no_superfluous_union") local x:boolean = s )"); LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(*requireType("x"), *typeChecker.booleanType); + CHECK_EQ(*requireType("x"), *builtinTypes->booleanType); } TEST_CASE_FIXTURE(Fixture, "and_or_ternary") @@ -99,9 +97,9 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "primitive_arith_no_metatable") std::optional retType = first(functionType->retTypes); REQUIRE(retType.has_value()); - CHECK_EQ(typeChecker.numberType, follow(*retType)); - CHECK_EQ(requireType("n"), typeChecker.numberType); - CHECK_EQ(requireType("s"), typeChecker.stringType); + CHECK_EQ(builtinTypes->numberType, follow(*retType)); + CHECK_EQ(requireType("n"), builtinTypes->numberType); + CHECK_EQ(requireType("s"), builtinTypes->stringType); } TEST_CASE_FIXTURE(Fixture, "primitive_arith_no_metatable_with_follows") @@ -112,7 +110,7 @@ TEST_CASE_FIXTURE(Fixture, "primitive_arith_no_metatable_with_follows") )"); LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(requireType("SOLAR_MASS"), typeChecker.numberType); + CHECK_EQ(requireType("SOLAR_MASS"), builtinTypes->numberType); } TEST_CASE_FIXTURE(Fixture, "primitive_arith_possible_metatable") @@ -147,6 +145,22 @@ TEST_CASE_FIXTURE(Fixture, "some_primitive_binary_ops") CHECK_EQ("number", toString(requireType("c"))); } +TEST_CASE_FIXTURE(Fixture, "floor_division_binary_op") +{ + CheckResult result = check(R"( + local a = 4 // 8 + local b = -4 // 9 + local c = 9 + c //= -6.5 + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("number", toString(requireType("a"))); + CHECK_EQ("number", toString(requireType("b"))); + CHECK_EQ("number", toString(requireType("c"))); +} + TEST_CASE_FIXTURE(BuiltinsFixture, "typecheck_overloaded_multiply_that_is_an_intersection") { CheckResult result = check(R"( @@ -174,11 +188,15 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "typecheck_overloaded_multiply_that_is_an_int LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ("Vec3", toString(requireType("a"))); - CHECK_EQ("Vec3", toString(requireType("b"))); - CHECK_EQ("Vec3", toString(requireType("c"))); - CHECK_EQ("Vec3", toString(requireType("d"))); - CHECK_EQ("Vec3", toString(requireType("e"))); + CHECK("Vec3" == toString(requireType("a"))); + CHECK("Vec3" == toString(requireType("b"))); + CHECK("Vec3" == toString(requireType("c"))); + CHECK("Vec3" == toString(requireType("d"))); + + if (FFlag::LuauSolverV2) + CHECK("mul" == toString(requireType("e"))); + else + CHECK_EQ("Vec3", toString(requireType("e"))); } TEST_CASE_FIXTURE(BuiltinsFixture, "typecheck_overloaded_multiply_that_is_an_intersection_on_rhs") @@ -208,11 +226,15 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "typecheck_overloaded_multiply_that_is_an_int LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ("Vec3", toString(requireType("a"))); - CHECK_EQ("Vec3", toString(requireType("b"))); - CHECK_EQ("Vec3", toString(requireType("c"))); - CHECK_EQ("Vec3", toString(requireType("d"))); - CHECK_EQ("Vec3", toString(requireType("e"))); + CHECK("Vec3" == toString(requireType("a"))); + CHECK("Vec3" == toString(requireType("b"))); + CHECK("Vec3" == toString(requireType("c"))); + CHECK("Vec3" == toString(requireType("d"))); + + if (FFlag::LuauSolverV2) + CHECK("mul" == toString(requireType("e"))); + else + CHECK_EQ("Vec3", toString(requireType("e"))); } TEST_CASE_FIXTURE(Fixture, "compare_numbers") @@ -247,9 +269,18 @@ TEST_CASE_FIXTURE(Fixture, "cannot_indirectly_compare_types_that_do_not_have_a_m LUAU_REQUIRE_ERROR_COUNT(1, result); - GenericError* gen = get(result.errors[0]); - - REQUIRE_EQ(gen->message, "Type a cannot be compared with < because it has no metatable"); + if (FFlag::LuauSolverV2) + { + UninhabitedTypeFunction* utf = get(result.errors[0]); + REQUIRE(utf); + REQUIRE_EQ(toString(utf->ty), "lt"); + } + else + { + GenericError* gen = get(result.errors[0]); + REQUIRE(gen != nullptr); + REQUIRE_EQ(gen->message, "Type a cannot be compared with < because it has no metatable"); + } } TEST_CASE_FIXTURE(BuiltinsFixture, "cannot_indirectly_compare_types_that_do_not_offer_overloaded_ordering_operators") @@ -268,9 +299,18 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "cannot_indirectly_compare_types_that_do_not_ LUAU_REQUIRE_ERROR_COUNT(1, result); - GenericError* gen = get(result.errors[0]); - REQUIRE(gen != nullptr); - REQUIRE_EQ(gen->message, "Table M does not offer metamethod __lt"); + if (FFlag::LuauSolverV2) + { + UninhabitedTypeFunction* utf = get(result.errors[0]); + REQUIRE(utf); + REQUIRE_EQ(toString(utf->ty), "lt"); + } + else + { + GenericError* gen = get(result.errors[0]); + REQUIRE(gen != nullptr); + REQUIRE_EQ(gen->message, "Table M does not offer metamethod __lt"); + } } TEST_CASE_FIXTURE(BuiltinsFixture, "cannot_compare_tables_that_do_not_have_the_same_metatable") @@ -353,7 +393,7 @@ TEST_CASE_FIXTURE(Fixture, "compound_assign_mismatch_op") s += true )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(result.errors[0], (TypeError{Location{{2, 13}, {2, 17}}, TypeMismatch{typeChecker.numberType, typeChecker.booleanType}})); + CHECK_EQ(result.errors[0], (TypeError{Location{{2, 13}, {2, 17}}, TypeMismatch{builtinTypes->numberType, builtinTypes->booleanType}})); } TEST_CASE_FIXTURE(Fixture, "compound_assign_mismatch_result") @@ -363,9 +403,17 @@ TEST_CASE_FIXTURE(Fixture, "compound_assign_mismatch_result") s += 10 )"); - LUAU_REQUIRE_ERROR_COUNT(2, result); - CHECK_EQ(result.errors[0], (TypeError{Location{{2, 8}, {2, 9}}, TypeMismatch{typeChecker.numberType, typeChecker.stringType}})); - CHECK_EQ(result.errors[1], (TypeError{Location{{2, 8}, {2, 15}}, TypeMismatch{typeChecker.stringType, typeChecker.numberType}})); + if (FFlag::LuauSolverV2) + { + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(result.errors[0], (TypeError{Location{{2, 8}, {2, 9}}, TypeMismatch{builtinTypes->numberType, builtinTypes->stringType}})); + } + else + { + LUAU_REQUIRE_ERROR_COUNT(2, result); + CHECK_EQ(result.errors[0], (TypeError{Location{{2, 8}, {2, 9}}, TypeMismatch{builtinTypes->numberType, builtinTypes->stringType}})); + CHECK_EQ(result.errors[1], (TypeError{Location{{2, 8}, {2, 15}}, TypeMismatch{builtinTypes->stringType, builtinTypes->numberType}})); + } } TEST_CASE_FIXTURE(BuiltinsFixture, "compound_assign_metatable") @@ -390,6 +438,33 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "compound_assign_metatable") LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(BuiltinsFixture, "compound_assign_metatable_with_changing_return_type") +{ + ScopedFastFlag sff{FFlag::LuauSolverV2, true}; + + CheckResult result = check(R"( + --!strict + type T = { x: number } + local MT = {} + + function MT:__add(other): number + return 112 + end + + local t = setmetatable({x = 2}, MT) + local u = t + 3 + t += 3 + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + TypeMismatch* tm = get(result.errors[0]); + REQUIRE(tm); + + CHECK("t" == toString(tm->wantedType)); + CHECK("number" == toString(tm->givenType)); +} + TEST_CASE_FIXTURE(BuiltinsFixture, "compound_assign_result_must_be_compatible_with_var") { CheckResult result = check(R"( @@ -409,7 +484,13 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "compound_assign_result_must_be_compatible_wi )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK(result.errors[0] == TypeError{Location{{13, 8}, {13, 14}}, TypeMismatch{requireType("x"), builtinTypes->numberType}}); + + CHECK(Location{{13, 8}, {13, 14}} == result.errors[0].location); + + TypeMismatch* tm = get(result.errors[0]); + REQUIRE(tm); + CHECK("x" == toString(tm->wantedType)); + CHECK("number" == toString(tm->givenType)); } TEST_CASE_FIXTURE(BuiltinsFixture, "compound_assign_mismatch_metatable") @@ -482,17 +563,26 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "typecheck_unary_minus") local c = -bar -- disallowed )"); - LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ("string", toString(requireType("a"))); CHECK_EQ("number", toString(requireType("b"))); - if (FFlag::DebugLuauDeferredConstraintResolution) + if (FFlag::LuauSolverV2) { - CHECK(toString(result.errors[0]) == "Type 'bar' could not be converted into 'number'"); + LUAU_REQUIRE_ERROR_COUNT(2, result); + + UninhabitedTypeFunction* utf = get(result.errors[0]); + REQUIRE(utf); + CHECK_EQ(toString(utf->ty), "unm"); + + TypeMismatch* tm = get(result.errors[1]); + REQUIRE(tm); + CHECK_EQ(toString(tm->givenType), "bar"); + CHECK_EQ(*tm->wantedType, *builtinTypes->numberType); } else { + LUAU_REQUIRE_ERROR_COUNT(1, result); + GenericError* gen = get(result.errors[0]); REQUIRE(gen); REQUIRE_EQ(gen->message, "Unary operator '-' not supported by type 'bar'"); @@ -516,17 +606,32 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "typecheck_unary_minus_error") local a = -foo )"); - LUAU_REQUIRE_ERROR_COUNT(1, result); + if (FFlag::LuauSolverV2) + { + LUAU_REQUIRE_ERROR_COUNT(2, result); - CHECK_EQ("string", toString(requireType("a"))); + CHECK(get(result.errors[0])); - TypeMismatch* tm = get(result.errors[0]); - REQUIRE_EQ(*tm->wantedType, *typeChecker.booleanType); - // given type is the typeof(foo) which is complex to compare against + // This second error is spurious. We should not be reporting it. + CHECK(get(result.errors[1])); + } + else + { + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CHECK_EQ("string", toString(requireType("a"))); + + TypeMismatch* tm = get(result.errors[0]); + REQUIRE_EQ(*tm->wantedType, *builtinTypes->booleanType); + // given type is the typeof(foo) which is complex to compare against + } } TEST_CASE_FIXTURE(BuiltinsFixture, "typecheck_unary_len_error") { + // CLI-116463 + ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + CheckResult result = check(R"( --!strict local mt = {} @@ -547,8 +652,10 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "typecheck_unary_len_error") CHECK_EQ("number", toString(requireType("a"))); TypeMismatch* tm = get(result.errors[0]); - REQUIRE_EQ(*tm->wantedType, *typeChecker.numberType); - REQUIRE_EQ(*tm->givenType, *typeChecker.stringType); + REQUIRE_MESSAGE(tm, "Expected a TypeMismatch but got " << result.errors[0]); + + REQUIRE_EQ(*tm->wantedType, *builtinTypes->numberType); + REQUIRE_EQ(*tm->givenType, *builtinTypes->stringType); } TEST_CASE_FIXTURE(BuiltinsFixture, "unary_not_is_boolean") @@ -594,22 +701,34 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "disallow_string_and_types_without_metatables LUAU_REQUIRE_ERROR_COUNT(3, result); - TypeMismatch* tm = get(result.errors[0]); - REQUIRE(tm); - CHECK_EQ(*tm->wantedType, *typeChecker.numberType); - CHECK_EQ(*tm->givenType, *typeChecker.stringType); - - GenericError* gen1 = get(result.errors[1]); - REQUIRE(gen1); - if (FFlag::DebugLuauDeferredConstraintResolution) - CHECK_EQ(gen1->message, "Operator + is not applicable for '{ value: number }' and 'number' because neither type has a metatable"); + if (FFlag::LuauSolverV2) + { + CHECK(get(result.errors[0])); + CHECK(Location{{2, 18}, {2, 30}} == result.errors[0].location); + CHECK(get(result.errors[1])); + CHECK(Location{{8, 18}, {8, 25}} == result.errors[1].location); + CHECK(get(result.errors[2])); + CHECK(Location{{24, 18}, {24, 27}} == result.errors[2].location); + } else - CHECK_EQ(gen1->message, "Binary operator '+' not supported by types 'foo' and 'number'"); - - TypeMismatch* tm2 = get(result.errors[2]); - REQUIRE(tm2); - CHECK_EQ(*tm2->wantedType, *typeChecker.numberType); - CHECK_EQ(*tm2->givenType, *requireType("foo")); + { + TypeMismatch* tm = get(result.errors[0]); + REQUIRE_MESSAGE(tm, "Expected a TypeMismatch but got " << result.errors[0]); + CHECK_EQ(*tm->wantedType, *builtinTypes->numberType); + CHECK_EQ(*tm->givenType, *builtinTypes->stringType); + + GenericError* gen1 = get(result.errors[1]); + REQUIRE(gen1); + if (FFlag::LuauSolverV2) + CHECK_EQ(gen1->message, "Operator + is not applicable for '{ value: number }' and 'number' because neither type has a metatable"); + else + CHECK_EQ(gen1->message, "Binary operator '+' not supported by types 'foo' and 'number'"); + + TypeMismatch* tm2 = get(result.errors[2]); + REQUIRE(tm2); + CHECK_EQ(*tm2->wantedType, *builtinTypes->numberType); + CHECK_EQ(*tm2->givenType, *requireType("foo")); + } } // CLI-29033 @@ -633,8 +752,16 @@ TEST_CASE_FIXTURE(Fixture, "concat_op_on_free_lhs_and_string_rhs") end )"); - LUAU_REQUIRE_ERROR_COUNT(1, result); - REQUIRE(get(result.errors[0])); + if (FFlag::LuauSolverV2) + { + LUAU_REQUIRE_NO_ERRORS(result); + CHECK("(a) -> concat" == toString(requireType("f"))); + } + else + { + LUAU_REQUIRE_ERROR_COUNT(1, result); + REQUIRE(get(result.errors[0])); + } } TEST_CASE_FIXTURE(Fixture, "concat_op_on_string_lhs_and_free_rhs") @@ -647,7 +774,10 @@ TEST_CASE_FIXTURE(Fixture, "concat_op_on_string_lhs_and_free_rhs") LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("(string) -> string", toString(requireType("f"))); + if (FFlag::LuauSolverV2) + CHECK("(a) -> concat" == toString(requireType("f"))); + else + CHECK_EQ("(string) -> string", toString(requireType("f"))); } TEST_CASE_FIXTURE(Fixture, "strict_binary_op_where_lhs_unknown") @@ -662,9 +792,21 @@ TEST_CASE_FIXTURE(Fixture, "strict_binary_op_where_lhs_unknown") src += "end"; CheckResult result = check(src); - LUAU_REQUIRE_ERROR_COUNT(ops.size(), result); - CHECK_EQ("Unknown type used in + operation; consider adding a type annotation to 'a'", toString(result.errors[0])); + if (FFlag::LuauSolverV2) + { + LUAU_REQUIRE_ERROR_COUNT(ops.size(), result); + CHECK_EQ( + "Operator '+' could not be applied to operands of types unknown and unknown; there is no corresponding overload for __add", + toString(result.errors[0]) + ); + CHECK_EQ("Operator '-' could not be applied to operands of types unknown and unknown; there is no corresponding overload for __sub", toString(result.errors[1])); + } + else + { + LUAU_REQUIRE_ERROR_COUNT(ops.size(), result); + CHECK_EQ("Unknown type used in + operation; consider adding a type annotation to 'a'", toString(result.errors[0])); + } } TEST_CASE_FIXTURE(BuiltinsFixture, "and_binexps_dont_unify") @@ -677,7 +819,11 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "and_binexps_dont_unify") end )"); - LUAU_REQUIRE_NO_ERRORS(result); + // This infers a type for `t` of `{unknown}`, and so it makes sense that `t[1].test` would error. + if (FFlag::LuauSolverV2) + LUAU_REQUIRE_ERROR_COUNT(1, result); + else + LUAU_REQUIRE_NO_ERRORS(result); } TEST_CASE_FIXTURE(Fixture, "error_on_invalid_operand_types_to_relational_operators") @@ -690,13 +836,18 @@ TEST_CASE_FIXTURE(Fixture, "error_on_invalid_operand_types_to_relational_operato LUAU_REQUIRE_ERROR_COUNT(1, result); - GenericError* ge = get(result.errors[0]); - REQUIRE(ge); - - if (FFlag::DebugLuauDeferredConstraintResolution) - CHECK_EQ("Types 'boolean' and 'boolean' cannot be compared with relational operator <", ge->message); + if (FFlag::LuauSolverV2) + { + UninhabitedTypeFunction* utf = get(result.errors[0]); + REQUIRE(utf); + REQUIRE_EQ(toString(utf->ty), "lt"); + } else + { + GenericError* ge = get(result.errors[0]); + REQUIRE(ge); CHECK_EQ("Type 'boolean' cannot be compared with relational operator <", ge->message); + } } TEST_CASE_FIXTURE(Fixture, "error_on_invalid_operand_types_to_relational_operators2") @@ -707,18 +858,34 @@ TEST_CASE_FIXTURE(Fixture, "error_on_invalid_operand_types_to_relational_operato local foo = a < b )"); + // If DCR is off and the flag to remove this check in the old solver is on, the expected behavior is no errors. + if (!FFlag::LuauSolverV2) + { + LUAU_REQUIRE_NO_ERRORS(result); + return; + } + LUAU_REQUIRE_ERROR_COUNT(1, result); - GenericError* ge = get(result.errors[0]); - REQUIRE(ge); - if (FFlag::DebugLuauDeferredConstraintResolution) - CHECK_EQ("Types 'number | string' and 'number | string' cannot be compared with relational operator <", ge->message); + if (FFlag::LuauSolverV2) + { + UninhabitedTypeFunction* utf = get(result.errors[0]); + REQUIRE(utf); + REQUIRE_EQ(toString(utf->ty), "lt"); + } else + { + GenericError* ge = get(result.errors[0]); + REQUIRE(ge); CHECK_EQ("Type 'number | string' cannot be compared with relational operator <", ge->message); + } } TEST_CASE_FIXTURE(Fixture, "cli_38355_recursive_union") { + // There's an extra spurious warning here when the new solver is enabled. + ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + CheckResult result = check(R"( --!strict local _ @@ -726,21 +893,24 @@ TEST_CASE_FIXTURE(Fixture, "cli_38355_recursive_union") )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ("Type contains a self-recursive construct that cannot be resolved", toString(result.errors[0])); + CHECK_EQ("Unknown type used in + operation; consider adding a type annotation to '_'", toString(result.errors[0])); } TEST_CASE_FIXTURE(BuiltinsFixture, "UnknownGlobalCompoundAssign") { // In non-strict mode, global definition is still allowed { - CheckResult result = check(R"( - --!nonstrict - a = a + 1 - print(a) - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(toString(result.errors[0]), "Unknown global 'a'"); + if (!FFlag::LuauSolverV2) + { + CheckResult result = check(R"( + --!nonstrict + a = a + 1 + print(a) + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), "Unknown global 'a'"); + } } // In strict mode we no longer generate two errors from lhs @@ -757,14 +927,17 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "UnknownGlobalCompoundAssign") // In non-strict mode, compound assignment is not a definition, it's a modification { - CheckResult result = check(R"( - --!nonstrict - a += 1 - print(a) - )"); - - LUAU_REQUIRE_ERROR_COUNT(2, result); - CHECK_EQ(toString(result.errors[0]), "Unknown global 'a'"); + if (!FFlag::LuauSolverV2) + { + CheckResult result = check(R"( + --!nonstrict + a += 1 + print(a) + )"); + + LUAU_REQUIRE_ERROR_COUNT(2, result); + CHECK_EQ(toString(result.errors[0]), "Unknown global 'a'"); + } } } @@ -802,7 +975,7 @@ local b: number = 1 or a TypeMismatch* tm = get(result.errors[0]); REQUIRE(tm); - CHECK_EQ(typeChecker.numberType, tm->wantedType); + CHECK_EQ(builtinTypes->numberType, tm->wantedType); CHECK_EQ("number?", toString(tm->givenType)); } @@ -842,8 +1015,6 @@ TEST_CASE_FIXTURE(Fixture, "operator_eq_operands_are_not_subtypes_of_each_other_ TEST_CASE_FIXTURE(Fixture, "operator_eq_completely_incompatible") { - ScopedFastFlag sff{"LuauIntersectionTestForEquality", true}; - CheckResult result = check(R"( local a: string | number = "hi" local b: {x: string}? = {x = "bye"} @@ -875,8 +1046,16 @@ TEST_CASE_FIXTURE(Fixture, "infer_any_in_all_modes_when_lhs_is_unknown") end )"); - LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(toString(result.errors[0]), "Unknown type used in + operation; consider adding a type annotation to 'x'"); + if (FFlag::LuauSolverV2) + { + LUAU_REQUIRE_NO_ERRORS(result); + CHECK(toString(requireType("f")) == "(a, b) -> add"); + } + else + { + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), "Unknown type used in + operation; consider adding a type annotation to 'x'"); + } result = check(Mode::Nonstrict, R"( local function f(x, y) @@ -891,6 +1070,146 @@ TEST_CASE_FIXTURE(Fixture, "infer_any_in_all_modes_when_lhs_is_unknown") // the case right now, though. } +TEST_CASE_FIXTURE(Fixture, "infer_type_for_generic_subtraction") +{ + CheckResult result = check(Mode::Strict, R"( + local function f(x, y) + return x - y + end + )"); + + if (FFlag::LuauSolverV2) + { + LUAU_REQUIRE_NO_ERRORS(result); + CHECK(toString(requireType("f")) == "(a, b) -> sub"); + } + else + { + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), "Unknown type used in - operation; consider adding a type annotation to 'x'"); + } +} + +TEST_CASE_FIXTURE(Fixture, "infer_type_for_generic_multiplication") +{ + CheckResult result = check(Mode::Strict, R"( + local function f(x, y) + return x * y + end + )"); + + if (FFlag::LuauSolverV2) + { + LUAU_REQUIRE_NO_ERRORS(result); + CHECK(toString(requireType("f")) == "(a, b) -> mul"); + } + else + { + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), "Unknown type used in * operation; consider adding a type annotation to 'x'"); + } +} + +TEST_CASE_FIXTURE(Fixture, "infer_type_for_generic_division") +{ + CheckResult result = check(Mode::Strict, R"( + local function f(x, y) + return x / y + end + )"); + + if (FFlag::LuauSolverV2) + { + LUAU_REQUIRE_NO_ERRORS(result); + CHECK(toString(requireType("f")) == "(a, b) -> div"); + } + else + { + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), "Unknown type used in / operation; consider adding a type annotation to 'x'"); + } +} + +TEST_CASE_FIXTURE(Fixture, "infer_type_for_generic_floor_division") +{ + CheckResult result = check(Mode::Strict, R"( + local function f(x, y) + return x // y + end + )"); + + if (FFlag::LuauSolverV2) + { + LUAU_REQUIRE_NO_ERRORS(result); + CHECK(toString(requireType("f")) == "(a, b) -> idiv"); + } + else + { + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), "Unknown type used in // operation; consider adding a type annotation to 'x'"); + } +} + +TEST_CASE_FIXTURE(Fixture, "infer_type_for_generic_exponentiation") +{ + CheckResult result = check(Mode::Strict, R"( + local function f(x, y) + return x ^ y + end + )"); + + if (FFlag::LuauSolverV2) + { + LUAU_REQUIRE_NO_ERRORS(result); + CHECK(toString(requireType("f")) == "(a, b) -> pow"); + } + else + { + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), "Unknown type used in ^ operation; consider adding a type annotation to 'x'"); + } +} + +TEST_CASE_FIXTURE(Fixture, "infer_type_for_generic_modulo") +{ + CheckResult result = check(Mode::Strict, R"( + local function f(x, y) + return x % y + end + )"); + + if (FFlag::LuauSolverV2) + { + LUAU_REQUIRE_NO_ERRORS(result); + CHECK(toString(requireType("f")) == "(a, b) -> mod"); + } + else + { + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), "Unknown type used in % operation; consider adding a type annotation to 'x'"); + } +} + +TEST_CASE_FIXTURE(Fixture, "infer_type_for_generic_concat") +{ + CheckResult result = check(Mode::Strict, R"( + local function f(x, y) + return x .. y + end + )"); + + if (FFlag::LuauSolverV2) + { + LUAU_REQUIRE_NO_ERRORS(result); + CHECK(toString(requireType("f")) == "(a, b) -> concat"); + } + else + { + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), "Unknown type used in .. operation; consider adding a type annotation to 'x'"); + } +} + TEST_CASE_FIXTURE(BuiltinsFixture, "equality_operations_succeed_if_any_union_branch_succeeds") { CheckResult result = check(R"( @@ -952,8 +1271,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "expected_types_through_binary_or") TEST_CASE_FIXTURE(ClassFixture, "unrelated_classes_cannot_be_compared") { - ScopedFastFlag sff{"LuauIntersectionTestForEquality", true}; - CheckResult result = check(R"( local a = BaseClass.New() local b = UnrelatedClass.New() @@ -966,8 +1283,6 @@ TEST_CASE_FIXTURE(ClassFixture, "unrelated_classes_cannot_be_compared") TEST_CASE_FIXTURE(Fixture, "unrelated_primitives_cannot_be_compared") { - ScopedFastFlag sff{"LuauIntersectionTestForEquality", true}; - CheckResult result = check(R"( local c = 5 == true )"); @@ -975,34 +1290,10 @@ TEST_CASE_FIXTURE(Fixture, "unrelated_primitives_cannot_be_compared") LUAU_REQUIRE_ERROR_COUNT(1, result); } -TEST_CASE_FIXTURE(BuiltinsFixture, "mm_ops_must_return_a_value") -{ - if (!FFlag::DebugLuauDeferredConstraintResolution) - return; - - CheckResult result = check(R"( - local mm = { - __add = function(self, other) - return - end, - } - - local x = setmetatable({}, mm) - local y = x + 123 - )"); - - LUAU_REQUIRE_ERROR_COUNT(2, result); - - CHECK(requireType("y") == builtinTypes->errorRecoveryType()); - - const GenericError* ge = get(result.errors[1]); - REQUIRE(ge); - CHECK(ge->message == "Metamethod '__add' must return a value"); -} - TEST_CASE_FIXTURE(BuiltinsFixture, "mm_comparisons_must_return_a_boolean") { - if (!FFlag::DebugLuauDeferredConstraintResolution) + // CLI-115687 + if (1 || !FFlag::LuauSolverV2) return; CheckResult result = check(R"( @@ -1036,10 +1327,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "mm_comparisons_must_return_a_boolean") TEST_CASE_FIXTURE(BuiltinsFixture, "reworked_and") { - ScopedFastFlag sff[]{ - {"LuauTryhardAnd", true}, - }; - CheckResult result = check(R"( local a: number? = 5 local b: boolean = (a or 1) > 10 @@ -1053,22 +1340,19 @@ local w = c and 1 CHECK("number?" == toString(requireType("x"))); CHECK("number" == toString(requireType("y"))); - if (FFlag::DebugLuauDeferredConstraintResolution) + if (FFlag::LuauSolverV2) CHECK("false | number" == toString(requireType("z"))); else CHECK("boolean | number" == toString(requireType("z"))); // 'false' widened to boolean - if (FFlag::DebugLuauDeferredConstraintResolution) - CHECK("((false?) & a) | number" == toString(requireType("w"))); + + if (FFlag::LuauSolverV2) + CHECK("number?" == toString(requireType("w"))); else CHECK("(boolean | number)?" == toString(requireType("w"))); } TEST_CASE_FIXTURE(BuiltinsFixture, "reworked_or") { - ScopedFastFlag sff[]{ - {"LuauTryhardAnd", true}, - }; - CheckResult result = check(R"( local a: number | false = 5 local b: number? = 6 @@ -1087,7 +1371,7 @@ local f1 = f or 'f' CHECK("number | string" == toString(requireType("a1"))); CHECK("number" == toString(requireType("b1"))); - if (FFlag::DebugLuauDeferredConstraintResolution) + if (FFlag::LuauSolverV2) { CHECK("string | true" == toString(requireType("c1"))); CHECK("string | true" == toString(requireType("d1"))); @@ -1101,4 +1385,230 @@ local f1 = f or 'f' CHECK("string" == toString(requireType("f1"))); } +TEST_CASE_FIXTURE(BuiltinsFixture, "reducing_and") +{ + CheckResult result = check(R"( +type Foo = { name: string?, flag: boolean? } +local arr: {Foo} = {} + +local function foo(arg: {name: string}?) + local name = if arg and arg.name then arg.name else nil + + table.insert(arr, { + name = name or "", + flag = name ~= nil and name ~= "", + }) +end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "luau_polyfill_is_array_simplified") +{ + CheckResult result = check(R"( + --!strict + return function(value: any) : boolean + if typeof(value) ~= "number" then + return false + end + if value % 1 ~= 0 or value < 1 then + return false + end + return true + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "luau_polyfill_is_array") +{ + // CLI-116480 Subtyping bug: table should probably be a subtype of {[unknown]: unknown} + ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + + CheckResult result = check(R"( +--!strict +return function(value: any): boolean + if typeof(value) ~= "table" then + return false + end + if next(value) == nil then + -- an empty table is an empty array + return true + end + + local length = #value + + if length == 0 then + return false + end + + local count = 0 + local sum = 0 + for key in pairs(value) do + if typeof(key) ~= "number" then + return false + end + if key % 1 ~= 0 or key < 1 then + return false + end + count += 1 + sum += key + end + + return sum == (count * (count + 1) / 2) +end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "luau-polyfill.String.slice") +{ + + CheckResult result = check(R"( +--!strict +local function slice(str: string, startIndexStr: string | number, lastIndexStr: (string | number)?): string + local strLen, invalidBytePosition = utf8.len(str) + assert(strLen ~= nil, ("string `%s` has an invalid byte at position %s"):format(str, tostring(invalidBytePosition))) + local startIndex = tonumber(startIndexStr) + + + -- if no last index length set, go to str length + 1 + local lastIndex = strLen + 1 + + assert(typeof(lastIndex) == "number", "lastIndexStr should convert to number") + + if lastIndex > strLen then + lastIndex = strLen + 1 + end + + local startIndexByte = utf8.offset(str, startIndex) + + return string.sub(str, startIndexByte, startIndexByte) +end + +return slice + + + )"); + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "luau-polyfill.Array.startswith") +{ + // This test also exercises whether the binary operator == passes the correct expected type + // to it's l,r operands + CheckResult result = check(R"( +--!strict +local function startsWith(value: string, substring: string, position: number?): boolean + -- Luau FIXME: we have to use a tmp variable, as Luau doesn't understand the logic below narrow position to `number` + local position_ + if position == nil or position < 1 then + position_ = 1 + else + position_ = position + end + + return value:find(substring, position_, true) == position_ +end + +return startsWith + + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "add_type_function_works") +{ + if (!FFlag::LuauSolverV2) + return; + + CheckResult result = check(R"( + local function add(x, y) + return x + y + end + + local a = add(1, 2) + local b = add("foo", "bar") + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK(toString(requireType("a")) == "number"); + CHECK(toString(requireType("b")) == "add"); + CHECK( + toString(result.errors[0]) == + "Operator '+' could not be applied to operands of types string and string; there is no corresponding overload for __add" + ); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "normalize_strings_comparison") +{ + CheckResult result = check(R"( +local function sortKeysForPrinting(a: any, b) + local typeofA = type(a) + local typeofB = type(b) + -- strings and numbers are sorted numerically/alphabetically + if typeofA == typeofB and (typeofA == "number" or typeofA == "string") then + return a < b + end + -- sort the rest by type name + return typeofA < typeofB +end +)"); + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "compare_singleton_string_to_string") +{ + CheckResult result = check(R"( + local function test(a: string, b: string) + if a == "Pet" and b == "Pet" then + return true + elseif a ~= b then + return a < b + else + return false + end + end +)"); + + // There is a flag to gate turning this off, and this warning is not + // implemented in the new solver, so assert there are no errors. + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "no_infinite_expansion_of_free_type" * doctest::timeout(1.0)) +{ + ScopedFastFlag sff(FFlag::LuauSolverV2, true); + check(R"( + local tooltip = {} + + function tooltip:Show() + local playerGui = self.Player:FindFirstChild("PlayerGui") + for _,c in ipairs(playerGui:GetChildren()) do + if c:IsA("ScreenGui") and c.DisplayOrder > self.Gui.DisplayOrder then + end + end + end + )"); + + // just type-checking this code is enough +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "compound_operator_on_upvalue") +{ + CheckResult result = check(R"( + local byteCursor: number = 0 + + local function advance(bytes: number) + byteCursor += bytes + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.primitives.test.cpp b/tests/TypeInfer.primitives.test.cpp index 02fdfa36e..c3cce9dfb 100644 --- a/tests/TypeInfer.primitives.test.cpp +++ b/tests/TypeInfer.primitives.test.cpp @@ -8,6 +8,7 @@ #include "Luau/VisitType.h" #include "Fixture.h" +#include "DiffAsserts.h" #include "doctest.h" @@ -31,7 +32,7 @@ TEST_CASE_FIXTURE(Fixture, "string_length") )"); LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("number", toString(requireType("t"))); + CHECK_EQ_DIFF(builtinTypes->numberType, requireType("t")); } TEST_CASE_FIXTURE(Fixture, "string_index") @@ -57,7 +58,7 @@ TEST_CASE_FIXTURE(Fixture, "string_method") )"); LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(*requireType("p"), *typeChecker.numberType); + CHECK_EQ(*requireType("p"), *builtinTypes->numberType); } TEST_CASE_FIXTURE(Fixture, "string_function_indirect") @@ -69,21 +70,30 @@ TEST_CASE_FIXTURE(Fixture, "string_function_indirect") )"); LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(*requireType("p"), *typeChecker.stringType); + CHECK_EQ(*requireType("p"), *builtinTypes->stringType); } -TEST_CASE_FIXTURE(Fixture, "CheckMethodsOfNumber") +TEST_CASE_FIXTURE(Fixture, "check_methods_of_number") { CheckResult result = check(R"( -local x: number = 9999 -function x:y(z: number) - local s: string = z -end -)"); + local x: number = 9999 + function x:y(z: number) + local s: string = z + end + )"); LUAU_REQUIRE_ERROR_COUNT(2, result); - CHECK_EQ(toString(result.errors[0]), "Cannot add method to non-table type 'number'"); - CHECK_EQ(toString(result.errors[1]), "Type 'number' could not be converted into 'string'"); + + if (FFlag::LuauSolverV2) + { + CHECK("Expected type table, got 'number' instead" == toString(result.errors[0])); + CHECK("Type 'number' could not be converted into 'string'" == toString(result.errors[1])); + } + else + { + CHECK_EQ(toString(result.errors[0]), "Cannot add method to non-table type 'number'"); + CHECK_EQ(toString(result.errors[1]), "Type 'number' could not be converted into 'string'"); + } } TEST_CASE("singleton_types") @@ -100,4 +110,14 @@ TEST_CASE("singleton_types") CHECK(result.errors.empty()); } +TEST_CASE_FIXTURE(BuiltinsFixture, "property_of_buffers") +{ + CheckResult result = check(R"( + local b = buffer.create(100) + print(b.foo) + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.provisional.test.cpp b/tests/TypeInfer.provisional.test.cpp index 0aacb8aec..514c31c82 100644 --- a/tests/TypeInfer.provisional.test.cpp +++ b/tests/TypeInfer.provisional.test.cpp @@ -1,5 +1,6 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/TypeInfer.h" +#include "Luau/RecursionCounter.h" #include "Fixture.h" @@ -9,7 +10,12 @@ using namespace Luau; -LUAU_FASTFLAG(LuauTypeMismatchInvarianceInError) +LUAU_FASTFLAG(LuauSolverV2); +LUAU_FASTINT(LuauNormalizeCacheLimit); +LUAU_FASTINT(LuauTarjanChildLimit); +LUAU_FASTINT(LuauTypeInferIterationLimit); +LUAU_FASTINT(LuauTypeInferRecursionLimit); +LUAU_FASTINT(LuauTypeInferTypePackLoopLimit); TEST_SUITE_BEGIN("ProvisionalTests"); @@ -49,7 +55,59 @@ TEST_CASE_FIXTURE(Fixture, "typeguard_inference_incomplete") end )"; - CHECK_EQ(expected, decorateWithTypes(code)); + const std::string expectedWithNewSolver = R"( + function f(a:{fn:()->(unknown,...unknown)}): () + if type(a) == 'boolean'then + local a1:{fn:()->(unknown,...unknown)}&boolean=a + elseif a.fn()then + local a2:{fn:()->(unknown,...unknown)}&(class|function|nil|number|string|thread|buffer|table)=a + end + end + )"; + + if (FFlag::LuauSolverV2) + CHECK_EQ(expectedWithNewSolver, decorateWithTypes(code)); + else + CHECK_EQ(expected, decorateWithTypes(code)); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "luau-polyfill.Array.filter") +{ + ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + + // This test exercises the fact that we should reduce sealed/unsealed/free tables + // res is a unsealed table with type {((T & ~nil)?) & any} + // Because we do not reduce it fully, we cannot unify it with `Array = { [number] : T} + // TLDR; reduction needs to reduce the indexer on res so it unifies with Array + CheckResult result = check(R"( +--!strict +-- Implements Javascript's `Array.prototype.filter` as defined below +-- https://developer.cmozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/Array/filter +type Array = { [number]: T } +type callbackFn = (element: T, index: number, array: Array) -> boolean +type callbackFnWithThisArg = (thisArg: U, element: T, index: number, array: Array) -> boolean +type Object = { [string]: any } +return function(t: Array, callback: callbackFn | callbackFnWithThisArg, thisArg: U?): Array + + local len = #t + local res = {} + if thisArg == nil then + for i = 1, len do + local kValue = t[i] + if kValue ~= nil then + if (callback :: callbackFn)(kValue, i, t) then + res[i] = kValue + end + end + end + else + end + + return res +end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); } TEST_CASE_FIXTURE(BuiltinsFixture, "xpcall_returns_what_f_returns") @@ -81,7 +139,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "xpcall_returns_what_f_returns") TEST_CASE_FIXTURE(Fixture, "weirditer_should_not_loop_forever") { // this flag is intentionally here doing nothing to demonstrate that we exit early via case detection - ScopedFastInt sfis{"LuauTypeInferTypePackLoopLimit", 50}; + ScopedFastInt sfis{FInt::LuauTypeInferTypePackLoopLimit, 50}; CheckResult result = check(R"( local function toVertexList(vertices, x, y, ...) @@ -114,6 +172,8 @@ TEST_CASE_FIXTURE(Fixture, "it_should_be_agnostic_of_actual_size") // For now, infer it as just a free table. TEST_CASE_FIXTURE(BuiltinsFixture, "setmetatable_constrains_free_type_into_free_table") { + ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + CheckResult result = check(R"( local a = {} local b @@ -132,6 +192,8 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "setmetatable_constrains_free_type_into_free_ // Luau currently doesn't yet know how to allow assignments when the binding was refined. TEST_CASE_FIXTURE(Fixture, "while_body_are_also_refined") { + ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + CheckResult result = check(R"( type Node = { value: T, child: Node? } @@ -155,6 +217,8 @@ TEST_CASE_FIXTURE(Fixture, "while_body_are_also_refined") // We should be type checking the metamethod at the call site of setmetatable. TEST_CASE_FIXTURE(BuiltinsFixture, "error_on_eq_metamethod_returning_a_type_other_than_boolean") { + ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + CheckResult result = check(R"( local tab = {a = 1} setmetatable(tab, {__eq = function(a, b): number @@ -176,8 +240,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "error_on_eq_metamethod_returning_a_type_othe // We need refine both operands as `never` in the `==` branch. TEST_CASE_FIXTURE(Fixture, "lvalue_equals_another_lvalue_with_no_overlap") { - ScopedFastFlag sff{"LuauIntersectionTestForEquality", true}; - CheckResult result = check(R"( local function f(a: string, b: boolean?) if a == b then @@ -215,16 +277,24 @@ TEST_CASE_FIXTURE(Fixture, "discriminate_from_x_not_equal_to_nil") LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("{| x: string, y: number |}", toString(requireTypeAtPosition({5, 28}))); + if (FFlag::LuauSolverV2) + { + CHECK_EQ("{ x: string, y: number }", toString(requireTypeAtPosition({5, 28}))); + CHECK_EQ("{ x: nil, y: nil }", toString(requireTypeAtPosition({7, 28}))); + } + else + { + CHECK_EQ("{| x: string, y: number |}", toString(requireTypeAtPosition({5, 28}))); - // Should be {| x: nil, y: nil |} - CHECK_EQ("{| x: nil, y: nil |} | {| x: string, y: number |}", toString(requireTypeAtPosition({7, 28}))); + // Should be {| x: nil, y: nil |} + CHECK_EQ("{| x: nil, y: nil |} | {| x: string, y: number |}", toString(requireTypeAtPosition({7, 28}))); + } } TEST_CASE_FIXTURE(BuiltinsFixture, "bail_early_if_unification_is_too_complicated" * doctest::timeout(0.5)) { - ScopedFastInt sffi{"LuauTarjanChildLimit", 1}; - ScopedFastInt sffi2{"LuauTypeInferIterationLimit", 1}; + ScopedFastInt sffi{FInt::LuauTarjanChildLimit, 1}; + ScopedFastInt sffi2{FInt::LuauTypeInferIterationLimit, 1}; CheckResult result = check(R"LUA( local Result @@ -249,9 +319,14 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "bail_early_if_unification_is_too_complicated end )LUA"); - auto it = std::find_if(result.errors.begin(), result.errors.end(), [](TypeError& a) { - return nullptr != get(a); - }); + auto it = std::find_if( + result.errors.begin(), + result.errors.end(), + [](TypeError& a) + { + return nullptr != get(a); + } + ); if (it == result.errors.end()) { dumpErrors(result); @@ -259,13 +334,8 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "bail_early_if_unification_is_too_complicated } } -// FIXME: Move this test to another source file when removing FFlag::LuauLowerBoundsCalculation TEST_CASE_FIXTURE(Fixture, "do_not_ice_when_trying_to_pick_first_of_generic_type_pack") { - ScopedFastFlag sff[]{ - {"LuauReturnAnyInsteadOfICE", true}, - }; - // In-place quantification causes these types to have the wrong types but only because of nasty interaction with prototyping. // The type of f is initially () -> free1... // Then the prototype iterator advances, and checks the function expression assigned to g, which has the type () -> free2... @@ -289,10 +359,19 @@ TEST_CASE_FIXTURE(Fixture, "do_not_ice_when_trying_to_pick_first_of_generic_type LUAU_REQUIRE_NO_ERRORS(result); - // f and g should have the type () -> () - CHECK_EQ("() -> (a...)", toString(requireType("f"))); - CHECK_EQ("() -> (a...)", toString(requireType("g"))); - CHECK_EQ("any", toString(requireType("x"))); // any is returned instead of ICE for now + if (FFlag::LuauSolverV2) + { + CHECK("() -> ()" == toString(requireType("f"))); + CHECK("() -> ()" == toString(requireType("g"))); + CHECK("nil" == toString(requireType("x"))); + } + else + { + // f and g should have the type () -> () + CHECK_EQ("() -> (a...)", toString(requireType("f"))); + CHECK_EQ("() -> (a...)", toString(requireType("g"))); + CHECK_EQ("any", toString(requireType("x"))); // any is returned instead of ICE for now + } } TEST_CASE_FIXTURE(Fixture, "specialization_binds_with_prototypes_too_early") @@ -303,7 +382,10 @@ TEST_CASE_FIXTURE(Fixture, "specialization_binds_with_prototypes_too_early") local s2s: (string) -> string = id )"); - LUAU_REQUIRE_ERRORS(result); // Should not have any errors. + if (FFlag::LuauSolverV2) + LUAU_REQUIRE_NO_ERRORS(result); + else + LUAU_REQUIRE_ERRORS(result); // Should not have any errors. } TEST_CASE_FIXTURE(Fixture, "weird_fail_to_unify_type_pack") @@ -311,7 +393,7 @@ TEST_CASE_FIXTURE(Fixture, "weird_fail_to_unify_type_pack") ScopedFastFlag sff[] = { // I'm not sure why this is broken without DCR, but it seems to be fixed // when DCR is enabled. - {"DebugLuauDeferredConstraintResolution", false}, + {FFlag::LuauSolverV2, false}, }; CheckResult result = check(R"( @@ -322,23 +404,6 @@ TEST_CASE_FIXTURE(Fixture, "weird_fail_to_unify_type_pack") LUAU_REQUIRE_ERRORS(result); // Should not have any errors. } -TEST_CASE_FIXTURE(Fixture, "weird_fail_to_unify_variadic_pack") -{ - ScopedFastFlag sff[] = { - // I'm not sure why this is broken without DCR, but it seems to be fixed - // when DCR is enabled. - {"DebugLuauDeferredConstraintResolution", false}, - }; - - CheckResult result = check(R"( - --!strict - local function f(...) return ... end - local g = function(...) return f(...) end - )"); - - LUAU_REQUIRE_ERRORS(result); // Should not have any errors. -} - // Belongs in TypeInfer.builtins.test.cpp. TEST_CASE_FIXTURE(BuiltinsFixture, "pcall_returns_at_least_two_value_but_function_returns_nothing") { @@ -409,32 +474,10 @@ TEST_CASE_FIXTURE(Fixture, "free_is_not_bound_to_any") CHECK_EQ("((any) -> (), any) -> ()", toString(requireType("foo"))); } -TEST_CASE_FIXTURE(BuiltinsFixture, "greedy_inference_with_shared_self_triggers_function_with_no_returns") -{ - ScopedFastFlag sff{"DebugLuauSharedSelf", true}; - - CheckResult result = check(R"( - local T = {} - T.__index = T - - function T.new() - local self = setmetatable({}, T) - return self:ctor() or self - end - - function T:ctor() - -- oops, no return! - end - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ("Not all codepaths in this function return 'self, a...'.", toString(result.errors[0])); -} - TEST_CASE_FIXTURE(Fixture, "dcr_can_partially_dispatch_a_constraint") { ScopedFastFlag sff[] = { - {"DebugLuauDeferredConstraintResolution", true}, + {FFlag::LuauSolverV2, true}, }; CheckResult result = check(R"( @@ -465,30 +508,39 @@ TEST_CASE_FIXTURE(Fixture, "dcr_can_partially_dispatch_a_constraint") // would append a new constraint number <: *blocked* to the constraint set // to be solved later. This should be faster and theoretically less prone // to cyclic constraint dependencies. - CHECK("(a, number) -> ()" == toString(requireType("prime_iter"))); + + if (FFlag::LuauSolverV2) + CHECK("(unknown, number) -> ()" == toString(requireType("prime_iter"))); + else + CHECK("(a, number) -> ()" == toString(requireType("prime_iter"))); } TEST_CASE_FIXTURE(Fixture, "free_options_cannot_be_unified_together") { + ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + TypeArena arena; TypeId nilType = builtinTypes->nilType; std::unique_ptr scope = std::make_unique(builtinTypes->anyTypePack); - TypeId free1 = arena.addType(FreeTypePack{scope.get()}); + TypeId free1 = arena.addType(FreeType{scope.get()}); TypeId option1 = arena.addType(UnionType{{nilType, free1}}); - TypeId free2 = arena.addType(FreeTypePack{scope.get()}); + TypeId free2 = arena.addType(FreeType{scope.get()}); TypeId option2 = arena.addType(UnionType{{nilType, free2}}); InternalErrorReporter iceHandler; UnifierSharedState sharedState{&iceHandler}; Normalizer normalizer{&arena, builtinTypes, NotNull{&sharedState}}; - Unifier u{NotNull{&normalizer}, Mode::Strict, NotNull{scope.get()}, Location{}, Variance::Covariant}; + Unifier u{NotNull{&normalizer}, NotNull{scope.get()}, Location{}, Variance::Covariant}; + + if (FFlag::LuauSolverV2) + u.enableNewSolver(); u.tryUnify(option1, option2); - CHECK(u.errors.empty()); + CHECK(!u.failure); u.log.commit(); @@ -501,7 +553,7 @@ TEST_CASE_FIXTURE(Fixture, "free_options_cannot_be_unified_together") TEST_CASE_FIXTURE(BuiltinsFixture, "for_in_loop_with_zero_iterators") { - ScopedFastFlag sff{"DebugLuauDeferredConstraintResolution", false}; + ScopedFastFlag sff{FFlag::LuauSolverV2, false}; CheckResult result = check(R"( function no_iter() end @@ -543,12 +595,15 @@ return wrapStrictTable(Constants, "Constants") frontend.check("game/B"); - ModulePtr m = frontend.moduleResolver.modules["game/B"]; + ModulePtr m = frontend.moduleResolver.getModule("game/B"); REQUIRE(m); std::optional result = first(m->returnType); REQUIRE(result); - CHECK(get(*result)); + if (FFlag::LuauSolverV2) + CHECK_EQ("unknown", toString(*result)); + else + CHECK_MESSAGE(get(*result), *result); } TEST_CASE_FIXTURE(BuiltinsFixture, "generic_type_leak_to_module_interface_variadic") @@ -582,12 +637,16 @@ return wrapStrictTable(Constants, "Constants") frontend.check("game/B"); - ModulePtr m = frontend.moduleResolver.modules["game/B"]; + ModulePtr m = frontend.moduleResolver.getModule("game/B"); REQUIRE(m); std::optional result = first(m->returnType); REQUIRE(result); - CHECK(get(*result)); + + if (FFlag::LuauSolverV2) + CHECK("unknown" == toString(*result)); + else + CHECK("any" == toString(*result)); } namespace @@ -770,26 +829,26 @@ TEST_CASE_FIXTURE(Fixture, "assign_table_with_refined_property_with_a_similar_ty end )"); - LUAU_REQUIRE_ERROR_COUNT(1, result); - - if (FFlag::LuauTypeMismatchInvarianceInError) - { - CHECK_EQ(R"(Type '{| x: number? |}' could not be converted into '{| x: number |}' -caused by: - Property 'x' is not compatible. Type 'number?' could not be converted into 'number' in an invariant context)", - toString(result.errors[0])); - } + if (FFlag::LuauSolverV2) + LUAU_REQUIRE_NO_ERRORS(result); // This is wrong. We should be rejecting this assignment. else { - CHECK_EQ(R"(Type '{| x: number? |}' could not be converted into '{| x: number |}' + LUAU_REQUIRE_ERROR_COUNT(1, result); + const std::string expected = R"(Type + '{| x: number? |}' +could not be converted into + '{| x: number |}' caused by: - Property 'x' is not compatible. Type 'number?' could not be converted into 'number')", - toString(result.errors[0])); + Property 'x' is not compatible. +Type 'number?' could not be converted into 'number' in an invariant context)"; + CHECK_EQ(expected, toString(result.errors[0])); } } TEST_CASE_FIXTURE(BuiltinsFixture, "table_insert_with_a_singleton_argument") { + ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + CheckResult result = check(R"( local function foo(t, x) if x == "hi" or x == "bye" then @@ -805,7 +864,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "table_insert_with_a_singleton_argument") LUAU_REQUIRE_NO_ERRORS(result); - if (FFlag::DebugLuauDeferredConstraintResolution) + if (FFlag::LuauSolverV2) CHECK_EQ("{string}", toString(requireType("t"))); else { @@ -814,4 +873,452 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "table_insert_with_a_singleton_argument") } } +// We really should be warning on this. We have no guarantee that T has any properties. +TEST_CASE_FIXTURE(Fixture, "lookup_prop_of_intersection_containing_unions_of_tables_that_have_the_prop") +{ + CheckResult result = check(R"( + local function mergeOptions(options: T & ({variable: string} | {variable: number})) + return options.variable + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + // LUAU_REQUIRE_ERROR_COUNT(1, result); + + // const UnknownProperty* unknownProp = get(result.errors[0]); + // REQUIRE(unknownProp); + + // CHECK("variable" == unknownProp->key); +} + +TEST_CASE_FIXTURE(Fixture, "expected_type_should_be_a_helpful_deduction_guide_for_function_calls") +{ + CheckResult result = check(R"( + type Ref = { val: T } + + local function useRef(x: T): Ref + return { val = x } + end + + local x: Ref = useRef(nil) + )"); + + if (FFlag::LuauSolverV2) + { + // This bug is fixed in the new solver. + LUAU_REQUIRE_ERROR_COUNT(1, result); + } + else + { + // This is actually wrong! Sort of. It's doing the wrong thing, it's actually asking whether + // `{| val: number? |} <: {| val: nil |}` + // instead of the correct way, which is + // `{| val: nil |} <: {| val: number? |}` + LUAU_REQUIRE_NO_ERRORS(result); + } +} + +TEST_CASE_FIXTURE(Fixture, "floating_generics_should_not_be_allowed") +{ + ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + + CheckResult result = check(R"( + local assign : (target: T, source0: U?, source1: V?, source2: W?, ...any) -> T & U & V & W = (nil :: any) + + -- We have a big problem here: The generics U, V, and W are not bound to anything! + -- Things get strange because of this. + local benchmark = assign({}) + local options = benchmark.options + do + local resolve2: any = nil + options.fn({ + resolve = function(...) + resolve2(...) + end, + }) + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "free_options_can_be_unified_together") +{ + ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + + TypeArena arena; + TypeId nilType = builtinTypes->nilType; + + std::unique_ptr scope = std::make_unique(builtinTypes->anyTypePack); + + TypeId free1 = arena.addType(FreeType{scope.get()}); + TypeId option1 = arena.addType(UnionType{{nilType, free1}}); + + TypeId free2 = arena.addType(FreeType{scope.get()}); + TypeId option2 = arena.addType(UnionType{{nilType, free2}}); + + InternalErrorReporter iceHandler; + UnifierSharedState sharedState{&iceHandler}; + Normalizer normalizer{&arena, builtinTypes, NotNull{&sharedState}}; + Unifier u{NotNull{&normalizer}, NotNull{scope.get()}, Location{}, Variance::Covariant}; + + if (FFlag::LuauSolverV2) + u.enableNewSolver(); + + u.tryUnify(option1, option2); + + CHECK(!u.failure); + + u.log.commit(); + + ToStringOptions opts; + CHECK("a?" == toString(option1, opts)); + CHECK("b?" == toString(option2, opts)); // should be `a?`. +} + +TEST_CASE_FIXTURE(Fixture, "unify_more_complex_unions_that_include_nil") +{ + CheckResult result = check(R"( + type Record = {prop: (string | boolean)?} + + function concatPagination(prop: (string | boolean | nil)?): Record + return {prop = prop} + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "optional_class_instances_are_invariant_old_solver") +{ + ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + + createSomeClasses(&frontend); + + CheckResult result = check(R"( + function foo(ref: {current: Parent?}) + end + + function bar(ref: {current: Child?}) + foo(ref) + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "optional_class_instances_are_invariant_new_solver") +{ + ScopedFastFlag sff{FFlag::LuauSolverV2, true}; + + createSomeClasses(&frontend); + + CheckResult result = check(R"( + function foo(ref: {read current: Parent?}) + end + + function bar(ref: {read current: Child?}) + foo(ref) + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(0, result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "luau-polyfill.Map.entries") +{ + + fileResolver.source["Module/Map"] = R"( +--!strict + +type Object = { [any]: any } +type Array = { [number]: T } +type Table = { [T]: V } +type Tuple = Array + +local Map = {} + +export type Map = { + size: number, + -- method definitions + set: (self: Map, K, V) -> Map, + get: (self: Map, K) -> V | nil, + clear: (self: Map) -> (), + delete: (self: Map, K) -> boolean, + has: (self: Map, K) -> boolean, + keys: (self: Map) -> Array, + values: (self: Map) -> Array, + entries: (self: Map) -> Array>, + ipairs: (self: Map) -> any, + [K]: V, + _map: { [K]: V }, + _array: { [number]: K }, +} + +function Map:entries() + return {} +end + +local function coerceToTable(mapLike: Map | Table): Array> + local e = mapLike:entries(); + return e +end + + )"; + + CheckResult result = frontend.check("Module/Map"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +// We would prefer this unification to be able to complete, but at least it should not crash +TEST_CASE_FIXTURE(BuiltinsFixture, "table_unification_infinite_recursion") +{ + // The new solver doesn't recurse as heavily in this situation. + ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + +#if defined(_NOOPT) || defined(_DEBUG) + ScopedFastInt LuauTypeInferRecursionLimit{FInt::LuauTypeInferRecursionLimit, 100}; +#endif + + fileResolver.source["game/A"] = R"( +local tbl = {} + +function tbl:f1(state) + self.someNonExistentvalue2 = state +end + +function tbl:f2() + self.someNonExistentvalue:Dc() +end + +function tbl:f3() + self:f2() + self:f1(false) +end +return tbl + )"; + + fileResolver.source["game/B"] = R"( +local tbl = require(game.A) +tbl:f3() + )"; + + CheckResult result = frontend.check("game/B"); + LUAU_REQUIRE_ERROR_COUNT(1, result); +} + +// Ideally, unification with any will not cause a 2^n normalization of a function overload +TEST_CASE_FIXTURE(BuiltinsFixture, "normalization_limit_in_unify_with_any") +{ + ScopedFastFlag sff[] = { + {FFlag::LuauSolverV2, true}, + }; + + // With default limit, this test will take 10 seconds in NoOpt + ScopedFastInt luauNormalizeCacheLimit{FInt::LuauNormalizeCacheLimit, 1000}; + + // Build a function type with a large overload set + const int parts = 100; + std::string source; + + for (int i = 0; i < parts; i++) + formatAppend(source, "type T%d = { f%d: number }\n", i, i); + + source += "type Instance = { new: (('s0', extra: Instance?) -> T0)"; + + for (int i = 1; i < parts; i++) + formatAppend(source, " & (('s%d', extra: Instance?) -> T%d)", i, i); + + source += " }\n"; + + source += R"( +local Instance: Instance = {} :: any + +local function foo(a: typeof(Instance.new)) return if a then 2 else 3 end + +foo(1 :: any) +)"; + + CheckResult result = check(source); + + LUAU_REQUIRE_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "luau_roact_useState_nilable_state_1") +{ + ScopedFastFlag sff{FFlag::LuauSolverV2, true}; + + CheckResult result = check(R"( + type Dispatch = (A) -> () + type BasicStateAction = ((S) -> S) | S + + type ScriptConnection = { Disconnect: (ScriptConnection) -> () } + + local blah = nil :: any + + local function useState( + initialState: (() -> S) | S, + ... + ): (S, Dispatch>) + return blah, blah + end + + local a, b = useState(nil :: ScriptConnection?) + + if a then + a:Disconnect() + b(nil :: ScriptConnection?) + end + )"); + + if (FFlag::LuauSolverV2) + LUAU_REQUIRE_NO_ERRORS(result); + else + { + // This is a known bug in the old solver. + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CHECK(Location{{19, 14}, {19, 41}} == result.errors[0].location); + } +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "luau_roact_useState_minimization") +{ + // We don't expect this test to work on the old solver, but it also does not yet work on the new solver. + // So, we can't just put a scoped fast flag here, or it would block CI. + if (!FFlag::LuauSolverV2) + return; + + CheckResult result = check(R"( + type BasicStateAction = ((S) -> S) | S + type Dispatch = (A) -> () + + local function useState( + initialState: (() -> S) | S + ): (S, Dispatch>) + -- fake impl that obeys types + local val = if type(initialState) == "function" then initialState() else initialState + return val, function(value) + return value + end + end + + local test, setTest = useState(nil :: string?) + + setTest(nil) -- this line causes the type to be narrowed in the old solver!!! + + local function update(value: string) + print(test) + setTest(value) + end + + update("hello") + )"); + + // We actually expect this code to be fine. + LUAU_REQUIRE_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "bin_prov") +{ + CheckResult result = check(R"( + local Bin = {} + + function Bin:add(item) + self.head = { item = item} + return item + end + + function Bin:destroy() + while self.head do + local item = self.head.item + if type(item) == "function" then + item() + elseif item.Destroy ~= nil then + end + self.head = self.head.next + end + end + )"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "update_phonemes_minimized") +{ + CheckResult result = check(R"( + local video + function(response) + for index = 1, #response do + video = video + end + return video + end + )"); + + LUAU_REQUIRE_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "table_containing_non_final_type_is_erroneously_cached") +{ + TypeArena arena; + Scope globalScope(builtinTypes->anyTypePack); + UnifierSharedState sharedState{&ice}; + Normalizer normalizer{&arena, builtinTypes, NotNull{&sharedState}}; + + TypeId tableTy = arena.addType(TableType{}); + TableType* table = getMutable(tableTy); + REQUIRE(table); + + TypeId freeTy = arena.freshType(&globalScope); + + table->props["foo"] = Property::rw(freeTy); + + std::shared_ptr n1 = normalizer.normalize(tableTy); + std::shared_ptr n2 = normalizer.normalize(tableTy); + + // This should not hold + CHECK(n1 == n2); +} + +// This is doable with the new solver, but there are some problems we have to work out first. +// CLI-111113 +TEST_CASE_FIXTURE(Fixture, "we_cannot_infer_functions_that_return_inconsistently") +{ + CheckResult result = check(R"( + function find_first(tbl: {T}, el) + for i, e in tbl do + if e == el then + return i + end + end + return nil + end + )"); + +#if 0 + // This #if block describes what should happen. + LUAU_CHECK_NO_ERRORS(result); + + // The second argument has type unknown because the == operator does not + // constrain the type of el. + CHECK("({T}, unknown) -> number?" == toString(requireType("find_first"))); +#else + // This is what actually happens right now. + + if (FFlag::LuauSolverV2) + { + LUAU_CHECK_ERROR_COUNT(2, result); + CHECK("({T}, unknown) -> number" == toString(requireType("find_first"))); + } + else + { + LUAU_CHECK_ERROR_COUNT(1, result); + + CHECK("({T}, b) -> number" == toString(requireType("find_first"))); + } +#endif +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.refinements.test.cpp b/tests/TypeInfer.refinements.test.cpp index 570cf278e..615bebcdf 100644 --- a/tests/TypeInfer.refinements.test.cpp +++ b/tests/TypeInfer.refinements.test.cpp @@ -7,15 +7,18 @@ #include "doctest.h" -LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) -LUAU_FASTFLAG(LuauNegatedClassTypes) +LUAU_FASTFLAG(LuauSolverV2) using namespace Luau; namespace { std::optional> magicFunctionInstanceIsA( - TypeChecker& typeChecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate) + TypeChecker& typeChecker, + const ScopePtr& scope, + const AstExprCall& expr, + WithPredicate withPredicate +) { if (expr.args.size != 1) return std::nullopt; @@ -30,9 +33,8 @@ std::optional> magicFunctionInstanceIsA( if (!lvalue || !tfun) return std::nullopt; - unfreeze(typeChecker.globalTypes); - TypePackId booleanPack = typeChecker.globalTypes.addTypePack({typeChecker.booleanType}); - freeze(typeChecker.globalTypes); + ModulePtr module = typeChecker.currentModule; + TypePackId booleanPack = module->internalTypes.addTypePack({typeChecker.booleanType}); return WithPredicate{booleanPack, {IsAPredicate{std::move(*lvalue), expr.location, tfun->type}}}; } @@ -62,47 +64,47 @@ struct RefinementClassFixture : BuiltinsFixture { RefinementClassFixture() { - TypeArena& arena = typeChecker.globalTypes; - NotNull scope{typeChecker.globalScope.get()}; + TypeArena& arena = frontend.globals.globalTypes; + NotNull scope{frontend.globals.globalScope.get()}; - std::optional rootSuper = FFlag::LuauNegatedClassTypes ? std::make_optional(typeChecker.builtinTypes->classType) : std::nullopt; + std::optional rootSuper = std::make_optional(builtinTypes->classType); unfreeze(arena); - TypeId vec3 = arena.addType(ClassType{"Vector3", {}, rootSuper, std::nullopt, {}, nullptr, "Test"}); + TypeId vec3 = arena.addType(ClassType{"Vector3", {}, rootSuper, std::nullopt, {}, nullptr, "Test", {}}); getMutable(vec3)->props = { - {"X", Property{typeChecker.numberType}}, - {"Y", Property{typeChecker.numberType}}, - {"Z", Property{typeChecker.numberType}}, + {"X", Property{builtinTypes->numberType}}, + {"Y", Property{builtinTypes->numberType}}, + {"Z", Property{builtinTypes->numberType}}, }; - TypeId inst = arena.addType(ClassType{"Instance", {}, rootSuper, std::nullopt, {}, nullptr, "Test"}); + TypeId inst = arena.addType(ClassType{"Instance", {}, rootSuper, std::nullopt, {}, nullptr, "Test", {}}); - TypePackId isAParams = arena.addTypePack({inst, typeChecker.stringType}); - TypePackId isARets = arena.addTypePack({typeChecker.booleanType}); + TypePackId isAParams = arena.addTypePack({inst, builtinTypes->stringType}); + TypePackId isARets = arena.addTypePack({builtinTypes->booleanType}); TypeId isA = arena.addType(FunctionType{isAParams, isARets}); getMutable(isA)->magicFunction = magicFunctionInstanceIsA; getMutable(isA)->dcrMagicRefinement = dcrMagicRefinementInstanceIsA; getMutable(inst)->props = { - {"Name", Property{typeChecker.stringType}}, + {"Name", Property{builtinTypes->stringType}}, {"IsA", Property{isA}}, }; - TypeId folder = typeChecker.globalTypes.addType(ClassType{"Folder", {}, inst, std::nullopt, {}, nullptr, "Test"}); - TypeId part = typeChecker.globalTypes.addType(ClassType{"Part", {}, inst, std::nullopt, {}, nullptr, "Test"}); + TypeId folder = frontend.globals.globalTypes.addType(ClassType{"Folder", {}, inst, std::nullopt, {}, nullptr, "Test", {}}); + TypeId part = frontend.globals.globalTypes.addType(ClassType{"Part", {}, inst, std::nullopt, {}, nullptr, "Test", {}}); getMutable(part)->props = { {"Position", Property{vec3}}, }; - typeChecker.globalScope->exportedTypeBindings["Vector3"] = TypeFun{{}, vec3}; - typeChecker.globalScope->exportedTypeBindings["Instance"] = TypeFun{{}, inst}; - typeChecker.globalScope->exportedTypeBindings["Folder"] = TypeFun{{}, folder}; - typeChecker.globalScope->exportedTypeBindings["Part"] = TypeFun{{}, part}; + frontend.globals.globalScope->exportedTypeBindings["Vector3"] = TypeFun{{}, vec3}; + frontend.globals.globalScope->exportedTypeBindings["Instance"] = TypeFun{{}, inst}; + frontend.globals.globalScope->exportedTypeBindings["Folder"] = TypeFun{{}, folder}; + frontend.globals.globalScope->exportedTypeBindings["Part"] = TypeFun{{}, part}; - for (const auto& [name, ty] : typeChecker.globalScope->exportedTypeBindings) + for (const auto& [name, ty] : frontend.globals.globalScope->exportedTypeBindings) persist(ty.type); - freeze(typeChecker.globalTypes); + freeze(frontend.globals.globalTypes); } }; } // namespace @@ -253,7 +255,7 @@ TEST_CASE_FIXTURE(Fixture, "a_and_b_or_a_and_c") CHECK_EQ("string", toString(requireTypeAtPosition({3, 28}))); CHECK_EQ("number?", toString(requireTypeAtPosition({4, 28}))); - if (FFlag::DebugLuauDeferredConstraintResolution) + if (FFlag::LuauSolverV2) CHECK_EQ("boolean", toString(requireTypeAtPosition({5, 28}))); else CHECK_EQ("true", toString(requireTypeAtPosition({5, 28}))); // oh no! :( @@ -276,7 +278,7 @@ TEST_CASE_FIXTURE(Fixture, "type_assertion_expr_carry_its_constraints") LUAU_REQUIRE_NO_ERRORS(result); - if (FFlag::DebugLuauDeferredConstraintResolution) + if (FFlag::LuauSolverV2) { CHECK_EQ("number?", toString(requireTypeAtPosition({3, 26}))); CHECK_EQ("string?", toString(requireTypeAtPosition({4, 26}))); @@ -292,29 +294,135 @@ TEST_CASE_FIXTURE(Fixture, "type_assertion_expr_carry_its_constraints") TEST_CASE_FIXTURE(BuiltinsFixture, "typeguard_in_if_condition_position") { CheckResult result = check(R"( - function f(s: any) + function f(s: any, t: unknown) if type(s) == "number" then local n = s end + if type(t) == "number" then + local n = t + end end )"); LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("number", toString(requireTypeAtPosition({3, 26}))); + // DCR changes refinements to preserve error suppression. + if (FFlag::LuauSolverV2) + CHECK_EQ("*error-type* | number", toString(requireTypeAtPosition({3, 26}))); + else + CHECK_EQ("number", toString(requireTypeAtPosition({3, 26}))); + CHECK_EQ("number", toString(requireTypeAtPosition({6, 26}))); } TEST_CASE_FIXTURE(BuiltinsFixture, "typeguard_in_assert_position") { CheckResult result = check(R"( - local a - assert(type(a) == "number") - local b = a + function f(a) + assert(type(a) == "number") + local b = a + return b + end )"); LUAU_REQUIRE_NO_ERRORS(result); - REQUIRE_EQ("number", toString(requireType("b"))); + if (FFlag::LuauSolverV2) + CHECK("(a) -> a & number" == toString(requireType("f"))); + else + CHECK("(a) -> number" == toString(requireType("f"))); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "refine_unknown_to_table_then_test_a_prop") +{ + CheckResult result = check(R"( + local function f(x: unknown): string? + if typeof(x) == "table" then + if typeof(x.foo) == "string" then + return x.foo + end + end + + return nil + end + )"); + + if (FFlag::LuauSolverV2) + LUAU_REQUIRE_NO_ERRORS(result); + else + { + LUAU_REQUIRE_ERROR_COUNT(2, result); + + for (size_t i = 0; i < result.errors.size(); i++) + { + const UnknownProperty* up = get(result.errors[i]); + REQUIRE_EQ("foo", up->key); + REQUIRE_EQ("unknown", toString(up->table)); + } + } +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "refine_unknown_to_table_then_test_a_nested_prop") +{ + CheckResult result = check(R"( + local function f(x: unknown): string? + if typeof(x) == "table" then + -- this should error, `x.foo` is an unknown property + if typeof(x.foo.bar) == "string" then + return x.foo.bar + end + end + + return nil + end + )"); + + if (FFlag::LuauSolverV2) + { + LUAU_REQUIRE_ERROR_COUNT(1, result); + const UnknownProperty* up = get(result.errors[0]); + REQUIRE_EQ("bar", up->key); + REQUIRE_EQ("unknown", toString(up->table)); + } + else + { + LUAU_REQUIRE_ERROR_COUNT(2, result); + + for (size_t i = 0; i < result.errors.size(); i++) + { + const UnknownProperty* up = get(result.errors[i]); + REQUIRE_EQ("foo", up->key); + REQUIRE_EQ("unknown", toString(up->table)); + } + } +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "refine_unknown_to_table_then_test_a_tested_nested_prop") +{ + CheckResult result = check(R"( + local function f(x: unknown): string? + if typeof(x) == "table" then + if typeof(x.foo) == "table" and typeof(x.foo.bar) == "string" then + return x.foo.bar + end + end + + return nil + end + )"); + + if (FFlag::LuauSolverV2) + LUAU_REQUIRE_NO_ERRORS(result); + else + { + LUAU_REQUIRE_ERROR_COUNT(3, result); + + for (size_t i = 0; i < result.errors.size(); i++) + { + const UnknownProperty* up = get(result.errors[i]); + REQUIRE_EQ("foo", up->key); + REQUIRE_EQ("unknown", toString(up->table)); + } + } } TEST_CASE_FIXTURE(BuiltinsFixture, "call_an_incompatible_function_after_using_typeguard") @@ -324,16 +432,26 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "call_an_incompatible_function_after_using_ty return x end - local function g(x: any) + local function g(x: unknown) + if type(x) == "string" then + f(x) + end + end + + local function h(x: any) if type(x) == "string" then f(x) end end )"); - LUAU_REQUIRE_ERROR_COUNT(1, result); + LUAU_REQUIRE_ERROR_COUNT(2, result); + + CHECK("Type 'string' could not be converted into 'number'" == toString(result.errors[0])); + CHECK(Location{{ 7, 18}, {7, 19}} == result.errors[0].location); - CHECK_EQ("Type 'string' could not be converted into 'number'", toString(result.errors[0])); + CHECK("Type 'string' could not be converted into 'number'" == toString(result.errors[1])); + CHECK(Location{{ 13, 18}, {13, 19}} == result.errors[1].location); } TEST_CASE_FIXTURE(BuiltinsFixture, "impossible_type_narrow_is_not_an_error") @@ -368,9 +486,10 @@ TEST_CASE_FIXTURE(Fixture, "truthy_constraint_on_properties") LUAU_REQUIRE_NO_ERRORS(result); - if (FFlag::DebugLuauDeferredConstraintResolution) + if (FFlag::LuauSolverV2) { - CHECK("{| x: number |}" == toString(requireTypeAtPosition({4, 23}))); + // CLI-115281 - Types produced by refinements don't always get simplified + CHECK("{ x: number? } & { x: ~(false?) }" == toString(requireTypeAtPosition({4, 23}))); CHECK("number" == toString(requireTypeAtPosition({5, 26}))); } @@ -454,7 +573,7 @@ TEST_CASE_FIXTURE(Fixture, "term_is_equal_to_an_lvalue") LUAU_REQUIRE_NO_ERRORS(result); - if (FFlag::DebugLuauDeferredConstraintResolution) + if (FFlag::LuauSolverV2) { CHECK_EQ(toString(requireTypeAtPosition({3, 28})), R"("hello")"); // a == "hello" CHECK_EQ(toString(requireTypeAtPosition({5, 28})), R"(((string & ~"hello") | number)?)"); // a ~= "hello" @@ -480,8 +599,11 @@ TEST_CASE_FIXTURE(Fixture, "lvalue_is_not_nil") LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(toString(requireTypeAtPosition({3, 28})), "number | string"); // a ~= nil - CHECK_EQ(toString(requireTypeAtPosition({5, 28})), "(number | string)?"); // a == nil + CHECK_EQ(toString(requireTypeAtPosition({3, 28})), "number | string"); // a ~= nil + if (FFlag::LuauSolverV2) + CHECK_EQ(toString(requireTypeAtPosition({5, 28})), "nil"); // a == nil :) + else + CHECK_EQ(toString(requireTypeAtPosition({5, 28})), "(number | string)?"); // a == nil } TEST_CASE_FIXTURE(Fixture, "free_type_is_equal_to_an_lvalue") @@ -496,7 +618,11 @@ TEST_CASE_FIXTURE(Fixture, "free_type_is_equal_to_an_lvalue") LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(toString(requireTypeAtPosition({3, 33})), "a"); // a == b + if (FFlag::LuauSolverV2) + CHECK_EQ(toString(requireTypeAtPosition({3, 33})), "unknown"); // a == b + else + CHECK_EQ(toString(requireTypeAtPosition({3, 33})), "a"); // a == b + CHECK_EQ(toString(requireTypeAtPosition({3, 36})), "string?"); // a == b } @@ -512,8 +638,11 @@ TEST_CASE_FIXTURE(Fixture, "unknown_lvalue_is_not_synonymous_with_other_on_not_e LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(toString(requireTypeAtPosition({3, 33})), "any"); // a ~= b - CHECK_EQ(toString(requireTypeAtPosition({3, 36})), "{| x: number |}?"); // a ~= b + CHECK_EQ(toString(requireTypeAtPosition({3, 33})), "any"); // a ~= b + if (FFlag::LuauSolverV2) + CHECK_EQ(toString(requireTypeAtPosition({3, 36})), "{ x: number }?"); // a ~= b + else + CHECK_EQ(toString(requireTypeAtPosition({3, 36})), "{| x: number |}?"); // a ~= b } TEST_CASE_FIXTURE(Fixture, "string_not_equal_to_string_or_nil") @@ -535,7 +664,7 @@ TEST_CASE_FIXTURE(Fixture, "string_not_equal_to_string_or_nil") CHECK_EQ(toString(requireTypeAtPosition({6, 29})), "string"); // a ~= b CHECK_EQ(toString(requireTypeAtPosition({6, 32})), "string?"); // a ~= b - if (FFlag::DebugLuauDeferredConstraintResolution) + if (FFlag::LuauSolverV2) { CHECK_EQ(toString(requireTypeAtPosition({8, 29})), "string?"); // a == b CHECK_EQ(toString(requireTypeAtPosition({8, 32})), "string?"); // a == b @@ -574,7 +703,10 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "type_narrow_to_vector") LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("*error-type*", toString(requireTypeAtPosition({3, 28}))); + if (FFlag::LuauSolverV2) + CHECK_EQ("never", toString(requireTypeAtPosition({3, 28}))); + else + CHECK_EQ("*error-type*", toString(requireTypeAtPosition({3, 28}))); } TEST_CASE_FIXTURE(BuiltinsFixture, "nonoptional_type_can_narrow_to_nil_if_sense_is_true") @@ -597,11 +729,23 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "nonoptional_type_can_narrow_to_nil_if_sense_ LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("nil", toString(requireTypeAtPosition({4, 24}))); // type(v) == "nil" - CHECK_EQ("string", toString(requireTypeAtPosition({6, 24}))); // type(v) ~= "nil" + if (FFlag::LuauSolverV2) + { + // CLI-115281 Types produced by refinements do not consistently get simplified + CHECK_EQ("(nil & string)?", toString(requireTypeAtPosition({4, 24}))); // type(v) == "nil" + CHECK_EQ("(boolean | buffer | class | function | number | string | table | thread) & string", toString(requireTypeAtPosition({6, 24}))); // type(v) ~= "nil" + + CHECK_EQ("(nil & string)?", toString(requireTypeAtPosition({10, 24}))); // equivalent to type(v) == "nil" + CHECK_EQ("(boolean | buffer | class | function | number | string | table | thread) & string", toString(requireTypeAtPosition({12, 24}))); // equivalent to type(v) ~= "nil" + } + else + { + CHECK_EQ("nil", toString(requireTypeAtPosition({4, 24}))); // type(v) == "nil" + CHECK_EQ("string", toString(requireTypeAtPosition({6, 24}))); // type(v) ~= "nil" - CHECK_EQ("nil", toString(requireTypeAtPosition({10, 24}))); // equivalent to type(v) == "nil" - CHECK_EQ("string", toString(requireTypeAtPosition({12, 24}))); // equivalent to type(v) ~= "nil" + CHECK_EQ("nil", toString(requireTypeAtPosition({10, 24}))); // equivalent to type(v) == "nil" + CHECK_EQ("string", toString(requireTypeAtPosition({12, 24}))); // equivalent to type(v) ~= "nil" + } } TEST_CASE_FIXTURE(BuiltinsFixture, "typeguard_not_to_be_string") @@ -636,8 +780,11 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "typeguard_narrows_for_table") LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("{| x: number |} | {| y: boolean |}", toString(requireTypeAtPosition({3, 28}))); // type(x) == "table" - CHECK_EQ("string", toString(requireTypeAtPosition({5, 28}))); // type(x) ~= "table" + if (FFlag::LuauSolverV2) + CHECK_EQ("{ x: number } | { y: boolean }", toString(requireTypeAtPosition({3, 28}))); // type(x) == "table" + else + CHECK_EQ("{| x: number |} | {| y: boolean |}", toString(requireTypeAtPosition({3, 28}))); // type(x) == "table" + CHECK_EQ("string", toString(requireTypeAtPosition({5, 28}))); // type(x) ~= "table" } TEST_CASE_FIXTURE(BuiltinsFixture, "typeguard_narrows_for_functions") @@ -675,7 +822,10 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "type_guard_can_filter_for_intersection_of_ta ToStringOptions opts; opts.exhaustive = true; - CHECK_EQ("{| x: number |} & {| y: number |}", toString(requireTypeAtPosition({4, 28}), opts)); + if (FFlag::LuauSolverV2) + CHECK_EQ("{ x: number } & { y: number }", toString(requireTypeAtPosition({4, 28}), opts)); + else + CHECK_EQ("{| x: number |} & {| y: number |}", toString(requireTypeAtPosition({4, 28}), opts)); CHECK_EQ("nil", toString(requireTypeAtPosition({6, 28}))); } @@ -713,7 +863,13 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "type_guard_narrowed_into_nothingness") LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("never", toString(requireTypeAtPosition({3, 28}))); + if (FFlag::LuauSolverV2) + { + // CLI-115281 Types produced by refinements do not consistently get simplified + CHECK_EQ("{ x: number } & ~table", toString(requireTypeAtPosition({3, 28}))); + } + else + CHECK_EQ("never", toString(requireTypeAtPosition({3, 28}))); } TEST_CASE_FIXTURE(Fixture, "not_a_or_not_b") @@ -787,16 +943,23 @@ TEST_CASE_FIXTURE(Fixture, "not_a_and_not_b2") TEST_CASE_FIXTURE(BuiltinsFixture, "either_number_or_string") { CheckResult result = check(R"( - local function f(x: any) + local function f(x: any, y: unknown) if type(x) == "number" or type(x) == "string" then local foo = x end + if type(y) == "number" or type(y) == "string" then + local foo = y + end end )"); LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("number | string", toString(requireTypeAtPosition({3, 28}))); + if (FFlag::LuauSolverV2) + CHECK_EQ("*error-type* | number | string", toString(requireTypeAtPosition({3, 28}))); + else + CHECK_EQ("number | string", toString(requireTypeAtPosition({3, 28}))); + CHECK_EQ("number | string", toString(requireTypeAtPosition({6, 28}))); } TEST_CASE_FIXTURE(Fixture, "not_t_or_some_prop_of_t") @@ -811,8 +974,11 @@ TEST_CASE_FIXTURE(Fixture, "not_t_or_some_prop_of_t") LUAU_REQUIRE_NO_ERRORS(result); - if (FFlag::DebugLuauDeferredConstraintResolution) - CHECK_EQ("{| x: true |}?", toString(requireTypeAtPosition({3, 28}))); + if (FFlag::LuauSolverV2) + { + // CLI-115281 Types produced by refinements do not consistently get simplified + CHECK_EQ("({ x: boolean } & { x: ~(false?) })?", toString(requireTypeAtPosition({3, 28}))); + } else CHECK_EQ("{| x: boolean |}?", toString(requireTypeAtPosition({3, 28}))); } @@ -908,15 +1074,76 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "type_comparison_ifelse_expression") function f(v:any) return if typeof(v) == "number" then v else returnOne(v) end + + function g(v:unknown) + return if typeof(v) == "number" then v else returnOne(v) + end )"); LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("number", toString(requireTypeAtPosition({6, 49}))); - if (FFlag::DebugLuauDeferredConstraintResolution) - CHECK_EQ("~number", toString(requireTypeAtPosition({6, 66}))); + if (FFlag::LuauSolverV2) + { + CHECK_EQ("*error-type* | number", toString(requireTypeAtPosition({6, 49}))); + CHECK_EQ("*error-type* | ~number", toString(requireTypeAtPosition({6, 66}))); + } else + { + CHECK_EQ("number", toString(requireTypeAtPosition({6, 49}))); CHECK_EQ("any", toString(requireTypeAtPosition({6, 66}))); + } + + CHECK_EQ("number", toString(requireTypeAtPosition({10, 49}))); + if (FFlag::LuauSolverV2) + CHECK_EQ("~number", toString(requireTypeAtPosition({10, 66}))); + else + CHECK_EQ("unknown", toString(requireTypeAtPosition({10, 66}))); +} + + +TEST_CASE_FIXTURE(BuiltinsFixture, "is_truthy_constraint_while_expression") +{ + CheckResult result = check(R"( + function f(v:string?) + while v do + local foo = v + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("string", toString(requireTypeAtPosition({3, 28}))); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "invert_is_truthy_constraint_while_expression") +{ + CheckResult result = check(R"( + function f(v:string?) + while not v do + local foo = v + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("nil", toString(requireTypeAtPosition({3, 28}))); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "refine_the_correct_types_opposite_of_while_a_is_not_number_or_string") +{ + CheckResult result = check(R"( + local function f(a: string | number | boolean) + while type(a) ~= "number" and type(a) ~= "string" do + local foo = a + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("boolean", toString(requireTypeAtPosition({3, 28}))); } TEST_CASE_FIXTURE(BuiltinsFixture, "correctly_lookup_a_shadowed_local_that_which_was_previously_refined") @@ -997,11 +1224,17 @@ TEST_CASE_FIXTURE(Fixture, "discriminate_from_truthiness_of_x") LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(R"({| tag: "exists", x: string |})", toString(requireTypeAtPosition({5, 28}))); - if (FFlag::DebugLuauDeferredConstraintResolution) - CHECK_EQ(R"({| tag: "missing", x: nil |})", toString(requireTypeAtPosition({7, 28}))); + if (FFlag::LuauSolverV2) + { + // CLI-115281 Types produced by refinements do not consistently get simplified + CHECK("{ tag: \"exists\", x: string } & { x: ~(false?) }" == toString(requireTypeAtPosition({5, 28}))); + CHECK("({ tag: \"exists\", x: string } & { x: ~~(false?) }) | { tag: \"missing\", x: nil }" == toString(requireTypeAtPosition({7, 28}))); + } else + { + CHECK_EQ(R"({| tag: "exists", x: string |})", toString(requireTypeAtPosition({5, 28}))); CHECK_EQ(R"({| tag: "exists", x: string |} | {| tag: "missing", x: nil |})", toString(requireTypeAtPosition({7, 28}))); + } } TEST_CASE_FIXTURE(Fixture, "discriminate_tag") @@ -1022,16 +1255,8 @@ TEST_CASE_FIXTURE(Fixture, "discriminate_tag") LUAU_REQUIRE_NO_ERRORS(result); - if (FFlag::DebugLuauDeferredConstraintResolution) - { - CHECK_EQ(R"({| catfood: string, name: string, tag: "Cat" |})", toString(requireTypeAtPosition({7, 33}))); - CHECK_EQ(R"({| dogfood: string, name: string, tag: "Dog" |})", toString(requireTypeAtPosition({9, 33}))); - } - else - { - CHECK_EQ("Cat", toString(requireTypeAtPosition({7, 33}))); - CHECK_EQ("Dog", toString(requireTypeAtPosition({9, 33}))); - } + CHECK_EQ("Cat", toString(requireTypeAtPosition({7, 33}))); + CHECK_EQ("Dog", toString(requireTypeAtPosition({9, 33}))); } TEST_CASE_FIXTURE(Fixture, "discriminate_tag_with_implicit_else") @@ -1052,16 +1277,8 @@ TEST_CASE_FIXTURE(Fixture, "discriminate_tag_with_implicit_else") LUAU_REQUIRE_NO_ERRORS(result); - if (FFlag::DebugLuauDeferredConstraintResolution) - { - CHECK_EQ(R"({| catfood: string, name: string, tag: "Cat" |})", toString(requireTypeAtPosition({7, 33}))); - CHECK_EQ(R"({| dogfood: string, name: string, tag: "Dog" |})", toString(requireTypeAtPosition({9, 33}))); - } - else - { - CHECK_EQ("Cat", toString(requireTypeAtPosition({7, 33}))); - CHECK_EQ("Dog", toString(requireTypeAtPosition({9, 33}))); - } + CHECK_EQ("Cat", toString(requireTypeAtPosition({7, 33}))); + CHECK_EQ("Dog", toString(requireTypeAtPosition({9, 33}))); } TEST_CASE_FIXTURE(Fixture, "and_or_peephole_refinement") @@ -1142,12 +1359,24 @@ TEST_CASE_FIXTURE(RefinementClassFixture, "discriminate_from_isa_of_x") LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(R"({| tag: "Part", x: Part |})", toString(requireTypeAtPosition({5, 28}))); - CHECK_EQ(R"({| tag: "Folder", x: Folder |})", toString(requireTypeAtPosition({7, 28}))); + + if (FFlag::LuauSolverV2) + { + CHECK(R"({ tag: "Part", x: Part })" == toString(requireTypeAtPosition({5, 28}))); + CHECK(R"({ tag: "Folder", x: Folder })" == toString(requireTypeAtPosition({7, 28}))); + } + else + { + CHECK_EQ(R"({| tag: "Part", x: Part |})", toString(requireTypeAtPosition({5, 28}))); + CHECK_EQ(R"({| tag: "Folder", x: Folder |})", toString(requireTypeAtPosition({7, 28}))); + } } TEST_CASE_FIXTURE(RefinementClassFixture, "typeguard_cast_free_table_to_vector") { + // CLI-115286 - Refining via type(x) == 'vector' does not work in the new solver + ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + CheckResult result = check(R"( local function f(vec) local X, Y, Z = vec.X, vec.Y, vec.Z @@ -1221,7 +1450,7 @@ TEST_CASE_FIXTURE(RefinementClassFixture, "type_narrow_but_the_discriminant_type LUAU_REQUIRE_NO_ERRORS(result); - if (FFlag::DebugLuauDeferredConstraintResolution) + if (FFlag::LuauSolverV2) { CHECK_EQ("never", toString(requireTypeAtPosition({3, 28}))); CHECK_EQ("Instance | Vector3 | number | string", toString(requireTypeAtPosition({5, 28}))); @@ -1271,6 +1500,9 @@ TEST_CASE_FIXTURE(RefinementClassFixture, "narrow_from_subclasses_of_instance_or TEST_CASE_FIXTURE(RefinementClassFixture, "x_as_any_if_x_is_instance_elseif_x_is_table") { + // CLI-117136 - this code doesn't finish constraint solving and has blocked types in the output + if (FFlag::LuauSolverV2) + return; CheckResult result = check(R"( --!nonstrict @@ -1285,7 +1517,7 @@ TEST_CASE_FIXTURE(RefinementClassFixture, "x_as_any_if_x_is_instance_elseif_x_is LUAU_REQUIRE_NO_ERRORS(result); - if (FFlag::DebugLuauDeferredConstraintResolution) + if (FFlag::LuauSolverV2) { CHECK_EQ("Folder & Instance & {- -}", toString(requireTypeAtPosition({5, 28}))); CHECK_EQ("(~Folder | ~Instance) & {- -} & never", toString(requireTypeAtPosition({7, 28}))); @@ -1335,6 +1567,10 @@ TEST_CASE_FIXTURE(RefinementClassFixture, "refine_param_of_type_folder_or_part_w TEST_CASE_FIXTURE(RefinementClassFixture, "isa_type_refinement_must_be_known_ahead_of_time") { + // CLI-115087 - The new solver does not consistently combine tables with + // class types when they appear in the upper bounds of a free type. + ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + CheckResult result = check(R"( local function f(x): Instance if x:IsA("Folder") then @@ -1355,6 +1591,9 @@ TEST_CASE_FIXTURE(RefinementClassFixture, "isa_type_refinement_must_be_known_ahe TEST_CASE_FIXTURE(RefinementClassFixture, "x_is_not_instance_or_else_not_part") { + // CLI-117135 - RefinementTests.x_is_not_instance_or_else_not_part not correctly applying refinements to a function parameter + if (FFlag::LuauSolverV2) + return; CheckResult result = check(R"( local function f(x: Part | Folder | string) if typeof(x) ~= "Instance" or not x:IsA("Part") then @@ -1402,7 +1641,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "refine_unknowns") LUAU_REQUIRE_NO_ERRORS(result); - if (FFlag::DebugLuauDeferredConstraintResolution) + if (FFlag::LuauSolverV2) { CHECK_EQ("string", toString(requireTypeAtPosition({3, 28}))); CHECK_EQ("~string", toString(requireTypeAtPosition({5, 28}))); @@ -1448,6 +1687,23 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "refine_thread") CHECK_EQ("number", toString(requireTypeAtPosition({5, 28}))); } +TEST_CASE_FIXTURE(BuiltinsFixture, "refine_buffer") +{ + CheckResult result = check(R"( + local function f(x: number | buffer) + if typeof(x) == "buffer" then + local foo = x + else + local foo = x + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("buffer", toString(requireTypeAtPosition({3, 28}))); + CHECK_EQ("number", toString(requireTypeAtPosition({5, 28}))); +} + TEST_CASE_FIXTURE(BuiltinsFixture, "falsiness_of_TruthyPredicate_narrows_into_nil") { CheckResult result = check(R"( @@ -1510,14 +1766,7 @@ local _ = _ ~= _ or _ or _ end )"); - if (FFlag::DebugLuauDeferredConstraintResolution) - { - // Without a realistic motivating case, it's hard to tell if it's important for this to work without errors. - LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK(get(result.errors[0])); - } - else - LUAU_REQUIRE_NO_ERRORS(result); + LUAU_REQUIRE_NO_ERRORS(result); } TEST_CASE_FIXTURE(BuiltinsFixture, "refine_unknown_to_table_then_take_the_length") @@ -1530,7 +1779,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "refine_unknown_to_table_then_take_the_length end )"); - if (FFlag::DebugLuauDeferredConstraintResolution) + if (FFlag::LuauSolverV2) { LUAU_REQUIRE_NO_ERRORS(result); CHECK_EQ("table", toString(requireTypeAtPosition({3, 29}))); @@ -1552,7 +1801,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "refine_unknown_to_table_then_clone_it") end )"); - if (FFlag::DebugLuauDeferredConstraintResolution) + if (FFlag::LuauSolverV2) { LUAU_REQUIRE_NO_ERRORS(result); } @@ -1564,6 +1813,9 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "refine_unknown_to_table_then_clone_it") TEST_CASE_FIXTURE(RefinementClassFixture, "refine_a_param_that_got_resolved_during_constraint_solving_stage") { + // CLI-117134 - Applying a refinement causes an optional value access error. + if (FFlag::LuauSolverV2) + return; CheckResult result = check(R"( type Id = T @@ -1597,7 +1849,7 @@ TEST_CASE_FIXTURE(RefinementClassFixture, "refine_a_param_that_got_resolved_duri LUAU_REQUIRE_NO_ERRORS(result); CHECK_EQ("Part", toString(requireTypeAtPosition({5, 28}))); - if (FFlag::DebugLuauDeferredConstraintResolution) + if (FFlag::LuauSolverV2) CHECK_EQ("Instance & ~Part", toString(requireTypeAtPosition({7, 28}))); else CHECK_EQ("Instance", toString(requireTypeAtPosition({7, 28}))); @@ -1605,8 +1857,6 @@ TEST_CASE_FIXTURE(RefinementClassFixture, "refine_a_param_that_got_resolved_duri TEST_CASE_FIXTURE(Fixture, "refine_a_property_of_some_global") { - ScopedFastFlag sff{"DebugLuauDeferredConstraintResolution", true}; - CheckResult result = check(R"( foo = { bar = 5 :: number? } @@ -1615,8 +1865,12 @@ TEST_CASE_FIXTURE(Fixture, "refine_a_property_of_some_global") end )"); - LUAU_REQUIRE_ERROR_COUNT(3, result); - CHECK_EQ("*error-type*", toString(requireTypeAtPosition({4, 30}))); + if (FFlag::LuauSolverV2) + { + LUAU_REQUIRE_ERROR_COUNT(3, result); + + CHECK_EQ("*error-type* | buffer | class | function | number | string | table | thread | true", toString(requireTypeAtPosition({4, 30}))); + } } TEST_CASE_FIXTURE(BuiltinsFixture, "dataflow_analysis_can_tell_refinements_when_its_appropriate_to_refine_into_nil_or_never") @@ -1639,9 +1893,9 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "dataflow_analysis_can_tell_refinements_when_ end if typeof(s) == "nil" then - local foo = s + local foo = s -- line 18 else - local foo = s + local foo = s -- line 20 end end )"); @@ -1654,9 +1908,10 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "dataflow_analysis_can_tell_refinements_when_ CHECK_EQ("nil", toString(requireTypeAtPosition({12, 28}))); CHECK_EQ("string", toString(requireTypeAtPosition({14, 28}))); - if (FFlag::DebugLuauDeferredConstraintResolution) + if (FFlag::LuauSolverV2) { - CHECK_EQ("never", toString(requireTypeAtPosition({18, 28}))); + // CLI-115281 - Types produced by refinements don't always get simplified + CHECK_EQ("nil & string", toString(requireTypeAtPosition({18, 28}))); CHECK_EQ("string", toString(requireTypeAtPosition({20, 28}))); } else @@ -1743,10 +1998,427 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "type_annotations_arent_relevant_when_doing_d LUAU_REQUIRE_NO_ERRORS(result); CHECK_EQ("nil", toString(requireTypeAtPosition({8, 28}))); - if (FFlag::DebugLuauDeferredConstraintResolution) - CHECK_EQ("never", toString(requireTypeAtPosition({9, 28}))); + if (FFlag::LuauSolverV2) + { + // CLI-115478 - This should be never + CHECK_EQ("nil", toString(requireTypeAtPosition({9, 28}))); + } else CHECK_EQ("nil", toString(requireTypeAtPosition({9, 28}))); } +TEST_CASE_FIXTURE(BuiltinsFixture, "function_call_with_colon_after_refining_not_to_be_nil") +{ + ScopedFastFlag sff{FFlag::LuauSolverV2, true}; + + CheckResult result = check(R"( + --!strict + export type Observer = { + read complete: ((self: Observer) -> ())?, + } + + local function _f(handler: Observer) + assert(handler.complete ~= nil) + handler:complete() -- incorrectly gives Value of type '((Observer) -> ())?' could be nil + handler.complete(handler) -- works fine, both forms should avoid the error + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "refinements_should_not_affect_assignment") +{ + CheckResult result = check(R"( + local a: unknown = true + if a == true then + a = 'not even remotely similar to a boolean' + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "refinements_should_preserve_error_suppression") +{ + CheckResult result = check(R"( + local a: any = {} + local b + if typeof(a) == "table" then + b = a.field + end + )"); + + if (FFlag::LuauSolverV2) + LUAU_REQUIRE_NO_ERRORS(result); + else + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "many_refinements_on_val") +{ + CheckResult result = check(R"( + local function is_nan(val: any): boolean + return type(val) == "number" and val ~= val + end + + local function is_js_boolean(val: any): boolean + return not not val and val ~= 0 and val ~= "" and not is_nan(val) + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("(any) -> boolean", toString(requireType("is_nan"))); + CHECK_EQ("(any) -> boolean", toString(requireType("is_js_boolean"))); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "refine_unknown_to_table") +{ + ScopedFastFlag sff{FFlag::LuauSolverV2, true}; + // this test is DCR-only as an instance of DCR fixing a bug in the old solver + + CheckResult result = check(R"( + local function f(a: unknown) + if typeof(a) == "table" then + for i, v in a do + return i, v + end + end + + error("") + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("(unknown) -> (unknown, unknown)", toString(requireType("f"))); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "conditional_refinement_should_stay_error_suppressing") +{ + ScopedFastFlag sff{FFlag::LuauSolverV2, true}; + + CheckResult result = check(R"( + local function test(element: any?) + if element then + local owner = element._owner + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "globals_can_be_narrowed_too") +{ + CheckResult result = check(R"( + if typeof(string) == 'string' then + local foo = string + end + )"); + + if (FFlag::LuauSolverV2) + { + // CLI-114134 + CHECK("string & typeof(string)" == toString(requireTypeAtPosition(Position{2, 24}))); + } + else + CHECK("never" == toString(requireTypeAtPosition(Position{2, 24}))); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "luau_polyfill_isindexkey_refine_conjunction") +{ + CheckResult result = check(R"( + local function isIndexKey(k, contiguousLength) + return type(k) == "number" + and k <= contiguousLength -- nothing out of bounds + and 1 <= k -- nothing illegal for array indices + and math.floor(k) == k -- no float keys + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "luau_polyfill_isindexkey_refine_conjunction_variant") +{ + CheckResult result = check(R"( + local function isIndexKey(k, contiguousLength: number) + return type(k) == "number" + and k <= contiguousLength -- nothing out of bounds + and 1 <= k -- nothing illegal for array indices + and math.floor(k) == k -- no float keys + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "ex") +{ + CheckResult result = check(R"( +local function f(x: string | number) + if typeof((x)) == "string" then + local y = x + end +end +)"); + TypeId t = requireTypeAtPosition({3, 18}); + CHECK("string" == toString(t)); +} + +TEST_CASE_FIXTURE(RefinementClassFixture, "mutate_prop_of_some_refined_symbol") +{ + CheckResult result = check(R"( + local function instances(): {Instance} error("") end + local function vec3(x, y, z): Vector3 error("") end + + for _, object in ipairs(instances()) do + if object:IsA("Part") then + object.Position = vec3(1, 2, 3) + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(RefinementClassFixture, "mutate_prop_of_some_refined_symbol_2") +{ + CheckResult result = check(R"( + type Result = never + | { tag: "ok", value: T } + | { tag: "err", error: E } + + local function results(): {Result} error("") end + + for _, res in ipairs(results()) do + if res.tag == "ok" then + res.value = 7 + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "ensure_t_after_return_references_all_reachable_points") +{ + CheckResult result = check(R"( + local t = {} + + local function f(k: string) + if t[k] ~= nil then + return + end + + t[k] = 5 + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("{ [string]: number }", toString(requireTypeAtPosition({8, 12}), {true})); +} + +TEST_CASE_FIXTURE(Fixture, "long_disjunction_of_refinements_should_not_trip_recursion_counter") +{ + CHECK_NOTHROW(check(R"( +function(obj) + if script.Parent.SeatNumber.Value == "1D" or + script.Parent.SeatNumber.Value == "2D" or + script.Parent.SeatNumber.Value == "3D" or + script.Parent.SeatNumber.Value == "4D" or + script.Parent.SeatNumber.Value == "5D" or + script.Parent.SeatNumber.Value == "6D" or + script.Parent.SeatNumber.Value == "7D" or + script.Parent.SeatNumber.Value == "8D" or + script.Parent.SeatNumber.Value == "9D" or + script.Parent.SeatNumber.Value == "10D" or + script.Parent.SeatNumber.Value == "11D" or + script.Parent.SeatNumber.Value == "12D" or + script.Parent.SeatNumber.Value == "13D" or + script.Parent.SeatNumber.Value == "14D" or + script.Parent.SeatNumber.Value == "15D" or + script.Parent.SeatNumber.Value == "16D" or + script.Parent.SeatNumber.Value == "1C" or + script.Parent.SeatNumber.Value == "2C" or + script.Parent.SeatNumber.Value == "3C" or + script.Parent.SeatNumber.Value == "4C" or + script.Parent.SeatNumber.Value == "5C" or + script.Parent.SeatNumber.Value == "6C" or + script.Parent.SeatNumber.Value == "7C" or + script.Parent.SeatNumber.Value == "8C" or + script.Parent.SeatNumber.Value == "9C" or + script.Parent.SeatNumber.Value == "10C" or + script.Parent.SeatNumber.Value == "11C" or + script.Parent.SeatNumber.Value == "12C" or + script.Parent.SeatNumber.Value == "13C" or + script.Parent.SeatNumber.Value == "14C" or + script.Parent.SeatNumber.Value == "15C" or + script.Parent.SeatNumber.Value == "16C" then +end +)")); +} + +TEST_CASE_FIXTURE(Fixture, "more_complex_long_disjunction_of_refinements_shouldnt_trip_ice") +{ + CHECK_NOTHROW(check(R"( +script:connect(function(obj) + if script.Parent.SeatNumber.Value == "1D" or + script.Parent.SeatNumber.Value == "2D" or + script.Parent.SeatNumber.Value == "3D" or + script.Parent.SeatNumber.Value == "4D" or + script.Parent.SeatNumber.Value == "5D" or + script.Parent.SeatNumber.Value == "6D" or + script.Parent.SeatNumber.Value == "7D" or + script.Parent.SeatNumber.Value == "8D" or + script.Parent.SeatNumber.Value == "9D" or + script.Parent.SeatNumber.Value == "10D" or + script.Parent.SeatNumber.Value == "11D" or + script.Parent.SeatNumber.Value == "12D" or + script.Parent.SeatNumber.Value == "13D" or + script.Parent.SeatNumber.Value == "14D" or + script.Parent.SeatNumber.Value == "15D" or + script.Parent.SeatNumber.Value == "16D" or + script.Parent.SeatNumber.Value == "1C" or + script.Parent.SeatNumber.Value == "2C" or + script.Parent.SeatNumber.Value == "3C" or + script.Parent.SeatNumber.Value == "4C" or + script.Parent.SeatNumber.Value == "5C" or + script.Parent.SeatNumber.Value == "6C" or + script.Parent.SeatNumber.Value == "7C" or + script.Parent.SeatNumber.Value == "8C" or + script.Parent.SeatNumber.Value == "9C" or + script.Parent.SeatNumber.Value == "10C" or + script.Parent.SeatNumber.Value == "11C" or + script.Parent.SeatNumber.Value == "12C" or + script.Parent.SeatNumber.Value == "13C" or + script.Parent.SeatNumber.Value == "14C" or + script.Parent.SeatNumber.Value == "15C" or + script.Parent.SeatNumber.Value == "16C" then + end) +)")); +} + +TEST_CASE_FIXTURE(Fixture, "refinements_should_avoid_building_up_big_intersect_families") +{ + CHECK_NOTHROW(check(R"( +script:connect(function(obj) + if script.Parent.SeatNumber.Value == "1D" or script.Parent.SeatNumber.Value == "2D" or script.Parent.SeatNumber.Value == "3D" or script.Parent.SeatNumber.Value == "4D" or script.Parent.SeatNumber.Value == "5D" or script.Parent.SeatNumber.Value == "6D" or script.Parent.SeatNumber.Value == "7D" or script.Parent.SeatNumber.Value == "8D" or script.Parent.SeatNumber.Value == "9D" or script.Parent.SeatNumber.Value == "10D" or script.Parent.SeatNumber.Value == "11D" or script.Parent.SeatNumber.Value == "12D" or script.Parent.SeatNumber.Value == "13D" or script.Parent.SeatNumber.Value == "14D" or script.Parent.SeatNumber.Value == "15D" or script.Parent.SeatNumber.Value == "16D" or script.Parent.SeatNumber.Value == "1C" or script.Parent.SeatNumber.Value == "2C" or script.Parent.SeatNumber.Value == "3C" or script.Parent.SeatNumber.Value == "4C" or script.Parent.SeatNumber.Value == "5C" or script.Parent.SeatNumber.Value == "6C" or script.Parent.SeatNumber.Value == "7C" or script.Parent.SeatNumber.Value == "8C" or script.Parent.SeatNumber.Value == "9C" or script.Parent.SeatNumber.Value == "10C" or script.Parent.SeatNumber.Value == "11C" or script.Parent.SeatNumber.Value == "12C" or script.Parent.SeatNumber.Value == "13C" or script.Parent.SeatNumber.Value == "14C" or script.Parent.SeatNumber.Value == "15C" or script.Parent.SeatNumber.Value == "16C" then + if p.Name == script.Parent.Parent.Parent.Parent.Parent.Parent.MainParts.CD.SurfaceGui[script.Parent.SeatNumber.Value].Player.Value or script.Parent.Parent.Parent.Parent.Parent.Parent.MainParts.CD.SurfaceGui[script.Parent.SeatNumber.Value].Player.Value == "" then + else + if script.Parent:FindFirstChild("SeatWeld") then + end + end + else + if p.Name == script.Parent.Parent.Parent.Parent.Parent.Parent.MainParts.AB.SurfaceGui[script.Parent.SeatNumber.Value].Player.Value or script.Parent.Parent.Parent.Parent.Parent.Parent.MainParts.AB.SurfaceGui[script.Parent.SeatNumber.Value].Player.Value == "" then + print("Allowed") + else + if script.Parent:FindFirstChild("SeatWeld") then + end + end + end +end) +)")); +} + +TEST_CASE_FIXTURE(Fixture, "refinements_table_intersection_limits" * doctest::timeout(0.5)) +{ + CheckResult result = check(R"( +--!strict +type Dir = { + a: number?, b: number?, c: number?, d: number?, e: number?, f: number?, + g: number?, h: number?, i: number?, j: number?, k: number?, l: number?, + m: number?, n: number?, o: number?, p: number?, q: number?, r: number?, +} + +local function test(dirs: {Dir}) + for k, dir in dirs + local success, message = pcall(function() + assert(dir.a == nil or type(dir.a) == "number") + assert(dir.b == nil or type(dir.b) == "number") + assert(dir.c == nil or type(dir.c) == "number") + assert(dir.d == nil or type(dir.d) == "number") + assert(dir.e == nil or type(dir.e) == "number") + assert(dir.f == nil or type(dir.f) == "number") + assert(dir.g == nil or type(dir.g) == "number") + assert(dir.h == nil or type(dir.h) == "number") + assert(dir.i == nil or type(dir.i) == "number") + assert(dir.j == nil or type(dir.j) == "number") + assert(dir.k == nil or type(dir.k) == "number") + assert(dir.l == nil or type(dir.l) == "number") + assert(dir.m == nil or type(dir.m) == "number") + assert(dir.n == nil or type(dir.n) == "number") + assert(dir.o == nil or type(dir.o) == "number") + assert(dir.p == nil or type(dir.p) == "number") + assert(dir.q == nil or type(dir.q) == "number") + assert(dir.r == nil or type(dir.r) == "number") + assert(dir.t == nil or type(dir.t) == "number") + assert(dir.u == nil or type(dir.u) == "number") + assert(dir.v == nil or type(dir.v) == "number") + local checkpoint = dir + + checkpoint.w = 1 + end) + assert(success) + end +end + )"); +} + +TEST_CASE_FIXTURE(RefinementClassFixture, "typeof_instance_refinement") +{ + CheckResult result = check(R"( + local function f(x: Instance | Vector3) + if typeof(x) == "Instance" then + local foo = x + else + local foo = x + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("Instance", toString(requireTypeAtPosition({3, 28}))); + CHECK_EQ("Vector3", toString(requireTypeAtPosition({5, 28}))); +} + +TEST_CASE_FIXTURE(RefinementClassFixture, "typeof_instance_error") +{ + CheckResult result = check(R"( + local function f(x: Part) + if typeof(x) == "Instance" then + local foo : Folder = x + end + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); +} + +TEST_CASE_FIXTURE(RefinementClassFixture, "typeof_instance_isa_refinement") +{ + CheckResult result = check(R"( + local function f(x: Part | Folder | string) + if typeof(x) == "Instance" then + local foo = x + if foo:IsA("Folder") then + local bar = foo + end + else + local foo = x + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("Folder | Part", toString(requireTypeAtPosition({3, 28}))); + CHECK_EQ("Folder", toString(requireTypeAtPosition({5, 32}))); + CHECK_EQ("string", toString(requireTypeAtPosition({8, 28}))); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.singletons.test.cpp b/tests/TypeInfer.singletons.test.cpp index 23e49f581..e2efc8fbd 100644 --- a/tests/TypeInfer.singletons.test.cpp +++ b/tests/TypeInfer.singletons.test.cpp @@ -3,12 +3,27 @@ #include "Fixture.h" #include "doctest.h" -#include "Luau/BuiltinDefinitions.h" using namespace Luau; +LUAU_FASTFLAG(LuauSolverV2); + TEST_SUITE_BEGIN("TypeSingletons"); +TEST_CASE_FIXTURE(Fixture, "function_args_infer_singletons") +{ + CheckResult result = check(R"( +--!strict +type Phase = "A" | "B" | "C" +local function f(e : Phase) : number + return 0 +end +local e = f("B") +)"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_CASE_FIXTURE(Fixture, "bool_singletons") { CheckResult result = check(R"( @@ -29,6 +44,20 @@ TEST_CASE_FIXTURE(Fixture, "string_singletons") LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(Fixture, "string_singleton_function_call") +{ + if (!FFlag::LuauSolverV2) + return; + + CheckResult result = check(R"( + local x = "a" + function f(x: "a") end + f(x) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_CASE_FIXTURE(Fixture, "bool_singletons_mismatch") { CheckResult result = check(R"( @@ -124,15 +153,25 @@ TEST_CASE_FIXTURE(Fixture, "overloaded_function_call_with_singletons") TEST_CASE_FIXTURE(Fixture, "overloaded_function_call_with_singletons_mismatch") { + ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + CheckResult result = check(R"( - function f(a, b) end - local g : ((true, string) -> ()) & ((false, number) -> ()) = (f::any) - g(true, 37) + function f(g: ((true, string) -> ()) & ((false, number) -> ())) + g(true, 37) + end )"); LUAU_REQUIRE_ERROR_COUNT(2, result); - CHECK_EQ("Type 'number' could not be converted into 'string'", toString(result.errors[0])); - CHECK_EQ("Other overloads are also not viable: (false, number) -> ()", toString(result.errors[1])); + if (FFlag::LuauSolverV2) + { + CHECK_EQ("None of the overloads for function that accept 2 arguments are compatible.", toString(result.errors[0])); + CHECK_EQ("Available overloads: (true, string) -> (); and (false, number) -> ()", toString(result.errors[1])); + } + else + { + CHECK_EQ("Type 'number' could not be converted into 'string'", toString(result.errors[0])); + CHECK_EQ("Other overloads are also not viable: (false, number) -> ()", toString(result.errors[1])); + } } TEST_CASE_FIXTURE(Fixture, "enums_using_singletons") @@ -155,8 +194,14 @@ TEST_CASE_FIXTURE(Fixture, "enums_using_singletons_mismatch") )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ("Type '\"bang\"' could not be converted into '\"bar\" | \"baz\" | \"foo\"'; none of the union options are compatible", - toString(result.errors[0])); + + if (FFlag::LuauSolverV2) + CHECK("Type '\"bang\"' could not be converted into '\"bar\" | \"baz\" | \"foo\"'" == toString(result.errors[0])); + else + CHECK_EQ( + "Type '\"bang\"' could not be converted into '\"bar\" | \"baz\" | \"foo\"'; none of the union options are compatible", + toString(result.errors[0]) + ); } TEST_CASE_FIXTURE(Fixture, "enums_using_singletons_subtyping") @@ -205,11 +250,23 @@ TEST_CASE_FIXTURE(Fixture, "tagged_unions_immutable_tag") type Dog = { tag: "Dog", howls: boolean } type Cat = { tag: "Cat", meows: boolean } type Animal = Dog | Cat - local a : Animal = { tag = "Cat", meows = true } + local a: Animal = { tag = "Cat", meows = true } a.tag = "Dog" )"); LUAU_REQUIRE_ERRORS(result); + + if (FFlag::LuauSolverV2) + { + CannotAssignToNever* tm = get(result.errors[0]); + REQUIRE(tm); + + CHECK(builtinTypes->stringType == tm->rhsType); + CHECK(CannotAssignToNever::Reason::PropertyNarrowed == tm->reason); + REQUIRE(tm->cause.size() == 2); + CHECK("\"Dog\"" == toString(tm->cause[0])); + CHECK("\"Cat\"" == toString(tm->cause[1])); + } } TEST_CASE_FIXTURE(Fixture, "table_has_a_boolean") @@ -277,6 +334,27 @@ TEST_CASE_FIXTURE(Fixture, "table_properties_alias_or_parens_is_indexer") CHECK_EQ("Cannot have more than one table indexer", toString(result.errors[0])); } +TEST_CASE_FIXTURE(Fixture, "indexer_can_be_union_of_singletons") +{ + if (!FFlag::LuauSolverV2) + return; + + CheckResult result = check(R"( + type Target = "A" | "B" + + type Test = {[Target]: number} + + local test: Test = {} + + test.A = 2 + test.C = 4 + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CHECK(8 == result.errors[0].location.begin.line); +} + TEST_CASE_FIXTURE(Fixture, "table_properties_type_error_escapes") { CheckResult result = check(R"( @@ -286,8 +364,18 @@ TEST_CASE_FIXTURE(Fixture, "table_properties_type_error_escapes") )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(R"(Table type '{ ["\n"]: number }' not compatible with type '{| ["<>"]: number |}' because the former is missing field '<>')", - toString(result.errors[0])); + if (FFlag::LuauSolverV2) + CHECK( + "Type\n" + " '{ [\"\\n\"]: number }'\n" + "could not be converted into\n" + " '{ [\"<>\"]: number }'" == toString(result.errors[0]) + ); + else + CHECK_EQ( + R"(Table type '{ ["\n"]: number }' not compatible with type '{| ["<>"]: number |}' because the former is missing field '<>')", + toString(result.errors[0]) + ); } TEST_CASE_FIXTURE(Fixture, "error_detailed_tagged_union_mismatch_string") @@ -301,10 +389,16 @@ local a: Animal = { tag = 'cat', cafood = 'something' } )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(R"(Type 'a' could not be converted into 'Cat | Dog' + if (FFlag::LuauSolverV2) + CHECK("Type '{ cafood: string, tag: \"cat\" }' could not be converted into 'Cat | Dog'" == toString(result.errors[0])); + else + { + const std::string expected = R"(Type 'a' could not be converted into 'Cat | Dog' caused by: - None of the union options are compatible. For example: Table type 'a' not compatible with type 'Cat' because the former is missing field 'catfood')", - toString(result.errors[0])); + None of the union options are compatible. For example: +Table type 'a' not compatible with type 'Cat' because the former is missing field 'catfood')"; + CHECK_EQ(expected, toString(result.errors[0])); + } } TEST_CASE_FIXTURE(Fixture, "error_detailed_tagged_union_mismatch_bool") @@ -318,33 +412,38 @@ local a: Result = { success = false, result = 'something' } )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(R"(Type 'a' could not be converted into 'Bad | Good' + if (FFlag::LuauSolverV2) + CHECK("Type '{ result: string, success: boolean }' could not be converted into 'Bad | Good'" == toString(result.errors[0])); + else + { + const std::string expected = R"(Type 'a' could not be converted into 'Bad | Good' caused by: - None of the union options are compatible. For example: Table type 'a' not compatible with type 'Bad' because the former is missing field 'error')", - toString(result.errors[0])); + None of the union options are compatible. For example: +Table type 'a' not compatible with type 'Bad' because the former is missing field 'error')"; + CHECK_EQ(expected, toString(result.errors[0])); + } } TEST_CASE_FIXTURE(Fixture, "parametric_tagged_union_alias") { ScopedFastFlag sff[] = { - {"DebugLuauDeferredConstraintResolution", true}, + {FFlag::LuauSolverV2, true}, }; - CheckResult result = check(R"( type Ok = {success: true, result: T} type Err = {success: false, error: T} type Result = Ok | Err local a : Result = {success = false, result = "hotdogs"} - local b : Result = {success = true, result = "hotdogs"} + -- local b : Result = {success = true, result = "hotdogs"} )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - const std::string expectedError = "Type 'a' could not be converted into 'Err | Ok'\n" - "caused by:\n" - " None of the union options are compatible. For example: Table type 'a'" - " not compatible with type 'Err' because the former is missing field 'error'"; + const std::string expectedError = R"(Type + '{ result: string, success: boolean }' +could not be converted into + 'Err | Ok')"; CHECK(toString(result.errors[0]) == expectedError); } @@ -364,6 +463,8 @@ local a: Animal = if true then { tag = 'cat', catfood = 'something' } else { tag TEST_CASE_FIXTURE(Fixture, "widen_the_supertype_if_it_is_free_and_subtype_has_singleton") { + ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + CheckResult result = check(R"( local function foo(f, x) if x == "hi" then @@ -373,7 +474,7 @@ TEST_CASE_FIXTURE(Fixture, "widen_the_supertype_if_it_is_free_and_subtype_has_si end )"); - LUAU_REQUIRE_NO_ERRORS(result); + LUAU_CHECK_NO_ERRORS(result); CHECK_EQ(R"("hi")", toString(requireTypeAtPosition({3, 18}))); // should be ((string) -> a..., string) -> () but needs lower bounds calculation @@ -382,6 +483,8 @@ TEST_CASE_FIXTURE(Fixture, "widen_the_supertype_if_it_is_free_and_subtype_has_si TEST_CASE_FIXTURE(Fixture, "return_type_of_f_is_not_widened") { + ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + CheckResult result = check(R"( local function foo(f, x): "hello"? -- anyone there? return if x == "hi" @@ -405,7 +508,11 @@ TEST_CASE_FIXTURE(Fixture, "widening_happens_almost_everywhere") )"); LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("string", toString(requireType("copy"))); + + if (FFlag::LuauSolverV2) + CHECK_EQ(R"("foo")", toString(requireType("copy"))); + else + CHECK_EQ("string", toString(requireType("copy"))); } TEST_CASE_FIXTURE(Fixture, "widening_happens_almost_everywhere_except_for_tables") @@ -513,4 +620,24 @@ TEST_CASE_FIXTURE(Fixture, "no_widening_from_callsites") LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(BuiltinsFixture, "singletons_stick_around_under_assignment") +{ + CheckResult result = check(R"( + type Foo = { + kind: "Foo", + } + + local foo = (nil :: any) :: Foo + + print(foo.kind == "Bar") -- type of equality refines to `false` + local kind = foo.kind + print(kind == "Bar") -- type of equality refines to `false` + )"); + + if (FFlag::LuauSolverV2) + LUAU_REQUIRE_NO_ERRORS(result); + else + LUAU_REQUIRE_ERROR_COUNT(1, result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.tables.test.cpp b/tests/TypeInfer.tables.test.cpp index 2a87f0e3b..b2510fe04 100644 --- a/tests/TypeInfer.tables.test.cpp +++ b/tests/TypeInfer.tables.test.cpp @@ -8,20 +8,58 @@ #include "Fixture.h" +#include "ScopedFlags.h" #include "doctest.h" #include using namespace Luau; -LUAU_FASTFLAG(LuauLowerBoundsCalculation); -LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); +LUAU_FASTFLAG(LuauSolverV2) LUAU_FASTFLAG(LuauInstantiateInSubtyping) -LUAU_FASTFLAG(LuauTypeMismatchInvarianceInError) -LUAU_FASTFLAG(LuauDontExtendUnsealedRValueTables) +LUAU_FASTFLAG(LuauFixIndexerSubtypingOrdering) +LUAU_FASTFLAG(LuauAcceptIndexingTableUnionsIntersections) + +LUAU_DYNAMIC_FASTINT(LuauTypeSolverRelease) TEST_SUITE_BEGIN("TableTests"); +TEST_CASE_FIXTURE(BuiltinsFixture, "generalization_shouldnt_seal_table_in_len_function_fn") +{ + if (!FFlag::LuauSolverV2) + return; + CheckResult result = check(R"( +local t = {} +for i = #t, 2, -1 do + t[i] = t[i + 1] +end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + const TableType* tType = get(requireType("t")); + REQUIRE(tType != nullptr); + REQUIRE(tType->indexer); + CHECK_EQ(tType->indexer->indexType, builtinTypes->numberType); + CHECK_EQ(follow(tType->indexer->indexResultType), builtinTypes->unknownType); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "LUAU_ASSERT_arg_exprs_doesnt_trigger_assert") +{ + CheckResult result = check(R"( +local FadeValue = {} +function FadeValue.new(finalCallback) + local self = setmetatable({}, FadeValue) + self.finalCallback = finalCallback + return self +end + +function FadeValue:destroy() + self.finalCallback() + self.finalCallback = nil +end +)"); +} + TEST_CASE_FIXTURE(Fixture, "basic") { CheckResult result = check("local t = {foo = \"bar\", baz = 9, quux = nil}"); @@ -32,20 +70,23 @@ TEST_CASE_FIXTURE(Fixture, "basic") std::optional fooProp = get(tType->props, "foo"); REQUIRE(bool(fooProp)); - CHECK_EQ(PrimitiveType::String, getPrimitiveType(fooProp->type)); + CHECK_EQ(PrimitiveType::String, getPrimitiveType(fooProp->type())); std::optional bazProp = get(tType->props, "baz"); REQUIRE(bool(bazProp)); - CHECK_EQ(PrimitiveType::Number, getPrimitiveType(bazProp->type)); + CHECK_EQ(PrimitiveType::Number, getPrimitiveType(bazProp->type())); std::optional quuxProp = get(tType->props, "quux"); REQUIRE(bool(quuxProp)); - CHECK_EQ(PrimitiveType::NilType, getPrimitiveType(quuxProp->type)); + CHECK_EQ(PrimitiveType::NilType, getPrimitiveType(quuxProp->type())); } TEST_CASE_FIXTURE(Fixture, "augment_table") { - CheckResult result = check("local t = {} t.foo = 'bar'"); + CheckResult result = check(R"( + local t = {} + t.foo = 'bar' + )"); LUAU_REQUIRE_NO_ERRORS(result); const TableType* tType = get(requireType("t")); @@ -66,12 +107,41 @@ TEST_CASE_FIXTURE(Fixture, "augment_nested_table") REQUIRE(tType != nullptr); REQUIRE(tType->props.find("p") != tType->props.end()); - const TableType* pType = get(tType->props["p"].type); + const TableType* pType = get(tType->props["p"].type()); REQUIRE(pType != nullptr); CHECK("{ p: { foo: string } }" == toString(requireType("t"), {true})); } +TEST_CASE_FIXTURE(Fixture, "assign_key_at_index_expr") +{ + CheckResult result = check(R"( + function f(t: {[string]: number}) + t["hello"] = 1 + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + // We had a bug where we forgot to record the astType of this particular node. + CHECK("string" == toString(requireTypeAtPosition({2, 19}))); +} + +TEST_CASE_FIXTURE(Fixture, "index_expression_is_checked_against_the_indexer_type") +{ + CheckResult result = check(R"( + function f(t: {[boolean]: number}) + t["hello"] = 15 + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + if (FFlag::LuauSolverV2) + CHECK_MESSAGE(get(result.errors[0]), "Expected CannotExtendTable but got " << toString(result.errors[0])); + else + CHECK(get(result.errors[0])); +} + TEST_CASE_FIXTURE(Fixture, "cannot_augment_sealed_table") { CheckResult result = check(R"( @@ -93,7 +163,10 @@ TEST_CASE_FIXTURE(Fixture, "cannot_augment_sealed_table") // TODO: better, more robust comparison of type vars auto s = toString(error->tableType, ToStringOptions{/*exhaustive*/ true}); - CHECK_EQ(s, "{| prop: number |}"); + if (FFlag::LuauSolverV2) + CHECK_EQ(s, "{ prop: number }"); + else + CHECK_EQ(s, "{| prop: number |}"); CHECK_EQ(error->prop, "foo"); CHECK_EQ(error->context, CannotExtendTable::Property); } @@ -134,6 +207,24 @@ TEST_CASE_FIXTURE(Fixture, "cannot_change_type_of_table_prop") LUAU_REQUIRE_ERROR_COUNT(1, result); } +TEST_CASE_FIXTURE(Fixture, "report_sensible_error_when_adding_a_value_to_a_nonexistent_prop") +{ + CheckResult result = check(R"( + local t = {} + t.foo[1] = 'one' + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + INFO(result.errors[0]); + + UnknownProperty* err = get(result.errors[0]); + REQUIRE(err); + + CHECK("t" == toString(err->table)); + CHECK("foo" == err->key); +} + TEST_CASE_FIXTURE(Fixture, "function_calls_can_produce_tables") { CheckResult result = check("function get_table() return {prop=999} end get_table().prop = 0"); @@ -160,7 +251,7 @@ TEST_CASE_FIXTURE(Fixture, "tc_member_function") std::optional fooProp = get(tableType->props, "foo"); REQUIRE(bool(fooProp)); - const FunctionType* methodType = get(follow(fooProp->type)); + const FunctionType* methodType = get(follow(fooProp->type())); REQUIRE(methodType != nullptr); } @@ -174,7 +265,7 @@ TEST_CASE_FIXTURE(Fixture, "tc_member_function_2") std::optional uProp = get(tableType->props, "U"); REQUIRE(bool(uProp)); - TypeId uType = uProp->type; + TypeId uType = uProp->type(); const TableType* uTable = get(uType); REQUIRE(uTable != nullptr); @@ -182,7 +273,7 @@ TEST_CASE_FIXTURE(Fixture, "tc_member_function_2") std::optional fooProp = get(uTable->props, "foo"); REQUIRE(bool(fooProp)); - const FunctionType* methodType = get(follow(fooProp->type)); + const FunctionType* methodType = get(follow(fooProp->type())); REQUIRE(methodType != nullptr); std::vector methodArgs = flatten(methodType->argTypes).first; @@ -195,20 +286,40 @@ TEST_CASE_FIXTURE(Fixture, "tc_member_function_2") TEST_CASE_FIXTURE(Fixture, "call_method") { - CheckResult result = check("local T = {} T.x = 0 function T:method() return self.x end local a = T:method()"); + CheckResult result = check(R"( + local T = {} + T.x = 0 + function T:method() + return self.x + end + local a = T:method() + )"); LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(*typeChecker.numberType, *requireType("a")); + CHECK_EQ(*builtinTypes->numberType, *requireType("a")); } TEST_CASE_FIXTURE(Fixture, "call_method_with_explicit_self_argument") { - CheckResult result = check("local T = {} T.x = 0 function T:method() return self.x end local a = T.method(T)"); + CheckResult result = check(R"( + local T = {} + T.x = 0 + + function T:method() + return self.x + end + + local a = T.method(T) + )"); + LUAU_REQUIRE_NO_ERRORS(result); } TEST_CASE_FIXTURE(Fixture, "used_dot_instead_of_colon") { + // CLI-114792 Dot vs colon warnings aren't in the new solver yet. + ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + CheckResult result = check(R"( local T = {} T.x = 0 @@ -218,9 +329,14 @@ TEST_CASE_FIXTURE(Fixture, "used_dot_instead_of_colon") local a = T.method() )"); - auto it = std::find_if(result.errors.begin(), result.errors.end(), [](const TypeError& e) { - return nullptr != get(e); - }); + auto it = std::find_if( + result.errors.begin(), + result.errors.end(), + [](const TypeError& e) + { + return nullptr != get(e); + } + ); REQUIRE(it != result.errors.end()); } @@ -254,6 +370,9 @@ TEST_CASE_FIXTURE(Fixture, "used_dot_instead_of_colon_but_correctly") TEST_CASE_FIXTURE(Fixture, "used_colon_instead_of_dot") { + // CLI-114792 Dot vs colon warnings aren't in the new solver yet. + ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + CheckResult result = check(R"( local T = {} T.x = 0 @@ -263,42 +382,22 @@ TEST_CASE_FIXTURE(Fixture, "used_colon_instead_of_dot") local a = T:method() )"); - auto it = std::find_if(result.errors.begin(), result.errors.end(), [](const TypeError& e) { - return nullptr != get(e); - }); + auto it = std::find_if( + result.errors.begin(), + result.errors.end(), + [](const TypeError& e) + { + return nullptr != get(e); + } + ); REQUIRE(it != result.errors.end()); } -#if 0 -TEST_CASE_FIXTURE(Fixture, "open_table_unification") -{ - CheckResult result = check(R"( - function foo(o) - print(o.foo) - print(o.bar) - end - - local a = {} - a.foo = 9 - - local b = {} - b.foo = 0 - - if random() then - b = a - end - - b.bar = '99' - - foo(a) - foo(b) - )"); - LUAU_REQUIRE_NO_ERRORS(result); -} -#endif - TEST_CASE_FIXTURE(Fixture, "open_table_unification_2") { + // CLI-114792 We don't report MissingProperties in many places where the old solver does. + ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + CheckResult result = check(R"( local a = {} a.x = 99 @@ -312,7 +411,7 @@ TEST_CASE_FIXTURE(Fixture, "open_table_unification_2") LUAU_REQUIRE_ERROR_COUNT(1, result); TypeError& err = result.errors[0]; MissingProperties* error = get(err); - REQUIRE(error != nullptr); + REQUIRE_MESSAGE(error != nullptr, "Expected MissingProperties but got " << toString(err)); REQUIRE(error->properties.size() == 1); CHECK_EQ("y", error->properties[0]); @@ -351,7 +450,7 @@ TEST_CASE_FIXTURE(Fixture, "open_table_unification_3") CHECK(arg0Table->props.count("baz")); } -TEST_CASE_FIXTURE(Fixture, "table_param_row_polymorphism_1") +TEST_CASE_FIXTURE(Fixture, "table_param_width_subtyping_1") { CheckResult result = check(R"( function foo(o) @@ -366,13 +465,13 @@ TEST_CASE_FIXTURE(Fixture, "table_param_row_polymorphism_1") LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "table_param_row_polymorphism_2") +TEST_CASE_FIXTURE(BuiltinsFixture, "table_param_width_subtyping_2") { CheckResult result = check(R"( --!strict function foo(o) - local a = o.bar - local b = o.baz + string.lower(o.bar) + string.lower(o.baz) end foo({bar='bar'}) @@ -380,14 +479,26 @@ TEST_CASE_FIXTURE(Fixture, "table_param_row_polymorphism_2") LUAU_REQUIRE_ERROR_COUNT(1, result); - MissingProperties* error = get(result.errors[0]); - REQUIRE(error != nullptr); - REQUIRE(error->properties.size() == 1); + // CLI 114792 We don't report MissingProperties in many places where the old solver does + if (FFlag::LuauSolverV2) + { + TypeMismatch* error = get(result.errors[0]); + REQUIRE_MESSAGE(error != nullptr, "Expected TypeMismatch but got " << toString(result.errors[0])); + + CHECK("{ read bar: string }" == toString(error->givenType)); + CHECK("{ read bar: string, read baz: string }" == toString(error->wantedType)); + } + else + { + MissingProperties* error = get(result.errors[0]); + REQUIRE_MESSAGE(error != nullptr, "Expected MissingProperties but got " << toString(result.errors[0])); + REQUIRE(error->properties.size() == 1); - CHECK_EQ("baz", error->properties[0]); + CHECK_EQ("baz", error->properties[0]); + } } -TEST_CASE_FIXTURE(Fixture, "table_param_row_polymorphism_3") +TEST_CASE_FIXTURE(Fixture, "table_param_width_subtyping_3") { CheckResult result = check(R"( local T = {} @@ -399,73 +510,34 @@ TEST_CASE_FIXTURE(Fixture, "table_param_row_polymorphism_3") )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - TypeError& err = result.errors[0]; - MissingProperties* error = get(err); - REQUIRE(error != nullptr); - REQUIRE(error->properties.size() == 1); - - CHECK_EQ("baz", error->properties[0]); - - // TODO(rblanckaert): Revist when we can bind self at function creation time - /* - CHECK_EQ(err->location, - (Location{ Position{4, 22}, Position{4, 30} }) - ); - */ - - CHECK_EQ(err.location, (Location{Position{6, 8}, Position{6, 9}})); -} - -#if 0 -TEST_CASE_FIXTURE(Fixture, "table_param_row_polymorphism_2") -{ - CheckResult result = check(R"( - function id(x) - return x - end - - function foo(o) - id(o.x) - id(o.y) - return o - end - - local a = {x=55, y=nil, w=3.14159} - local b = {} - b.x = 1 - b.y = 'hello' - b.z = 'something extra!' - - local q = foo(a) -- line 17 - local w = foo(b) -- line 18 - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - for (const auto& e : result.errors) - std::cout << "Error: " << e << std::endl; - TypeId qType = requireType("q"); - const TableType* qTable = get(qType); - REQUIRE(qType != nullptr); + CHECK(result.errors[0].location == Location{Position{6, 8}, Position{6, 9}}); - CHECK(qTable->props.find("x") != qTable->props.end()); - CHECK(qTable->props.find("y") != qTable->props.end()); - CHECK(qTable->props.find("z") == qTable->props.end()); - CHECK(qTable->props.find("w") != qTable->props.end()); - - TypeId wType = requireType("w"); - const TableType* wTable = get(wType); - REQUIRE(wTable != nullptr); - - CHECK(wTable->props.find("x") != wTable->props.end()); - CHECK(wTable->props.find("y") != wTable->props.end()); - CHECK(wTable->props.find("z") != wTable->props.end()); - CHECK(wTable->props.find("w") == wTable->props.end()); + if (FFlag::LuauSolverV2) + CHECK(toString(result.errors[0]) == "Type 'T' could not be converted into '{ read baz: unknown }'"); + else + { + TypeError& err = result.errors[0]; + MissingProperties* error = get(err); + REQUIRE_MESSAGE(error != nullptr, "Expected MissingProperties but got " << toString(err)); + REQUIRE(error->properties.size() == 1); + + CHECK_EQ("baz", error->properties[0]); + + // TODO(rblanckaert): Revist when we can bind self at function creation time + /* + CHECK_EQ(err->location, + (Location{ Position{4, 22}, Position{4, 30} }) + ); + */ + } } -#endif TEST_CASE_FIXTURE(Fixture, "table_unification_4") { + // CLI-114134 - Use egraphs to simplify types better. + ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + CheckResult result = check(R"( function foo(o) if o.prop then @@ -494,6 +566,9 @@ TEST_CASE_FIXTURE(Fixture, "ok_to_add_property_to_free_table") TEST_CASE_FIXTURE(Fixture, "okay_to_add_property_to_unsealed_tables_by_assignment") { + // CLI-114872 + ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + CheckResult result = check(R"( --!strict local t = { u = {} } @@ -510,6 +585,9 @@ TEST_CASE_FIXTURE(Fixture, "okay_to_add_property_to_unsealed_tables_by_assignmen TEST_CASE_FIXTURE(Fixture, "okay_to_add_property_to_unsealed_tables_by_function_call") { + // CLI-114873 + ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + CheckResult result = check(R"( --!strict function get(x) return x.opts["MYOPT"] end @@ -519,16 +597,8 @@ TEST_CASE_FIXTURE(Fixture, "okay_to_add_property_to_unsealed_tables_by_function_ local x = get(t) )"); - if (FFlag::DebugLuauDeferredConstraintResolution) - { - LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("number", toString(requireType("x"))); - } - else - { - LUAU_REQUIRE_ERRORS(result); - // CHECK_EQ("number?", toString(requireType("x"))); - } + LUAU_REQUIRE_ERRORS(result); + // CHECK_EQ("number?", toString(requireType("x"))); } TEST_CASE_FIXTURE(Fixture, "width_subtyping") @@ -576,8 +646,8 @@ TEST_CASE_FIXTURE(Fixture, "infer_array") REQUIRE(bool(ttv->indexer)); - CHECK_EQ(*ttv->indexer->indexType, *typeChecker.numberType); - CHECK_EQ(*ttv->indexer->indexResultType, *typeChecker.stringType); + CHECK_EQ(*ttv->indexer->indexType, *builtinTypes->numberType); + CHECK_EQ(*ttv->indexer->indexResultType, *builtinTypes->stringType); } /* This is a bit weird. @@ -621,23 +691,29 @@ TEST_CASE_FIXTURE(Fixture, "indexers_get_quantified_too") LUAU_REQUIRE_NO_ERRORS(result); - const FunctionType* ftv = get(requireType("swap")); - REQUIRE(ftv != nullptr); + if (FFlag::LuauSolverV2) + CHECK("({unknown}) -> ()" == toString(requireType("swap"))); + else + { + const FunctionType* ftv = get(requireType("swap")); + REQUIRE(ftv != nullptr); - std::vector argVec = flatten(ftv->argTypes).first; + std::vector argVec = flatten(ftv->argTypes).first; - REQUIRE_EQ(1, argVec.size()); + REQUIRE_EQ(1, argVec.size()); - const TableType* ttv = get(follow(argVec[0])); - REQUIRE(ttv != nullptr); + const TableType* ttv = get(follow(argVec[0])); + REQUIRE(ttv != nullptr); - REQUIRE(bool(ttv->indexer)); + REQUIRE(bool(ttv->indexer)); - const TableIndexer& indexer = *ttv->indexer; + const TableIndexer& indexer = *ttv->indexer; - REQUIRE("number" == toString(indexer.indexType)); + REQUIRE("number" == toString(indexer.indexType)); - REQUIRE(nullptr != get(follow(indexer.indexResultType))); + TypeId indexResultType = follow(indexer.indexResultType); + REQUIRE_MESSAGE(get(indexResultType), "Expected generic but got " << toString(indexResultType)); + } } TEST_CASE_FIXTURE(Fixture, "indexers_quantification_2") @@ -685,8 +761,15 @@ TEST_CASE_FIXTURE(Fixture, "infer_indexer_from_array_like_table") REQUIRE(bool(ttv->indexer)); const TableIndexer& indexer = *ttv->indexer; - CHECK_EQ(*typeChecker.numberType, *indexer.indexType); - CHECK_EQ(*typeChecker.stringType, *indexer.indexResultType); + CHECK_EQ(*builtinTypes->numberType, *indexer.indexType); + + if (FFlag::LuauSolverV2) + { + // CLI-114134 - Use egraphs to simplify types + CHECK("string | string | string" == toString(indexer.indexResultType)); + } + else + CHECK_EQ(*builtinTypes->stringType, *indexer.indexResultType); } TEST_CASE_FIXTURE(Fixture, "infer_indexer_from_value_property_in_literal") @@ -719,11 +802,17 @@ TEST_CASE_FIXTURE(Fixture, "infer_indexer_from_value_property_in_literal") CHECK(bool(retType->indexer)); const TableIndexer& indexer = *retType->indexer; - CHECK_EQ("{| __name: string |}", toString(indexer.indexType)); + if (FFlag::LuauSolverV2) + CHECK_EQ("{ __name: string }", toString(indexer.indexType)); + else + CHECK_EQ("{| __name: string |}", toString(indexer.indexType)); } TEST_CASE_FIXTURE(Fixture, "infer_indexer_from_its_variable_type_and_unifiable") { + // This code is totally different in the new solver. We instead create a new type state for t2. + ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + CheckResult result = check(R"( local t1: { [string]: string } = {} local t2 = { "bar" } @@ -736,12 +825,13 @@ TEST_CASE_FIXTURE(Fixture, "infer_indexer_from_its_variable_type_and_unifiable") TypeMismatch* tm = get(result.errors[0]); REQUIRE(tm != nullptr); - const TableType* tTy = get(requireType("t2")); - REQUIRE(tTy != nullptr); + TypeId t2Ty = requireType("t2"); + const TableType* tTy = get(t2Ty); + REQUIRE_MESSAGE(tTy != nullptr, "Expected a table but got " << toString(t2Ty)); REQUIRE(tTy->indexer); - CHECK_EQ(*typeChecker.numberType, *tTy->indexer->indexType); - CHECK_EQ(*typeChecker.stringType, *tTy->indexer->indexResultType); + CHECK_EQ(*builtinTypes->numberType, *tTy->indexer->indexType); + CHECK_EQ(*builtinTypes->stringType, *tTy->indexer->indexResultType); } TEST_CASE_FIXTURE(Fixture, "indexer_mismatch") @@ -761,7 +851,10 @@ TEST_CASE_FIXTURE(Fixture, "indexer_mismatch") TypeMismatch* tm = get(result.errors[0]); REQUIRE(tm != nullptr); CHECK(toString(tm->wantedType) == "{number}"); - CHECK(toString(tm->givenType) == "{| [string]: string |}"); + if (FFlag::LuauSolverV2) + CHECK(toString(tm->givenType) == "{ [string]: string }"); + else + CHECK(toString(tm->givenType) == "{| [string]: string |}"); CHECK_NE(*t1, *t2); } @@ -800,6 +893,8 @@ TEST_CASE_FIXTURE(Fixture, "sealed_table_value_can_infer_an_indexer") TEST_CASE_FIXTURE(Fixture, "array_factory_function") { + ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + CheckResult result = check(R"( function empty() return {} end local array: {string} = empty() @@ -811,24 +906,36 @@ TEST_CASE_FIXTURE(Fixture, "array_factory_function") TEST_CASE_FIXTURE(Fixture, "sealed_table_indexers_must_unify") { CheckResult result = check(R"( - local A = { 5, 7, 8 } - local B = { "one", "two", "three" } - - B = A + function f(a: {number}): {string} + return a + end )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_MESSAGE(nullptr != get(result.errors[0]), "Expected a TypeMismatch but got " << result.errors[0]); + if (FFlag::LuauSolverV2) + { + // CLI-114879 - Error path reporting is not great + CHECK( + toString(result.errors[0]) == + "Type pack '{number}' could not be converted into '{string}'; at [0].indexResult(), number is not exactly string" + ); + } + else + CHECK_MESSAGE(nullptr != get(result.errors[0]), "Expected a TypeMismatch but got " << result.errors[0]); } TEST_CASE_FIXTURE(Fixture, "indexer_on_sealed_table_must_unify_with_free_table") { + // CLI-114134 What should be happening here is that the type of `t` should + // be reduced from `{number} & {string}` to `never`, but that's not + // happening. + ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + CheckResult result = check(R"( - local A = { 1, 2, 3 } - function F(t) + function F(t): {number} t[4] = "hi" - A = t + return t end )"); @@ -838,19 +945,27 @@ TEST_CASE_FIXTURE(Fixture, "indexer_on_sealed_table_must_unify_with_free_table") TEST_CASE_FIXTURE(Fixture, "infer_type_when_indexing_from_a_table_indexer") { CheckResult result = check(R"( - local t: { [number]: string } - local s = t[1] + function f(t: {string}) + return t[1] + end + + local s = f({}) )"); LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(*typeChecker.stringType, *requireType("s")); + CHECK_EQ(*builtinTypes->stringType, *requireType("s")); } -TEST_CASE_FIXTURE(Fixture, "indexing_from_a_table_should_prefer_properties_when_possible") +TEST_CASE_FIXTURE(BuiltinsFixture, "indexing_from_a_table_should_prefer_properties_when_possible") { CheckResult result = check(R"( - local t: { a: string, [string]: number } + function f(): { a: string, [string]: number } + error("e") + end + + local t = f() + local a1 = t.a local a2 = t["a"] @@ -865,19 +980,21 @@ TEST_CASE_FIXTURE(Fixture, "indexing_from_a_table_should_prefer_properties_when_ LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(*typeChecker.stringType, *requireType("a1")); - CHECK_EQ(*typeChecker.stringType, *requireType("a2")); + CHECK_EQ(*builtinTypes->stringType, *requireType("a1")); + CHECK_EQ(*builtinTypes->stringType, *requireType("a2")); - CHECK_EQ(*typeChecker.numberType, *requireType("b1")); - CHECK_EQ(*typeChecker.numberType, *requireType("b2")); + CHECK_EQ(*builtinTypes->numberType, *requireType("b1")); + CHECK_EQ(*builtinTypes->numberType, *requireType("b2")); - CHECK_EQ(*typeChecker.numberType, *requireType("c")); + CHECK_EQ(*builtinTypes->numberType, *requireType("c")); CHECK_MESSAGE(nullptr != get(result.errors[0]), "Expected a TypeMismatch but got " << result.errors[0]); } TEST_CASE_FIXTURE(Fixture, "any_when_indexing_into_an_unsealed_table_with_no_indexer_in_nonstrict_mode") { + ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + CheckResult result = check(R"( --!nonstrict @@ -913,10 +1030,10 @@ TEST_CASE_FIXTURE(Fixture, "disallow_indexing_into_an_unsealed_table_with_no_ind local k1 = getConstant("key1") )"); - if (FFlag::LuauDontExtendUnsealedRValueTables) - CHECK("any" == toString(requireType("k1"))); + if (FFlag::LuauSolverV2) + CHECK("unknown" == toString(requireType("k1"))); else - CHECK("a" == toString(requireType("k1"))); + CHECK("any" == toString(requireType("k1"))); LUAU_REQUIRE_NO_ERRORS(result); } @@ -932,16 +1049,17 @@ TEST_CASE_FIXTURE(Fixture, "assigning_to_an_unsealed_table_with_string_literal_s LUAU_REQUIRE_NO_ERRORS(result); - CHECK("string" == toString(*typeChecker.stringType)); + CHECK("string" == toString(*builtinTypes->stringType)); - TableType* tableType = getMutable(requireType("t")); - REQUIRE(tableType != nullptr); + TypeId tType = requireType("t"); + TableType* tableType = getMutable(tType); + REQUIRE_MESSAGE(tableType != nullptr, "Expected a table but got " << toString(tType, {true})); REQUIRE(tableType->indexer == std::nullopt); REQUIRE(0 != tableType->props.count("a")); - TypeId propertyA = tableType->props["a"].type; + TypeId propertyA = tableType->props["a"].type(); REQUIRE(propertyA != nullptr); - CHECK_EQ(*typeChecker.stringType, *propertyA); + CHECK_EQ(*builtinTypes->stringType, *propertyA); } TEST_CASE_FIXTURE(BuiltinsFixture, "oop_indexer_works") @@ -964,7 +1082,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "oop_indexer_works") LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(*typeChecker.stringType, *requireType("words")); + CHECK_EQ(*builtinTypes->stringType, *requireType("words")); } TEST_CASE_FIXTURE(BuiltinsFixture, "indexer_table") @@ -977,7 +1095,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "indexer_table") LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(*typeChecker.stringType, *requireType("b")); + CHECK_EQ(*builtinTypes->stringType, *requireType("b")); } TEST_CASE_FIXTURE(BuiltinsFixture, "indexer_fn") @@ -988,7 +1106,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "indexer_fn") )"); LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(*typeChecker.numberType, *requireType("b")); + CHECK_EQ(*builtinTypes->numberType, *requireType("b")); } TEST_CASE_FIXTURE(BuiltinsFixture, "meta_add") @@ -997,9 +1115,13 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "meta_add") // We'll want to change this one in particular when we add real syntax for metatables. CheckResult result = check(R"( - local a = setmetatable({}, {__add = function(l, r) return l end}) - type Vector = typeof(a) - local b:Vector + local mt = { + __add = function(l, r) + return l + end + } + local a = setmetatable({}, mt) + local b = setmetatable({}, mt) local c = a + b )"); @@ -1023,6 +1145,8 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "meta_add_inferred") TEST_CASE_FIXTURE(BuiltinsFixture, "meta_add_both_ways") { + ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + CheckResult result = check(R"( type VectorMt = { __add: (Vector, number) -> Vector } local vectorMt: VectorMt @@ -1039,6 +1163,30 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "meta_add_both_ways") CHECK_EQ(*requireType("a"), *requireType("c")); } +TEST_CASE_FIXTURE(BuiltinsFixture, "meta_add_both_ways_lti") +{ + ScopedFastFlag sff{FFlag::LuauSolverV2, true}; + + CheckResult result = check(R"( + local vectorMt = {} + + function vectorMt.__add(self: Vector, other: number) + return self + end + + type Vector = typeof(setmetatable({}, vectorMt)) + local a: Vector = setmetatable({}, vectorMt) + + local b = a + 2 + local c = 2 + a + )"); + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("Vector", toString(requireType("a"))); + CHECK_EQ(*requireType("a"), *requireType("b")); + CHECK_EQ(*requireType("a"), *requireType("c")); +} + // This test exposed a bug where we let go of the "seen" stack while unifying table types // As a result, type inference crashed with a stack overflow. TEST_CASE_FIXTURE(BuiltinsFixture, "unification_of_unions_in_a_self_referential_type") @@ -1063,11 +1211,11 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "unification_of_unions_in_a_self_referential_ const MetatableType* amtv = get(requireType("a")); REQUIRE(amtv); - CHECK_EQ(amtv->metatable, requireType("amt")); + CHECK_EQ(follow(amtv->metatable), follow(requireType("amt"))); const MetatableType* bmtv = get(requireType("b")); REQUIRE(bmtv); - CHECK_EQ(bmtv->metatable, requireType("bmt")); + CHECK_EQ(follow(bmtv->metatable), follow(requireType("bmt"))); } TEST_CASE_FIXTURE(BuiltinsFixture, "oop_polymorphic") @@ -1102,10 +1250,10 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "oop_polymorphic") LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(*typeChecker.booleanType, *requireType("alive")); - CHECK_EQ(*typeChecker.stringType, *requireType("movement")); - CHECK_EQ(*typeChecker.stringType, *requireType("name")); - CHECK_EQ(*typeChecker.numberType, *requireType("speed")); + CHECK_EQ(*builtinTypes->booleanType, *requireType("alive")); + CHECK_EQ(*builtinTypes->stringType, *requireType("movement")); + CHECK_EQ(*builtinTypes->stringType, *requireType("name")); + CHECK_EQ(*builtinTypes->numberType, *requireType("speed")); } TEST_CASE_FIXTURE(Fixture, "user_defined_table_types_are_named") @@ -1113,7 +1261,7 @@ TEST_CASE_FIXTURE(Fixture, "user_defined_table_types_are_named") CheckResult result = check(R"( type Vector3 = {x: number, y: number} - local v: Vector3 + local v: Vector3 = {x = 5, y = 7} )"); LUAU_REQUIRE_NO_ERRORS(result); @@ -1155,8 +1303,11 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "result_is_always_any_if_lhs_is_any") TEST_CASE_FIXTURE(Fixture, "result_is_bool_for_equality_operators_if_lhs_is_any") { CheckResult result = check(R"( - local a: any - local b: number + function f(): (any, number) + return 5, 7 + end + + local a: any, b: number = f() local c = a < b )"); @@ -1260,13 +1411,14 @@ TEST_CASE_FIXTURE(Fixture, "pass_incompatible_union_to_a_generic_table_without_c CheckResult result = check(R"( -- must be in this specific order, and with (roughly) those exact properties! type A = {x: number, [any]: any} | {} - local a: A function f(t) t.y = 1 end - f(a) + function g(a: A) + f(a) + end )"); LUAU_REQUIRE_ERROR_COUNT(1, result); @@ -1283,7 +1435,9 @@ TEST_CASE_FIXTURE(Fixture, "passing_compatible_unions_to_a_generic_table_without t.y = 1 end - f({y = 5} :: A) + function g(a: A) + f(a) + end )"); LUAU_REQUIRE_NO_ERRORS(result); @@ -1368,6 +1522,9 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "found_multiple_like_keys") TEST_CASE_FIXTURE(BuiltinsFixture, "dont_suggest_exact_match_keys") { + // CLI-114977 Unsealed table writes don't account for order properly + ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + CheckResult result = check(R"( local t = {} t.foO = 1 @@ -1408,6 +1565,9 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "getmetatable_returns_pointer_to_metatable") TEST_CASE_FIXTURE(BuiltinsFixture, "metatable_mismatch_should_fail") { + // This test is invalid because we now create a new type state for t1 at the assignment. + ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + CheckResult result = check(R"( local t1 = {x = 1} local mt1 = {__index = {y = 2}} @@ -1449,6 +1609,9 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "property_lookup_through_tabletypevar_metatab TEST_CASE_FIXTURE(BuiltinsFixture, "missing_metatable_for_sealed_tables_do_not_get_inferred") { + // This test is invalid because we now create a new type state for t at the assignment. + ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + CheckResult result = check(R"( local t = {x = 1} @@ -1501,17 +1664,19 @@ TEST_CASE_FIXTURE(Fixture, "right_table_missing_key") // Could be flaky if the fix has regressed. TEST_CASE_FIXTURE(Fixture, "right_table_missing_key2") { - CheckResult result = check(R"( - local lt: { [string]: string, a: string } - local rt: {} + // CLI-114792 We don't report MissingProperties + ScopedFastFlag sff{FFlag::LuauSolverV2, false}; - lt = rt + CheckResult result = check(R"( + function f(t: {}): { [string]: string, a: string } + return t + end )"); LUAU_REQUIRE_ERROR_COUNT(1, result); MissingProperties* mp = get(result.errors[0]); - REQUIRE(mp); + REQUIRE_MESSAGE(mp, "Expected MissingProperties but got " << toString(result.errors[0])); CHECK_EQ(mp->context, MissingProperties::Missing); REQUIRE_EQ(1, mp->properties.size()); CHECK_EQ(mp->properties[0], "a"); @@ -1532,10 +1697,20 @@ TEST_CASE_FIXTURE(Fixture, "casting_unsealed_tables_with_props_into_table_with_i ToStringOptions o{/* exhaustive= */ true}; TypeMismatch* tm = get(result.errors[0]); REQUIRE(tm); - CHECK_EQ("{| [string]: string |}", toString(tm->wantedType, o)); - // Should t now have an indexer? - // It would if the assignment to rt was correctly typed. - CHECK_EQ("{ [string]: string, foo: number }", toString(tm->givenType, o)); + + if (FFlag::LuauSolverV2) + { + CHECK_EQ("{ [string]: string }", toString(tm->wantedType, o)); + CHECK_EQ("{ [string]: number }", toString(tm->givenType, o)); + } + else + { + CHECK_EQ("{| [string]: string |}", toString(tm->wantedType, o)); + + // Should t now have an indexer? + // It would if the assignment to rt was correctly typed. + CHECK_EQ("{ [string]: string, foo: number }", toString(tm->givenType, o)); + } } TEST_CASE_FIXTURE(Fixture, "casting_sealed_tables_with_props_into_table_with_indexer") @@ -1551,14 +1726,22 @@ TEST_CASE_FIXTURE(Fixture, "casting_sealed_tables_with_props_into_table_with_ind ToStringOptions o{/* exhaustive= */ true}; TypeMismatch* tm = get(result.errors[0]); REQUIRE(tm); - CHECK_EQ("{| [string]: string |}", toString(tm->wantedType, o)); - CHECK_EQ("{| foo: number |}", toString(tm->givenType, o)); + if (FFlag::LuauSolverV2) + { + CHECK_EQ("{ [string]: string }", toString(tm->wantedType, o)); + CHECK_EQ("{ foo: number }", toString(tm->givenType, o)); + } + else + { + CHECK_EQ("{| [string]: string |}", toString(tm->wantedType, o)); + CHECK_EQ("{| foo: number |}", toString(tm->givenType, o)); + } } TEST_CASE_FIXTURE(Fixture, "casting_tables_with_props_into_table_with_indexer2") { CheckResult result = check(R"( - local function foo(a: {[string]: number, a: string}) end + local function foo(x: {[string]: number, a: string}) end foo({ a = "" }) )"); @@ -1577,8 +1760,17 @@ TEST_CASE_FIXTURE(Fixture, "casting_tables_with_props_into_table_with_indexer3") ToStringOptions o{/* exhaustive= */ true}; TypeMismatch* tm = get(result.errors[0]); REQUIRE(tm); - CHECK_EQ("{| [string]: number, a: string |}", toString(tm->wantedType, o)); - CHECK_EQ("{ a: number }", toString(tm->givenType, o)); + + if (FFlag::LuauSolverV2) + { + CHECK("string" == toString(tm->wantedType)); + CHECK("number" == toString(tm->givenType)); + } + else + { + CHECK_EQ("{| [string]: number, a: string |}", toString(tm->wantedType, o)); + CHECK_EQ("{ [string]: number, a: number }", toString(tm->givenType, o)); + } } TEST_CASE_FIXTURE(Fixture, "casting_tables_with_props_into_table_with_indexer4") @@ -1597,38 +1789,62 @@ TEST_CASE_FIXTURE(Fixture, "casting_tables_with_props_into_table_with_indexer4") TEST_CASE_FIXTURE(Fixture, "table_subtyping_with_missing_props_dont_report_multiple_errors") { CheckResult result = check(R"( - local vec3 = {x = 1, y = 2, z = 3} - local vec1 = {x = 1} - - vec3 = vec1 + function f(vec1: {x: number}): {x: number, y: number, z: number} + return vec1 + end )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - MissingProperties* mp = get(result.errors[0]); - REQUIRE(mp); - CHECK_EQ(mp->context, MissingProperties::Missing); - REQUIRE_EQ(2, mp->properties.size()); - CHECK_EQ(mp->properties[0], "y"); - CHECK_EQ(mp->properties[1], "z"); - CHECK_EQ("vec3", toString(mp->superType)); - CHECK_EQ("vec1", toString(mp->subType)); + if (FFlag::LuauSolverV2) + { + CHECK_EQ( + "Type pack '{ x: number }' could not be converted into '{ x: number, y: number, z: number }';" + " at [0], { x: number } is not a subtype of { x: number, y: number, z: number }", + toString(result.errors[0]) + ); + } + else + { + MissingProperties* mp = get(result.errors[0]); + REQUIRE_MESSAGE(mp, result.errors[0]); + CHECK_EQ(mp->context, MissingProperties::Missing); + REQUIRE_EQ(2, mp->properties.size()); + CHECK_EQ(mp->properties[0], "y"); + CHECK_EQ(mp->properties[1], "z"); + + CHECK_EQ("{| x: number, y: number, z: number |}", toString(mp->superType)); + CHECK_EQ("{| x: number |}", toString(mp->subType)); + } } TEST_CASE_FIXTURE(Fixture, "table_subtyping_with_missing_props_dont_report_multiple_errors2") { CheckResult result = check(R"( - type DumbMixedTable = {[number]: number, x: number} - local t: DumbMixedTable = {"fail"} + type MixedTable = {[number]: number, x: number} + local t: MixedTable = {"fail"} )"); - LUAU_REQUIRE_ERROR_COUNT(2, result); + if (FFlag::LuauSolverV2) + { + LUAU_REQUIRE_ERROR_COUNT(1, result); - MissingProperties* mp = get(result.errors[1]); - REQUIRE(mp); - CHECK_EQ(mp->context, MissingProperties::Missing); - REQUIRE_EQ(1, mp->properties.size()); - CHECK_EQ(mp->properties[0], "x"); + TypeMismatch* tm = get(result.errors[0]); + REQUIRE(tm); + + CHECK("MixedTable" == toString(tm->wantedType)); + CHECK("{string}" == toString(tm->givenType)); + } + else + { + LUAU_REQUIRE_ERROR_COUNT(2, result); + + MissingProperties* mp = get(result.errors[1]); + REQUIRE(mp); + CHECK_EQ(mp->context, MissingProperties::Missing); + REQUIRE_EQ(1, mp->properties.size()); + CHECK_EQ(mp->properties[0], "x"); + } } TEST_CASE_FIXTURE(Fixture, "table_subtyping_with_extra_props_dont_report_multiple_errors") @@ -1637,8 +1853,8 @@ TEST_CASE_FIXTURE(Fixture, "table_subtyping_with_extra_props_dont_report_multipl function mkvec3() return {x = 1, y = 2, z = 3} end function mkvec1() return {x = 1} end - local vec3 = {mkvec3()} - local vec1 = {mkvec1()} + local vec3: {{x: number, y: number, z: number}} = {mkvec3()} + local vec1: {{x: number}} = {mkvec1()} vec1 = vec3 )"); @@ -1647,12 +1863,21 @@ TEST_CASE_FIXTURE(Fixture, "table_subtyping_with_extra_props_dont_report_multipl TypeMismatch* tm = get(result.errors[0]); REQUIRE(tm); - CHECK_EQ("vec1", toString(tm->wantedType)); - CHECK_EQ("vec3", toString(tm->givenType)); -} -TEST_CASE_FIXTURE(Fixture, "table_subtyping_with_extra_props_is_ok") -{ + if (FFlag::LuauSolverV2) + { + CHECK_EQ("{{ x: number }}", toString(tm->wantedType)); + CHECK_EQ("{{ x: number, y: number, z: number }}", toString(tm->givenType)); + } + else + { + CHECK_EQ("{{| x: number |}}", toString(tm->wantedType)); + CHECK_EQ("{{| x: number, y: number, z: number |}}", toString(tm->givenType)); + } +} + +TEST_CASE_FIXTURE(Fixture, "table_subtyping_with_extra_props_is_ok") +{ CheckResult result = check(R"( local vec3 = {x = 1, y = 2, z = 3} local vec1 = {x = 1} @@ -1665,18 +1890,10 @@ TEST_CASE_FIXTURE(Fixture, "table_subtyping_with_extra_props_is_ok") TEST_CASE_FIXTURE(Fixture, "type_mismatch_on_massive_table_is_cut_short") { - ScopedFastInt sfis{"LuauTableTypeMaximumStringifierLength", 40}; + ScopedFastInt sfis{FInt::LuauTableTypeMaximumStringifierLength, 40}; CheckResult result = check(R"( - local t - t = {} - t.a = 1 - t.b = 1 - t.c = 1 - t.d = 1 - t.e = 1 - t.f = 1 - + local t: {a: number,b: number, c: number, d: number, e: number, f: number} = nil :: any t = 1 )"); @@ -1684,15 +1901,34 @@ TEST_CASE_FIXTURE(Fixture, "type_mismatch_on_massive_table_is_cut_short") TypeMismatch* tm = get(result.errors[0]); REQUIRE(tm); - CHECK("{ a: number, b: number, c: number, d: number, e: number, ... 1 more ... }" == toString(requireType("t"))); - CHECK_EQ("number", toString(tm->givenType)); - CHECK_EQ("Type 'number' could not be converted into '{ a: number, b: number, c: number, d: number, e: number, ... 1 more ... }'", - toString(result.errors[0])); + if (FFlag::LuauSolverV2) + { + CHECK("{ a: number, b: number, c: number, d: number, e: number, ... 1 more ... }" == toString(requireType("t"))); + CHECK_EQ("number", toString(tm->givenType)); + + CHECK_EQ( + "Type 'number' could not be converted into '{ a: number, b: number, c: number, d: number, e: number, ... 1 more ... }'", + toString(result.errors[0]) + ); + } + else + { + CHECK("{| a: number, b: number, c: number, d: number, e: number, ... 1 more ... |}" == toString(requireType("t"))); + CHECK_EQ("number", toString(tm->givenType)); + + CHECK_EQ( + "Type 'number' could not be converted into '{| a: number, b: number, c: number, d: number, e: number, ... 1 more ... |}'", + toString(result.errors[0]) + ); + } } TEST_CASE_FIXTURE(Fixture, "ok_to_set_nil_even_on_non_lvalue_base_expr") { + // CLI-100076 Assigning nil to an indexer should always succeed + ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + CheckResult result = check(R"( local function f(): { [string]: number } return { ["foo"] = 1 } @@ -1712,7 +1948,14 @@ TEST_CASE_FIXTURE(Fixture, "ok_to_provide_a_subtype_during_construction") )"); LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("{number | string}", toString(requireType("t"), {/*exhaustive*/ true})); + + if (FFlag::LuauSolverV2) + { + // CLI-114134 Use egraphs to simplify types more consistently + CHECK("{number | number | string}" == toString(requireType("t"), {/*exhaustive*/ true})); + } + else + CHECK_EQ("{number | string}", toString(requireType("t"), {/*exhaustive*/ true})); } TEST_CASE_FIXTURE(Fixture, "reasonable_error_when_adding_a_nonexistent_property_to_an_array_like_table") @@ -1726,10 +1969,20 @@ TEST_CASE_FIXTURE(Fixture, "reasonable_error_when_adding_a_nonexistent_property_ LUAU_REQUIRE_ERROR_COUNT(1, result); - UnknownProperty* up = get(result.errors[0]); - REQUIRE(up != nullptr); + if (FFlag::LuauSolverV2) + { + CannotExtendTable* cet = get(result.errors[0]); + REQUIRE_MESSAGE(cet, "Expected CannotExtendTable but got " << result.errors[0]); - CHECK_EQ("B", up->key); + CHECK("B" == cet->prop); + } + else + { + UnknownProperty* up = get(result.errors[0]); + REQUIRE_MESSAGE(up != nullptr, "Expected an UnknownProperty but got " << result.errors[0]); + + CHECK_EQ("B", up->key); + } } TEST_CASE_FIXTURE(Fixture, "shorter_array_types_actually_work") @@ -1759,7 +2012,11 @@ TEST_CASE_FIXTURE(Fixture, "only_ascribe_synthetic_names_at_module_scope") LUAU_REQUIRE_ERROR_COUNT(0, result); CHECK_EQ("TopLevel", toString(requireType("TopLevel"))); - CHECK_EQ("{number}", toString(requireType("foo"))); + + if (FFlag::LuauSolverV2) + CHECK_EQ("{number}?", toString(requireType("foo"))); + else + CHECK_EQ("{number}", toString(requireType("foo"))); } TEST_CASE_FIXTURE(Fixture, "hide_table_error_properties") @@ -1780,8 +2037,16 @@ TEST_CASE_FIXTURE(Fixture, "hide_table_error_properties") LUAU_REQUIRE_ERROR_COUNT(2, result); - CHECK_EQ("Cannot add property 'a' to table '{| x: number |}'", toString(result.errors[0])); - CHECK_EQ("Cannot add property 'b' to table '{| x: number |}'", toString(result.errors[1])); + if (FFlag::LuauSolverV2) + { + CHECK_EQ("Cannot add property 'a' to table '{ x: number }'", toString(result.errors[0])); + CHECK_EQ("Cannot add property 'b' to table '{ x: number }'", toString(result.errors[1])); + } + else + { + CHECK_EQ("Cannot add property 'a' to table '{| x: number |}'", toString(result.errors[0])); + CHECK_EQ("Cannot add property 'b' to table '{| x: number |}'", toString(result.errors[1])); + } } TEST_CASE_FIXTURE(BuiltinsFixture, "builtin_table_names") @@ -1800,7 +2065,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "builtin_table_names") TEST_CASE_FIXTURE(BuiltinsFixture, "persistent_sealed_table_is_immutable") { CheckResult result = check(R"( - --!nonstrict function os:bad() end )"); @@ -1831,16 +2095,19 @@ local Test: {Table} = { TEST_CASE_FIXTURE(Fixture, "common_table_element_general") { + // CLI-115275 - Bidirectional inference does not always propagate indexer types into the expression + ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + CheckResult result = check(R"( -type Table = { - a: number, - b: number? -} + type Table = { + a: number, + b: number? + } -local Test: {Table} = { - [2] = { a = 1 }, - [5] = { a = 2, b = 3 } -} + local Test: {Table} = { + [2] = { a = 1 }, + [5] = { a = 2, b = 3 } + } )"); LUAU_REQUIRE_NO_ERRORS(result); @@ -1918,11 +2185,11 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "quantifying_a_bound_var_works") LUAU_REQUIRE_NO_ERRORS(result); TypeId ty = requireType("clazz"); TableType* ttv = getMutable(ty); - REQUIRE(ttv); + REQUIRE_MESSAGE(ttv, "Expected a table but got " << toString(ty, {true})); REQUIRE(ttv->props.count("new")); Property& prop = ttv->props["new"]; - REQUIRE(prop.type); - const FunctionType* ftv = get(follow(prop.type)); + REQUIRE(prop.type()); + const FunctionType* ftv = get(follow(prop.type())); REQUIRE(ftv); const TypePack* res = get(follow(ftv->retTypes)); REQUIRE(res); @@ -1934,40 +2201,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "quantifying_a_bound_var_works") REQUIRE_EQ(ttv->state, TableState::Sealed); } -TEST_CASE_FIXTURE(BuiltinsFixture, "less_exponential_blowup_please") -{ - ScopedFastFlag sff{"DebugLuauSharedSelf", true}; - - CheckResult result = check(R"( - --!strict - - local Foo = setmetatable({}, {}) - Foo.__index = Foo - - function Foo.new() - local self = setmetatable({}, Foo) - return self:constructor() or self - end - function Foo:constructor() end - - function Foo:create() - local foo = Foo.new() - foo:First() - foo:Second() - foo:Third() - return foo - end - function Foo:First() end - function Foo:Second() end - function Foo:Third() end - - local newData = Foo:create() - newData:First() - )"); - - LUAU_REQUIRE_ERROR_COUNT(2, result); -} - TEST_CASE_FIXTURE(Fixture, "common_table_element_union_in_call") { CheckResult result = check(R"( @@ -1984,11 +2217,18 @@ foo({ TEST_CASE_FIXTURE(Fixture, "common_table_element_union_in_call_tail") { + // CLI-115239 - Bidirectional checking does not work for __call metamethods + ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + CheckResult result = check(R"( -type Foo = {x: number | string} -local function foo(l: {Foo}, ...: {Foo}) end + type Foo = {x: number | string} + local function foo(l: {Foo}, ...: {Foo}) end -foo({{x = 1234567}, {x = "hello"}}, {{x = 1234567}, {x = "hello"}}, {{x = 1234567}, {x = "hello"}}) + foo( + {{x = 1234567}, {x = "hello"}}, + {{x = 1234567}, {x = "hello"}}, + {{x = 1234567}, {x = "hello"}} + ) )"); LUAU_REQUIRE_NO_ERRORS(result); @@ -2025,8 +2265,18 @@ TEST_CASE_FIXTURE(Fixture, "invariant_table_properties_means_instantiating_table local c : string = t.m("hi") )"); + if (FFlag::LuauSolverV2) + { + LUAU_REQUIRE_ERROR_COUNT(2, result); + + CHECK(get(result.errors[0])); + CHECK(Location{{6, 45}, {6, 46}} == result.errors[0].location); + + CHECK(get(result.errors[1])); + } + // TODO: test behavior is wrong with LuauInstantiateInSubtyping until we can re-enable the covariant requirement for instantiation in subtyping - if (FFlag::LuauInstantiateInSubtyping) + else if (FFlag::LuauInstantiateInSubtyping) LUAU_REQUIRE_NO_ERRORS(result); else LUAU_REQUIRE_ERRORS(result); @@ -2064,19 +2314,22 @@ TEST_CASE_FIXTURE(Fixture, "error_detailed_prop") type A = { x: number, y: number } type B = { x: number, y: string } -local a: A +local a: A = { x = 123, y = 456 } local b: B = a )"); LUAU_REQUIRE_ERRORS(result); - if (FFlag::LuauTypeMismatchInvarianceInError) - CHECK_EQ(toString(result.errors[0]), R"(Type 'A' could not be converted into 'B' -caused by: - Property 'y' is not compatible. Type 'number' could not be converted into 'string' in an invariant context)"); + + if (FFlag::LuauSolverV2) + CHECK(toString(result.errors.at(0)) == R"(Type 'A' could not be converted into 'B'; at [read "y"], number is not exactly string)"); else - CHECK_EQ(toString(result.errors[0]), R"(Type 'A' could not be converted into 'B' + { + const std::string expected = R"(Type 'A' could not be converted into 'B' caused by: - Property 'y' is not compatible. Type 'number' could not be converted into 'string')"); + Property 'y' is not compatible. +Type 'number' could not be converted into 'string' in an invariant context)"; + CHECK_EQ(expected, toString(result.errors[0])); + } } TEST_CASE_FIXTURE(Fixture, "error_detailed_prop_nested") @@ -2088,23 +2341,25 @@ type BS = { x: number, y: string } type A = { a: boolean, b: AS } type B = { a: boolean, b: BS } -local a: A +local a: A = { a = false, b = { x = 123, y = 456 } } local b: B = a )"); LUAU_REQUIRE_ERRORS(result); - if (FFlag::LuauTypeMismatchInvarianceInError) - CHECK_EQ(toString(result.errors[0]), R"(Type 'A' could not be converted into 'B' -caused by: - Property 'b' is not compatible. Type 'AS' could not be converted into 'BS' -caused by: - Property 'y' is not compatible. Type 'number' could not be converted into 'string' in an invariant context)"); + + if (FFlag::LuauSolverV2) + CHECK(toString(result.errors.at(0)) == R"(Type 'A' could not be converted into 'B'; at [read "b"][read "y"], number is not exactly string)"); else - CHECK_EQ(toString(result.errors[0]), R"(Type 'A' could not be converted into 'B' + { + const std::string expected = R"(Type 'A' could not be converted into 'B' caused by: - Property 'b' is not compatible. Type 'AS' could not be converted into 'BS' + Property 'b' is not compatible. +Type 'AS' could not be converted into 'BS' caused by: - Property 'y' is not compatible. Type 'number' could not be converted into 'string')"); + Property 'y' is not compatible. +Type 'number' could not be converted into 'string' in an invariant context)"; + CHECK_EQ(expected, toString(result.errors[0])); + } } TEST_CASE_FIXTURE(BuiltinsFixture, "error_detailed_metatable_prop") @@ -2119,35 +2374,78 @@ local b2 = setmetatable({ x = 2, y = 4 }, { __call = function(s, t) end }); local c2: typeof(a2) = b2 )"); - LUAU_REQUIRE_ERROR_COUNT(2, result); - if (FFlag::LuauTypeMismatchInvarianceInError) - CHECK_EQ(toString(result.errors[0]), R"(Type 'b1' could not be converted into 'a1' + const std::string expected1 = R"(Type 'b1' could not be converted into 'a1' caused by: - Type '{ x: number, y: string }' could not be converted into '{ x: number, y: number }' + Type + '{ x: number, y: string }' +could not be converted into + '{ x: number, y: number }' caused by: - Property 'y' is not compatible. Type 'string' could not be converted into 'number' in an invariant context)"); - else - CHECK_EQ(toString(result.errors[0]), R"(Type 'b1' could not be converted into 'a1' + Property 'y' is not compatible. +Type 'string' could not be converted into 'number' in an invariant context)"; + const std::string expected2 = R"(Type 'b2' could not be converted into 'a2' caused by: - Type '{ x: number, y: string }' could not be converted into '{ x: number, y: number }' + Type + '{ __call: (a, b) -> () }' +could not be converted into + '{ __call: (a) -> () }' caused by: - Property 'y' is not compatible. Type 'string' could not be converted into 'number')"); + Property '__call' is not compatible. +Type + '(a, b) -> ()' +could not be converted into + '(a) -> ()'; different number of generic type parameters)"; - if (FFlag::LuauInstantiateInSubtyping) + if (FFlag::LuauSolverV2) { - CHECK_EQ(toString(result.errors[1]), R"(Type 'b2' could not be converted into 'a2' + // The assignment of c2 to b2 is, surprisingly, allowed under the new + // solver for two reasons: + // + // First, both of the __call functions have hidden ...any arguments + // because their exact definition is available. + // + // Second, nil <: unknown, so we consider that parameter to be optional. + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK("Type 'b1' could not be converted into 'a1'; at table()[read \"y\"], string is not exactly number" == toString(result.errors[0])); + } + else if (FFlag::LuauInstantiateInSubtyping) + { + LUAU_REQUIRE_ERROR_COUNT(2, result); + CHECK_EQ(expected1, toString(result.errors[0])); + + const std::string expected3 = R"(Type 'b2' could not be converted into 'a2' caused by: - Type '{ __call: (a, b) -> () }' could not be converted into '{ __call: (a) -> () }' + Type + '{ __call: (a, b) -> () }' +could not be converted into + '{ __call: (a) -> () }' caused by: - Property '__call' is not compatible. Type '(a, b) -> ()' could not be converted into '(a) -> ()'; different number of generic type parameters)"); + Property '__call' is not compatible. +Type + '(a, b) -> ()' +could not be converted into + '(a) -> ()'; different number of generic type parameters)"; + + CHECK_EQ(expected2, toString(result.errors[1])); } else { - CHECK_EQ(toString(result.errors[1]), R"(Type 'b2' could not be converted into 'a2' + LUAU_REQUIRE_ERROR_COUNT(2, result); + CHECK_EQ(expected1, toString(result.errors[0])); + + std::string expected3 = R"(Type 'b2' could not be converted into 'a2' caused by: - Type '{ __call: (a, b) -> () }' could not be converted into '{ __call: (a) -> () }' + Type + '{ __call: (a, b) -> () }' +could not be converted into + '{ __call: (a) -> () }' caused by: - Property '__call' is not compatible. Type '(a, b) -> ()' could not be converted into '(a) -> ()'; different number of generic type parameters)"); + Property '__call' is not compatible. +Type + '(a, b) -> ()' +could not be converted into + '(a) -> ()'; different number of generic type parameters)"; + CHECK_EQ(expected3, toString(result.errors[1])); } } @@ -2162,14 +2460,19 @@ TEST_CASE_FIXTURE(Fixture, "error_detailed_indexer_key") )"); LUAU_REQUIRE_ERRORS(result); - if (FFlag::LuauTypeMismatchInvarianceInError) - CHECK_EQ(toString(result.errors[0]), R"(Type 'A' could not be converted into 'B' -caused by: - Property '[indexer key]' is not compatible. Type 'number' could not be converted into 'string' in an invariant context)"); + + if (FFlag::LuauSolverV2) + { + CHECK("Type 'A' could not be converted into 'B'; at indexer(), number is not exactly string" == toString(result.errors[0])); + } else - CHECK_EQ(toString(result.errors[0]), R"(Type 'A' could not be converted into 'B' + { + const std::string expected = R"(Type 'A' could not be converted into 'B' caused by: - Property '[indexer key]' is not compatible. Type 'number' could not be converted into 'string')"); + Property '[indexer key]' is not compatible. +Type 'number' could not be converted into 'string' in an invariant context)"; + CHECK_EQ(expected, toString(result.errors[0])); + } } TEST_CASE_FIXTURE(Fixture, "error_detailed_indexer_value") @@ -2183,18 +2486,26 @@ TEST_CASE_FIXTURE(Fixture, "error_detailed_indexer_value") )"); LUAU_REQUIRE_ERRORS(result); - if (FFlag::LuauTypeMismatchInvarianceInError) - CHECK_EQ(toString(result.errors[0]), R"(Type 'A' could not be converted into 'B' -caused by: - Property '[indexer value]' is not compatible. Type 'number' could not be converted into 'string' in an invariant context)"); + + if (FFlag::LuauSolverV2) + { + CHECK("Type 'A' could not be converted into 'B'; at indexResult(), number is not exactly string" == toString(result.errors[0])); + } else - CHECK_EQ(toString(result.errors[0]), R"(Type 'A' could not be converted into 'B' + { + const std::string expected = R"(Type 'A' could not be converted into 'B' caused by: - Property '[indexer value]' is not compatible. Type 'number' could not be converted into 'string')"); + Property '[indexer value]' is not compatible. +Type 'number' could not be converted into 'string' in an invariant context)"; + CHECK_EQ(expected, toString(result.errors[0])); + } } TEST_CASE_FIXTURE(Fixture, "explicitly_typed_table") { + // Table properties like HasSuper.p must be invariant. The new solver rightly rejects this program. + ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + CheckResult result = check(R"( --!strict type Super = { x : number } @@ -2224,21 +2535,35 @@ local y: number = tmp.p.y )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(toString(result.errors[0]), R"(Type 'tmp' could not be converted into 'HasSuper' + + if (FFlag::LuauSolverV2) + CHECK( + "Type 'tmp' could not be converted into 'HasSuper'; at [read \"p\"], { x: number, y: number } is not exactly Super" == + toString(result.errors[0]) + ); + else + { + const std::string expected = R"(Type 'tmp' could not be converted into 'HasSuper' caused by: - Property 'p' is not compatible. Table type '{ x: number, y: number }' not compatible with type 'Super' because the former has extra field 'y')"); + Property 'p' is not compatible. +Table type '{ x: number, y: number }' not compatible with type 'Super' because the former has extra field 'y')"; + CHECK_EQ(expected, toString(result.errors[0])); + } } TEST_CASE_FIXTURE(Fixture, "explicitly_typed_table_with_indexer") { + // CLI-114791 Bidirectional inference should be able to cause the inference engine to forget that a table literal has some property + ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + CheckResult result = check(R"( ---!strict -type Super = { x : number } -type Sub = { x : number, y: number } -type HasSuper = { [string] : Super } -type HasSub = { [string] : Sub } -local a: HasSuper = { p = { x = 5, y = 7 }} -a.p = { x = 9 } + --!strict + type Super = { x : number } + type Sub = { x : number, y: number } + type HasSuper = { [string] : Super } + type HasSub = { [string] : Sub } + local a: HasSuper = { p = { x = 5, y = 7 }} + a.p = { x = 9 } )"); LUAU_REQUIRE_NO_ERRORS(result); @@ -2246,6 +2571,9 @@ a.p = { x = 9 } TEST_CASE_FIXTURE(BuiltinsFixture, "recursive_metatable_type_call") { + // CLI-114782 + ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + CheckResult result = check(R"( local b b = setmetatable({}, {__call = b}) @@ -2253,7 +2581,8 @@ b() )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(toString(result.errors[0]), R"(Cannot call non-function t1 where t1 = { @metatable { __call: t1 }, { } })"); + + CHECK_EQ(toString(result.errors[0]), R"(Cannot call a value of type t1 where t1 = { @metatable { __call: t1 }, { } })"); } TEST_CASE_FIXTURE(Fixture, "table_subtyping_shouldn't_add_optional_properties_to_sealed_tables") @@ -2321,12 +2650,15 @@ local y = #x TEST_CASE_FIXTURE(Fixture, "length_operator_union_errors") { + ScopedFastFlag _{FFlag::LuauSolverV2, true}; + CheckResult result = check(R"( local x: {number} | number | string local y = #x )"); - LUAU_REQUIRE_ERROR_COUNT(1, result); + // CLI-119936: This shouldn't double error but does under the new solver. + LUAU_REQUIRE_ERROR_COUNT(2, result); } TEST_CASE_FIXTURE(BuiltinsFixture, "dont_hang_when_trying_to_look_up_in_cyclic_metatable_index") @@ -2372,7 +2704,13 @@ TEST_CASE_FIXTURE(Fixture, "confusing_indexing") local foo = f({p = "string"}) )"); - LUAU_REQUIRE_NO_ERRORS(result); + if (FFlag::LuauSolverV2) + { + // CLI-114781 Bidirectional checking can't see through the intersection + LUAU_REQUIRE_ERROR_COUNT(1, result); + } + else + LUAU_REQUIRE_NO_ERRORS(result); CHECK_EQ("number | string", toString(requireType("foo"))); } @@ -2392,7 +2730,10 @@ TEST_CASE_FIXTURE(Fixture, "pass_a_union_of_tables_to_a_function_that_requires_a LUAU_REQUIRE_NO_ERRORS(result); - REQUIRE_EQ("{| [any]: any, x: number, y: number |} | {| y: number |}", toString(requireType("b"))); + if (FFlag::LuauSolverV2) + REQUIRE_EQ("{ y: number }", toString(requireType("b"))); + else + REQUIRE_EQ("{- y: number -}", toString(requireType("b"))); } TEST_CASE_FIXTURE(Fixture, "pass_a_union_of_tables_to_a_function_that_requires_a_table_2") @@ -2410,7 +2751,10 @@ TEST_CASE_FIXTURE(Fixture, "pass_a_union_of_tables_to_a_function_that_requires_a LUAU_REQUIRE_NO_ERRORS(result); - REQUIRE_EQ("{| [any]: any, x: number, y: number |} | {| y: number |}", toString(requireType("b"))); + if (FFlag::LuauSolverV2) + REQUIRE_EQ("{ y: number }", toString(requireType("b"))); + else + REQUIRE_EQ("{- y: number -}", toString(requireType("b"))); } TEST_CASE_FIXTURE(Fixture, "unifying_tables_shouldnt_uaf1") @@ -2477,11 +2821,14 @@ TEST_CASE_FIXTURE(Fixture, "table_length") LUAU_REQUIRE_NO_ERRORS(result); CHECK(nullptr != get(requireType("t"))); - CHECK_EQ(*typeChecker.numberType, *requireType("s")); + CHECK_EQ(*builtinTypes->numberType, *requireType("s")); } TEST_CASE_FIXTURE(Fixture, "nil_assign_doesnt_hit_indexer") { + // CLI-100076 - Assigning a table key to `nil` in the presence of an indexer should always be permitted + ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + CheckResult result = check("local a = {} a[0] = 7 a[0] = nil"); LUAU_REQUIRE_ERROR_COUNT(0, result); } @@ -2498,8 +2845,8 @@ TEST_CASE_FIXTURE(Fixture, "wrong_assign_does_hit_indexer") CHECK((Location{Position{3, 15}, Position{3, 18}}) == result.errors[0].location); TypeMismatch* tm = get(result.errors[0]); REQUIRE(tm); - CHECK(tm->wantedType == typeChecker.numberType); - CHECK(tm->givenType == typeChecker.stringType); + CHECK(tm->wantedType == builtinTypes->numberType); + CHECK(tm->givenType == builtinTypes->stringType); } TEST_CASE_FIXTURE(Fixture, "nil_assign_doesnt_hit_no_indexer") @@ -2509,10 +2856,16 @@ TEST_CASE_FIXTURE(Fixture, "nil_assign_doesnt_hit_no_indexer") a['a'] = nil )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(result.errors[0], (TypeError{Location{Position{2, 17}, Position{2, 20}}, TypeMismatch{ - typeChecker.numberType, - typeChecker.nilType, - }})); + CHECK_EQ( + result.errors[0], + (TypeError{ + Location{Position{2, 17}, Position{2, 20}}, + TypeMismatch{ + builtinTypes->numberType, + builtinTypes->nilType, + } + }) + ); } TEST_CASE_FIXTURE(Fixture, "free_rhs_table_can_also_be_bound") @@ -2577,6 +2930,25 @@ TEST_CASE_FIXTURE(Fixture, "tables_get_names_from_their_locals") CHECK_EQ("T", toString(requireType("T"))); } +TEST_CASE_FIXTURE(Fixture, "should_not_unblock_table_type_twice") +{ + // don't run this when the DCR flag isn't set + if (!FFlag::LuauSolverV2) + return; + + check(R"( + local timer = peek(timerQueue) + while timer ~= nil do + if timer.startTime <= currentTime then + timer.isQueued = true + end + timer = peek(timerQueue) + end + )"); + + // Just checking this is enough to satisfy the original bug. +} + TEST_CASE_FIXTURE(Fixture, "generalize_table_argument") { CheckResult result = check(R"( @@ -2591,7 +2963,6 @@ TEST_CASE_FIXTURE(Fixture, "generalize_table_argument") )"); LUAU_REQUIRE_NO_ERRORS(result); - dumpErrors(result); const FunctionType* fooType = get(requireType("foo")); REQUIRE(fooType); @@ -2599,10 +2970,13 @@ TEST_CASE_FIXTURE(Fixture, "generalize_table_argument") std::optional fooArg1 = first(fooType->argTypes); REQUIRE(fooArg1); - const TableType* fooArg1Table = get(*fooArg1); + const TableType* fooArg1Table = get(follow(*fooArg1)); REQUIRE(fooArg1Table); - CHECK_EQ(fooArg1Table->state, TableState::Generic); + if (FFlag::LuauSolverV2) + CHECK_EQ(fooArg1Table->state, TableState::Sealed); + else + CHECK_EQ(fooArg1Table->state, TableState::Generic); } /* @@ -2643,7 +3017,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "dont_quantify_table_that_belongs_to_outer_sc REQUIRE(counterType); REQUIRE(counterType->props.count("new")); - const FunctionType* newType = get(follow(counterType->props["new"].type)); + const FunctionType* newType = get(follow(counterType->props["new"].type())); REQUIRE(newType); std::optional newRetType = *first(newType->retTypes); @@ -2685,7 +3059,7 @@ TEST_CASE_FIXTURE(Fixture, "inferring_crazy_table_should_also_be_quick") )"); ModulePtr module = getMainModule(); - if (FFlag::DebugLuauDeferredConstraintResolution) + if (FFlag::LuauSolverV2) CHECK_GE(500, module->internalTypes.types.size()); else CHECK_GE(100, module->internalTypes.types.size()); @@ -2709,7 +3083,7 @@ TEST_CASE_FIXTURE(Fixture, "setmetatable_cant_be_used_to_mutate_global_types") Fixture fix; // inherit env from parent fixture checker - fix.typeChecker.globalScope = typeChecker.globalScope; + fix.frontend.globals.globalScope = frontend.globals.globalScope; fix.check(R"( --!nonstrict @@ -2723,7 +3097,7 @@ end // validate sharedEnv post-typecheck; valuable for debugging some typeck crashes but slows fuzzing down // note: it's important for typeck to be destroyed at this point! { - for (auto& p : typeChecker.globalScope->bindings) + for (auto& p : frontend.globals.globalScope->bindings) { toString(p.second.typeId); // toString walks the entire type, making sure ASAN catches access to destroyed type arenas } @@ -2744,8 +3118,17 @@ do end TEST_CASE_FIXTURE(BuiltinsFixture, "dont_crash_when_setmetatable_does_not_produce_a_metatabletypevar") { CheckResult result = check("local x = setmetatable({})"); - LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ("Argument count mismatch. Function 'setmetatable' expects 2 arguments, but only 1 is specified", toString(result.errors[0])); + + if (FFlag::LuauSolverV2) + { + // CLI-114665: Generic parameters should not also be optional. + LUAU_REQUIRE_NO_ERRORS(result); + } + else + { + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ("Argument count mismatch. Function 'setmetatable' expects 2 arguments, but only 1 is specified", toString(result.errors[0])); + } } TEST_CASE_FIXTURE(BuiltinsFixture, "instantiate_table_cloning") @@ -2805,7 +3188,25 @@ c = b const TableType* ttv = get(*ty); REQUIRE(ttv); - CHECK(ttv->instantiatedTypeParams.empty()); + CHECK(0 == ttv->instantiatedTypeParams.size()); +} + +TEST_CASE_FIXTURE(Fixture, "record_location_of_inserted_table_properties") +{ + CheckResult result = check(R"( + local a = {} + a.foo = 1234 + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + const TableType* tt = get(requireType("a")); + REQUIRE(tt); + + REQUIRE(tt->props.count("foo")); + + const Property& prop = tt->props.find("foo")->second; + CHECK(Location{{2, 10}, {2, 13}} == prop.location); } TEST_CASE_FIXTURE(Fixture, "table_indexing_error_location") @@ -2823,7 +3224,7 @@ local baz = foo[bar] TEST_CASE_FIXTURE(BuiltinsFixture, "table_call_metamethod_basic") { - if (!FFlag::DebugLuauDeferredConstraintResolution) + if (!FFlag::LuauSolverV2) return; CheckResult result = check(R"( @@ -2838,7 +3239,13 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "table_call_metamethod_basic") local foo = a(12) )"); - LUAU_REQUIRE_NO_ERRORS(result); + if (FFlag::LuauSolverV2) + { + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK(get(result.errors[0])); + } + else + LUAU_REQUIRE_NO_ERRORS(result); CHECK(requireType("foo") == builtinTypes->numberType); } @@ -2853,10 +3260,19 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "table_call_metamethod_must_be_callable") )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK(result.errors[0] == TypeError{ - Location{{5, 20}, {5, 21}}, - CannotCallNonFunction{builtinTypes->numberType}, - }); + + if (FFlag::LuauSolverV2) + { + CHECK("Cannot call a value of type a" == toString(result.errors[0])); + } + else + { + TypeError e{ + Location{{5, 20}, {5, 21}}, + CannotCallNonFunction{builtinTypes->numberType}, + }; + CHECK(result.errors[0] == e); + } } TEST_CASE_FIXTURE(BuiltinsFixture, "table_call_metamethod_generic") @@ -2879,14 +3295,17 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "table_call_metamethod_generic") TEST_CASE_FIXTURE(BuiltinsFixture, "table_simple_call") { + // The new solver can see that this function is safe to oversaturate. + ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + CheckResult result = check(R"( -local a = setmetatable({ x = 2 }, { - __call = function(self) - return (self.x :: number) * 2 -- should work without annotation in the future - end -}) -local b = a() -local c = a(2) -- too many arguments + local a = setmetatable({ x = 2 }, { + __call = function(self) + return (self.x :: number) * 2 -- should work without annotation in the future + end + }) + local b = a() + local c = a(2) -- too many arguments )"); LUAU_REQUIRE_ERROR_COUNT(1, result); @@ -2911,7 +3330,10 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "access_index_metamethod_that_returns_variadi ToStringOptions o; o.exhaustive = true; - CHECK_EQ("{| x: string |}", toString(requireType("foo"), o)); + if (FFlag::LuauSolverV2) + CHECK_EQ("{ x: string }", toString(requireType("foo"), o)); + else + CHECK_EQ("{| x: string |}", toString(requireType("foo"), o)); } TEST_CASE_FIXTURE(Fixture, "dont_invalidate_the_properties_iterator_of_free_table_when_rolled_back") @@ -2959,8 +3381,17 @@ TEST_CASE_FIXTURE(Fixture, "checked_prop_too_early") )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ("Value of type '{| x: number? |}?' could be nil", toString(result.errors[0])); - CHECK_EQ("number | {| x: number? |}", toString(requireType("u"))); + + if (FFlag::LuauSolverV2) + { + CHECK_EQ("Value of type '{ x: number? }?' could be nil", toString(result.errors[0])); + CHECK_EQ("number | { x: number }", toString(requireType("u"))); + } + else + { + CHECK_EQ("Value of type '{| x: number? |}?' could be nil", toString(result.errors[0])); + CHECK_EQ("number | {| x: number? |}", toString(requireType("u"))); + } } TEST_CASE_FIXTURE(Fixture, "accidentally_checked_prop_in_opposite_branch") @@ -2971,7 +3402,10 @@ TEST_CASE_FIXTURE(Fixture, "accidentally_checked_prop_in_opposite_branch") )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ("Value of type '{| x: number? |}?' could be nil", toString(result.errors[0])); + if (FFlag::LuauSolverV2) + CHECK_EQ("Value of type '{ x: number? }?' could be nil", toString(result.errors[0])); + else + CHECK_EQ("Value of type '{| x: number? |}?' could be nil", toString(result.errors[0])); CHECK_EQ("boolean", toString(requireType("u"))); } @@ -3047,37 +3481,18 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "dont_leak_free_table_props") LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("({+ blah: a +}) -> ()", toString(requireType("a"))); - CHECK_EQ("({+ gwar: a +}) -> ()", toString(requireType("b"))); - CHECK_EQ("() -> ({+ blah: a, gwar: b +}) -> ()", toString(getMainModule()->returnType)); -} - -TEST_CASE_FIXTURE(Fixture, "inferred_return_type_of_free_table") -{ - ScopedFastFlag sff[] = { - // {"LuauLowerBoundsCalculation", true}, - {"DebugLuauSharedSelf", true}, - }; - - check(R"( - function Base64FileReader(data) - local reader = {} - local index: number - - function reader:PeekByte() - return data:byte(index) - end - - function reader:Byte() - return data:byte(index - 1) - end - - return reader - end - )"); - - CHECK_EQ("(t1) -> {| Byte: (a) -> (b...), PeekByte: (a) -> (b...) |} where t1 = {+ byte: (t1, number) -> (b...) +}", - toString(requireType("Base64FileReader"))); + if (FFlag::LuauSolverV2) + { + CHECK_EQ("({ read blah: unknown }) -> ()", toString(requireType("a"))); + CHECK_EQ("({ read gwar: unknown }) -> ()", toString(requireType("b"))); + CHECK_EQ("(...any) -> ({ read blah: unknown, read gwar: unknown }) -> ()", toString(getMainModule()->returnType)); + } + else + { + CHECK_EQ("({+ blah: a +}) -> ()", toString(requireType("a"))); + CHECK_EQ("({+ gwar: a +}) -> ()", toString(requireType("b"))); + CHECK_EQ("() -> ({+ blah: a, gwar: b +}) -> ()", toString(getMainModule()->returnType)); + } } TEST_CASE_FIXTURE(Fixture, "mixed_tables_with_implicit_numbered_keys") @@ -3086,71 +3501,22 @@ TEST_CASE_FIXTURE(Fixture, "mixed_tables_with_implicit_numbered_keys") local t: { [string]: number } = { 5, 6, 7 } )"); - LUAU_REQUIRE_ERROR_COUNT(3, result); - - CHECK_EQ("Type 'number' could not be converted into 'string'", toString(result.errors[0])); - CHECK_EQ("Type 'number' could not be converted into 'string'", toString(result.errors[1])); - CHECK_EQ("Type 'number' could not be converted into 'string'", toString(result.errors[2])); -} - -TEST_CASE_FIXTURE(Fixture, "shared_selfs") -{ - ScopedFastFlag sff{"DebugLuauSharedSelf", true}; - - CheckResult result = check(R"( - local t = {} - t.x = 5 - - function t:m1() return self.x end - function t:m2() return self.y end - - return t - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - ToStringOptions opts; - opts.exhaustive = true; - CHECK_EQ("{| m1: ({+ x: a, y: b +}) -> a, m2: ({+ x: a, y: b +}) -> b, x: number |}", toString(requireType("t"), opts)); -} - -TEST_CASE_FIXTURE(Fixture, "shared_selfs_from_free_param") -{ - ScopedFastFlag sff{"DebugLuauSharedSelf", true}; - - CheckResult result = check(R"( - local function f(t) - function t:m1() return self.x end - function t:m2() return self.y end - end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - CHECK_EQ("({+ m1: ({+ x: a, y: b +}) -> a, m2: ({+ x: a, y: b +}) -> b +}) -> ()", toString(requireType("f"))); -} - -TEST_CASE_FIXTURE(BuiltinsFixture, "shared_selfs_through_metatables") -{ - ScopedFastFlag sff{"DebugLuauSharedSelf", true}; - - CheckResult result = check(R"( - local t = {} - t.__index = t - setmetatable({}, t) - - function t:m1() return self.x end - function t:m2() return self.y end - - return t - )"); - - LUAU_REQUIRE_NO_ERRORS(result); + if (FFlag::LuauSolverV2) + { + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK( + "Type '{number}' could not be converted into '{ [string]: number }'; at indexer(), number is not exactly string" == + toString(result.errors[0]) + ); + } + else + { + LUAU_REQUIRE_ERROR_COUNT(3, result); - ToStringOptions opts; - opts.exhaustive = true; - CHECK_EQ( - toString(requireType("t"), opts), "t1 where t1 = {| __index: t1, m1: ({+ x: a, y: b +}) -> a, m2: ({+ x: a, y: b +}) -> b |}"); + CHECK_EQ("Type 'number' could not be converted into 'string'", toString(result.errors[0])); + CHECK_EQ("Type 'number' could not be converted into 'string'", toString(result.errors[1])); + CHECK_EQ("Type 'number' could not be converted into 'string'", toString(result.errors[2])); + } } TEST_CASE_FIXTURE(Fixture, "expected_indexer_value_type_extra") @@ -3202,64 +3568,13 @@ TEST_CASE_FIXTURE(Fixture, "prop_access_on_unions_of_indexers_where_key_whose_ty )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ("Type '{number} | {| [boolean]: number |}' does not have key 'x'", toString(result.errors[0])); + if (FFlag::LuauSolverV2) + CHECK_EQ("Type '{ [boolean]: number } | {number}' does not have key 'x'", toString(result.errors[0])); + else + CHECK_EQ("Type '{number} | {| [boolean]: number |}' does not have key 'x'", toString(result.errors[0])); } -TEST_CASE_FIXTURE(BuiltinsFixture, "quantify_metatables_of_metatables_of_table") -{ - ScopedFastFlag sff[]{ - {"DebugLuauSharedSelf", true}, - }; - - CheckResult result = check(R"( - local T = {} - - function T:m() - return self.x, self.y - end - - function T:n() - end - - local U = setmetatable({}, {__index = T}) - - local V = setmetatable({}, {__index = U}) - - return V - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - ToStringOptions opts; - opts.exhaustive = true; - CHECK_EQ(toString(requireType("V"), opts), "{ @metatable { __index: { @metatable { __index: {| m: ({+ x: a, y: b +}) -> (a, b), n: ({+ x: a, y: b +}) -> () |} }, { } } }, { } }"); -} - -TEST_CASE_FIXTURE(Fixture, "quantify_even_that_table_was_never_exported_at_all") -{ - ScopedFastFlag sff{"DebugLuauSharedSelf", true}; - - CheckResult result = check(R"( - local T = {} - - function T:m() - return self.x - end - - function T:n() - return self.y - end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - ToStringOptions opts; - opts.exhaustive = true; - CHECK_EQ("{| m: ({+ x: a, y: b +}) -> a, n: ({+ x: a, y: b +}) -> b |}", toString(requireType("T"), opts)); -} - -TEST_CASE_FIXTURE(BuiltinsFixture, "leaking_bad_metatable_errors") +TEST_CASE_FIXTURE(BuiltinsFixture, "leaking_bad_metatable_errors") { CheckResult result = check(R"( local a = setmetatable({}, 1) @@ -3273,6 +3588,9 @@ local b = a.x TEST_CASE_FIXTURE(Fixture, "scalar_is_a_subtype_of_a_compatible_polymorphic_shape_type") { + // CLI-115087 The new solver cannot infer that a table-like type is actually string + ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + CheckResult result = check(R"( local function f(s) return s:lower() @@ -3298,27 +3616,69 @@ TEST_CASE_FIXTURE(Fixture, "scalar_is_not_a_subtype_of_a_compatible_polymorphic_ f("baz" :: "bar" | "baz") )"); - LUAU_REQUIRE_ERROR_COUNT(3, result); + if (FFlag::LuauSolverV2) + { + // CLI-115090 Error reporting is quite bad in this case. + + // This should be just 3 + LUAU_REQUIRE_ERROR_COUNT(4, result); + + TypeMismatch* tm1 = get(result.errors[0]); + REQUIRE(tm1); + CHECK("typeof(string)" == toString(tm1->givenType)); + CHECK("t1 where t1 = { read absolutely_no_scalar_has_this_method: (t1) -> (a...) }" == toString(tm1->wantedType)); + + TypeMismatch* tm2 = get(result.errors[1]); + REQUIRE(tm2); + CHECK("typeof(string)" == toString(tm2->givenType)); + CHECK("t1 where t1 = { read absolutely_no_scalar_has_this_method: (t1) -> (a...) }" == toString(tm2->wantedType)); + + TypeMismatch* tm3 = get(result.errors[2]); + REQUIRE(tm3); + CHECK("typeof(string)" == toString(tm3->givenType)); + CHECK("t1 where t1 = { read absolutely_no_scalar_has_this_method: (t1) -> (a...) }" == toString(tm3->wantedType)); + + TypeMismatch* tm4 = get(result.errors[3]); + REQUIRE(tm4); + CHECK("typeof(string)" == toString(tm4->givenType)); + CHECK("t1 where t1 = { read absolutely_no_scalar_has_this_method: (t1) -> (a...) }" == toString(tm4->wantedType)); + } + else + { + LUAU_REQUIRE_ERROR_COUNT(3, result); - CHECK_EQ(R"(Type 'string' could not be converted into 't1 where t1 = {- absolutely_no_scalar_has_this_method: (t1) -> (a...) -}' + const std::string expected1 = + R"(Type 'string' could not be converted into 't1 where t1 = {- absolutely_no_scalar_has_this_method: (t1) -> (a...) -}' caused by: - The former's metatable does not satisfy the requirements. Table type 'typeof(string)' not compatible with type 't1 where t1 = {- absolutely_no_scalar_has_this_method: (t1) -> (a...) -}' because the former is missing field 'absolutely_no_scalar_has_this_method')", - toString(result.errors[0])); - CHECK_EQ(R"(Type '"bar"' could not be converted into 't1 where t1 = {- absolutely_no_scalar_has_this_method: (t1) -> (a...) -}' + The former's metatable does not satisfy the requirements. +Table type 'typeof(string)' not compatible with type 't1 where t1 = {- absolutely_no_scalar_has_this_method: (t1) -> (a...) -}' because the former is missing field 'absolutely_no_scalar_has_this_method')"; + CHECK_EQ(expected1, toString(result.errors[0])); + + const std::string expected2 = + R"(Type '"bar"' could not be converted into 't1 where t1 = {- absolutely_no_scalar_has_this_method: (t1) -> (a...) -}' caused by: - The former's metatable does not satisfy the requirements. Table type 'typeof(string)' not compatible with type 't1 where t1 = {- absolutely_no_scalar_has_this_method: (t1) -> (a...) -}' because the former is missing field 'absolutely_no_scalar_has_this_method')", - toString(result.errors[1])); - CHECK_EQ(R"(Type '"bar" | "baz"' could not be converted into 't1 where t1 = {- absolutely_no_scalar_has_this_method: (t1) -> (a...) -}' + The former's metatable does not satisfy the requirements. +Table type 'typeof(string)' not compatible with type 't1 where t1 = {- absolutely_no_scalar_has_this_method: (t1) -> (a...) -}' because the former is missing field 'absolutely_no_scalar_has_this_method')"; + CHECK_EQ(expected2, toString(result.errors[1])); + + const std::string expected3 = R"(Type + '"bar" | "baz"' +could not be converted into + 't1 where t1 = {- absolutely_no_scalar_has_this_method: (t1) -> (a...) -}' caused by: - Not all union options are compatible. Type '"bar"' could not be converted into 't1 where t1 = {- absolutely_no_scalar_has_this_method: (t1) -> (a...) -}' + Not all union options are compatible. +Type '"bar"' could not be converted into 't1 where t1 = {- absolutely_no_scalar_has_this_method: (t1) -> (a...) -}' caused by: - The former's metatable does not satisfy the requirements. Table type 'typeof(string)' not compatible with type 't1 where t1 = {- absolutely_no_scalar_has_this_method: (t1) -> (a...) -}' because the former is missing field 'absolutely_no_scalar_has_this_method')", - toString(result.errors[2])); + The former's metatable does not satisfy the requirements. +Table type 'typeof(string)' not compatible with type 't1 where t1 = {- absolutely_no_scalar_has_this_method: (t1) -> (a...) -}' because the former is missing field 'absolutely_no_scalar_has_this_method')"; + CHECK_EQ(expected3, toString(result.errors[2])); + } } TEST_CASE_FIXTURE(Fixture, "a_free_shape_can_turn_into_a_scalar_if_it_is_compatible") { - ScopedFastFlag luauScalarShapeUnifyToMtOwner{"LuauScalarShapeUnifyToMtOwner2", true}; // Changes argument from table type to primitive + // CLI-115087 The new solver cannot infer that a table-like type is actually string + ScopedFastFlag sff{FFlag::LuauSolverV2, false}; CheckResult result = check(R"( local function f(s): string @@ -3340,17 +3700,40 @@ TEST_CASE_FIXTURE(Fixture, "a_free_shape_cannot_turn_into_a_scalar_if_it_is_not_ end )"); - LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(R"(Type 't1 where t1 = {+ absolutely_no_scalar_has_this_method: (t1) -> (a, b...) +}' could not be converted into 'string' + if (FFlag::LuauSolverV2) + { + LUAU_REQUIRE_ERROR_COUNT(4, result); + + CHECK(toString(result.errors[0]) == "Parameter 's' has been reduced to never. This function is not callable with any possible value."); + // FIXME: These free types should have been generalized by now. + CHECK( + toString(result.errors[1]) == + "Parameter 's' is required to be a subtype of '{- read absolutely_no_scalar_has_this_method: ('a <: (never) -> ('b, c...)) -}' here." + ); + CHECK(toString(result.errors[2]) == "Parameter 's' is required to be a subtype of 'string' here."); + CHECK(get(result.errors[3])); + + CHECK_EQ("(never) -> string", toString(requireType("f"))); + } + else + { + LUAU_REQUIRE_ERROR_COUNT(1, result); + + const std::string expected = + R"(Type 't1 where t1 = {+ absolutely_no_scalar_has_this_method: (t1) -> (a, b...) +}' could not be converted into 'string' caused by: - The former's metatable does not satisfy the requirements. Table type 'typeof(string)' not compatible with type 't1 where t1 = {+ absolutely_no_scalar_has_this_method: (t1) -> (a, b...) +}' because the former is missing field 'absolutely_no_scalar_has_this_method')", - toString(result.errors[0])); - CHECK_EQ("(t1) -> string where t1 = {+ absolutely_no_scalar_has_this_method: (t1) -> (a, b...) +}", toString(requireType("f"))); + The former's metatable does not satisfy the requirements. +Table type 'typeof(string)' not compatible with type 't1 where t1 = {+ absolutely_no_scalar_has_this_method: (t1) -> (a, b...) +}' because the former is missing field 'absolutely_no_scalar_has_this_method')"; + CHECK_EQ(expected, toString(result.errors[0])); + + CHECK_EQ("(t1) -> string where t1 = {+ absolutely_no_scalar_has_this_method: (t1) -> (a, b...) +}", toString(requireType("f"))); + } } TEST_CASE_FIXTURE(BuiltinsFixture, "a_free_shape_can_turn_into_a_scalar_directly") { - ScopedFastFlag luauScalarShapeUnifyToMtOwner{"LuauScalarShapeUnifyToMtOwner2", true}; + // We need egraphs to simplify the type of `out` here. CLI-114134 + ScopedFastFlag sff{FFlag::LuauSolverV2, false}; CheckResult result = check(R"( local function stringByteList(str) @@ -3370,33 +3753,47 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "a_free_shape_can_turn_into_a_scalar_directly TEST_CASE_FIXTURE(Fixture, "invariant_table_properties_means_instantiating_tables_in_call_is_unsound") { ScopedFastFlag sff[]{ - {"LuauInstantiateInSubtyping", true}, + {FFlag::LuauInstantiateInSubtyping, true}, }; CheckResult result = check(R"( --!strict local t = {} - function t.m(x) return x end + function t.m(x: T) return x end local a : string = t.m("hi") local b : number = t.m(5) function f(x : { m : (number)->number }) - x.m = function(x) return 1+x end + x.m = function(x: number) return 1+x end end + f(t) -- This shouldn't typecheck + local c : string = t.m("hi") )"); - LUAU_REQUIRE_NO_ERRORS(result); - // TODO: test behavior is wrong until we can re-enable the covariant requirement for instantiation in subtyping - // LUAU_REQUIRE_ERRORS(result); - // CHECK_EQ(toString(result.errors[0]), R"(Type 't' could not be converted into '{| m: (number) -> number |}' - // caused by: - // Property 'm' is not compatible. Type '(a) -> a' could not be converted into '(number) -> number'; different number of generic type - // parameters)"); - // // this error message is not great since the underlying issue is that the context is invariant, - // and `(number) -> number` cannot be a subtype of `(a) -> a`. -} + if (FFlag::LuauSolverV2) + { + // FIXME. We really should be reporting just one error in this case. CLI-114509 + LUAU_REQUIRE_ERROR_COUNT(3, result); + + CHECK(get(result.errors[0])); + CHECK(get(result.errors[1])); + CHECK(get(result.errors[2])); + } + else + { + // TODO: test behavior is wrong until we can re-enable the covariant requirement for instantiation in subtyping + // LUAU_REQUIRE_ERRORS(result); + // CHECK_EQ(toString(result.errors[0]), R"(Type 't' could not be converted into '{| m: (number) -> number |}' + // caused by: + // Property 'm' is not compatible. Type '(a) -> a' could not be converted into '(number) -> number'; different number of generic type + // parameters)"); + // // this error message is not great since the underlying issue is that the context is invariant, + // and `(number) -> number` cannot be a subtype of `(a) -> a`. + LUAU_REQUIRE_NO_ERRORS(result); + } +} TEST_CASE_FIXTURE(BuiltinsFixture, "generic_table_instantiation_potential_regression") { @@ -3412,16 +3809,27 @@ local g : ({ p : number, q : string }) -> ({ p : number, r : boolean }) = f LUAU_REQUIRE_ERROR_COUNT(1, result); - MissingProperties* error = get(result.errors[0]); - REQUIRE(error != nullptr); - REQUIRE(error->properties.size() == 1); + if (FFlag::LuauSolverV2) + { + const TypeMismatch* error = get(result.errors[0]); + REQUIRE_MESSAGE(error, "Expected TypeMismatch but got " << result.errors[0]); + + CHECK("({ p: number, q: string }) -> { p: number, r: boolean }" == toString(error->wantedType)); + CHECK("({ p: number }) -> { p: number }" == toString(error->givenType)); + } + else + { + const MissingProperties* error = get(result.errors[0]); + REQUIRE_MESSAGE(error != nullptr, "Expected MissingProperties but got " << result.errors[0]); - CHECK_EQ("r", error->properties[0]); + REQUIRE(error->properties.size() == 1); + CHECK_EQ("r", error->properties[0]); + } } TEST_CASE_FIXTURE(BuiltinsFixture, "setmetatable_has_a_side_effect") { - if (!FFlag::DebugLuauDeferredConstraintResolution) + if (!FFlag::LuauSolverV2) return; CheckResult result = check(R"( @@ -3457,8 +3865,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "tables_should_be_fully_populated") TEST_CASE_FIXTURE(Fixture, "fuzz_table_indexer_unification_can_bound_owner_to_string") { - ScopedFastFlag luauScalarShapeUnifyToMtOwner{"LuauScalarShapeUnifyToMtOwner2", true}; - CheckResult result = check(R"( sin,_ = nil _ = _[_.sin][_._][_][_]._ @@ -3470,8 +3876,6 @@ _[_] = _ TEST_CASE_FIXTURE(BuiltinsFixture, "fuzz_table_extra_prop_unification_can_bound_owner_to_string") { - ScopedFastFlag luauScalarShapeUnifyToMtOwner{"LuauScalarShapeUnifyToMtOwner2", true}; - CheckResult result = check(R"( l0,_ = nil _ = _,_[_.n5]._[_][_][_]._ @@ -3483,8 +3887,6 @@ _._.foreach[_],_ = _[_],_._ TEST_CASE_FIXTURE(BuiltinsFixture, "fuzz_typelevel_promote_on_changed_table_type") { - ScopedFastFlag luauScalarShapeUnifyToMtOwner{"LuauScalarShapeUnifyToMtOwner2", true}; - CheckResult result = check(R"( _._,_ = nil _ = _.foreach[_]._,_[_.n5]._[_.foreach][_][_]._ @@ -3497,9 +3899,7 @@ _ = _._ TEST_CASE_FIXTURE(BuiltinsFixture, "fuzz_table_unify_instantiated_table") { ScopedFastFlag sff[]{ - {"LuauInstantiateInSubtyping", true}, - {"LuauScalarShapeUnifyToMtOwner2", true}, - {"LuauTableUnifyInstantiationFix", true}, + {FFlag::LuauInstantiateInSubtyping, true}, }; CheckResult result = check(R"( @@ -3516,9 +3916,7 @@ return _[_()()[_]] <= _ TEST_CASE_FIXTURE(Fixture, "fuzz_table_unify_instantiated_table_with_prop_realloc") { ScopedFastFlag sff[]{ - {"LuauInstantiateInSubtyping", true}, - {"LuauScalarShapeUnifyToMtOwner2", true}, - {"LuauTableUnifyInstantiationFix", true}, + {FFlag::LuauInstantiateInSubtyping, true}, }; CheckResult result = check(R"( @@ -3537,12 +3935,6 @@ end) TEST_CASE_FIXTURE(BuiltinsFixture, "fuzz_table_unify_prop_realloc") { - // For this test, we don't need LuauInstantiateInSubtyping - ScopedFastFlag sff[]{ - {"LuauScalarShapeUnifyToMtOwner2", true}, - {"LuauTableUnifyInstantiationFix", true}, - }; - CheckResult result = check(R"( n3,_ = nil _ = _[""]._,_[l0][_._][{[_]=_,_=_,}][_G].number @@ -3554,8 +3946,6 @@ _ = {_,} TEST_CASE_FIXTURE(Fixture, "when_augmenting_an_unsealed_table_with_an_indexer_apply_the_correct_scope_to_the_indexer_type") { - ScopedFastFlag sff{"LuauDontExtendUnsealedRValueTables", true}; - CheckResult result = check(R"( local events = {} local mockObserveEvent = function(_, key, callback) @@ -3577,15 +3967,16 @@ TEST_CASE_FIXTURE(Fixture, "when_augmenting_an_unsealed_table_with_an_indexer_ap CHECK(tt->props.empty()); REQUIRE(tt->indexer); - CHECK("string" == toString(tt->indexer->indexType)); + if (FFlag::LuauSolverV2) + CHECK("unknown" == toString(tt->indexer->indexType)); + else + CHECK("string" == toString(tt->indexer->indexType)); LUAU_REQUIRE_NO_ERRORS(result); } TEST_CASE_FIXTURE(Fixture, "dont_extend_unsealed_tables_in_rvalue_position") { - ScopedFastFlag sff{"LuauDontExtendUnsealedRValueTables", true}; - CheckResult result = check(R"( local testDictionary = { FruitName = "Lemon", @@ -3604,7 +3995,10 @@ TEST_CASE_FIXTURE(Fixture, "dont_extend_unsealed_tables_in_rvalue_position") CHECK(0 == ttv->props.count("")); - LUAU_REQUIRE_NO_ERRORS(result); + if (FFlag::LuauSolverV2) + LUAU_REQUIRE_ERROR_COUNT(1, result); + else + LUAU_REQUIRE_NO_ERRORS(result); } TEST_CASE_FIXTURE(BuiltinsFixture, "extend_unsealed_table_with_metatable") @@ -3645,4 +4039,893 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "top_table_type_is_isomorphic_to_empty_sealed )"); } +TEST_CASE_FIXTURE(BuiltinsFixture, "luau-polyfill.Array.includes") +{ + + CheckResult result = check(R"( +type Array = { [number]: T } + +function indexOf(array: Array, searchElement: any, fromIndex: number?): number + return -1 +end + +return function(array: Array, searchElement: any, fromIndex: number?): boolean + return -1 ~= indexOf(array, searchElement, fromIndex) +end + + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "certain_properties_of_table_literal_arguments_can_be_covariant") +{ + CheckResult result = check(R"( + function f(a: {[string]: string | {any} | nil }) + return a + end + + local x = f({ + title = "Feature.VirtualEvents.EnableNotificationsModalTitle", + body = "Feature.VirtualEvents.EnableNotificationsModalBody", + notNow = "Feature.VirtualEvents.NotNowButton", + getNotified = "Feature.VirtualEvents.GetNotifiedButton", + }) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "subproperties_can_also_be_covariantly_tested") +{ + CheckResult result = check(R"( + type T = { + [string]: {[string]: (string | number)?} + } + + function f(t: T) + return t + end + + local x = f({ + subprop={x="hello"} + }) + + local y = f({ + subprop={x=41} + }) + + local z = f({ + subprop={} + }) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "cyclic_shifted_tables") +{ + CheckResult result = check(R"( + local function id(x: a): a + return x + end + + -- Remove name from cyclic table + local foo = id({}) + foo.foo = id({}) + foo.foo.foo = id({}) + foo.foo.foo.foo = id({}) + foo.foo.foo.foo.foo = foo + + local almostFoo = id({}) + almostFoo.foo = id({}) + almostFoo.foo.foo = id({}) + almostFoo.foo.foo.foo = id({}) + almostFoo.foo.foo.foo.foo = almostFoo + -- Shift + almostFoo = almostFoo.foo.foo + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + + +TEST_CASE_FIXTURE(Fixture, "cli_84607_missing_prop_in_array_or_dict") +{ + ScopedFastFlag sff{FFlag::LuauFixIndexerSubtypingOrdering, true}; + + CheckResult result = check(R"( + type Thing = { name: string, prop: boolean } + + local arrayOfThings : {Thing} = { + { name = "a" } + } + + local dictOfThings : {[string]: Thing} = { + a = { name = "a" } + } + )"); + + LUAU_REQUIRE_ERROR_COUNT(2, result); + + if (FFlag::LuauSolverV2) + { + const TypeMismatch* err1 = get(result.errors[0]); + REQUIRE_MESSAGE(err1, "Expected TypeMismatch but got " << result.errors[0]); + + CHECK("{Thing}" == toString(err1->wantedType)); + CHECK("{{ name: string }}" == toString(err1->givenType)); + + const TypeMismatch* err2 = get(result.errors[1]); + REQUIRE_MESSAGE(err2, "Expected TypeMismatch but got " << result.errors[1]); + + CHECK("{ [string]: Thing }" == toString(err2->wantedType)); + CHECK("{ [string]: { name: string } }" == toString(err2->givenType)); + } + else + { + TypeError& err1 = result.errors[0]; + MissingProperties* error1 = get(err1); + REQUIRE(error1); + REQUIRE(error1->properties.size() == 1); + + CHECK_EQ("prop", error1->properties[0]); + + TypeError& err2 = result.errors[1]; + TypeMismatch* mismatch = get(err2); + REQUIRE(mismatch); + MissingProperties* error2 = get(*mismatch->error); + REQUIRE(error2); + REQUIRE(error2->properties.size() == 1); + + CHECK_EQ("prop", error2->properties[0]); + } +} + +TEST_CASE_FIXTURE(Fixture, "simple_method_definition") +{ + CheckResult result = check(R"( + local T = {} + + function T:m() + return 5 + end + + return T + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + if (FFlag::LuauSolverV2) + CHECK_EQ("{ m: (unknown) -> number }", toString(getMainModule()->returnType, ToStringOptions{true})); + else + CHECK_EQ("{| m: (a) -> number |}", toString(getMainModule()->returnType, ToStringOptions{true})); +} + +TEST_CASE_FIXTURE(Fixture, "identify_all_problematic_table_fields") +{ + ScopedFastFlag sff_LuauSolverV2{FFlag::LuauSolverV2, true}; + + CheckResult result = check(R"( + type T = { + a: number, + b: string, + c: boolean, + } + + local a: T = { + a = "foo", + b = false, + c = 123, + } + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + std::string expected = + "Type '{ a: string, b: boolean, c: number }' could not be converted into 'T'; at [read \"a\"], string is not exactly number" + "\n\tat [read \"b\"], boolean is not exactly string" + "\n\tat [read \"c\"], number is not exactly boolean"; + CHECK(toString(result.errors[0]) == expected); +} + +TEST_CASE_FIXTURE(Fixture, "read_and_write_only_table_properties_are_unsupported") +{ + ScopedFastFlag sff[] = { + {FFlag::LuauSolverV2, false}, + }; + + CheckResult result = check(R"( + type W = {read x: number} + type X = {write x: boolean} + + type Y = {read ["prop"]: boolean} + type Z = {write ["prop"]: string} + )"); + + LUAU_REQUIRE_ERROR_COUNT(4, result); + + CHECK("read keyword is illegal here" == toString(result.errors[0])); + CHECK(Location{{1, 18}, {1, 22}} == result.errors[0].location); + CHECK("write keyword is illegal here" == toString(result.errors[1])); + CHECK(Location{{2, 18}, {2, 23}} == result.errors[1].location); + CHECK("read keyword is illegal here" == toString(result.errors[2])); + CHECK(Location{{4, 18}, {4, 22}} == result.errors[2].location); + CHECK("write keyword is illegal here" == toString(result.errors[3])); + CHECK(Location{{5, 18}, {5, 23}} == result.errors[3].location); +} + +TEST_CASE_FIXTURE(Fixture, "read_ond_write_only_indexers_are_unsupported") +{ + CheckResult result = check(R"( + type T = {read [string]: number} + type U = {write [string]: boolean} + )"); + + LUAU_REQUIRE_ERROR_COUNT(2, result); + + CHECK("read keyword is illegal here" == toString(result.errors[0])); + CHECK(Location{{1, 18}, {1, 22}} == result.errors[0].location); + CHECK("write keyword is illegal here" == toString(result.errors[1])); + CHECK(Location{{2, 18}, {2, 23}} == result.errors[1].location); +} + +TEST_CASE_FIXTURE(Fixture, "infer_write_property") +{ + ScopedFastFlag sff{FFlag::LuauSolverV2, true}; + + CheckResult result = check(R"( + function f(t) + t.y = 1 + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK("({ y: number }) -> ()" == toString(requireType("f"))); +} + +TEST_CASE_FIXTURE(Fixture, "table_subtyping_error_suppression") +{ + CheckResult result = check(R"( + function one(tbl: {x: any}) end + function two(tbl: {x: string}) one(tbl) end -- ok, string <: any and any <: string + + function three(tbl: {x: any, y: string}) end + function four(tbl: {x: string, y: string}) three(tbl) end -- ok, string <: any, any <: string, string <: string + function five(tbl: {x: string, y: number}) three(tbl) end -- error, string <: any, any <: string, but number (result.errors[0]); + REQUIRE(tm); + + + // the new solver reports specifically the inner mismatch, rather than the whole table + // honestly not sure which of these is a better developer experience. + if (FFlag::LuauSolverV2) + { + CHECK_EQ(*tm->wantedType, *builtinTypes->stringType); + CHECK_EQ(*tm->givenType, *builtinTypes->numberType); + } + else + { + CHECK_EQ("{| x: any, y: string |}", toString(tm->wantedType)); + CHECK_EQ("{| x: string, y: number |}", toString(tm->givenType)); + } +} + +TEST_CASE_FIXTURE(Fixture, "write_to_read_only_property") +{ + ScopedFastFlag sff{FFlag::LuauSolverV2, true}; + + CheckResult result = check(R"( + function f(t: {read x: number}) + t.x = 5 + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CHECK("Property x of table '{ read x: number }' is read-only" == toString(result.errors[0])); + + PropertyAccessViolation* pav = get(result.errors[0]); + REQUIRE(pav); + + CHECK("{ read x: number }" == toString(pav->table, {true})); + CHECK("x" == pav->key); + CHECK(PropertyAccessViolation::CannotWrite == pav->context); +} + +TEST_CASE_FIXTURE(Fixture, "write_to_unusually_named_read_only_property") +{ + ScopedFastFlag sff{FFlag::LuauSolverV2, true}; + + CheckResult result = check(R"( + function f(t: {read ["hello world"]: number}) + t["hello world"] = 5 + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CHECK("Property \"hello world\" of table '{ read [\"hello world\"]: number }' is read-only" == toString(result.errors[0])); +} + +TEST_CASE_FIXTURE(Fixture, "write_annotations_are_unsupported_even_with_the_new_solver") +{ + ScopedFastFlag sff{FFlag::LuauSolverV2, true}; + + CheckResult result = check(R"( + function f(t: {write foo: number}) + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CHECK("write keyword is illegal here" == toString(result.errors[0])); + CHECK(Location{{1, 23}, {1, 28}} == result.errors[0].location); +} + +TEST_CASE_FIXTURE(Fixture, "read_and_write_only_table_properties_are_unsupported") +{ + ScopedFastFlag sff[] = {{FFlag::LuauSolverV2, false}}; + + CheckResult result = check(R"( + type W = {read x: number} + type X = {write x: boolean} + + type Y = {read ["prop"]: boolean} + type Z = {write ["prop"]: string} + )"); + + LUAU_REQUIRE_ERROR_COUNT(4, result); + + CHECK("read keyword is illegal here" == toString(result.errors[0])); + CHECK(Location{{1, 18}, {1, 22}} == result.errors[0].location); + CHECK("write keyword is illegal here" == toString(result.errors[1])); + CHECK(Location{{2, 18}, {2, 23}} == result.errors[1].location); + CHECK("read keyword is illegal here" == toString(result.errors[2])); + CHECK(Location{{4, 18}, {4, 22}} == result.errors[2].location); + CHECK("write keyword is illegal here" == toString(result.errors[3])); + CHECK(Location{{5, 18}, {5, 23}} == result.errors[3].location); +} + +TEST_CASE_FIXTURE(Fixture, "read_ond_write_only_indexers_are_unsupported") +{ + ScopedFastFlag sff[] = {{FFlag::LuauSolverV2, false}}; + + CheckResult result = check(R"( + type T = {read [string]: number} + type U = {write [string]: boolean} + )"); + + LUAU_REQUIRE_ERROR_COUNT(2, result); + + CHECK("read keyword is illegal here" == toString(result.errors[0])); + CHECK(Location{{1, 18}, {1, 22}} == result.errors[0].location); + CHECK("write keyword is illegal here" == toString(result.errors[1])); + CHECK(Location{{2, 18}, {2, 23}} == result.errors[1].location); +} + +TEST_CASE_FIXTURE(Fixture, "table_writes_introduce_write_properties") +{ + if (!FFlag::LuauSolverV2) + return; + + ScopedFastFlag sff[] = {{FFlag::LuauSolverV2, true}}; + + CheckResult result = check(R"( + function oc(player, speaker) + local head = speaker.Character:FindFirstChild('Head') + speaker.Character = player[1].Character + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK( + "({{ read Character: t1 }}, { Character: t1 }) -> () " + "where " + "t1 = { read FindFirstChild: (t1, string) -> (a, b...) }" == toString(requireType("oc")) + ); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "tables_can_have_both_metatables_and_indexers") +{ + CheckResult result = check(R"( + local a = {} + a[1] = 5 + a[2] = 17 + + local t = {} + setmetatable(a, t) + + local c = a[1] + print(a[1]) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK("number" == toString(requireType("c"))); +} + +TEST_CASE_FIXTURE(Fixture, "refined_thing_can_be_an_array") +{ + CheckResult result = check(R"( + function foo(x, y) + if x then + return x[1] + else + return y + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK("({a}, a) -> a" == toString(requireType("foo"))); +} + +TEST_CASE_FIXTURE(Fixture, "parameter_was_set_an_indexer_and_bounded_by_string") +{ + if (!FFlag::LuauSolverV2) + return; + + CheckResult result = check(R"( + function f(t) + local s: string = t + t[5] = 7 + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(3, result); + + CHECK_EQ("Parameter 't' has been reduced to never. This function is not callable with any possible value.", toString(result.errors[0])); + CHECK_EQ("Parameter 't' is required to be a subtype of 'string' here.", toString(result.errors[1])); + CHECK_EQ("Parameter 't' is required to be a subtype of '{number}' here.", toString(result.errors[2])); +} + +TEST_CASE_FIXTURE(Fixture, "parameter_was_set_an_indexer_and_bounded_by_another_parameter") +{ + if (!FFlag::LuauSolverV2) + return; + + CheckResult result = check(R"( + function f(t1, t2) + t1[5] = 7 -- 't1 <: {number} + t2 = t1 -- 't1 <: 't2 + t1[5] = 7 -- 't1 <: {number} + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + // FIXME CLI-114134. We need to simplify types more consistently. + CHECK_EQ("(unknown & {number} & {number}, unknown) -> ()", toString(requireType("f"))); +} + +TEST_CASE_FIXTURE(Fixture, "write_to_union_property_not_all_present") +{ + ScopedFastFlag sff{FFlag::LuauSolverV2, true}; + + CheckResult result = check(R"( + type Animal = {tag: "Cat", meow: boolean} | {tag: "Dog", woof: boolean} + function f(t: Animal) + t.tag = "Dog" + end + )"); + + // this should fail because `t` may be a `Cat` variant, and `"Dog"` is not a subtype of `"Cat"`. + LUAU_REQUIRE_ERRORS(result); + + CannotAssignToNever* tm = get(result.errors[0]); + REQUIRE(tm); + + CHECK(builtinTypes->stringType == tm->rhsType); + CHECK(CannotAssignToNever::Reason::PropertyNarrowed == tm->reason); + REQUIRE(tm->cause.size() == 2); + CHECK("\"Cat\"" == toString(tm->cause[0])); + CHECK("\"Dog\"" == toString(tm->cause[1])); +} + +TEST_CASE_FIXTURE(Fixture, "mymovie_read_write_tables_bug") +{ + CheckResult result = check(R"( + type MockedResponseBody = string | (() -> MockedResponseBody) + type MockedResponse = { type: 'body', body: MockedResponseBody } | { type: 'error' } + + local function mockedResponseToHttpResponse(mockedResponse: MockedResponse) + assert(mockedResponse.type == 'body', 'Mocked response is not a body') + if typeof(mockedResponse.body) == 'string' then + else + return mockedResponseToHttpResponse(mockedResponse) + end + end + )"); + + // we're primarily interested in knowing that this does not crash. + LUAU_REQUIRE_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "mymovie_read_write_tables_bug_2") +{ + CheckResult result = check(R"( + type MockedResponse = { type: 'body' } | { type: 'error' } + + local function mockedResponseToHttpResponse(mockedResponse: MockedResponse) + assert(mockedResponse.type == 'body', 'Mocked response is not a body') + + if typeof(mockedResponse.body) == 'string' then + elseif typeof(mockedResponse.body) == 'table' then + else + return mockedResponseToHttpResponse(mockedResponse) + end + end + )"); + + // we're primarily interested in knowing that this does not crash. + LUAU_REQUIRE_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "instantiated_metatable_frozen_table_clone_mutation") +{ + fileResolver.source["game/worker"] = R"( +type WorkerImpl = { + destroy: (self: Worker) -> boolean, +} + +type WorkerProps = { id: number } + +export type Worker = typeof(setmetatable({} :: WorkerProps, {} :: WorkerImpl)) + +return {} + )"; + + fileResolver.source["game/library"] = R"( +local Worker = require(game.worker) + +export type Worker = Worker.Worker + +return {} + )"; + + CheckResult result = frontend.check("game/library"); + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "setprop_on_a_mutating_local_in_both_loops_and_functions") +{ + CheckResult result = check(R"( + local _ = 5 + + while (_) do + _._ = nil + function _() + _ = nil + end + end + )"); + + LUAU_REQUIRE_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "cant_index_this") +{ + CheckResult result = check(R"( + local a: number = 9 + a[18] = "tomfoolery" + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + NotATable* notATable = get(result.errors[0]); + REQUIRE(notATable); + + CHECK("number" == toString(notATable->ty)); +} + +TEST_CASE_FIXTURE(Fixture, "setindexer_multiple_tables_intersection") +{ + ScopedFastFlag sff{FFlag::LuauSolverV2, true}; + + CheckResult result = check(R"( + local function f(t: { [string]: number } & { [thread]: boolean }, x) + local k = "a" + t[k] = x + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(2, result); + CHECK("({ [string]: number } & { [thread]: boolean }, never) -> ()" == toString(requireType("f"))); +} + +TEST_CASE_FIXTURE(Fixture, "insert_a_and_f_of_a_into_table_res_in_a_loop") +{ + CheckResult result = check(R"( + local function f(t) + local res = {} + + for k, a in t do + res[k] = f(a) + res[k] = a + end + end + )"); + + if (FFlag::LuauSolverV2) + { + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK(get(result.errors[0])); + } + else + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "ipairs_adds_an_unbounded_indexer") +{ + CheckResult result = check(R"( + --!strict + + local a = {} + ipairs(a) + )"); + + // The old solver erroneously leaves a free type dangling here. The new + // solver does better. + if (FFlag::LuauSolverV2) + CHECK("{unknown}" == toString(requireType("a"), {true})); + else + CHECK("{a}" == toString(requireType("a"), {true})); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "index_results_compare_to_nil") +{ + CheckResult result = check(R"( + --!strict + + function foo(tbl: {number}) + if tbl[2] == nil then + print("foo") + end + + if tbl[3] ~= nil then + print("bar") + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "fuzzer_normalization_preserves_tbl_scopes") +{ + CheckResult result = check(R"( +Module 'l0': +do end + +Module 'l1': +local _ = {n0=nil,} +if if nil then _ then +if nil and (_)._ ~= (_)._ then +do end +while _ do +_ = _ +do end +end +end +do end +end +local l0 +while _ do +_ = nil +(_[_])._ %= `{# _}{bit32.extract(# _,1)}` +end + +)"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "table_literal_inference_assert") +{ + CheckResult result = check(R"( + local buttons = { + buttons = {}; + } + + buttons.Button = { + call = nil; + lightParts = nil; + litPropertyOverrides = nil; + model = nil; + pivot = nil; + unlitPropertyOverrides = nil; + } + buttons.Button.__index = buttons.Button + + local lightFuncs: { (self: types.Button, lit: boolean) -> nil } = { + ['\x00'] = function(self: types.Button, lit: boolean) + end; + } + )"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "metatable_table_assertion_crash") +{ + CheckResult result = check(R"( + local NexusInstance = {} + function NexusInstance:__InitMetaMethods(): () + local Metatable = {} + local OriginalIndexTable = getmetatable(self).__index + setmetatable(self, Metatable) + + Metatable.__newindex = function(_, Index: string, Value: any): () + --Return if the new and old values are the same. + if self[Index] == Value then + end + end + end + )"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "table::insert_should_not_report_errors_when_correct_overload_is_picked") +{ + CheckResult result = check(R"( +type cs = { GetTagged : (cs, string) -> any} +local destroyQueue: {any} = {} -- pair of (time, coin) +local tick : () -> any +local CS : cs +local DESTROY_DELAY +local function SpawnCoin() + local spawns = CS:GetTagged('CoinSpawner') + local n : any + local StartPos = spawns[n].CFrame + local Coin = script.Coin:Clone() + Coin.CFrame = StartPos + Coin.Parent = workspace.Coins + + table.insert(destroyQueue, {tick() + DESTROY_DELAY, Coin}) +end +)"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + + +TEST_CASE_FIXTURE(BuiltinsFixture, "indexing_branching_table") +{ + ScopedFastFlag sff{FFlag::LuauAcceptIndexingTableUnionsIntersections, true}; + + CheckResult result = check(R"( + local test = if true then { "meow", "woof" } else { 4, 81 } + local test2 = test[1] + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + // unfortunate type duplication in the union + if (FFlag::LuauSolverV2) + CHECK("number | string | string" == toString(requireType("test2"))); + else + CHECK("number | string" == toString(requireType("test2"))); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "indexing_branching_table2") +{ + ScopedFastFlag sff{FFlag::LuauAcceptIndexingTableUnionsIntersections, true}; + + CheckResult result = check(R"( + local test = if true then {} else {} + local test2 = test[1] + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + // unfortunate type duplication in the union + if (FFlag::LuauSolverV2) + CHECK("unknown | unknown" == toString(requireType("test2"))); + else + CHECK("any" == toString(requireType("test2"))); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "length_of_array_is_number") +{ + CheckResult result = check(R"( + local function TestFunc(ranges: {number}): number + if true then + ranges = {} :: {number} + end + local numRanges: number = #ranges + return numRanges + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "subtyping_with_a_metatable_table_path") +{ + // Builtin functions have to be setup for the new solver + if (!FFlag::LuauSolverV2) + return; + + CheckResult result = check(R"( + type self = {} & {} + type Class = typeof(setmetatable()) + local function _(): Class + return setmetatable({}::self, {}) + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ( + "Type pack '{ @metatable { }, { } & { } }' could not be converted into 'Class'; at [0].metatable(), { } is not a subtype of nil\n" + "\ttype { @metatable { }, { } & { } }[0].table()[0] ({ }) is not a subtype of Class[0].table() (nil)\n" + "\ttype { @metatable { }, { } & { } }[0].table()[1] ({ }) is not a subtype of Class[0].table() (nil)", + toString(result.errors[0]) + ); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "metatable_union_type") +{ + + ScopedFastFlag _{FFlag::LuauSolverV2, true}; + + // This will have one (legitimate) error but previously would crash. + auto result = check(R"( + local function set(key, value) + local Message = {} + function Message.new(message) + local self = message or {} + setmetatable(self, Message) + return self + end + local self = Message.new(nil) + self[key] = value + end + )"); + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ( + "Cannot add indexer to table '{ @metatable t1, (nil & ~(false?)) | { } } where t1 = { new: (a) -> { @metatable t1, (a & ~(false?)) | { } } }'", + toString(result.errors[0]) + ); +} + +TEST_CASE_FIXTURE(Fixture, "function_check_constraint_too_eager") +{ + ScopedFastFlag _{FFlag::LuauSolverV2, true}; + + // CLI-121540: All of these examples should have no errors. + + LUAU_CHECK_ERROR_COUNT(3, check(R"( + local function doTheThing(_: { [string]: unknown }) end + doTheThing({ + ['foo'] = 5, + ['bar'] = 'heyo', + }) + )")); + + LUAU_CHECK_ERROR_COUNT(1, check(R"( + type Input = { [string]: unknown } + + local i : Input = { + [('%s'):format('3.14')]=5, + ['stringField']='Heyo' + } + )")); + + // This example previously asserted due to eagerly mutating the underlying + // table type. + LUAU_CHECK_ERROR_COUNT(3, check(R"( + type Input = { [string]: unknown } + + local function doTheThing(_: Input) end + + doTheThing({ + [('%s'):format('3.14')]=5, + ['stringField']='Heyo' + }) + )")); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.test.cpp b/tests/TypeInfer.test.cpp index 3865e83a8..42963f5e1 100644 --- a/tests/TypeInfer.test.cpp +++ b/tests/TypeInfer.test.cpp @@ -2,12 +2,14 @@ #include "Luau/AstQuery.h" #include "Luau/BuiltinDefinitions.h" +#include "Luau/Frontend.h" #include "Luau/Scope.h" #include "Luau/TypeInfer.h" #include "Luau/Type.h" #include "Luau/VisitType.h" #include "Fixture.h" +#include "ClassFixture.h" #include "ScopedFlags.h" #include "doctest.h" @@ -15,8 +17,13 @@ #include LUAU_FASTFLAG(LuauFixLocationSpanTableIndexExpr); -LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); +LUAU_FASTFLAG(LuauSolverV2); LUAU_FASTFLAG(LuauInstantiateInSubtyping); +LUAU_FASTINT(LuauCheckRecursionLimit); +LUAU_FASTINT(LuauNormalizeCacheLimit); +LUAU_FASTINT(LuauRecursionLimit); +LUAU_FASTINT(LuauTypeInferRecursionLimit); +LUAU_FASTFLAG(LuauNewSolverVisitErrorExprLvalues) using namespace Luau; @@ -27,8 +34,7 @@ TEST_CASE_FIXTURE(Fixture, "tc_hello_world") CheckResult result = check("local a = 7"); LUAU_REQUIRE_NO_ERRORS(result); - TypeId aType = requireType("a"); - CHECK_EQ(getPrimitiveType(aType), PrimitiveType::Number); + CHECK("number" == toString(requireType("a"))); } TEST_CASE_FIXTURE(Fixture, "tc_propagation") @@ -43,20 +49,47 @@ TEST_CASE_FIXTURE(Fixture, "tc_propagation") TEST_CASE_FIXTURE(Fixture, "tc_error") { CheckResult result = check("local a = 7 local b = 'hi' a = b"); - LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(result.errors[0], (TypeError{Location{Position{0, 35}, Position{0, 36}}, TypeMismatch{typeChecker.numberType, typeChecker.stringType}})); + if (FFlag::LuauSolverV2) + { + LUAU_REQUIRE_NO_ERRORS(result); + CHECK("number | string" == toString(requireType("a"))); + } + else + { + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CHECK_EQ( + result.errors[0], + (TypeError{Location{Position{0, 35}, Position{0, 36}}, TypeMismatch{builtinTypes->numberType, builtinTypes->stringType}}) + ); + } } TEST_CASE_FIXTURE(Fixture, "tc_error_2") { CheckResult result = check("local a = 7 a = 'hi'"); - LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(result.errors[0], (TypeError{Location{Position{0, 18}, Position{0, 22}}, TypeMismatch{ - requireType("a"), - typeChecker.stringType, - }})); + if (FFlag::LuauSolverV2) + { + LUAU_REQUIRE_NO_ERRORS(result); + CHECK("number | string" == toString(requireType("a"))); + } + else + { + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CHECK_EQ( + result.errors[0], + (TypeError{ + Location{Position{0, 18}, Position{0, 22}}, + TypeMismatch{ + requireType("a"), + builtinTypes->stringType, + } + }) + ); + } } TEST_CASE_FIXTURE(Fixture, "infer_locals_with_nil_value") @@ -64,8 +97,27 @@ TEST_CASE_FIXTURE(Fixture, "infer_locals_with_nil_value") CheckResult result = check("local f = nil; f = 'hello world'"); LUAU_REQUIRE_NO_ERRORS(result); - TypeId ty = requireType("f"); - CHECK_EQ(getPrimitiveType(ty), PrimitiveType::String); + if (FFlag::LuauSolverV2) + { + CHECK("string?" == toString(requireType("f"))); + } + else + { + TypeId ty = requireType("f"); + CHECK_EQ(getPrimitiveType(ty), PrimitiveType::String); + } +} + +TEST_CASE_FIXTURE(Fixture, "infer_locals_with_nil_value_2") +{ + CheckResult result = check(R"( + local a = 2 + local b = a,nil + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("number", toString(requireType("a"))); + CHECK_EQ("number", toString(requireType("b"))); } TEST_CASE_FIXTURE(Fixture, "infer_locals_via_assignment_from_its_call_site") @@ -77,15 +129,25 @@ TEST_CASE_FIXTURE(Fixture, "infer_locals_via_assignment_from_its_call_site") f("foo") )"); - LUAU_REQUIRE_ERROR_COUNT(1, result); + if (FFlag::LuauSolverV2) + { + CHECK("unknown" == toString(requireType("a"))); + CHECK("(unknown) -> ()" == toString(requireType("f"))); - CHECK_EQ("number", toString(requireType("a"))); + LUAU_REQUIRE_NO_ERRORS(result); + } + else + { + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CHECK_EQ("number", toString(requireType("a"))); + } } TEST_CASE_FIXTURE(Fixture, "infer_in_nocheck_mode") { ScopedFastFlag sff[]{ - {"DebugLuauDeferredConstraintResolution", false}, + {FFlag::LuauSolverV2, false}, }; CheckResult result = check(R"( @@ -102,6 +164,16 @@ TEST_CASE_FIXTURE(Fixture, "infer_in_nocheck_mode") LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(Fixture, "obvious_type_error_in_nocheck_mode") +{ + CheckResult result = check(R"( + --!nocheck + local x: string = 5 + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_CASE_FIXTURE(Fixture, "expr_statement") { CheckResult result = check("local foo = 5 foo()"); @@ -123,8 +195,16 @@ TEST_CASE_FIXTURE(Fixture, "if_statement") LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(*typeChecker.stringType, *requireType("a")); - CHECK_EQ(*typeChecker.numberType, *requireType("b")); + if (FFlag::LuauSolverV2) + { + CHECK("string?" == toString(requireType("a"))); + CHECK("number?" == toString(requireType("b"))); + } + else + { + CHECK_EQ(*builtinTypes->stringType, *requireType("a")); + CHECK_EQ(*builtinTypes->numberType, *requireType("b")); + } } TEST_CASE_FIXTURE(Fixture, "statements_are_topologically_sorted") @@ -145,6 +225,8 @@ TEST_CASE_FIXTURE(Fixture, "statements_are_topologically_sorted") TEST_CASE_FIXTURE(Fixture, "unify_nearly_identical_recursive_types") { + ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + CheckResult result = check(R"( local o o:method() @@ -184,6 +266,8 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "weird_case") TEST_CASE_FIXTURE(Fixture, "dont_ice_when_failing_the_occurs_check") { + ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + CheckResult result = check(R"( --!strict local s @@ -213,7 +297,7 @@ TEST_CASE_FIXTURE(Fixture, "crazy_complexity") A:A():A():A():A():A():A():A():A():A():A():A() )"); - std::cout << "OK! Allocated " << typeChecker.types.size() << " types" << std::endl; + MESSAGE("OK! Allocated ", typeChecker.types.size(), " types"); } #endif @@ -235,7 +319,7 @@ TEST_CASE_FIXTURE(Fixture, "type_errors_infer_types") CHECK_EQ("x", err->key); // TODO: Should we assert anything about these tests when DCR is being used? - if (!FFlag::DebugLuauDeferredConstraintResolution) + if (!FFlag::LuauSolverV2) { CHECK_EQ("*error-type*", toString(requireType("c"))); CHECK_EQ("*error-type*", toString(requireType("d"))); @@ -301,6 +385,8 @@ TEST_CASE_FIXTURE(Fixture, "exponential_blowup_from_copying_types") // checker. We also want it to somewhat match up with production values, so we push up the parser recursion limit a little bit instead. TEST_CASE_FIXTURE(Fixture, "check_type_infer_recursion_count") { + ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + #if defined(LUAU_ENABLE_ASAN) int limit = 250; #elif defined(_DEBUG) || defined(_NOOPT) @@ -309,7 +395,7 @@ TEST_CASE_FIXTURE(Fixture, "check_type_infer_recursion_count") int limit = 600; #endif - ScopedFastInt sfi{"LuauCheckRecursionLimit", limit}; + ScopedFastInt sfi{FInt::LuauCheckRecursionLimit, limit}; CheckResult result = check("function f() return " + rep("{a=", limit) + "'a'" + rep("}", limit) + " end"); @@ -327,8 +413,8 @@ TEST_CASE_FIXTURE(Fixture, "check_block_recursion_limit") int limit = 600; #endif - ScopedFastInt luauRecursionLimit{"LuauRecursionLimit", limit + 100}; - ScopedFastInt luauCheckRecursionLimit{"LuauCheckRecursionLimit", limit - 100}; + ScopedFastInt luauRecursionLimit{FInt::LuauRecursionLimit, limit + 100}; + ScopedFastInt luauCheckRecursionLimit{FInt::LuauCheckRecursionLimit, limit - 100}; CheckResult result = check(rep("do ", limit) + "local a = 1" + rep(" end", limit)); @@ -345,8 +431,8 @@ TEST_CASE_FIXTURE(Fixture, "check_expr_recursion_limit") #else int limit = 600; #endif - ScopedFastInt luauRecursionLimit{"LuauRecursionLimit", limit + 100}; - ScopedFastInt luauCheckRecursionLimit{"LuauCheckRecursionLimit", limit - 100}; + ScopedFastInt luauRecursionLimit{FInt::LuauRecursionLimit, limit + 100}; + ScopedFastInt luauCheckRecursionLimit{FInt::LuauCheckRecursionLimit, limit - 100}; CheckResult result = check(R"(("foo"))" + rep(":lower()", limit)); @@ -356,6 +442,9 @@ TEST_CASE_FIXTURE(Fixture, "check_expr_recursion_limit") TEST_CASE_FIXTURE(Fixture, "globals") { + // The new solver does not permit assignments to globals like this. + ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + CheckResult result = check(R"( --!nonstrict foo = true @@ -368,6 +457,8 @@ TEST_CASE_FIXTURE(Fixture, "globals") TEST_CASE_FIXTURE(Fixture, "globals2") { + ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + CheckResult result = check(R"( --!nonstrict foo = function() return 1 end @@ -416,6 +507,8 @@ TEST_CASE_FIXTURE(Fixture, "correctly_scope_locals_do") TEST_CASE_FIXTURE(Fixture, "checking_should_not_ice") { + ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + CHECK_NOTHROW(check(R"( --!nonstrict f,g = ... @@ -473,8 +566,13 @@ struct FindFreeTypes return !foundOne; } - template - bool operator()(ID, Unifiable::Free) + bool operator()(TypeId, FreeType) + { + foundOne = true; + return false; + } + + bool operator()(TypePackId, FreeTypePack) { foundOne = true; return false; @@ -503,6 +601,8 @@ TEST_CASE_FIXTURE(Fixture, "tc_after_error_recovery_no_assert") TEST_CASE_FIXTURE(BuiltinsFixture, "tc_after_error_recovery_no_replacement_name_in_error") { { + ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + CheckResult result = check(R"( --!strict local t = { x = 10, y = 20 } @@ -523,6 +623,8 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "tc_after_error_recovery_no_replacement_name_ } { + ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + CheckResult result = check(R"( --!strict function string.() end @@ -580,7 +682,7 @@ TEST_CASE_FIXTURE(Fixture, "stringify_nested_unions_with_optionals") LUAU_REQUIRE_ERROR_COUNT(1, result); TypeMismatch* tm = get(result.errors[0]); REQUIRE(tm); - CHECK_EQ(typeChecker.numberType, tm->wantedType); + CHECK_EQ(builtinTypes->numberType, tm->wantedType); CHECK_EQ("(boolean | number | string)?", toString(tm->givenType)); } @@ -596,6 +698,8 @@ TEST_CASE_FIXTURE(Fixture, "cli_39932_use_unifier_in_ensure_methods") TEST_CASE_FIXTURE(Fixture, "dont_report_type_errors_within_an_AstStatError") { + ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + CheckResult result = check(R"( foo )"); @@ -605,6 +709,8 @@ TEST_CASE_FIXTURE(Fixture, "dont_report_type_errors_within_an_AstStatError") TEST_CASE_FIXTURE(Fixture, "dont_report_type_errors_within_an_AstExprError") { + ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + CheckResult result = check(R"( local a = foo: )"); @@ -654,11 +760,19 @@ TEST_CASE_FIXTURE(Fixture, "no_stack_overflow_from_isoptional") std::optional t0 = lookupType("t0"); REQUIRE(t0); - CHECK_EQ("*error-type*", toString(*t0)); - - auto it = std::find_if(result.errors.begin(), result.errors.end(), [](TypeError& err) { - return get(err); - }); + if (FFlag::LuauSolverV2) + CHECK("any" == toString(*t0)); + else + CHECK_EQ("*error-type*", toString(*t0)); + + auto it = std::find_if( + result.errors.begin(), + result.errors.end(), + [](TypeError& err) + { + return get(err); + } + ); CHECK(it != result.errors.end()); } @@ -706,7 +820,10 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "no_heap_use_after_free_error") end )"); - LUAU_REQUIRE_ERRORS(result); + if (FFlag::LuauSolverV2) + LUAU_REQUIRE_NO_ERRORS(result); + else + LUAU_REQUIRE_ERRORS(result); } TEST_CASE_FIXTURE(Fixture, "infer_type_assertion_value_type") @@ -903,6 +1020,41 @@ TEST_CASE_FIXTURE(Fixture, "fuzzer_found_this") )"); } +/* + * We had a bug where we'd improperly cache the normalization of types that are + * not fully solved yet. This eventually caused a crash elsewhere in the type + * solver. + */ +TEST_CASE_FIXTURE(BuiltinsFixture, "fuzzer_found_this_2") +{ + (void)check(R"( + local _ + if _ then + _ = _ + while _() do + _ = # _ + end + end + )"); +} + +TEST_CASE_FIXTURE(Fixture, "indexing_a_cyclic_intersection_does_not_crash") +{ + (void)check(R"( + local _ + if _ then + while nil do + _ = _ + end + end + if _[if _ then ""] then + while nil do + _ = if _ then "" + end + end + )"); +} + TEST_CASE_FIXTURE(BuiltinsFixture, "recursive_metatable_crash") { CheckResult result = check(R"( @@ -948,6 +1100,8 @@ end TEST_CASE_FIXTURE(Fixture, "cli_50041_committing_txnlog_in_apollo_client_error") { + ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + CheckResult result = check(R"( --!strict --!nolint @@ -973,10 +1127,6 @@ TEST_CASE_FIXTURE(Fixture, "cli_50041_committing_txnlog_in_apollo_client_error") function Policies:readField(options: ReadFieldOptions) local _ = self:getStoreFieldName(options) - --[[ - Type error: - TypeError { "MainModule", Location { { line = 25, col = 16 }, { line = 25, col = 20 } }, TypeMismatch { Policies, {- getStoreFieldName: (tp1) -> (a, b...) -} } } - ]] foo(self) end )"); @@ -988,16 +1138,23 @@ TEST_CASE_FIXTURE(Fixture, "cli_50041_committing_txnlog_in_apollo_client_error") // unsound. LUAU_REQUIRE_ERROR_COUNT(1, result); - - CHECK_EQ( - R"(Type 't1 where t1 = {+ getStoreFieldName: (t1, {| fieldName: string |} & {| from: number? |}) -> (a, b...) +}' could not be converted into 'Policies' + const std::string expected = R"(Type 'Policies' from 'MainModule' could not be converted into 'Policies' from 'MainModule' caused by: - Property 'getStoreFieldName' is not compatible. Type 't1 where t1 = ({+ getStoreFieldName: t1 +}, {| fieldName: string |} & {| from: number? |}) -> (a, b...)' could not be converted into '(Policies, FieldSpecifier) -> string' + Property 'getStoreFieldName' is not compatible. +Type + '(Policies, FieldSpecifier & {| from: number? |}) -> (a, b...)' +could not be converted into + '(Policies, FieldSpecifier) -> string' caused by: - Argument #2 type is not compatible. Type 'FieldSpecifier' could not be converted into 'FieldSpecifier & {| from: number? |}' + Argument #2 type is not compatible. +Type + 'FieldSpecifier' +could not be converted into + 'FieldSpecifier & {| from: number? |}' caused by: - Not all intersection parts are compatible. Table type 'FieldSpecifier' not compatible with type '{| from: number? |}' because the former has extra field 'fieldName')", - toString(result.errors[0])); + Not all intersection parts are compatible. +Table type 'FieldSpecifier' not compatible with type '{| from: number? |}' because the former has extra field 'fieldName')"; + CHECK_EQ(expected, toString(result.errors[0])); } else { @@ -1007,7 +1164,7 @@ caused by: TEST_CASE_FIXTURE(Fixture, "type_infer_recursion_limit_no_ice") { - ScopedFastInt sfi("LuauTypeInferRecursionLimit", 2); + ScopedFastInt sfi(FInt::LuauTypeInferRecursionLimit, 2); CheckResult result = check(R"( function complex() @@ -1020,12 +1177,16 @@ TEST_CASE_FIXTURE(Fixture, "type_infer_recursion_limit_no_ice") )"); LUAU_REQUIRE_ERRORS(result); - CHECK_EQ("Code is too complex to typecheck! Consider simplifying the code around this area", toString(result.errors[0])); + + if (FFlag::LuauSolverV2) + CHECK("Type contains a self-recursive construct that cannot be resolved" == toString(result.errors[0])); + else + CHECK_EQ("Code is too complex to typecheck! Consider simplifying the code around this area", toString(result.errors[0])); } TEST_CASE_FIXTURE(Fixture, "type_infer_recursion_limit_normalizer") { - ScopedFastInt sfi("LuauTypeInferRecursionLimit", 10); + ScopedFastInt sfi(FInt::LuauTypeInferRecursionLimit, 10); CheckResult result = check(R"( function f() @@ -1034,13 +1195,21 @@ TEST_CASE_FIXTURE(Fixture, "type_infer_recursion_limit_normalizer") end )"); - LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ("Internal error: Code is too complex to typecheck! Consider adding type annotations around this area", toString(result.errors[0])); + validateErrors(result.errors); + REQUIRE_MESSAGE(!result.errors.empty(), getErrors(result)); + + CHECK(1 == result.errors.size()); + + if (FFlag::LuauSolverV2) + CHECK(Location{{3, 22}, {3, 42}} == result.errors[0].location); + else + CHECK(Location{{3, 12}, {3, 46}} == result.errors[0].location); + CHECK_EQ("Code is too complex to typecheck! Consider simplifying the code around this area", toString(result.errors[0])); } TEST_CASE_FIXTURE(Fixture, "type_infer_cache_limit_normalizer") { - ScopedFastInt sfi("LuauNormalizeCacheLimit", 10); + ScopedFastInt sfi(FInt::LuauNormalizeCacheLimit, 10); CheckResult result = check(R"( local x : ((number) -> number) & ((string) -> string) & ((nil) -> nil) & (({}) -> {}) @@ -1048,11 +1217,14 @@ TEST_CASE_FIXTURE(Fixture, "type_infer_cache_limit_normalizer") )"); LUAU_REQUIRE_ERRORS(result); - CHECK_EQ("Internal error: Code is too complex to typecheck! Consider adding type annotations around this area", toString(result.errors[0])); + CHECK_EQ("Code is too complex to typecheck! Consider simplifying the code around this area", toString(result.errors[0])); } TEST_CASE_FIXTURE(Fixture, "follow_on_new_types_in_substitution") { + // CLI-114134 + ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + CheckResult result = check(R"( local obj = {} @@ -1150,8 +1322,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "it_is_ok_to_have_inconsistent_number_of_retu TEST_CASE_FIXTURE(Fixture, "fuzz_free_table_type_change_during_index_check") { - ScopedFastFlag sff{"LuauScalarShapeUnifyToMtOwner2", true}; - CheckResult result = check(R"( local _ = nil while _["" >= _] do @@ -1163,8 +1333,6 @@ end TEST_CASE_FIXTURE(BuiltinsFixture, "typechecking_in_type_guards") { - ScopedFastFlag sff{"LuauTypecheckTypeguards", true}; - CheckResult result = check(R"( local a = type(foo) == 'nil' local b = typeof(foo) ~= 'nil' @@ -1175,4 +1343,393 @@ local b = typeof(foo) ~= 'nil' CHECK(toString(result.errors[1]) == "Unknown global 'foo'"); } +TEST_CASE_FIXTURE(Fixture, "occurs_isnt_always_failure") +{ + CheckResult result = check(R"( +function f(x, c) -- x : X + local y = if c then x else nil -- y : X? + local z = if c then x else nil -- z : X? + y = z +end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "dcr_delays_expansion_of_function_containing_blocked_parameter_type") +{ + ScopedFastFlag sff[] = { + {FFlag::LuauSolverV2, true}, + }; + + CheckResult result = check(R"( + local b: any + + function f(x) + local a = b[1] or 'Cn' + local c = x[1] + + if a:sub(1, #c) == c then + end + end + )"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "recursive_function_that_invokes_itself_with_a_refinement_of_its_parameter") +{ + CheckResult result = check(R"( + local TRUE: true = true + + local function matches(value, t: true) + if value then + return true + end + end + + local function readValue(breakpoint) + if matches(breakpoint, TRUE) then + readValue(breakpoint) + end + end + )"); + + if (FFlag::LuauSolverV2) + CHECK("(unknown) -> ()" == toString(requireType("readValue"))); + else + CHECK("(a) -> ()" == toString(requireType("readValue"))); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "recursive_function_that_invokes_itself_with_a_refinement_of_its_parameter_2") +{ + CheckResult result = check(R"( + local function readValue(breakpoint) + if type(breakpoint) == 'number' then + readValue(breakpoint) + end + end + )"); + + if (FFlag::LuauSolverV2) + CHECK("(unknown) -> ()" == toString(requireType("readValue"))); + else + CHECK("(number) -> ()" == toString(requireType("readValue"))); +} + +/* + * We got into a case where, as we unified two nearly identical unions with one + * another, where we had a concatenated TxnLog that created a cycle between two + * free types. + * + * This code used to crash the type checker. See CLI-71190 + */ +TEST_CASE_FIXTURE(BuiltinsFixture, "convoluted_case_where_two_TypeVars_were_bound_to_each_other") +{ + check(R"( + type React_Ref = { current: ElementType } | ((ElementType) -> ()) + + type React_AbstractComponent = { + render: ((ref: React_Ref) -> nil) + } + + local createElement : (React_AbstractComponent) -> () + + function ScrollView:render() + local one = table.unpack( + if true then a else b + ) + + createElement(one) + createElement(one) + end + )"); + + // If this code does not crash, we are in good shape. +} + +/* + * Under DCR we had an issue where constraint resolution resulted in the + * following: + * + * *blocked-55* ~ hasProp {- name: *blocked-55* -}, "name" + * + * This is a perfectly reasonable constraint, but one that doesn't actually + * constrain anything. When we encounter a constraint like this, we need to + * replace the result type by a free type that is scoped to the enclosing table. + * + * Conceptually, it's simplest to think of this constraint as one that is + * tautological. It does not actually contribute any new information. + */ +TEST_CASE_FIXTURE(Fixture, "handle_self_referential_HasProp_constraints") +{ + CheckResult result = check(R"( + local function calculateTopBarHeight(props) + end + local function isTopPage(props) + local topMostOpaquePage + if props.avatarRoute then + topMostOpaquePage = props.avatarRoute.opaque.name + else + topMostOpaquePage = props.opaquePage + end + end + + function TopBarContainer:updateTopBarHeight(prevProps, prevState) + calculateTopBarHeight(self.props) + isTopPage(self.props) + local topMostOpaquePage + if self.props.avatarRoute then + topMostOpaquePage = self.props.avatarRoute.opaque.name + -- ^--------------------------------^ + else + topMostOpaquePage = self.props.opaquePage + end + end + )"); +} + +/* We had an issue where we were unifying two type packs + * + * free-2-0... and (string, free-4-0...) + * + * The correct thing to do here is to promote everything on the right side to + * level 2-0 before binding the left pack to the right. If we fail to do this, + * then the code fragment here fails to typecheck because the argument and + * return types of C are generalized before we ever get to checking the body of + * C. + */ +TEST_CASE_FIXTURE(Fixture, "promote_tail_type_packs") +{ + CheckResult result = check(R"( + --!strict + + local A: any = nil + + local C + local D = A( + A({}, { + __call = function(a): string + local E: string = C(a) + return E + end + }), + { + F = function(s: typeof(C)) + end + } + ) + + function C(b: any): string + return '' + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "lti_must_record_contributing_locations") +{ + ScopedFastFlag sff_LuauSolverV2{FFlag::LuauSolverV2, true}; + + CheckResult result = check(R"( + local function f(a) + if math.random() > 0.5 then + math.abs(a) + else + string.len(a) + end + end + )"); + + // We inspect the actual errors in other tests; this test verifies that we + // actually recorded breadcrumbs for a. + LUAU_REQUIRE_ERROR_COUNT(3, result); + TypeId fnTy = requireType("f"); + const FunctionType* fn = get(fnTy); + REQUIRE(fn); + + TypeId argTy = *first(fn->argTypes); + std::vector> locations = getMainModule()->upperBoundContributors[argTy]; + CHECK(locations.size() == 2); +} + +/* + * CLI-49876 + * + * We had a bug where we would not use the correct TxnLog when evaluating a + * variadic overload. We could therefore get into a state where the TxnLog has + * logged that a generic matches to one type, but the variadic tail has already + * been bound to another type outside of that TxnLog. + * + * This caused type checking to succeed when it should have failed. + */ +TEST_CASE_FIXTURE(BuiltinsFixture, "be_sure_to_use_active_txnlog_when_evaluating_a_variadic_overload") +{ + ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + + CheckResult result = check(R"( + local function concat(target: {T}, ...: {T} | T): {T} + return (nil :: any) :: {T} + end + + local res = concat({"alic"}, 1, 2) + )"); + + LUAU_REQUIRE_ERRORS(result); + + for (const auto& e : result.errors) + CHECK(5 == e.location.begin.line); +} + +/* + * We had an issue where this kind of typeof() call could produce the untestable type ~{} + */ +TEST_CASE_FIXTURE(Fixture, "typeof_cannot_refine_builtin_alias") +{ + GlobalTypes& globals = frontend.globals; + TypeArena& arena = globals.globalTypes; + + unfreeze(arena); + + globals.globalScope->exportedTypeBindings["GlobalTable"] = TypeFun{{}, arena.addType(TableType{TableState::Sealed, TypeLevel{}})}; + + freeze(arena); + + (void)check(R"( + function foo(x) + if typeof(x) == 'GlobalTable' then + end + end + )"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "bad_iter_metamethod") +{ + CheckResult result = check(R"( + function iter(): unknown + return nil + end + + local a = {__iter = iter} + setmetatable(a, a) + + for i in a do + end + )"); + + if (FFlag::LuauSolverV2) + { + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CannotCallNonFunction* ccnf = get(result.errors[0]); + REQUIRE(ccnf); + + CHECK("unknown" == toString(ccnf->ty)); + } + else + { + LUAU_REQUIRE_NO_ERRORS(result); + } +} + +TEST_CASE_FIXTURE(Fixture, "leading_bar") +{ + CheckResult result = check(R"( + type Bar = | number + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK("number" == toString(requireTypeAlias("Bar"))); +} + +TEST_CASE_FIXTURE(Fixture, "leading_bar_question_mark") +{ + CheckResult result = check(R"( + type Bar = |? + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK("Expected type, got '?'" == toString(result.errors[0])); + CHECK("*error-type*?" == toString(requireTypeAlias("Bar"))); +} + +TEST_CASE_FIXTURE(Fixture, "leading_ampersand") +{ + CheckResult result = check(R"( + type Amp = & string + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK("string" == toString(requireTypeAlias("Amp"))); +} + +TEST_CASE_FIXTURE(Fixture, "leading_bar_no_type") +{ + CheckResult result = check(R"( + type Bar = | + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK("Expected type, got " == toString(result.errors[0])); + CHECK("*error-type*" == toString(requireTypeAlias("Bar"))); +} + +TEST_CASE_FIXTURE(Fixture, "leading_ampersand_no_type") +{ + CheckResult result = check(R"( + type Amp = & + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK("Expected type, got " == toString(result.errors[0])); + CHECK("*error-type*" == toString(requireTypeAlias("Amp"))); +} + +TEST_CASE_FIXTURE(Fixture, "react_lua_follow_free_type_ub") +{ + ScopedFastFlag _{FFlag::LuauSolverV2, true}; + + LUAU_REQUIRE_NO_ERRORS(check(R"( + return function(Roact) + local Tree = Roact.Component:extend("Tree") + + function Tree:render() + local breadth, components, depth, id, wrap = + self.props.breadth, self.props.components, self.props.depth, self.props.id, self.props.wrap + local Box = components.Box + if depth == 0 then + Roact.createElement(Box, {}) + else + Roact.createElement(Tree, {}) + end + + end + end + )")); +} + +TEST_CASE_FIXTURE(Fixture, "visit_error_nodes_in_lvalue") +{ + ScopedFastFlag sffs[] = { + {FFlag::LuauSolverV2, true}, + {FFlag::LuauNewSolverVisitErrorExprLvalues, true} + }; + + // This should always fail to parse, but shouldn't assert. Previously this + // would assert as we end up _roughly_ parsing this (with a lot of error + // nodes) as: + // + // do + // x :: T, y = z + // end + // + // We assume that `T` has some resolved type that is set up during + // constraint generation and resolved during constraint solving to + // be used during typechecking. We didn't descend into error nodes + // in lvalue positions. + LUAU_REQUIRE_ERRORS(check(R"( + --!strict + (::, + )")); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.tryUnify.test.cpp b/tests/TypeInfer.tryUnify.test.cpp index 47b140a14..2a0a072a7 100644 --- a/tests/TypeInfer.tryUnify.test.cpp +++ b/tests/TypeInfer.tryUnify.test.cpp @@ -11,16 +11,20 @@ using namespace Luau; -LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) +LUAU_FASTFLAG(LuauSolverV2); +LUAU_FASTFLAG(LuauUnifierRecursionOnRestart); struct TryUnifyFixture : Fixture { + // Cannot use `TryUnifyFixture` under DCR. + ScopedFastFlag noDcr{FFlag::LuauSolverV2, false}; + TypeArena arena; ScopePtr globalScope{new Scope{arena.addTypePack({TypeId{}})}}; InternalErrorReporter iceHandler; UnifierSharedState unifierState{&iceHandler}; Normalizer normalizer{&arena, builtinTypes, NotNull{&unifierState}}; - Unifier state{NotNull{&normalizer}, Mode::Strict, NotNull{globalScope.get()}, Location{}, Variance::Covariant}; + Unifier state{NotNull{&normalizer}, NotNull{globalScope.get()}, Location{}, Variance::Covariant}; }; TEST_SUITE_BEGIN("TryUnifyTests"); @@ -28,22 +32,25 @@ TEST_SUITE_BEGIN("TryUnifyTests"); TEST_CASE_FIXTURE(TryUnifyFixture, "primitives_unify") { Type numberOne{TypeVariant{PrimitiveType{PrimitiveType::Number}}}; - Type numberTwo = numberOne; + Type numberTwo = numberOne.clone(); state.tryUnify(&numberTwo, &numberOne); + CHECK(!state.failure); CHECK(state.errors.empty()); } TEST_CASE_FIXTURE(TryUnifyFixture, "compatible_functions_are_unified") { - Type functionOne{ - TypeVariant{FunctionType(arena.addTypePack({arena.freshType(globalScope->level)}), arena.addTypePack({typeChecker.numberType}))}}; + Type functionOne{TypeVariant{FunctionType(arena.addTypePack({arena.freshType(globalScope->level)}), arena.addTypePack({builtinTypes->numberType})) + }}; - Type functionTwo{TypeVariant{ - FunctionType(arena.addTypePack({arena.freshType(globalScope->level)}), arena.addTypePack({arena.freshType(globalScope->level)}))}}; + Type functionTwo{ + TypeVariant{FunctionType(arena.addTypePack({arena.freshType(globalScope->level)}), arena.addTypePack({arena.freshType(globalScope->level)}))} + }; state.tryUnify(&functionTwo, &functionOne); + CHECK(!state.failure); CHECK(state.errors.empty()); state.log.commit(); @@ -54,18 +61,19 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "compatible_functions_are_unified") TEST_CASE_FIXTURE(TryUnifyFixture, "incompatible_functions_are_preserved") { TypePackVar argPackOne{TypePack{{arena.freshType(globalScope->level)}, std::nullopt}}; - Type functionOne{ - TypeVariant{FunctionType(arena.addTypePack({arena.freshType(globalScope->level)}), arena.addTypePack({typeChecker.numberType}))}}; + Type functionOne{TypeVariant{FunctionType(arena.addTypePack({arena.freshType(globalScope->level)}), arena.addTypePack({builtinTypes->numberType})) + }}; - Type functionOneSaved = functionOne; + Type functionOneSaved = functionOne.clone(); TypePackVar argPackTwo{TypePack{{arena.freshType(globalScope->level)}, std::nullopt}}; - Type functionTwo{ - TypeVariant{FunctionType(arena.addTypePack({arena.freshType(globalScope->level)}), arena.addTypePack({typeChecker.stringType}))}}; + Type functionTwo{TypeVariant{FunctionType(arena.addTypePack({arena.freshType(globalScope->level)}), arena.addTypePack({builtinTypes->stringType})) + }}; - Type functionTwoSaved = functionTwo; + Type functionTwoSaved = functionTwo.clone(); state.tryUnify(&functionTwo, &functionOne); + CHECK(state.failure); CHECK(!state.errors.empty()); CHECK_EQ(functionOne, functionOneSaved); @@ -82,39 +90,49 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "tables_can_be_unified") TableType{{{"foo", {arena.freshType(globalScope->level)}}}, std::nullopt, globalScope->level, TableState::Unsealed}, }}; - CHECK_NE(*getMutable(&tableOne)->props["foo"].type, *getMutable(&tableTwo)->props["foo"].type); + CHECK_NE(*getMutable(&tableOne)->props["foo"].type(), *getMutable(&tableTwo)->props["foo"].type()); state.tryUnify(&tableTwo, &tableOne); + CHECK(!state.failure); CHECK(state.errors.empty()); state.log.commit(); - CHECK_EQ(*getMutable(&tableOne)->props["foo"].type, *getMutable(&tableTwo)->props["foo"].type); + CHECK_EQ(*getMutable(&tableOne)->props["foo"].type(), *getMutable(&tableTwo)->props["foo"].type()); } TEST_CASE_FIXTURE(TryUnifyFixture, "incompatible_tables_are_preserved") { Type tableOne{TypeVariant{ - TableType{{{"foo", {arena.freshType(globalScope->level)}}, {"bar", {typeChecker.numberType}}}, std::nullopt, globalScope->level, - TableState::Unsealed}, + TableType{ + {{"foo", {arena.freshType(globalScope->level)}}, {"bar", {builtinTypes->numberType}}}, + std::nullopt, + globalScope->level, + TableState::Unsealed + }, }}; Type tableTwo{TypeVariant{ - TableType{{{"foo", {arena.freshType(globalScope->level)}}, {"bar", {typeChecker.stringType}}}, std::nullopt, globalScope->level, - TableState::Unsealed}, + TableType{ + {{"foo", {arena.freshType(globalScope->level)}}, {"bar", {builtinTypes->stringType}}}, + std::nullopt, + globalScope->level, + TableState::Unsealed + }, }}; - CHECK_NE(*getMutable(&tableOne)->props["foo"].type, *getMutable(&tableTwo)->props["foo"].type); + CHECK_NE(*getMutable(&tableOne)->props["foo"].type(), *getMutable(&tableTwo)->props["foo"].type()); state.tryUnify(&tableTwo, &tableOne); + CHECK(state.failure); CHECK_EQ(1, state.errors.size()); - CHECK_NE(*getMutable(&tableOne)->props["foo"].type, *getMutable(&tableTwo)->props["foo"].type); + CHECK_NE(*getMutable(&tableOne)->props["foo"].type(), *getMutable(&tableTwo)->props["foo"].type()); } -TEST_CASE_FIXTURE(TryUnifyFixture, "uninhabited_intersection_sub_never") +TEST_CASE_FIXTURE(Fixture, "uninhabited_intersection_sub_never") { CheckResult result = check(R"( function f(arg : string & number) : never @@ -124,7 +142,7 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "uninhabited_intersection_sub_never") LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(TryUnifyFixture, "uninhabited_intersection_sub_anything") +TEST_CASE_FIXTURE(Fixture, "uninhabited_intersection_sub_anything") { CheckResult result = check(R"( function f(arg : string & number) : boolean @@ -134,11 +152,9 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "uninhabited_intersection_sub_anything") LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(TryUnifyFixture, "uninhabited_table_sub_never") +TEST_CASE_FIXTURE(Fixture, "uninhabited_table_sub_never") { - ScopedFastFlag sffs[]{ - {"LuauUninhabitedSubAnything2", true}, - }; + ScopedFastFlag sff{FFlag::LuauSolverV2, false}; CheckResult result = check(R"( function f(arg : { prop : string & number }) : never @@ -148,11 +164,9 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "uninhabited_table_sub_never") LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(TryUnifyFixture, "uninhabited_table_sub_anything") +TEST_CASE_FIXTURE(Fixture, "uninhabited_table_sub_anything") { - ScopedFastFlag sffs[]{ - {"LuauUninhabitedSubAnything2", true}, - }; + ScopedFastFlag sff{FFlag::LuauSolverV2, false}; CheckResult result = check(R"( function f(arg : { prop : string & number }) : boolean @@ -162,8 +176,10 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "uninhabited_table_sub_anything") LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(TryUnifyFixture, "members_of_failed_typepack_unification_are_unified_with_errorType") +TEST_CASE_FIXTURE(Fixture, "members_of_failed_typepack_unification_are_unified_with_errorType") { + ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + CheckResult result = check(R"( function f(arg: number) end local a @@ -173,12 +189,14 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "members_of_failed_typepack_unification_are_u LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ("a", toString(requireType("a"))); + CHECK_EQ("number", toString(requireType("a"))); CHECK_EQ("*error-type*", toString(requireType("b"))); } -TEST_CASE_FIXTURE(TryUnifyFixture, "result_of_failed_typepack_unification_is_constrained") +TEST_CASE_FIXTURE(Fixture, "result_of_failed_typepack_unification_is_constrained") { + ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + CheckResult result = check(R"( function f(arg: number) return arg end local a @@ -188,12 +206,12 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "result_of_failed_typepack_unification_is_con LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ("a", toString(requireType("a"))); + CHECK_EQ("number", toString(requireType("a"))); CHECK_EQ("*error-type*", toString(requireType("b"))); CHECK_EQ("number", toString(requireType("c"))); } -TEST_CASE_FIXTURE(TryUnifyFixture, "typepack_unification_should_trim_free_tails") +TEST_CASE_FIXTURE(Fixture, "typepack_unification_should_trim_free_tails") { CheckResult result = check(R"( --!strict @@ -214,24 +232,26 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "typepack_unification_should_trim_free_tails" TEST_CASE_FIXTURE(TryUnifyFixture, "variadic_type_pack_unification") { - TypePackVar testPack{TypePack{{typeChecker.numberType, typeChecker.stringType}, std::nullopt}}; - TypePackVar variadicPack{VariadicTypePack{typeChecker.numberType}}; + TypePackVar testPack{TypePack{{builtinTypes->numberType, builtinTypes->stringType}, std::nullopt}}; + TypePackVar variadicPack{VariadicTypePack{builtinTypes->numberType}}; state.tryUnify(&testPack, &variadicPack); + CHECK(state.failure); CHECK(!state.errors.empty()); } TEST_CASE_FIXTURE(TryUnifyFixture, "variadic_tails_respect_progress") { - TypePackVar variadicPack{VariadicTypePack{typeChecker.booleanType}}; - TypePackVar a{TypePack{{typeChecker.numberType, typeChecker.stringType, typeChecker.booleanType, typeChecker.booleanType}}}; - TypePackVar b{TypePack{{typeChecker.numberType, typeChecker.stringType}, &variadicPack}}; + TypePackVar variadicPack{VariadicTypePack{builtinTypes->booleanType}}; + TypePackVar a{TypePack{{builtinTypes->numberType, builtinTypes->stringType, builtinTypes->booleanType, builtinTypes->booleanType}}}; + TypePackVar b{TypePack{{builtinTypes->numberType, builtinTypes->stringType}, &variadicPack}}; state.tryUnify(&b, &a); + CHECK(!state.failure); CHECK(state.errors.empty()); } -TEST_CASE_FIXTURE(TryUnifyFixture, "variadics_should_use_reversed_properly") +TEST_CASE_FIXTURE(Fixture, "variadics_should_use_reversed_properly") { CheckResult result = check(R"( --!strict @@ -258,7 +278,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "cli_41095_concat_log_in_sealed_table_unifica LUAU_REQUIRE_ERROR_COUNT(2, result); CHECK_EQ(toString(result.errors[0]), "No overload for function accepts 0 arguments."); - if (FFlag::DebugLuauDeferredConstraintResolution) + if (FFlag::LuauSolverV2) CHECK_EQ(toString(result.errors[1]), "Available overloads: ({V}, V) -> (); and ({V}, number, V) -> ()"); else CHECK_EQ(toString(result.errors[1]), "Available overloads: ({a}, a) -> (); and ({a}, number, a) -> ()"); @@ -266,11 +286,11 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "cli_41095_concat_log_in_sealed_table_unifica TEST_CASE_FIXTURE(TryUnifyFixture, "free_tail_is_grown_properly") { - TypePackId threeNumbers = arena.addTypePack(TypePack{{typeChecker.numberType, typeChecker.numberType, typeChecker.numberType}, std::nullopt}); - TypePackId numberAndFreeTail = arena.addTypePack(TypePack{{typeChecker.numberType}, arena.addTypePack(TypePackVar{FreeTypePack{TypeLevel{}}})}); + TypePackId threeNumbers = + arena.addTypePack(TypePack{{builtinTypes->numberType, builtinTypes->numberType, builtinTypes->numberType}, std::nullopt}); + TypePackId numberAndFreeTail = arena.addTypePack(TypePack{{builtinTypes->numberType}, arena.addTypePack(TypePackVar{FreeTypePack{TypeLevel{}}})}); - ErrorVec unifyErrors = state.canUnify(numberAndFreeTail, threeNumbers); - CHECK(unifyErrors.size() == 0); + CHECK(state.canUnify(numberAndFreeTail, threeNumbers).empty()); } TEST_CASE_FIXTURE(TryUnifyFixture, "recursive_metatable_getmatchtag") @@ -279,7 +299,7 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "recursive_metatable_getmatchtag") Type table{TableType{}}; Type metatable{MetatableType{&redirect, &table}}; redirect = BoundType{&metatable}; // Now we have a metatable that is recursive on the table type - Type variant{UnionType{{&metatable, typeChecker.numberType}}}; + Type variant{UnionType{{&metatable, builtinTypes->numberType}}}; state.tryUnify(&metatable, &variant); } @@ -293,13 +313,13 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "cli_50320_follow_in_any_unification") state.tryUnify(&free, &target); // Shouldn't assert or error. - state.tryUnify(&func, typeChecker.anyType); + state.tryUnify(&func, builtinTypes->anyType); } TEST_CASE_FIXTURE(TryUnifyFixture, "txnlog_preserves_type_owner") { TypeId a = arena.addType(Type{FreeType{TypeLevel{}}}); - TypeId b = typeChecker.numberType; + TypeId b = builtinTypes->numberType; state.tryUnify(a, b); state.log.commit(); @@ -310,7 +330,7 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "txnlog_preserves_type_owner") TEST_CASE_FIXTURE(TryUnifyFixture, "txnlog_preserves_pack_owner") { TypePackId a = arena.addTypePack(TypePackVar{FreeTypePack{TypeLevel{}}}); - TypePackId b = typeChecker.anyTypePack; + TypePackId b = builtinTypes->anyTypePack; state.tryUnify(a, b); state.log.commit(); @@ -320,16 +340,14 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "txnlog_preserves_pack_owner") TEST_CASE_FIXTURE(TryUnifyFixture, "metatables_unify_against_shape_of_free_table") { - ScopedFastFlag sff("DebugLuauDeferredConstraintResolution", true); - TableType::Props freeProps{ - {"foo", {typeChecker.numberType}}, + {"foo", {builtinTypes->numberType}}, }; TypeId free = arena.addType(TableType{freeProps, std::nullopt, TypeLevel{}, TableState::Free}); TableType::Props indexProps{ - {"foo", {typeChecker.stringType}}, + {"foo", {builtinTypes->stringType}}, }; TypeId index = arena.addType(TableType{indexProps, std::nullopt, TypeLevel{}, TableState::Sealed}); @@ -343,22 +361,25 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "metatables_unify_against_shape_of_free_table TypeId target = arena.addType(TableType{TableState::Unsealed, TypeLevel{}}); TypeId metatable = arena.addType(MetatableType{target, mt}); + state.enableNewSolver(); state.tryUnify(metatable, free); state.log.commit(); REQUIRE_EQ(state.errors.size(), 1); - - std::string expected = "Type '{ @metatable {| __index: {| foo: string |} |}, { } }' could not be converted into '{- foo: number -}'\n" - "caused by:\n" - " Type 'number' could not be converted into 'string'"; - CHECK_EQ(toString(state.errors[0]), expected); + const std::string expected = R"(Type + '{ @metatable {| __index: {| foo: string |} |}, { } }' +could not be converted into + '{- foo: number -}' +caused by: + Type 'number' could not be converted into 'string')"; + CHECK_EQ(expected, toString(state.errors[0])); } TEST_CASE_FIXTURE(TryUnifyFixture, "fuzz_tail_unification_issue") { - TypePackVar variadicAny{VariadicTypePack{typeChecker.anyType}}; - TypePackVar packTmp{TypePack{{typeChecker.anyType}, &variadicAny}}; - TypePackVar packSub{TypePack{{typeChecker.anyType, typeChecker.anyType}, &packTmp}}; + TypePackVar variadicAny{VariadicTypePack{builtinTypes->anyType}}; + TypePackVar packTmp{TypePack{{builtinTypes->anyType}, &variadicAny}}; + TypePackVar packSub{TypePack{{builtinTypes->anyType, builtinTypes->anyType}, &packTmp}}; Type freeTy{FreeType{TypeLevel{}}}; TypePackVar freeTp{FreeTypePack{TypeLevel{}}}; @@ -379,4 +400,129 @@ local l0:(any)&(typeof(_)),l0:(any)|(any) = _,_ LUAU_REQUIRE_ERRORS(result); } +static TypeId createTheType(TypeArena& arena, NotNull builtinTypes, Scope* scope, TypeId freeTy) +{ + /* + ({| + render: ( + (('a) -> ()) | {| current: 'a |} + ) -> nil + |}) -> () + */ + TypePackId emptyPack = arena.addTypePack({}); + + return arena.addType(FunctionType{ + arena.addTypePack({arena.addType(TableType{ + TableType::Props{ + {{"render", + Property(arena.addType(FunctionType{ + arena.addTypePack({arena.addType(UnionType{ + {arena.addType(FunctionType{arena.addTypePack({freeTy}), emptyPack}), + arena.addType(TableType{TableType::Props{{"current", {freeTy}}}, std::nullopt, TypeLevel{}, scope, TableState::Sealed})} + })}), + arena.addTypePack({builtinTypes->nilType}) + }))}} + }, + std::nullopt, + TypeLevel{}, + scope, + TableState::Sealed + })}), + emptyPack + }); +}; + +// See CLI-71190 +TEST_CASE_FIXTURE(TryUnifyFixture, "unifying_two_unions_under_dcr_does_not_create_a_BoundType_cycle") +{ + const std::shared_ptr scope = globalScope; + const std::shared_ptr nestedScope = std::make_shared(scope); + + const TypeId outerType = arena.freshType(scope.get()); + const TypeId outerType2 = arena.freshType(scope.get()); + + const TypeId innerType = arena.freshType(nestedScope.get()); + + state.enableNewSolver(); + + SUBCASE("equal_scopes") + { + TypeId one = createTheType(arena, builtinTypes, scope.get(), outerType); + TypeId two = createTheType(arena, builtinTypes, scope.get(), outerType2); + + state.tryUnify(one, two); + state.log.commit(); + + ToStringOptions opts; + + CHECK(follow(outerType) == follow(outerType2)); + } + + SUBCASE("outer_scope_is_subtype") + { + TypeId one = createTheType(arena, builtinTypes, scope.get(), outerType); + TypeId two = createTheType(arena, builtinTypes, scope.get(), innerType); + + state.tryUnify(one, two); + state.log.commit(); + + ToStringOptions opts; + + CHECK(follow(outerType) == follow(innerType)); + + // The scope of outerType exceeds that of innerType. The latter should be bound to the former. + const BoundType* bt = get_if(&innerType->ty); + REQUIRE(bt); + CHECK(bt->boundTo == outerType); + } + + SUBCASE("outer_scope_is_supertype") + { + TypeId one = createTheType(arena, builtinTypes, scope.get(), innerType); + TypeId two = createTheType(arena, builtinTypes, scope.get(), outerType); + + state.tryUnify(one, two); + state.log.commit(); + + ToStringOptions opts; + + CHECK(follow(outerType) == follow(innerType)); + + // The scope of outerType exceeds that of innerType. The latter should be bound to the former. + const BoundType* bt = get_if(&innerType->ty); + REQUIRE(bt); + CHECK(bt->boundTo == outerType); + } +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "table_unification_full_restart_recursion") +{ + ScopedFastFlag luauUnifierRecursionOnRestart{FFlag::LuauUnifierRecursionOnRestart, true}; + + CheckResult result = check(R"( +local A, B, C, D + +E = function(a, b) + local mt = getmetatable(b) + if mt.tm:bar(A) == nil and mt.tm:bar(B) == nil then end + if mt.foo == true then D(b, 3) end + mt.foo:call(false, b) +end + +A = function(a, b) + local mt = getmetatable(b) + if mt.foo == true then D(b, 3) end + C(mt, 3) +end + +B = function(a, b) + local mt = getmetatable(b) + if mt.foo == true then D(b, 3) end + C(mt, 3) +end + )"); + + LUAU_REQUIRE_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.typePacks.cpp b/tests/TypeInfer.typePacks.test.cpp similarity index 78% rename from tests/TypeInfer.typePacks.cpp rename to tests/TypeInfer.typePacks.test.cpp index 78eb6d477..8b489c449 100644 --- a/tests/TypeInfer.typePacks.cpp +++ b/tests/TypeInfer.typePacks.test.cpp @@ -9,6 +9,9 @@ using namespace Luau; +LUAU_FASTFLAG(LuauSolverV2); +LUAU_FASTFLAG(LuauInstantiateInSubtyping); + TEST_SUITE_BEGIN("TypePackTests"); TEST_CASE_FIXTURE(Fixture, "infer_multi_return") @@ -27,8 +30,8 @@ TEST_CASE_FIXTURE(Fixture, "infer_multi_return") const auto& [returns, tail] = flatten(takeTwoType->retTypes); CHECK_EQ(2, returns.size()); - CHECK_EQ(typeChecker.numberType, follow(returns[0])); - CHECK_EQ(typeChecker.numberType, follow(returns[1])); + CHECK_EQ(builtinTypes->numberType, follow(returns[0])); + CHECK_EQ(builtinTypes->numberType, follow(returns[1])); CHECK(!tail); } @@ -74,9 +77,9 @@ TEST_CASE_FIXTURE(Fixture, "last_element_of_return_statement_can_itself_be_a_pac const auto& [rets, tail] = flatten(takeOneMoreType->retTypes); REQUIRE_EQ(3, rets.size()); - CHECK_EQ(typeChecker.numberType, follow(rets[0])); - CHECK_EQ(typeChecker.numberType, follow(rets[1])); - CHECK_EQ(typeChecker.numberType, follow(rets[2])); + CHECK_EQ(builtinTypes->numberType, follow(rets[0])); + CHECK_EQ(builtinTypes->numberType, follow(rets[1])); + CHECK_EQ(builtinTypes->numberType, follow(rets[2])); CHECK(!tail); } @@ -184,28 +187,28 @@ TEST_CASE_FIXTURE(Fixture, "parenthesized_varargs_returns_any") TEST_CASE_FIXTURE(Fixture, "variadic_packs") { - TypeArena& arena = typeChecker.globalTypes; + TypeArena& arena = frontend.globals.globalTypes; unfreeze(arena); - TypePackId listOfNumbers = arena.addTypePack(TypePackVar{VariadicTypePack{typeChecker.numberType}}); - TypePackId listOfStrings = arena.addTypePack(TypePackVar{VariadicTypePack{typeChecker.stringType}}); + TypePackId listOfNumbers = arena.addTypePack(TypePackVar{VariadicTypePack{builtinTypes->numberType}}); + TypePackId listOfStrings = arena.addTypePack(TypePackVar{VariadicTypePack{builtinTypes->stringType}}); // clang-format off - addGlobalBinding(frontend, "foo", + addGlobalBinding(frontend.globals, "foo", arena.addType( FunctionType{ listOfNumbers, - arena.addTypePack({typeChecker.numberType}) + arena.addTypePack({builtinTypes->numberType}) } ), "@test" ); - addGlobalBinding(frontend, "bar", + addGlobalBinding(frontend.globals, "bar", arena.addType( FunctionType{ - arena.addTypePack({{typeChecker.numberType}, listOfStrings}), - arena.addTypePack({typeChecker.numberType}) + arena.addTypePack({{builtinTypes->numberType}, listOfStrings}), + arena.addTypePack({builtinTypes->numberType}) } ), "@test" @@ -223,9 +226,16 @@ TEST_CASE_FIXTURE(Fixture, "variadic_packs") LUAU_REQUIRE_ERROR_COUNT(2, result); - CHECK_EQ(result.errors[0], (TypeError{Location(Position{3, 21}, Position{3, 26}), TypeMismatch{typeChecker.numberType, typeChecker.stringType}})); + CHECK(Location{Position{3, 21}, Position{3, 26}} == result.errors[0].location); + CHECK(Location{Position{4, 29}, Position{4, 30}} == result.errors[1].location); + + CHECK_EQ( + result.errors[0], (TypeError{Location(Position{3, 21}, Position{3, 26}), TypeMismatch{builtinTypes->numberType, builtinTypes->stringType}}) + ); - CHECK_EQ(result.errors[1], (TypeError{Location(Position{4, 29}, Position{4, 30}), TypeMismatch{typeChecker.stringType, typeChecker.numberType}})); + CHECK_EQ( + result.errors[1], (TypeError{Location(Position{4, 29}, Position{4, 30}), TypeMismatch{builtinTypes->stringType, builtinTypes->numberType}}) + ); } TEST_CASE_FIXTURE(Fixture, "variadic_pack_syntax") @@ -243,6 +253,7 @@ TEST_CASE_FIXTURE(Fixture, "variadic_pack_syntax") CHECK_EQ(toString(requireType("foo")), "(...number) -> ()"); } +#if 0 TEST_CASE_FIXTURE(Fixture, "type_pack_hidden_free_tail_infinite_growth") { CheckResult result = check(R"( @@ -259,6 +270,7 @@ end LUAU_REQUIRE_ERRORS(result); } +#endif TEST_CASE_FIXTURE(Fixture, "variadic_argument_tail") { @@ -304,12 +316,18 @@ local c: Packed tf = lookupType("Packed"); REQUIRE(tf); CHECK_EQ(toString(*tf), "Packed"); - CHECK_EQ(toString(*tf, {true}), "{| f: (T, U...) -> (T, U...) |}"); + if (FFlag::LuauSolverV2) + CHECK_EQ(toString(*tf, {true}), "{ f: (T, U...) -> (T, U...) }"); + else + CHECK_EQ(toString(*tf, {true}), "{| f: (T, U...) -> (T, U...) |}"); auto ttvA = get(requireType("a")); REQUIRE(ttvA); CHECK_EQ(toString(requireType("a")), "Packed"); - CHECK_EQ(toString(requireType("a"), {true}), "{| f: (number) -> number |}"); + if (FFlag::LuauSolverV2) + CHECK_EQ(toString(requireType("a"), {true}), "{ f: (number) -> number }"); + else + CHECK_EQ(toString(requireType("a"), {true}), "{| f: (number) -> number |}"); REQUIRE(ttvA->instantiatedTypeParams.size() == 1); REQUIRE(ttvA->instantiatedTypePackParams.size() == 1); CHECK_EQ(toString(ttvA->instantiatedTypeParams[0], {true}), "number"); @@ -318,7 +336,10 @@ local c: Packed auto ttvB = get(requireType("b")); REQUIRE(ttvB); CHECK_EQ(toString(requireType("b")), "Packed"); - CHECK_EQ(toString(requireType("b"), {true}), "{| f: (string, number) -> (string, number) |}"); + if (FFlag::LuauSolverV2) + CHECK_EQ(toString(requireType("b"), {true}), "{ f: (string, number) -> (string, number) }"); + else + CHECK_EQ(toString(requireType("b"), {true}), "{| f: (string, number) -> (string, number) |}"); REQUIRE(ttvB->instantiatedTypeParams.size() == 1); REQUIRE(ttvB->instantiatedTypePackParams.size() == 1); CHECK_EQ(toString(ttvB->instantiatedTypeParams[0], {true}), "string"); @@ -327,7 +348,10 @@ local c: Packed auto ttvC = get(requireType("c")); REQUIRE(ttvC); CHECK_EQ(toString(requireType("c")), "Packed"); - CHECK_EQ(toString(requireType("c"), {true}), "{| f: (string, number, boolean) -> (string, number, boolean) |}"); + if (FFlag::LuauSolverV2) + CHECK_EQ(toString(requireType("c"), {true}), "{ f: (string, number, boolean) -> (string, number, boolean) }"); + else + CHECK_EQ(toString(requireType("c"), {true}), "{| f: (string, number, boolean) -> (string, number, boolean) |}"); REQUIRE(ttvC->instantiatedTypeParams.size() == 1); REQUIRE(ttvC->instantiatedTypePackParams.size() == 1); CHECK_EQ(toString(ttvC->instantiatedTypeParams[0], {true}), "string"); @@ -356,12 +380,25 @@ local d: { a: typeof(c) } auto tf = lookupImportedType("Import", "Packed"); REQUIRE(tf); CHECK_EQ(toString(*tf), "Packed"); - CHECK_EQ(toString(*tf, {true}), "{| a: T, b: (U...) -> () |}"); - CHECK_EQ(toString(requireType("a"), {true}), "{| a: number, b: () -> () |}"); - CHECK_EQ(toString(requireType("b"), {true}), "{| a: string, b: (number) -> () |}"); - CHECK_EQ(toString(requireType("c"), {true}), "{| a: string, b: (number, boolean) -> () |}"); - CHECK_EQ(toString(requireType("d")), "{| a: Packed |}"); + if (FFlag::LuauSolverV2) + { + CHECK_EQ(toString(*tf, {true}), "{ a: T, b: (U...) -> () }"); + + CHECK_EQ(toString(requireType("a"), {true}), "{ a: number, b: () -> () }"); + CHECK_EQ(toString(requireType("b"), {true}), "{ a: string, b: (number) -> () }"); + CHECK_EQ(toString(requireType("c"), {true}), "{ a: string, b: (number, boolean) -> () }"); + CHECK_EQ(toString(requireType("d")), "{ a: Packed }"); + } + else + { + CHECK_EQ(toString(*tf, {true}), "{| a: T, b: (U...) -> () |}"); + + CHECK_EQ(toString(requireType("a"), {true}), "{| a: number, b: () -> () |}"); + CHECK_EQ(toString(requireType("b"), {true}), "{| a: string, b: (number) -> () |}"); + CHECK_EQ(toString(requireType("c"), {true}), "{| a: string, b: (number, boolean) -> () |}"); + CHECK_EQ(toString(requireType("d")), "{| a: Packed |}"); + } } TEST_CASE_FIXTURE(BuiltinsFixture, "type_pack_type_parameters") @@ -384,19 +421,31 @@ type C = Import.Packed auto tf = lookupType("Alias"); REQUIRE(tf); CHECK_EQ(toString(*tf), "Alias"); - CHECK_EQ(toString(*tf, {true}), "{| a: S, b: (T, R...) -> () |}"); + if (FFlag::LuauSolverV2) + CHECK_EQ(toString(*tf, {true}), "{ a: S, b: (T, R...) -> () }"); + else + CHECK_EQ(toString(*tf, {true}), "{| a: S, b: (T, R...) -> () |}"); - CHECK_EQ(toString(requireType("a"), {true}), "{| a: string, b: (number, boolean) -> () |}"); + if (FFlag::LuauSolverV2) + CHECK_EQ(toString(requireType("a"), {true}), "{ a: string, b: (number, boolean) -> () }"); + else + CHECK_EQ(toString(requireType("a"), {true}), "{| a: string, b: (number, boolean) -> () |}"); tf = lookupType("B"); REQUIRE(tf); CHECK_EQ(toString(*tf), "B"); - CHECK_EQ(toString(*tf, {true}), "{| a: string, b: (X...) -> () |}"); + if (FFlag::LuauSolverV2) + CHECK_EQ(toString(*tf, {true}), "{ a: string, b: (X...) -> () }"); + else + CHECK_EQ(toString(*tf, {true}), "{| a: string, b: (X...) -> () |}"); tf = lookupType("C"); REQUIRE(tf); CHECK_EQ(toString(*tf), "C"); - CHECK_EQ(toString(*tf, {true}), "{| a: string, b: (number, X...) -> () |}"); + if (FFlag::LuauSolverV2) + CHECK_EQ(toString(*tf, {true}), "{ a: string, b: (number, X...) -> () }"); + else + CHECK_EQ(toString(*tf, {true}), "{| a: string, b: (number, X...) -> () |}"); } TEST_CASE_FIXTURE(Fixture, "type_alias_type_packs_nested") @@ -412,9 +461,11 @@ type Packed4 = (Packed3, T...) -> (Packed3, T...) auto tf = lookupType("Packed4"); REQUIRE(tf); - CHECK_EQ(toString(*tf), + CHECK_EQ( + toString(*tf), "((((T...) -> (T...), T...) -> ((T...) -> (T...), T...), T...) -> (((T...) -> (T...), T...) -> ((T...) -> (T...), T...), T...), T...) -> " - "((((T...) -> (T...), T...) -> ((T...) -> (T...), T...), T...) -> (((T...) -> (T...), T...) -> ((T...) -> (T...), T...), T...), T...)"); + "((((T...) -> (T...), T...) -> ((T...) -> (T...), T...), T...) -> (((T...) -> (T...), T...) -> ((T...) -> (T...), T...), T...), T...)" + ); } TEST_CASE_FIXTURE(Fixture, "type_alias_type_pack_variadic") @@ -524,12 +575,12 @@ local b: Y<(), ()> TEST_CASE_FIXTURE(Fixture, "type_alias_backwards_compatible") { CheckResult result = check(R"( -type X = () -> T -type Y = (T) -> U + type X = () -> T + type Y = (T) -> U -type A = X<(number)> -type B = Y<(number), (boolean)> -type C = Y<(number), boolean> + type A = X<(number)> + type B = Y<(number), (boolean)> + type C = Y<(number), boolean> )"); LUAU_REQUIRE_NO_ERRORS(result); @@ -736,48 +787,69 @@ local d: Y ()> TEST_CASE_FIXTURE(Fixture, "type_alias_default_type_errors") { + ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + CheckResult result = check(R"( -type Y = { a: T } -local a: Y = { a = 2 } + type Y = { a: T } + local a: Y = { a = 2 } )"); LUAU_REQUIRE_ERROR_COUNT(1, result); CHECK_EQ(toString(result.errors[0]), "Unknown type 'T'"); +} - result = check(R"( -type Y = { a: (T...) -> () } -local a: Y<> +TEST_CASE_FIXTURE(Fixture, "type_alias_default_type_errors2") +{ + CheckResult result = check(R"( + type Y = { a: (T...) -> () } + local a: Y<> )"); LUAU_REQUIRE_ERROR_COUNT(1, result); CHECK_EQ(toString(result.errors[0]), "Unknown type 'T'"); +} - result = check(R"( -type Y = { a: (T) -> U... } -local a: Y<...number> +TEST_CASE_FIXTURE(Fixture, "type_alias_default_type_errors3") +{ + ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + + CheckResult result = check(R"( + type Y = { a: (T) -> U... } + local a: Y<...number> )"); LUAU_REQUIRE_ERROR_COUNT(1, result); CHECK_EQ(toString(result.errors[0]), "Generic type 'Y' expects at least 1 type argument, but none are specified"); +} - result = check(R"( -type Packed = (T) -> T -local a: Packed +TEST_CASE_FIXTURE(Fixture, "type_alias_default_type_errors4") +{ + ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + + CheckResult result = check(R"( + type Packed = (T) -> T + local a: Packed )"); LUAU_REQUIRE_ERROR_COUNT(1, result); CHECK_EQ(toString(result.errors[0]), "Type parameter list is required"); +} - result = check(R"( -type Y = { a: T } -local a: Y +TEST_CASE_FIXTURE(Fixture, "type_alias_default_type_errors5") +{ + CheckResult result = check(R"( + type Y = { a: T } + local a: Y )"); LUAU_REQUIRE_ERRORS(result); +} - result = check(R"( -type Y = { a: T } -local a: Y<...number> +TEST_CASE_FIXTURE(Fixture, "type_alias_default_type_errors6") +{ + CheckResult result = check(R"( + type Y = { a: T } + local a: Y<...number> )"); LUAU_REQUIRE_ERRORS(result); @@ -863,7 +935,10 @@ type R = { m: F } LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(toString(*lookupType("R"), {true}), "t1 where t1 = {| m: (t1) -> (t1) -> () |}"); + if (FFlag::LuauSolverV2) + CHECK_EQ(toString(*lookupType("R"), {true}), "t1 where t1 = { m: (t1) -> (t1) -> () }"); + else + CHECK_EQ(toString(*lookupType("R"), {true}), "t1 where t1 = {| m: (t1) -> (t1) -> () |}"); } TEST_CASE_FIXTURE(Fixture, "pack_tail_unification_check") @@ -875,9 +950,27 @@ a = b )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(toString(result.errors[0]), R"(Type '() -> (number, ...boolean)' could not be converted into '() -> (number, ...string)' + + if (FFlag::LuauSolverV2) + { + const std::string expected = + "Type\n" + " '() -> (number, ...boolean)'\n" + "could not be converted into\n" + " '() -> (number, ...string)'; at returns().tail().variadic(), boolean is not a subtype of string"; + + CHECK(expected == toString(result.errors[0])); + } + else + { + const std::string expected = R"(Type + '() -> (number, ...boolean)' +could not be converted into + '() -> (number, ...string)' caused by: - Type 'boolean' could not be converted into 'string')"); + Type 'boolean' could not be converted into 'string')"; + CHECK_EQ(expected, toString(result.errors[0])); + } } // TODO: File a Jira about this @@ -964,11 +1057,16 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "detect_cyclic_typepacks2") end )"); - LUAU_REQUIRE_ERRORS(result); + LUAU_REQUIRE_ERROR_COUNT(2, result); + + CHECK("Unknown type 't0'" == toString(result.errors[0])); + CHECK(get(result.errors[1])); } TEST_CASE_FIXTURE(Fixture, "unify_variadic_tails_in_arguments") { + ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + CheckResult result = check(R"( function foo(...: string): number return 1 @@ -996,14 +1094,20 @@ TEST_CASE_FIXTURE(Fixture, "unify_variadic_tails_in_arguments_free") )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(toString(result.errors[0]), "Type 'number' could not be converted into 'boolean'"); + if (FFlag::LuauSolverV2) + CHECK( + toString(result.errors.at(0)) == + "Type pack '...number' could not be converted into 'boolean'; type ...number.tail() (...number) is not a subtype of boolean (boolean)" + ); + else + CHECK_EQ(toString(result.errors[0]), "Type 'number' could not be converted into 'boolean'"); } TEST_CASE_FIXTURE(BuiltinsFixture, "type_packs_with_tails_in_vararg_adjustment") { std::optional sff; - if (FFlag::DebugLuauDeferredConstraintResolution) - sff = {"LuauInstantiateInSubtyping", true}; + if (FFlag::LuauSolverV2) + sff = {FFlag::LuauInstantiateInSubtyping, true}; CheckResult result = check(R"( local function wrapReject(fn: (self: any, ...TArg) -> ...TResult): (self: any, ...TArg) -> ...TResult @@ -1023,8 +1127,8 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "type_packs_with_tails_in_vararg_adjustment") TEST_CASE_FIXTURE(BuiltinsFixture, "generalize_expectedTypes_with_proper_scope") { ScopedFastFlag sff[] = { - {"DebugLuauDeferredConstraintResolution", true}, - {"LuauInstantiateInSubtyping", true}, + {FFlag::LuauSolverV2, true}, + {FFlag::LuauInstantiateInSubtyping, true}, }; CheckResult result = check(R"( @@ -1058,4 +1162,14 @@ end LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(Fixture, "type_param_overflow") +{ + CheckResult result = check(R"( + type Two = { a: T, b: U } + local x: Two = { a = 1, b = 'c' } + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.typestates.test.cpp b/tests/TypeInfer.typestates.test.cpp new file mode 100644 index 000000000..0bce75465 --- /dev/null +++ b/tests/TypeInfer.typestates.test.cpp @@ -0,0 +1,526 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Fixture.h" + +#include "doctest.h" + +LUAU_FASTFLAG(LuauSolverV2) + +using namespace Luau; + +namespace +{ +struct TypeStateFixture : BuiltinsFixture +{ + ScopedFastFlag dcr{FFlag::LuauSolverV2, true}; +}; +} // namespace + +TEST_SUITE_BEGIN("TypeStatesTest"); + +TEST_CASE_FIXTURE(TypeStateFixture, "initialize_x_of_type_string_or_nil_with_nil") +{ + CheckResult result = check(R"( + local x: string? = nil + local a = x + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK("string?" == toString(requireType("a"))); +} + +TEST_CASE_FIXTURE(TypeStateFixture, "extraneous_lvalues_are_populated_with_nil") +{ + CheckResult result = check(R"( + local function f(): (string, number) + return "hello", 5 + end + + local x, y, z = f() + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK("Function only returns 2 values, but 3 are required here" == toString(result.errors[0])); + CHECK("string" == toString(requireType("x"))); + CHECK("number" == toString(requireType("y"))); + CHECK("nil" == toString(requireType("z"))); +} + +TEST_CASE_FIXTURE(TypeStateFixture, "assign_different_values_to_x") +{ + CheckResult result = check(R"( + local x: string? = nil + local a = x + x = "hello!" + local b = x + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK("string?" == toString(requireType("a"))); + CHECK("string" == toString(requireType("b"))); +} + +TEST_CASE_FIXTURE(TypeStateFixture, "parameter_x_was_constrained_by_two_types") +{ + // Parameter `x` has a fresh type `'x` bounded by `never` and `unknown`. + // The first use of `x` constrains `x`'s upper bound by `string | number`. + // The second use of `x`, aliased by `y`, constrains `x`'s upper bound by `string?`. + // This results in `'x <: (string | number) & (string?)`. + // The principal type of the upper bound is `string`. + CheckResult result = check(R"( + local function f(x): string? + local y: string | number = x + return y + end + )"); + + if (FFlag::LuauSolverV2) + { + // `y` is annotated `string | number` which is explicitly not compatible with `string?` + // as such, we produce an error here for that mismatch. + // + // this is not necessarily the best inference here, since we can indeed produce `string` + // as a type for `x`, but it's a limitation we can accept for now. + LUAU_REQUIRE_ERRORS(result); + + TypePackMismatch* tpm = get(result.errors[0]); + REQUIRE(tpm); + CHECK("string?" == toString(tpm->wantedTp)); + CHECK("number | string" == toString(tpm->givenTp)); + + CHECK("(number | string) -> string?" == toString(requireType("f"))); + } + else + { + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK("(string) -> string?" == toString(requireType("f"))); + } +} + +#if 0 +TEST_CASE_FIXTURE(TypeStateFixture, "local_that_will_be_assigned_later") +{ + CheckResult result = check(R"( + local x: string + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); +} + +TEST_CASE_FIXTURE(TypeStateFixture, "refine_a_local_and_then_assign_it") +{ + CheckResult result = check(R"( + local function f(x: string?) + if typeof(x) == "string" then + x = nil + end + + local y: nil = x + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(TypeStateFixture, "assign_a_local_and_then_refine_it") +{ + CheckResult result = check(R"( + local function f(x: string?) + x = nil + + if typeof(x) == "string" then + local y: typeof(x) = "hello" + end + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK("Type 'string' could not be converted into 'never'" == toString(result.errors[0])); +} +#endif + +TEST_CASE_FIXTURE(TypeStateFixture, "recursive_local_function") +{ + CheckResult result = check(R"( + local function f(x) + f(5) + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(TypeStateFixture, "recursive_function") +{ + CheckResult result = check(R"( + function f(x) + f(5) + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(TypeStateFixture, "compound_assignment") +{ + CheckResult result = check(R"( + local x = 5 + x += 7 + + local a = x + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(TypeStateFixture, "assignment_identity") +{ + CheckResult result = check(R"( + local x = 5 + x = x + + local a = x + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK("number" == toString(requireType("a"))); +} + +TEST_CASE_FIXTURE(TypeStateFixture, "assignment_swap") +{ + CheckResult result = check(R"( + local x, y = 5, "hello" + x, y = y, x + + local a, b = x, y + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK("string" == toString(requireType("a"))); + CHECK("number" == toString(requireType("b"))); +} + +TEST_CASE_FIXTURE(TypeStateFixture, "parameter_x_was_constrained_by_two_types_2") +{ + CheckResult result = check(R"( + local function f(x): number? + local y: string? = nil -- 'y <: string? + y = x -- 'y ~ 'x + return y -- 'y <: number? + + -- We therefore infer 'y <: (string | nil) & (number | nil) + -- or 'y <: nil + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK("(nil) -> number?" == toString(requireType("f"))); +} + +TEST_CASE_FIXTURE(TypeStateFixture, "parameter_x_is_some_type_or_optional_then_assigned_with_alternate_value") +{ + CheckResult result = check(R"( + local function f(x: number?) + x = x or 5 + return x + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK("(number?) -> number" == toString(requireType("f"))); +} + +TEST_CASE_FIXTURE(TypeStateFixture, "local_assigned_in_either_branches_that_falls_through") +{ + CheckResult result = check(R"( + local x = nil + if math.random() > 0.5 then + x = 5 + else + x = "hello" + end + local y = x + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK("number | string" == toString(requireType("y"))); +} + +TEST_CASE_FIXTURE(TypeStateFixture, "local_assigned_in_only_one_branch_that_falls_through") +{ + CheckResult result = check(R"( + local x = nil + if math.random() > 0.5 then + x = 5 + end + local y = x + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK("number?" == toString(requireType("y"))); +} + +TEST_CASE_FIXTURE(TypeStateFixture, "then_branch_assigns_and_else_branch_also_assigns_but_is_met_with_return") +{ + CheckResult result = check(R"( + local x = nil + if math.random() > 0.5 then + x = 5 + else + x = "hello" + return + end + local y = x + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK("number?" == toString(requireType("y"))); +} + +TEST_CASE_FIXTURE(TypeStateFixture, "then_branch_assigns_but_is_met_with_return_and_else_branch_assigns") +{ + CheckResult result = check(R"( + local x = nil + if math.random() > 0.5 then + x = 5 + return + else + x = "hello" + end + local y = x + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK("string?" == toString(requireType("y"))); +} + +TEST_CASE_FIXTURE(TypeStateFixture, "invalidate_type_refinements_upon_assignments") +{ + CheckResult result = check(R"( + type Ok = { tag: "ok", val: T } + type Err = { tag: "err", err: E } + type Result = Ok | Err + + local function f(res: Result) + assert(res.tag == "ok") + local tag: "ok", val: T = res.tag, res.val + res = { tag = "err" :: "err", err = (5 :: any) :: E } + local tag: "err", err: E = res.tag, res.err + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +#if 0 +TEST_CASE_FIXTURE(TypeStateFixture, "local_t_is_assigned_a_fresh_table_with_x_assigned_a_union_and_then_assert_restricts_actual_outflow_of_types") +{ + CheckResult result = check(R"( + local t = nil + + if math.random() > 0.5 then + t = {} + t.x = if math.random() > 0.5 then 5 else "hello" + assert(typeof(t.x) == "string") + else + t = {} + t.x = if math.random() > 0.5 then 7 else true + assert(typeof(t.x) == "boolean") + end + + local x = t.x + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + // CHECK("boolean | string" == toString(requireType("x"))); + CHECK("boolean | number | number | string" == toString(requireType("x"))); +} +#endif + +TEST_CASE_FIXTURE(TypeStateFixture, "captured_locals_are_unions_of_all_assignments") +{ + CheckResult result = check(R"( + local x = nil + + function f() + print(x) + x = "five" + end + + x = 5 + f() + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK("(number | string)?" == toString(requireTypeAtPosition({4, 18}))); +} + +TEST_CASE_FIXTURE(TypeStateFixture, "captured_locals_are_unions_of_all_assignments_2") +{ + CheckResult result = check(R"( + local t = {x = nil} + + function f() + print(t.x) + t = {x = "five"} + end + + t = {x = 5} + f() + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK("{ x: nil } | { x: number } | { x: string }" == toString(requireTypeAtPosition({4, 18}), {true})); + CHECK("(number | string)?" == toString(requireTypeAtPosition({4, 20}))); +} + +TEST_CASE_FIXTURE(TypeStateFixture, "prototyped_recursive_functions") +{ + CheckResult result = check(R"( + local f + function f() + if math.random() > 0.5 then + f() + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK("(() -> ())?" == toString(requireType("f"))); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "prototyped_recursive_functions_but_has_future_assignments") +{ + // early return if the flag isn't set since this is blocking gated commits + if (!FFlag::LuauSolverV2) + return; + + CheckResult result = check(R"( + local f + function f() + if math.random() > 0.5 then + f() + end + end + f = 5 + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CHECK("((() -> ()) | number)?" == toString(requireType("f"))); +} + +TEST_CASE_FIXTURE(TypeStateFixture, "prototyped_recursive_functions_but_has_previous_assignments") +{ + CheckResult result = check(R"( + local f + f = 5 + function f() + if math.random() > 0.5 then + f() + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK("((() -> ()) | number)?" == toString(requireType("f"))); +} + +TEST_CASE_FIXTURE(TypeStateFixture, "multiple_assignments_in_loops") +{ + CheckResult result = check(R"( + local x = nil + + for i = 1, 10 do + x = 5 + x = "hello" + end + + print(x) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK("(number | string)?" == toString(requireType("x"))); +} + +TEST_CASE_FIXTURE(TypeStateFixture, "typestates_preserve_error_suppression") +{ + CheckResult result = check(R"( + local a: any = 51 + a = "pickles" -- We'll have a new DefId for this iteration of `a`. Its type must also be error-suppressing + print(a) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK("*error-type* | string" == toString(requireTypeAtPosition({3, 14}), {true})); +} + + +TEST_CASE_FIXTURE(BuiltinsFixture, "typestates_preserve_error_suppression_properties") +{ + // early return if the flag isn't set since this is blocking gated commits + // unconditional return + // CLI-117098 Type states with error suppressing properties doesn't infer the correct type for properties. + if (!FFlag::LuauSolverV2 || FFlag::LuauSolverV2) + return; + + CheckResult result = check(R"( + local a: {x: any} = {x = 51} + a.x = "pickles" -- We'll have a new DefId for this iteration of `a.x`. Its type must also be error-suppressing + print(a.x) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK("*error-type* | string" == toString(requireTypeAtPosition({3, 16}), {true})); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "typestates_do_not_apply_to_the_initial_local_definition") +{ + // early return if the flag isn't set since this is blocking gated commits + if (!FFlag::LuauSolverV2) + return; + + CheckResult result = check(R"( + type MyType = number | string + local foo: MyType = 5 + print(foo) + foo = 7 + print(foo) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK("number | string" == toString(requireTypeAtPosition({3, 14}), {true})); + CHECK("number" == toString(requireTypeAtPosition({5, 14}), {true})); +} + +TEST_CASE_FIXTURE(Fixture, "typestate_globals") +{ + ScopedFastFlag sff{FFlag::LuauSolverV2, true}; + + loadDefinition(R"( + declare foo: string | number + declare function f(x: string): () + )"); + + CheckResult result = check(R"( + foo = "a" + f(foo) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "typestate_unknown_global") +{ + ScopedFastFlag sff{FFlag::LuauSolverV2, true}; + + CheckResult result = check(R"( + x = 5 + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CHECK(get(result.errors[0])); +} + +TEST_SUITE_END(); diff --git a/tests/TypeInfer.unionTypes.test.cpp b/tests/TypeInfer.unionTypes.test.cpp index 6f69d6827..0303f546f 100644 --- a/tests/TypeInfer.unionTypes.test.cpp +++ b/tests/TypeInfer.unionTypes.test.cpp @@ -8,10 +8,36 @@ using namespace Luau; +LUAU_FASTFLAG(LuauSolverV2) +LUAU_FASTFLAG(LuauAcceptIndexingTableUnionsIntersections) + TEST_SUITE_BEGIN("UnionTypes"); +TEST_CASE_FIXTURE(Fixture, "fuzzer_union_with_one_part_assertion") +{ + CheckResult result = check(R"( +local _ = {},nil +repeat + +_,_ = if _.number == "" or _.number or _._ then + _ + elseif _.__index == _._G then + tostring + elseif _ then + _ + else + ``,_._G + +until _._ + )"); +} + TEST_CASE_FIXTURE(Fixture, "return_types_can_be_disjoint") { + // CLI-114134 We need egraphs to consistently reduce the cyclic union + // introduced by the increment here. + ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + CheckResult result = check(R"( local count = 0 function most_of_the_natural_numbers(): number? @@ -30,45 +56,56 @@ TEST_CASE_FIXTURE(Fixture, "return_types_can_be_disjoint") REQUIRE(utv != nullptr); } -TEST_CASE_FIXTURE(Fixture, "allow_specific_assign") +TEST_CASE_FIXTURE(Fixture, "return_types_can_be_disjoint_using_compound_assignment") { CheckResult result = check(R"( - local a:number|string = 22 + local count = 0 + function most_of_the_natural_numbers(): number? + if count < 10 then + -- count = count + 1 + count += 1 + return count + else + return nil + end + end )"); LUAU_REQUIRE_NO_ERRORS(result); + + const FunctionType* utv = get(requireType("most_of_the_natural_numbers")); + REQUIRE(utv != nullptr); } -TEST_CASE_FIXTURE(Fixture, "allow_more_specific_assign") +TEST_CASE_FIXTURE(Fixture, "allow_specific_assign") { CheckResult result = check(R"( local a:number|string = 22 - local b:number|string|nil = a )"); LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "disallow_less_specific_assign") +TEST_CASE_FIXTURE(Fixture, "allow_more_specific_assign") { CheckResult result = check(R"( - local a:number = 10 - local b:number|string = 20 - a = b + function f(a: number | string, b: (number | string)?) + b = a + end )"); - LUAU_REQUIRE_ERROR_COUNT(1, result); + LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "disallow_less_specific_assign2") +TEST_CASE_FIXTURE(Fixture, "disallow_less_specific_assign") { CheckResult result = check(R"( - local a:number? = 10 - local b:number|string? = 20 - a = b + function f(a: number, b: number | string) + a = b + end )"); - REQUIRE_EQ(1, result.errors.size()); + LUAU_REQUIRE_ERROR_COUNT(1, result); } TEST_CASE_FIXTURE(Fixture, "optional_arguments") @@ -84,6 +121,9 @@ TEST_CASE_FIXTURE(Fixture, "optional_arguments") TEST_CASE_FIXTURE(Fixture, "optional_arguments_table") { + // CLI-115588 - Bidirectional inference does not happen for assignments + ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + CheckResult result = check(R"( local a:{a:string, b:string?} a = {a="ok"} @@ -125,13 +165,14 @@ TEST_CASE_FIXTURE(Fixture, "index_on_a_union_type_with_property_guaranteed_to_ex CheckResult result = check(R"( type A = {x: number} type B = {x: number} - local t: A | B - local r = t.x + function f(t: A | B) + return t.x + end )"); LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(*typeChecker.numberType, *requireType("r")); + CHECK_EQ("(A | B) -> number", toString(requireType("f"))); } TEST_CASE_FIXTURE(Fixture, "index_on_a_union_type_with_mixed_types") @@ -139,13 +180,14 @@ TEST_CASE_FIXTURE(Fixture, "index_on_a_union_type_with_mixed_types") CheckResult result = check(R"( type A = {x: number} type B = {x: string} - local t: A | B - local r = t.x + function f(t: A | B) + return t.x + end )"); LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("number | string", toString(requireType("r"))); + CHECK_EQ("(A | B) -> number | string", toString(requireType("f"))); } TEST_CASE_FIXTURE(Fixture, "index_on_a_union_type_works_at_arbitrary_depth") @@ -153,13 +195,14 @@ TEST_CASE_FIXTURE(Fixture, "index_on_a_union_type_works_at_arbitrary_depth") CheckResult result = check(R"( type A = {x: {y: {z: {thing: number}}}} type B = {x: {y: {z: {thing: string}}}} - local t: A | B - local r = t.x.y.z.thing + function f(t: A | B) + return t.x.y.z.thing + end )"); LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("number | string", toString(requireType("r"))); + CHECK_EQ("(A | B) -> number | string", toString(requireType("f"))); } TEST_CASE_FIXTURE(Fixture, "index_on_a_union_type_with_one_optional_property") @@ -167,13 +210,14 @@ TEST_CASE_FIXTURE(Fixture, "index_on_a_union_type_with_one_optional_property") CheckResult result = check(R"( type A = {x: number} type B = {x: number?} - local t: A | B - local r = t.x + function f(t: A | B) + return t.x + end )"); LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("number?", toString(requireType("r"))); + CHECK_EQ("(A | B) -> number?", toString(requireType("f"))); } TEST_CASE_FIXTURE(Fixture, "index_on_a_union_type_with_missing_property") @@ -181,23 +225,22 @@ TEST_CASE_FIXTURE(Fixture, "index_on_a_union_type_with_missing_property") CheckResult result = check(R"( type A = {x: number} type B = {} - local t: A | B - local r = t.x + function f(t: A | B) + return t.x + end )"); LUAU_REQUIRE_ERROR_COUNT(1, result); MissingUnionProperty* mup = get(result.errors[0]); REQUIRE(mup); - CHECK_EQ(mup->type, requireType("t")); - REQUIRE(mup->missing.size() == 1); - std::optional bTy = lookupType("B"); - REQUIRE(bTy); - CHECK_EQ(mup->missing[0], *bTy); - CHECK_EQ(mup->key, "x"); + CHECK_EQ("Key 'x' is missing from 'B' in the type 'A | B'", toString(result.errors[0])); - CHECK_EQ("*error-type*", toString(requireType("r"))); + if (FFlag::LuauSolverV2) + CHECK_EQ("(A | B) -> number", toString(requireType("f"))); + else + CHECK_EQ("(A | B) -> *error-type*", toString(requireType("f"))); } TEST_CASE_FIXTURE(Fixture, "index_on_a_union_type_with_one_property_of_type_any") @@ -205,13 +248,14 @@ TEST_CASE_FIXTURE(Fixture, "index_on_a_union_type_with_one_property_of_type_any" CheckResult result = check(R"( type A = {x: number} type B = {x: any} - local t: A | B - local r = t.x + function f(t: A | B) + return t.x + end )"); LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(*typeChecker.anyType, *requireType("r")); + CHECK_EQ("(A | B) -> any", toString(requireType("f"))); } TEST_CASE_FIXTURE(Fixture, "union_equality_comparisons") @@ -221,14 +265,13 @@ TEST_CASE_FIXTURE(Fixture, "union_equality_comparisons") type B = number | nil type C = number | boolean - local a: A = 1 - local b: B = nil - local c: C = true - local n = 1 + function f(a: A, b: B, c: C) + local n = 1 - local x = a == b - local y = a == n - local z = a == c + local x = a == b + local y = a == n + local z = a == c + end )"); LUAU_REQUIRE_NO_ERRORS(result); @@ -237,16 +280,17 @@ TEST_CASE_FIXTURE(Fixture, "union_equality_comparisons") TEST_CASE_FIXTURE(Fixture, "optional_union_members") { CheckResult result = check(R"( -local a = { a = { x = 1, y = 2 }, b = 3 } -type A = typeof(a) -local b: A? = a -local bf = b -local c = bf.a.y + local a = { a = { x = 1, y = 2 }, b = 3 } + type A = typeof(a) + function f(b: A?) + return b.a.y + end )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(*typeChecker.numberType, *requireType("c")); + CHECK_EQ("Value of type 'A?' could be nil", toString(result.errors[0])); + CHECK_EQ("(A?) -> number", toString(requireType("f"))); } TEST_CASE_FIXTURE(Fixture, "optional_union_functions") @@ -255,37 +299,41 @@ TEST_CASE_FIXTURE(Fixture, "optional_union_functions") local a = {} function a.foo(x:number, y:number) return x + y end type A = typeof(a) - local b: A? = a - local c = b.foo(1, 2) + function f(b: A?) + return b.foo(1, 2) + end )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(*typeChecker.numberType, *requireType("c")); + CHECK_EQ("Value of type 'A?' could be nil", toString(result.errors[0])); + CHECK_EQ("(A?) -> number", toString(requireType("f"))); } TEST_CASE_FIXTURE(Fixture, "optional_union_methods") { CheckResult result = check(R"( -local a = {} -function a:foo(x:number, y:number) return x + y end -type A = typeof(a) -local b: A? = a -local c = b:foo(1, 2) + local a = {} + function a:foo(x:number, y:number) return x + y end + type A = typeof(a) + function f(b: A?) + return b:foo(1, 2) + end )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(*typeChecker.numberType, *requireType("c")); + CHECK_EQ("Value of type 'A?' could be nil", toString(result.errors[0])); + CHECK_EQ("(A?) -> number", toString(requireType("f"))); } TEST_CASE_FIXTURE(Fixture, "optional_union_follow") { CheckResult result = check(R"( -local y: number? = 2 -local x = y -local function f(a: number, b: typeof(x), c: typeof(x)) return -a end -return f() + local y: number? = 2 + local x = y + function f(a: number, b: number?, c: number?) return -a end + return f() )"); LUAU_REQUIRE_ERROR_COUNT(1, result); @@ -300,10 +348,11 @@ return f() TEST_CASE_FIXTURE(Fixture, "optional_field_access_error") { CheckResult result = check(R"( -type A = { x: number } -local b: A? = { x = 2 } -local c = b.x -local d = b.y + type A = { x: number } + function f(b: A?) + local c = b.x + local d = b.y + end )"); LUAU_REQUIRE_ERROR_COUNT(3, result); @@ -315,9 +364,10 @@ local d = b.y TEST_CASE_FIXTURE(Fixture, "optional_index_error") { CheckResult result = check(R"( -type A = {number} -local a: A? = {1, 2, 3} -local b = a[1] + type A = {number} + function f(a: A?) + local b = a[1] + end )"); LUAU_REQUIRE_ERROR_COUNT(1, result); @@ -327,9 +377,10 @@ local b = a[1] TEST_CASE_FIXTURE(Fixture, "optional_call_error") { CheckResult result = check(R"( -type A = (number) -> number -local a: A? = function(a) return -a end -local b = a(4) + type A = (number) -> number + function f(a: A?) + local b = a(4) + end )"); LUAU_REQUIRE_ERROR_COUNT(1, result); @@ -339,60 +390,75 @@ local b = a(4) TEST_CASE_FIXTURE(Fixture, "optional_assignment_errors") { CheckResult result = check(R"( -type A = { x: number } -local a: A? = { x = 2 } -a.x = 2 + type A = { x: number } + function f(a: A?) + a.x = 2 + end )"); LUAU_REQUIRE_ERROR_COUNT(1, result); CHECK_EQ("Value of type 'A?' could be nil", toString(result.errors[0])); +} - result = check(R"( -type A = { x: number } & { y: number } -local a: A? = { x = 2, y = 3 } -a.x = 2 +TEST_CASE_FIXTURE(Fixture, "optional_assignment_errors_2") +{ + CheckResult result = check(R"( + type A = { x: number } & { y: number } + function f(a: A?) + a.x = 2 + end )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ("Value of type '({| x: number |} & {| y: number |})?' could be nil", toString(result.errors[0])); + auto s = toString(result.errors[0]); + if (FFlag::LuauSolverV2) + CHECK_EQ("Value of type '({ x: number } & { y: number })?' could be nil", s); + else + CHECK_EQ("Value of type '({| x: number |} & {| y: number |})?' could be nil", s); } TEST_CASE_FIXTURE(Fixture, "optional_length_error") { + + ScopedFastFlag _{FFlag::LuauSolverV2, true}; + CheckResult result = check(R"( -type A = {number} -local a: A? = {1, 2, 3} -local b = #a + type A = {number} + function f(a: A?) + local b = #a + end )"); - LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ("Value of type 'A?' could be nil", toString(result.errors[0])); + // CLI-119936: This shouldn't double error but does under the new solver. + LUAU_REQUIRE_ERROR_COUNT(2, result); + CHECK_EQ("Operator '#' could not be applied to operand of type A?; there is no corresponding overload for __len", toString(result.errors[0])); + CHECK_EQ("Value of type 'A?' could be nil", toString(result.errors[1])); } TEST_CASE_FIXTURE(Fixture, "optional_missing_key_error_details") { CheckResult result = check(R"( -type A = { x: number, y: number } -type B = { x: number, y: number } -type C = { x: number } -type D = { x: number } - -local a: A|B|C|D -local b = a.y + type A = { x: number, y: number } + type B = { x: number, y: number } + type C = { x: number } + type D = { x: number } -local c: A|(B|C)?|D -local d = c.y + function f(a: A | B | C | D) + local y = a.y + local z = a.z + end -local e = a.z + function g(c: A | B | C | D | nil) + local d = c.y + end )"); LUAU_REQUIRE_ERROR_COUNT(4, result); CHECK_EQ("Key 'y' is missing from 'C', 'D' in the type 'A | B | C | D'", toString(result.errors[0])); + CHECK_EQ("Type 'A | B | C | D' does not have key 'z'", toString(result.errors[1])); - CHECK_EQ("Value of type '(A | B | C | D)?' could be nil", toString(result.errors[1])); - CHECK_EQ("Key 'y' is missing from 'C', 'D' in the type 'A | B | C | D'", toString(result.errors[2])); - - CHECK_EQ("Type 'A | B | C | D' does not have key 'z'", toString(result.errors[3])); + CHECK_EQ("Value of type '(A | B | C | D)?' could be nil", toString(result.errors[2])); + CHECK_EQ("Key 'y' is missing from 'C', 'D' in the type 'A | B | C | D'", toString(result.errors[3])); } TEST_CASE_FIXTURE(Fixture, "optional_iteration") @@ -412,6 +478,8 @@ end TEST_CASE_FIXTURE(Fixture, "unify_unsealed_table_union_check") { + ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + CheckResult result = check(R"( local x = { x = 3 } type A = number? @@ -464,30 +532,49 @@ type Z = { z: number } type XYZ = X | Y | Z -local a: XYZ -local b: { w: number } = a +function f(a: XYZ) + local b: { w: number } = a +end )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(toString(result.errors[0]), R"(Type 'X | Y | Z' could not be converted into '{| w: number |}' + + if (FFlag::LuauSolverV2) + { + CHECK_EQ( + toString(result.errors[0]), + "Type 'X | Y | Z' could not be converted into '{ w: number }'; type X | Y | Z[0] (X) is not a subtype " + "of { w: number } ({ w: number })\n\t" + "type X | Y | Z[1] (Y) is not a subtype of { w: number } ({ w: number })\n\t" + "type X | Y | Z[2] (Z) is not a subtype of { w: number } ({ w: number })" + ); + } + else + { + CHECK_EQ(toString(result.errors[0]), R"(Type 'X | Y | Z' could not be converted into '{| w: number |}' caused by: - Not all union options are compatible. Table type 'X' not compatible with type '{| w: number |}' because the former is missing field 'w')"); + Not all union options are compatible. +Table type 'X' not compatible with type '{| w: number |}' because the former is missing field 'w')"); + } } TEST_CASE_FIXTURE(Fixture, "error_detailed_union_all") { CheckResult result = check(R"( -type X = { x: number } -type Y = { y: number } -type Z = { z: number } + type X = { x: number } + type Y = { y: number } + type Z = { z: number } -type XYZ = X | Y | Z + type XYZ = X | Y | Z -local a: XYZ = { w = 4 } + local a: XYZ = { w = 4 } )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(toString(result.errors[0]), R"(Type 'a' could not be converted into 'X | Y | Z'; none of the union options are compatible)"); + if (FFlag::LuauSolverV2) + CHECK(toString(result.errors[0]) == "Type '{ w: number }' could not be converted into 'X | Y | Z'"); + else + CHECK_EQ(toString(result.errors[0]), R"(Type 'a' could not be converted into 'X | Y | Z'; none of the union options are compatible)"); } TEST_CASE_FIXTURE(Fixture, "error_detailed_optional") @@ -499,9 +586,16 @@ local a: X? = { w = 4 } )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(toString(result.errors[0]), R"(Type 'a' could not be converted into 'X?' + if (FFlag::LuauSolverV2) + CHECK("Type '{ w: number }' could not be converted into 'X?'" == toString(result.errors[0])); + else + { + const std::string expected = R"(Type 'a' could not be converted into 'X?' caused by: - None of the union options are compatible. For example: Table type 'a' not compatible with type 'X' because the former is missing field 'x')"); + None of the union options are compatible. For example: +Table type 'a' not compatible with type 'X' because the former is missing field 'x')"; + CHECK_EQ(expected, toString(result.errors[0])); + } } // We had a bug where a cyclic union caused a stack overflow. @@ -520,37 +614,80 @@ TEST_CASE_FIXTURE(Fixture, "dont_allow_cyclic_unions_to_be_inferred") LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(Fixture, "indexing_into_a_cyclic_union_doesnt_crash") +{ + // It shouldn't be possible to craft a cyclic union, but even if we do, we + // shouldn't blow up. + + TypeArena& arena = frontend.globals.globalTypes; + unfreeze(arena); + + TypeId badCyclicUnionTy = arena.freshType(frontend.globals.globalScope.get()); + UnionType u; + + u.options.push_back(badCyclicUnionTy); + u.options.push_back(arena.addType(TableType{ + {}, TableIndexer{builtinTypes->numberType, builtinTypes->numberType}, TypeLevel{}, frontend.globals.globalScope.get(), TableState::Sealed + })); + + asMutable(badCyclicUnionTy)->ty.emplace(std::move(u)); + + frontend.globals.globalScope->exportedTypeBindings["BadCyclicUnion"] = TypeFun{{}, badCyclicUnionTy}; + + freeze(arena); + + CheckResult result = check(R"( + function f(x: BadCyclicUnion) + return x[0] + end + )"); + + // this is a cyclic union of number arrays, so it _is_ a table, even if it's a nonsense type. + // no need to generate a NotATable error here. The new solver automatically handles this and + // correctly reports no errors. + if (FFlag::LuauAcceptIndexingTableUnionsIntersections || FFlag::LuauSolverV2) + LUAU_REQUIRE_NO_ERRORS(result); + else + LUAU_REQUIRE_ERROR_COUNT(1, result); +} + TEST_CASE_FIXTURE(BuiltinsFixture, "table_union_write_indirect") { + ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + CheckResult result = check(R"( type A = { x: number, y: (number) -> string } | { z: number, y: (number) -> string } - local a:A = nil - - function a.y(x) - return tostring(x * 2) - end + function f(a: A) + function a.y(x) + return tostring(x * 2) + end - function a.y(x: string): number - return tonumber(x) or 0 + function a.y(x: string): number + return tonumber(x) or 0 + end end )"); LUAU_REQUIRE_ERROR_COUNT(1, result); // NOTE: union normalization will improve this message - CHECK_EQ(toString(result.errors[0]), - R"(Type '(string) -> number' could not be converted into '((number) -> string) | ((number) -> string)'; none of the union options are compatible)"); + const std::string expected = R"(Type + '(string) -> number' +could not be converted into + '((number) -> string) | ((number) -> string)'; none of the union options are compatible)"; + CHECK_EQ(expected, toString(result.errors[0])); } TEST_CASE_FIXTURE(Fixture, "union_true_and_false") { CheckResult result = check(R"( - local x : boolean - local y1 : (true | false) = x -- OK - local y2 : (true | false | (string & number)) = x -- OK - local y3 : (true | (string & number) | false) = x -- OK - local y4 : (true | (boolean & true) | false) = x -- OK - )"); + function f(x : boolean) + local y1 : (true | false) = x -- OK + local y2 : (true | false | (string & number)) = x -- OK + local y3 : (true | (string & number) | false) = x -- OK + local y4 : (true | (boolean & true) | false) = x -- OK + end + )"); LUAU_REQUIRE_NO_ERRORS(result); } @@ -558,8 +695,9 @@ TEST_CASE_FIXTURE(Fixture, "union_true_and_false") TEST_CASE_FIXTURE(Fixture, "union_of_functions") { CheckResult result = check(R"( - local x : (number) -> number? - local y : ((number?) -> number?) | ((number) -> number) = x -- OK + function f(x : (number) -> number?) + local y : ((number?) -> number?) | ((number) -> number) = x -- OK + end )"); LUAU_REQUIRE_NO_ERRORS(result); @@ -568,8 +706,9 @@ TEST_CASE_FIXTURE(Fixture, "union_of_functions") TEST_CASE_FIXTURE(Fixture, "union_of_generic_functions") { CheckResult result = check(R"( - local x : (a) -> a? - local y : ((a?) -> a?) | ((b) -> b) = x -- Not OK + function f(x : (a) -> a?) + local y : ((a?) -> a?) | ((b) -> b) = x -- Not OK + end )"); // TODO: should this example typecheck? @@ -579,8 +718,9 @@ TEST_CASE_FIXTURE(Fixture, "union_of_generic_functions") TEST_CASE_FIXTURE(Fixture, "union_of_generic_typepack_functions") { CheckResult result = check(R"( - local x : (number, a...) -> (number?, a...) - local y : ((number?, a...) -> (number?, a...)) | ((number, b...) -> (number, b...)) = x -- Not OK + function f(x : (number, a...) -> (number?, a...)) + local y : ((number?, a...) -> (number?, a...)) | ((number, b...) -> (number, b...)) = x -- Not OK + end )"); // TODO: should this example typecheck? @@ -589,102 +729,151 @@ TEST_CASE_FIXTURE(Fixture, "union_of_generic_typepack_functions") TEST_CASE_FIXTURE(Fixture, "union_of_functions_mentioning_generics") { + ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + CheckResult result = check(R"( - function f() - local x : (a) -> a? - local y : ((a?) -> nil) | ((a) -> a) = x -- OK - local z : ((b?) -> nil) | ((b) -> b) = x -- Not OK - end - )"); + function f() + function g(x : (a) -> a?) + local y : ((a?) -> nil) | ((a) -> a) = x -- OK + local z : ((b?) -> nil) | ((b) -> b) = x -- Not OK + end + end + )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(toString(result.errors[0]), - "Type '(a) -> a?' could not be converted into '((b) -> b) | ((b?) -> nil)'; none of the union options are compatible"); + CHECK_EQ( + toString(result.errors[0]), + "Type '(a) -> a?' could not be converted into '((b) -> b) | ((b?) -> nil)'; none of the union options are compatible" + ); } TEST_CASE_FIXTURE(Fixture, "union_of_functions_mentioning_generic_typepacks") { + ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + CheckResult result = check(R"( - function f() - local x : (number, a...) -> (number?, a...) - local y : ((number | string, a...) -> (number, a...)) | ((number?, a...) -> (nil, a...)) = x -- OK - local z : ((number) -> number) | ((number?, a...) -> (number?, a...)) = x -- Not OK - end - )"); + function f() + function g(x : (number, a...) -> (number?, a...)) + local y : ((number | string, a...) -> (number, a...)) | ((number?, a...) -> (nil, a...)) = x -- OK + local z : ((number) -> number) | ((number?, a...) -> (number?, a...)) = x -- Not OK + end + end + )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(toString(result.errors[0]), "Type '(number, a...) -> (number?, a...)' could not be converted into '((number) -> number) | ((number?, " - "a...) -> (number?, a...))'; none of the union options are compatible"); + const std::string expected = R"(Type + '(number, a...) -> (number?, a...)' +could not be converted into + '((number) -> number) | ((number?, a...) -> (number?, a...))'; none of the union options are compatible)"; + CHECK_EQ(expected, toString(result.errors[0])); } TEST_CASE_FIXTURE(Fixture, "union_of_functions_with_mismatching_arg_arities") { + ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + CheckResult result = check(R"( - local x : (number) -> number? - local y : ((number?) -> number) | ((number | string) -> nil) = x -- OK - local z : ((number, string?) -> number) | ((number) -> nil) = x -- Not OK + function f(x : (number) -> number?) + local y : ((number?) -> number) | ((number | string) -> nil) = x -- OK + local z : ((number, string?) -> number) | ((number) -> nil) = x -- Not OK + end )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(toString(result.errors[0]), "Type '(number) -> number?' could not be converted into '((number) -> nil) | ((number, string?) -> " - "number)'; none of the union options are compatible"); + const std::string expected = R"(Type + '(number) -> number?' +could not be converted into + '((number) -> nil) | ((number, string?) -> number)'; none of the union options are compatible)"; + CHECK_EQ(expected, toString(result.errors[0])); } TEST_CASE_FIXTURE(Fixture, "union_of_functions_with_mismatching_result_arities") { + ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + CheckResult result = check(R"( - local x : () -> (number | string) - local y : (() -> number) | (() -> string) = x -- OK - local z : (() -> number) | (() -> (string, string)) = x -- Not OK + function f(x : () -> (number | string)) + local y : (() -> number) | (() -> string) = x -- OK + local z : (() -> number) | (() -> (string, string)) = x -- Not OK + end )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(toString(result.errors[0]), "Type '() -> number | string' could not be converted into '(() -> (string, string)) | (() -> number)'; none " - "of the union options are compatible"); + const std::string expected = R"(Type + '() -> number | string' +could not be converted into + '(() -> (string, string)) | (() -> number)'; none of the union options are compatible)"; + CHECK_EQ(expected, toString(result.errors[0])); } TEST_CASE_FIXTURE(Fixture, "union_of_functions_with_variadics") { + ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + CheckResult result = check(R"( - local x : (...nil) -> (...number?) - local y : ((...string?) -> (...number)) | ((...number?) -> nil) = x -- OK - local z : ((...string?) -> (...number)) | ((...string?) -> nil) = x -- OK + function f(x : (...nil) -> (...number?)) + local y : ((...string?) -> (...number)) | ((...number?) -> nil) = x -- OK + local z : ((...string?) -> (...number)) | ((...string?) -> nil) = x -- OK + end )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(toString(result.errors[0]), "Type '(...nil) -> (...number?)' could not be converted into '((...string?) -> (...number)) | ((...string?) " - "-> nil)'; none of the union options are compatible"); + const std::string expected = R"(Type + '(...nil) -> (...number?)' +could not be converted into + '((...string?) -> (...number)) | ((...string?) -> nil)'; none of the union options are compatible)"; + CHECK_EQ(expected, toString(result.errors[0])); } TEST_CASE_FIXTURE(Fixture, "union_of_functions_with_mismatching_arg_variadics") { CheckResult result = check(R"( - local x : (number) -> () - local y : ((number?) -> ()) | ((...number) -> ()) = x -- OK - local z : ((number?) -> ()) | ((...number?) -> ()) = x -- Not OK + function f(x : (number) -> ()) + local y : ((number?) -> ()) | ((...number) -> ()) = x -- OK + local z : ((number?) -> ()) | ((...number?) -> ()) = x -- Not OK + end )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(toString(result.errors[0]), - "Type '(number) -> ()' could not be converted into '((...number?) -> ()) | ((number?) -> ())'; none of the union options are compatible"); + if (FFlag::LuauSolverV2) + { + CHECK(R"(Type + '(number) -> ()' +could not be converted into + '((...number?) -> ()) | ((number?) -> ())')" == toString(result.errors[0])); + } + else + { + const std::string expected = R"(Type + '(number) -> ()' +could not be converted into + '((...number?) -> ()) | ((number?) -> ())'; none of the union options are compatible)"; + CHECK_EQ(expected, toString(result.errors[0])); + } } TEST_CASE_FIXTURE(Fixture, "union_of_functions_with_mismatching_result_variadics") { + ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + CheckResult result = check(R"( - local x : () -> (number?, ...number) - local y : (() -> (...number)) | (() -> nil) = x -- OK - local z : (() -> (...number)) | (() -> number) = x -- OK + function f(x : () -> (number?, ...number)) + local y : (() -> (...number)) | (() -> nil) = x -- OK + local z : (() -> (...number)) | (() -> number) = x -- OK + end )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(toString(result.errors[0]), "Type '() -> (number?, ...number)' could not be converted into '(() -> (...number)) | (() -> number)'; none " - "of the union options are compatible"); + const std::string expected = R"(Type + '() -> (number?, ...number)' +could not be converted into + '(() -> (...number)) | (() -> number)'; none of the union options are compatible)"; + CHECK_EQ(expected, toString(result.errors[0])); } TEST_CASE_FIXTURE(Fixture, "less_greedy_unification_with_union_types") { - if (!FFlag::DebugLuauDeferredConstraintResolution) + if (!FFlag::LuauSolverV2) return; CheckResult result = check(R"( @@ -696,12 +885,12 @@ TEST_CASE_FIXTURE(Fixture, "less_greedy_unification_with_union_types") LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("({| x: number |} | {| x: string |}) -> {| x: number |} | {| x: string |}", toString(requireType("f"))); + CHECK_EQ("(({ read x: unknown } & { x: number }) | ({ read x: unknown } & { x: string })) -> { x: number } | { x: string }", toString(requireType("f"))); } TEST_CASE_FIXTURE(Fixture, "less_greedy_unification_with_union_types_2") { - if (!FFlag::DebugLuauDeferredConstraintResolution) + if (!FFlag::LuauSolverV2) return; CheckResult result = check(R"( @@ -712,7 +901,96 @@ TEST_CASE_FIXTURE(Fixture, "less_greedy_unification_with_union_types_2") LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("({| x: number |} | {| x: string |}) -> number | string", toString(requireType("f"))); + CHECK_EQ("({ x: number } | { x: string }) -> number | string", toString(requireType("f"))); +} + +TEST_CASE_FIXTURE(Fixture, "union_table_any_property") +{ + CheckResult result = check(R"( + function f(x) + -- x : X + -- sup : { p : { q : X } }? + local sup = if true then { p = { q = x } } else nil + local sub : { p : any } + sup = nil + sup = sub + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "union_function_any_args") +{ + ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + + CheckResult result = check(R"( + function f(sup : ((...any) -> (...any))?, sub : ((number) -> (...any))) + sup = sub + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "optional_any") +{ + CheckResult result = check(R"( + function f(sup : any?, sub : number) + sup = sub + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "generic_function_with_optional_arg") +{ + ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + + CheckResult result = check(R"( + function f(x : T?) : {T} + local result = {} + if x then + result[1] = x + end + return result + end + local t : {string} = f(nil) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "lookup_prop_of_intersection_containing_unions") +{ + CheckResult result = check(R"( + local function mergeOptions(options: T & ({} | {})) + return options.variables + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + const UnknownProperty* unknownProp = get(result.errors[0]); + REQUIRE(unknownProp); + + CHECK("variables" == unknownProp->key); +} + +TEST_CASE_FIXTURE(Fixture, "suppress_errors_for_prop_lookup_of_a_union_that_includes_error") +{ + ScopedFastFlag sff{FFlag::LuauSolverV2, true}; + + registerHiddenTypes(&frontend); + + CheckResult result = check(R"( + function f(a: err | Not) + local b = a.foo + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); } TEST_SUITE_END(); diff --git a/tests/TypeInfer.unknownnever.test.cpp b/tests/TypeInfer.unknownnever.test.cpp index f17ada20e..0c62d0b6a 100644 --- a/tests/TypeInfer.unknownnever.test.cpp +++ b/tests/TypeInfer.unknownnever.test.cpp @@ -6,6 +6,8 @@ using namespace Luau; +LUAU_FASTFLAG(LuauSolverV2); + TEST_SUITE_BEGIN("TypeInferUnknownNever"); TEST_CASE_FIXTURE(Fixture, "string_subtype_and_unknown_supertype") @@ -116,14 +118,14 @@ TEST_CASE_FIXTURE(Fixture, "type_packs_containing_never_is_itself_uninhabitable" local x, y, z = f() )"); - if (FFlag::DebugLuauDeferredConstraintResolution) + if (FFlag::LuauSolverV2) { LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ("Function only returns 2 values, but 3 are required here", toString(result.errors[0])); + CHECK("Function only returns 2 values, but 3 are required here" == toString(result.errors[0])); - CHECK_EQ("string", toString(requireType("x"))); - CHECK_EQ("never", toString(requireType("y"))); - CHECK_EQ("*error-type*", toString(requireType("z"))); + CHECK("string" == toString(requireType("x"))); + CHECK("never" == toString(requireType("y"))); + CHECK("nil" == toString(requireType("z"))); } else { @@ -147,7 +149,7 @@ TEST_CASE_FIXTURE(Fixture, "type_packs_containing_never_is_itself_uninhabitable2 LUAU_REQUIRE_NO_ERRORS(result); - if (FFlag::DebugLuauDeferredConstraintResolution) + if (FFlag::LuauSolverV2) { CHECK_EQ("string", toString(requireType("x1"))); CHECK_EQ("never", toString(requireType("x2"))); @@ -191,12 +193,20 @@ TEST_CASE_FIXTURE(Fixture, "call_never") TEST_CASE_FIXTURE(Fixture, "assign_to_local_which_is_never") { + // CLI-117119 - What do we do about assigning to never? CheckResult result = check(R"( local t: never t = 3 )"); - LUAU_REQUIRE_NO_ERRORS(result); + if (FFlag::LuauSolverV2) + { + LUAU_REQUIRE_ERROR_COUNT(1, result); + } + else + { + LUAU_REQUIRE_NO_ERRORS(result); + } } TEST_CASE_FIXTURE(Fixture, "assign_to_global_which_is_never") @@ -213,8 +223,9 @@ TEST_CASE_FIXTURE(Fixture, "assign_to_global_which_is_never") TEST_CASE_FIXTURE(Fixture, "assign_to_prop_which_is_never") { CheckResult result = check(R"( - local t: never - t.x = 5 + local function f(t: never) + t.x = 5 + end )"); LUAU_REQUIRE_NO_ERRORS(result); @@ -223,8 +234,9 @@ TEST_CASE_FIXTURE(Fixture, "assign_to_prop_which_is_never") TEST_CASE_FIXTURE(Fixture, "assign_to_subscript_which_is_never") { CheckResult result = check(R"( - local t: never - t[5] = 7 + local function f(t: never) + t[5] = 7 + end )"); LUAU_REQUIRE_NO_ERRORS(result); @@ -253,10 +265,17 @@ TEST_CASE_FIXTURE(Fixture, "pick_never_from_variadic_type_pack") TEST_CASE_FIXTURE(Fixture, "index_on_union_of_tables_for_properties_that_is_never") { + // CLI-117116 - We are erroneously warning when passing a valid table literal where we expect a union of tables. + if (FFlag::LuauSolverV2) + return; CheckResult result = check(R"( type Disjoint = {foo: never, bar: unknown, tag: "ok"} | {foo: never, baz: unknown, tag: "err"} - local disjoint: Disjoint = {foo = 5 :: never, bar = true, tag = "ok"} - local foo = disjoint.foo + + function f(disjoint: Disjoint) + return disjoint.foo + end + + local foo = f({foo = 5 :: never, bar = true, tag = "ok"}) )"); LUAU_REQUIRE_NO_ERRORS(result); @@ -266,10 +285,17 @@ TEST_CASE_FIXTURE(Fixture, "index_on_union_of_tables_for_properties_that_is_neve TEST_CASE_FIXTURE(Fixture, "index_on_union_of_tables_for_properties_that_is_sorta_never") { + // CLI-117116 - We are erroneously warning when passing a valid table literal where we expect a union of tables. + if (FFlag::LuauSolverV2) + return; CheckResult result = check(R"( type Disjoint = {foo: string, bar: unknown, tag: "ok"} | {foo: never, baz: unknown, tag: "err"} - local disjoint: Disjoint = {foo = 5 :: never, bar = true, tag = "ok"} - local foo = disjoint.foo + + function f(disjoint: Disjoint) + return disjoint.foo + end + + local foo = f({foo = 5 :: never, bar = true, tag = "ok"}) )"); LUAU_REQUIRE_NO_ERRORS(result); @@ -301,10 +327,6 @@ TEST_CASE_FIXTURE(Fixture, "length_of_never") TEST_CASE_FIXTURE(Fixture, "dont_unify_operands_if_one_of_the_operand_is_never_in_any_ordering_operators") { - ScopedFastFlag sff[]{ - {"LuauTryhardAnd", true}, - }; - CheckResult result = check(R"( local function ord(x: nil, y) return x ~= nil and x > y @@ -313,13 +335,10 @@ TEST_CASE_FIXTURE(Fixture, "dont_unify_operands_if_one_of_the_operand_is_never_i LUAU_REQUIRE_NO_ERRORS(result); - if (FFlag::DebugLuauDeferredConstraintResolution) - CHECK_EQ("(nil, a) -> boolean", toString(requireType("ord"))); + if (FFlag::LuauSolverV2) + CHECK_EQ("(nil, unknown) -> boolean", toString(requireType("ord"))); else - { - // Widening doesn't normalize yet, so the result is a bit strange - CHECK_EQ("(nil, a) -> boolean | boolean", toString(requireType("ord"))); - } + CHECK_EQ("(nil, a) -> boolean", toString(requireType("ord"))); } TEST_CASE_FIXTURE(Fixture, "math_operators_and_never") @@ -330,8 +349,88 @@ TEST_CASE_FIXTURE(Fixture, "math_operators_and_never") end )"); + if (FFlag::LuauSolverV2) + { + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK(get(result.errors[0])); + + // CLI-114134 Egraph-based simplification. + // CLI-116549 x ~= nil : false when x : nil + CHECK("(nil, a) -> and>" == toString(requireType("mul"))); + } + else + { + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("(nil, a) -> boolean", toString(requireType("mul"))); + } +} + +TEST_CASE_FIXTURE(Fixture, "compare_never") +{ + CheckResult result = check(R"( + local function cmp(x: nil, y: number) + return x ~= nil and x > y and x < y -- infers boolean | never, which is normalized into boolean + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("(nil, number) -> boolean", toString(requireType("cmp"))); +} + +TEST_CASE_FIXTURE(Fixture, "lti_error_at_declaration_for_never_normalizations") +{ + ScopedFastFlag sff_LuauSolverV2{FFlag::LuauSolverV2, true}; + + CheckResult result = check(R"( + local function num(x: number) end + local function str(x: string) end + local function cond(): boolean return false end + + local function f(a) + if cond() then + num(a) + else + str(a) + end + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(3, result); + CHECK(toString(result.errors[0]) == "Parameter 'a' has been reduced to never. This function is not callable with any possible value."); + CHECK(toString(result.errors[1]) == "Parameter 'a' is required to be a subtype of 'number' here."); + CHECK(toString(result.errors[2]) == "Parameter 'a' is required to be a subtype of 'string' here."); +} + +TEST_CASE_FIXTURE(Fixture, "lti_permit_explicit_never_annotation") +{ + ScopedFastFlag sff_LuauSolverV2{FFlag::LuauSolverV2, true}; + + CheckResult result = check(R"( + local function num(x: number) end + local function str(x: string) end + local function cond(): boolean return false end + + local function f(a: never) + if cond() then + num(a) + else + str(a) + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "cast_from_never_does_not_error") +{ + CheckResult result = check(R"( + local function f(x: never): number + return x :: number + end + )"); + LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("(nil, a) -> boolean", toString(requireType("mul"))); } TEST_SUITE_END(); diff --git a/tests/TypePack.test.cpp b/tests/TypePack.test.cpp index 20404434a..7d8ed38f7 100644 --- a/tests/TypePack.test.cpp +++ b/tests/TypePack.test.cpp @@ -25,7 +25,7 @@ struct TypePackFixture TypePackId freshTypePack() { - typePacks.emplace_back(new TypePackVar{Unifiable::Free{TypeLevel{}}}); + typePacks.emplace_back(new TypePackVar{FreeTypePack{TypeLevel{}}}); return typePacks.back().get(); } diff --git a/tests/TypePath.test.cpp b/tests/TypePath.test.cpp new file mode 100644 index 000000000..be6e84faf --- /dev/null +++ b/tests/TypePath.test.cpp @@ -0,0 +1,620 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/TypePath.h" + +#include "Luau/Type.h" +#include "Luau/TypeArena.h" +#include "Luau/TypePack.h" + +#include "ClassFixture.h" +#include "doctest.h" +#include "Fixture.h" +#include "ScopedFlags.h" + +#include + +using namespace Luau; +using namespace Luau::TypePath; + +LUAU_FASTFLAG(LuauSolverV2); +LUAU_DYNAMIC_FASTINT(LuauTypePathMaximumTraverseSteps); + +struct TypePathFixture : Fixture +{ + ScopedFastFlag sff1{FFlag::LuauSolverV2, true}; +}; + +struct TypePathBuiltinsFixture : BuiltinsFixture +{ + ScopedFastFlag sff1{FFlag::LuauSolverV2, true}; +}; + +TEST_SUITE_BEGIN("TypePathManipulation"); + +TEST_CASE("append") +{ + SUBCASE("empty_paths") + { + Path p; + CHECK(p.append(Path{}).empty()); + } + + SUBCASE("empty_path_with_path") + { + Path p1; + Path p2(TypeField::Metatable); + + Path result = p1.append(p2); + CHECK(result == Path(TypeField::Metatable)); + } + + SUBCASE("two_paths") + { + Path p1(TypeField::IndexLookup); + Path p2(TypeField::Metatable); + + Path result = p1.append(p2); + CHECK(result == Path({TypeField::IndexLookup, TypeField::Metatable})); + } + + SUBCASE("all_components") + { + Path p1({TypeField::IndexLookup, TypeField::Metatable}); + Path p2({TypeField::Metatable, PackField::Arguments}); + + Path result = p1.append(p2); + CHECK(result == Path({TypeField::IndexLookup, TypeField::Metatable, TypeField::Metatable, PackField::Arguments})); + } + + SUBCASE("does_not_mutate") + { + Path p1(TypeField::IndexLookup); + Path p2(TypeField::Metatable); + + p1.append(p2); + CHECK(p1 == Path(TypeField::IndexLookup)); + CHECK(p2 == Path(TypeField::Metatable)); + } +} + +TEST_CASE("push") +{ + Path p; + Path result = p.push(TypeField::Metatable); + + CHECK(p.empty()); + CHECK(result == Path(TypeField::Metatable)); +} + +TEST_CASE("pop") +{ + SUBCASE("empty_path") + { + Path p; + CHECK(p.empty()); + CHECK(p.pop().empty()); + } +} + +TEST_SUITE_END(); // TypePathManipulation + +TEST_SUITE_BEGIN("TypePathTraversal"); + +#define TYPESOLVE_CODE(code) \ + do \ + { \ + CheckResult result = check(code); \ + LUAU_REQUIRE_NO_ERRORS(result); \ + } while (false); + +TEST_CASE_FIXTURE(TypePathFixture, "empty_traversal") +{ + CHECK(traverseForType(builtinTypes->numberType, kEmpty, builtinTypes) == builtinTypes->numberType); +} + +TEST_CASE_FIXTURE(TypePathFixture, "table_property") +{ + TYPESOLVE_CODE(R"( + local x = { y = 123 } + )"); + + CHECK(traverseForType(requireType("x"), Path(TypePath::Property{"y", true}), builtinTypes) == builtinTypes->numberType); +} + +TEST_CASE_FIXTURE(ClassFixture, "class_property") +{ + CHECK(traverseForType(vector2InstanceType, Path(TypePath::Property{"X", true}), builtinTypes) == builtinTypes->numberType); +} + +TEST_CASE_FIXTURE(TypePathBuiltinsFixture, "metatable_property") +{ + SUBCASE("meta_does_not_contribute") + { + TYPESOLVE_CODE(R"( + local x = setmetatable({ x = 123 }, {}) + )"); + } + + SUBCASE("meta_and_table_supply_property") + { + // since the table takes priority, the __index property won't matter + TYPESOLVE_CODE(R"( + local x = setmetatable({ x = 123 }, { __index = { x = 'foo' } }) + )"); + } + + SUBCASE("only_meta_supplies_property") + { + TYPESOLVE_CODE(R"( + local x = setmetatable({}, { __index = { x = 123 } }) + )"); + } + + CHECK(traverseForType(requireType("x"), Path(TypePath::Property::read("x")), builtinTypes) == builtinTypes->numberType); +} + +TEST_CASE_FIXTURE(TypePathFixture, "index") +{ + SUBCASE("unions") + { + TYPESOLVE_CODE(R"( + type T = number | string | boolean + )"); + + SUBCASE("in_bounds") + { + CHECK(traverseForType(requireTypeAlias("T"), Path(TypePath::Index{1}), builtinTypes) == builtinTypes->stringType); + } + + SUBCASE("out_of_bounds") + { + CHECK(traverseForType(requireTypeAlias("T"), Path(TypePath::Index{97}), builtinTypes) == std::nullopt); + } + } + + SUBCASE("intersections") + { + // use functions to avoid the intersection being normalized away + TYPESOLVE_CODE(R"( + type T = (() -> ()) & ((true) -> false) & ((false) -> true) + )"); + + SUBCASE("in_bounds") + { + auto result = traverseForType(requireTypeAlias("T"), Path(TypePath::Index{1}), builtinTypes); + CHECK(result); + + if (result) + CHECK(toString(*result) == "(true) -> false"); + } + + SUBCASE("out_of_bounds") + { + CHECK(traverseForType(requireTypeAlias("T"), Path(TypePath::Index{97}), builtinTypes) == std::nullopt); + } + } + + SUBCASE("type_packs") + { + // use functions to avoid the intersection being normalized away + TYPESOLVE_CODE(R"( + type T = (number, string, true, false) -> () + )"); + + SUBCASE("in_bounds") + { + Path path = Path({TypePath::PackField::Arguments, TypePath::Index{1}}); + auto result = traverseForType(requireTypeAlias("T"), path, builtinTypes); + CHECK(result == builtinTypes->stringType); + } + + SUBCASE("out_of_bounds") + { + Path path = Path({TypePath::PackField::Arguments, TypePath::Index{72}}); + auto result = traverseForType(requireTypeAlias("T"), path, builtinTypes); + CHECK(result == std::nullopt); + } + } +} + +TEST_CASE_FIXTURE(ClassFixture, "metatables") +{ + SUBCASE("string") + { + auto result = traverseForType(builtinTypes->stringType, Path(TypeField::Metatable), builtinTypes); + CHECK(result == getMetatable(builtinTypes->stringType, builtinTypes)); + } + + SUBCASE("string_singleton") + { + TYPESOLVE_CODE(R"( + type T = "foo" + )"); + + auto result = traverseForType(requireTypeAlias("T"), Path(TypeField::Metatable), builtinTypes); + CHECK(result == getMetatable(builtinTypes->stringType, builtinTypes)); + } + + SUBCASE("table") + { + TYPESOLVE_CODE(R"( + type Table = { foo: number } + type Metatable = { bar: number } + local tbl: Table = { foo = 123 } + local mt: Metatable = { bar = 456 } + local res = setmetatable(tbl, mt) + )"); + + // Tricky test setup because 'setmetatable' mutates the argument 'tbl' type + auto result = traverseForType(requireType("res"), Path(TypeField::Table), builtinTypes); + auto expected = lookupType("Table"); + REQUIRE(expected); + CHECK(result == follow(*expected)); + } + + SUBCASE("metatable") + { + TYPESOLVE_CODE(R"( + local mt = { foo = 123 } + local tbl = setmetatable({}, mt) + )"); + + auto result = traverseForType(requireType("tbl"), Path(TypeField::Metatable), builtinTypes); + CHECK(result == requireType("mt")); + } + + SUBCASE("class") + { + auto result = traverseForType(vector2InstanceType, Path(TypeField::Metatable), builtinTypes); + // ClassFixture's Vector2 metatable is just an empty table, but it's there. + CHECK(result); + } +} + +TEST_CASE_FIXTURE(TypePathFixture, "bounds") +{ + SUBCASE("free_type") + { + TypeArena& arena = frontend.globals.globalTypes; + unfreeze(arena); + + TypeId ty = arena.freshType(frontend.globals.globalScope.get()); + FreeType* ft = getMutable(ty); + + SUBCASE("upper") + { + ft->upperBound = builtinTypes->numberType; + auto result = traverseForType(ty, Path(TypeField::UpperBound), builtinTypes); + CHECK(result == builtinTypes->numberType); + } + + SUBCASE("lower") + { + ft->lowerBound = builtinTypes->booleanType; + auto result = traverseForType(ty, Path(TypeField::LowerBound), builtinTypes); + CHECK(result == builtinTypes->booleanType); + } + } + + SUBCASE("unbounded_type") + { + CHECK(traverseForType(builtinTypes->numberType, Path(TypeField::UpperBound), builtinTypes) == std::nullopt); + CHECK(traverseForType(builtinTypes->numberType, Path(TypeField::LowerBound), builtinTypes) == std::nullopt); + } +} + +TEST_CASE_FIXTURE(TypePathFixture, "indexers") +{ + SUBCASE("table") + { + SUBCASE("lookup_indexer") + { + TYPESOLVE_CODE(R"( + type T = { [string]: boolean } + )"); + + auto lookupResult = traverseForType(requireTypeAlias("T"), Path(TypeField::IndexLookup), builtinTypes); + auto resultResult = traverseForType(requireTypeAlias("T"), Path(TypeField::IndexResult), builtinTypes); + + CHECK(lookupResult == builtinTypes->stringType); + CHECK(resultResult == builtinTypes->booleanType); + } + + SUBCASE("no_indexer") + { + TYPESOLVE_CODE(R"( + type T = { y: number } + )"); + + auto lookupResult = traverseForType(requireTypeAlias("T"), Path(TypeField::IndexLookup), builtinTypes); + auto resultResult = traverseForType(requireTypeAlias("T"), Path(TypeField::IndexResult), builtinTypes); + + CHECK(lookupResult == std::nullopt); + CHECK(resultResult == std::nullopt); + } + } + + // TODO: Class types +} + +TEST_CASE_FIXTURE(TypePathFixture, "negated") +{ + SUBCASE("valid") + { + TypeArena& arena = frontend.globals.globalTypes; + unfreeze(arena); + + TypeId ty = arena.addType(NegationType{builtinTypes->numberType}); + auto result = traverseForType(ty, Path(TypeField::Negated), builtinTypes); + CHECK(result == builtinTypes->numberType); + } + + SUBCASE("not_negation") + { + auto result = traverseForType(builtinTypes->numberType, Path(TypeField::Negated), builtinTypes); + CHECK(result == std::nullopt); + } +} + +TEST_CASE_FIXTURE(TypePathFixture, "variadic") +{ + SUBCASE("valid") + { + TypeArena& arena = frontend.globals.globalTypes; + unfreeze(arena); + + TypePackId tp = arena.addTypePack(VariadicTypePack{builtinTypes->numberType}); + auto result = traverseForType(tp, Path(TypeField::Variadic), builtinTypes); + CHECK(result == builtinTypes->numberType); + } + + SUBCASE("not_variadic") + { + auto result = traverseForType(builtinTypes->numberType, Path(TypeField::Variadic), builtinTypes); + CHECK(result == std::nullopt); + } +} + +TEST_CASE_FIXTURE(TypePathFixture, "arguments") +{ + SUBCASE("function") + { + TYPESOLVE_CODE(R"( + function f(x: number, y: string) + end + )"); + + auto result = traverseForPack(requireType("f"), Path(PackField::Arguments), builtinTypes); + CHECK(result); + if (result) + CHECK(toString(*result) == "number, string"); + } + + SUBCASE("not_function") + { + auto result = traverseForPack(builtinTypes->booleanType, Path(PackField::Arguments), builtinTypes); + CHECK(result == std::nullopt); + } +} + +TEST_CASE_FIXTURE(TypePathFixture, "returns") +{ + SUBCASE("function") + { + TYPESOLVE_CODE(R"( + function f(): (number, string) + return 123, "foo" + end + )"); + + auto result = traverseForPack(requireType("f"), Path(PackField::Returns), builtinTypes); + CHECK(result); + if (result) + CHECK(toString(*result) == "number, string"); + } + + SUBCASE("not_function") + { + auto result = traverseForPack(builtinTypes->booleanType, Path(PackField::Returns), builtinTypes); + CHECK(result == std::nullopt); + } +} + +TEST_CASE_FIXTURE(TypePathFixture, "tail") +{ + SUBCASE("has_tail") + { + TYPESOLVE_CODE(R"( + type T = (number, string, ...boolean) -> () + )"); + + auto result = traverseForPack(requireTypeAlias("T"), Path({PackField::Arguments, PackField::Tail}), builtinTypes); + CHECK(result); + if (result) + CHECK(toString(*result) == "...boolean"); + } + + SUBCASE("finite_pack") + { + TYPESOLVE_CODE(R"( + type T = (number, string) -> () + )"); + + auto result = traverseForPack(requireTypeAlias("T"), Path({PackField::Arguments, PackField::Tail}), builtinTypes); + CHECK(result == std::nullopt); + } + + SUBCASE("type") + { + auto result = traverseForPack(builtinTypes->stringType, Path({PackField::Arguments, PackField::Tail}), builtinTypes); + CHECK(result == std::nullopt); + } +} + +TEST_CASE_FIXTURE(TypePathFixture, "cycles" * doctest::timeout(0.5)) +{ + // This will fail an occurs check, but it's a quick example of a cyclic type + // where there _is_ no traversal. + SUBCASE("bound_cycle") + { + TypeArena& arena = frontend.globals.globalTypes; + unfreeze(arena); + + TypeId a = arena.addType(BlockedType{}); + TypeId b = arena.addType(BoundType{a}); + asMutable(a)->ty.emplace(b); + + CHECK_THROWS(traverseForType(a, Path(TypeField::IndexResult), builtinTypes)); + } + + SUBCASE("table_contains_itself") + { + TypeArena& arena = frontend.globals.globalTypes; + unfreeze(arena); + + TypeId tbl = arena.addType(TableType{}); + getMutable(tbl)->props["a"] = Luau::Property(tbl); + + auto result = traverseForType(tbl, Path(TypePath::Property{"a", true}), builtinTypes); + CHECK(result == tbl); + } +} + +TEST_CASE_FIXTURE(TypePathFixture, "step_limit") +{ + ScopedFastInt sfi(DFInt::LuauTypePathMaximumTraverseSteps, 2); + + TYPESOLVE_CODE(R"( + type T = { + x: { + y: { + z: number + } + } + } + )"); + + TypeId root = requireTypeAlias("T"); + Path path = PathBuilder().readProp("x").readProp("y").readProp("z").build(); + auto result = traverseForType(root, path, builtinTypes); + CHECK(!result); +} + +TEST_CASE_FIXTURE(TypePathBuiltinsFixture, "complex_chains") +{ + SUBCASE("add_metamethod_return_type") + { + TYPESOLVE_CODE(R"( + type Meta = { + __add: (Tab, Tab) -> number, + } + + type Tab = typeof(setmetatable({}, {} :: Meta)) + )"); + + TypeId root = requireTypeAlias("Tab"); + Path path = PathBuilder().mt().readProp("__add").rets().index(0).build(); + auto result = traverseForType(root, path, builtinTypes); + CHECK(result == builtinTypes->numberType); + } + + SUBCASE("overloaded_fn_overload_one_argument_two") + { + TYPESOLVE_CODE(R"( + type Obj = { + method: ((true, false) -> string) & ((string) -> number) + } + )"); + + TypeId root = requireTypeAlias("Obj"); + Path path = PathBuilder().readProp("method").index(0).args().index(1).build(); + auto result = traverseForType(root, path, builtinTypes); + CHECK(*result == builtinTypes->falseType); + } +} + +TEST_SUITE_END(); // TypePathTraversal + +TEST_SUITE_BEGIN("TypePathToString"); + +TEST_CASE("field") +{ + ScopedFastFlag sff[] = { + {FFlag::LuauSolverV2, false}, + }; + + CHECK(toString(PathBuilder().prop("foo").build()) == R"(["foo"])"); +} + +TEST_CASE("index") +{ + CHECK(toString(PathBuilder().index(0).build()) == "[0]"); +} + +TEST_CASE("chain") +{ + CHECK(toString(PathBuilder().index(0).mt().build()) == "[0].metatable()"); +} + +TEST_SUITE_END(); // TypePathToString + +TEST_SUITE_BEGIN("TypePathBuilder"); + +TEST_CASE("empty_path") +{ + Path p = PathBuilder().build(); + CHECK(p.empty()); +} + +TEST_CASE("prop") +{ + ScopedFastFlag sff[] = { + {FFlag::LuauSolverV2, false}, + }; + + Path p = PathBuilder().prop("foo").build(); + CHECK(p == Path(TypePath::Property{"foo"})); +} + +TEST_CASE_FIXTURE(TypePathFixture, "readProp") +{ + Path p = PathBuilder().readProp("foo").build(); + CHECK(p == Path(TypePath::Property::read("foo"))); +} + +TEST_CASE_FIXTURE(TypePathFixture, "writeProp") +{ + Path p = PathBuilder().writeProp("foo").build(); + CHECK(p == Path(TypePath::Property::write("foo"))); +} + +TEST_CASE("index") +{ + Path p = PathBuilder().index(0).build(); + CHECK(p == Path(TypePath::Index{0})); +} + +TEST_CASE("fields") +{ + CHECK(PathBuilder().mt().build() == Path(TypeField::Metatable)); + CHECK(PathBuilder().lb().build() == Path(TypeField::LowerBound)); + CHECK(PathBuilder().ub().build() == Path(TypeField::UpperBound)); + CHECK(PathBuilder().indexKey().build() == Path(TypeField::IndexLookup)); + CHECK(PathBuilder().indexValue().build() == Path(TypeField::IndexResult)); + CHECK(PathBuilder().negated().build() == Path(TypeField::Negated)); + CHECK(PathBuilder().variadic().build() == Path(TypeField::Variadic)); + CHECK(PathBuilder().args().build() == Path(PackField::Arguments)); + CHECK(PathBuilder().rets().build() == Path(PackField::Returns)); + CHECK(PathBuilder().tail().build() == Path(PackField::Tail)); +} + +TEST_CASE("chained") +{ + ScopedFastFlag sff{FFlag::LuauSolverV2, true}; + + CHECK( + PathBuilder().index(0).readProp("foo").mt().readProp("bar").args().index(1).build() == + Path({Index{0}, TypePath::Property::read("foo"), TypeField::Metatable, TypePath::Property::read("bar"), PackField::Arguments, Index{1}}) + ); +} + +TEST_SUITE_END(); // TypePathBuilder diff --git a/tests/TypeReduction.test.cpp b/tests/TypeReduction.test.cpp deleted file mode 100644 index 582725b74..000000000 --- a/tests/TypeReduction.test.cpp +++ /dev/null @@ -1,1491 +0,0 @@ -// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details -#include "Luau/TypeReduction.h" - -#include "Fixture.h" -#include "doctest.h" - -using namespace Luau; - -namespace -{ -struct ReductionFixture : Fixture -{ - TypeReductionOptions typeReductionOpts{/* allowTypeReductionsFromOtherArenas */ true}; - ToStringOptions toStringOpts{true}; - - TypeArena arena; - InternalErrorReporter iceHandler; - UnifierSharedState unifierState{&iceHandler}; - TypeReduction reduction{NotNull{&arena}, builtinTypes, NotNull{&iceHandler}, typeReductionOpts}; - - ReductionFixture() - { - registerHiddenTypes(&frontend); - createSomeClasses(&frontend); - } - - TypeId reductionof(TypeId ty) - { - std::optional reducedTy = reduction.reduce(ty); - REQUIRE(reducedTy); - return *reducedTy; - } - - TypeId reductionof(const std::string& annotation) - { - check("type _Res = " + annotation); - return reductionof(requireTypeAlias("_Res")); - } - - std::string toStringFull(TypeId ty) - { - return toString(ty, toStringOpts); - } -}; -} // namespace - -TEST_SUITE_BEGIN("TypeReductionTests"); - -TEST_CASE_FIXTURE(ReductionFixture, "cartesian_product_exceeded") -{ - ScopedFastInt sfi{"LuauTypeReductionCartesianProductLimit", 5}; - - CheckResult result = check(R"( - type T - = string - & (number | string | boolean) - & (number | string | boolean) - )"); - - CHECK(!reduction.reduce(requireTypeAlias("T"))); - // LUAU_REQUIRE_ERROR_COUNT(1, result); - // CHECK("Code is too complex to typecheck! Consider simplifying the code around this area" == toString(result.errors[0])); -} - -TEST_CASE_FIXTURE(ReductionFixture, "cartesian_product_exceeded_with_normal_limit") -{ - CheckResult result = check(R"( - type T - = string -- 1 = 1 - & (number | string | boolean) -- 1 * 3 = 3 - & (number | string | boolean) -- 3 * 3 = 9 - & (number | string | boolean) -- 9 * 3 = 27 - & (number | string | boolean) -- 27 * 3 = 81 - & (number | string | boolean) -- 81 * 3 = 243 - & (number | string | boolean) -- 243 * 3 = 729 - & (number | string | boolean) -- 729 * 3 = 2187 - & (number | string | boolean) -- 2187 * 3 = 6561 - & (number | string | boolean) -- 6561 * 3 = 19683 - & (number | string | boolean) -- 19683 * 3 = 59049 - & (number | string) -- 59049 * 2 = 118098 - )"); - - CHECK(!reduction.reduce(requireTypeAlias("T"))); - // LUAU_REQUIRE_ERROR_COUNT(1, result); - // CHECK("Code is too complex to typecheck! Consider simplifying the code around this area" == toString(result.errors[0])); -} - -TEST_CASE_FIXTURE(ReductionFixture, "cartesian_product_is_zero") -{ - ScopedFastInt sfi{"LuauTypeReductionCartesianProductLimit", 5}; - - CheckResult result = check(R"( - type T - = string - & (number | string | boolean) - & (number | string | boolean) - & never - )"); - - LUAU_REQUIRE_NO_ERRORS(result); -} - -TEST_CASE_FIXTURE(ReductionFixture, "stress_test_recursion_limits") -{ - TypeId ty = arena.addType(IntersectionType{{builtinTypes->numberType, builtinTypes->stringType}}); - for (size_t i = 0; i < 20'000; ++i) - { - TableType table; - table.state = TableState::Sealed; - table.props["x"] = {ty}; - ty = arena.addType(IntersectionType{{arena.addType(table), arena.addType(table)}}); - } - - CHECK(!reduction.reduce(ty)); -} - -TEST_CASE_FIXTURE(ReductionFixture, "caching") -{ - SUBCASE("free_tables") - { - TypeId ty1 = arena.addType(TableType{}); - getMutable(ty1)->state = TableState::Free; - getMutable(ty1)->props["x"] = {builtinTypes->stringType}; - - TypeId ty2 = arena.addType(TableType{}); - getMutable(ty2)->state = TableState::Sealed; - - TypeId intersectionTy = arena.addType(IntersectionType{{ty1, ty2}}); - - CHECK("{- x: string -} & {| |}" == toStringFull(reductionof(intersectionTy))); - - getMutable(ty1)->state = TableState::Sealed; - CHECK("{| x: string |}" == toStringFull(reductionof(intersectionTy))); - } - - SUBCASE("unsealed_tables") - { - TypeId ty1 = arena.addType(TableType{}); - getMutable(ty1)->state = TableState::Unsealed; - getMutable(ty1)->props["x"] = {builtinTypes->stringType}; - - TypeId ty2 = arena.addType(TableType{}); - getMutable(ty2)->state = TableState::Sealed; - - TypeId intersectionTy = arena.addType(IntersectionType{{ty1, ty2}}); - - CHECK("{| x: string |}" == toStringFull(reductionof(intersectionTy))); - - getMutable(ty1)->state = TableState::Sealed; - CHECK("{| x: string |}" == toStringFull(reductionof(intersectionTy))); - } - - SUBCASE("free_types") - { - TypeId ty1 = arena.freshType(nullptr); - TypeId ty2 = arena.addType(TableType{}); - getMutable(ty2)->state = TableState::Sealed; - - TypeId intersectionTy = arena.addType(IntersectionType{{ty1, ty2}}); - - CHECK("a & {| |}" == toStringFull(reductionof(intersectionTy))); - - *asMutable(ty1) = BoundType{ty2}; - CHECK("{| |}" == toStringFull(reductionof(intersectionTy))); - } - - SUBCASE("we_can_see_that_the_cache_works_if_we_mutate_a_normally_not_mutated_type") - { - TypeId ty1 = arena.addType(BoundType{builtinTypes->stringType}); - TypeId ty2 = builtinTypes->numberType; - - TypeId intersectionTy = arena.addType(IntersectionType{{ty1, ty2}}); - - CHECK("never" == toStringFull(reductionof(intersectionTy))); // Bound & number ~ never - - *asMutable(ty1) = BoundType{ty2}; - CHECK("never" == toStringFull(reductionof(intersectionTy))); // Bound & number ~ number, but the cache is `never`. - } - - SUBCASE("ptr_eq_irreducible_unions") - { - TypeId unionTy = arena.addType(UnionType{{builtinTypes->stringType, builtinTypes->numberType}}); - TypeId reducedTy = reductionof(unionTy); - REQUIRE(unionTy == reducedTy); - } - - SUBCASE("ptr_eq_irreducible_intersections") - { - TypeId intersectionTy = arena.addType(IntersectionType{{builtinTypes->stringType, arena.addType(GenericType{"G"})}}); - TypeId reducedTy = reductionof(intersectionTy); - REQUIRE(intersectionTy == reducedTy); - } - - SUBCASE("ptr_eq_free_table") - { - TypeId tableTy = arena.addType(TableType{}); - getMutable(tableTy)->state = TableState::Free; - - TypeId reducedTy = reductionof(tableTy); - REQUIRE(tableTy == reducedTy); - } - - SUBCASE("ptr_eq_unsealed_table") - { - TypeId tableTy = arena.addType(TableType{}); - getMutable(tableTy)->state = TableState::Unsealed; - - TypeId reducedTy = reductionof(tableTy); - REQUIRE(tableTy == reducedTy); - } -} // caching - -TEST_CASE_FIXTURE(ReductionFixture, "intersections_without_negations") -{ - SUBCASE("string_and_string") - { - TypeId ty = reductionof("string & string"); - CHECK("string" == toStringFull(ty)); - } - - SUBCASE("never_and_string") - { - TypeId ty = reductionof("never & string"); - CHECK("never" == toStringFull(ty)); - } - - SUBCASE("string_and_never") - { - TypeId ty = reductionof("string & never"); - CHECK("never" == toStringFull(ty)); - } - - SUBCASE("unknown_and_string") - { - TypeId ty = reductionof("unknown & string"); - CHECK("string" == toStringFull(ty)); - } - - SUBCASE("string_and_unknown") - { - TypeId ty = reductionof("string & unknown"); - CHECK("string" == toStringFull(ty)); - } - - SUBCASE("any_and_string") - { - TypeId ty = reductionof("any & string"); - CHECK("string" == toStringFull(ty)); - } - - SUBCASE("string_and_any") - { - TypeId ty = reductionof("string & any"); - CHECK("string" == toStringFull(ty)); - } - - SUBCASE("string_or_number_and_string") - { - TypeId ty = reductionof("(string | number) & string"); - CHECK("string" == toStringFull(ty)); - } - - SUBCASE("string_and_string_or_number") - { - TypeId ty = reductionof("string & (string | number)"); - CHECK("string" == toStringFull(ty)); - } - - SUBCASE("string_and_a") - { - TypeId ty = reductionof(R"(string & "a")"); - CHECK(R"("a")" == toStringFull(ty)); - } - - SUBCASE("boolean_and_true") - { - TypeId ty = reductionof("boolean & true"); - CHECK("true" == toStringFull(ty)); - } - - SUBCASE("boolean_and_a") - { - TypeId ty = reductionof(R"(boolean & "a")"); - CHECK("never" == toStringFull(ty)); - } - - SUBCASE("a_and_a") - { - TypeId ty = reductionof(R"("a" & "a")"); - CHECK(R"("a")" == toStringFull(ty)); - } - - SUBCASE("a_and_b") - { - TypeId ty = reductionof(R"("a" & "b")"); - CHECK("never" == toStringFull(ty)); - } - - SUBCASE("a_and_true") - { - TypeId ty = reductionof(R"("a" & true)"); - CHECK("never" == toStringFull(ty)); - } - - SUBCASE("a_and_true") - { - TypeId ty = reductionof(R"(true & false)"); - CHECK("never" == toStringFull(ty)); - } - - SUBCASE("function_type_and_function") - { - TypeId ty = reductionof("() -> () & fun"); - CHECK("() -> ()" == toStringFull(ty)); - } - - SUBCASE("function_type_and_string") - { - TypeId ty = reductionof("() -> () & string"); - CHECK("never" == toStringFull(ty)); - } - - SUBCASE("parent_and_child") - { - TypeId ty = reductionof("Parent & Child"); - CHECK("Child" == toStringFull(ty)); - } - - SUBCASE("child_and_parent") - { - TypeId ty = reductionof("Child & Parent"); - CHECK("Child" == toStringFull(ty)); - } - - SUBCASE("child_and_unrelated") - { - TypeId ty = reductionof("Child & Unrelated"); - CHECK("never" == toStringFull(ty)); - } - - SUBCASE("string_and_table") - { - TypeId ty = reductionof("string & {}"); - CHECK("never" == toStringFull(ty)); - } - - SUBCASE("string_and_child") - { - TypeId ty = reductionof("string & Child"); - CHECK("never" == toStringFull(ty)); - } - - SUBCASE("string_and_function") - { - TypeId ty = reductionof("string & () -> ()"); - CHECK("never" == toStringFull(ty)); - } - - SUBCASE("function_and_table") - { - TypeId ty = reductionof("() -> () & {}"); - CHECK("never" == toStringFull(ty)); - } - - SUBCASE("function_and_class") - { - TypeId ty = reductionof("() -> () & Child"); - CHECK("never" == toStringFull(ty)); - } - - SUBCASE("function_and_function") - { - TypeId ty = reductionof("() -> () & () -> ()"); - CHECK("(() -> ()) & (() -> ())" == toStringFull(ty)); - } - - SUBCASE("table_and_table") - { - TypeId ty = reductionof("{} & {}"); - CHECK("{| |}" == toStringFull(ty)); - } - - SUBCASE("table_and_metatable") - { - // No setmetatable in ReductionFixture, so we mix and match. - BuiltinsFixture fixture; - fixture.check(R"( - type Ty = {} & typeof(setmetatable({}, {})) - )"); - - TypeId ty = reductionof(fixture.requireTypeAlias("Ty")); - CHECK("{ @metatable { }, { } } & {| |}" == toStringFull(ty)); - } - - SUBCASE("a_and_string") - { - TypeId ty = reductionof(R"("a" & string)"); - CHECK(R"("a")" == toStringFull(ty)); - } - - SUBCASE("reducible_function_and_function") - { - TypeId ty = reductionof("((string | string) -> (number | number)) & fun"); - CHECK("(string) -> number" == toStringFull(ty)); - } - - SUBCASE("string_and_error") - { - TypeId ty = reductionof("string & err"); - CHECK("*error-type* & string" == toStringFull(ty)); - } - - SUBCASE("table_p_string_and_table_p_number") - { - TypeId ty = reductionof("{ p: string } & { p: number }"); - CHECK("never" == toStringFull(ty)); - } - - SUBCASE("table_p_string_and_table_p_string") - { - TypeId ty = reductionof("{ p: string } & { p: string }"); - CHECK("{| p: string |}" == toStringFull(ty)); - } - - SUBCASE("table_x_table_p_string_and_table_x_table_p_number") - { - TypeId ty = reductionof("{ x: { p: string } } & { x: { p: number } }"); - CHECK("never" == toStringFull(ty)); - } - - SUBCASE("table_p_and_table_q") - { - TypeId ty = reductionof("{ p: string } & { q: number }"); - CHECK("{| p: string, q: number |}" == toStringFull(ty)); - } - - SUBCASE("table_tag_a_or_table_tag_b_and_table_b") - { - TypeId ty = reductionof("({ tag: string, a: number } | { tag: number, b: string }) & { b: string }"); - CHECK("{| a: number, b: string, tag: string |} | {| b: string, tag: number |}" == toStringFull(ty)); - } - - SUBCASE("table_string_number_indexer_and_table_string_number_indexer") - { - TypeId ty = reductionof("{ [string]: number } & { [string]: number }"); - CHECK("{| [string]: number |}" == toStringFull(ty)); - } - - SUBCASE("table_string_number_indexer_and_empty_table") - { - TypeId ty = reductionof("{ [string]: number } & {}"); - CHECK("{| [string]: number |}" == toStringFull(ty)); - } - - SUBCASE("empty_table_table_string_number_indexer") - { - TypeId ty = reductionof("{} & { [string]: number }"); - CHECK("{| [string]: number |}" == toStringFull(ty)); - } - - SUBCASE("string_number_indexer_and_number_number_indexer") - { - TypeId ty = reductionof("{ [string]: number } & { [number]: number }"); - CHECK("{number} & {| [string]: number |}" == toStringFull(ty)); - } - - SUBCASE("table_p_string_and_indexer_number_number") - { - TypeId ty = reductionof("{ p: string } & { [number]: number }"); - CHECK("{| [number]: number, p: string |}" == toStringFull(ty)); - } - - SUBCASE("table_p_string_and_indexer_string_number") - { - TypeId ty = reductionof("{ p: string } & { [string]: number }"); - CHECK("{| [string]: number, p: string |}" == toStringFull(ty)); - } - - SUBCASE("table_p_string_and_table_p_string_plus_indexer_string_number") - { - TypeId ty = reductionof("{ p: string } & { p: string, [string]: number }"); - CHECK("{| [string]: number, p: string |}" == toStringFull(ty)); - } - - SUBCASE("array_number_and_array_string") - { - TypeId ty = reductionof("{number} & {string}"); - CHECK("{never}" == toStringFull(ty)); - } - - SUBCASE("array_string_and_array_string") - { - TypeId ty = reductionof("{string} & {string}"); - CHECK("{string}" == toStringFull(ty)); - } - - SUBCASE("array_string_or_number_and_array_string") - { - TypeId ty = reductionof("{string | number} & {string}"); - CHECK("{string}" == toStringFull(ty)); - } - - SUBCASE("fresh_type_and_string") - { - TypeId freshTy = arena.freshType(nullptr); - TypeId ty = reductionof(arena.addType(IntersectionType{{freshTy, builtinTypes->stringType}})); - CHECK("a & string" == toStringFull(ty)); - } - - SUBCASE("string_and_fresh_type") - { - TypeId freshTy = arena.freshType(nullptr); - TypeId ty = reductionof(arena.addType(IntersectionType{{builtinTypes->stringType, freshTy}})); - CHECK("a & string" == toStringFull(ty)); - } - - SUBCASE("generic_and_string") - { - TypeId genericTy = arena.addType(GenericType{"G"}); - TypeId ty = reductionof(arena.addType(IntersectionType{{genericTy, builtinTypes->stringType}})); - CHECK("G & string" == toStringFull(ty)); - } - - SUBCASE("string_and_generic") - { - TypeId genericTy = arena.addType(GenericType{"G"}); - TypeId ty = reductionof(arena.addType(IntersectionType{{builtinTypes->stringType, genericTy}})); - CHECK("G & string" == toStringFull(ty)); - } - - SUBCASE("parent_and_child_or_parent_and_anotherchild_or_parent_and_unrelated") - { - TypeId ty = reductionof("Parent & (Child | AnotherChild | Unrelated)"); - CHECK("AnotherChild | Child" == toString(ty)); - } - - SUBCASE("parent_and_child_or_parent_and_anotherchild_or_parent_and_unrelated_2") - { - TypeId ty = reductionof("(Parent & Child) | (Parent & AnotherChild) | (Parent & Unrelated)"); - CHECK("AnotherChild | Child" == toString(ty)); - } - - SUBCASE("top_table_and_table") - { - TypeId ty = reductionof("tbl & {}"); - CHECK("{| |}" == toString(ty)); - } - - SUBCASE("top_table_and_non_table") - { - TypeId ty = reductionof("tbl & \"foo\""); - CHECK("never" == toString(ty)); - } - - SUBCASE("top_table_and_metatable") - { - BuiltinsFixture fixture; - registerHiddenTypes(&fixture.frontend); - fixture.check(R"( - type Ty = tbl & typeof(setmetatable({}, {})) - )"); - - TypeId ty = reductionof(fixture.requireTypeAlias("Ty")); - CHECK("{ @metatable { }, { } }" == toString(ty)); - } -} // intersections_without_negations - -TEST_CASE_FIXTURE(ReductionFixture, "intersections_with_negations") -{ - SUBCASE("nil_and_not_nil") - { - TypeId ty = reductionof("nil & Not"); - CHECK("never" == toStringFull(ty)); - } - - SUBCASE("nil_and_not_false") - { - TypeId ty = reductionof("nil & Not"); - CHECK("nil" == toStringFull(ty)); - } - - SUBCASE("string_or_nil_and_not_nil") - { - TypeId ty = reductionof("(string?) & Not"); - CHECK("string" == toStringFull(ty)); - } - - SUBCASE("string_or_nil_and_not_false_or_nil") - { - TypeId ty = reductionof("(string?) & Not"); - CHECK("string" == toStringFull(ty)); - } - - SUBCASE("string_or_nil_and_not_false_and_not_nil") - { - TypeId ty = reductionof("(string?) & Not & Not"); - CHECK("string" == toStringFull(ty)); - } - - SUBCASE("not_false_and_bool") - { - TypeId ty = reductionof("Not & boolean"); - CHECK("true" == toStringFull(ty)); - } - - SUBCASE("function_type_and_not_function") - { - TypeId ty = reductionof("() -> () & Not"); - CHECK("never" == toStringFull(ty)); - } - - SUBCASE("function_type_and_not_string") - { - TypeId ty = reductionof("() -> () & Not"); - CHECK("() -> ()" == toStringFull(ty)); - } - - SUBCASE("not_a_and_string_or_nil") - { - TypeId ty = reductionof(R"(Not<"a"> & (string | nil))"); - CHECK(R"((string & ~"a")?)" == toStringFull(ty)); - } - - SUBCASE("not_a_and_a") - { - TypeId ty = reductionof(R"(Not<"a"> & "a")"); - CHECK("never" == toStringFull(ty)); - } - - SUBCASE("not_a_and_b") - { - TypeId ty = reductionof(R"(Not<"a"> & "b")"); - CHECK(R"("b")" == toStringFull(ty)); - } - - SUBCASE("not_string_and_a") - { - TypeId ty = reductionof(R"(Not & "a")"); - CHECK("never" == toStringFull(ty)); - } - - SUBCASE("not_bool_and_true") - { - TypeId ty = reductionof("Not & true"); - CHECK("never" == toStringFull(ty)); - } - - SUBCASE("not_string_and_true") - { - TypeId ty = reductionof("Not & true"); - CHECK("true" == toStringFull(ty)); - } - - SUBCASE("parent_and_not_child") - { - TypeId ty = reductionof("Parent & Not"); - CHECK("Parent & ~Child" == toStringFull(ty)); - } - - SUBCASE("not_child_and_parent") - { - TypeId ty = reductionof("Not & Parent"); - CHECK("Parent & ~Child" == toStringFull(ty)); - } - - SUBCASE("child_and_not_parent") - { - TypeId ty = reductionof("Child & Not"); - CHECK("never" == toStringFull(ty)); - } - - SUBCASE("not_parent_and_child") - { - TypeId ty = reductionof("Not & Child"); - CHECK("never" == toStringFull(ty)); - } - - SUBCASE("not_parent_and_unrelated") - { - TypeId ty = reductionof("Not & Unrelated"); - CHECK("Unrelated" == toStringFull(ty)); - } - - SUBCASE("unrelated_and_not_parent") - { - TypeId ty = reductionof("Unrelated & Not"); - CHECK("Unrelated" == toStringFull(ty)); - } - - SUBCASE("not_unrelated_and_parent") - { - TypeId ty = reductionof("Not & Parent"); - CHECK("Parent" == toStringFull(ty)); - } - - SUBCASE("parent_and_not_unrelated") - { - TypeId ty = reductionof("Parent & Not"); - CHECK("Parent" == toStringFull(ty)); - } - - SUBCASE("reducible_function_and_not_function") - { - TypeId ty = reductionof("((string | string) -> (number | number)) & Not"); - CHECK("never" == toStringFull(ty)); - } - - SUBCASE("string_and_not_error") - { - TypeId ty = reductionof("string & Not"); - CHECK("string" == toStringFull(ty)); - } - - SUBCASE("table_p_string_and_table_p_not_number") - { - TypeId ty = reductionof("{ p: string } & { p: Not }"); - CHECK("{| p: string |}" == toStringFull(ty)); - } - - SUBCASE("table_p_string_and_table_p_not_string") - { - TypeId ty = reductionof("{ p: string } & { p: Not }"); - CHECK("never" == toStringFull(ty)); - } - - SUBCASE("table_x_table_p_string_and_table_x_table_p_not_number") - { - TypeId ty = reductionof("{ x: { p: string } } & { x: { p: Not } }"); - CHECK("{| x: {| p: string |} |}" == toStringFull(ty)); - } - - SUBCASE("table_or_nil_and_truthy") - { - TypeId ty = reductionof("({ x: number | string }?) & Not"); - CHECK("{| x: number | string |}" == toString(ty)); - } - - SUBCASE("not_top_table_and_table") - { - TypeId ty = reductionof("Not & {}"); - CHECK("never" == toString(ty)); - } - - SUBCASE("not_top_table_and_metatable") - { - BuiltinsFixture fixture; - registerHiddenTypes(&fixture.frontend); - fixture.check(R"( - type Ty = Not & typeof(setmetatable({}, {})) - )"); - - TypeId ty = reductionof(fixture.requireTypeAlias("Ty")); - CHECK("never" == toString(ty)); - } -} // intersections_with_negations - -TEST_CASE_FIXTURE(ReductionFixture, "unions_without_negations") -{ - SUBCASE("never_or_string") - { - TypeId ty = reductionof("never | string"); - CHECK("string" == toStringFull(ty)); - } - - SUBCASE("string_or_never") - { - TypeId ty = reductionof("string | never"); - CHECK("string" == toStringFull(ty)); - } - - SUBCASE("unknown_or_string") - { - TypeId ty = reductionof("unknown | string"); - CHECK("unknown" == toStringFull(ty)); - } - - SUBCASE("string_or_unknown") - { - TypeId ty = reductionof("string | unknown"); - CHECK("unknown" == toStringFull(ty)); - } - - SUBCASE("any_or_string") - { - TypeId ty = reductionof("any | string"); - CHECK("any" == toStringFull(ty)); - } - - SUBCASE("string_or_any") - { - TypeId ty = reductionof("string | any"); - CHECK("any" == toStringFull(ty)); - } - - SUBCASE("string_or_string_and_number") - { - TypeId ty = reductionof("string | (string & number)"); - CHECK("string" == toStringFull(ty)); - } - - SUBCASE("string_or_string") - { - TypeId ty = reductionof("string | string"); - CHECK("string" == toStringFull(ty)); - } - - SUBCASE("string_or_number") - { - TypeId ty = reductionof("string | number"); - CHECK("number | string" == toStringFull(ty)); - } - - SUBCASE("number_or_string") - { - TypeId ty = reductionof("number | string"); - CHECK("number | string" == toStringFull(ty)); - } - - SUBCASE("string_or_number_or_string") - { - TypeId ty = reductionof("(string | number) | string"); - CHECK("number | string" == toStringFull(ty)); - } - - SUBCASE("string_or_number_or_string_2") - { - TypeId ty = reductionof("string | (number | string)"); - CHECK("number | string" == toStringFull(ty)); - } - - SUBCASE("string_or_string_or_number") - { - TypeId ty = reductionof("string | (string | number)"); - CHECK("number | string" == toStringFull(ty)); - } - - SUBCASE("string_or_string_or_number_or_boolean") - { - TypeId ty = reductionof("string | (string | number | boolean)"); - CHECK("boolean | number | string" == toStringFull(ty)); - } - - SUBCASE("string_or_string_or_boolean_or_number") - { - TypeId ty = reductionof("string | (string | boolean | number)"); - CHECK("boolean | number | string" == toStringFull(ty)); - } - - SUBCASE("string_or_boolean_or_string_or_number") - { - TypeId ty = reductionof("string | (boolean | string | number)"); - CHECK("boolean | number | string" == toStringFull(ty)); - } - - SUBCASE("boolean_or_string_or_number_or_string") - { - TypeId ty = reductionof("(boolean | string | number) | string"); - CHECK("boolean | number | string" == toStringFull(ty)); - } - - SUBCASE("boolean_or_true") - { - TypeId ty = reductionof("boolean | true"); - CHECK("boolean" == toStringFull(ty)); - } - - SUBCASE("boolean_or_false") - { - TypeId ty = reductionof("boolean | false"); - CHECK("boolean" == toStringFull(ty)); - } - - SUBCASE("boolean_or_true_or_false") - { - TypeId ty = reductionof("boolean | true | false"); - CHECK("boolean" == toStringFull(ty)); - } - - SUBCASE("string_or_a") - { - TypeId ty = reductionof(R"(string | "a")"); - CHECK("string" == toStringFull(ty)); - } - - SUBCASE("a_or_a") - { - TypeId ty = reductionof(R"("a" | "a")"); - CHECK(R"("a")" == toStringFull(ty)); - } - - SUBCASE("a_or_b") - { - TypeId ty = reductionof(R"("a" | "b")"); - CHECK(R"("a" | "b")" == toStringFull(ty)); - } - - SUBCASE("a_or_b_or_string") - { - TypeId ty = reductionof(R"("a" | "b" | string)"); - CHECK("string" == toStringFull(ty)); - } - - SUBCASE("unknown_or_any") - { - TypeId ty = reductionof("unknown | any"); - CHECK("unknown" == toStringFull(ty)); - } - - SUBCASE("any_or_unknown") - { - TypeId ty = reductionof("any | unknown"); - CHECK("unknown" == toStringFull(ty)); - } - - SUBCASE("function_type_or_function") - { - TypeId ty = reductionof("() -> () | fun"); - CHECK("function" == toStringFull(ty)); - } - - SUBCASE("function_or_string") - { - TypeId ty = reductionof("fun | string"); - CHECK("function | string" == toStringFull(ty)); - } - - SUBCASE("parent_or_child") - { - TypeId ty = reductionof("Parent | Child"); - CHECK("Parent" == toStringFull(ty)); - } - - SUBCASE("child_or_parent") - { - TypeId ty = reductionof("Child | Parent"); - CHECK("Parent" == toStringFull(ty)); - } - - SUBCASE("parent_or_unrelated") - { - TypeId ty = reductionof("Parent | Unrelated"); - CHECK("Parent | Unrelated" == toStringFull(ty)); - } - - SUBCASE("parent_or_child_or_unrelated") - { - TypeId ty = reductionof("Parent | Child | Unrelated"); - CHECK("Parent | Unrelated" == toStringFull(ty)); - } - - SUBCASE("parent_or_unrelated_or_child") - { - TypeId ty = reductionof("Parent | Unrelated | Child"); - CHECK("Parent | Unrelated" == toStringFull(ty)); - } - - SUBCASE("parent_or_child_or_unrelated_or_child") - { - TypeId ty = reductionof("Parent | Child | Unrelated | Child"); - CHECK("Parent | Unrelated" == toStringFull(ty)); - } - - SUBCASE("string_or_true") - { - TypeId ty = reductionof("string | true"); - CHECK("string | true" == toStringFull(ty)); - } - - SUBCASE("string_or_function") - { - TypeId ty = reductionof("string | () -> ()"); - CHECK("(() -> ()) | string" == toStringFull(ty)); - } - - SUBCASE("string_or_err") - { - TypeId ty = reductionof("string | err"); - CHECK("*error-type* | string" == toStringFull(ty)); - } - - SUBCASE("top_table_or_table") - { - TypeId ty = reductionof("tbl | {}"); - CHECK("table" == toString(ty)); - } - - SUBCASE("top_table_or_metatable") - { - BuiltinsFixture fixture; - registerHiddenTypes(&fixture.frontend); - fixture.check(R"( - type Ty = tbl | typeof(setmetatable({}, {})) - )"); - - TypeId ty = reductionof(fixture.requireTypeAlias("Ty")); - CHECK("table" == toString(ty)); - } - - SUBCASE("top_table_or_non_table") - { - TypeId ty = reductionof("tbl | number"); - CHECK("number | table" == toString(ty)); - } -} // unions_without_negations - -TEST_CASE_FIXTURE(ReductionFixture, "unions_with_negations") -{ - SUBCASE("string_or_not_string") - { - TypeId ty = reductionof("string | Not"); - CHECK("unknown" == toStringFull(ty)); - } - - SUBCASE("not_string_or_string") - { - TypeId ty = reductionof("Not | string"); - CHECK("unknown" == toStringFull(ty)); - } - - SUBCASE("not_number_or_string") - { - TypeId ty = reductionof("Not | string"); - CHECK("~number" == toStringFull(ty)); - } - - SUBCASE("string_or_not_number") - { - TypeId ty = reductionof("string | Not"); - CHECK("~number" == toStringFull(ty)); - } - - SUBCASE("not_hi_or_string_and_not_hi") - { - TypeId ty = reductionof(R"(Not<"hi"> | (string & Not<"hi">))"); - CHECK(R"(~"hi")" == toStringFull(ty)); - } - - SUBCASE("string_and_not_hi_or_not_hi") - { - TypeId ty = reductionof(R"((string & Not<"hi">) | Not<"hi">)"); - CHECK(R"(~"hi")" == toStringFull(ty)); - } - - SUBCASE("string_or_not_never") - { - TypeId ty = reductionof("string | Not"); - CHECK("unknown" == toStringFull(ty)); - } - - SUBCASE("not_a_or_not_a") - { - TypeId ty = reductionof(R"(Not<"a"> | Not<"a">)"); - CHECK(R"(~"a")" == toStringFull(ty)); - } - - SUBCASE("not_a_or_a") - { - TypeId ty = reductionof(R"(Not<"a"> | "a")"); - CHECK("unknown" == toStringFull(ty)); - } - - SUBCASE("a_or_not_a") - { - TypeId ty = reductionof(R"("a" | Not<"a">)"); - CHECK("unknown" == toStringFull(ty)); - } - - SUBCASE("not_a_or_string") - { - TypeId ty = reductionof(R"(Not<"a"> | string)"); - CHECK("unknown" == toStringFull(ty)); - } - - SUBCASE("string_or_not_a") - { - TypeId ty = reductionof(R"(string | Not<"a">)"); - CHECK("unknown" == toStringFull(ty)); - } - - SUBCASE("not_string_or_a") - { - TypeId ty = reductionof(R"(Not | "a")"); - CHECK(R"("a" | ~string)" == toStringFull(ty)); - } - - SUBCASE("a_or_not_string") - { - TypeId ty = reductionof(R"("a" | Not)"); - CHECK(R"("a" | ~string)" == toStringFull(ty)); - } - - SUBCASE("not_number_or_a") - { - TypeId ty = reductionof(R"(Not | "a")"); - CHECK("~number" == toStringFull(ty)); - } - - SUBCASE("a_or_not_number") - { - TypeId ty = reductionof(R"("a" | Not)"); - CHECK("~number" == toStringFull(ty)); - } - - SUBCASE("not_a_or_not_b") - { - TypeId ty = reductionof(R"(Not<"a"> | Not<"b">)"); - CHECK("unknown" == toStringFull(ty)); - } - - SUBCASE("boolean_or_not_false") - { - TypeId ty = reductionof("boolean | Not"); - CHECK("unknown" == toStringFull(ty)); - } - - SUBCASE("boolean_or_not_true") - { - TypeId ty = reductionof("boolean | Not"); - CHECK("unknown" == toStringFull(ty)); - } - - SUBCASE("false_or_not_false") - { - TypeId ty = reductionof("false | Not"); - CHECK("unknown" == toStringFull(ty)); - } - - SUBCASE("true_or_not_false") - { - TypeId ty = reductionof("true | Not"); - CHECK("~false" == toStringFull(ty)); - } - - SUBCASE("not_boolean_or_true") - { - TypeId ty = reductionof("Not | true"); - CHECK("~false" == toStringFull(ty)); - } - - SUBCASE("not_false_or_not_boolean") - { - TypeId ty = reductionof("Not | Not"); - CHECK("~false" == toStringFull(ty)); - } - - SUBCASE("function_type_or_not_function") - { - TypeId ty = reductionof("() -> () | Not"); - CHECK("(() -> ()) | ~function" == toStringFull(ty)); - } - - SUBCASE("not_parent_or_child") - { - TypeId ty = reductionof("Not | Child"); - CHECK("Child | ~Parent" == toStringFull(ty)); - } - - SUBCASE("child_or_not_parent") - { - TypeId ty = reductionof("Child | Not"); - CHECK("Child | ~Parent" == toStringFull(ty)); - } - - SUBCASE("parent_or_not_child") - { - TypeId ty = reductionof("Parent | Not"); - CHECK("unknown" == toStringFull(ty)); - } - - SUBCASE("not_child_or_parent") - { - TypeId ty = reductionof("Not | Parent"); - CHECK("unknown" == toStringFull(ty)); - } - - SUBCASE("parent_or_not_unrelated") - { - TypeId ty = reductionof("Parent | Not"); - CHECK("~Unrelated" == toStringFull(ty)); - } - - SUBCASE("not_string_or_string_and_not_a") - { - TypeId ty = reductionof(R"(Not | (string & Not<"a">))"); - CHECK(R"(~"a")" == toStringFull(ty)); - } - - SUBCASE("not_string_or_not_string") - { - TypeId ty = reductionof("Not | Not"); - CHECK("~string" == toStringFull(ty)); - } - - SUBCASE("not_string_or_not_number") - { - TypeId ty = reductionof("Not | Not"); - CHECK("unknown" == toStringFull(ty)); - } - - SUBCASE("not_a_or_not_boolean") - { - TypeId ty = reductionof(R"(Not<"a"> | Not)"); - CHECK("unknown" == toStringFull(ty)); - } - - SUBCASE("not_a_or_boolean") - { - TypeId ty = reductionof(R"(Not<"a"> | boolean)"); - CHECK(R"(~"a")" == toStringFull(ty)); - } - - SUBCASE("string_or_err") - { - TypeId ty = reductionof("string | Not"); - CHECK("string | ~*error-type*" == toStringFull(ty)); - } - - SUBCASE("not_top_table_or_table") - { - TypeId ty = reductionof("Not | {}"); - CHECK("{| |} | ~table" == toString(ty)); - } - - SUBCASE("not_top_table_or_metatable") - { - BuiltinsFixture fixture; - registerHiddenTypes(&fixture.frontend); - fixture.check(R"( - type Ty = Not | typeof(setmetatable({}, {})) - )"); - - TypeId ty = reductionof(fixture.requireTypeAlias("Ty")); - CHECK("{ @metatable { }, { } } | ~table" == toString(ty)); - } -} // unions_with_negations - -TEST_CASE_FIXTURE(ReductionFixture, "tables") -{ - SUBCASE("reduce_props") - { - TypeId ty = reductionof("{ x: string | string, y: number | number }"); - CHECK("{| x: string, y: number |}" == toStringFull(ty)); - } - - SUBCASE("reduce_indexers") - { - TypeId ty = reductionof("{ [string | string]: number | number }"); - CHECK("{| [string]: number |}" == toStringFull(ty)); - } - - SUBCASE("reduce_instantiated_type_parameters") - { - check(R"( - type Foo = { x: T } - local foo: Foo = { x = "hello" } - )"); - - TypeId ty = reductionof(requireType("foo")); - CHECK("Foo" == toString(ty)); - } - - SUBCASE("reduce_instantiated_type_pack_parameters") - { - check(R"( - type Foo = { x: () -> T... } - local foo: Foo = { x = function() return "hi", 5 end } - )"); - - TypeId ty = reductionof(requireType("foo")); - CHECK("Foo" == toString(ty)); - } - - SUBCASE("reduce_tables_within_tables") - { - TypeId ty = reductionof("{ x: { y: string & number } }"); - CHECK("never" == toStringFull(ty)); - } - - SUBCASE("array_of_never") - { - TypeId ty = reductionof("{never}"); - CHECK("{never}" == toStringFull(ty)); - } -} - -TEST_CASE_FIXTURE(ReductionFixture, "metatables") -{ - SUBCASE("reduce_table_part") - { - TableType table; - table.state = TableState::Sealed; - table.props["x"] = {arena.addType(UnionType{{builtinTypes->stringType, builtinTypes->stringType}})}; - TypeId tableTy = arena.addType(std::move(table)); - - TypeId ty = reductionof(arena.addType(MetatableType{tableTy, arena.addType(TableType{})})); - CHECK("{ @metatable { }, {| x: string |} }" == toStringFull(ty)); - } - - SUBCASE("reduce_metatable_part") - { - TableType table; - table.state = TableState::Sealed; - table.props["x"] = {arena.addType(UnionType{{builtinTypes->stringType, builtinTypes->stringType}})}; - TypeId tableTy = arena.addType(std::move(table)); - - TypeId ty = reductionof(arena.addType(MetatableType{arena.addType(TableType{}), tableTy})); - CHECK("{ @metatable {| x: string |}, { } }" == toStringFull(ty)); - } -} - -TEST_CASE_FIXTURE(ReductionFixture, "functions") -{ - SUBCASE("reduce_parameters") - { - TypeId ty = reductionof("(string | string) -> ()"); - CHECK("(string) -> ()" == toStringFull(ty)); - } - - SUBCASE("reduce_returns") - { - TypeId ty = reductionof("() -> (string | string)"); - CHECK("() -> string" == toStringFull(ty)); - } - - SUBCASE("reduce_parameters_and_returns") - { - TypeId ty = reductionof("(string | string) -> (number | number)"); - CHECK("(string) -> number" == toStringFull(ty)); - } - - SUBCASE("reduce_tail") - { - TypeId ty = reductionof("() -> ...(string | string)"); - CHECK("() -> (...string)" == toStringFull(ty)); - } - - SUBCASE("reduce_head_and_tail") - { - TypeId ty = reductionof("() -> (string | string, number | number, ...(boolean | boolean))"); - CHECK("() -> (string, number, ...boolean)" == toStringFull(ty)); - } - - SUBCASE("reduce_overloaded_functions") - { - TypeId ty = reductionof("((number | number) -> ()) & ((string | string) -> ())"); - CHECK("((number) -> ()) & ((string) -> ())" == toStringFull(ty)); - } -} // functions - -TEST_CASE_FIXTURE(ReductionFixture, "negations") -{ - SUBCASE("not_unknown") - { - TypeId ty = reductionof("Not"); - CHECK("never" == toStringFull(ty)); - } - - SUBCASE("not_never") - { - TypeId ty = reductionof("Not"); - CHECK("unknown" == toStringFull(ty)); - } - - SUBCASE("not_any") - { - TypeId ty = reductionof("Not"); - CHECK("any" == toStringFull(ty)); - } - - SUBCASE("not_not_reduction") - { - TypeId ty = reductionof("Not>"); - CHECK("never" == toStringFull(ty)); - } - - SUBCASE("not_string") - { - TypeId ty = reductionof("Not"); - CHECK("~string" == toStringFull(ty)); - } - - SUBCASE("not_string_or_number") - { - TypeId ty = reductionof("Not"); - CHECK("~number & ~string" == toStringFull(ty)); - } - - SUBCASE("not_string_and_number") - { - TypeId ty = reductionof("Not"); - CHECK("unknown" == toStringFull(ty)); - } - - SUBCASE("not_error") - { - TypeId ty = reductionof("Not"); - CHECK("~*error-type*" == toStringFull(ty)); - } -} // negations - -TEST_CASE_FIXTURE(ReductionFixture, "discriminable_unions") -{ - SUBCASE("cat_or_dog_and_dog") - { - TypeId ty = reductionof(R"(({ tag: "cat", catfood: string } | { tag: "dog", dogfood: string }) & { tag: "dog" })"); - CHECK(R"({| dogfood: string, tag: "dog" |})" == toStringFull(ty)); - } - - SUBCASE("cat_or_dog_and_not_dog") - { - TypeId ty = reductionof(R"(({ tag: "cat", catfood: string } | { tag: "dog", dogfood: string }) & { tag: Not<"dog"> })"); - CHECK(R"({| catfood: string, tag: "cat" |})" == toStringFull(ty)); - } - - SUBCASE("string_or_number_and_number") - { - TypeId ty = reductionof("({ tag: string, a: number } | { tag: number, b: string }) & { tag: string }"); - CHECK("{| a: number, tag: string |}" == toStringFull(ty)); - } - - SUBCASE("string_or_number_and_number") - { - TypeId ty = reductionof("({ tag: string, a: number } | { tag: number, b: string }) & { tag: number }"); - CHECK("{| b: string, tag: number |}" == toStringFull(ty)); - } - - SUBCASE("child_or_unrelated_and_parent") - { - TypeId ty = reductionof("({ tag: Child, x: number } | { tag: Unrelated, y: string }) & { tag: Parent }"); - CHECK("{| tag: Child, x: number |}" == toStringFull(ty)); - } - - SUBCASE("child_or_unrelated_and_not_parent") - { - TypeId ty = reductionof("({ tag: Child, x: number } | { tag: Unrelated, y: string }) & { tag: Not }"); - CHECK("{| tag: Unrelated, y: string |}" == toStringFull(ty)); - } -} - -TEST_CASE_FIXTURE(ReductionFixture, "cycles") -{ - SUBCASE("recursively_defined_function") - { - check("type F = (f: F) -> ()"); - - TypeId ty = reductionof(requireTypeAlias("F")); - CHECK("t1 where t1 = (t1) -> ()" == toStringFull(ty)); - } - - SUBCASE("recursively_defined_function_and_function") - { - check("type F = (f: F & fun) -> ()"); - - TypeId ty = reductionof(requireTypeAlias("F")); - CHECK("t1 where t1 = (function & t1) -> ()" == toStringFull(ty)); - } - - SUBCASE("recursively_defined_table") - { - check("type T = { x: T }"); - - TypeId ty = reductionof(requireTypeAlias("T")); - CHECK("t1 where t1 = {| x: t1 |}" == toStringFull(ty)); - } - - SUBCASE("recursively_defined_table_and_table") - { - check("type T = { x: T & {} }"); - - TypeId ty = reductionof(requireTypeAlias("T")); - CHECK("t1 where t1 = {| x: t1 & {| |} |}" == toStringFull(ty)); - } - - SUBCASE("recursively_defined_table_and_table_2") - { - check("type T = { x: T } & { x: number }"); - - TypeId ty = reductionof(requireTypeAlias("T")); - CHECK("t1 where t1 = {| x: number |} & {| x: t1 |}" == toStringFull(ty)); - } - - SUBCASE("recursively_defined_table_and_table_3") - { - check("type T = { x: T } & { x: T }"); - - TypeId ty = reductionof(requireTypeAlias("T")); - CHECK("t1 where t1 = {| x: t1 |} & {| x: t1 |}" == toStringFull(ty)); - } -} - -TEST_SUITE_END(); diff --git a/tests/TypeVar.test.cpp b/tests/TypeVar.test.cpp index 36e437e24..9e21b1e03 100644 --- a/tests/TypeVar.test.cpp +++ b/tests/TypeVar.test.cpp @@ -2,7 +2,6 @@ #include "Luau/Scope.h" #include "Luau/Type.h" #include "Luau/TypeInfer.h" -#include "Luau/TypeReduction.h" #include "Luau/VisitType.h" #include "Fixture.h" @@ -16,13 +15,13 @@ TEST_SUITE_BEGIN("TypeTests"); TEST_CASE_FIXTURE(Fixture, "primitives_are_equal") { - REQUIRE_EQ(typeChecker.booleanType, typeChecker.booleanType); + REQUIRE_EQ(builtinTypes->booleanType, builtinTypes->booleanType); } TEST_CASE_FIXTURE(Fixture, "bound_type_is_equal_to_that_which_it_is_bound") { - Type bound(BoundType(typeChecker.booleanType)); - REQUIRE_EQ(bound, *typeChecker.booleanType); + Type bound(BoundType(builtinTypes->booleanType)); + REQUIRE_EQ(bound, *builtinTypes->booleanType); } TEST_CASE_FIXTURE(Fixture, "equivalent_cyclic_tables_are_equal") @@ -54,8 +53,8 @@ TEST_CASE_FIXTURE(Fixture, "different_cyclic_tables_are_not_equal") TEST_CASE_FIXTURE(Fixture, "return_type_of_function_is_not_parenthesized_if_just_one_value") { auto emptyArgumentPack = TypePackVar{TypePack{}}; - auto returnPack = TypePackVar{TypePack{{typeChecker.numberType}}}; - auto returnsTwo = Type(FunctionType(typeChecker.globalScope->level, &emptyArgumentPack, &returnPack)); + auto returnPack = TypePackVar{TypePack{{builtinTypes->numberType}}}; + auto returnsTwo = Type(FunctionType(frontend.globals.globalScope->level, &emptyArgumentPack, &returnPack)); std::string res = toString(&returnsTwo); CHECK_EQ("() -> number", res); @@ -64,8 +63,8 @@ TEST_CASE_FIXTURE(Fixture, "return_type_of_function_is_not_parenthesized_if_just TEST_CASE_FIXTURE(Fixture, "return_type_of_function_is_parenthesized_if_not_just_one_value") { auto emptyArgumentPack = TypePackVar{TypePack{}}; - auto returnPack = TypePackVar{TypePack{{typeChecker.numberType, typeChecker.numberType}}}; - auto returnsTwo = Type(FunctionType(typeChecker.globalScope->level, &emptyArgumentPack, &returnPack)); + auto returnPack = TypePackVar{TypePack{{builtinTypes->numberType, builtinTypes->numberType}}}; + auto returnsTwo = Type(FunctionType(frontend.globals.globalScope->level, &emptyArgumentPack, &returnPack)); std::string res = toString(&returnsTwo); CHECK_EQ("() -> (number, number)", res); @@ -74,10 +73,10 @@ TEST_CASE_FIXTURE(Fixture, "return_type_of_function_is_parenthesized_if_not_just TEST_CASE_FIXTURE(Fixture, "return_type_of_function_is_parenthesized_if_tail_is_free") { auto emptyArgumentPack = TypePackVar{TypePack{}}; - auto free = Unifiable::Free(TypeLevel()); + auto free = FreeTypePack(TypeLevel()); auto freePack = TypePackVar{TypePackVariant{free}}; - auto returnPack = TypePackVar{TypePack{{typeChecker.numberType}, &freePack}}; - auto returnsTwo = Type(FunctionType(typeChecker.globalScope->level, &emptyArgumentPack, &returnPack)); + auto returnPack = TypePackVar{TypePack{{builtinTypes->numberType}, &freePack}}; + auto returnsTwo = Type(FunctionType(frontend.globals.globalScope->level, &emptyArgumentPack, &returnPack)); std::string res = toString(&returnsTwo); CHECK_EQ(res, "() -> (number, a...)"); @@ -86,9 +85,9 @@ TEST_CASE_FIXTURE(Fixture, "return_type_of_function_is_parenthesized_if_tail_is_ TEST_CASE_FIXTURE(Fixture, "subset_check") { UnionType super, sub, notSub; - super.options = {typeChecker.numberType, typeChecker.stringType, typeChecker.booleanType}; - sub.options = {typeChecker.numberType, typeChecker.stringType}; - notSub.options = {typeChecker.numberType, typeChecker.nilType}; + super.options = {builtinTypes->numberType, builtinTypes->stringType, builtinTypes->booleanType}; + sub.options = {builtinTypes->numberType, builtinTypes->stringType}; + notSub.options = {builtinTypes->numberType, builtinTypes->nilType}; CHECK(isSubset(super, sub)); CHECK(!isSubset(super, notSub)); @@ -97,7 +96,7 @@ TEST_CASE_FIXTURE(Fixture, "subset_check") TEST_CASE_FIXTURE(Fixture, "iterate_over_UnionType") { UnionType utv; - utv.options = {typeChecker.numberType, typeChecker.stringType, typeChecker.anyType}; + utv.options = {builtinTypes->numberType, builtinTypes->stringType, builtinTypes->anyType}; std::vector result; for (TypeId ty : &utv) @@ -110,19 +109,38 @@ TEST_CASE_FIXTURE(Fixture, "iterating_over_nested_UnionTypes") { Type subunion{UnionType{}}; UnionType* innerUtv = getMutable(&subunion); - innerUtv->options = {typeChecker.numberType, typeChecker.stringType}; + innerUtv->options = {builtinTypes->numberType, builtinTypes->stringType}; UnionType utv; - utv.options = {typeChecker.anyType, &subunion}; + utv.options = {builtinTypes->anyType, &subunion}; std::vector result; for (TypeId ty : &utv) result.push_back(ty); REQUIRE_EQ(result.size(), 3); - CHECK_EQ(result[0], typeChecker.anyType); - CHECK_EQ(result[2], typeChecker.stringType); - CHECK_EQ(result[1], typeChecker.numberType); + CHECK_EQ(result[0], builtinTypes->anyType); + CHECK_EQ(result[2], builtinTypes->stringType); + CHECK_EQ(result[1], builtinTypes->numberType); +} + +TEST_CASE_FIXTURE(Fixture, "iterating_over_nested_UnionTypes_postfix_operator_plus_plus") +{ + Type subunion{UnionType{}}; + UnionType* innerUtv = getMutable(&subunion); + innerUtv->options = {builtinTypes->numberType, builtinTypes->stringType}; + + UnionType utv; + utv.options = {builtinTypes->anyType, &subunion}; + + std::vector result; + for (auto it = begin(&utv); it != end(&utv); it++) + result.push_back(*it); + + REQUIRE_EQ(result.size(), 3); + CHECK_EQ(result[0], builtinTypes->anyType); + CHECK_EQ(result[2], builtinTypes->stringType); + CHECK_EQ(result[1], builtinTypes->numberType); } TEST_CASE_FIXTURE(Fixture, "iterator_detects_cyclic_UnionTypes_and_skips_over_them") @@ -132,8 +150,8 @@ TEST_CASE_FIXTURE(Fixture, "iterator_detects_cyclic_UnionTypes_and_skips_over_th Type btv{UnionType{}}; UnionType* utv2 = getMutable(&btv); - utv2->options.push_back(typeChecker.numberType); - utv2->options.push_back(typeChecker.stringType); + utv2->options.push_back(builtinTypes->numberType); + utv2->options.push_back(builtinTypes->stringType); utv2->options.push_back(&atv); utv1->options.push_back(&btv); @@ -143,14 +161,14 @@ TEST_CASE_FIXTURE(Fixture, "iterator_detects_cyclic_UnionTypes_and_skips_over_th result.push_back(ty); REQUIRE_EQ(result.size(), 2); - CHECK_EQ(result[0], typeChecker.numberType); - CHECK_EQ(result[1], typeChecker.stringType); + CHECK_EQ(result[0], builtinTypes->numberType); + CHECK_EQ(result[1], builtinTypes->stringType); } TEST_CASE_FIXTURE(Fixture, "iterator_descends_on_nested_in_first_operator*") { - Type tv1{UnionType{{typeChecker.stringType, typeChecker.numberType}}}; - Type tv2{UnionType{{&tv1, typeChecker.booleanType}}}; + Type tv1{UnionType{{builtinTypes->stringType, builtinTypes->numberType}}}; + Type tv2{UnionType{{&tv1, builtinTypes->booleanType}}}; auto utv = get(&tv2); std::vector result; @@ -158,19 +176,19 @@ TEST_CASE_FIXTURE(Fixture, "iterator_descends_on_nested_in_first_operator*") result.push_back(ty); REQUIRE_EQ(result.size(), 3); - CHECK_EQ(result[0], typeChecker.stringType); - CHECK_EQ(result[1], typeChecker.numberType); - CHECK_EQ(result[2], typeChecker.booleanType); + CHECK_EQ(result[0], builtinTypes->stringType); + CHECK_EQ(result[1], builtinTypes->numberType); + CHECK_EQ(result[2], builtinTypes->booleanType); } TEST_CASE_FIXTURE(Fixture, "UnionTypeIterator_with_vector_iter_ctor") { - Type tv1{UnionType{{typeChecker.stringType, typeChecker.numberType}}}; - Type tv2{UnionType{{&tv1, typeChecker.booleanType}}}; + Type tv1{UnionType{{builtinTypes->stringType, builtinTypes->numberType}}}; + Type tv2{UnionType{{&tv1, builtinTypes->booleanType}}}; auto utv = get(&tv2); std::vector actual(begin(utv), end(utv)); - std::vector expected{typeChecker.stringType, typeChecker.numberType, typeChecker.booleanType}; + std::vector expected{builtinTypes->stringType, builtinTypes->numberType, builtinTypes->booleanType}; CHECK_EQ(actual, expected); } @@ -273,12 +291,23 @@ TEST_CASE_FIXTURE(Fixture, "substitution_skip_failure") TypeId root = &ttvTweenResult; - typeChecker.currentModule = std::make_shared(); - typeChecker.currentModule->scopes.emplace_back(Location{}, std::make_shared(builtinTypes->anyTypePack)); - - TypeId result = typeChecker.anyify(typeChecker.globalScope, root, Location{}); - - CHECK_EQ("{| f: t1 |} where t1 = () -> {| f: () -> {| f: ({| f: t1 |}) -> (), signal: {| f: (any) -> () |} |} |}", toString(result)); + ModulePtr currentModule = std::make_shared(); + Anyification anyification( + ¤tModule->internalTypes, + frontend.globals.globalScope, + builtinTypes, + &frontend.iceHandler, + builtinTypes->anyType, + builtinTypes->anyTypePack + ); + std::optional any = anyification.substitute(root); + + REQUIRE(!anyification.normalizationTooComplex); + REQUIRE(any.has_value()); + if (FFlag::LuauSolverV2) + CHECK_EQ("{ f: t1 } where t1 = () -> { f: () -> { f: ({ f: t1 }) -> (), signal: { f: (any) -> () } } }", toString(*any)); + else + CHECK_EQ("{| f: t1 |} where t1 = () -> {| f: () -> {| f: ({| f: t1 |}) -> (), signal: {| f: (any) -> () |} |} |}", toString(*any)); } TEST_CASE("tagging_tables") @@ -291,7 +320,7 @@ TEST_CASE("tagging_tables") TEST_CASE("tagging_classes") { - Type base{ClassType{"Base", {}, std::nullopt, std::nullopt, {}, nullptr, "Test"}}; + Type base{ClassType{"Base", {}, std::nullopt, std::nullopt, {}, nullptr, "Test", {}}}; CHECK(!Luau::hasTag(&base, "foo")); Luau::attachTag(&base, "foo"); CHECK(Luau::hasTag(&base, "foo")); @@ -299,8 +328,8 @@ TEST_CASE("tagging_classes") TEST_CASE("tagging_subclasses") { - Type base{ClassType{"Base", {}, std::nullopt, std::nullopt, {}, nullptr, "Test"}}; - Type derived{ClassType{"Derived", {}, &base, std::nullopt, {}, nullptr, "Test"}}; + Type base{ClassType{"Base", {}, std::nullopt, std::nullopt, {}, nullptr, "Test", {}}}; + Type derived{ClassType{"Derived", {}, &base, std::nullopt, {}, nullptr, "Test", {}}}; CHECK(!Luau::hasTag(&base, "foo")); CHECK(!Luau::hasTag(&derived, "foo")); diff --git a/tests/Unifier2.test.cpp b/tests/Unifier2.test.cpp new file mode 100644 index 000000000..657341036 --- /dev/null +++ b/tests/Unifier2.test.cpp @@ -0,0 +1,135 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Luau/Scope.h" +#include "Luau/ToString.h" +#include "Luau/TypeArena.h" +#include "Luau/Unifier2.h" +#include "Luau/Error.h" + +#include "ScopedFlags.h" + +#include "doctest.h" + +using namespace Luau; + +LUAU_FASTFLAG(LuauSolverV2) + +struct Unifier2Fixture +{ + TypeArena arena; + BuiltinTypes builtinTypes; + Scope scope{builtinTypes.anyTypePack}; + InternalErrorReporter iceReporter; + Unifier2 u2{NotNull{&arena}, NotNull{&builtinTypes}, NotNull{&scope}, NotNull{&iceReporter}}; + ToStringOptions opts; + + ScopedFastFlag sff{FFlag::LuauSolverV2, true}; + + std::pair freshType() + { + FreeType ft{&scope, builtinTypes.neverType, builtinTypes.unknownType}; + + TypeId ty = arena.addType(ft); + FreeType* ftv = getMutable(ty); + REQUIRE(ftv != nullptr); + + return {ty, ftv}; + } + + std::string toString(TypeId ty) + { + return ::Luau::toString(ty, opts); + } + + std::string toString(TypePackId ty) + { + return ::Luau::toString(ty, opts); + } +}; + +TEST_SUITE_BEGIN("Unifier2"); + +TEST_CASE_FIXTURE(Unifier2Fixture, "T <: number") +{ + auto [left, freeLeft] = freshType(); + + CHECK(u2.unify(left, builtinTypes.numberType)); + + CHECK("never" == toString(freeLeft->lowerBound)); + CHECK("number" == toString(freeLeft->upperBound)); +} + +TEST_CASE_FIXTURE(Unifier2Fixture, "number <: T") +{ + auto [right, freeRight] = freshType(); + + CHECK(u2.unify(builtinTypes.numberType, right)); + + CHECK("number" == toString(freeRight->lowerBound)); + CHECK("unknown" == toString(freeRight->upperBound)); +} + +TEST_CASE_FIXTURE(Unifier2Fixture, "T <: U") +{ + auto [left, freeLeft] = freshType(); + auto [right, freeRight] = freshType(); + + CHECK(u2.unify(left, right)); + + CHECK("t1 where t1 = ('a <: (t1 <: 'b))" == toString(left)); + CHECK("t1 where t1 = (('a <: t1) <: 'b)" == toString(right)); + + CHECK("never" == toString(freeLeft->lowerBound)); + CHECK("t1 where t1 = (('a <: t1) <: 'b)" == toString(freeLeft->upperBound)); + + CHECK("t1 where t1 = ('a <: (t1 <: 'b))" == toString(freeRight->lowerBound)); + CHECK("unknown" == toString(freeRight->upperBound)); +} + +TEST_CASE_FIXTURE(Unifier2Fixture, "(string) -> () <: (X) -> Y...") +{ + TypeId stringToUnit = arena.addType(FunctionType{arena.addTypePack({builtinTypes.stringType}), arena.addTypePack({})}); + + auto [x, xFree] = freshType(); + TypePackId y = arena.freshTypePack(&scope); + + TypeId xToY = arena.addType(FunctionType{arena.addTypePack({x}), y}); + + u2.unify(stringToUnit, xToY); + + CHECK("string" == toString(xFree->upperBound)); + + const TypePack* yPack = get(follow(y)); + REQUIRE(yPack != nullptr); + + CHECK(0 == yPack->head.size()); + CHECK(!yPack->tail); +} + +TEST_CASE_FIXTURE(Unifier2Fixture, "unify_binds_free_subtype_tail_pack") +{ + TypePackId numberPack = arena.addTypePack({builtinTypes.numberType}); + + TypePackId freeTail = arena.freshTypePack(&scope); + TypeId freeHead = arena.addType(FreeType{&scope, builtinTypes.neverType, builtinTypes.unknownType}); + TypePackId freeAndFree = arena.addTypePack({freeHead}, freeTail); + + u2.unify(freeAndFree, numberPack); + + CHECK("('a <: number)" == toString(freeAndFree)); +} + +TEST_CASE_FIXTURE(Unifier2Fixture, "unify_binds_free_supertype_tail_pack") +{ + TypePackId numberPack = arena.addTypePack({builtinTypes.numberType}); + + TypePackId freeTail = arena.freshTypePack(&scope); + TypeId freeHead = arena.addType(FreeType{&scope, builtinTypes.neverType, builtinTypes.unknownType}); + TypePackId freeAndFree = arena.addTypePack({freeHead}, freeTail); + + u2.unify(numberPack, freeAndFree); + + CHECK("(number <: 'a)" == toString(freeAndFree)); +} + +TEST_SUITE_END(); diff --git a/tests/Variant.test.cpp b/tests/Variant.test.cpp index 83eec519a..4abb77f0c 100644 --- a/tests/Variant.test.cpp +++ b/tests/Variant.test.cpp @@ -177,15 +177,19 @@ TEST_CASE("Visit") // void-returning visitor, const variants std::string r1; visit( - [&](const auto& v) { + [&](const auto& v) + { r1 += ToStringVisitor()(v); }, - v1c); + v1c + ); visit( - [&](const auto& v) { + [&](const auto& v) + { r1 += ToStringVisitor()(v); }, - v2c); + v2c + ); CHECK(r1 == "12345"); // value-returning visitor, const variants @@ -203,17 +207,21 @@ TEST_CASE("Visit") // value-returning visitor, mutable variant std::string r3; r3 += visit( - [&](auto& v) { + [&](auto& v) + { IncrementVisitor()(v); return ToStringVisitor()(v); }, - v1); + v1 + ); r3 += visit( - [&](auto& v) { + [&](auto& v) + { IncrementVisitor()(v); return ToStringVisitor()(v); }, - v2); + v2 + ); CHECK(r3 == "1231147"); } diff --git a/tests/VecDeque.test.cpp b/tests/VecDeque.test.cpp new file mode 100644 index 000000000..c8d10d4a0 --- /dev/null +++ b/tests/VecDeque.test.cpp @@ -0,0 +1,800 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/VecDeque.h" + +#include "doctest.h" +#include + +TEST_SUITE_BEGIN("VecDequeTests"); + +TEST_CASE("forward_queue_test_no_initial_capacity") +{ + // initial capacity is not set, so this should grow to be 11 + Luau::VecDeque queue{}; + + REQUIRE(queue.empty()); + + for (int i = 0; i < 10; i++) + queue.push_back(i); + // q: 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 + + REQUIRE(!queue.empty()); + REQUIRE(queue.size() == 10); + + CHECK_EQ(queue.capacity(), 11); + + for (int j = 0; j < 10; j++) + { + CHECK_EQ(queue.front(), j); + CHECK_EQ(queue.back(), 9); + + REQUIRE(!queue.empty()); + queue.pop_front(); + } +} + +TEST_CASE("forward_queue_test") +{ + // initial capacity set to 5 so that a grow is necessary + Luau::VecDeque queue; + queue.reserve(5); + + REQUIRE(queue.empty()); + + for (int i = 0; i < 10; i++) + queue.push_back(i); + // q: 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 + + REQUIRE(!queue.empty()); + REQUIRE(queue.size() == 10); + + CHECK_EQ(queue.capacity(), 13); + + for (int j = 0; j < 10; j++) + { + CHECK_EQ(queue.front(), j); + CHECK_EQ(queue.back(), 9); + + REQUIRE(!queue.empty()); + queue.pop_front(); + } +} + +TEST_CASE("forward_queue_test_initializer_list") +{ + Luau::VecDeque queue{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}; + // q: 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 + + REQUIRE(!queue.empty()); + REQUIRE(queue.size() == 10); + + CHECK_EQ(queue.capacity(), 10); + + for (int j = 0; j < 10; j++) + { + CHECK_EQ(queue.front(), j); + CHECK_EQ(queue.back(), 9); + + REQUIRE(!queue.empty()); + queue.pop_front(); + } +} + +TEST_CASE("reverse_queue_test") +{ + // initial capacity set to 5 so that a grow is necessary + Luau::VecDeque queue; + queue.reserve(5); + + REQUIRE(queue.empty()); + + for (int i = 0; i < 10; i++) + queue.push_front(i); + // q: 9, 8, 7, 6, 5, 4, 3, 2, 1, 0 + + REQUIRE(!queue.empty()); + REQUIRE(queue.size() == 10); + + CHECK_EQ(queue.capacity(), 13); + + for (int j = 0; j < 10; j++) + { + CHECK_EQ(queue.front(), 9); + CHECK_EQ(queue.back(), j); + + REQUIRE(!queue.empty()); + queue.pop_back(); + } +} + +TEST_CASE("random_access_queue_test") +{ + // initial capacity set to 5 so that a grow is necessary + Luau::VecDeque queue; + queue.reserve(5); + + REQUIRE(queue.empty()); + + for (int i = 0; i < 10; i++) + queue.push_back(i); + // q: 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 + + REQUIRE(!queue.empty()); + REQUIRE(queue.size() == 10); + + for (int j = 0; j < 10; j++) + { + CHECK_EQ(queue.at(j), j); + CHECK_EQ(queue[j], j); + } +} + +TEST_CASE("clear_works_on_queue") +{ + // initial capacity set to 5 so that a grow is necessary + Luau::VecDeque queue; + queue.reserve(5); + + REQUIRE(queue.empty()); + + for (int i = 0; i < 10; i++) + queue.push_back(i); + // q: 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 + + REQUIRE(!queue.empty()); + REQUIRE(queue.size() == 10); + + for (int j = 0; j < 10; j++) + CHECK_EQ(queue[j], j); + + queue.clear(); + CHECK(queue.empty()); + CHECK(queue.size() == 0); +} + +TEST_CASE("pop_front_at_end") +{ + // initial capacity set to 5 so that a grow is necessary + Luau::VecDeque queue; + queue.reserve(5); + + REQUIRE(queue.empty()); + + // setting up the internal buffer to be: 1234567890 by the end (i.e. 0 at the end of the buffer) + queue.push_front(0); + + for (int i = 1; i < 10; i++) + queue.push_back(i); + // q: 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 + + REQUIRE(!queue.empty()); + REQUIRE(queue.size() == 10); + + for (int j = 0; j < 10; j++) + { + CHECK_EQ(queue.front(), j); + CHECK_EQ(queue.back(), 9); + + REQUIRE(!queue.empty()); + queue.pop_front(); + } +} + +TEST_CASE("pop_back_at_front") +{ + // initial capacity set to 5 so that a grow is necessary + Luau::VecDeque queue; + queue.reserve(5); + + REQUIRE(queue.empty()); + + // setting up the internal buffer to be: 9012345678 by the end (i.e. 9 at the front of the buffer) + queue.push_back(0); + + for (int i = 1; i < 10; i++) + queue.push_front(i); + // q: 9, 8, 7, 6, 5, 4, 3, 2, 1, 0 + + REQUIRE(!queue.empty()); + REQUIRE(queue.size() == 10); + + for (int j = 0; j < 10; j++) + { + CHECK_EQ(queue.front(), 9); + CHECK_EQ(queue.back(), j); + + REQUIRE(!queue.empty()); + queue.pop_back(); + } +} + +TEST_CASE("queue_is_contiguous") +{ + // initial capacity is not set, so this should grow to be 11 + Luau::VecDeque queue{}; + + REQUIRE(queue.empty()); + + for (int i = 0; i < 10; i++) + queue.push_back(i); + // q: 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 + + REQUIRE(!queue.empty()); + REQUIRE(queue.size() == 10); + + CHECK_EQ(queue.capacity(), 11); + CHECK(queue.is_contiguous()); +} + +TEST_CASE("queue_is_not_contiguous") +{ + // initial capacity is not set, so this should grow to be 11 + Luau::VecDeque queue{}; + + REQUIRE(queue.empty()); + + for (int i = 5; i < 10; i++) + queue.push_back(i); + for (int i = 4; i >= 0; i--) + queue.push_front(i); + // buffer: 56789......01234 + // q: 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 + + REQUIRE(!queue.empty()); + REQUIRE(queue.size() == 10); + + CHECK_EQ(queue.capacity(), 11); + CHECK(!queue.is_contiguous()); + + // checking that it is indeed sequential integers from 0 to 9 + for (int j = 0; j < 10; j++) + CHECK_EQ(queue[j], j); +} + +TEST_CASE("shrink_to_fit_works") +{ + // initial capacity is not set, so this should grow to be 11 + Luau::VecDeque queue{}; + + REQUIRE(queue.empty()); + + for (int i = 5; i < 10; i++) + queue.push_back(i); + for (int i = 4; i >= 0; i--) + queue.push_front(i); + // buffer: 56789......01234 + // q: 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 + + REQUIRE(!queue.empty()); + REQUIRE(queue.size() == 10); + + REQUIRE_EQ(queue.capacity(), 11); + CHECK(!queue.is_contiguous()); + + // checking that it is indeed sequential integers from 0 to 9 + for (int j = 0; j < 10; j++) + CHECK_EQ(queue[j], j); + + queue.shrink_to_fit(); + // shrink to fit always makes a contiguous buffer + CHECK(queue.is_contiguous()); + // the capacity should be exactly the size now + CHECK_EQ(queue.capacity(), queue.size()); + + REQUIRE(!queue.empty()); + + // checking that it is still sequential integers from 0 to 9 + for (int j = 0; j < 10; j++) + CHECK_EQ(queue[j], j); +} + +const static std::string testStringSet[2][10] = { + + // To hit potential SSO issues showing memory management issues, we need small strings + {"one", "two", "three", "four", "five", "six", "seven", "eight", "nine", "ten"}, + + // This list of non-SSO test strings consists of quotes from Ursula K. Le Guin. + {"Love doesn't just sit there, like a stone, it has to be made, like bread; remade all the time, made new.", + "People who deny the existence of dragons are often eaten by dragons. From within.", + "It is good to have an end to journey toward; but it is the journey that matters, in the end.", + "We're each of us alone, to be sure. What can you do but hold your hand out in the dark?", + "When you light a candle, you also cast a shadow.", + "You cannot buy the revolution. You cannot make the revolution. You can only be the revolution. It is in your spirit, or it is nowhere.", + "To learn which questions are unanswerable, and not to answer them: this skill is most needful in times of stress and darkness.", + "What sane person could live in this world and not be crazy?", + "The only thing that makes life possible is permanent, intolerable uncertainty: not knowing what comes next.", + "My imagination makes me human and makes me a fool; it gives me all the world and exiles me from it."} +}; + +TEST_CASE("string_queue_test_no_initial_capacity") +{ + for (size_t stringSet = 0; stringSet < 2; stringSet++) + { + auto testStrings = testStringSet[stringSet]; + + // initial capacity is not set, so this should grow to be 11 + Luau::VecDeque queue; + + REQUIRE(queue.empty()); + + for (int i = 0; i < 10; i++) + queue.push_back(testStrings[i]); + // q: 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 + + REQUIRE(!queue.empty()); + REQUIRE(queue.size() == 10); + + CHECK_EQ(queue.capacity(), 11); + + for (int j = 0; j < 10; j++) + { + CHECK_EQ(queue.front(), testStrings[j]); + CHECK_EQ(queue.back(), testStrings[9]); + + REQUIRE(!queue.empty()); + queue.pop_front(); + } + } +} + +TEST_CASE("string_queue_test") +{ + for (size_t stringSet = 0; stringSet < 2; stringSet++) + { + auto testStrings = testStringSet[stringSet]; + + // initial capacity set to 5 so that a grow is necessary + Luau::VecDeque queue; + queue.reserve(5); + + REQUIRE(queue.empty()); + + for (int i = 0; i < 10; i++) + queue.push_back(testStrings[i]); + // q: 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 + + REQUIRE(!queue.empty()); + REQUIRE(queue.size() == 10); + + CHECK_EQ(queue.capacity(), 13); + + for (int j = 0; j < 10; j++) + { + CHECK_EQ(queue.front(), testStrings[j]); + CHECK_EQ(queue.back(), testStrings[9]); + + REQUIRE(!queue.empty()); + queue.pop_front(); + } + } +} + +TEST_CASE("string_queue_test_initializer_list") +{ + for (size_t stringSet = 0; stringSet < 2; stringSet++) + { + auto testStrings = testStringSet[stringSet]; + + Luau::VecDeque queue{ + testStrings[0], + testStrings[1], + testStrings[2], + testStrings[3], + testStrings[4], + testStrings[5], + testStrings[6], + testStrings[7], + testStrings[8], + testStrings[9], + }; + // q: 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 + + REQUIRE(!queue.empty()); + REQUIRE(queue.size() == 10); + + CHECK_EQ(queue.capacity(), 10); + + for (int j = 0; j < 10; j++) + { + CHECK_EQ(queue.front(), testStrings[j]); + CHECK_EQ(queue.back(), testStrings[9]); + + REQUIRE(!queue.empty()); + queue.pop_front(); + } + } +} + +TEST_CASE("reverse_string_queue_test") +{ + for (size_t stringSet = 0; stringSet < 2; stringSet++) + { + auto testStrings = testStringSet[stringSet]; + + // initial capacity set to 5 so that a grow is necessary + Luau::VecDeque queue; + queue.reserve(5); + + REQUIRE(queue.empty()); + + for (int i = 0; i < 10; i++) + queue.push_front(testStrings[i]); + // q: 9, 8, 7, 6, 5, 4, 3, 2, 1, 0 + + REQUIRE(!queue.empty()); + REQUIRE(queue.size() == 10); + + CHECK_EQ(queue.capacity(), 13); + + for (int j = 0; j < 10; j++) + { + CHECK_EQ(queue.front(), testStrings[9]); + CHECK_EQ(queue.back(), testStrings[j]); + + REQUIRE(!queue.empty()); + queue.pop_back(); + } + } +} + +TEST_CASE("random_access_string_queue_test") +{ + for (size_t stringSet = 0; stringSet < 2; stringSet++) + { + auto testStrings = testStringSet[stringSet]; + + // initial capacity set to 5 so that a grow is necessary + Luau::VecDeque queue; + queue.reserve(5); + + REQUIRE(queue.empty()); + + for (int i = 0; i < 10; i++) + queue.push_back(testStrings[i]); + // q: 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 + + REQUIRE(!queue.empty()); + REQUIRE(queue.size() == 10); + + for (int j = 0; j < 10; j++) + { + CHECK_EQ(queue.at(j), testStrings[j]); + CHECK_EQ(queue[j], testStrings[j]); + } + } +} + +TEST_CASE("clear_works_on_string_queue") +{ + for (size_t stringSet = 0; stringSet < 2; stringSet++) + { + auto testStrings = testStringSet[stringSet]; + + // initial capacity set to 5 so that a grow is necessary + Luau::VecDeque queue; + queue.reserve(5); + + REQUIRE(queue.empty()); + + for (int i = 0; i < 10; i++) + queue.push_back(testStrings[i]); + // q: 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 + + REQUIRE(!queue.empty()); + REQUIRE(queue.size() == 10); + + for (int j = 0; j < 10; j++) + CHECK_EQ(queue[j], testStrings[j]); + + queue.clear(); + CHECK(queue.empty()); + CHECK(queue.size() == 0); + } +} + +TEST_CASE("pop_front_string_at_end") +{ + for (size_t stringSet = 0; stringSet < 2; stringSet++) + { + auto testStrings = testStringSet[stringSet]; + + // initial capacity set to 5 so that a grow is necessary + Luau::VecDeque queue; + queue.reserve(5); + + REQUIRE(queue.empty()); + + // setting up the internal buffer to be: 1234567890 by the end (i.e. 0 at the end of the buffer) + queue.push_front(testStrings[0]); + + for (int i = 1; i < 10; i++) + queue.push_back(testStrings[i]); + // q: 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 + + REQUIRE(!queue.empty()); + REQUIRE(queue.size() == 10); + + for (int j = 0; j < 10; j++) + { + CHECK_EQ(queue.front(), testStrings[j]); + CHECK_EQ(queue.back(), testStrings[9]); + + REQUIRE(!queue.empty()); + queue.pop_front(); + } + } +} + +TEST_CASE("pop_back_string_at_front") +{ + for (size_t stringSet = 0; stringSet < 2; stringSet++) + { + auto testStrings = testStringSet[stringSet]; + + // initial capacity set to 5 so that a grow is necessary + Luau::VecDeque queue; + queue.reserve(5); + + REQUIRE(queue.empty()); + + // setting up the internal buffer to be: 9012345678 by the end (i.e. 9 at the front of the buffer) + queue.push_back(testStrings[0]); + + for (int i = 1; i < 10; i++) + queue.push_front(testStrings[i]); + // q: 9, 8, 7, 6, 5, 4, 3, 2, 1, 0 + + REQUIRE(!queue.empty()); + REQUIRE(queue.size() == 10); + + for (int j = 0; j < 10; j++) + { + CHECK_EQ(queue.front(), testStrings[9]); + CHECK_EQ(queue.back(), testStrings[j]); + + REQUIRE(!queue.empty()); + queue.pop_back(); + } + } +} + +TEST_CASE("string_queue_is_contiguous") +{ + for (size_t stringSet = 0; stringSet < 2; stringSet++) + { + auto testStrings = testStringSet[stringSet]; + + // initial capacity is not set, so this should grow to be 11 + Luau::VecDeque queue{}; + + REQUIRE(queue.empty()); + + for (int i = 0; i < 10; i++) + queue.push_back(testStrings[i]); + // q: 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 + + REQUIRE(!queue.empty()); + REQUIRE(queue.size() == 10); + + CHECK_EQ(queue.capacity(), 11); + CHECK(queue.is_contiguous()); + + for (int j = 0; j < 10; j++) + CHECK_EQ(queue[j], testStrings[j]); + + // Check copy construction + Luau::VecDeque queue2 = queue; + + REQUIRE(!queue2.empty()); + REQUIRE(queue2.size() == 10); + + CHECK_EQ(queue2.capacity(), 11); + CHECK(queue2.is_contiguous()); + + for (int j = 0; j < 10; j++) + CHECK_EQ(queue2[j], testStrings[j]); + + // Check copy assignment + Luau::VecDeque queue3; + queue3 = queue; + + REQUIRE(!queue3.empty()); + REQUIRE(queue3.size() == 10); + + CHECK_EQ(queue3.capacity(), 11); + CHECK(queue3.is_contiguous()); + + for (int j = 0; j < 10; j++) + CHECK_EQ(queue3[j], testStrings[j]); + + // Check move construction + Luau::VecDeque queue4 = std::move(queue3); + + REQUIRE(!queue4.empty()); + REQUIRE(queue4.size() == 10); + + CHECK_EQ(queue4.capacity(), 11); + CHECK(queue4.is_contiguous()); + + for (int j = 0; j < 10; j++) + CHECK_EQ(queue4[j], testStrings[j]); + + // Check move assignment + Luau::VecDeque queue5; + queue5 = std::move(queue2); + + REQUIRE(!queue5.empty()); + REQUIRE(queue5.size() == 10); + + CHECK_EQ(queue5.capacity(), 11); + CHECK(queue5.is_contiguous()); + + for (int j = 0; j < 10; j++) + CHECK_EQ(queue5[j], testStrings[j]); + } +} + +TEST_CASE("string_queue_is_not_contiguous") +{ + for (size_t stringSet = 0; stringSet < 2; stringSet++) + { + auto testStrings = testStringSet[stringSet]; + + // initial capacity is not set, so this should grow to be 11 + Luau::VecDeque queue{}; + + REQUIRE(queue.empty()); + + for (int i = 5; i < 10; i++) + queue.push_back(testStrings[i]); + for (int i = 4; i >= 0; i--) + queue.push_front(testStrings[i]); + // buffer: 56789......01234 + // q: 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 + + REQUIRE(!queue.empty()); + REQUIRE(queue.size() == 10); + + CHECK_EQ(queue.capacity(), 11); + CHECK(!queue.is_contiguous()); + + // checking that it is indeed sequential integers from 0 to 9 + for (int j = 0; j < 10; j++) + CHECK_EQ(queue[j], testStrings[j]); + + // Check copy construction + Luau::VecDeque queue2 = queue; + + REQUIRE(!queue2.empty()); + REQUIRE(queue2.size() == 10); + + CHECK_EQ(queue2.capacity(), 11); + CHECK(!queue2.is_contiguous()); + + for (int j = 0; j < 10; j++) + CHECK_EQ(queue2[j], testStrings[j]); + + // Check copy assignment + Luau::VecDeque queue3; + queue3 = queue; + + REQUIRE(!queue3.empty()); + REQUIRE(queue3.size() == 10); + + CHECK_EQ(queue3.capacity(), 11); + CHECK(queue3.is_contiguous()); + + for (int j = 0; j < 10; j++) + CHECK_EQ(queue3[j], testStrings[j]); + + // Check move construction + Luau::VecDeque queue4 = std::move(queue); + + REQUIRE(!queue4.empty()); + REQUIRE(queue4.size() == 10); + + CHECK_EQ(queue4.capacity(), 11); + CHECK(!queue4.is_contiguous()); + + for (int j = 0; j < 10; j++) + CHECK_EQ(queue4[j], testStrings[j]); + + // Check move assignment + Luau::VecDeque queue5; + queue5 = std::move(queue2); + + REQUIRE(!queue5.empty()); + REQUIRE(queue5.size() == 10); + + CHECK_EQ(queue5.capacity(), 11); + CHECK(!queue5.is_contiguous()); + + for (int j = 0; j < 10; j++) + CHECK_EQ(queue5[j], testStrings[j]); + + // Check that grow from discontiguous is handled well + queue4.push_back("zero"); + queue4.push_back("?"); + + for (int j = 0; j < 10; j++) + CHECK_EQ(queue4[j], testStrings[j]); + CHECK_EQ(queue4[10], "zero"); + CHECK_EQ(queue4[11], "?"); + + // Check that reserve from discontiguous is handled well + queue5.reserve(20); + + for (int j = 0; j < 10; j++) + CHECK_EQ(queue5[j], testStrings[j]); + } +} + +TEST_CASE("shrink_to_fit_works_with_strings") +{ + for (size_t stringSet = 0; stringSet < 2; stringSet++) + { + auto testStrings = testStringSet[stringSet]; + + // initial capacity is not set, so this should grow to be 11 + Luau::VecDeque queue{}; + + REQUIRE(queue.empty()); + + for (int i = 5; i < 10; i++) + queue.push_back(testStrings[i]); + for (int i = 4; i >= 0; i--) + queue.push_front(testStrings[i]); + // buffer: 56789......01234 + // q: 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 + + REQUIRE(!queue.empty()); + REQUIRE(queue.size() == 10); + + REQUIRE_EQ(queue.capacity(), 11); + CHECK(!queue.is_contiguous()); + + // checking that it is indeed sequential integers from 0 to 9 + for (int j = 0; j < 10; j++) + CHECK_EQ(queue[j], testStrings[j]); + + queue.shrink_to_fit(); + // shrink to fit always makes a contiguous buffer + CHECK(queue.is_contiguous()); + // the capacity should be exactly the size now + CHECK_EQ(queue.capacity(), queue.size()); + + REQUIRE(!queue.empty()); + + // checking that it is still sequential integers from 0 to 9 + for (int j = 0; j < 10; j++) + CHECK_EQ(queue[j], testStrings[j]); + } +} + +struct TestStruct +{ +}; + +// Verify that elements pushed to the front of the queue are properly destroyed when the queue is destroyed. +TEST_CASE("push_front_elements_are_destroyed_correctly") +{ + std::shared_ptr t = std::make_shared(); + { + Luau::VecDeque> queue{}; + REQUIRE(queue.empty()); + queue.reserve(10); + queue.push_front(t); + queue.push_front(t); + REQUIRE(t.use_count() == 3); // Num of references to the TestStruct instance is now 3 + // <-- call destructor here + + // Extra check for correct copies + Luau::VecDeque> queue2 = queue; + Luau::VecDeque> queue3; + queue3 = queue; + } + + // At this point the destructor should be called and we should be back down to one instance of TestStruct + REQUIRE(t.use_count() == 1); +} + +TEST_SUITE_END(); diff --git a/tests/VisitType.test.cpp b/tests/VisitType.test.cpp index 4fba694a8..186afaa52 100644 --- a/tests/VisitType.test.cpp +++ b/tests/VisitType.test.cpp @@ -8,26 +8,40 @@ using namespace Luau; -LUAU_FASTINT(LuauVisitRecursionLimit) +LUAU_FASTINT(LuauVisitRecursionLimit); +LUAU_FASTFLAG(LuauSolverV2) -TEST_SUITE_BEGIN("VisitTypeVar"); +TEST_SUITE_BEGIN("VisitType"); TEST_CASE_FIXTURE(Fixture, "throw_when_limit_is_exceeded") { - ScopedFastInt sfi{"LuauVisitRecursionLimit", 3}; + if (FFlag::LuauSolverV2) + { + CheckResult result = check(R"( + local t : {a: {b: {c: {d: {e: boolean}}}}} + )"); + ScopedFastInt sfi{FInt::LuauVisitRecursionLimit, 3}; + TypeId tType = requireType("t"); - CheckResult result = check(R"( - local t : {a: {b: {c: {d: {e: boolean}}}}} - )"); + CHECK_THROWS_AS(toString(tType), RecursionLimitException); + } + else + { + ScopedFastInt sfi{FInt::LuauVisitRecursionLimit, 3}; - TypeId tType = requireType("t"); + CheckResult result = check(R"( + local t : {a: {b: {c: {d: {e: boolean}}}}} + )"); - CHECK_THROWS_AS(toString(tType), RecursionLimitException); + TypeId tType = requireType("t"); + + CHECK_THROWS_AS(toString(tType), RecursionLimitException); + } } TEST_CASE_FIXTURE(Fixture, "dont_throw_when_limit_is_high_enough") { - ScopedFastInt sfi{"LuauVisitRecursionLimit", 8}; + ScopedFastInt sfi{FInt::LuauVisitRecursionLimit, 8}; CheckResult result = check(R"( local t : {a: {b: {c: {d: {e: boolean}}}}} @@ -38,4 +52,21 @@ TEST_CASE_FIXTURE(Fixture, "dont_throw_when_limit_is_high_enough") (void)toString(tType); } +TEST_CASE_FIXTURE(Fixture, "some_free_types_do_not_have_bounds") +{ + Type t{FreeType{TypeLevel{}}}; + + (void)toString(&t); +} + +TEST_CASE_FIXTURE(Fixture, "some_free_types_have_bounds") +{ + ScopedFastFlag sff{FFlag::LuauSolverV2, true}; + + Scope scope{builtinTypes->anyTypePack}; + Type t{FreeType{&scope, builtinTypes->neverType, builtinTypes->numberType}}; + + CHECK("('a <: number)" == toString(&t)); +} + TEST_SUITE_END(); diff --git a/tests/conformance/apicalls.lua b/tests/conformance/apicalls.lua index 274166237..8db62d96c 100644 --- a/tests/conformance/apicalls.lua +++ b/tests/conformance/apicalls.lua @@ -22,4 +22,12 @@ function getpi() return pi end +function largealloc() + table.create(1000000) +end + +function oops() + return "oops" +end + return('OK') diff --git a/tests/conformance/basic.lua b/tests/conformance/basic.lua index e23c1a53f..98f8000ed 100644 --- a/tests/conformance/basic.lua +++ b/tests/conformance/basic.lua @@ -91,6 +91,15 @@ assert((function() local a = 1 a = a + 2 return a end)() == 3) assert((function() local a = 1 a = a - 2 return a end)() == -1) assert((function() local a = 1 a = a * 2 return a end)() == 2) assert((function() local a = 1 a = a / 2 return a end)() == 0.5) + +-- floor division should always round towards -Infinity +assert((function() local a = 1 a = a // 2 return a end)() == 0) +assert((function() local a = 3 a = a // 2 return a end)() == 1) +assert((function() local a = 3.5 a = a // 2 return a end)() == 1) +assert((function() local a = -1 a = a // 2 return a end)() == -1) +assert((function() local a = -3 a = a // 2 return a end)() == -2) +assert((function() local a = -3.5 a = a // 2 return a end)() == -2) + assert((function() local a = 5 a = a % 2 return a end)() == 1) assert((function() local a = 3 a = a ^ 2 return a end)() == 9) assert((function() local a = 3 a = a ^ 3 return a end)() == 27) @@ -120,6 +129,8 @@ assert((function() local a a = nil local b = 2 b = a and b return b end)() == ni assert((function() local a a = 1 local b = 2 b = a or b return b end)() == 1) assert((function() local a a = nil local b = 2 b = a or b return b end)() == 2) +assert((function(a) return 12 % a end)(5) == 2) + -- binary arithmetics coerces strings to numbers (sadly) assert(1 + "2" == 3) assert(2 * "0xa" == 20) @@ -166,6 +177,33 @@ assert((function() local a = 1 for b=1,9 do a = a * 2 if a == 128 then break els -- make sure internal index is protected against modification assert((function() local a = 1 for b=9,1,-2 do a = a * 2 b = nil end return a end)() == 32) +-- make sure that when step is 0, we treat it as backward iteration (and as such, iterate zero times or indefinitely) +-- this is consistent with Lua 5.1; future Lua versions emit an error when step is 0; LuaJIT instead treats 0 as forward iteration +-- we repeat tests twice, with and without constant folding +local zero = tonumber("0") +assert((function() local c = 0 for i=1,10,0 do c += 1 if c > 10 then break end end return c end)() == 0) +assert((function() local c = 0 for i=10,1,0 do c += 1 if c > 10 then break end end return c end)() == 11) +assert((function() local c = 0 for i=1,10,zero do c += 1 if c > 10 then break end end return c end)() == 0) +assert((function() local c = 0 for i=10,1,zero do c += 1 if c > 10 then break end end return c end)() == 11) + +-- make sure that when limit is nan, we iterate zero times (this is consistent with Lua 5.1; future Lua versions break this) +-- we repeat tests twice, with and without constant folding +local nan = tonumber("nan") +assert((function() local c = 0 for i=1,0/0 do c += 1 end return c end)() == 0) +assert((function() local c = 0 for i=1,0/0,-1 do c += 1 end return c end)() == 0) +assert((function() local c = 0 for i=1,nan do c += 1 end return c end)() == 0) +assert((function() local c = 0 for i=1,nan,-1 do c += 1 end return c end)() == 0) + +-- make sure that when step is nan, we treat it as backward iteration and as such iterate once iff start<=limit +assert((function() local c = 0 for i=1,10,0/0 do c += 1 end return c end)() == 0) +assert((function() local c = 0 for i=10,1,0/0 do c += 1 end return c end)() == 1) +assert((function() local c = 0 for i=1,10,nan do c += 1 end return c end)() == 0) +assert((function() local c = 0 for i=10,1,nan do c += 1 end return c end)() == 1) + +-- make sure that when index becomes nan mid-iteration, we correctly exit the loop (this is broken in Lua 5.1; future Lua versions fix this) +assert((function() local c = 0 for i=-math.huge,0,math.huge do c += 1 end return c end)() == 1) +assert((function() local c = 0 for i=math.huge,math.huge,-math.huge do c += 1 end return c end)() == 1) + -- generic for -- ipairs assert((function() local a = '' for k in ipairs({5, 6, 7}) do a = a .. k end return a end)() == "123") @@ -275,6 +313,10 @@ assert((function() return result end)() == "ArcticDunesCanyonsWaterMountainsHillsLavaflowPlainsMarsh") +-- table literals may contain duplicate fields; the language doesn't specify assignment order but we currently assign left to right +assert((function() local t = {data = 4, data = nil, data = 42} return t.data end)() == 42) +assert((function() local t = {data = 4, data = nil, data = 42, data = nil} return t.data end)() == nil) + -- multiple returns -- local= assert((function() function foo() return 2, 3, 4 end local a, b, c = foo() return ''..a..b..c end)() == "234") @@ -492,6 +534,7 @@ local function vec3t(x, y, z) __sub = function(l, r) return vec3t(l.x - r.x, l.y - r.y, l.z - r.z) end, __mul = function(l, r) return type(r) == "number" and vec3t(l.x * r, l.y * r, l.z * r) or vec3t(l.x * r.x, l.y * r.y, l.z * r.z) end, __div = function(l, r) return type(r) == "number" and vec3t(l.x / r, l.y / r, l.z / r) or vec3t(l.x / r.x, l.y / r.y, l.z / r.z) end, + __idiv = function(l, r) return type(r) == "number" and vec3t(l.x // r, l.y // r, l.z // r) or vec3t(l.x // r.x, l.y // r.y, l.z // r.z) end, __unm = function(v) return vec3t(-v.x, -v.y, -v.z) end, __tostring = function(v) return string.format("%g, %g, %g", v.x, v.y, v.z) end }) @@ -502,10 +545,13 @@ assert((function() return tostring(vec3t(1,2,3) + vec3t(4,5,6)) end)() == "5, 7, assert((function() return tostring(vec3t(1,2,3) - vec3t(4,5,6)) end)() == "-3, -3, -3") assert((function() return tostring(vec3t(1,2,3) * vec3t(4,5,6)) end)() == "4, 10, 18") assert((function() return tostring(vec3t(1,2,3) / vec3t(2,4,8)) end)() == "0.5, 0.5, 0.375") +assert((function() return tostring(vec3t(1,2,3) // vec3t(2,4,2)) end)() == "0, 0, 1") +assert((function() return tostring(vec3t(1,2,3) // vec3t(-2,-4,-2)) end)() == "-1, -1, -2") -- reg vs constant assert((function() return tostring(vec3t(1,2,3) * 2) end)() == "2, 4, 6") assert((function() return tostring(vec3t(1,2,3) / 2) end)() == "0.5, 1, 1.5") +assert((function() return tostring(vec3t(1,2,3) // 2) end)() == "0, 1, 1") -- unary assert((function() return tostring(-vec3t(1,2,3)) end)() == "-1, -2, -3") @@ -680,16 +726,15 @@ assert((function() return sum end)() == 15) --- the reason why this test is interesting is that the table created here has arraysize=0 and a single hash element with key = 1.0 --- ipairs must iterate through that +-- ipairs will not iterate through hash part assert((function() - local arr = { [1] = 42 } + local arr = { [1] = 1, [42] = 42, x = 10 } local sum = 0 for i,v in ipairs(arr) do sum = sum + v end return sum -end)() == 42) +end)() == 1) -- the reason why this test is interesting is it ensures we do correct mutability analysis for locals local function chainTest(n) @@ -942,6 +987,11 @@ end)(true) == 5050) assert(pcall(typeof) == false) assert(pcall(type) == false) +function nothing() end + +assert(pcall(function() return typeof(nothing()) end) == false) +assert(pcall(function() return type(nothing()) end) == false) + -- typeof == type in absence of custom userdata assert(concat(typeof(5), typeof(nil), typeof({}), typeof(newproxy())) == "number,nil,table,userdata") diff --git a/tests/conformance/bitwise.lua b/tests/conformance/bitwise.lua index 3b117892d..c25365089 100644 --- a/tests/conformance/bitwise.lua +++ b/tests/conformance/bitwise.lua @@ -72,6 +72,7 @@ for _, b in pairs(c) do assert(bit32.bxor(b, b) == 0) assert(bit32.bxor(b, 0) == b) assert(bit32.bxor(b, b, b) == b) + assert(bit32.bxor(b, b, b, b) == 0) assert(bit32.bnot(b) ~= b) assert(bit32.bnot(bit32.bnot(b)) == b) assert(bit32.bnot(b) == 2^32 - 1 - b) @@ -101,6 +102,7 @@ assert(bit32.extract(0xa0001111, 28, 4) == 0xa) assert(bit32.extract(0xa0001111, 31, 1) == 1) assert(bit32.extract(0x50000111, 31, 1) == 0) assert(bit32.extract(0xf2345679, 0, 32) == 0xf2345679) +assert(bit32.extract(0xa0001111, 0) == 1) assert(bit32.extract(0xa0001111, 16) == 0) assert(bit32.extract(0xa0001111, 31) == 1) assert(bit32.extract(42, 1, 3) == 5) @@ -134,6 +136,16 @@ assert(bit32.countrz(0x80000000) == 31) assert(bit32.countrz(0x40000000) == 30) assert(bit32.countrz(0x7fffffff) == 0) +-- testing byteswap +assert(bit32.byteswap(0x10203040) == 0x40302010) +assert(bit32.byteswap(0) == 0) +assert(bit32.byteswap(-1) == 0xffffffff) + +-- bit32.bor(n, 0) must clear top bits +-- we check this obscuring the constant through a global to make sure this gets evaluated fully +high32 = 0x42_1234_5678 +assert(bit32.bor(high32, 0) == 0x1234_5678) + --[[ This test verifies a fix in luauF_replace() where if the 4th parameter was not a number, but the first three are numbers, it will @@ -164,5 +176,6 @@ assert(bit32.btest("1", 3) == true) assert(bit32.countlz("42") == 26) assert(bit32.countrz("42") == 1) assert(bit32.extract("42", 1, 3) == 5) +assert(bit32.byteswap("0xa1b2c3d4") == 0xd4c3b2a1) return('OK') diff --git a/tests/conformance/buffers.lua b/tests/conformance/buffers.lua new file mode 100644 index 000000000..5da2a688c --- /dev/null +++ b/tests/conformance/buffers.lua @@ -0,0 +1,626 @@ +-- This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +print("testing byte buffer library") + +function call(fn, ...) + local ok, res = pcall(fn, ...) + assert(ok) + return res +end + +function ecall(fn, ...) + local ok, err = pcall(fn, ...) + assert(not ok) + return err:sub((err:find(": ") or -1) + 2, #err) +end + +local function simple_byte_reads() + local b = buffer.create(1024) + + assert(buffer.len(b) == 1024) + + assert(buffer.readi8(b, 5) == 0) + buffer.writei8(b, 10, 32) + assert(buffer.readi8(b, 10) == 32) + buffer.writei8(b, 15, 5) + buffer.writei8(b, 14, 4) + buffer.writei8(b, 13, 3) + buffer.writei8(b, 12, 2) + buffer.writei8(b, 11, 1) + assert(buffer.readi8(b, 11) == 1) + assert(buffer.readi8(b, 12) == 2) + assert(buffer.readi8(b, 13) == 3) + assert(buffer.readi8(b, 14) == 4) + assert(buffer.readi8(b, 15) == 5) + + local x = buffer.readi8(b, 14) + buffer.readi8(b, 13) + assert(x == 7) + + buffer.writei8(b, 16, x) +end + +simple_byte_reads() + +local function offset_byte_reads(start: number) + local b = buffer.create(1024) + + buffer.writei8(b, start, 32) + assert(buffer.readi8(b, start) == 32) + buffer.writei8(b, start + 5, 5) + buffer.writei8(b, start + 4, 4) + buffer.writei8(b, start + 3, 3) + buffer.writei8(b, start + 2, 2) + buffer.writei8(b, start + 1, 1) + assert(buffer.readi8(b, start + 1) == 1) + assert(buffer.readi8(b, start + 2) == 2) + assert(buffer.readi8(b, start + 3) == 3) + assert(buffer.readi8(b, start + 4) == 4) + assert(buffer.readi8(b, start + 5) == 5) + + local x = buffer.readi8(b, start + 4) + buffer.readi8(b, start + 3) + assert(x == 7) +end + +offset_byte_reads(5) +offset_byte_reads(30) + +local function simple_float_reinterpret() + local b = buffer.create(1024) + + buffer.writei32(b, 10, 0x3f800000) + local one = buffer.readf32(b, 10) + assert(one == 1.0) + + buffer.writef32(b, 10, 2.75197) + local magic = buffer.readi32(b, 10) + assert(magic == 0x40302047) + + buffer.writef32(b, 10, one) + local magic2 = buffer.readi32(b, 10) + + assert(magic2 == 0x3f800000) +end + +simple_float_reinterpret() + +local function simple_double_reinterpret() + local b = buffer.create(1024) + + buffer.writei32(b, 10, 0x00000000) + buffer.writei32(b, 14, 0x3ff00000) + local one = buffer.readf64(b, 10) + assert(one == 1.0) + + buffer.writef64(b, 10, 1.437576533064206) + local magic1 = buffer.readi32(b, 10) + local magic2 = buffer.readi32(b, 14) + + assert(magic1 == 0x40302010) + assert(magic2 == 0x3ff70050) + + buffer.writef64(b, 10, one) + local magic3 = buffer.readi32(b, 10) + local magic4 = buffer.readi32(b, 14) + + assert(magic3 == 0x00000000) + assert(magic4 == 0x3ff00000) +end + +simple_double_reinterpret() + +local function simple_string_ops() + local b = buffer.create(1024) + + buffer.writestring(b, 15, " world") + buffer.writestring(b, 10, "hello") + buffer.writei8(b, 21, string.byte('!')) + assert(buffer.readstring(b, 10, 12) == "hello world!") + + buffer.writestring(b, 10, "hellommm", 5) + assert(buffer.readstring(b, 10, 12) == "hello world!") + + buffer.writestring(b, 10, string.rep("hellommm", 1000), 5) + assert(buffer.readstring(b, 10, 12) == "hello world!") +end + +simple_string_ops() + +local function simple_copy_ops() + local b1 = buffer.create(1024) + local b2 = buffer.create(1024) + + buffer.writestring(b1, 200, "hello") + buffer.writestring(b1, 100, "world") + + buffer.copy(b1, 300, b1, 100, 5) + + buffer.writei8(b2, 35, string.byte(' ')) + buffer.writei8(b2, 41, string.byte('!')) + + buffer.copy(b2, 30, b1, 200, 5) + buffer.copy(b2, 36, b1, 300, 5) + + assert(buffer.readstring(b2, 30, 12) == "hello world!") + + local b3 = buffer.create(9) + buffer.writestring(b3, 0, "say hello") + buffer.copy(b2, 36, b3, 4) + assert(buffer.readstring(b2, 30, 12) == "hello hello!") + + local b4 = buffer.create(5) + buffer.writestring(b4, 0, "world") + buffer.copy(b2, 36, b4) + assert(buffer.readstring(b2, 30, 12) == "hello world!") + + buffer.writestring(b1, 200, "abcdefgh"); + buffer.copy(b1, 200, b1, 202, 6) + assert(buffer.readstring(b1, 200, 8) == "cdefghgh") + buffer.copy(b1, 202, b1, 200, 6) + assert(buffer.readstring(b1, 200, 8) == "cdcdefgh") +end + +simple_copy_ops() + +-- bounds checking + +local function createchecks() + assert(ecall(function() buffer.create(-1) end) == "invalid argument #1 to 'create' (size)") + assert(ecall(function() buffer.create(-1000000) end) == "invalid argument #1 to 'create' (size)") +end + +createchecks() + +local function boundchecks() + local b = buffer.create(1024) + + assert(call(function() return buffer.readi8(b, 1023) end) == 0) + assert(ecall(function() buffer.readi8(b, 1024) end) == "buffer access out of bounds") + assert(ecall(function() buffer.readi8(b, -1) end) == "buffer access out of bounds") + assert(ecall(function() buffer.readi8(b, -100000) end) == "buffer access out of bounds") + + call(function() buffer.writei8(b, 1023, 0) end) + assert(ecall(function() buffer.writei8(b, 1024, 0) end) == "buffer access out of bounds") + assert(ecall(function() buffer.writei8(b, -1, 0) end) == "buffer access out of bounds") + assert(ecall(function() buffer.writei8(b, -100000, 0) end) == "buffer access out of bounds") + + -- i16 + assert(call(function() return buffer.readi16(b, 1022) end) == 0) + assert(ecall(function() buffer.readi16(b, 1023) end) == "buffer access out of bounds") + assert(ecall(function() buffer.readi16(b, -1) end) == "buffer access out of bounds") + assert(ecall(function() buffer.readi16(b, -100000) end) == "buffer access out of bounds") + assert(ecall(function() buffer.readi16(b, 0x7fffffff) end) == "buffer access out of bounds") + assert(ecall(function() buffer.readi16(b, 0x7ffffffe) end) == "buffer access out of bounds") + assert(ecall(function() buffer.readi16(b, 0x7ffffffd) end) == "buffer access out of bounds") + assert(ecall(function() buffer.readi16(b, 0x80000000) end) == "buffer access out of bounds") + assert(ecall(function() buffer.readi16(b, 0x0fffffff) end) == "buffer access out of bounds") + + call(function() buffer.writei16(b, 1022, 0) end) + assert(ecall(function() buffer.writei16(b, 1023, 0) end) == "buffer access out of bounds") + assert(ecall(function() buffer.writei16(b, -1, 0) end) == "buffer access out of bounds") + assert(ecall(function() buffer.writei16(b, -100000, 0) end) == "buffer access out of bounds") + + -- i32 + assert(call(function() return buffer.readi32(b, 1020) end) == 0) + assert(ecall(function() buffer.readi32(b, 1021) end) == "buffer access out of bounds") + assert(ecall(function() buffer.readi32(b, -1) end) == "buffer access out of bounds") + assert(ecall(function() buffer.readi32(b, -100000) end) == "buffer access out of bounds") + + call(function() buffer.writei32(b, 1020, 0) end) + assert(ecall(function() buffer.writei32(b, 1021, 0) end) == "buffer access out of bounds") + assert(ecall(function() buffer.writei32(b, -1, 0) end) == "buffer access out of bounds") + assert(ecall(function() buffer.writei32(b, -100000, 0) end) == "buffer access out of bounds") + + -- f32 + assert(call(function() return buffer.readf32(b, 1020) end) == 0) + assert(ecall(function() buffer.readf32(b, 1021) end) == "buffer access out of bounds") + assert(ecall(function() buffer.readf32(b, -1) end) == "buffer access out of bounds") + assert(ecall(function() buffer.readf32(b, -100000) end) == "buffer access out of bounds") + + call(function() buffer.writef32(b, 1020, 0) end) + assert(ecall(function() buffer.writef32(b, 1021, 0) end) == "buffer access out of bounds") + assert(ecall(function() buffer.writef32(b, -1, 0) end) == "buffer access out of bounds") + assert(ecall(function() buffer.writef32(b, -100000, 0) end) == "buffer access out of bounds") + + -- f64 + assert(call(function() return buffer.readf64(b, 1016) end) == 0) + assert(ecall(function() buffer.readf64(b, 1017) end) == "buffer access out of bounds") + assert(ecall(function() buffer.readf64(b, -1) end) == "buffer access out of bounds") + assert(ecall(function() buffer.readf64(b, -100000) end) == "buffer access out of bounds") + + call(function() buffer.writef64(b, 1016, 0) end) + assert(ecall(function() buffer.writef64(b, 1017, 0) end) == "buffer access out of bounds") + assert(ecall(function() buffer.writef64(b, -1, 0) end) == "buffer access out of bounds") + assert(ecall(function() buffer.writef64(b, -100000, 0) end) == "buffer access out of bounds") + + -- string + assert(call(function() return buffer.readstring(b, 1016, 8) end) == "\0\0\0\0\0\0\0\0") + assert(ecall(function() buffer.readstring(b, 1017, 8) end) == "buffer access out of bounds") + assert(ecall(function() buffer.readstring(b, -1, -8) end) == "invalid argument #3 to 'readstring' (size)") + assert(ecall(function() buffer.readstring(b, -100000, 8) end) == "buffer access out of bounds") + assert(ecall(function() buffer.readstring(b, -100000, 8) end) == "buffer access out of bounds") + + call(function() buffer.writestring(b, 1016, "abcdefgh") end) + assert(ecall(function() buffer.writestring(b, 1017, "abcdefgh") end) == "buffer access out of bounds") + assert(ecall(function() buffer.writestring(b, -1, "abcdefgh") end) == "buffer access out of bounds") + assert(ecall(function() buffer.writestring(b, -100000, "abcdefgh") end) == "buffer access out of bounds") + assert(ecall(function() buffer.writestring(b, 100, "abcd", -5) end) == "invalid argument #4 to 'writestring' (count)") + assert(ecall(function() buffer.writestring(b, 100, "abcd", 50) end) == "string length overflow") + + -- copy + assert(ecall(function() buffer.copy(b, 30, b, 200, 1000) end) == "buffer access out of bounds") + assert(ecall(function() buffer.copy(b, 30, b, 200, -5) end) == "buffer access out of bounds") + assert(ecall(function() buffer.copy(b, 30, b, 2000, 10) end) == "buffer access out of bounds") + assert(ecall(function() buffer.copy(b, 30, b, -1, 10) end) == "buffer access out of bounds") + assert(ecall(function() buffer.copy(b, 30, b, -10, 10) end) == "buffer access out of bounds") + assert(ecall(function() buffer.copy(b, 30, b, -100000, 10) end) == "buffer access out of bounds") + + local b2 = buffer.create(1024) + assert(ecall(function() buffer.copy(b, -200, b, 200, 200) end) == "buffer access out of bounds") + assert(ecall(function() buffer.copy(b, 825, b, 200, 200) end) == "buffer access out of bounds") +end + +boundchecks() + +local function boundchecksnonconst(size, minus1, minusbig, intmax) + local b = buffer.create(size) + + assert(call(function() return buffer.readi8(b, size-1) end) == 0) + assert(ecall(function() buffer.readi8(b, size) end) == "buffer access out of bounds") + assert(ecall(function() buffer.readi8(b, minus1) end) == "buffer access out of bounds") + assert(ecall(function() buffer.readi8(b, minusbig) end) == "buffer access out of bounds") + + call(function() buffer.writei8(b, size-1, 0) end) + assert(ecall(function() buffer.writei8(b, size, 0) end) == "buffer access out of bounds") + assert(ecall(function() buffer.writei8(b, minus1, 0) end) == "buffer access out of bounds") + assert(ecall(function() buffer.writei8(b, minusbig, 0) end) == "buffer access out of bounds") + + -- i16 + assert(call(function() return buffer.readi16(b, size-2) end) == 0) + assert(ecall(function() buffer.readi16(b, size-1) end) == "buffer access out of bounds") + assert(ecall(function() buffer.readi16(b, minus1) end) == "buffer access out of bounds") + assert(ecall(function() buffer.readi16(b, minusbig) end) == "buffer access out of bounds") + assert(ecall(function() buffer.readi16(b, intmax) end) == "buffer access out of bounds") + assert(ecall(function() buffer.readi16(b, intmax-1) end) == "buffer access out of bounds") + assert(ecall(function() buffer.readi16(b, intmax-2) end) == "buffer access out of bounds") + assert(ecall(function() buffer.readi16(b, intmax+1) end) == "buffer access out of bounds") + + call(function() buffer.writei16(b, size-2, 0) end) + assert(ecall(function() buffer.writei16(b, size-1, 0) end) == "buffer access out of bounds") + assert(ecall(function() buffer.writei16(b, minus1, 0) end) == "buffer access out of bounds") + assert(ecall(function() buffer.writei16(b, minusbig, 0) end) == "buffer access out of bounds") + + -- i32 + assert(call(function() return buffer.readi32(b, size-4) end) == 0) + assert(ecall(function() buffer.readi32(b, size-3) end) == "buffer access out of bounds") + assert(ecall(function() buffer.readi32(b, minus1) end) == "buffer access out of bounds") + assert(ecall(function() buffer.readi32(b, minusbig) end) == "buffer access out of bounds") + + call(function() buffer.writei32(b, size-4, 0) end) + assert(ecall(function() buffer.writei32(b, size-3, 0) end) == "buffer access out of bounds") + assert(ecall(function() buffer.writei32(b, minus1, 0) end) == "buffer access out of bounds") + assert(ecall(function() buffer.writei32(b, minusbig, 0) end) == "buffer access out of bounds") + + -- f32 + assert(call(function() return buffer.readf32(b, size-4) end) == 0) + assert(ecall(function() buffer.readf32(b, size-3) end) == "buffer access out of bounds") + assert(ecall(function() buffer.readf32(b, minus1) end) == "buffer access out of bounds") + assert(ecall(function() buffer.readf32(b, minusbig) end) == "buffer access out of bounds") + + call(function() buffer.writef32(b, size-4, 0) end) + assert(ecall(function() buffer.writef32(b, size-3, 0) end) == "buffer access out of bounds") + assert(ecall(function() buffer.writef32(b, minus1, 0) end) == "buffer access out of bounds") + assert(ecall(function() buffer.writef32(b, minusbig, 0) end) == "buffer access out of bounds") + + -- f64 + assert(call(function() return buffer.readf64(b, size-8) end) == 0) + assert(ecall(function() buffer.readf64(b, size-7) end) == "buffer access out of bounds") + assert(ecall(function() buffer.readf64(b, minus1) end) == "buffer access out of bounds") + assert(ecall(function() buffer.readf64(b, minusbig) end) == "buffer access out of bounds") + + call(function() buffer.writef64(b, size-8, 0) end) + assert(ecall(function() buffer.writef64(b, size-7, 0) end) == "buffer access out of bounds") + assert(ecall(function() buffer.writef64(b, minus1, 0) end) == "buffer access out of bounds") + assert(ecall(function() buffer.writef64(b, minusbig, 0) end) == "buffer access out of bounds") + + -- string + assert(call(function() return buffer.readstring(b, size-8, 8) end) == "\0\0\0\0\0\0\0\0") + assert(ecall(function() buffer.readstring(b, size-7, 8) end) == "buffer access out of bounds") + assert(ecall(function() buffer.readstring(b, minus1, 8) end) == "buffer access out of bounds") + assert(ecall(function() buffer.readstring(b, minusbig, 8) end) == "buffer access out of bounds") + + call(function() buffer.writestring(b, size-8, "abcdefgh") end) + assert(ecall(function() buffer.writestring(b, size-7, "abcdefgh") end) == "buffer access out of bounds") + assert(ecall(function() buffer.writestring(b, minus1, "abcdefgh") end) == "buffer access out of bounds") + assert(ecall(function() buffer.writestring(b, minusbig, "abcdefgh") end) == "buffer access out of bounds") +end + +boundchecksnonconst(1024, -1, -100000, 0x7fffffff) + +local function boundcheckssmall() + local b = buffer.create(1) + + assert(call(function() return buffer.readi8(b, 0) end) == 0) + assert(ecall(function() buffer.readi8(b, 1) end) == "buffer access out of bounds") + assert(ecall(function() buffer.readi8(b, -1) end) == "buffer access out of bounds") + + call(function() buffer.writei8(b, 0, 0) end) + assert(ecall(function() buffer.writei8(b, 1, 0) end) == "buffer access out of bounds") + assert(ecall(function() buffer.writei8(b, -1, 0) end) == "buffer access out of bounds") + + -- i16 + assert(ecall(function() buffer.readi16(b, 0) end) == "buffer access out of bounds") + assert(ecall(function() buffer.readi16(b, -1) end) == "buffer access out of bounds") + assert(ecall(function() buffer.readi16(b, -2) end) == "buffer access out of bounds") + assert(ecall(function() buffer.writei16(b, 0, 0) end) == "buffer access out of bounds") + assert(ecall(function() buffer.writei16(b, -1, 0) end) == "buffer access out of bounds") + assert(ecall(function() buffer.writei16(b, -2, 0) end) == "buffer access out of bounds") + + -- i32 + assert(ecall(function() buffer.readi32(b, 0) end) == "buffer access out of bounds") + assert(ecall(function() buffer.readi32(b, -1) end) == "buffer access out of bounds") + assert(ecall(function() buffer.readi32(b, -4) end) == "buffer access out of bounds") + assert(ecall(function() buffer.writei32(b, 0, 0) end) == "buffer access out of bounds") + assert(ecall(function() buffer.writei32(b, -1, 0) end) == "buffer access out of bounds") + assert(ecall(function() buffer.writei32(b, -4, 0) end) == "buffer access out of bounds") + + -- f32 + assert(ecall(function() buffer.readf32(b, 0) end) == "buffer access out of bounds") + assert(ecall(function() buffer.readf32(b, -1) end) == "buffer access out of bounds") + assert(ecall(function() buffer.readf32(b, -4) end) == "buffer access out of bounds") + assert(ecall(function() buffer.writef32(b, 0, 0) end) == "buffer access out of bounds") + assert(ecall(function() buffer.writef32(b, -1, 0) end) == "buffer access out of bounds") + assert(ecall(function() buffer.writef32(b, -4, 0) end) == "buffer access out of bounds") + + -- f64 + assert(ecall(function() buffer.readf64(b, 0) end) == "buffer access out of bounds") + assert(ecall(function() buffer.readf64(b, -1) end) == "buffer access out of bounds") + assert(ecall(function() buffer.readf64(b, -8) end) == "buffer access out of bounds") + assert(ecall(function() buffer.writef64(b, 0, 0) end) == "buffer access out of bounds") + assert(ecall(function() buffer.writef64(b, -1, 0) end) == "buffer access out of bounds") + assert(ecall(function() buffer.writef64(b, -7, 0) end) == "buffer access out of bounds") + + -- string + assert(ecall(function() buffer.readstring(b, 0, 8) end) == "buffer access out of bounds") + assert(ecall(function() buffer.readstring(b, -1, 8) end) == "buffer access out of bounds") + assert(ecall(function() buffer.readstring(b, -8, 8) end) == "buffer access out of bounds") + assert(ecall(function() buffer.writestring(b, 0, "abcdefgh") end) == "buffer access out of bounds") + assert(ecall(function() buffer.writestring(b, -1, "abcdefgh") end) == "buffer access out of bounds") + assert(ecall(function() buffer.writestring(b, -7, "abcdefgh") end) == "buffer access out of bounds") +end + +boundcheckssmall() + +local function boundcheckssmallnonconst(zero, one, minus1, minus2, minus4, minus7, minus8) + local b = buffer.create(1) + + assert(call(function() return buffer.readi8(b, 0) end) == 0) + assert(ecall(function() buffer.readi8(b, one) end) == "buffer access out of bounds") + assert(ecall(function() buffer.readi8(b, minus1) end) == "buffer access out of bounds") + + call(function() buffer.writei8(b, 0, 0) end) + assert(ecall(function() buffer.writei8(b, one, 0) end) == "buffer access out of bounds") + assert(ecall(function() buffer.writei8(b, minus1, 0) end) == "buffer access out of bounds") + + -- i16 + assert(ecall(function() buffer.readi16(b, zero) end) == "buffer access out of bounds") + assert(ecall(function() buffer.readi16(b, minus1) end) == "buffer access out of bounds") + assert(ecall(function() buffer.readi16(b, minus2) end) == "buffer access out of bounds") + assert(ecall(function() buffer.writei16(b, zero, 0) end) == "buffer access out of bounds") + assert(ecall(function() buffer.writei16(b, minus1, 0) end) == "buffer access out of bounds") + assert(ecall(function() buffer.writei16(b, minus2, 0) end) == "buffer access out of bounds") + + -- i32 + assert(ecall(function() buffer.readi32(b, zero) end) == "buffer access out of bounds") + assert(ecall(function() buffer.readi32(b, minus1) end) == "buffer access out of bounds") + assert(ecall(function() buffer.readi32(b, minus4) end) == "buffer access out of bounds") + assert(ecall(function() buffer.writei32(b, zero, 0) end) == "buffer access out of bounds") + assert(ecall(function() buffer.writei32(b, minus1, 0) end) == "buffer access out of bounds") + assert(ecall(function() buffer.writei32(b, minus4, 0) end) == "buffer access out of bounds") + + -- f32 + assert(ecall(function() buffer.readf32(b, zero) end) == "buffer access out of bounds") + assert(ecall(function() buffer.readf32(b, minus1) end) == "buffer access out of bounds") + assert(ecall(function() buffer.readf32(b, minus4) end) == "buffer access out of bounds") + assert(ecall(function() buffer.writef32(b, zero, 0) end) == "buffer access out of bounds") + assert(ecall(function() buffer.writef32(b, minus1, 0) end) == "buffer access out of bounds") + assert(ecall(function() buffer.writef32(b, minus4, 0) end) == "buffer access out of bounds") + + -- f64 + assert(ecall(function() buffer.readf64(b, zero) end) == "buffer access out of bounds") + assert(ecall(function() buffer.readf64(b, minus1) end) == "buffer access out of bounds") + assert(ecall(function() buffer.readf64(b, minus8) end) == "buffer access out of bounds") + assert(ecall(function() buffer.writef64(b, zero, 0) end) == "buffer access out of bounds") + assert(ecall(function() buffer.writef64(b, minus1, 0) end) == "buffer access out of bounds") + assert(ecall(function() buffer.writef64(b, minus7, 0) end) == "buffer access out of bounds") + + -- string + assert(ecall(function() buffer.readstring(b, zero, 8) end) == "buffer access out of bounds") + assert(ecall(function() buffer.readstring(b, minus1, 8) end) == "buffer access out of bounds") + assert(ecall(function() buffer.readstring(b, minus8, 8) end) == "buffer access out of bounds") + assert(ecall(function() buffer.writestring(b, zero, "abcdefgh") end) == "buffer access out of bounds") + assert(ecall(function() buffer.writestring(b, minus1, "abcdefgh") end) == "buffer access out of bounds") + assert(ecall(function() buffer.writestring(b, minus7, "abcdefgh") end) == "buffer access out of bounds") +end + +boundcheckssmallnonconst(0, 1, -1, -2, -4, -7, -8) + +local function boundchecksempty() + local b = buffer.create(0) -- useless, but probably more generic + + assert(ecall(function() buffer.readi8(b, 1) end) == "buffer access out of bounds") + assert(ecall(function() buffer.readi8(b, 0) end) == "buffer access out of bounds") + assert(ecall(function() buffer.readi8(b, -1) end) == "buffer access out of bounds") + assert(ecall(function() buffer.writei8(b, 1, 0) end) == "buffer access out of bounds") + assert(ecall(function() buffer.writei8(b, 0, 0) end) == "buffer access out of bounds") + assert(ecall(function() buffer.writei8(b, -1, 0) end) == "buffer access out of bounds") + + assert(ecall(function() buffer.readi16(b, 0) end) == "buffer access out of bounds") + assert(ecall(function() buffer.readi32(b, 0) end) == "buffer access out of bounds") + assert(ecall(function() buffer.readf32(b, 0) end) == "buffer access out of bounds") + assert(ecall(function() buffer.readf64(b, 0) end) == "buffer access out of bounds") + assert(ecall(function() buffer.readstring(b, 0, 1) end) == "buffer access out of bounds") + assert(ecall(function() buffer.readstring(b, 0, 8) end) == "buffer access out of bounds") +end + +boundchecksempty() + +local function intuint() + local b = buffer.create(32) + + buffer.writeu32(b, 0, 0xffffffff) + assert(buffer.readi8(b, 0) == -1) + assert(buffer.readu8(b, 0) == 255) + assert(buffer.readi16(b, 0) == -1) + assert(buffer.readu16(b, 0) == 65535) + assert(buffer.readi32(b, 0) == -1) + assert(buffer.readu32(b, 0) == 4294967295) + + buffer.writei32(b, 0, -1) + assert(buffer.readi8(b, 0) == -1) + assert(buffer.readu8(b, 0) == 255) + assert(buffer.readi16(b, 0) == -1) + assert(buffer.readu16(b, 0) == 65535) + assert(buffer.readi32(b, 0) == -1) + assert(buffer.readu32(b, 0) == 4294967295) + + buffer.writei16(b, 0, 65535) + buffer.writei16(b, 2, -1) + assert(buffer.readi8(b, 0) == -1) + assert(buffer.readu8(b, 0) == 255) + assert(buffer.readi16(b, 0) == -1) + assert(buffer.readu16(b, 0) == 65535) + assert(buffer.readi32(b, 0) == -1) + assert(buffer.readu32(b, 0) == 4294967295) + + buffer.writeu16(b, 0, 65535) + buffer.writeu16(b, 2, -1) + assert(buffer.readi8(b, 0) == -1) + assert(buffer.readu8(b, 0) == 255) + assert(buffer.readi16(b, 0) == -1) + assert(buffer.readu16(b, 0) == 65535) + assert(buffer.readi32(b, 0) == -1) + assert(buffer.readu32(b, 0) == 4294967295) +end + +intuint() + +local function intuinttricky() + local b = buffer.create(32) + + buffer.writeu8(b, 0, 0xffffffff) + assert(buffer.readi8(b, 0) == -1) + assert(buffer.readu8(b, 0) == 255) + assert(buffer.readi16(b, 0) == 255) + assert(buffer.readu16(b, 0) == 255) + assert(buffer.readi32(b, 0) == 255) + assert(buffer.readu32(b, 0) == 255) + + buffer.writeu16(b, 0, 0xffffffff) + assert(buffer.readi8(b, 0) == -1) + assert(buffer.readu8(b, 0) == 255) + assert(buffer.readi16(b, 0) == -1) + assert(buffer.readu16(b, 0) == 65535) + assert(buffer.readi32(b, 0) == 65535) + assert(buffer.readu32(b, 0) == 65535) + + buffer.writei32(b, 8, 0xffffffff) + buffer.writeu32(b, 12, 0xffffffff) + assert(buffer.readstring(b, 8, 4) == buffer.readstring(b, 12, 4)) + + buffer.writei32(b, 8, -2147483648) + buffer.writeu32(b, 12, 0x80000000) + assert(buffer.readstring(b, 8, 4) == buffer.readstring(b, 12, 4)) +end + +intuinttricky() + +local function fromtostring() + local b = buffer.fromstring("1234567890") + assert(buffer.tostring(b) == "1234567890") + + buffer.writestring(b, 4, "xyz") + assert(buffer.tostring(b) == "1234xyz890") + + local b2 = buffer.fromstring("abcd\0ef") + assert(buffer.tostring(b2) == "abcd\0ef") +end + +fromtostring() + +local function fill() + local b = buffer.create(10) + + buffer.fill(b, 0, 0x61) + assert(buffer.tostring(b) == "aaaaaaaaaa") + + buffer.fill(b, 0, 0x62, 5) + assert(buffer.tostring(b) == "bbbbbaaaaa") + + buffer.fill(b, 4, 0x63) + assert(buffer.tostring(b) == "bbbbcccccc") + + buffer.fill(b, 6, 0x64, 3) + assert(buffer.tostring(b) == "bbbbccdddc") + + buffer.fill(b, 2, 0xffffff65, 8) + assert(buffer.tostring(b) == "bbeeeeeeee") + + -- out of bounds + assert(ecall(function() buffer.fill(b, -10, 1) end) == "buffer access out of bounds") + assert(ecall(function() buffer.fill(b, 11, 1) end) == "buffer access out of bounds") + assert(ecall(function() buffer.fill(b, 0, 1, 11) end) == "buffer access out of bounds") + assert(ecall(function() buffer.fill(b, 5, 1, 6) end) == "buffer access out of bounds") + assert(ecall(function() buffer.fill(b, 5, 1, -1) end) == "buffer access out of bounds") +end + +fill() + +local function misc(t16) + local b = buffer.create(1000) + + assert(select('#', buffer.writei32(b, 10, 40)) == 0) + assert(select('#', buffer.writef32(b, 20, 40.0)) == 0) + + -- some extra operation to place '#t16' into a linear block + t16[1] = 10 + t16[15] = 20 + + buffer.writei32(b, #t16, 10) + assert(buffer.readi32(b, 16) == 10) + + buffer.writeu8(b, 100, 0xff) + buffer.writeu8(b, 110, 0x80) + assert(buffer.readu32(b, 100) == 255) + assert(buffer.readu32(b, 110) == 128) + buffer.writeu16(b, 200, 0xffff) + buffer.writeu16(b, 210, 0x8000) + assert(buffer.readu32(b, 200) == 65535) + assert(buffer.readu32(b, 210) == 32768) +end + +misc(table.create(16, 0)) + +local function testslowcalls() + getfenv() + + simple_byte_reads() + offset_byte_reads(5) + offset_byte_reads(30) + simple_float_reinterpret() + simple_double_reinterpret() + simple_string_ops() + createchecks() + boundchecks() + boundchecksnonconst(1024, -1, -100000, 0x7fffffff) + boundcheckssmall() + boundcheckssmallnonconst(0, 1, -1, -2, -4, -7, -8) + boundchecksempty() + intuint() + intuinttricky() + fromtostring() + fill() + misc(table.create(16, 0)) +end + +testslowcalls() + +return('OK') diff --git a/tests/conformance/calls.lua b/tests/conformance/calls.lua index 621a921aa..6555f93e1 100644 --- a/tests/conformance/calls.lua +++ b/tests/conformance/calls.lua @@ -236,4 +236,12 @@ if not limitedstack then assert(not err and string.find(msg, "error")) end +-- testing deep nested calls with a large thread stack +do + function recurse(n, ...) return n <= 1 and (1 + #{...}) or recurse(n-1, table.unpack(table.create(4000, 1))) + 1 end + + local ok, msg = pcall(recurse, 19000) + assert(not ok and string.find(msg, "not enough memory")) +end + return('OK') diff --git a/tests/conformance/closure.lua b/tests/conformance/closure.lua index 7b0573546..10dc322fb 100644 --- a/tests/conformance/closure.lua +++ b/tests/conformance/closure.lua @@ -363,46 +363,12 @@ local x = coroutine.create (function () end) assert(not coroutine.resume(x)) + -- overwrite previous position of local `a' assert(not coroutine.resume(x, 1, 1, 1, 1, 1, 1, 1)) assert(_G.f() == 11) assert(_G.f() == 12) - -if not T then - (Message or print)('\a\n >>> testC not active: skipping yield/hook tests <<<\n\a') -else - - local turn - - function fact (t, x) - assert(turn == t) - if x == 0 then return 1 - else return x*fact(t, x-1) - end - end - - local A,B,a,b = 0,0,0,0 - - local x = coroutine.create(function () - T.setyhook("", 2) - A = fact("A", 10) - end) - - local y = coroutine.create(function () - T.setyhook("", 3) - B = fact("B", 11) - end) - - while A==0 or B==0 do - if A==0 then turn = "A"; T.resume(x) end - if B==0 then turn = "B"; T.resume(y) end - end - - assert(B/A == 11) -end - - -- leaving a pending coroutine open _X = coroutine.wrap(function () local a = 10 diff --git a/tests/conformance/constructs.lua b/tests/conformance/constructs.lua index f133501f1..b8b88478a 100644 --- a/tests/conformance/constructs.lua +++ b/tests/conformance/constructs.lua @@ -237,4 +237,19 @@ repeat i = i+1 until i==c +-- validate continue upvalue close behavior +local function check_connected(writer, reader) + writer(1) + assert(reader() == 1) + return true +end + +repeat + local value = nil + local function write(n) + value = n + end + continue +until check_connected(write, function() return value end) + return 'OK' diff --git a/tests/conformance/datetime.lua b/tests/conformance/datetime.lua index dc73948b6..8e2103d91 100644 --- a/tests/conformance/datetime.lua +++ b/tests/conformance/datetime.lua @@ -18,6 +18,34 @@ assert(os.date(string.rep("%d", 1000), t) == assert(os.date(string.rep("%", 200)) == string.rep("%", 100)) assert(os.date("", -1) == nil) +assert(os.time({ year = 1969, month = 12, day = 31, hour = 23, min = 59, sec = 59}) == nil) -- just before start +assert(os.time({ year = 1970, month = 1, day = 1, hour = 0, min = 0, sec = 0}) == 0) -- start +assert(os.time({ year = 3000, month = 12, day = 31, hour = 23, min = 59, sec = 59}) == 32535215999) -- just before Windows max range +assert(os.time({ year = 1970, month = 1, day = 1, hour = 0, min = 0, sec = -1}) == nil) -- going before using time fields + +assert(os.time({ year = 2000, month = 6, day = 10, hour = 0, min = 0, sec = 0}) == 960595200) +assert(os.time({ year = 2000, month = 6, day = 10, hour = 0, min = 0, sec = -86400}) == 960508800) +assert(os.time({ year = 2000, month = 6, day = 10, hour = 0, min = 0, sec = 86400}) == 960681600) +assert(os.time({ year = 2000, month = 6, day = 10, hour = 0, min = -600, sec = 0}) == 960559200) +assert(os.time({ year = 2000, month = 6, day = 10, hour = 0, min = 600, sec = 0}) == 960631200) +assert(os.time({ year = 2000, month = 6, day = 10, hour = -600, min = 0, sec = 0}) == 958435200) +assert(os.time({ year = 2000, month = 6, day = 10, hour = 600, min = 0, sec = 0}) == 962755200) +assert(os.time({ year = 2000, month = 6, day = -100, hour = 0, min = 0, sec = 0}) == 951091200) +assert(os.time({ year = 2000, month = 6, day = 1000, hour = 0, min = 0, sec = 0}) == 1046131200) +assert(os.time({ year = 2000, month = -60, day = 10, hour = 0, min = 0, sec = 0}) == 787017600) +assert(os.time({ year = 2000, month = 60, day = 10, hour = 0, min = 0, sec = 0}) == 1102636800) + +assert(os.time({ year = 2000, month = 6, day = 10, hour = 0, min = 0, sec = -86400000}) == 874195200) +assert(os.time({ year = 2000, month = 6, day = 10, hour = 0, min = 0, sec = 86400000}) == 1046995200) +assert(os.time({ year = 2000, month = 6, day = 10, hour = 0, min = -600000, sec = 0}) == 924595200) +assert(os.time({ year = 2000, month = 6, day = 10, hour = 0, min = 600000, sec = 0}) == 996595200) +assert(os.time({ year = 2100, month = 6, day = 10, hour = -600000, min = 0, sec = 0}) == 1956268800) +assert(os.time({ year = 2100, month = 6, day = 10, hour = 600000, min = 0, sec = 0}) == 6276268800) +assert(os.time({ year = 2100, month = 6, day = -10000, hour = 0, min = 0, sec = 0}) == 3251404800) +assert(os.time({ year = 2100, month = 6, day = 100000, hour = 0, min = 0, sec = 0}) == 12755404800) +assert(os.time({ year = 2100, month = -600, day = 10, hour = 0, min = 0, sec = 0}) == 2522707200) +assert(os.time({ year = 2100, month = 600, day = 10, hour = 0, min = 0, sec = 0}) == 5678380800) + local function checkDateTable (t) local D = os.date("!*t", t) assert(os.time(D) == t) diff --git a/tests/conformance/debug.lua b/tests/conformance/debug.lua index 0c8cc2d87..e044ea454 100644 --- a/tests/conformance/debug.lua +++ b/tests/conformance/debug.lua @@ -111,4 +111,31 @@ end testlinedefined() +-- don't leave garbage on the other thread +local wrapped1 = coroutine.create(function() + local thread = coroutine.create(function(target) + for i = 1, 100 do pcall(debug.info, target, 0, "llf") end + return 123 + end) + + local success, res = coroutine.resume(thread, coroutine.running()) + assert(success) + assert(res == 123) +end) + +coroutine.resume(wrapped1) + +local wrapped2 = coroutine.create(function() + local thread = coroutine.create(function(target) + for i = 1, 100 do pcall(debug.info, target, 0, "ff") end + return 123 + end) + + local success, res = coroutine.resume(thread, coroutine.running()) + assert(success) + assert(res == 123) +end) + +coroutine.resume(wrapped2) + return 'OK' diff --git a/tests/conformance/debugger.lua b/tests/conformance/debugger.lua index c773013b7..0980703a1 100644 --- a/tests/conformance/debugger.lua +++ b/tests/conformance/debugger.lua @@ -69,4 +69,34 @@ end breakpointSetFromMetamethod() +-- break inside function with non-monotonic line info +local function cond(a) + if a then + print('a') + else + print('not a') + end +end + +breakpoint(77) + +pcall(cond, nil) -- prevent inlining + +local function continueLocals() + repeat + local x = tostring(game) + do continue end + local a1, a2, a3, a4, a5, a6 + until pcall( + function() + print("1") + print("2") + end, nil + ) or true +end + +breakpoint(93) + +pcall(continueLocals, nil) -- prevent inlining + return 'OK' diff --git a/tests/conformance/events.lua b/tests/conformance/events.lua index 94314c3fb..4ee801f02 100644 --- a/tests/conformance/events.lua +++ b/tests/conformance/events.lua @@ -107,6 +107,7 @@ t.__add = f("add") t.__sub = f("sub") t.__mul = f("mul") t.__div = f("div") +t.__idiv = f("idiv") t.__mod = f("mod") t.__unm = f("unm") t.__pow = f("pow") @@ -128,6 +129,8 @@ assert(a*a == a) assert(cap[0] == "mul" and cap[1] == a and cap[2] == a and cap[3]==nil) assert(a/0 == a) assert(cap[0] == "div" and cap[1] == a and cap[2] == 0 and cap[3]==nil) +assert(a//0 == a) +assert(cap[0] == "idiv" and cap[1] == a and cap[2] == 0 and cap[3]==nil) assert(a%2 == a) assert(cap[0] == "mod" and cap[1] == a and cap[2] == 2 and cap[3]==nil) assert(-a == a) diff --git a/tests/conformance/interrupt.lua b/tests/conformance/interrupt.lua index d4b7c80a4..86712be57 100644 --- a/tests/conformance/interrupt.lua +++ b/tests/conformance/interrupt.lua @@ -1,20 +1,111 @@ -- This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details print("testing interrupts") -function foo() - for i=1,10 do end - return +-- this function will be called by C code with a special interrupt handler that validates hit locations +function test() + function foo() + for i=1,10 do end + return + end + + foo() + + function bar() + local i = 0 + while i < 10 do + i += i + 1 + end + end + + bar() + + function baz() + end + + baz() +end + +-- these functions will be called by C code with a special interrupt handler that terminates after a few invocations +function infloop1() + while true do end +end + +function infloop2() + while true do continue end +end + +function infloop3() + repeat until false +end + +function infloop4() + repeat continue until false +end + +function infloop5() + for i=0,0,0 do end +end + +function infloop6() + for i=0,0,0 do continue end end -foo() +function infloop7() + for i=1,math.huge do end +end + +function infloop8() + for i=1,math.huge do continue end +end + +function infloop9() + -- technically not a loop, but an exponentially recursive function + local function boom() + boom() + boom() + end + boom() +end -function bar() - local i = 0 - while i < 10 do - i += i + 1 - end +function infloop10() + for l0=4096,0,0 do + repeat + continue + until function() end + end end -bar() +local haystack = string.rep("x", 100) +local pattern = string.rep("x?", 100) .. string.rep("x", 100) + +function strhang1() + string.find(haystack, pattern) +end + +function strhang2() + string.match(haystack, pattern) +end + +function strhang3() + string.gsub(haystack, pattern, "%0") +end + +function strhang4() + for k, v in string.gmatch(haystack, pattern) do + end +end + +function strhang5() + local x = string.rep('x', 1000) + string.match(x, string.rep('x.*', 100) .. 'y') +end + +function strhangpcall() + for i = 1,100 do + local status, msg = pcall(string.find, haystack, pattern) + assert(status == false) + assert(msg == "timeout") + end +end return "OK" diff --git a/tests/conformance/math.lua b/tests/conformance/math.lua index e2f68e654..97c444624 100644 --- a/tests/conformance/math.lua +++ b/tests/conformance/math.lua @@ -61,6 +61,7 @@ assert(1111111111111111-1111111111111110== 1000.00e-03) -- 1234567890123456 assert(1.1 == '1.'+'.1') assert('1111111111111111'-'1111111111111110' == tonumber" +0.001e+3 \n\t") +assert(10000000000000001 == 10000000000000000) function eq (a,b,limit) if not limit then limit = 10E-10 end @@ -82,6 +83,7 @@ assert(not(1>1) and not(1>2) and (2>1)) assert(not('a'>'a') and not('a'>'b') and ('b'>'a')) assert((1>=1) and not(1>=2) and (2>=1)) assert(('a'>='a') and not('a'>='b') and ('b'>='a')) +assert((unk and unk > 0) == nil) -- validate precedence between and and > -- testing mod operator assert(-4%3 == 2) @@ -188,6 +190,26 @@ do -- testing NaN assert(a[NaN] == nil) end +-- extra NaN tests, hidden in a function +do + function neq(a) return a ~= a end + function eq(a) return a == a end + function lt(a) return a < a end + function le(a) return a <= a end + function gt(a) return a > a end + function ge(a) return a >= a end + + local NaN -- to avoid constant folding + NaN = 10e500 - 10e400 + + assert(neq(NaN)) + assert(not eq(NaN)) + assert(not lt(NaN)) + assert(not le(NaN)) + assert(not gt(NaN)) + assert(not ge(NaN)) +end + -- require "checktable" -- stat(a) @@ -235,23 +257,65 @@ assert(flag); assert(select(2, pcall(math.random, 1, 2, 3)):match("wrong number of arguments")) +-- argument count +function nothing() end + +assert(pcall(math.abs) == false) +assert(pcall(function() return math.abs(nothing()) end) == false) + -- min/max assert(math.min(1) == 1) assert(math.min(1, 2) == 1) assert(math.min(1, 2, -1) == -1) assert(math.min(1, -1, 2) == -1) +assert(math.min(1, -1, 2, -2) == -2) assert(math.max(1) == 1) assert(math.max(1, 2) == 2) assert(math.max(1, 2, -1) == 2) assert(math.max(1, -1, 2) == 2) +assert(math.max(1, -1, 2, -2) == 2) + +local ma, mb, mc, md + +assert(pcall(function() + ma = 1 + mb = -1 + mc = 2 + md = -2 +end) == true) + +-- min/max without contant-folding +assert(math.min(ma) == 1) +assert(math.min(ma, mc) == 1) +assert(math.min(ma, mc, mb) == -1) +assert(math.min(ma, mb, mc) == -1) +assert(math.min(ma, mb, mc, md) == -2) +assert(math.max(ma) == 1) +assert(math.max(ma, mc) == 2) +assert(math.max(ma, mc, mb) == 2) +assert(math.max(ma, mb, mc) == 2) +assert(math.max(ma, mb, mc, md) == 2) + +local inf = math.huge * 2 +local nan = 0 / 0 + +assert(math.min(nan, 2) ~= math.min(nan, 2)) +assert(math.min(1, nan) == 1) +assert(math.max(nan, 2) ~= math.max(nan, 2)) +assert(math.max(1, nan) == 1) + +local function noinline(x, ...) local s, r = pcall(function(y) return y end, x) return r end -- noise assert(math.noise(0.5) == 0) assert(math.noise(0.5, 0.5) == -0.25) assert(math.noise(0.5, 0.5, -0.5) == 0.125) +assert(math.noise(455.7204209769105, 340.80410508750134, 121.80087666537628) == 0.5010709762573242) -local inf = math.huge * 2 -local nan = 0 / 0 +assert(math.noise(noinline(0.5)) == 0) +assert(math.noise(noinline(0.5), 0.5) == -0.25) +assert(math.noise(noinline(0.5), 0.5, -0.5) == 0.125) +assert(math.noise(noinline(455.7204209769105), 340.80410508750134, 121.80087666537628) == 0.5010709762573242) -- sign assert(math.sign(0) == 0) @@ -261,10 +325,12 @@ assert(math.sign(inf) == 1) assert(math.sign(-inf) == -1) assert(math.sign(nan) == 0) -assert(math.min(nan, 2) ~= math.min(nan, 2)) -assert(math.min(1, nan) == 1) -assert(math.max(nan, 2) ~= math.max(nan, 2)) -assert(math.max(1, nan) == 1) +assert(math.sign(noinline(0)) == 0) +assert(math.sign(noinline(42)) == 1) +assert(math.sign(noinline(-42)) == -1) +assert(math.sign(noinline(inf)) == 1) +assert(math.sign(noinline(-inf)) == -1) +assert(math.sign(noinline(nan)) == 0) -- clamp assert(math.clamp(-1, 0, 1) == 0) @@ -272,6 +338,11 @@ assert(math.clamp(0.5, 0, 1) == 0.5) assert(math.clamp(2, 0, 1) == 1) assert(math.clamp(4, 0, 0) == 0) +assert(math.clamp(noinline(-1), 0, 1) == 0) +assert(math.clamp(noinline(0.5), 0, 1) == 0.5) +assert(math.clamp(noinline(2), 0, 1) == 1) +assert(math.clamp(noinline(4), 0, 0) == 0) + -- round assert(math.round(0) == 0) assert(math.round(0.4) == 0) @@ -281,6 +352,19 @@ assert(math.round(-0.4) == 0) assert(math.round(-0.5) == -1) assert(math.round(-3.5) == -4) assert(math.round(math.huge) == math.huge) +assert(math.round(0.49999999999999994) == 0) +assert(math.round(-0.49999999999999994) == 0) + +assert(math.round(noinline(0)) == 0) +assert(math.round(noinline(0.4)) == 0) +assert(math.round(noinline(0.5)) == 1) +assert(math.round(noinline(3.5)) == 4) +assert(math.round(noinline(-0.4)) == 0) +assert(math.round(noinline(-0.5)) == -1) +assert(math.round(noinline(-3.5)) == -4) +assert(math.round(noinline(math.huge)) == math.huge) +assert(math.round(noinline(0.49999999999999994)) == 0) +assert(math.round(noinline(-0.49999999999999994)) == 0) -- fmod assert(math.fmod(3, 2) == 1) @@ -288,13 +372,55 @@ assert(math.fmod(-3, 2) == -1) assert(math.fmod(3, -2) == 1) assert(math.fmod(-3, -2) == -1) +assert(math.fmod(noinline(3), 2) == 1) +assert(math.fmod(noinline(-3), 2) == -1) +assert(math.fmod(noinline(3), -2) == 1) +assert(math.fmod(noinline(-3), -2) == -1) + -- pow assert(math.pow(2, 0) == 1) assert(math.pow(2, 2) == 4) assert(math.pow(4, 0.5) == 2) assert(math.pow(-2, 2) == 4) + +assert(math.pow(noinline(2), 0) == 1) +assert(math.pow(noinline(2), 2) == 4) +assert(math.pow(noinline(4), 0.5) == 2) +assert(math.pow(noinline(-2), 2) == 4) + +-- map +assert(math.map(0, -1, 1, 0, 2) == 1) +assert(math.map(1, 1, 4, 0, 2) == 0) +assert(math.map(2.5, 1, 4, 0, 2) == 1) +assert(math.map(4, 1, 4, 0, 2) == 2) +assert(math.map(1, 1, 4, 2, 0) == 2) +assert(math.map(2.5, 1, 4, 2, 0) == 1) +assert(math.map(4, 1, 4, 2, 0) == 0) +assert(math.map(1, 4, 1, 2, 0) == 0) +assert(math.map(2.5, 4, 1, 2, 0) == 1) +assert(math.map(4, 4, 1, 2, 0) == 2) +assert(math.map(-8, 0, 4, 0, 2) == -4) +assert(math.map(16, 0, 4, 0, 2) == 8) + assert(tostring(math.pow(-2, 0.5)) == "nan") +-- test that fastcalls return correct number of results +assert(select('#', math.floor(1.4)) == 1) +assert(select('#', math.ceil(1.6)) == 1) +assert(select('#', math.sqrt(9)) == 1) +assert(select('#', math.deg(9)) == 1) +assert(select('#', math.rad(9)) == 1) +assert(select('#', math.sin(1.5)) == 1) +assert(select('#', math.atan2(1.5, 0.5)) == 1) +assert(select('#', math.modf(1.5)) == 2) +assert(select('#', math.frexp(1.5)) == 2) + +-- test that fastcalls that return variadic results return them correctly in variadic position +assert(select(1, math.modf(1.5)) == 1) +assert(select(2, math.modf(1.5)) == 0.5) +assert(select(1, math.frexp(1.5)) == 0.75) +assert(select(2, math.frexp(1.5)) == 1) + -- most of the tests above go through fastcall path -- to make sure the basic implementations are also correct we test these functions with string->number coercions assert(math.abs("-4") == 4) @@ -316,7 +442,7 @@ assert(math.log10("10") == 1) assert(math.log("0") == -inf) assert(math.log("8", 2) == 3) assert(math.log("10", 10) == 1) -assert(math.log("9", 3) == 2) +assert(math.log("16", 4) == 2) assert(math.max("1", 2) == 2) assert(math.max(2, "1") == 2) assert(math.max(1, 2, "3") == 3) @@ -339,9 +465,4 @@ assert(math.sign("-2") == -1) assert(math.sign("0") == 0) assert(math.round("1.8") == 2) --- test that fastcalls return correct number of results -assert(select('#', math.floor(1.4)) == 1) -assert(select('#', math.ceil(1.6)) == 1) -assert(select('#', math.sqrt(9)) == 1) - return('OK') diff --git a/tests/conformance/move.lua b/tests/conformance/move.lua index 27a96ffc8..bb613157f 100644 --- a/tests/conformance/move.lua +++ b/tests/conformance/move.lua @@ -64,6 +64,28 @@ do a = table.move({[minI] = 100}, minI, minI, maxI) eqT(a, {[minI] = 100, [maxI] = 100}) + + -- hash part skips array slice + a = {} + table.move({[-1] = 1, [0] = 2, [1] = 3, [2] = 4}, -1, 3, 1, a) + eqT(a, {[1] = 1, [2] = 2, [3] = 3, [4] = 4}) + + a = {} + table.move({[-1] = 1, [0] = 2, [1] = 3, [2] = 4, [10] = 5, [100] = 6, [1000] = 7}, -1, 3, 1, a) + eqT(a, {[1] = 1, [2] = 2, [3] = 3, [4] = 4}) + + -- moving ranges containing nil values into tables with values + a = {1, 2, 3, 4, 5} + table.move({10}, 1, 3, 2, a) + eqT(a, {1, 10, nil, nil, 5}) + + a = {1, 2, 3, 4, 5} + table.move({10}, -1, 1, 2, a) + eqT(a, {1, nil, nil, 10, 5}) + + a = {[-1000] = 1, [1000] = 2, [1] = 3} + table.move({10}, -1000, 1000, -1000, a) + eqT(a, {10}) end checkerror("too many", table.move, {}, 0, maxI, 1) diff --git a/tests/conformance/native.lua b/tests/conformance/native.lua new file mode 100644 index 000000000..038450137 --- /dev/null +++ b/tests/conformance/native.lua @@ -0,0 +1,516 @@ +-- This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +print("testing native code generation") + +assert((function(x, y) + -- trigger a linear sequence + local t1 = x + 2 + local t2 = x - 7 + + local a = x * 10 + local b = a + 1 + a = y -- update 'a' version + local t = {} -- call to allocate table forces a spill + local c = x * 10 + return c, b, t, t1, t2 +end)(5, 10) == 50) + +assert((function(x) + local oops -- split to prevent inlining + function oops() + end + + -- x is checked to be a number here; we can not execute a reentry from oops() because optimizer assumes this holds until return + local y = math.abs(x) + oops() + return y * x +end)("42") == 1764) + +local function fuzzfail1(...) + repeat + _ = nil + until not {} + for _ in ... do + for l0=_,_ do + end + return + end +end + +local function fuzzfail2() + local _ + do + repeat + _ = typeof(_),{_=_,} + _ = _(_._) + until _ + end +end + +assert(pcall(fuzzfail2) == false) + +local function fuzzfail3() + function _(...) + _({_,_,true,},{...,},_,not _) + end + _() +end + +assert(pcall(fuzzfail3) == false) + +local function fuzzfail4() + local _ = setmetatable({},setmetatable({_=_,},_)) + return _(_:_()) +end + +assert(pcall(fuzzfail4) == false) + +local function fuzzfail5() + local _ = bit32.band + _(_(_,0),_) + _(_,_) +end + +assert(pcall(fuzzfail5) == false) + +local function fuzzfail6(_) + return bit32.extract(_,671088640,_) +end + +assert(pcall(fuzzfail6, 1) == false) + +local function fuzzfail7(_) + return bit32.extract(_,_,671088640) +end + +assert(pcall(fuzzfail7, 1) == false) + +local function fuzzfail8(...) + local _ = _,_ + _.n0,_,_,_,_,_,_,_,_._,_,_,_[...],_,_,_ = nil + _,n0,_,_,_,_,_,_,_,_,l0,_,_,_,_ = nil + function _() + end + _._,_,_,_,_,_,_,_,_,_,_[...],_,n0[l0],_ = nil + _[...],_,_,_,_,_,_,_,_()[_],_,_,_,_,_ = _(),... +end + +assert(pcall(fuzzfail8) == false) + +local function fuzzfail9() + local _ = bit32.bor + local x = _(_(_,_),_(_,_),_(-16834560,_),_(_(- _,-2130706432)),- _),_(_(_,_),_(-16834560,-2130706432)) +end + +assert(pcall(fuzzfail9) == false) + +local function fuzzfail10() + local _ + _ = false,if _ then _ else _ + _ = not _ + l0,_[l0] = not _ +end + +assert(pcall(fuzzfail10) == false) + +local function fuzzfail11(x, ...) + return bit32.arshift(bit32.bnot(x),(...)) +end + +assert(fuzzfail11(0xffff0000, 8) == 0xff) + +local function fuzzfail12() + _,_,_,_,_,_,_,_ = not _, not _, not _, not _, not _, not _, not _, not _ +end + +assert(pcall(fuzzfail12) == true) + +local function fuzzfail13() + _,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_ = not _, not _, not _, not _, not _, not _, not _, not _, not _, not _, not _, not _, not _, not _, not _, not _ +end + +assert(pcall(fuzzfail13) == true) + +local function fuzzfail14() + for l0=771751936,_ do + for l0=771751936,0 do + while 538970624 do + end + end + end +end + +assert(pcall(fuzzfail14) == false) + +local function fuzzfail15() + local a + if a then + repeat until a + else + local b = `{a}` + a = nil + end +end + +assert(pcall(fuzzfail15) == true) + +local function fuzzfail16() + _ = {[{[2]=77,_=_,[2]=_,}]=not _,} + _ = {77,[2]=11008,[2]=_,[0]=_,} +end + +assert(pcall(fuzzfail16) == true) + +local function fuzzfail17() + return bit32.extract(1293942816,1293942816) +end + +assert(pcall(fuzzfail17) == false) + +local function fuzzfail18() + return bit32.extract(7890276,0) +end + +assert(pcall(fuzzfail18) == true) +assert(fuzzfail18() == 0) + +local function fuzzfail19() + local _ = 2 + _ += _ + _ = _,_ >= _,{_ >= _,_ >= _,_(),} + + local _ = 2 + do + _ = assert({n0=_,_,n0=_,}),{_={_[_()],},_,} + end +end + +assert(pcall(fuzzfail19) == false) + +local function fuzzfail20() + assert(true) + assert(false,(_),true) + _ = nil +end + +assert(pcall(fuzzfail20) == false) + +local function fuzzfail21(...) + local _ = assert,_ + if _ then else return _ / _ end + _(_) + _(_,_) + assert(...,_) + _((not _),_) + _(true,_ / _) + _(_,_()) + return _ +end + +assert(pcall(fuzzfail21) == false) + +local function fuzzfail22(...) + local _ = {false,},true,...,l0 + while _ do + _ = true,{unpack(0,_),},l0 + _.n126 = nil + _ = {not _,_=not _,n0=_,_,n0=not _,},_ < _ + return _ > _ + end + return `""` +end + +assert(pcall(fuzzfail22) == false) + +local function fuzzfail23(...) + local _ = {false,},_,...,l0 + while _ do + _ = true,{unpack(_),},l0 + _ = {{[_]=nil,_=not _,_,true,_=nil,},not _,not _,_,bxor=- _,} + do end + break + end + do end + local _ = _,true + do end + local _ = _,true +end + +assert(pcall(fuzzfail23) == false) + +local function arraySizeInv1() + local t = {1, 2, nil, nil, nil, nil, nil, nil, nil, true} + + table.insert(t, 3) + + return t[10] +end + +assert(arraySizeInv1() == true) + +local function arraySizeInv2() + local t = {1, 2, nil, nil, nil, nil, nil, nil, nil, true} + + local u = {a = t} + table.insert(u.a, 3) -- aliased modifiction of 't' register through other value + + return t[10] +end + +assert(arraySizeInv2() == true) + +local function nilInvalidatesSlot() + local function tabs() + local t = { x=1, y=2, z=3 } + setmetatable(t, { __index = function(t, k) return 42 end }) + return t, t + end + + local t1, t2 = tabs() + + for i=1,2 do + local a = t1.x + t2.x = nil + local b = t1.x + t2.x = 1 + assert(a == 1 and b == 42) + end +end + +nilInvalidatesSlot() + +local function arraySizeOpt1(a) + a[1] += 2 + a[1] *= 3 + + table.insert(a, 3) + table.insert(a, 4) + table.insert(a, 5) + table.insert(a, 6) + + a[1] += 4 + a[1] *= 5 + + return a[1] + a[5] +end + +assert(arraySizeOpt1({1}) == 71) + +local function arraySizeOpt2(a, i) + a[i] += 2 + a[i] *= 3 + + table.insert(a, 3) + table.insert(a, 4) + table.insert(a, 5) + table.insert(a, 6) + + a[i] += 4 + a[i] *= 5 + + return a[i] + a[5] +end + +assert(arraySizeOpt2({1}, 1) == 71) + +function deadLoopBody(n) + local r = 0 + if n and false then + for i = 1, n do + r += 1 + end + end + return r +end + +assert(deadLoopBody(5) == 0) + +function arrayIndexingSpecialNumbers1(a, b, c) + local arr = table.create(100000) + arr[a] = 9 + arr[b-1] = 80 + arr[b] = 700 + arr[b+1] = 6000 + arr[c-1] = 50000 + arr[c] = 400000 + arr[c+1] = 3000000 + + return arr[1] + arr[255] + arr[256] + arr[257] + arr[65535] + arr[65536] + arr[65537] +end + +assert(arrayIndexingSpecialNumbers1(1, 256, 65536) == 3456789) + +function loopIteratorProtocol(a, t) + local sum = 0 + + do + local a, b, c, d, e, f, g = {}, {}, {}, {}, {}, {}, {} + end + + for k, v in ipairs(t) do + if k == 10 then sum += math.abs('-8') end + + sum += k + end + + return sum +end + +assert(loopIteratorProtocol(0, table.create(100, 5)) == 5058) + +function valueTrackingIssue1() + local b = buffer.create(1) + buffer.writeu8(b, 0, 0) + local v1 + + local function closure() + assert(type(b) == "buffer") -- b is the first upvalue + v1 = nil -- v1 is the second upvalue + + -- prevent inlining + for i = 1, 100 do print(`{b} is {b}`) end + end + + closure() +end + +valueTrackingIssue1() + +local function vec3compsum(a: vector) + return a.X + a.Y + a.Z +end + +assert(vec3compsum(vector(1, 2, 4)) == 7.0) + +local function vec3add(a: vector, b: vector) return a + b end +local function vec3sub(a: vector, b: vector) return a - b end +local function vec3mul(a: vector, b: vector) return a * b end +local function vec3div(a: vector, b: vector) return a / b end +local function vec3neg(a: vector) return -a end + +assert(vec3add(vector(10, 20, 40), vector(1, 0, 2)) == vector(11, 20, 42)) +assert(vec3sub(vector(10, 20, 40), vector(1, 0, 2)) == vector(9, 20, 38)) +assert(vec3mul(vector(10, 20, 40), vector(1, 0, 2)) == vector(10, 0, 80)) +assert(vec3div(vector(10, 20, 40), vector(1, 0, 2)) == vector(10, math.huge, 20)) +assert(vec3neg(vector(10, 20, 40)) == vector(-10, -20, -40)) + +local function vec3mulnum(a: vector, b: number) return a * b end +local function vec3mulconst(a: vector) return a * 4 end + +assert(vec3mulnum(vector(10, 20, 40), 4) == vector(40, 80, 160)) +assert(vec3mulconst(vector(10, 20, 40), 4) == vector(40, 80, 160)) + +local function bufferbounds(zero) + local b1 = buffer.create(1) + local b2 = buffer.create(2) + local b4 = buffer.create(4) + local b8 = buffer.create(8) + local b10 = buffer.create(10) + + -- only one valid position and size for a 1 byte buffer + buffer.writei8(b1, zero + 0, buffer.readi8(b1, zero + 0)) + buffer.writeu8(b1, zero + 0, buffer.readu8(b1, zero + 0)) + + -- 2 byte buffer + buffer.writei8(b2, zero + 0, buffer.readi8(b2, zero + 0)) + buffer.writeu8(b2, zero + 0, buffer.readu8(b2, zero + 0)) + buffer.writei8(b2, zero + 1, buffer.readi8(b2, zero + 1)) + buffer.writeu8(b2, zero + 1, buffer.readu8(b2, zero + 1)) + buffer.writei16(b2, zero + 0, buffer.readi16(b2, zero + 0)) + buffer.writeu16(b2, zero + 0, buffer.readu16(b2, zero + 0)) + + -- 4 byte buffer + buffer.writei8(b4, zero + 0, buffer.readi8(b4, zero + 0)) + buffer.writeu8(b4, zero + 0, buffer.readu8(b4, zero + 0)) + buffer.writei8(b4, zero + 3, buffer.readi8(b4, zero + 3)) + buffer.writeu8(b4, zero + 3, buffer.readu8(b4, zero + 3)) + buffer.writei16(b4, zero + 0, buffer.readi16(b4, zero + 0)) + buffer.writeu16(b4, zero + 0, buffer.readu16(b4, zero + 0)) + buffer.writei16(b4, zero + 2, buffer.readi16(b4, zero + 2)) + buffer.writeu16(b4, zero + 2, buffer.readu16(b4, zero + 2)) + buffer.writei32(b4, zero + 0, buffer.readi32(b4, zero + 0)) + buffer.writeu32(b4, zero + 0, buffer.readu32(b4, zero + 0)) + buffer.writef32(b4, zero + 0, buffer.readf32(b4, zero + 0)) + + -- 8 byte buffer + buffer.writei8(b8, zero + 0, buffer.readi8(b8, zero + 0)) + buffer.writeu8(b8, zero + 0, buffer.readu8(b8, zero + 0)) + buffer.writei8(b8, zero + 7, buffer.readi8(b8, zero + 7)) + buffer.writeu8(b8, zero + 7, buffer.readu8(b8, zero + 7)) + buffer.writei16(b8, zero + 0, buffer.readi16(b8, zero + 0)) + buffer.writeu16(b8, zero + 0, buffer.readu16(b8, zero + 0)) + buffer.writei16(b8, zero + 6, buffer.readi16(b8, zero + 6)) + buffer.writeu16(b8, zero + 6, buffer.readu16(b8, zero + 6)) + buffer.writei32(b8, zero + 0, buffer.readi32(b8, zero + 0)) + buffer.writeu32(b8, zero + 0, buffer.readu32(b8, zero + 0)) + buffer.writef32(b8, zero + 0, buffer.readf32(b8, zero + 0)) + buffer.writei32(b8, zero + 4, buffer.readi32(b8, zero + 4)) + buffer.writeu32(b8, zero + 4, buffer.readu32(b8, zero + 4)) + buffer.writef32(b8, zero + 4, buffer.readf32(b8, zero + 4)) + buffer.writef64(b8, zero + 0, buffer.readf64(b8, zero + 0)) + + -- 'any' size buffer + buffer.writei8(b10, zero + 0, buffer.readi8(b10, zero + 0)) + buffer.writeu8(b10, zero + 0, buffer.readu8(b10, zero + 0)) + buffer.writei8(b10, zero + 9, buffer.readi8(b10, zero + 9)) + buffer.writeu8(b10, zero + 9, buffer.readu8(b10, zero + 9)) + buffer.writei16(b10, zero + 0, buffer.readi16(b10, zero + 0)) + buffer.writeu16(b10, zero + 0, buffer.readu16(b10, zero + 0)) + buffer.writei16(b10, zero + 8, buffer.readi16(b10, zero + 8)) + buffer.writeu16(b10, zero + 8, buffer.readu16(b10, zero + 8)) + buffer.writei32(b10, zero + 0, buffer.readi32(b10, zero + 0)) + buffer.writeu32(b10, zero + 0, buffer.readu32(b10, zero + 0)) + buffer.writef32(b10, zero + 0, buffer.readf32(b10, zero + 0)) + buffer.writei32(b10, zero + 6, buffer.readi32(b10, zero + 6)) + buffer.writeu32(b10, zero + 6, buffer.readu32(b10, zero + 6)) + buffer.writef32(b10, zero + 6, buffer.readf32(b10, zero + 6)) + buffer.writef64(b10, zero + 0, buffer.readf64(b10, zero + 0)) + buffer.writef64(b10, zero + 2, buffer.readf64(b10, zero + 2)) + + assert(is_native()) +end + +bufferbounds(0) + +function deadStoreChecks1() + local a = 1.0 + local b = 0.0 + + local function update() + b += a + for i = 1, 100 do print(`{b} is {b}`) end + end + + update() + a = 10 + update() + a = 100 + update() + + return b +end + +assert(deadStoreChecks1() == 111) + +local function extramath1(a) + return type(math.sign(a)) +end + +assert(extramath1(2) == "number") +assert(extramath1("2") == "number") + +local function extramath2(a) + return type(math.modf(a)) +end + +assert(extramath2(2) == "number") +assert(extramath2("2") == "number") + +local function extramath3(a) + local b, c = math.modf(a) + return type(c) +end + +assert(extramath3(2) == "number") +assert(extramath3("2") == "number") + +return('OK') diff --git a/tests/conformance/native_types.lua b/tests/conformance/native_types.lua new file mode 100644 index 000000000..639ce80b8 --- /dev/null +++ b/tests/conformance/native_types.lua @@ -0,0 +1,91 @@ +-- This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +print("testing native code generation with type annotations") + +function call(fn, ...) + local ok, res = pcall(fn, ...) + assert(ok) + return res +end + +function ecall(fn, ...) + local ok, err = pcall(fn, ...) + assert(not ok) + return err:sub(err:find(": ") + 2, #err) +end + +local function add(a: number, b: number, native: boolean) + assert(native == is_native()) + return a + b +end + +call(add, 1, 3, true) +ecall(add, nil, 2, false) + +local function isnil(x: nil) + assert(is_native()) + return not x +end + +call(isnil, nil) +ecall(isnil, 2) + +local function isany(x: any, y: number) + assert(is_native()) + return not not x +end + +call(isany, nil, 1) +call(isany, 2, 1) +call(isany, {}, 1) + +local function optstring(s: string?) + assert(is_native()) + return if s then s..'2' else '3' +end + +assert(call(optstring, nil) == '3') +assert(call(optstring, 'two: ') == 'two: 2') +ecall(optstring, 2) + +local function checktable(a: {x:number}) assert(is_native()) end +local function checkfunction(a: () -> ()) assert(is_native()) end +local function checkthread(a: thread) assert(is_native()) end +local function checkuserdata(a: userdata) assert(is_native()) end +local function checkvector(a: vector) assert(is_native()) end +local function checkbuffer(a: buffer) assert(is_native()) end +local function checkoptbuffer(a: buffer?) assert(is_native()) end + +call(checktable, {}) +ecall(checktable, 2) + +call(checkfunction, function() end) +ecall(checkfunction, 2) + +call(checkthread, coroutine.create(function() end)) +ecall(checkthread, 2) + +call(checkuserdata, newproxy()) +ecall(checkuserdata, 2) + +call(checkvector, vector(1, 2, 3)) +ecall(checkvector, 2) + +call(checkbuffer, buffer.create(10)) +ecall(checkbuffer, 2) +call(checkoptbuffer, buffer.create(10)) +call(checkoptbuffer, nil) +ecall(checkoptbuffer, 2) + +local function mutation_causes_bad_exit(a: number, count: number, sum: number) + repeat + a = 's' + sum += count + pcall(function() end) + count -= 1 + until count == 0 + return sum +end + +assert(call(mutation_causes_bad_exit, 5, 10, 0) == 55) + +return('OK') diff --git a/tests/conformance/native_userdata.lua b/tests/conformance/native_userdata.lua new file mode 100644 index 000000000..b1b2a1033 --- /dev/null +++ b/tests/conformance/native_userdata.lua @@ -0,0 +1,42 @@ +-- This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +print('testing userdata') + +function ecall(fn, ...) + local ok, err = pcall(fn, ...) + assert(not ok) + return err:sub((err:find(": ") or -1) + 2, #err) +end + +local function realmad(a: vec2, b: vec2, c: vec2): vec2 + return -c + a * b; +end + +local function dm(s: vec2, t: vec2, u: vec2) + local x = s:Dot(t) + assert(x == 13) + + local t = u:Min(s) + assert(t.X == 5) + assert(t.Y == 4) +end + +local s: vec2 = vec2(5, 4) +local t: vec2 = vec2(1, 2) +local u: vec2 = vec2(10, 20) + +local x: vec2 = realmad(s, t, u) + +assert(x.X == -5) +assert(x.Y == -12) + +dm(s, t, u) + +local function mu(v: vec2) + assert(v.Magnitude == 2) + assert(v.Unit.X == 0) + assert(v.Unit.Y == 1) +end + +mu(vec2(0, 2)) + +return 'OK' diff --git a/tests/conformance/pcall.lua b/tests/conformance/pcall.lua index 969209fc4..265c397bd 100644 --- a/tests/conformance/pcall.lua +++ b/tests/conformance/pcall.lua @@ -63,7 +63,7 @@ if not limitedstack then function stackover() return pcall(stackover) end local res = {pcall(stackover)} - assert(#res == 200) + assert(#res == 200) -- stack limit (MAXCCALLS) is 200 end -- yield tests @@ -161,4 +161,111 @@ checkresults({ false, "ok" }, xpcall(recurse, function() return string.reverse(" -- however, if xpcall handler itself runs out of extra stack space, we get "error in error handling" checkresults({ false, "error in error handling" }, xpcall(recurse, function() return recurse(calllimit) end, calllimit - 2)) +-- simulate OOM and make sure we can catch it with pcall or xpcall +checkresults({ false, "not enough memory" }, pcall(function() table.create(1e6) end)) +checkresults({ false, "not enough memory" }, xpcall(function() table.create(1e6) end, function(e) return e end)) +checkresults({ false, "oops" }, xpcall(function() table.create(1e6) end, function(e) return "oops" end)) +checkresults({ false, "error in error handling" }, xpcall(function() error("oops") end, function(e) table.create(1e6) end)) +checkresults({ false, "not enough memory" }, xpcall(function() table.create(1e6) end, function(e) table.create(1e6) end)) + +co = coroutine.create(function() table.create(1e6) end) +coroutine.resume(co) +checkresults({ false, "not enough memory" }, coroutine.close(co)) + +-- ensure that pcall and xpcall close upvalues when handling error +local upclo +local function uptest(y) + local a, b, c, d = 1, 2, 3, 4 + upclo = function() + local t = table.pack("a", "b", "d", "e", "f", "g", "h", "i") + assert(a == 1) + assert(b == 2) + assert(c == 3) + assert(d == 4) + a, b, c, d = table.unpack(t) + return "ok" + end + if y then + y() + end + error("oops") +end + +-- ensure that pcall and xpcall close upvalues when handling error (immediate) +do + upclo = nil + pcall(uptest, nil) + assert(upclo() == "ok") +end + +do + upclo = nil + xpcall(uptest, function(err) end, nil) + assert(upclo() == "ok") +end + +do + upclo = nil + xpcall(uptest, function(err) return "e", "e", "e", "e", "e", "e", "e", "e" end, nil) + assert(upclo() == "ok") +end + +-- ensure that pcall and xpcall close upvalues when handling error (deferred) +do + upclo = nil + checkresults({"yield", "return", "ok"}, colog(function() + pcall(uptest, coroutine.yield) + return upclo() + end)) +end + +do + upclo = nil + checkresults({"yield", "return", "ok"}, colog(function() + xpcall(uptest, function(err) end, coroutine.yield) + return upclo() + end)) +end + +do + upclo = nil + checkresults({"yield", "return", "ok"}, colog(function() + xpcall(uptest, function(err) return "e", "e", "e", "e", "e", "e", "e", "e" end, coroutine.yield) + return upclo() + end)) +end + +-- also cover an edge case where xpcall's error handler may want access to the upvalues (immediate + deferred) +do + local ur + upclo = nil + xpcall(uptest, function(err) + ur = upclo() + end, nil) + assert(ur == "ok") +end + +do + local ur + upclo = nil + checkresults({"yield", "return"}, colog(function() + xpcall(uptest, function(err) + ur = upclo() + end, coroutine.yield) + end)) + assert(ur == "ok") +end + +-- test stack overflow from a thread that had an error, recovered, and subsequently called coroutine.resume again +if not limitedstack then + local count = 0 + local function foo() + count += 1 + pcall(1) -- create an error + coroutine.wrap(foo)() -- call another coroutine + end + checkerror(pcall(foo)) -- triggers C stack overflow + assert(count + 1 == 200) -- stack limit (MAXCCALLS) is 200, -1 for first pcall +end + return 'OK' diff --git a/tests/conformance/sort.lua b/tests/conformance/sort.lua index 95940e111..3c2c20dd4 100644 --- a/tests/conformance/sort.lua +++ b/tests/conformance/sort.lua @@ -2,6 +2,47 @@ -- This file is based on Lua 5.x tests -- https://github.com/lua/lua/tree/master/testes print"testing sort" +function checksort(t, f, ...) + assert(#t == select('#', ...)) + local copy = table.clone(t) + table.sort(copy, f) + for i=1,#t do assert(copy[i] == select(i, ...)) end +end + +-- basic edge cases +checksort({}, nil) +checksort({1}, nil, 1) + +-- small inputs +checksort({1, 2}, nil, 1, 2) +checksort({2, 1}, nil, 1, 2) + +checksort({1, 2, 3}, nil, 1, 2, 3) +checksort({2, 1, 3}, nil, 1, 2, 3) +checksort({1, 3, 2}, nil, 1, 2, 3) +checksort({3, 2, 1}, nil, 1, 2, 3) +checksort({3, 1, 2}, nil, 1, 2, 3) + +-- "large" input +checksort({3, 8, 1, 7, 10, 2, 5, 4, 9, 6}, nil, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10) +checksort({"Jan", "Feb", "Mar", "Apr", "May", "Jun", "Jul", "Aug", "Sep", "Oct", "Nov", "Dec"}, nil, "Apr", "Aug", "Dec", "Feb", "Jan", "Jul", "Jun", "Mar", "May", "Nov", "Oct", "Sep") + +-- duplicates +checksort({3, 1, 1, 7, 1, 3, 5, 1, 9, 3}, nil, 1, 1, 1, 1, 3, 3, 3, 5, 7, 9) + +-- predicates +checksort({3, 8, 1, 7, 10, 2, 5, 4, 9, 6}, function (a, b) return a > b end, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1) + +-- can't sort readonly tables +assert(pcall(table.sort, table.freeze({2, 1})) == false) + +-- first argument must be a table, second argument must be nil or function +assert(pcall(table.sort) == false) +assert(pcall(table.sort, "abc") == false) +assert(pcall(table.sort, {}, 42) == false) +assert(pcall(table.sort, {}, {}) == false) + +-- legacy Lua tests function check (a, f) f = f or function (x,y) return x // IsDebuggerPresent +#include // IsDebuggerPresent +#endif + +#if defined(CODEGEN_TARGET_X64) +#include #endif #ifdef __APPLE__ @@ -21,8 +27,13 @@ #include #endif +#include // TODO: remove with LuauTypeSolverRelease #include +#include + +LUAU_DYNAMIC_FASTINT(LuauTypeSolverRelease) + // Indicates if verbose output is enabled; can be overridden via --verbose // Currently, this enables output from 'print', but other verbose output could be enabled eventually. bool verbose = false; @@ -44,6 +55,9 @@ static bool skipFastFlag(const char* flagName) if (strncmp(flagName, "Debug", 5) == 0) return true; + if (strcmp(flagName, "StudioReportLuauAny2") == 0) + return true; + return false; } @@ -59,6 +73,25 @@ static bool debuggerPresent() int ret = sysctl(mib, sizeof(mib) / sizeof(*mib), &info, &size, nullptr, 0); // debugger is attached if the P_TRACED flag is set return ret == 0 && (info.kp_proc.p_flag & P_TRACED) != 0; +#elif defined(__linux__) + FILE* st = fopen("/proc/self/status", "r"); + if (!st) + return false; // assume no debugger is attached. + + int tpid = 0; + char buf[256]; + + while (fgets(buf, sizeof(buf), st)) + { + if (strncmp(buf, "TracerPid:\t", 11) == 0) + { + tpid = atoi(buf + 11); + break; + } + } + + fclose(st); + return tpid != 0; #else return false; // assume no debugger is attached. #endif @@ -67,7 +100,7 @@ static bool debuggerPresent() static int testAssertionHandler(const char* expr, const char* file, int line, const char* function) { if (debuggerPresent()) - LUAU_DEBUGBREAK(); + return 1; // LUAU_ASSERT will trigger LUAU_DEBUGBREAK for a more convenient debugging experience ADD_FAIL_AT(file, line, "Assertion failed: ", std::string(expr)); return 1; @@ -142,8 +175,10 @@ struct BoostLikeReporter : doctest::IReporter } void log_message(const doctest::MessageData& md) override - { // - printf("%s(%d): ERROR: %s\n", md.m_file, md.m_line, md.m_string.c_str()); + { + const char* severity = (md.m_severity & doctest::assertType::is_warn) ? "WARNING" : "ERROR"; + + printf("%s(%d): %s: %s\n", md.m_file, md.m_line, severity, md.m_string.c_str()); } // called when a test case is skipped either because it doesn't pass the filters, has a skip decorator @@ -151,6 +186,93 @@ struct BoostLikeReporter : doctest::IReporter void test_case_skipped(const doctest::TestCaseData&) override {} }; +struct TeamCityReporter : doctest::IReporter +{ + const doctest::TestCaseData* currentTest = nullptr; + + TeamCityReporter(const doctest::ContextOptions& in) {} + + void report_query(const doctest::QueryData&) override {} + + void test_run_start() override {} + + void test_run_end(const doctest::TestRunStats& /*in*/) override {} + + void test_case_start(const doctest::TestCaseData& in) override + { + currentTest = ∈ + printf("##teamcity[testStarted name='%s: %s' captureStandardOutput='true']\n", in.m_test_suite, in.m_name); + } + + // called when a test case is reentered because of unfinished subcases + void test_case_reenter(const doctest::TestCaseData& /*in*/) override {} + + void test_case_end(const doctest::CurrentTestCaseStats& in) override + { + printf( + "##teamcity[testMetadata testName='%s: %s' name='total_asserts' type='number' value='%d']\n", + currentTest->m_test_suite, + currentTest->m_name, + in.numAssertsCurrentTest + ); + printf( + "##teamcity[testMetadata testName='%s: %s' name='failed_asserts' type='number' value='%d']\n", + currentTest->m_test_suite, + currentTest->m_name, + in.numAssertsFailedCurrentTest + ); + printf( + "##teamcity[testMetadata testName='%s: %s' name='runtime' type='number' value='%f']\n", + currentTest->m_test_suite, + currentTest->m_name, + in.seconds + ); + + if (!in.testCaseSuccess) + printf("##teamcity[testFailed name='%s: %s']\n", currentTest->m_test_suite, currentTest->m_name); + + printf("##teamcity[testFinished name='%s: %s']\n", currentTest->m_test_suite, currentTest->m_name); + } + + void test_case_exception(const doctest::TestCaseException& in) override + { + printf( + "##teamcity[testFailed name='%s: %s' message='Unhandled exception' details='%s']\n", + currentTest->m_test_suite, + currentTest->m_name, + in.error_string.c_str() + ); + } + + void subcase_start(const doctest::SubcaseSignature& /*in*/) override {} + void subcase_end() override {} + + void log_assert(const doctest::AssertData& ad) override + { + if (!ad.m_failed) + return; + + if (ad.m_decomp.size()) + fprintf(stderr, "%s(%d): ERROR: %s (%s)\n", ad.m_file, ad.m_line, ad.m_expr, ad.m_decomp.c_str()); + else + fprintf(stderr, "%s(%d): ERROR: %s\n", ad.m_file, ad.m_line, ad.m_expr); + } + + void log_message(const doctest::MessageData& md) override + { + const char* severity = (md.m_severity & doctest::assertType::is_warn) ? "WARNING" : "ERROR"; + bool isError = md.m_severity & (doctest::assertType::is_require | doctest::assertType::is_check); + fprintf(isError ? stderr : stdout, "%s(%d): %s: %s\n", md.m_file, md.m_line, severity, md.m_string.c_str()); + } + + void test_case_skipped(const doctest::TestCaseData& in) override + { + printf("##teamcity[testIgnored name='%s: %s' captureStandardOutput='false']\n", in.m_test_suite, in.m_name); + } +}; + +REGISTER_REPORTER("teamcity", 1, TeamCityReporter); + template using FValueResult = std::pair; @@ -182,7 +304,7 @@ static FValueResult parseFFlag(std::string_view view) auto [name, value] = parseFValueHelper(view); bool state = value ? *value == "true" : true; if (value && value != "true" && value != "false") - std::cerr << "Ignored '" << name << "' because '" << *value << "' is not a valid FFlag state." << std::endl; + fprintf(stderr, "Ignored '%s' because '%s' is not a valid flag state\n", name.c_str(), value->c_str()); return {name, state}; } @@ -228,8 +350,21 @@ static void setFastFlags(const std::vector& flags) } } +// This function performs system/architecture specific initialization prior to running tests. +static void initSystem() +{ +#if defined(CODEGEN_TARGET_X64) + // Some unit tests make use of denormalized numbers. So flags to flush to zero or treat denormals as zero + // must be disabled for expected behavior. + _MM_SET_FLUSH_ZERO_MODE(_MM_FLUSH_ZERO_OFF); + _MM_SET_DENORMALS_ZERO_MODE(_MM_DENORMALS_ZERO_OFF); +#endif +} + int main(int argc, char** argv) { + initSystem(); + Luau::assertHandler() = testAssertionHandler; doctest::registerReporter("boost", 0, true); @@ -245,9 +380,7 @@ int main(int argc, char** argv) if (skipFastFlag(flag->name)) continue; - if (flag->dynamic) - std::cout << 'D'; - std::cout << "FFlag" << flag->name << std::endl; + printf("%sFFlag%s\n", flag->dynamic ? "D" : "", flag->name); } return 0; @@ -267,7 +400,7 @@ int main(int argc, char** argv) if (doctest::parseIntOption(argc, argv, "-O", doctest::option_int, level)) { if (level < 0 || level > 2) - std::cerr << "Optimization level must be between 0 and 2 inclusive." << std::endl; + fprintf(stderr, "Optimization level must be between 0 and 2 inclusive\n"); else optimizationLevel = level; } @@ -282,6 +415,12 @@ int main(int argc, char** argv) printf("Using RNG seed %u\n", *randomSeed); } + // New Luau type solver uses a temporary scheme where fixes are made under a single version flag + // When flags are enabled, new solver is enabled with all new features and fixes + // When it's disabled, this value should have no effect (all uses under a new solver) + // Flag setup argument can still be used to override this to a specific value if desired + DFInt::LuauTypeSolverRelease.value = std::numeric_limits::max(); + if (std::vector flags; doctest::parseCommaSepArgs(argc, argv, "--fflags=", flags)) setFastFlags(flags); @@ -308,6 +447,14 @@ int main(int argc, char** argv) } } + // These callbacks register unit tests that need runtime support to be + // correctly set up. Running them here means that all command line flags + // have been parsed, fast flags have been set, and we've potentially already + // exited. Once doctest::Context::run is invoked, the test list will be + // picked up from global state. + for (Luau::RegisterCallback cb : Luau::getRegisterCallbacks()) + cb(); + int result = context.run(); if (doctest::parseFlag(argc, argv, "--help") || doctest::parseFlag(argc, argv, "-h")) { diff --git a/tests/require/with_config/.luaurc b/tests/require/with_config/.luaurc new file mode 100644 index 000000000..7e7abf18a --- /dev/null +++ b/tests/require/with_config/.luaurc @@ -0,0 +1,6 @@ +{ + "aliases": { + "dep": "this_should_be_overwritten_by_child_luaurc", + "otherdep": "src/other_dependency" + } +} diff --git a/tests/require/with_config/GlobalLuauLibraries/global_library.luau b/tests/require/with_config/GlobalLuauLibraries/global_library.luau new file mode 100644 index 000000000..0508e0bd1 --- /dev/null +++ b/tests/require/with_config/GlobalLuauLibraries/global_library.luau @@ -0,0 +1 @@ +return {"result from global_library"} diff --git a/tests/require/with_config/ProjectLuauLibraries/library.luau b/tests/require/with_config/ProjectLuauLibraries/library.luau new file mode 100644 index 000000000..9470401b3 --- /dev/null +++ b/tests/require/with_config/ProjectLuauLibraries/library.luau @@ -0,0 +1 @@ +return {"result from library"} diff --git a/tests/require/with_config/src/.luaurc b/tests/require/with_config/src/.luaurc new file mode 100644 index 000000000..90c6b646d --- /dev/null +++ b/tests/require/with_config/src/.luaurc @@ -0,0 +1,6 @@ +{ + "aliases": { + "dep": "dependency", + "subdir": "subdirectory" + } +} diff --git a/tests/require/with_config/src/alias_requirer.luau b/tests/require/with_config/src/alias_requirer.luau new file mode 100644 index 000000000..4375a7835 --- /dev/null +++ b/tests/require/with_config/src/alias_requirer.luau @@ -0,0 +1 @@ +return require("@dep") diff --git a/tests/require/with_config/src/dependency.luau b/tests/require/with_config/src/dependency.luau new file mode 100644 index 000000000..07466f429 --- /dev/null +++ b/tests/require/with_config/src/dependency.luau @@ -0,0 +1 @@ +return {"result from dependency"} diff --git a/tests/require/with_config/src/directory_alias_requirer.luau b/tests/require/with_config/src/directory_alias_requirer.luau new file mode 100644 index 000000000..3b19d4ffa --- /dev/null +++ b/tests/require/with_config/src/directory_alias_requirer.luau @@ -0,0 +1 @@ +return(require("@subdir/subdirectory_dependency")) diff --git a/tests/require/with_config/src/other_dependency.luau b/tests/require/with_config/src/other_dependency.luau new file mode 100644 index 000000000..8c582dc22 --- /dev/null +++ b/tests/require/with_config/src/other_dependency.luau @@ -0,0 +1 @@ +return {"result from other_dependency"} diff --git a/tests/require/with_config/src/parent_alias_requirer.luau b/tests/require/with_config/src/parent_alias_requirer.luau new file mode 100644 index 000000000..a8e8de094 --- /dev/null +++ b/tests/require/with_config/src/parent_alias_requirer.luau @@ -0,0 +1 @@ +return require("@otherdep") diff --git a/tests/require/with_config/src/subdirectory/subdirectory_dependency.luau b/tests/require/with_config/src/subdirectory/subdirectory_dependency.luau new file mode 100644 index 000000000..8bbd0bebd --- /dev/null +++ b/tests/require/with_config/src/subdirectory/subdirectory_dependency.luau @@ -0,0 +1 @@ +return {"result from subdirectory_dependency"} diff --git a/tests/require/without_config/ambiguous/directory/dependency.luau b/tests/require/without_config/ambiguous/directory/dependency.luau new file mode 100644 index 000000000..07466f429 --- /dev/null +++ b/tests/require/without_config/ambiguous/directory/dependency.luau @@ -0,0 +1 @@ +return {"result from dependency"} diff --git a/tests/require/without_config/ambiguous/directory/dependency/init.luau b/tests/require/without_config/ambiguous/directory/dependency/init.luau new file mode 100644 index 000000000..07466f429 --- /dev/null +++ b/tests/require/without_config/ambiguous/directory/dependency/init.luau @@ -0,0 +1 @@ +return {"result from dependency"} diff --git a/tests/require/without_config/ambiguous/file/dependency.lua b/tests/require/without_config/ambiguous/file/dependency.lua new file mode 100644 index 000000000..07466f429 --- /dev/null +++ b/tests/require/without_config/ambiguous/file/dependency.lua @@ -0,0 +1 @@ +return {"result from dependency"} diff --git a/tests/require/without_config/ambiguous/file/dependency.luau b/tests/require/without_config/ambiguous/file/dependency.luau new file mode 100644 index 000000000..07466f429 --- /dev/null +++ b/tests/require/without_config/ambiguous/file/dependency.luau @@ -0,0 +1 @@ +return {"result from dependency"} diff --git a/tests/require/without_config/ambiguous_directory_requirer.luau b/tests/require/without_config/ambiguous_directory_requirer.luau new file mode 100644 index 000000000..e46be806b --- /dev/null +++ b/tests/require/without_config/ambiguous_directory_requirer.luau @@ -0,0 +1,3 @@ +local result = require("./ambiguous/directory/dependency") +result[#result+1] = "required into module" +return result diff --git a/tests/require/without_config/ambiguous_file_requirer.luau b/tests/require/without_config/ambiguous_file_requirer.luau new file mode 100644 index 000000000..8e3a576d3 --- /dev/null +++ b/tests/require/without_config/ambiguous_file_requirer.luau @@ -0,0 +1,3 @@ +local result = require("./ambiguous/file/dependency") +result[#result+1] = "required into module" +return result diff --git a/tests/require/without_config/dependency.luau b/tests/require/without_config/dependency.luau new file mode 100644 index 000000000..07466f429 --- /dev/null +++ b/tests/require/without_config/dependency.luau @@ -0,0 +1 @@ +return {"result from dependency"} diff --git a/tests/require/without_config/lua/init.lua b/tests/require/without_config/lua/init.lua new file mode 100644 index 000000000..7c28b735f --- /dev/null +++ b/tests/require/without_config/lua/init.lua @@ -0,0 +1 @@ +return {"result from init.lua"} diff --git a/tests/require/without_config/lua_dependency.lua b/tests/require/without_config/lua_dependency.lua new file mode 100644 index 000000000..aec2d82b5 --- /dev/null +++ b/tests/require/without_config/lua_dependency.lua @@ -0,0 +1 @@ +return {"result from lua_dependency"} diff --git a/tests/require/without_config/luau/init.lua b/tests/require/without_config/luau/init.lua new file mode 100644 index 000000000..7e3680b20 --- /dev/null +++ b/tests/require/without_config/luau/init.lua @@ -0,0 +1 @@ +return {"wrong file"} diff --git a/tests/require/without_config/luau/init.luau b/tests/require/without_config/luau/init.luau new file mode 100644 index 000000000..72463463d --- /dev/null +++ b/tests/require/without_config/luau/init.luau @@ -0,0 +1 @@ +return {"result from init.luau"} diff --git a/tests/require/without_config/module.luau b/tests/require/without_config/module.luau new file mode 100644 index 000000000..1d1393ff8 --- /dev/null +++ b/tests/require/without_config/module.luau @@ -0,0 +1,3 @@ +local result = require("./dependency") +result[#result+1] = "required into module" +return result diff --git a/tools/codegenstat.py b/tools/codegenstat.py new file mode 100644 index 000000000..9370cc740 --- /dev/null +++ b/tools/codegenstat.py @@ -0,0 +1,58 @@ +#!/usr/bin/python3 +# This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +# Given the output of --compile=codegenverbose in stdin, this script outputs statistics about bytecode/IR + +import sys +import re +from collections import defaultdict + +count_bc = defaultdict(int) +count_ir = defaultdict(int) +count_asm = defaultdict(int) +count_irasm = defaultdict(int) + +# GETTABLEKS R10 R1 K18 ['s'] +# L1: DIV R14 R13 R3 +re_bc = re.compile(r'^(?:L\d+: )?([A-Z_]+) ') +# # CHECK_SLOT_MATCH %178, K3, bb_fallback_37 +# # %175 = LOAD_TAG R15 +re_ir = re.compile(r'^# (?:%\d+ = )?([A-Z_]+) ') +# cmp w14,#5 +re_asm = re.compile(r'^ ([a-z.]+) ') + +current_ir = None + +for line in sys.stdin.buffer.readlines(): + line = line.decode('utf-8', errors='ignore').rstrip() + + if m := re_asm.match(line): + count_asm[m[1]] += 1 + if current_ir: + count_irasm[current_ir] += 1 + elif m := re_ir.match(line): + count_ir[m[1]] += 1 + current_ir = m[1] + elif m := re_bc.match(line): + count_bc[m[1]] += 1 + +def display(name, counts, limit=None, extra=None): + items = sorted(counts.items(), key=lambda p: p[1], reverse=True) + total = 0 + for k,v in items: + total += v + shown = 0 + print(name) + for i, (k,v) in enumerate(items): + if i == limit: + if shown < total: + print(f' {"Others":25}: {total-shown} ({(total-shown)/total*100:.1f}%)') + break + print(f' {k:25}: {v} ({v/total*100:.1f}%){"; "+extra(k) if extra else ""}') + shown += v + print() + +display("Bytecode", count_bc, limit=20) +display("IR", count_ir, limit=20) +display("Assembly", count_asm, limit=10) +display("IR->Assembly", count_irasm, limit=30, extra=lambda op: f'{count_irasm[op] / count_ir[op]:.1f} insn/op') diff --git a/tools/codesizeprediction.py b/tools/codesizeprediction.py new file mode 100644 index 000000000..ba877efe4 --- /dev/null +++ b/tools/codesizeprediction.py @@ -0,0 +1,185 @@ +#!/usr/bin/python3 +# This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +# NOTE: This script is experimental. This script uses a linear regression to construct a model for predicting native +# code size from bytecode. Some initial work has been done to analyze a large corpus of Luau scripts, and while for +# most functions the model predicts the native code size quite well (+/-25%), there are many cases where the predicted +# size is off by as much as 13x. Notably, the predicted size is generally better for smaller functions and worse for +# larger functions. Therefore, in its current form this analysis is probably not suitable for use as a basis for +# compilation heuristics. A nonlinear model may produce better results. The script here exists as a foundation for +# further exploration. + + +import json +import glob +from pathlib import Path +import pandas as pd +import numpy as np +from sklearn.linear_model import LinearRegression +import matplotlib.pyplot as plt +import argparse + + +def readStats(statsFileGlob): + '''Reads files matching the supplied glob. + Files should be generated by the Compile.cpp CLI''' + + statsFiles = glob.glob(statsFileGlob, recursive=True) + + print("Reading %s files." % len(statsFiles)) + + df_dict = { + "statsFile": [], + "script": [], + "name": [], + "line": [], + "bcodeCount": [], + "irCount": [], + "asmCount": [], + "bytecodeSummary": [] + } + + for statsFile in statsFiles: + stats = json.loads(Path(statsFile).read_text()) + for script, filestats in stats.items(): + for funstats in filestats["lowerStats"]["functions"]: + df_dict["statsFile"].append(statsFile) + df_dict["script"].append(script) + df_dict["name"].append(funstats["name"]) + df_dict["line"].append(funstats["line"]) + df_dict["bcodeCount"].append(funstats["bcodeCount"]) + df_dict["irCount"].append(funstats["irCount"]) + df_dict["asmCount"].append(funstats["asmCount"]) + df_dict["bytecodeSummary"].append( + tuple(funstats["bytecodeSummary"][0])) + + return pd.DataFrame.from_dict(df_dict) + + +def addFunctionCount(df): + df2 = df.drop_duplicates(subset=['asmCount', 'bytecodeSummary'], ignore_index=True).groupby( + ['bytecodeSummary']).size().reset_index(name='functionCount') + return df.merge(df2, on='bytecodeSummary', how='left') + +# def deduplicateDf(df): +# return df.drop_duplicates(subset=['bcodeCount', 'asmCount', 'bytecodeSummary'], ignore_index=True) + + +def randomizeDf(df): + return df.sample(frac=1) + + +def splitSeq(seq): + n = len(seq) // 2 + return (seq[:n], seq[n:]) + + +def trainAsmSizePredictor(df): + XTrain, XValidate = splitSeq( + np.array([list(seq) for seq in df.bytecodeSummary])) + YTrain, YValidate = splitSeq(np.array(df.asmCount)) + + reg = LinearRegression( + positive=True, fit_intercept=False).fit(XTrain, YTrain) + YPredict1 = reg.predict(XTrain) + YPredict2 = reg.predict(XValidate) + + trainRmse = np.sqrt(np.mean((np.array(YPredict1) - np.array(YTrain))**2)) + predictRmse = np.sqrt( + np.mean((np.array(YPredict2) - np.array(YValidate))**2)) + + print(f"Score: {reg.score(XTrain, YTrain)}") + print(f"Training RMSE: {trainRmse}") + print(f"Prediction RMSE: {predictRmse}") + print(f"Model Intercept: {reg.intercept_}") + print(f"Model Coefficients:\n{reg.coef_}") + + df.loc[:, 'asmCountPredicted'] = np.concatenate( + (YPredict1, YPredict2)).round().astype(int) + df['usedForTraining'] = np.concatenate( + (np.repeat(True, YPredict1.size), np.repeat(False, YPredict2.size))) + df['diff'] = df['asmCountPredicted'] - df['asmCount'] + df['diffPerc'] = (100 * df['diff']) / df['asmCount'] + df.loc[(df["diffPerc"] == np.inf), 'diffPerc'] = 0.0 + df['diffPerc'] = df['diffPerc'].round() + + return (reg, df) + + +def saveModel(reg, file): + f = open(file, "w") + f.write(f"Intercept: {reg.intercept_}\n") + f.write(f"Coefficients: \n{reg.coef_}\n") + f.close() + + +def bcodeVsAsmPlot(df, plotFile=None, minBcodeCount=None, maxBcodeCount=None): + if minBcodeCount is None: + minBcodeCount = df.bcodeCount.min() + if maxBcodeCount is None: + maxBcodeCount = df.bcodeCount.max() + + subDf = df[(df.bcodeCount <= maxBcodeCount) & + (df.bcodeCount >= minBcodeCount)] + + plt.scatter(subDf.bcodeCount, subDf.asmCount) + plt.title("ASM variation by Bytecode") + plt.xlabel("Bytecode Instruction Count") + plt.ylabel("ASM Instruction Count") + + if plotFile is not None: + plt.savefig(plotFile) + + return plt + + +def predictionErrorPlot(df, plotFile=None, minPerc=None, maxPerc=None, bins=200): + if minPerc is None: + minPerc = df['diffPerc'].min() + if maxPerc is None: + maxPerc = df['diffPerc'].max() + + plotDf = df[(df["usedForTraining"] == False) & ( + df["diffPerc"] >= minPerc) & (df["diffPerc"] <= maxPerc)] + + plt.hist(plotDf["diffPerc"], bins=bins) + plt.title("Prediction Error Distribution") + plt.xlabel("Prediction Error %") + plt.ylabel("Function Count") + + if plotFile is not None: + plt.savefig(plotFile) + + return plt + + +def parseArgs(): + parser = argparse.ArgumentParser( + prog='codesizeprediction.py', + description='Constructs a linear regression model to predict native instruction count from bytecode opcode distribution') + parser.add_argument("fileglob", + help="glob pattern for stats files to be used for training") + parser.add_argument("modelfile", + help="text file to save model details") + parser.add_argument("--nativesizefig", + help="path for saving the plot showing the variation of native code size with bytecode") + parser.add_argument("--predictionerrorfig", + help="path for saving the plot showing the distribution of prediction error") + return parser.parse_args() + + +if __name__ == "__main__": + args = parseArgs() + + df0 = readStats(args.fileglob) + df1 = addFunctionCount(df0) + df2 = randomizeDf(df1) + + plt = bcodeVsAsmPlot(df2, args.nativesizefig, 0, 100) + plt.show() + + (reg, df4) = trainAsmSizePredictor(df2) + saveModel(reg, args.modelfile) + + plt = predictionErrorPlot(df4, args.predictionerrorfig, -200, 200) + plt.show() diff --git a/tools/faillist.txt b/tools/faillist.txt index c68312985..e69de29bb 100644 --- a/tools/faillist.txt +++ b/tools/faillist.txt @@ -1,271 +0,0 @@ -AnnotationTests.too_many_type_params -AstQuery.last_argument_function_call_type -AstQuery::getDocumentationSymbolAtPosition.overloaded_class_method -AstQuery::getDocumentationSymbolAtPosition.overloaded_fn -AstQuery::getDocumentationSymbolAtPosition.table_overloaded_function_prop -AutocompleteTest.autocomplete_response_perf1 -BuiltinTests.aliased_string_format -BuiltinTests.assert_removes_falsy_types -BuiltinTests.assert_removes_falsy_types2 -BuiltinTests.assert_removes_falsy_types_even_from_type_pack_tail_but_only_for_the_first_type -BuiltinTests.assert_returns_false_and_string_iff_it_knows_the_first_argument_cannot_be_truthy -BuiltinTests.bad_select_should_not_crash -BuiltinTests.dont_add_definitions_to_persistent_types -BuiltinTests.gmatch_definition -BuiltinTests.match_capture_types -BuiltinTests.match_capture_types2 -BuiltinTests.math_max_checks_for_numbers -BuiltinTests.select_slightly_out_of_range -BuiltinTests.select_way_out_of_range -BuiltinTests.set_metatable_needs_arguments -BuiltinTests.setmetatable_should_not_mutate_persisted_types -BuiltinTests.sort_with_bad_predicate -BuiltinTests.string_format_as_method -BuiltinTests.string_format_correctly_ordered_types -BuiltinTests.string_format_report_all_type_errors_at_correct_positions -BuiltinTests.string_format_tostring_specifier_type_constraint -BuiltinTests.string_format_use_correct_argument2 -BuiltinTests.table_insert_correctly_infers_type_of_array_3_args_overload -BuiltinTests.table_pack -BuiltinTests.table_pack_reduce -BuiltinTests.table_pack_variadic -DefinitionTests.class_definition_overload_metamethods -DefinitionTests.class_definition_string_props -FrontendTest.environments -FrontendTest.nocheck_cycle_used_by_checked -GenericsTests.apply_type_function_nested_generics2 -GenericsTests.better_mismatch_error_messages -GenericsTests.bound_tables_do_not_clone_original_fields -GenericsTests.check_mutual_generic_functions -GenericsTests.correctly_instantiate_polymorphic_member_functions -GenericsTests.do_not_infer_generic_functions -GenericsTests.generic_argument_count_too_few -GenericsTests.generic_argument_count_too_many -GenericsTests.generic_functions_should_be_memory_safe -GenericsTests.generic_type_pack_parentheses -GenericsTests.higher_rank_polymorphism_should_not_accept_instantiated_arguments -GenericsTests.infer_generic_function_function_argument_2 -GenericsTests.infer_generic_function_function_argument_3 -GenericsTests.infer_generic_function_function_argument_overloaded -GenericsTests.infer_generic_lib_function_function_argument -GenericsTests.instantiated_function_argument_names -GenericsTests.no_stack_overflow_from_quantifying -GenericsTests.self_recursive_instantiated_param -IntersectionTypes.table_intersection_write_sealed -IntersectionTypes.table_intersection_write_sealed_indirect -IntersectionTypes.table_write_sealed_indirect -ModuleTests.clone_self_property -NonstrictModeTests.for_in_iterator_variables_are_any -NonstrictModeTests.function_parameters_are_any -NonstrictModeTests.inconsistent_module_return_types_are_ok -NonstrictModeTests.inconsistent_return_types_are_ok -NonstrictModeTests.infer_nullary_function -NonstrictModeTests.infer_the_maximum_number_of_values_the_function_could_return -NonstrictModeTests.inline_table_props_are_also_any -NonstrictModeTests.local_tables_are_not_any -NonstrictModeTests.locals_are_any_by_default -NonstrictModeTests.offer_a_hint_if_you_use_a_dot_instead_of_a_colon -NonstrictModeTests.parameters_having_type_any_are_optional -NonstrictModeTests.table_props_are_any -ProvisionalTests.assign_table_with_refined_property_with_a_similar_type_is_illegal -ProvisionalTests.bail_early_if_unification_is_too_complicated -ProvisionalTests.do_not_ice_when_trying_to_pick_first_of_generic_type_pack -ProvisionalTests.error_on_eq_metamethod_returning_a_type_other_than_boolean -ProvisionalTests.free_options_cannot_be_unified_together -ProvisionalTests.generic_type_leak_to_module_interface_variadic -ProvisionalTests.greedy_inference_with_shared_self_triggers_function_with_no_returns -ProvisionalTests.setmetatable_constrains_free_type_into_free_table -ProvisionalTests.specialization_binds_with_prototypes_too_early -ProvisionalTests.table_insert_with_a_singleton_argument -ProvisionalTests.typeguard_inference_incomplete -RefinementTest.type_guard_can_filter_for_intersection_of_tables -RefinementTest.type_narrow_to_vector -RefinementTest.typeguard_cast_free_table_to_vector -RefinementTest.typeguard_in_assert_position -RefinementTest.x_as_any_if_x_is_instance_elseif_x_is_table -RuntimeLimits.typescript_port_of_Result_type -TableTests.a_free_shape_cannot_turn_into_a_scalar_if_it_is_not_compatible -TableTests.accidentally_checked_prop_in_opposite_branch -TableTests.any_when_indexing_into_an_unsealed_table_with_no_indexer_in_nonstrict_mode -TableTests.casting_tables_with_props_into_table_with_indexer3 -TableTests.casting_tables_with_props_into_table_with_indexer4 -TableTests.checked_prop_too_early -TableTests.disallow_indexing_into_an_unsealed_table_with_no_indexer_in_strict_mode -TableTests.dont_crash_when_setmetatable_does_not_produce_a_metatabletypevar -TableTests.dont_hang_when_trying_to_look_up_in_cyclic_metatable_index -TableTests.dont_suggest_exact_match_keys -TableTests.error_detailed_metatable_prop -TableTests.expected_indexer_from_table_union -TableTests.expected_indexer_value_type_extra -TableTests.expected_indexer_value_type_extra_2 -TableTests.explicitly_typed_table -TableTests.explicitly_typed_table_with_indexer -TableTests.found_like_key_in_table_function_call -TableTests.found_like_key_in_table_property_access -TableTests.found_multiple_like_keys -TableTests.fuzz_table_unify_instantiated_table -TableTests.generic_table_instantiation_potential_regression -TableTests.give_up_after_one_metatable_index_look_up -TableTests.indexer_on_sealed_table_must_unify_with_free_table -TableTests.indexing_from_a_table_should_prefer_properties_when_possible -TableTests.inequality_operators_imply_exactly_matching_types -TableTests.infer_array_2 -TableTests.inferred_return_type_of_free_table -TableTests.instantiate_table_cloning_3 -TableTests.leaking_bad_metatable_errors -TableTests.less_exponential_blowup_please -TableTests.missing_metatable_for_sealed_tables_do_not_get_inferred -TableTests.mixed_tables_with_implicit_numbered_keys -TableTests.nil_assign_doesnt_hit_indexer -TableTests.ok_to_set_nil_even_on_non_lvalue_base_expr -TableTests.only_ascribe_synthetic_names_at_module_scope -TableTests.oop_polymorphic -TableTests.quantify_even_that_table_was_never_exported_at_all -TableTests.quantify_metatables_of_metatables_of_table -TableTests.reasonable_error_when_adding_a_nonexistent_property_to_an_array_like_table -TableTests.result_is_always_any_if_lhs_is_any -TableTests.result_is_bool_for_equality_operators_if_lhs_is_any -TableTests.right_table_missing_key2 -TableTests.shared_selfs -TableTests.shared_selfs_from_free_param -TableTests.shared_selfs_through_metatables -TableTests.table_call_metamethod_basic -TableTests.table_simple_call -TableTests.table_subtyping_with_missing_props_dont_report_multiple_errors -TableTests.table_unification_4 -TableTests.used_colon_instead_of_dot -TableTests.used_dot_instead_of_colon -ToString.named_metatable_toStringNamedFunction -ToString.toStringDetailed2 -ToString.toStringErrorPack -ToString.toStringNamedFunction_generic_pack -ToString.toStringNamedFunction_map -TryUnifyTests.members_of_failed_typepack_unification_are_unified_with_errorType -TryUnifyTests.result_of_failed_typepack_unification_is_constrained -TryUnifyTests.typepack_unification_should_trim_free_tails -TypeAliases.cannot_create_cyclic_type_with_unknown_module -TypeAliases.generic_param_remap -TypeAliases.mismatched_generic_type_param -TypeAliases.mutually_recursive_types_restriction_not_ok_1 -TypeAliases.mutually_recursive_types_restriction_not_ok_2 -TypeAliases.mutually_recursive_types_swapsies_not_ok -TypeAliases.recursive_types_restriction_not_ok -TypeAliases.report_shadowed_aliases -TypeAliases.type_alias_local_mutation -TypeAliases.type_alias_local_rename -TypeAliases.type_alias_of_an_imported_recursive_generic_type -TypeInfer.check_type_infer_recursion_count -TypeInfer.checking_should_not_ice -TypeInfer.cli_50041_committing_txnlog_in_apollo_client_error -TypeInfer.dont_report_type_errors_within_an_AstExprError -TypeInfer.dont_report_type_errors_within_an_AstStatError -TypeInfer.fuzz_free_table_type_change_during_index_check -TypeInfer.globals -TypeInfer.globals2 -TypeInfer.infer_assignment_value_types_mutable_lval -TypeInfer.it_is_ok_to_have_inconsistent_number_of_return_values_in_nonstrict -TypeInfer.no_stack_overflow_from_isoptional -TypeInfer.no_stack_overflow_from_isoptional2 -TypeInfer.tc_after_error_recovery_no_replacement_name_in_error -TypeInfer.type_infer_recursion_limit_no_ice -TypeInferAnyError.for_in_loop_iterator_is_any2 -TypeInferClasses.class_type_mismatch_with_name_conflict -TypeInferClasses.classes_without_overloaded_operators_cannot_be_added -TypeInferClasses.index_instance_property -TypeInferClasses.optional_class_field_access_error -TypeInferClasses.table_class_unification_reports_sane_errors_for_missing_properties -TypeInferClasses.warn_when_prop_almost_matches -TypeInferFunctions.calling_function_with_anytypepack_doesnt_leak_free_types -TypeInferFunctions.cannot_hoist_interior_defns_into_signature -TypeInferFunctions.check_function_before_lambda_that_uses_it -TypeInferFunctions.dont_give_other_overloads_message_if_only_one_argument_matching_overload_exists -TypeInferFunctions.dont_infer_parameter_types_for_functions_from_their_call_site -TypeInferFunctions.duplicate_functions_with_different_signatures_not_allowed_in_nonstrict -TypeInferFunctions.function_cast_error_uses_correct_language -TypeInferFunctions.function_decl_non_self_sealed_overwrite_2 -TypeInferFunctions.function_decl_non_self_unsealed_overwrite -TypeInferFunctions.function_does_not_return_enough_values -TypeInferFunctions.function_statement_sealed_table_assignment_through_indexer -TypeInferFunctions.improved_function_arg_mismatch_error_nonstrict -TypeInferFunctions.improved_function_arg_mismatch_errors -TypeInferFunctions.infer_anonymous_function_arguments -TypeInferFunctions.infer_that_function_does_not_return_a_table -TypeInferFunctions.luau_subtyping_is_np_hard -TypeInferFunctions.no_lossy_function_type -TypeInferFunctions.occurs_check_failure_in_function_return_type -TypeInferFunctions.record_matching_overload -TypeInferFunctions.report_exiting_without_return_nonstrict -TypeInferFunctions.report_exiting_without_return_strict -TypeInferFunctions.return_type_by_overload -TypeInferFunctions.too_few_arguments_variadic -TypeInferFunctions.too_few_arguments_variadic_generic -TypeInferFunctions.too_few_arguments_variadic_generic2 -TypeInferFunctions.too_many_arguments_error_location -TypeInferFunctions.too_many_return_values_in_parentheses -TypeInferFunctions.too_many_return_values_no_function -TypeInferLoops.for_in_loop_error_on_factory_not_returning_the_right_amount_of_values -TypeInferLoops.for_in_loop_with_next -TypeInferLoops.for_in_with_generic_next -TypeInferLoops.loop_iter_metamethod_ok_with_inference -TypeInferLoops.loop_iter_no_indexer_nonstrict -TypeInferLoops.loop_iter_trailing_nil -TypeInferLoops.properly_infer_iteratee_is_a_free_table -TypeInferLoops.unreachable_code_after_infinite_loop -TypeInferModules.custom_require_global -TypeInferModules.do_not_modify_imported_types_5 -TypeInferModules.module_type_conflict -TypeInferModules.module_type_conflict_instantiated -TypeInferModules.type_error_of_unknown_qualified_type -TypeInferOOP.inferring_hundreds_of_self_calls_should_not_suffocate_memory -TypeInferOOP.methods_are_topologically_sorted -TypeInferOperators.CallAndOrOfFunctions -TypeInferOperators.CallOrOfFunctions -TypeInferOperators.cannot_indirectly_compare_types_that_do_not_have_a_metatable -TypeInferOperators.cannot_indirectly_compare_types_that_do_not_offer_overloaded_ordering_operators -TypeInferOperators.cli_38355_recursive_union -TypeInferOperators.compound_assign_metatable -TypeInferOperators.compound_assign_mismatch_metatable -TypeInferOperators.disallow_string_and_types_without_metatables_from_arithmetic_binary_ops -TypeInferOperators.in_nonstrict_mode_strip_nil_from_intersections_when_considering_relational_operators -TypeInferOperators.infer_any_in_all_modes_when_lhs_is_unknown -TypeInferOperators.operator_eq_completely_incompatible -TypeInferOperators.typecheck_overloaded_multiply_that_is_an_intersection -TypeInferOperators.typecheck_overloaded_multiply_that_is_an_intersection_on_rhs -TypeInferOperators.typecheck_unary_len_error -TypeInferOperators.UnknownGlobalCompoundAssign -TypeInferOperators.unrelated_classes_cannot_be_compared -TypeInferOperators.unrelated_primitives_cannot_be_compared -TypeInferPrimitives.CheckMethodsOfNumber -TypeInferPrimitives.string_index -TypeInferUnknownNever.assign_to_global_which_is_never -TypeInferUnknownNever.dont_unify_operands_if_one_of_the_operand_is_never_in_any_ordering_operators -TypeInferUnknownNever.math_operators_and_never -TypePackTests.detect_cyclic_typepacks2 -TypePackTests.pack_tail_unification_check -TypePackTests.type_alias_backwards_compatible -TypePackTests.type_alias_default_type_errors -TypePackTests.type_alias_type_packs_errors -TypePackTests.unify_variadic_tails_in_arguments -TypePackTests.variadic_packs -TypeSingletons.function_call_with_singletons -TypeSingletons.function_call_with_singletons_mismatch -TypeSingletons.indexing_on_union_of_string_singletons -TypeSingletons.no_widening_from_callsites -TypeSingletons.return_type_of_f_is_not_widened -TypeSingletons.table_properties_type_error_escapes -TypeSingletons.taking_the_length_of_union_of_string_singleton -TypeSingletons.widen_the_supertype_if_it_is_free_and_subtype_has_singleton -TypeSingletons.widening_happens_almost_everywhere -UnionTypes.index_on_a_union_type_with_missing_property -UnionTypes.optional_assignment_errors -UnionTypes.optional_call_error -UnionTypes.optional_field_access_error -UnionTypes.optional_index_error -UnionTypes.optional_iteration -UnionTypes.optional_length_error -UnionTypes.optional_missing_key_error_details -UnionTypes.optional_union_follow -UnionTypes.optional_union_functions -UnionTypes.optional_union_members -UnionTypes.optional_union_methods -UnionTypes.table_union_write_indirect diff --git a/tools/flag-bisect.py b/tools/flag-bisect.py index 01f3ef7ce..55663a789 100644 --- a/tools/flag-bisect.py +++ b/tools/flag-bisect.py @@ -135,7 +135,7 @@ def add_argument_parsers(parser): interestness_parser.add_argument('--auto', dest='mode', action='store_const', const=InterestnessMode.AUTO, default=InterestnessMode.AUTO, help='Automatically figure out which one of --pass or --fail should be used') interestness_parser.add_argument('--fail', dest='mode', action='store_const', const=InterestnessMode.FAIL, - help='You want this if omitting --fflags=true causes tests to fail') + help='You want this if passing --fflags=true causes tests to fail') interestness_parser.add_argument('--pass', dest='mode', action='store_const', const=InterestnessMode.PASS, help='You want this if passing --fflags=true causes tests to pass') interestness_parser.add_argument('--timeout', dest='timeout', type=int, default=0, metavar='SECONDS', diff --git a/tools/fuzz/fuzzer-postprocess.py b/tools/fuzz/fuzzer-postprocess.py new file mode 100644 index 000000000..a517c245a --- /dev/null +++ b/tools/fuzz/fuzzer-postprocess.py @@ -0,0 +1,175 @@ +#!/usr/bin/python3 +# This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +import argparse +import jinja2 +import multiprocessing +import os +import shutil +import subprocess +import sys + + +def is_crash(reproducer_name: str) -> bool: + return reproducer_name.startswith("crash-") or reproducer_name.startswith("oom-") + + +class CrashReport: + def __init__(self, args, crash_id): + self.id = crash_id + self.args = args + self.crash_root = os.path.join(args.output_directory, crash_id) + + def trace(self) -> str: + trace_path = os.path.join(self.crash_root, "trace.txt") + + if os.path.exists(trace_path): + with open(os.path.join(self.crash_root, "trace.txt"), "r") as trace_file: + return trace_file.read() + else: + return None + + def modules(self) -> str: + with open(os.path.join(self.crash_root, "modules.txt"), "r") as modules_file: + return modules_file.read() + + def artifact_link(self) -> str: + return f"{self.args.artifact_root}/{self.id}/minimized_reproducer" + + +class MetaValue: + def __init__(self, name, value): + self.name = name + self.value = value + self.link = None + + +def minimize_crash(args, reproducer, workdir): + if not is_crash(os.path.basename(reproducer)): + # Not actually a crash, so no minimization is actually possible. + return + + print( + f"Minimizing reproducer {os.path.basename(reproducer)} for {args.minimize_for} seconds.") + + reproducer_absolute = os.path.abspath(reproducer) + + artifact = os.path.join(workdir, "minimized_reproducer") + minimize_result = subprocess.run([args.executable, "-detect_leaks=0", "-minimize_crash=1", + f"-exact_artifact_path={artifact}", f"-max_total_time={args.minimize_for}", reproducer_absolute], cwd=workdir, stdout=sys.stdout if args.verbose else subprocess.DEVNULL, stderr=sys.stderr if args.verbose else subprocess.DEVNULL) + + if minimize_result.returncode != 0: + print( + f"Minimize process exited with code {minimize_result.returncode}; minimization failed.") + return + + if os.path.exists(artifact): + print( + f"Minimized {os.path.basename(reproducer)} from {os.path.getsize(reproducer)} bytes to {os.path.getsize(artifact)}.") + + +def process_crash(args, reproducer): + crash_id = os.path.basename(reproducer) + crash_output = os.path.join(args.output_directory, crash_id) + print(f"Processing reproducer {crash_id}.") + + print(f"Output will be stored in {crash_output}.") + if os.path.exists(crash_output): + print(f"Contents of {crash_output} will be discarded.") + shutil.rmtree(crash_output, ignore_errors=True) + + os.makedirs(crash_output) + shutil.copyfile(reproducer, os.path.join(crash_output, "original_reproducer")) + shutil.copyfile(reproducer, os.path.join( + crash_output, "minimized_reproducer")) + + minimize_crash(args, reproducer, crash_output) + + if is_crash(crash_id): + trace_result = subprocess.run([args.executable, os.path.join( + crash_output, "minimized_reproducer")], stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True) + trace_text = trace_result.stdout + + with open(os.path.join(crash_output, "trace.txt"), "w") as trace_file: + trace_file.write(trace_text) + + modules_result = subprocess.run([args.prototest, os.path.join( + crash_output, "minimized_reproducer"), "-detect_leaks=0", "-verbosity=0"], stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True) + modules_text = modules_result.stdout + + module_index_of = modules_text.index("Module") + modules_text = modules_text[module_index_of:] + + with open(os.path.join(crash_output, "modules.txt"), "w") as modules_file: + modules_file.write(modules_text) + + return CrashReport(args, crash_id) + + +def process_crashes(args): + crash_names = sorted(os.listdir(args.source_directory)) + with multiprocessing.Pool(args.workers) as pool: + crashes = [(args, os.path.join(args.source_directory, c)) for c in crash_names] + crashes = pool.starmap(process_crash, crashes) + print(f"Processed {len(crashes)} crashes.") + return crashes + + +def generate_report(crashes, meta): + env = jinja2.Environment( + loader=jinja2.PackageLoader("fuzzer-postprocess"), + autoescape=jinja2.select_autoescape() + ) + + template = env.get_template("index.html") + with open("fuzz-report.html", "w") as report_file: + report_file.write(template.render( + crashes=crashes, + meta=meta, + )) + + +def __main__(): + parser = argparse.ArgumentParser() + parser.add_argument("--source_directory", required=True) + parser.add_argument("--output_directory", required=True) + parser.add_argument("--executable", required=True) + parser.add_argument("--prototest", required=True) + parser.add_argument("--minimize_for", required=True) + parser.add_argument("--artifact_root", required=True) + parser.add_argument("--verbose", "-v", action="store_true") + parser.add_argument("--workers", action="store", type=int, default=4) + meta_group = parser.add_argument_group( + "metadata", description="Report metadata to attach.") + meta_group.add_argument("--meta.values", nargs="*", + help="Any metadata to attach, in the form name=value. Multiple values may be specified.", dest="metadata_values", default=[]) + meta_group.add_argument("--meta.urls", nargs="*", + help="URLs to attach to metadata, in the form name=url. Multiple values may be specified. A value must also be specified with --meta.values.", dest="metadata_urls", default=[]) + args = parser.parse_args() + + meta_values = dict() + for pair in args.metadata_values: + components = pair.split("=", 1) + name = components[0] + value = components[1] + + meta_values[name] = MetaValue(name, value) + + for pair in args.metadata_urls: + components = pair.split("=", 1) + name = components[0] + url = components[1] + + if name in meta_values: + meta_values[name].link = url + else: + print(f"Metadata {name} has URL {url} but no value specified.") + + meta_values = sorted(list(meta_values.values()), key=lambda x: x.name) + + crashes = process_crashes(args) + generate_report(crashes, meta_values) + + +if __name__ == "__main__": + __main__() diff --git a/tools/fuzz/fuzzfilter.py b/tools/fuzz/fuzzfilter.py new file mode 100644 index 000000000..07140be3d --- /dev/null +++ b/tools/fuzz/fuzzfilter.py @@ -0,0 +1,113 @@ +#!/usr/bin/python3 +# This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +# Given a fuzzer binary and a list of crashing programs, this tool collects unique crash reasons and prints reproducers. + +import argparse +import multiprocessing +import os +import re +import subprocess +import sys + + +def is_crash(reproducer_name: str) -> bool: + return reproducer_name.startswith("crash-") or reproducer_name.startswith("oom-") + + +class Reproducer: + def __init__(self, file, reason, fingerprint): + self.file = file + self.reason = reason + self.fingerprint = fingerprint + + +def get_crash_reason(binary, file, remove_passing): + res = subprocess.run( + [binary, file], stdout=subprocess.DEVNULL, stderr=subprocess.PIPE) + + if res.returncode == 0: + if remove_passing: + print(f"Warning: {binary} {file} returned 0; removing from result set.", file=sys.stderr) + os.remove(file) + else: + print(f"Warning: {binary} {file} returned 0", file=sys.stderr) + + return None + + err = res.stderr.decode("utf-8") + + if (pos := err.find("ERROR: AddressSanitizer:")) != -1: + return err[pos:] + + if (pos := err.find("ERROR: libFuzzer:")) != -1: + return err[pos:] + + print(f"Warning: {binary} {file} returned unrecognized error {err} with exit code {res.returncode}", file=sys.stderr) + return None + + +def get_crash_fingerprint(reason): + # Due to ASLR addresses are different every time, so we filter them out + reason = re.sub(r"0x[0-9a-f]+", "0xXXXX", reason) + return reason + + +parser = argparse.ArgumentParser() +parser.add_argument("binary") +parser.add_argument("files", action="append", default=[]) +parser.add_argument("--remove-duplicates", action="store_true") +parser.add_argument("--remove-passing", action="store_true") +parser.add_argument("--workers", action="store", default=1, type=int) +parser.add_argument("--verbose", "-v", action="count", default=0, dest="verbosity") + +args = parser.parse_args() + +def process_file(file): + reason = get_crash_reason(args.binary, file, args.remove_passing) + if reason is None: + return None + + fingerprint = get_crash_fingerprint(reason) + return Reproducer(file, reason, fingerprint) + + +filter_targets = [] +if len(args.files) == 1: + for root, dirs, files in os.walk(args.files[0]): + for file in files: + if not is_crash(file): + continue + + filter_targets.append(os.path.join(root, file)) +else: + filter_targets = args.files + +if __name__ == "__main__": + multiprocessing.freeze_support() + + with multiprocessing.Pool(processes = args.workers) as pool: + print(f"Processing {len(filter_targets)} reproducers across {args.workers} workers.") + reproducers = [r for r in pool.map(process_file, filter_targets) if r is not None] + + seen = set() + for index, reproducer in enumerate(reproducers): + if reproducer.fingerprint in seen: + if sys.stdout.isatty(): + print("-\|/"[index % 4], end="\r") + + if args.remove_duplicates: + if args.verbosity >= 1: + print(f"Removing duplicate reducer {reproducer.file}.") + os.remove(reproducer.file) + + continue + + seen.add(reproducer.fingerprint) + if args.verbosity >= 2: + print(f"Reproducer: {args.binary} {reproducer.file}") + print(f"Output: {reproducer.reason}") + + print(f"Total unique crashes: {len(seen)}") + if args.remove_duplicates: + print(f"Duplicate reproducers have been removed.") diff --git a/tools/fuzz/requirements.txt b/tools/fuzz/requirements.txt new file mode 100644 index 000000000..297ba3246 --- /dev/null +++ b/tools/fuzz/requirements.txt @@ -0,0 +1,2 @@ +Jinja2==3.1.4 +MarkupSafe==2.1.3 diff --git a/tools/fuzz/templates/index.html b/tools/fuzz/templates/index.html new file mode 100644 index 000000000..bd4114cad --- /dev/null +++ b/tools/fuzz/templates/index.html @@ -0,0 +1,132 @@ + + + + + + + Luau Fuzzer Report + + + +
+
+

Fuzzer Report

+
+
+ + + + {% for crash in crashes %} +
+

{{ crash.id }}

+

Download reproducer artifact

+ {% if crash.trace() %} +
Trace
+
+
{{ crash.trace() }}
+
+ {% endif %} +
Module set
+
+
{{ crash.modules() }}
+
+
+
+ {% endfor %} +
+ + diff --git a/tools/heuristicstat.py b/tools/heuristicstat.py new file mode 100644 index 000000000..efd9f8c87 --- /dev/null +++ b/tools/heuristicstat.py @@ -0,0 +1,167 @@ +#!/usr/bin/python3 +# This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +import argparse +import json +from collections import Counter +import pandas as pd +## needed for 'to_markdown' method for pandas data frame +import tabulate + + +def getArgs(): + parser = argparse.ArgumentParser(description='Analyze compiler statistics') + parser.add_argument('--bytecode-bin-factor', dest='bytecodeBinFactor',default=10,help='Bytecode bin size as a multiple of 1000 (10 by default)') + parser.add_argument('--block-bin-factor', dest='blockBinFactor',default=1,help='Block bin size as a multiple of 1000 (1 by default)') + parser.add_argument('--block-instruction-bin-factor', dest='blockInstructionBinFactor',default=1,help='Block bin size as a multiple of 1000 (1 by default)') + parser.add_argument('statsFile', help='stats.json file generated by running luau-compile') + args = parser.parse_args() + return args + +def readStats(statsFile): + with open(statsFile) as f: + stats = json.load(f) + + scripts = [] + functionCounts = [] + bytecodeLengths = [] + blockPreOptCounts = [] + blockPostOptCounts = [] + maxBlockInstructionCounts = [] + + for path, fileStat in stats.items(): + scripts.append(path) + functionCounts.append(fileStat['lowerStats']['totalFunctions'] - fileStat['lowerStats']['skippedFunctions']) + bytecodeLengths.append(fileStat['bytecode']) + blockPreOptCounts.append(fileStat['lowerStats']['blocksPreOpt']) + blockPostOptCounts.append(fileStat['lowerStats']['blocksPostOpt']) + maxBlockInstructionCounts.append(fileStat['lowerStats']['maxBlockInstructions']) + + stats_df = pd.DataFrame({ + 'Script': scripts, + 'FunctionCount': functionCounts, + 'BytecodeLength': bytecodeLengths, + 'BlockPreOptCount': blockPreOptCounts, + 'BlockPostOptCount': blockPostOptCounts, + 'MaxBlockInstructionCount': maxBlockInstructionCounts + }) + + return stats_df + + +def analyzeBytecodeStats(stats_df, config): + binFactor = config.bytecodeBinFactor + divisor = binFactor * 1000 + totalScriptCount = len(stats_df.index) + + lengthLabels = [] + scriptCounts = [] + scriptPercs = [] + + counter = Counter() + + for index, row in stats_df.iterrows(): + value = row['BytecodeLength'] + factor = int(value / divisor) + counter[factor] += 1 + + for factor, scriptCount in sorted(counter.items()): + left = factor * binFactor + right = left + binFactor + lengthLabel = '{left}K-{right}K'.format(left=left, right=right) + lengthLabels.append(lengthLabel) + scriptCounts.append(scriptCount) + scriptPerc = round(scriptCount * 100 / totalScriptCount, 1) + scriptPercs.append(scriptPerc) + + bcode_df = pd.DataFrame({ + 'BytecodeLength': lengthLabels, + 'ScriptCount': scriptCounts, + 'ScriptPerc': scriptPercs + }) + + return bcode_df + + +def analyzeBlockStats(stats_df, config, field): + binFactor = config.blockBinFactor + divisor = binFactor * 1000 + totalScriptCount = len(stats_df.index) + + blockLabels = [] + scriptCounts = [] + scriptPercs = [] + + counter = Counter() + + for index, row in stats_df.iterrows(): + value = row[field] + factor = int(value / divisor) + counter[factor] += 1 + + for factor, scriptCount in sorted(counter.items()): + left = factor * binFactor + right = left + binFactor + blockLabel = '{left}K-{right}K'.format(left=left, right=right) + blockLabels.append(blockLabel) + scriptCounts.append(scriptCount) + scriptPerc = round((scriptCount * 100) / totalScriptCount, 1) + scriptPercs.append(scriptPerc) + + block_df = pd.DataFrame({ + field: blockLabels, + 'ScriptCount': scriptCounts, + 'ScriptPerc': scriptPercs + }) + + return block_df + +def analyzeMaxBlockInstructionStats(stats_df, config): + binFactor = config.blockInstructionBinFactor + divisor = binFactor * 1000 + totalScriptCount = len(stats_df.index) + + blockLabels = [] + scriptCounts = [] + scriptPercs = [] + + counter = Counter() + + for index, row in stats_df.iterrows(): + value = row['MaxBlockInstructionCount'] + factor = int(value / divisor) + counter[factor] += 1 + + for factor, scriptCount in sorted(counter.items()): + left = factor * binFactor + right = left + binFactor + blockLabel = '{left}K-{right}K'.format(left=left, right=right) + blockLabels.append(blockLabel) + scriptCounts.append(scriptCount) + scriptPerc = round((scriptCount * 100) / totalScriptCount, 1) + scriptPercs.append(scriptPerc) + + block_df = pd.DataFrame({ + 'MaxBlockInstructionCount': blockLabels, + 'ScriptCount': scriptCounts, + 'ScriptPerc': scriptPercs + }) + + return block_df + +if __name__ == '__main__': + config = getArgs() + + stats_df = readStats(config.statsFile) + + bcode_df = analyzeBytecodeStats(stats_df, config) + print(bcode_df.to_markdown()) + + block_df = analyzeBlockStats(stats_df, config, 'BlockPreOptCount') + print(block_df.to_markdown()) + + block_df = analyzeBlockStats(stats_df, config, 'BlockPostOptCount') + print(block_df.to_markdown()) + + block_df = analyzeMaxBlockInstructionStats(stats_df, config) + print(block_df.to_markdown()) diff --git a/tools/lldb_formatters.lldb b/tools/lldb_formatters.lldb index f10faa94e..d8e659e95 100644 --- a/tools/lldb_formatters.lldb +++ b/tools/lldb_formatters.lldb @@ -1,4 +1,10 @@ type synthetic add -x "^Luau::detail::DenseHashTable<.*>$" -l lldb_formatters.DenseHashTableSyntheticChildrenProvider +type synthetic add -x "^Luau::DenseHashMap<.*>$" -l lldb_formatters.DenseHashMapSyntheticChildrenProvider +type synthetic add -x "^Luau::DenseHashSet<.*>$" -l lldb_formatters.DenseHashSetSyntheticChildrenProvider + +type summary add -x "^Luau::DenseHashMap<.*>$" --summary-string "count = ${var.impl.count}" +type summary add -x "^Luau::DenseHashSet<.*>$" --summary-string "count = ${var.impl.count}" + type summary add "Luau::Symbol" -F lldb_formatters.luau_symbol_summary type synthetic add -x "^Luau::Variant<.+>$" -l lldb_formatters.LuauVariantSyntheticChildrenProvider @@ -6,5 +12,14 @@ type summary add -x "^Luau::Variant<.+>$" -F lldb_formatters.luau_variant_summar type synthetic add -x "^Luau::AstArray<.+>$" -l lldb_formatters.AstArraySyntheticChildrenProvider +type summary add -x "^Luau::NotNull<.+>$" --summary-string "${*var.ptr}" + type summary add --summary-string "${var.line}:${var.column}" Luau::Position type summary add --summary-string "${var.begin}-${var.end}" Luau::Location + +type summary add --summary-string "${var.ty} (${var%S})" Luau::TypeId Luau::TypePackId + +type summary add Luau::TypePath::Property -F lldb_formatters.luau_typepath_property_summary +type summary add --summary-string "[${var.index}]" Luau::TypePath::Index + +type summary add -x "^Luau::TryPair<.+>$" --summary-string "(${var.first%T}, ${var.second%T})" diff --git a/tools/lldb_formatters.py b/tools/lldb_formatters.py index 19fc0f54c..7e957c75d 100644 --- a/tools/lldb_formatters.py +++ b/tools/lldb_formatters.py @@ -1,7 +1,11 @@ # This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +import lldb + # HACK: LLDB's python API doesn't afford anything helpful for getting at variadic template parameters. # We're forced to resort to parsing names as strings. + + def templateParams(s): depth = 0 start = s.find("<") + 1 @@ -170,9 +174,153 @@ def has_children(self): return True +class DenseHashMapSyntheticChildrenProvider: + fixed_names = ["count", "capacity"] + max_expand_children = 100 + max_expand_capacity = 1000 + + def __init__(self, valobj, internal_dict): + self.valobj = valobj + self.count = 0 + self.capacity = 0 + + def num_children(self): + return min(self.max_expand_children, self.count) + len(self.fixed_names) + + def get_child_index(self, name): + try: + if name in self.fixed_names: + return self.fixed_names.index(name) + + return -1 + except Exception as e: + print("get_child_index exception", e, name) + return -1 + + def get_child_at_index(self, index): + try: + if index < len(self.fixed_names): + fixed_name = self.fixed_names[index] + impl_child = self.valobj.GetValueForExpressionPath( + f".impl.{fixed_name}") + + return self.valobj.CreateValueFromData(fixed_name, impl_child.GetData(), impl_child.GetType()) + else: + index -= len(self.fixed_names) + + empty_key_valobj = self.valobj.GetValueForExpressionPath( + f".impl.empty_key") + key_type = empty_key_valobj.GetType().GetCanonicalType().GetName() + skipped = 0 + + for slot in range(0, min(self.max_expand_capacity, self.capacity)): + slot_pair = self.valobj.GetValueForExpressionPath( + f".impl.data[{slot}]") + slot_key_valobj = slot_pair.GetChildMemberWithName("first") + + eq_test_valobj = self.valobj.EvaluateExpression( + f"*(reinterpret_cast({empty_key_valobj.AddressOf().GetValueAsUnsigned()})) == *(reinterpret_cast({slot_key_valobj.AddressOf().GetValueAsUnsigned()}))") + if eq_test_valobj.GetValue() == "true": + continue + + # Skip over previous occupied slots. + if index > skipped: + skipped += 1 + continue + + return self.valobj.CreateValueFromData(f"[{index}]", slot_pair.GetData(), slot_pair.GetType()) + + except Exception as e: + print("get_child_at_index error", e, index) + + def update(self): + try: + self.capacity = self.count = self.valobj.GetValueForExpressionPath( + ".impl.capacity").GetValueAsUnsigned() + self.count = self.valobj.GetValueForExpressionPath( + ".impl.count").GetValueAsUnsigned() + except Exception as e: + print("update error", e) + + def has_children(self): + return True + + +class DenseHashSetSyntheticChildrenProvider: + fixed_names = ["count", "capacity"] + max_expand_children = 100 + max_expand_capacity = 1000 + + def __init__(self, valobj, internal_dict): + self.valobj = valobj + self.count = 0 + self.capacity = 0 + + def num_children(self): + return min(self.max_expand_children, self.count) + len(self.fixed_names) + + def get_child_index(self, name): + try: + if name in self.fixed_names: + return self.fixed_names.index(name) + + return -1 + except Exception as e: + print("get_child_index exception", e, name) + return -1 + + def get_child_at_index(self, index): + try: + if index < len(self.fixed_names): + fixed_name = self.fixed_names[index] + impl_child = self.valobj.GetValueForExpressionPath( + f".impl.{fixed_name}") + + return self.valobj.CreateValueFromData(fixed_name, impl_child.GetData(), impl_child.GetType()) + else: + index -= len(self.fixed_names) + + empty_key_valobj = self.valobj.GetValueForExpressionPath( + f".impl.empty_key") + key_type = empty_key_valobj.GetType().GetCanonicalType().GetName() + skipped = 0 + + for slot in range(0, min(self.max_expand_capacity, self.capacity)): + slot_valobj = self.valobj.GetValueForExpressionPath( + f".impl.data[{slot}]") + + eq_test_valobj = self.valobj.EvaluateExpression( + f"*(reinterpret_cast({empty_key_valobj.AddressOf().GetValueAsUnsigned()})) == *(reinterpret_cast({slot_valobj.AddressOf().GetValueAsUnsigned()}))") + if eq_test_valobj.GetValue() == "true": + continue + + # Skip over previous occupied slots. + if index > skipped: + skipped += 1 + continue + + return self.valobj.CreateValueFromData(f"[{index}]", slot_valobj.GetData(), slot_valobj.GetType()) + + except Exception as e: + print("get_child_at_index error", e, index) + + def update(self): + try: + self.capacity = self.count = self.valobj.GetValueForExpressionPath( + ".impl.capacity").GetValueAsUnsigned() + self.count = self.valobj.GetValueForExpressionPath( + ".impl.count").GetValueAsUnsigned() + except Exception as e: + print("update error", e) + + def has_children(self): + return True + + def luau_symbol_summary(valobj, internal_dict, options): local = valobj.GetChildMemberWithName("local") - global_ = valobj.GetChildMemberWithName("global").GetChildMemberWithName("value") + global_ = valobj.GetChildMemberWithName( + "global").GetChildMemberWithName("value") if local.GetValueAsUnsigned() != 0: return f'local {local.GetChildMemberWithName("name").GetChildMemberWithName("value").GetSummary()}' @@ -209,7 +357,33 @@ def get_child_at_index(self, index): print("get_child_index error:", e) def update(self): - self.size = self.valobj.GetChildMemberWithName("size").GetValueAsUnsigned() + self.size = self.valobj.GetChildMemberWithName( + "size").GetValueAsUnsigned() def has_children(self): return True + + +def luau_typepath_property_summary(valobj, internal_dict, options): + name = valobj.GetChildMemberWithName("name").GetSummary() + result = "[" + + read_write = False + try: + fflag_valobj = valobj.GetFrame().GetValueForVariablePath( + "FFlag::LuauSolverV2::value") + + read_write = fflag_valobj.GetValue() == "true" + except Exception as e: + print("luau_typepath_property_summary error:", e) + + if read_write: + is_read = valobj.GetChildMemberWithName("isRead").GetValue() == "true" + if is_read: + result += "read " + else: + result += "write " + + result += name + result += "]" + return result diff --git a/tools/lvmexecute_split.py b/tools/lvmexecute_split.py deleted file mode 100644 index 16de45dcc..000000000 --- a/tools/lvmexecute_split.py +++ /dev/null @@ -1,112 +0,0 @@ -#!/usr/bin/python3 -# This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details - -# This code can be used to split lvmexecute.cpp VM switch into separate functions for use as native code generation fallbacks -import sys -import re - -input = sys.stdin.readlines() - -inst = "" - -header = """// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details -// This file was generated by 'tools/lvmexecute_split.py' script, do not modify it by hand -#pragma once - -#include - -struct lua_State; -struct Closure; -typedef uint32_t Instruction; -typedef struct lua_TValue TValue; -typedef TValue* StkId; - -""" - -source = """// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details -// This code is based on Lua 5.x implementation licensed under MIT License; see lua_LICENSE.txt for details -// This file was generated by 'tools/lvmexecute_split.py' script, do not modify it by hand -#include "Fallbacks.h" -#include "FallbacksProlog.h" - -""" - -function = "" -signature = "" - -includeInsts = ["LOP_NEWCLOSURE", "LOP_NAMECALL", "LOP_FORGPREP", "LOP_GETVARARGS", "LOP_DUPCLOSURE", "LOP_PREPVARARGS", "LOP_BREAK", "LOP_GETGLOBAL", "LOP_SETGLOBAL", "LOP_GETTABLEKS", "LOP_SETTABLEKS"] - -state = 0 - -# parse with the state machine -for line in input: - # find the start of an instruction - if state == 0: - match = re.match("\s+VM_CASE\((LOP_[A-Z_0-9]+)\)", line) - - if match: - inst = match[1] - signature = "const Instruction* execute_" + inst + "(lua_State* L, const Instruction* pc, StkId base, TValue* k)" - function = signature + "\n" - function += "{\n" - function += " [[maybe_unused]] Closure* cl = clvalue(L->ci->func);\n" - state = 1 - - # first line of the instruction which is "{" - elif state == 1: - assert(line == " {\n") - state = 2 - - # find the end of an instruction - elif state == 2: - # remove jumps back into the native code - if line == "#if LUA_CUSTOM_EXECUTION\n": - state = 3 - continue - - if line[0] == ' ': - finalline = line[12:-1] + "\n" - else: - finalline = line - - finalline = finalline.replace("VM_NEXT();", "return pc;"); - finalline = finalline.replace("goto exit;", "return NULL;"); - finalline = finalline.replace("return;", "return NULL;"); - - function += finalline - match = re.match(" }", line) - - if match: - # break is not supported - if inst == "LOP_BREAK": - function = "const Instruction* execute_" + inst + "(lua_State* L, const Instruction* pc, StkId base, TValue* k)\n" - function += "{\n LUAU_ASSERT(!\"Unsupported deprecated opcode\");\n LUAU_UNREACHABLE();\n}\n" - # handle fallthrough - elif inst == "LOP_NAMECALL": - function = function[:-len(finalline)] - function += " return pc;\n}\n" - - if inst in includeInsts: - header += signature + ";\n" - source += function + "\n" - - state = 0 - - # skip LUA_CUSTOM_EXECUTION code blocks - elif state == 3: - if line == "#endif\n": - state = 4 - continue - - # skip extra line - elif state == 4: - state = 2 - -# make sure we found the ending -assert(state == 0) - -with open("Fallbacks.h", "w") as fp: - fp.writelines(header) - -with open("Fallbacks.cpp", "w") as fp: - fp.writelines(source) diff --git a/tools/natvis/Analysis.natvis b/tools/natvis/Analysis.natvis index ca66cbe2c..74cc18fec 100644 --- a/tools/natvis/Analysis.natvis +++ b/tools/natvis/Analysis.natvis @@ -6,70 +6,70 @@ - {{ typeId=0, value={*($T1*)storage} }} - {{ typeId=1, value={*($T2*)storage} }} - {{ typeId=2, value={*($T3*)storage} }} - {{ typeId=3, value={*($T4*)storage} }} - {{ typeId=4, value={*($T5*)storage} }} - {{ typeId=5, value={*($T6*)storage} }} - {{ typeId=6, value={*($T7*)storage} }} - {{ typeId=7, value={*($T8*)storage} }} - {{ typeId=8, value={*($T9*)storage} }} - {{ typeId=9, value={*($T10*)storage} }} - {{ typeId=10, value={*($T11*)storage} }} - {{ typeId=11, value={*($T12*)storage} }} - {{ typeId=12, value={*($T13*)storage} }} - {{ typeId=13, value={*($T14*)storage} }} - {{ typeId=14, value={*($T15*)storage} }} - {{ typeId=15, value={*($T16*)storage} }} - {{ typeId=16, value={*($T17*)storage} }} - {{ typeId=17, value={*($T18*)storage} }} - {{ typeId=18, value={*($T19*)storage} }} - {{ typeId=19, value={*($T20*)storage} }} - {{ typeId=20, value={*($T21*)storage} }} - {{ typeId=21, value={*($T22*)storage} }} - {{ typeId=22, value={*($T23*)storage} }} - {{ typeId=23, value={*($T24*)storage} }} - {{ typeId=24, value={*($T25*)storage} }} - {{ typeId=25, value={*($T26*)storage} }} - {{ typeId=26, value={*($T27*)storage} }} - {{ typeId=27, value={*($T28*)storage} }} - {{ typeId=28, value={*($T29*)storage} }} - {{ typeId=29, value={*($T30*)storage} }} - {{ typeId=30, value={*($T31*)storage} }} - {{ typeId=31, value={*($T32*)storage} }} - {{ typeId=32, value={*($T33*)storage} }} - {{ typeId=33, value={*($T34*)storage} }} - {{ typeId=34, value={*($T35*)storage} }} - {{ typeId=35, value={*($T36*)storage} }} - {{ typeId=36, value={*($T37*)storage} }} - {{ typeId=37, value={*($T38*)storage} }} - {{ typeId=38, value={*($T39*)storage} }} - {{ typeId=39, value={*($T40*)storage} }} - {{ typeId=40, value={*($T41*)storage} }} - {{ typeId=41, value={*($T42*)storage} }} - {{ typeId=42, value={*($T43*)storage} }} - {{ typeId=43, value={*($T44*)storage} }} - {{ typeId=44, value={*($T45*)storage} }} - {{ typeId=45, value={*($T46*)storage} }} - {{ typeId=46, value={*($T47*)storage} }} - {{ typeId=47, value={*($T48*)storage} }} - {{ typeId=48, value={*($T49*)storage} }} - {{ typeId=49, value={*($T50*)storage} }} - {{ typeId=50, value={*($T51*)storage} }} - {{ typeId=51, value={*($T52*)storage} }} - {{ typeId=52, value={*($T53*)storage} }} - {{ typeId=53, value={*($T54*)storage} }} - {{ typeId=54, value={*($T55*)storage} }} - {{ typeId=55, value={*($T56*)storage} }} - {{ typeId=56, value={*($T57*)storage} }} - {{ typeId=57, value={*($T58*)storage} }} - {{ typeId=58, value={*($T59*)storage} }} - {{ typeId=59, value={*($T60*)storage} }} - {{ typeId=60, value={*($T61*)storage} }} - {{ typeId=61, value={*($T62*)storage} }} - {{ typeId=62, value={*($T63*)storage} }} - {{ typeId=63, value={*($T64*)storage} }} + {{ {"$T1"}: {*($T1*)storage} }} + {{ {"$T2"}: {*($T2*)storage} }} + {{ {"$T3"}: {*($T3*)storage} }} + {{ {"$T4"}: {*($T4*)storage} }} + {{ {"$T5"}: {*($T5*)storage} }} + {{ {"$T6"}: {*($T6*)storage} }} + {{ {"$T7"}: {*($T7*)storage} }} + {{ {"$T8"}: {*($T8*)storage} }} + {{ {"$T9"}: {*($T9*)storage} }} + {{ {"$T10"}: {*($T10*)storage} }} + {{ {"$T11"}: {*($T11*)storage} }} + {{ {"$T12"}: {*($T12*)storage} }} + {{ {"$T13"}: {*($T13*)storage} }} + {{ {"$T14"}: {*($T14*)storage} }} + {{ {"$T15"}: {*($T15*)storage} }} + {{ {"$T16"}: {*($T16*)storage} }} + {{ {"$T17"}: {*($T17*)storage} }} + {{ {"$T18"}: {*($T18*)storage} }} + {{ {"$T19"}: {*($T19*)storage} }} + {{ {"$T20"}: {*($T20*)storage} }} + {{ {"$T21"}: {*($T21*)storage} }} + {{ {"$T22"}: {*($T22*)storage} }} + {{ {"$T23"}: {*($T23*)storage} }} + {{ {"$T24"}: {*($T24*)storage} }} + {{ {"$T25"}: {*($T25*)storage} }} + {{ {"$T26"}: {*($T26*)storage} }} + {{ {"$T27"}: {*($T27*)storage} }} + {{ {"$T28"}: {*($T28*)storage} }} + {{ {"$T29"}: {*($T29*)storage} }} + {{ {"$T30"}: {*($T30*)storage} }} + {{ {"$T31"}: {*($T31*)storage} }} + {{ {"$T32"}: {*($T32*)storage} }} + {{ {"$T33"}: {*($T33*)storage} }} + {{ {"$T34"}: {*($T34*)storage} }} + {{ {"$T35"}: {*($T35*)storage} }} + {{ {"$T36"}: {*($T36*)storage} }} + {{ {"$T37"}: {*($T37*)storage} }} + {{ {"$T38"}: {*($T38*)storage} }} + {{ {"$T39"}: {*($T39*)storage} }} + {{ {"$T40"}: {*($T40*)storage} }} + {{ {"$T41"}: {*($T41*)storage} }} + {{ {"$T42"}: {*($T42*)storage} }} + {{ {"$T43"}: {*($T43*)storage} }} + {{ {"$T44"}: {*($T44*)storage} }} + {{ {"$T45"}: {*($T45*)storage} }} + {{ {"$T46"}: {*($T46*)storage} }} + {{ {"$T47"}: {*($T47*)storage} }} + {{ {"$T48"}: {*($T48*)storage} }} + {{ {"$T49"}: {*($T49*)storage} }} + {{ {"$T50"}: {*($T50*)storage} }} + {{ {"$T51"}: {*($T51*)storage} }} + {{ {"$T52"}: {*($T52*)storage} }} + {{ {"$T53"}: {*($T53*)storage} }} + {{ {"$T54"}: {*($T54*)storage} }} + {{ {"$T55"}: {*($T55*)storage} }} + {{ {"$T56"}: {*($T56*)storage} }} + {{ {"$T57"}: {*($T57*)storage} }} + {{ {"$T58"}: {*($T58*)storage} }} + {{ {"$T59"}: {*($T59*)storage} }} + {{ {"$T60"}: {*($T60*)storage} }} + {{ {"$T61"}: {*($T61*)storage} }} + {{ {"$T62"}: {*($T62*)storage} }} + {{ {"$T63"}: {*($T63*)storage} }} + {{ {"$T64"}: {*($T64*)storage} }} typeId *($T1*)storage diff --git a/tools/natvis/CodeGen.natvis b/tools/natvis/CodeGen.natvis index 5ff6e1432..2c34b4862 100644 --- a/tools/natvis/CodeGen.natvis +++ b/tools/natvis/CodeGen.natvis @@ -1,45 +1,46 @@ - - noreg - rip + + noreg + rip - al - cl - dl - bl + al + cl + dl + bl - eax - ecx - edx - ebx - esp - ebp - esi - edi - e{(int)index,d}d + eax + ecx + edx + ebx + esp + ebp + esi + edi + e{(int)index,d}d - rax - rcx - rdx - rbx - rsp - rbp - rsi - rdi - r{(int)index,d} + rax + rcx + rdx + rbx + rsp + rbp + rsi + rdi + r{(int)index,d} - xmm{(int)index,d} + xmm{(int)index,d} - ymm{(int)index,d} + ymm{(int)index,d} - + {base} {memSize,en} ptr[{base} + {index}*{(int)scale,d} + {imm}] {memSize,en} ptr[{index}*{(int)scale,d} + {imm}] {memSize,en} ptr[{base} + {imm}] + {memSize,en} ptr[{base} + {imm}] {memSize,en} ptr[{imm}] {imm} @@ -53,4 +54,13 @@ + + none + R{index&0xff}-v{index >> 8} + R{index&0xff} + K{index} + UP{index} + %{index} + + diff --git a/tools/natvis/Common.natvis b/tools/natvis/Common.natvis new file mode 100644 index 000000000..fe3a96d59 --- /dev/null +++ b/tools/natvis/Common.natvis @@ -0,0 +1,27 @@ + + + + + + count + capacity + + capacity + data + + + + + + + impl + + + + + + impl + + + + diff --git a/tools/natvis/VM.natvis b/tools/natvis/VM.natvis index e45e4e283..59bc43c46 100644 --- a/tools/natvis/VM.natvis +++ b/tools/natvis/VM.natvis @@ -1,10 +1,10 @@ - + nil {(bool)value.b} - lightuserdata {value.p} + lightuserdata {(uintptr_t)value.p,h} tag: {extra[0]} number = {value.n} vector = {value.v[0]}, {value.v[1]}, {*(float*)&extra} {value.gc->ts} @@ -12,6 +12,7 @@ function {value.gc->cl,view(short)} userdata {value.gc->u} thread {value.gc->th} + buffer {value.gc->buf} size={value.gc->buf.len} proto {value.gc->p} upvalue {value.gc->uv} deadkey @@ -24,8 +25,10 @@ value.gc->cl value.gc->u value.gc->th + value.gc->buf value.gc->p value.gc->uv + extra[0] fixed ({(int)value.gc->gch.marked}) black ({(int)value.gc->gch.marked}) @@ -36,10 +39,10 @@ - + nil {(bool)value.b} - lightuserdata {value.p} + lightuserdata {(uintptr_t)value.p,h} tag: {extra[0]} number = {value.n} vector = {value.v[0]}, {value.v[1]}, {*(float*)&extra} {value.gc->ts} @@ -47,6 +50,7 @@ function {value.gc->cl,view(short)} userdata {value.gc->u} thread {value.gc->th} + buffer {value.gc->buf} size={value.gc->buf.len} proto {value.gc->p} upvalue {value.gc->uv} deadkey @@ -59,19 +63,21 @@ value.gc->cl value.gc->u value.gc->th + value.gc->buf value.gc->p value.gc->uv + extra[0] next - + {key,na} = {val} --- - + table metatable @@ -97,7 +103,7 @@ - + @@ -125,7 +131,7 @@ - + {c.f,na} {l.p,na} {c} @@ -133,11 +139,11 @@ invalid - + {data,s} - + @@ -157,7 +163,7 @@ =[C] {cl().c.f,na} - + {ci,na} thread @@ -205,7 +211,7 @@ - + {source->data,sb}:{linedefined} function {debugname->data,sb} [{(int)numparams} arg, {(int)nups} upval] {source->data,sb}:{linedefined} [{(int)numparams} arg, {(int)nups} upval] @@ -260,7 +266,7 @@ - + {(lua_Type)tt} diff --git a/tools/perfgraph.py b/tools/perfgraph.py index 94c57cc77..1f6ecc2f1 100644 --- a/tools/perfgraph.py +++ b/tools/perfgraph.py @@ -60,29 +60,35 @@ def getDuration(nodes, nid): node = nodes[nid - 1] total = node['TotalDuration'] - for cid in node['NodeIds']: - total -= nodes[cid - 1]['TotalDuration'] + if 'NodeIds' in node: + for cid in node['NodeIds']: + total -= nodes[cid - 1]['TotalDuration'] return total def getFunctionKey(fn): - return fn['Source'] + "," + fn['Name'] + "," + str(fn['Line']) + source = fn['Source'] if 'Source' in fn else '' + name = fn['Name'] if 'Name' in fn else '' + line = str(fn['Line']) if 'Line' in fn else '-1' + + return source + "," + name + "," + line def recursivelyBuildNodeTree(nodes, functions, parent, fid, nid): ninfo = nodes[nid - 1] finfo = functions[fid - 1] child = parent.child(getFunctionKey(finfo)) - child.source = finfo['Source'] - child.function = finfo['Name'] - child.line = int(finfo['Line']) if finfo['Line'] > 0 else 0 + child.source = finfo['Source'] if 'Source' in finfo else '' + child.function = finfo['Name'] if 'Name' in finfo else '' + child.line = int(finfo['Line']) if 'Line' in finfo and finfo['Line'] > 0 else 0 child.ticks = getDuration(nodes, nid) - assert(len(ninfo['FunctionIds']) == len(ninfo['NodeIds'])) + if 'FunctionIds' in ninfo: + assert(len(ninfo['FunctionIds']) == len(ninfo['NodeIds'])) - for i in range(0, len(ninfo['FunctionIds'])): - recursivelyBuildNodeTree(nodes, functions, child, ninfo['FunctionIds'][i], ninfo['NodeIds'][i]) + for i in range(0, len(ninfo['FunctionIds'])): + recursivelyBuildNodeTree(nodes, functions, child, ninfo['FunctionIds'][i], ninfo['NodeIds'][i]) return @@ -104,10 +110,11 @@ def nodeFromJSONV2(dump): child.function = name child.ticks = getDuration(nodes, nid) - assert(len(node['FunctionIds']) == len(node['NodeIds'])) + if 'FunctionIds' in node: + assert(len(node['FunctionIds']) == len(node['NodeIds'])) - for i in range(0, len(node['FunctionIds'])): - recursivelyBuildNodeTree(nodes, functions, child, node['FunctionIds'][i], node['NodeIds'][i]) + for i in range(0, len(node['FunctionIds'])): + recursivelyBuildNodeTree(nodes, functions, child, node['FunctionIds'][i], node['NodeIds'][i]) return root diff --git a/tools/stackdbg.py b/tools/stackdbg.py new file mode 100644 index 000000000..de656c607 --- /dev/null +++ b/tools/stackdbg.py @@ -0,0 +1,94 @@ +#!usr/bin/python3 +""" +To use this command, simply run the command: +`command script import /path/to/your/game-engine/Client/Luau/tools/stackdbg.py` +in the `lldb` interpreter. You can also add it to your .lldbinit file to have it be +automatically imported. + +If using vscode, you can add the above command to your launch.json under `preRunCommands` for the appropriate target. For example: +{ + "name": "Luau.UnitTest", + "type": "lldb", + "request": "launch", + "program": "${workspaceFolder}/build/ninja/common-tests/noopt/Luau/Luau.UnitTest", + "preRunCommands": [ + "command script import ${workspaceFolder}/Client/Luau/tools/stackdbg.py" + ], +} + +Once this is loaded, +`(lldb) help stack` +or +`(lldb) stack -h +or +`(lldb) stack --help + +can get you started +""" + +import lldb +import functools +import argparse +import shlex + +# Dumps the collected frame data +def dump(collected): + for (frame_name, size_in_kb, live_size_kb, variables) in collected: + print(f'{frame_name}, locals: {size_in_kb}kb, fp-sp: {live_size_kb}kb') + for (var_name, var_size, variable_obj) in variables: + print(f' {var_name}, {var_size} bytes') + +def dbg_stack_pressure(frame, frames_to_show = 5, sort_frames = False, vars_to_show = 5, sort_vars = True): + totalKb = 0 + collect = [] + for f in frame.thread: + frame_name = f.GetFunctionName() + variables = [ (v.GetName(), v.GetByteSize(), v) for v in f.get_locals() ] + if sort_vars: + variables.sort(key = lambda x: x[1], reverse = True) + size_in_kb = functools.reduce(lambda x,y : x + y[1], variables, 0) / 1024 + + fp = f.GetFP() + sp = f.GetSP() + live_size_kb = round((fp - sp) / 1024, 2) + + size_in_kb = round(size_in_kb, 2) + totalKb += size_in_kb + collect.append((frame_name, size_in_kb, live_size_kb, variables[:vars_to_show])) + if sort_frames: + collect.sort(key = lambda x: x[1], reverse = True) + + print("******************** Report Stack Usage ********************") + totalMb = round(totalKb / 1024, 2) + print(f'{len(frame.thread)} stack frames used {totalMb}MB') + dump(collect[:frames_to_show]) + +def stack(debugger, command, result, internal_dict): + """ + usage: [-h] [-f FRAMES] [-fd] [-v VARS] [-vd] + + optional arguments: + -h, --help show this help message and exit + -f FRAMES, --frames FRAMES + How many stack frames to display + -fd, --sort_frames Sort frames + -v VARS, --vars VARS How many variables per frame to display + -vd, --sort_vars Sort frames + """ + + frame = debugger.GetSelectedTarget().GetProcess().GetSelectedThread().GetSelectedFrame() + args = shlex.split(command) + argparser = argparse.ArgumentParser(allow_abbrev = True) + argparser.add_argument("-f", "--frames", required=False, help="How many stack frames to display", default=5, type=int) + argparser.add_argument("-fd", "--sort_frames", required=False, help="Sort frames in descending order of stack usage", action="store_true", default=False) + argparser.add_argument("-v", "--vars", required=False, help="How many variables per frame to display", default=5, type=int) + argparser.add_argument("-vd", "--sort_vars", required=False, help="Sort locals in descending order of stack usage ", action="store_true", default=False) + + args = argparser.parse_args(args) + dbg_stack_pressure(frame, frames_to_show=args.frames, sort_frames=args.sort_frames, vars_to_show=args.vars, sort_vars=args.sort_vars) + +# Initialization code to add commands +def __lldb_init_module(debugger, internal_dict): + debugger.HandleCommand('command script add -f stackdbg.stack stack') + print("The 'stack' python command has been installed and is ready for use.") + diff --git a/tools/test_dcr.py b/tools/test_dcr.py index d30490b30..3de92b375 100644 --- a/tools/test_dcr.py +++ b/tools/test_dcr.py @@ -6,9 +6,17 @@ import subprocess as sp import sys import xml.sax as x -import colorama as c -c.init() +try: + import colorama as c +except ImportError: + class c: + class Fore: + RED='' + RESET='' + GREEN='' +else: + c.init() SCRIPT_PATH = os.path.split(sys.argv[0])[0] FAIL_LIST_PATH = os.path.join(SCRIPT_PATH, "faillist.txt") @@ -107,6 +115,12 @@ def main(): action="store_true", help="Write a new faillist.txt after running tests.", ) + parser.add_argument( + "--ts", + dest="suite", + action="store", + help="Only run a specific suite." + ) parser.add_argument("--randomize", action="store_true", help="Pick a random seed") @@ -122,17 +136,18 @@ def main(): failList = loadFailList() - commandLine = [ - args.path, - "--reporters=xml", - "--fflags=true,DebugLuauDeferredConstraintResolution=true", - ] + flags = ["true", "LuauSolverV2"] + + commandLine = [args.path, "--reporters=xml", "--fflags=" + ",".join(flags)] if args.random_seed: commandLine.append("--random-seed=" + str(args.random_seed)) elif args.randomize: commandLine.append("--randomize") + if args.suite: + commandLine.append(f'--ts={args.suite}') + print_stderr(">", " ".join(commandLine)) p = sp.Popen( @@ -140,6 +155,8 @@ def main(): stdout=sp.PIPE, ) + assert p.stdout + handler = Handler(failList) if args.dump: