diff --git a/.github/ISSUE_TEMPLATE/config.yml b/.github/ISSUE_TEMPLATE/config.yml new file mode 100644 index 0000000000..ca5ce76041 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/config.yml @@ -0,0 +1,5 @@ +blank_issues_enabled: false +contact_links: + - name: ❓ Ask a question + url: https://github.com/open-policy-agent/feedback/discussions + about: Community Support Forum \ No newline at end of file diff --git a/.github/ISSUE_TEMPLATE/feature_request.md b/.github/ISSUE_TEMPLATE/feature_request.md index d12e73bf32..96b1e0ee0f 100644 --- a/.github/ISSUE_TEMPLATE/feature_request.md +++ b/.github/ISSUE_TEMPLATE/feature_request.md @@ -10,11 +10,11 @@ assignees: '' -## What part of OPA would you like to see improved? +## What is the underlying problem you're trying to solve? ## Describe the ideal solution diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index 2661ab4292..48759a3b76 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -27,3 +27,27 @@ For more information on contributing to OPA see: for high-level contributing guidelines and development setup. --> + +### Why the changes in this PR are needed? + + + +### What are the changes in this PR? + + + +### Notes to assist PR review: + + + +### Further comments: + + diff --git a/.github/workflows/codeql-analysis.yml b/.github/workflows/codeql-analysis.yml index 713957a516..cbe78421db 100644 --- a/.github/workflows/codeql-analysis.yml +++ b/.github/workflows/codeql-analysis.yml @@ -30,20 +30,20 @@ jobs: steps: - name: Checkout repository - uses: actions/checkout@v2 + uses: actions/checkout@v3 - id: go_version name: Read go version - run: echo "::set-output name=go_version::$(cat .go-version)" + run: echo "go_version=$(cat .go-version)" >> $GITHUB_OUTPUT - name: Install Go (${{ steps.go_version.outputs.go_version }}) - uses: actions/setup-go@v2 + uses: actions/setup-go@v4 with: go-version: ${{ steps.go_version.outputs.go_version }} # Initializes the CodeQL tools for scanning. - name: Initialize CodeQL - uses: github/codeql-action/init@v1 + uses: github/codeql-action/init@v2 with: languages: ${{ matrix.language }} # If you wish to specify custom queries, you can do so here or in a config file. @@ -59,4 +59,4 @@ jobs: make build - name: Perform CodeQL Analysis - uses: github/codeql-action/analyze@v1 + uses: github/codeql-action/analyze@v2 diff --git a/.github/workflows/nightly.yaml b/.github/workflows/nightly.yaml index d6181a02eb..1655af9c63 100644 --- a/.github/workflows/nightly.yaml +++ b/.github/workflows/nightly.yaml @@ -8,10 +8,10 @@ on: jobs: race-detector: name: Go Race Detector - runs-on: ubuntu-latest + runs-on: ubuntu-22.04 steps: - name: Check out code - uses: actions/checkout@v2 + uses: actions/checkout@v3 - name: Test with Race Detector run: CGO_ENABLED=1 make ci-go-race-detector @@ -27,28 +27,22 @@ jobs: native-fuzzer: name: Go Fuzzer (native) - runs-on: ubuntu-latest + runs-on: ubuntu-22.04 steps: - name: Check out code - uses: actions/checkout@v2 + uses: actions/checkout@v3 - id: go_version name: Read go version - run: echo "::set-output name=go_version::$(cat .go-version)" + run: echo "go_version=$(cat .go-version)" >> $GITHUB_OUTPUT - name: Install Go (${{ steps.go_version.outputs.go_version }}) - uses: actions/setup-go@v2 + uses: actions/setup-go@v4 with: go-version: ${{ steps.go_version.outputs.go_version }} - - name: Install gotip - run: | - go install golang.org/dl/gotip@latest - gotip download - gotip version - - - name: gotip test -fuzz - run: gotip test ./ast -fuzz FuzzParseStatementsAndCompileModules -fuzztime 1h -v -run '^$' + - name: go test -fuzz + run: go test ./ast -fuzz FuzzParseStatementsAndCompileModules -fuzztime 1h -v -run '^$' - name: Dump crashers if: ${{ failure() }} @@ -63,26 +57,18 @@ jobs: status: ${{ job.status }} fields: repo,workflow - fuzzer: - name: Go Fuzzer - runs-on: ubuntu-latest + go-perf: + name: Go Perf + runs-on: ubuntu-22.04 steps: - name: Check out code - uses: actions/checkout@v2 - - - name: Run go-fuzz - run: make ci-go-check-fuzz + uses: actions/checkout@v3 - - name: Dump crashers - if: ${{ failure() }} - run: find build/fuzzer/workdir/crashers -name '*.quoted' -print -exec cat {} \; - - - name: Upload Workdir - if: ${{ failure() }} - uses: actions/upload-artifact@v2 - with: - name: workdir - path: ./build/fuzzer/workdir + - name: Benchmark Test Golang + run: make ci-go-perf + timeout-minutes: 30 + env: + DOCKER_RUNNING: 0 - name: Slack Notification uses: 8398a7/action-slack@v3 @@ -95,10 +81,10 @@ jobs: go-proxy-check: name: Go mod check - runs-on: ubuntu-latest + runs-on: ubuntu-22.04 steps: - name: Check out code - uses: actions/checkout@v2 + uses: actions/checkout@v3 - name: Vendor without proxy run: make check-go-module @@ -112,3 +98,87 @@ jobs: with: status: ${{ job.status }} fields: repo,workflow + + trivy-scan-image: + name: Trivy security scan image + runs-on: ubuntu-22.04 + steps: + - name: Checkout code # needed for .trivyignore file + uses: actions/checkout@v3 + + - run: "docker pull openpolicyagent/opa:edge-static" + + # Equivalent to: + # $ trivy image openpolicyagent/opa:edge-static + - name: Run Trivy scan on image + uses: aquasecurity/trivy-action@0.10.0 + with: + image-ref: 'openpolicyagent/opa:edge-static' + format: table + exit-code: '1' + ignore-unfixed: true + vuln-type: os,library + severity: CRITICAL,HIGH + + - name: Slack Notification + uses: 8398a7/action-slack@v3 + env: + SLACK_WEBHOOK_URL: ${{ secrets.SLACK_NOTIFICATION_WEBHOOK }} + if: ${{ failure() && env.SLACK_WEBHOOK_URL }} + with: + status: ${{ job.status }} + fields: repo,workflow + + trivy-scan-repo: + name: Trivy security scan repo + runs-on: ubuntu-22.04 + steps: + - name: Checkout code + uses: actions/checkout@v3 + + # Equivalent to: + # $ trivy fs . + - name: Run Trivy scan on repo + uses: aquasecurity/trivy-action@0.10.0 + with: + scan-type: fs + format: table + exit-code: '1' + ignore-unfixed: true + skip-dirs: vendor/,internal/gqlparser/validator/imported/ + severity: CRITICAL,HIGH + + - name: Slack Notification + uses: 8398a7/action-slack@v3 + env: + SLACK_WEBHOOK_URL: ${{ secrets.SLACK_NOTIFICATION_WEBHOOK }} + if: ${{ failure() && env.SLACK_WEBHOOK_URL }} + with: + status: ${{ job.status }} + fields: repo,workflow + + govulncheck: + name: Go vulnerability check + runs-on: ubuntu-22.04 + steps: + - uses: actions/checkout@v3 + - id: go_version + name: Read go version + run: echo "go_version=$(cat .go-version)" >> $GITHUB_OUTPUT + + - name: Install Go (${{ steps.go_version.outputs.go_version }}) + uses: actions/setup-go@v4 + with: + go-version: ${{ steps.go_version.outputs.go_version }} + + - run: go install golang.org/x/vuln/cmd/govulncheck@latest + - run: govulncheck ./... + + - name: Slack Notification + uses: 8398a7/action-slack@v3 + env: + SLACK_WEBHOOK_URL: ${{ secrets.SLACK_NOTIFICATION_WEBHOOK }} + if: ${{ failure() && env.SLACK_WEBHOOK_URL }} + with: + status: ${{ job.status }} + fields: repo,workflow diff --git a/.github/workflows/post-merge.yaml b/.github/workflows/post-merge.yaml index 89cb3ea92b..1acb7f5d05 100644 --- a/.github/workflows/post-merge.yaml +++ b/.github/workflows/post-merge.yaml @@ -7,16 +7,16 @@ on: jobs: generate: - name: Sync Generated Code - runs-on: ubuntu-18.04 + name: Sync Generated Code and Docs + runs-on: ubuntu-22.04 steps: - name: Check out code - uses: actions/checkout@v2 + uses: actions/checkout@v3 with: - token: ${{ secrets.GH_PUSH_TOKEN }} + token: ${{ secrets.GH_PUSH_TOKEN }} # required to push to protected branch below - name: Generate - run: make clean generate + run: make clean generate docs-generate-cli-docs - name: Commit & Push shell: bash @@ -45,13 +45,34 @@ jobs: echo "No generated changes to push!" fi + AUTHOR=cli-docs-updater + git config user.name ${AUTHOR} + git config user.email ${AUTHOR}@github.com + + # Prevent looping if the build was non-deterministic.. + CAN_PUSH=1 + if [[ "$(git log -1 --pretty=format:'%an')" == "${AUTHOR}" ]]; then + CAN_PUSH=0 + fi + + if ./build/commit-cli-docs.sh; then + if [[ "${CAN_PUSH}" == "1" ]]; then + git push + else + echo "Previous commit was auto-generated -- Aborting!" + exit 1 + fi + else + echo "No generated changes to push!" + fi + code-coverage: name: Update Go Test Coverage - runs-on: ubuntu-18.04 + runs-on: ubuntu-22.04 needs: generate steps: - name: Check out code - uses: actions/checkout@v2 + uses: actions/checkout@v3 - name: Unit Test Golang run: make ci-go-test-coverage @@ -59,11 +80,11 @@ jobs: release-build: name: Release Build (linux, windows) - runs-on: ubuntu-18.04 + runs-on: ubuntu-22.04 needs: generate steps: - name: Check out code - uses: actions/checkout@v2 + uses: actions/checkout@v3 - name: Build Linux and Windows run: make ci-go-ci-build-linux ci-go-ci-build-linux-static ci-go-ci-build-windows @@ -71,8 +92,15 @@ jobs: env: TELEMETRY_URL: ${{ secrets.TELEMETRY_URL }} + - name: Build Linux arm64 + run: make ci-go-ci-build-linux-static + timeout-minutes: 30 + env: + GOARCH: arm64 + TELEMETRY_URL: ${{ secrets.TELEMETRY_URL }} + - name: Upload binaries - uses: actions/upload-artifact@v2 + uses: actions/upload-artifact@v3 if: always() with: name: binaries @@ -84,14 +112,14 @@ jobs: needs: generate steps: - name: Check out code - uses: actions/checkout@v2 + uses: actions/checkout@v3 - id: go_version name: Read go version - run: echo "::set-output name=go_version::$(cat .go-version)" + run: echo "go_version=$(cat .go-version)" >> $GITHUB_OUTPUT - name: Install Go (${{ steps.go_version.outputs.go_version }}) - uses: actions/setup-go@v2 + uses: actions/setup-go@v4 with: go-version: ${{ steps.go_version.outputs.go_version }} @@ -102,7 +130,7 @@ jobs: TELEMETRY_URL: ${{ secrets.TELEMETRY_URL }} - name: Upload binaries (darwin) - uses: actions/upload-artifact@v2 + uses: actions/upload-artifact@v3 if: always() with: name: binaries @@ -110,22 +138,25 @@ jobs: deploy-edge: name: Push Edge Release - runs-on: ubuntu-18.04 + runs-on: ubuntu-22.04 needs: [release-build, release-build-darwin] steps: - name: Check out code - uses: actions/checkout@v2 + uses: actions/checkout@v3 - name: Test run: make ci-release-test timeout-minutes: 60 - name: Download release binaries - uses: actions/download-artifact@v2 + uses: actions/download-artifact@v3 with: name: binaries path: _release + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v2 + - name: Deploy OPA Edge env: DOCKER_USER: ${{ secrets.DOCKER_USER }} @@ -140,11 +171,11 @@ jobs: deploy-wasm-builder: name: Deploy WASM Builder - runs-on: ubuntu-18.04 + runs-on: ubuntu-22.04 needs: generate steps: - name: Check out code - uses: actions/checkout@v2 + uses: actions/checkout@v3 - name: Build and Push opa-wasm-builder env: diff --git a/.github/workflows/post-release.yaml b/.github/workflows/post-release.yaml index 8705e12cec..942f778a4c 100644 --- a/.github/workflows/post-release.yaml +++ b/.github/workflows/post-release.yaml @@ -4,14 +4,17 @@ on: release: types: [published] +permissions: + contents: none + jobs: kick-netlify: name: Kick Netlify - runs-on: ubuntu-18.04 + runs-on: ubuntu-22.04 steps: - name: Trigger Netlify Deploy env: NETLIFY_BUILD_HOOK_URL: ${{ secrets.NETLIFY_BUILD_HOOK_URL }} if: ${{ env.NETLIFY_BUILD_HOOK_URL }} run: | - curl --fail --request POST -d {} ${{ env.NETLIFY_BUILD_HOOK_URL }} \ No newline at end of file + curl --fail --request POST -d {} ${{ env.NETLIFY_BUILD_HOOK_URL }} diff --git a/.github/workflows/post-tag.yaml b/.github/workflows/post-tag.yaml index c31aa58218..a73102d6da 100644 --- a/.github/workflows/post-tag.yaml +++ b/.github/workflows/post-tag.yaml @@ -8,10 +8,10 @@ on: jobs: generate: name: Generate Code - runs-on: ubuntu-18.04 + runs-on: ubuntu-22.04 steps: - name: Check out code - uses: actions/checkout@v2 + uses: actions/checkout@v3 with: token: ${{ secrets.GH_PUSH_TOKEN }} @@ -20,11 +20,11 @@ jobs: release-build: name: Release Build (linux, windows) - runs-on: ubuntu-18.04 + runs-on: ubuntu-22.04 needs: generate steps: - name: Check out code - uses: actions/checkout@v2 + uses: actions/checkout@v3 - name: Build Linux and Windows run: make ci-go-ci-build-linux ci-go-ci-build-linux-static ci-go-ci-build-windows @@ -32,8 +32,15 @@ jobs: env: TELEMETRY_URL: ${{ secrets.TELEMETRY_URL }} + - name: Build Linux arm64 + run: make ci-go-ci-build-linux-static + timeout-minutes: 30 + env: + GOARCH: arm64 + TELEMETRY_URL: ${{ secrets.TELEMETRY_URL }} + - name: Upload binaries - uses: actions/upload-artifact@v2 + uses: actions/upload-artifact@v3 if: always() with: name: binaries @@ -45,25 +52,25 @@ jobs: needs: generate steps: - name: Check out code - uses: actions/checkout@v2 + uses: actions/checkout@v3 - id: go_version name: Read go version - run: echo "::set-output name=go_version::$(cat .go-version)" + run: echo "go_version=$(cat .go-version)" >> $GITHUB_OUTPUT - name: Install Go (${{ steps.go_version.outputs.go_version }}) - uses: actions/setup-go@v2 + uses: actions/setup-go@v4 with: go-version: ${{ steps.go_version.outputs.go_version }} - name: Build Darwin - run: make ci-build-darwin + run: make ci-build-darwin ci-build-darwin-arm64-static timeout-minutes: 30 env: TELEMETRY_URL: ${{ secrets.TELEMETRY_URL }} - name: Upload binaries (darwin) - uses: actions/upload-artifact@v2 + uses: actions/upload-artifact@v3 if: always() with: name: binaries @@ -72,21 +79,24 @@ jobs: build: name: Push Latest Release needs: [release-build, release-build-darwin] - runs-on: ubuntu-latest + runs-on: ubuntu-22.04 steps: - name: Check out code - uses: actions/checkout@v2 + uses: actions/checkout@v3 - name: Set TAG_NAME in Environment # Subsequent jobs will be have the computed tag name run: echo "TAG_NAME=${GITHUB_REF##*/}" >> $GITHUB_ENV - name: Download release binaries - uses: actions/download-artifact@v2 + uses: actions/download-artifact@v3 with: name: binaries path: _release + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v2 + - name: Build and Deploy OPA Docker Images id: build-and-deploy env: diff --git a/.github/workflows/pull-request.yaml b/.github/workflows/pull-request.yaml index c8921e4e49..12ac7377a0 100644 --- a/.github/workflows/pull-request.yaml +++ b/.github/workflows/pull-request.yaml @@ -2,21 +2,27 @@ name: PR Check on: [pull_request] +# When a new revision is pushed to a PR, cancel all in-progress CI runs for that +# PR. See https://docs.github.com/en/actions/using-jobs/using-concurrency +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} + cancel-in-progress: true + jobs: # All jobs essentially re-create the `ci-release-test` make target, but are split # up for parallel runners for faster PR feedback and a nicer UX. generate: name: Generate Code - runs-on: ubuntu-18.04 + runs-on: ubuntu-22.04 steps: - name: Check out code - uses: actions/checkout@v2 + uses: actions/checkout@v3 - name: Generate run: make clean generate - name: Upload generated artifacts - uses: actions/upload-artifact@v2 + uses: actions/upload-artifact@v3 with: name: generated path: | @@ -24,7 +30,7 @@ jobs: capabilities.json go-build: - name: Go Build (${{ matrix.os }}) + name: Go Build (${{ matrix.os }}${{ matrix.arch && format(' {0}', matrix.arch) || '' }}) runs-on: ${{ matrix.run }} needs: generate strategy: @@ -32,39 +38,46 @@ jobs: matrix: include: - os: linux - run: ubuntu-18.04 + run: ubuntu-22.04 targets: ci-go-ci-build-linux ci-go-ci-build-linux-static + arch: amd64 + - os: linux + run: ubuntu-22.04 + targets: ci-go-ci-build-linux-static + arch: arm64 - os: windows - run: ubuntu-18.04 + run: ubuntu-22.04 targets: ci-go-ci-build-windows - os: darwin run: macos-latest targets: ci-build-darwin ci-build-darwin-arm64-static steps: - name: Check out code - uses: actions/checkout@v2 + uses: actions/checkout@v3 - id: go_version name: Read go version - run: echo "::set-output name=go_version::$(cat .go-version)" + run: echo "go_version=$(cat .go-version)" >> $GITHUB_OUTPUT - name: Install Go (${{ steps.go_version.outputs.go_version }}) - uses: actions/setup-go@v2 + uses: actions/setup-go@v4 with: go-version: ${{ steps.go_version.outputs.go_version }} if: matrix.os == 'darwin' - name: Download generated artifacts - uses: actions/download-artifact@v2 + uses: actions/download-artifact@v3 with: name: generated - name: Build run: make ${{ matrix.targets }} + env: + GOARCH: ${{ matrix.arch }} timeout-minutes: 30 - name: Upload binaries - uses: actions/upload-artifact@v2 + uses: actions/upload-artifact@v3 if: always() with: name: binaries @@ -79,24 +92,24 @@ jobs: matrix: include: - os: linux - run: ubuntu-18.04 + run: ubuntu-22.04 - os: darwin run: macos-latest steps: - name: Check out code - uses: actions/checkout@v2 + uses: actions/checkout@v3 - id: go_version name: Read go version - run: echo "::set-output name=go_version::$(cat .go-version)" + run: echo "go_version=$(cat .go-version)" >> $GITHUB_OUTPUT - name: Install Go (${{ steps.go_version.outputs.go_version }}) - uses: actions/setup-go@v2 + uses: actions/setup-go@v4 with: go-version: ${{ steps.go_version.outputs.go_version }} - name: Download generated artifacts - uses: actions/download-artifact@v2 + uses: actions/download-artifact@v3 with: name: generated @@ -104,34 +117,12 @@ jobs: run: make test-coverage timeout-minutes: 30 - go-perf: - name: Go Perf - runs-on: ubuntu-18.04 - steps: - - name: Check out code - uses: actions/checkout@v2 - - - name: Benchmark Test Golang - run: make ci-go-perf - timeout-minutes: 30 - - go-quick-fuzz: - name: Go quick fuzz - runs-on: ubuntu-18.04 - steps: - - name: Check out code - uses: actions/checkout@v2 - - - name: Run fuzz check (3m) - run: make ci-go-check-fuzz FUZZ_TIME=180 - timeout-minutes: 30 - go-lint: name: Go Lint - runs-on: ubuntu-18.04 + runs-on: ubuntu-22.04 steps: - name: Check out code - uses: actions/checkout@v2 + uses: actions/checkout@v3 - name: Golang Style and Lint Check run: make check @@ -139,63 +130,108 @@ jobs: wasm: name: WASM - runs-on: ubuntu-18.04 + runs-on: ubuntu-22.04 + needs: generate steps: - name: Check out code - uses: actions/checkout@v2 + uses: actions/checkout@v3 - - name: Build and Test WASM + - name: Check PR for changes to Wasm + uses: dorny/paths-filter@v2 + id: changes + with: + filters: | + wasm: + - Makefile + - 'wasm/**' + - 'ast/**' + - 'internal/compiler/**' + - 'internal/planner/**' + - 'internal/wasm/**' + - 'test/wasm/**' + - 'test/cases/**' + + - name: Download generated artifacts + uses: actions/download-artifact@v3 + with: + name: generated + if: steps.changes.outputs.wasm == 'true' + + - name: Build and Test Wasm run: make ci-wasm timeout-minutes: 15 + if: steps.changes.outputs.wasm == 'true' + + - name: Build and Test Wasm SDK + run: make ci-go-wasm-sdk-e2e-test + timeout-minutes: 30 + if: steps.changes.outputs.wasm == 'true' + env: + DOCKER_RUNNING: 0 check-generated: name: Check Generated - runs-on: ubuntu-18.04 + runs-on: ubuntu-22.04 + needs: generate steps: - name: Check out code - uses: actions/checkout@v2 + uses: actions/checkout@v3 + + - name: Download generated artifacts + uses: actions/download-artifact@v3 + with: + name: generated - name: Check Working Copy run: make ci-check-working-copy timeout-minutes: 15 - - wasm-go-sdk-e2e: - name: OPA Wasm SDK e2e - runs-on: ubuntu-18.04 - steps: - - name: Check out code - uses: actions/checkout@v2 - - - name: Build and Test Wasm SDK - run: make ci-go-wasm-sdk-e2e-test - timeout-minutes: 30 + env: + DOCKER_RUNNING: 0 race-detector: name: Go Race Detector - runs-on: ubuntu-latest + runs-on: ubuntu-22.04 + needs: generate steps: - name: Check out code - uses: actions/checkout@v2 + uses: actions/checkout@v3 + + - name: Download generated artifacts + uses: actions/download-artifact@v3 + with: + name: generated - name: Test with Race Detector run: make ci-go-race-detector + env: + DOCKER_RUNNING: 0 smoke-test-docker-images: name: docker image smoke test - runs-on: ubuntu-latest + runs-on: ubuntu-22.04 needs: go-build steps: - name: Check out code - uses: actions/checkout@v2 + uses: actions/checkout@v3 + + - name: Set up QEMU + uses: docker/setup-qemu-action@v2 + with: + platforms: arm64 - name: Download release binaries - uses: actions/download-artifact@v2 + uses: actions/download-artifact@v3 with: name: binaries path: _release - - name: Test images + - name: Test amd64 images + run: make ci-image-smoke-test + + - name: Test arm64 images run: make ci-image-smoke-test + env: + GOARCH: arm64 smoke-test-binaries: runs-on: ${{ matrix.os }} @@ -203,9 +239,9 @@ jobs: strategy: matrix: include: - - os: ubuntu-latest + - os: ubuntu-22.04 exec: opa_linux_amd64 - - os: ubuntu-latest + - os: ubuntu-22.04 exec: opa_linux_amd64_static wasm: disabled - os: macos-latest @@ -215,10 +251,10 @@ jobs: steps: - name: Check out code - uses: actions/checkout@v2 + uses: actions/checkout@v3 - name: Download release binaries - uses: actions/download-artifact@v2 + uses: actions/download-artifact@v3 with: name: binaries path: _release @@ -230,47 +266,60 @@ jobs: run: make ci-binary-smoke-test-wasm BINARY=${{ matrix.exec }} if: matrix.wasm != 'disabled' - nodejs-wasm-example: - name: npm-opa-wasm - runs-on: ubuntu-latest - needs: go-build + go-version-build: + name: Go compat build/test + needs: generate + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: [ubuntu-22.04, macos-latest] + version: ["1.18", "1.19"] steps: - - name: Download release binaries - uses: actions/download-artifact@v2 + - uses: actions/checkout@v3 + - name: Download generated artifacts + uses: actions/download-artifact@v3 with: - name: binaries - path: _release - - - name: Prepare OPA - run: | - ln -s _release/*/opa_linux_amd64 opa - chmod +x opa - echo $(pwd) >> $GITHUB_PATH + name: generated + - uses: actions/setup-go@v4 + with: + go-version: ${{ matrix.version }} + - run: make build + env: + DOCKER_RUNNING: 0 + - run: make go-test + env: + DOCKER_RUNNING: 0 + + + # Run PR metadata against Rego policies + rego-check-pr: + name: Rego PR checks + runs-on: ubuntu-22.04 + steps: + - name: Checkout code + uses: actions/checkout@v3 - - name: Check out npm-opa-wasm - uses: actions/checkout@v2 + - name: Download OPA + uses: open-policy-agent/setup-opa@v2 with: - repository: open-policy-agent/npm-opa-wasm - path: npm-opa-wasm + version: edge - - name: Run npm-opa-wasm nodejs-app examples - run: | - npm install - ./e2e.sh - working-directory: npm-opa-wasm + - name: Test policies + run: opa test build/policy - go-version-build: - name: Go compat builds - runs-on: ubuntu-latest - strategy: - matrix: - include: - - version: "1.16" - - version: "1.15" - steps: - - name: Check out code - uses: actions/checkout@v2 + - name: Ensure proper formatting + run: opa fmt --list --fail build/policy - - name: Build - run: make ci-go-ci-build-linux GOVERSION=${{ matrix.version }} - timeout-minutes: 30 + - name: Run policy checks on changed files + run: | + curl --silent --fail --header 'Authorization: Bearer ${{ secrets.GITHUB_TOKEN }}' -o files.json \ + https://api.github.com/repos/${{ github.repository }}/pulls/${{ github.event.pull_request.number }}/files + opa eval --bundle build/policy/ --format values --input files.json \ + --fail-defined 'data.files.deny[message]' + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + + - name: Show input on failure + run: opa eval --input files.json --format pretty input + if: ${{ failure() }} diff --git a/.gitignore b/.gitignore index faf4e41a76..e6429ce649 100644 --- a/.gitignore +++ b/.gitignore @@ -34,3 +34,6 @@ policies # man pages man + +# generated when running local website build +docs/website/.hugo_build.lock diff --git a/.go-version b/.go-version index b48f322609..0bd54efd31 100644 --- a/.go-version +++ b/.go-version @@ -1 +1 @@ -1.17 +1.20.4 diff --git a/.golangci.yaml b/.golangci.yaml index 6286815ef3..5ac97ced90 100644 --- a/.golangci.yaml +++ b/.golangci.yaml @@ -18,11 +18,11 @@ linters: - gofmt - goimports - unused - - varcheck - - deadcode - misspell + - tenv - typecheck - - structcheck - staticcheck - gosimple + - prealloc + - unconvert # - gosec # too many false positives diff --git a/.trivyignore b/.trivyignore new file mode 100644 index 0000000000..74dc0da0a3 --- /dev/null +++ b/.trivyignore @@ -0,0 +1,35 @@ +# We're not directly using nor running these dependencies, and hence they're not applicable +# +# * github.com/satori/go.uuid - used by a dependency that imports containerd... the containerd import +# used here is not vulnerable though. +CVE-2021-3538 + +# * go.etcd.io/etcd - we don't run etcd as part of the OPA deployment +CVE-2018-1098 +CVE-2018-1099 + +# * k8s.io/kubernetes - we don't run kubernetes as part of the OPA deployment +CVE-2019-1002101 +CVE-2019-11250 +CVE-2019-11253 +CVE-2019-11254 +CVE-2020-8552 +CVE-2020-8554 +CVE-2020-8555 +CVE-2020-8557 +CVE-2020-8558 +CVE-2020-8559 +CVE-2020-8561 +CVE-2020-8562 +CVE-2020-8563 +CVE-2020-8564 +CVE-2020-8565 +CVE-2021-25735 +CVE-2021-25740 +CVE-2021-25741 + +# * github.com/emicklei/go-restful - we don't use its code in our handlers +CVE-2022-1996 + +# github.com/dgrijalva/jwt-go -- vulnerable version used by docker/distribution above +CVE-2020-26160 diff --git a/ADOPTERS.md b/ADOPTERS.md index 91c94c9cc6..059842c276 100644 --- a/ADOPTERS.md +++ b/ADOPTERS.md @@ -6,6 +6,17 @@ This is a list of organizations that have spoken publicly about their adoption or production users that have added themselves (in alphabetical order): +* [2U, Inc](https://2u.com) has incorporated OPA into their SDLC for both Terraform and Kubernetes deployments. + Shift left! + +* [Appsflyer](https://www.appsflyer.com/) uses OPA to make consistent + authorization decisions by hundreds of microservices for UI and API data + access. All authorization decisions are delegated to OPA that is deployed as a + central service. The decisions are driven by flexible policy rules that take + into consideration data privacy regulations and policies, data consents and + application level access permissions. For more information, see the [Appsflyer + Engineering Blog post](https://medium.com/appsflyer/authorization-solution-for-microservices-architecture-a2ac0c3c510b). + * [Atlassian](https://www.atlassian.com/) uses OPA in a heterogeneous cloud environment for microservice API authorization. OPA is deployed per-host and inside of their Slauth (AAA) system. Policies are tagged and categorized @@ -31,7 +42,7 @@ production users that have added themselves (in alphabetical order): * [Capital One](https://www.capitalone.com/) uses OPA to enforce a variety of admission control policies across their Kubernetes clusters including image - registry whitelisting, label requirements, resource requirements, container + registry allowlisting, label requirements, resource requirements, container privileges, etc. For more information see this talk from [KubeCon US 2018](https://www.youtube.com/watch?v=CDDsjMOtJ-c&t=6m35s) and this talk from [OPA Summit 2019](https://www.youtube.com/watch?v=vkvWZuqSk5M). @@ -64,6 +75,11 @@ production users that have added themselves (in alphabetical order): OPA-based admission controllers, covering single-tenant environments and hard multi-tenancy configurations. +* [Digraph](https://www.getdigraph.com) is a developer-first cloud compliance platform + that uses OPA to let security teams detect and resolve non-compliant infrastructure + changes before they're deployed to production, and produce audit trails to eliminate + manual work and accelerate audit processes like SOC and ISO. + * [Fugue](https://fugue.co) is a cloud security SaaS that uses OPA to classify compliance violations and security risks in AWS and Azure accounts and generate compliance reports and notifications. @@ -73,6 +89,19 @@ production users that have added themselves (in alphabetical order): RBAC, PV, and Quota resources that are central to the security and operation of these clusters. For more information see this talk from [KubeCon US 2019](https://www.youtube.com/watch?v=lYHr_UaHsYQ). +* [Google Cloud](https://cloud.google.com/) uses OPA to validate Google Cloud + product's configurations in several products and tools, including + [Anthos Config Management](https://cloud.google.com/anthos/config-management), + [GKE Policy Automation](https://github.com/google/gke-policy-automation) or + [Config Validator](https://github.com/GoogleCloudPlatform/policy-library). See + [Creating policy-compliant Google Cloud resources article](https://cloud.google.com/architecture/policy-compliant-resources) + for example use cases. + +* [Infracost](https://www.infracost.io/) shows cloud cost estimates for Terraform. + It uses OPA to enable users to create cost policies, and setup guardrails such + as "this change puts the monthly costs above $10K, which is the budget for this + product. Consider asking the team lead to review it". See [the docs](https://www.infracost.io/docs/features/cost_policies/) for details. + * [Intuit](https://www.intuit.com/company/) uses OPA as a validating and mutating admission controller to implement various security, multi-tenancy, and risk management policies across approximately 50 @@ -93,9 +122,9 @@ production users that have added themselves (in alphabetical order): AWS resources to generate the final report. * [Mercari](https://www.mercari.com/) uses OPA to enforce admission control - policies in their multi-tenant Kubernetes clusters. It helps maintain - the governance of the cluster, checking that developers are following - the best practices in the admission controller. They also use [confest](https://github.com/open-policy-agent/conftest) to + policies in their multi-tenant Kubernetes clusters. It helps maintain + the governance of the cluster, checking that developers are following + the best practices in the admission controller. They also use [confest](https://github.com/open-policy-agent/conftest) to enforce policies in their CI/CD pipeline. * [Netflix](https://www.netflix.com) uses OPA as a method of enforcing @@ -134,6 +163,8 @@ production users that have added themselves (in alphabetical order): etc. SAP/Infrabox is used in production within SAP and has several external users. +* [Terminus Software](https://terminus.com/) uses OPA for microservice authorization. + * [T-Mobile](https://www.t-mobile.com) uses OPA as a core component for their [MagTape](https://github.com/tmobile/magtape/) project that enforces best practices and secure configurations across their fleet of Kubernetes @@ -167,6 +198,18 @@ production users that have added themselves (in alphabetical order): [part 1](https://blog.verygoodsecurity.com/posts/building-a-fine-grained-permission-system-in-a-distributed-environment), [part 2](https://blog.verygoodsecurity.com/posts/building-a-fine-grained-permissions-system-in-a-distributed-environment). +* [VNG Cloud](https://www.vngcloud.vn/en/home) [Identity and Access Management (IAM)](https://iam.vngcloud.vn/) + use OPA as a policy-based decision engine for authorization. IAM provides administrators with fine-grained + access control to VNG Cloud resources and help centralize and manage permissions to access resources. + Specifically, OPA is integrated to evaluate policies to make the decision about denying or allowing incoming requests. + +* [Wiz](https://www.wiz.io/) helps every organization rapidly remove the most critical + risks in their cloud estate. It simply connects in minutes, requires zero agents, and + automatically correlates the entire security stack to uncover the most pressing issues. + Wiz policies leverage Open Policy Agent (OPA) for a unified framework across the + cloud-native stack. Whether for configurations, compliance, IaC, and more, OPA enables + teams to move faster in the cloud. For more information on how Wiz uses OPA, [contact Wiz](https://www.wiz.io/contact/). + * [Xenit AB](https://www.xenit.se/) uses OPA to implement fine-grained control over resource formulation in its managed Kubernetes service as well as several customer-specific implementations. For more information, see the Kubernetes Terraform library [OPA Gatekeeper module](https://github.com/XenitAB/terraform-modules/tree/main/modules/kubernetes/opa-gatekeeper) and @@ -175,7 +218,7 @@ production users that have added themselves (in alphabetical order): * [Yelp](https://www.yelp.com/) use OPA and Envoy to enforce authorization policies across a fleet of microservices that evolved out of a monolithic architecture. For more information see this talk from [KubeCon US 2019](https://www.youtube.com/watch?v=Z6aN3Smt-9M). - + In addition, there are several production adopters that prefer to remain anonymous. @@ -187,13 +230,13 @@ remain anonymous. This is a list of adopters in early stages of production or pre-production (in alphabetical order): -* [Aserto](https://www.aserto.com/) is a venture-backed developer API company - that helps developers easily build permissions and roles into their SaaS - applications. Aserto uses OPA as its core engine, and has contributed projects - such as [Open Policy Registry](https://openpolicyregistry.io) and - [OPA Runtime](https://github.com/aserto-dev/runtime) that make it easier for +* [Aserto](https://www.aserto.com/) is a venture-backed developer API company + that helps developers easily build permissions and roles into their SaaS + applications. Aserto uses OPA as its core engine, and has contributed projects + such as [Open Policy Registry](https://openpolicyregistry.io) and + [OPA Runtime](https://github.com/aserto-dev/runtime) that make it easier for developers to incorporate OPA policies and the OPA engine into their applications. - + * [Cyral](https://www.cyral.com/) is a venture-funded data security company. Still in stealth mode but using OPA to manage and enforce fine-grained authorization policies. @@ -210,6 +253,14 @@ pre-production (in alphabetical order): December 2018, ~850 ORY Keto instances were running in a mix of pre-production and production environments. +* [Permit.io](https://permit.io) Uses a combination of OPA and OPAL + to power fine-grained authorization policies at the core of the Permit.io platform. + Permit.io leverages the power of OPA's Rego language, + generating new Rego code on the fly from its UI policy editor. + The team behind Permit.io contributes to the OPA ecosystem - creating opens-source projects like + [OPAL- making OPA event-driven)](https://github.com/permitio/opal) + and [OPToggles - sync Frontend with open-policy](https://github.com/permitio/OPToggles). + * [Scalr](https://scalr.com/) is a remote operations backend for Terraform that helps users scale their Terraform usage through automation and collaboration. [Scalr uses OPA](https://docs.scalr.com/en/latest/opa.html) to validate Terraform @@ -221,6 +272,10 @@ pre-production (in alphabetical order): automated code review, defining access levels or blocking execution of unwanted code. +* [Wealthsimple](https://www.wealthsimple.com/) is using OPA to power all authorization checks their microservice ecosystem by leveraging their existing authorization library make the transition to OPA as simple as possible for development teams. + +* [Magda](https://github.com/magda-io/magda) is a federated, Kubernetes-based, open-source data catalog system. Working as Magda's central authorisation policy engine, OPA helps not only the API endpoint authorisation. Magda also uses its partial evaluation feature to translate datasets authorisation decisions to other database-specific DSLs (e.g. SQL or Elasticsearch DSL) and use them for dataset authorisation enforcement in different databases. + Other adopters that have gone into production or various stages of testing include: @@ -230,4 +285,5 @@ testing include: * [State Street Corporation](http://www.statestreet.com/) If you have adopted OPA and would like to be included in this list, -feel free to submit a PR. +feel free to submit a PR updating this file or +[open an issue](https://github.com/open-policy-agent/opa/issues/new?assignees=&labels=adopt-opa&template=adopt-opa.yaml&title=organization_name+has+adopted+OPA). diff --git a/CHANGELOG.md b/CHANGELOG.md index ec4b5f3504..75515bda5c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,2324 @@ project adheres to [Semantic Versioning](http://semver.org/). ## Unreleased +## 0.52.0 + +This release contains some enhancements, bugfixes, and a new builtin function. + +### Allow Adding Labels via Discovery + +Previously OPA did not allow any updates to the labels provided in the boot configuration via the discovered (ie. service) +config. This was done to avoid breaking the discovery configuration. But there are use cases where labels can serve as a convenient +way to pass information that could be used in policies, status updates or decision logs. This change allows +additional labels to be configured in the service config which are then made available during runtime. + +See [the Discovery documentation](https://www.openpolicyagent.org/docs/v0.52.0/management-discovery/#limitations) +for more details. + +Authored by @mjungsbluth. + +### New Built-In Function: crypto.hmac.equal + +`crypto.hmac.equal` provides a convenient way to compare hashes generated by the MD5, SHA-1, SHA-256 and SHA-512 hashing algorithms. + +Below is a real world example of how this built-in function can be utilized. Imagine our server is registered as a +GitHub webhook which subscribes to certain events on GitHub.com. Now we want to limit requests to those coming from GitHub. +One of the ways to do that is to first set up a secret token and validate the information. Once we create the token on GitHub, +we'll set up an environment variable that stores this token and makes it available to OPA via the `opa.runtime` built-in. +In the case of GitHub webhooks the validation is done by comparing the hash signature received in the `X-Hub-Signature-256` +header and calculating a hash using the secret token and payload body. The `check_signature` rule implements this logic. + +```rego +package example + +import input.attributes.request.http as http_request + +allow { + http_request.method == "POST" + input.parsed_path = ["workflows", "github", "webhooks"] + check_signature +} + +check_signature { + secret_key := opa.runtime().env.GITHUB_SECRET_KEY + hash_body := crypto.hmac.sha256(http_request.raw_body, secret_key) + expected_signature := concat("", ["sha256=", hash_body]) + header_signature = http_request.headers["X-Hub-Signature-256"] + crypto.hmac.equal(header_signature, expected_signature) +} +``` + +See [the documentation on the new built-in](https://www.openpolicyagent.org/docs/v0.52.0/policy-reference/#builtin-crypto-cryptohmacequal) +for all the details. + +Authored by @sandokandias. + +### Extend Authentication Methods Supported by OCI Downloader + +Previously the OCI Downloader had support for only three types of authentication methods, namely `Client TLS Certificates`, +`Basic Authentication` and `Bearer Token`. This change adds support for other authentication methods such as [AWS Signature](https://www.openpolicyagent.org/docs/v0.52.0/configuration/#aws-signature), +[GCP Metadata Token](https://www.openpolicyagent.org/docs/v0.52.0/configuration/#gcp-metadata-token). See [the documentation](https://www.openpolicyagent.org/docs/v0.52.0/configuration/#using-private-image-from-oci-repositories) +for more details. + +Authored by @DerGut. + +### Update Profiler Output With Number of Generated Expressions + +The number of EVAL/REDO counts in the profile result are sometimes difficult to understand. This is mainly due to the +fact that the compiler rewrites expressions and assigns the same location to each generated expression and the profiler +keys the counters by the location. To provide more clarity, the profile output now includes the number of generated +expressions for each given expression thereby helping to better understand the result and also how the evaluation works. + +Here is an example of the updated profiler output with the new `NUM GEN EXPR` column: + +```ruby ++----------+----------+----------+--------------+-------------+ +| TIME | NUM EVAL | NUM REDO | NUM GEN EXPR | LOCATION | ++----------+----------+----------+--------------+-------------+ +| 20.291µs | 3 | 3 | 3 | test.rego:7 | +| 1µs | 1 | 1 | 1 | test.rego:6 | +| 2.333µs | 1 | 1 | 1 | test.rego:5 | +| 6.333µs | 1 | 1 | 1 | test.rego:4 | +| 84.75µs | 1 | 1 | 1 | data | ++----------+----------+----------+--------------+-------------+ +``` + +See [the Profiling documentation](https://www.openpolicyagent.org/docs/v0.52.0/policy-performance/#profiling) +for more details. + +Authored by @ashutosh-narkar. + +### Runtime, Tooling, SDK + +- bundle: Add ability to load bundles from an arbitrary filesystem ([#5833](https://github.com/open-policy-agent/opa/issues/5833)) authored by @kjothen +- server: Add a note to explicitly point out if OPA binds to the 0.0.0.0 interface on server initialization ([#5090](https://github.com/open-policy-agent/opa/issues/5090)) authored by @Parsifal-M +- Include trace and span identifier in decision logs to help with correlating logs and trace data ([#5230](https://github.com/open-policy-agent/opa/issues/5230)) authored by @ashutosh-narkar + +### Topdown and Rego + +- ast: Disallow partial object rules to have other partial object rule within their immediate extent ([#5855](https://github.com/open-policy-agent/opa/issues/5855)) authored by @johanfylling +- ast: Disallow multi-value rules to have other rules in their extent ([#5813](https://github.com/open-policy-agent/opa/issues/5813)) authored by @johanfylling +- ast: Set result of groundness check on indexer's AllRules func so that rule evaluation for complete rules is not skipped ([#5857](https://github.com/open-policy-agent/opa/issues/5857)) authored by @ashutosh-narkar +- rego: Fix duplicate text in error message during module parsing ([#5837](https://github.com/open-policy-agent/opa/pull/5837)) authored by @TzlilSwimmer123 +- planner: Fix bugs that have an impact on IR ([#5829](https://github.com/open-policy-agent/opa/pull/5829)) and Wasm usage ([#5839](https://github.com/open-policy-agent/opa/pull/5839)) authored by @srenatus +- ast: Include information about the location of rule value and reference in the AST's JSON representation based on the provided custom parsing options ([#5790](https://github.com/open-policy-agent/opa/issues/5790)) authored by @Trolloldem +- ast: Fix issue with unset annotation data when custom parsing options provided ([#5826](https://github.com/open-policy-agent/opa/issues/5826)) authored by @charlieegan3 + +### Docs + +- docs/rest-api: Update Compile API docs to include some use-cases ([#5858](https://github.com/open-policy-agent/opa/pull/5858)) authored by @charlieegan3 +- docs/extensions: Add Nondeterministic field to the Rego object initialization in the code example for the Custom Built-in Function section ([#5861](https://github.com/open-policy-agent/opa/pull/5861)) (authored by @RmStorm) + + +### Website + Ecosystem + +- Ecosystem: + - Reposaur ([#5854](https://github.com/open-policy-agent/opa/pull/5854)) authored by @charlieegan3 + - Update logo for Torque integration ([#5810](https://github.com/open-policy-agent/opa/pull/5810)) authored by @shirabendor-quali + +- Website: + - Reorganize the `MISCELLANEOUS` section to improve content navigation ([#4614](https://github.com/open-policy-agent/opa/issues/4614)) authored by @lakhanjindam + +### Miscellaneous + +- Dependency bumps, notably: + - golang from 1.20.2 to 1.20.3 + - golang.org/x/net from 0.8.0 to 0.9.0 + - github.com/prometheus/client_golang from 1.14.0 to 1.15.0 + + +## 0.51.0 + +This release contains improvements to monitoring and an assortment of fixes and improvements. + +### Monitoring + +#### Surface unauthorized request count from OPA HTTP API authz handler via Status API + +Currently when OPA's HTTP server rejects requests per +the [authz policy](https://www.openpolicyagent.org/docs/latest/security/#authentication-and-authorization), +this is not accounted for via the management APIs. +This change adds that count in the metric registry that is +part of the Status API for more visibility. + +([#3378](https://github.com/open-policy-agent/opa/issues/3378)) authored by @ashutosh-narkar. + +#### Surface more decision log errors via Status API + +Previously in [5732](https://github.com/open-policy-agent/opa/pull/5732), +we updated the decision log plugin to +surface errors via the Status API. However, in that change +certain events like encoder errors and log drops due to +buffer size limits had no metrics associated with them. +This change adds more metrics for these events so that they +can be surfaced via the Status API. + +([#5637](https://github.com/open-policy-agent/opa/issues/5637)) authored by @ashutosh-narkar. + +#### Include truncated HTTP response in logs + +This change updates the client debug log to include +the full HTTP response in case of non-200 status codes. +Recording the response in the logs can help to provide +more information to debug error scenarios. + +([#2961](https://github.com/open-policy-agent/opa/issues/2961)) authored by @ashutosh-narkar reported by @gshively11. + +### Topdown and Rego + +- Wasm: Add native support for `object.union_n` built-in function (authored by @Azanul) + +### Fixes + +- ast: Properly set the reported location of unused variables in strict-mode errors. ([#5662](https://github.com/open-policy-agent/opa/issues/5662)) authored by @boranx +- fmt: report wrong arity for built-in functions. ([#5646](https://github.com/open-policy-agent/opa/issues/5646)) authored by @Trolloldem +- topdown: http.send(): Ensuring intra-query caching consistency. ([#5736](https://github.com/open-policy-agent/opa/issues/5736)) authored by @johanfylling +- Performance improvements to decision logging. + Specifically, by removing superfluous json encoding roundtrip and double work in AST conversion of to-be-logged events. (authored by @srenatus) + +### Docs, Website, and Ecosystem + +- Fix typo in documentation (authored by @eternaltyro) +- Update TLS authentication docs (authored by @charlieegan3) +- Clarification in docs about checksums of Windows executables (authored by @Ronnie-personal) +- docs: Small fix to context placement in integration (authored by @craigpastro) +- docs/website: Fix floating navbar anchor issue ([5774](https://github.com/open-policy-agent/opa/issues/5774)) authored by @charlieegan3 reported by @kristiansvalland + +### Miscellaneous + +- Update -debug images to use Chainguard images ([5544](https://github.com/open-policy-agent/opa/issues/5544)) (authored by @charlieegan3) +- Various third-party dependencies were updated. + +## 0.50.2 + +This is a bug fix release that addresses a regression in 0.50.1. +This regression impacts policies with rules that, as its else-value, assign a comprehension containing variables. +Such rules would cause the compilation of the policy to fail with a `rego_unsafe_var_error` error. + +E.g. the following policy would fail to compile with a `policy.rego:5: rego_unsafe_var_error: var x is unsafe` error: +```rego +package example + +p { + false +} else := [x | x := 1] +``` + +### Fixes + +- ast: Fixing bug where comprehensions in rule else-heads weren't rewritten correctly ([#5771](https://github.com/open-policy-agent/opa/issues/5771)) authored by @johanfylling reported by @davidmdm + +## 0.50.1 + +This is a bug fix release addressing the following issues: + +### Fixes + +- ast/compile: Guard recursive module equality check. ([#5756](https://github.com/open-policy-agent/opa/issues/5756)) authored by @philipaconrad. + Resolves a performance regression when using large bundles. +- ast: Relaxing strict-mode check for unused args in else-branching functions ([#5758](https://github.com/open-policy-agent/opa/issues/5758)) authored by @johanfylling reported by @ethanjli. + +### Miscellaneous + +- Use normalized policy paths as compiler module keys and store IDs (authored by @ashutosh-narkar). + Resolves an issue with bundle loading on Windows. + +## 0.50.0 + +This release contains a mix of new features, bugfixes, security fixes, optimizations and build updates related to +OPA's published images. + +### New Built-in Functions: JSON Schema Verification and Validation + +These new built-in functions add functionality to verify and validate JSON Schema ([#5486](https://github.com/open-policy-agent/opa/pull/5486)) (co-authored by @jkulvich and @johanfylling). + +- `json.verify_schema`: Checks that the input is a valid JSON schema object +- `json.match_schema`: Checks that the document matches the JSON schema + +See the [documentation](https://www.openpolicyagent.org/docs/v0.50.0/policy-reference/#object) for all details. + +### Annotations scoped to `package` carries across modules + +`package` scoped schema annotations are now applied across modules instead of only local to the module where +it's declared ([#5251](https://github.com/open-policy-agent/opa/issues/5251)) (authored by @johanfylling). This change may cause compile-time errors and behavioural changes to +type checking when the `schemas` annotation is used, and to rules calling the `rego.metadata.chain()` built-in function: + + - Existing projects with the same package declared in multiple files will trigger a `rego_type_error: package annotation redeclared` +error _if_ two or more of these are annotated with the `package` scope. + - If using the `package` scope, the `schemas` annotation will be applied to type checking also for rules declared in +another file than the annotation declaration, as long as the package is the same. + - The chain of metadata returned by the `rego.metadata.chain()` built-in function will now contain an entry for the +package even if the annotations are declared in another file, if the scope is `package`. + +### Remote bundle URL shorthand for `run` command + +To load a remote bundle using `opa run`, the `set` directive can be provided multiple times as shown below: +``` + $ opa run -s --set "services.default.url=https://example.com" \ + --set "bundles.example.service=default" \ + --set "bundles.example.resource=/bundles/bundle.tar.gz" \ + --set "bundles.example.persist=true" +``` + +The following command can be used as a shorthand to easily start OPA with a remote bundle ([#5674](https://github.com/open-policy-agent/opa/issues/5674)) (authored by @anderseknert): +``` +$ opa run -s https://example.com/bundles/bundle.tar.gz +``` + +### Performance Improvements for `json.patch` Built-in Function + +Performance improvements in `json.patch` were achieved with the introduction of a new `EditTree` data structure, +which is built for applying in-place modifications to an `ast.Term`, and can render the final result of all edits efficiently +by applying all patches in a JSON-Patch sequence rapidly, and then collapsing all edits at the end with minimal wasted `ast.Term` copying (authored by @philipaconrad). +For more details and benchmarks refer [#5494](https://github.com/open-policy-agent/opa/pull/5494) and [#5390](https://github.com/open-policy-agent/opa/pull/5390). + +### Surface decision log errors via status API + +Errors encountered during decision log uploads will now be surfaced via the Status API in addition to being logged. This +functionality should give users greater visibility into any issues OPA may face while processing, uploading logs etc ([#5637](https://github.com/open-policy-agent/opa/issues/5637)) (authored by @ashutosh-narkar). + +See the [documentation](https://www.openpolicyagent.org/docs/v0.50.0/management-status/#status-service-api) for more details. + +### OPA Published Images Update + +All published OPA images now run with a non-root uid/gid. The `uid:gid` is set to `1000:1000` for all images. As a result +there is no longer a need for the `-rootless` image variant and hence it will be not be published as part of future releases. +This change is in line with container security best practices. OPA can still be run with root privileges by explicitly setting the user, +either with the `--user` argument for `docker run`, or by specifying the `securityContext` in the Kubernetes Pod specification. + + +### Runtime, Tooling, SDK + +- server: Support compression of response payloads if HTTP client supports it ([#5310](https://github.com/open-policy-agent/opa/issues/5310)) authored by @AdrianArnautu +- bundle: Ensure the bundle resulting from merging a set of bundles does not contain `nil` data ([#5703](https://github.com/open-policy-agent/opa/issues/5703)) authored by @anderseknert +- repl: Use lowercase for repl commands only and keep any provided arguments as-is ([#5229](https://github.com/open-policy-agent/opa/issues/5229)) authored by @Trolloldem +- metrics: New endpoint `/metrics/alloc_bytes` to show OPA's memory utilization ([#5715](https://github.com/open-policy-agent/opa/pull/5715)) authored by @anderseknert +- server: When using OPA TLS authorization, authz policy authors will now have access to the client certificates +presented as part of the TLS connection. This new data will be available under the key `client_certificates` ([#5538](https://github.com/open-policy-agent/opa/issues/5538)) authored by @charlieegan3 +- server: Use streaming implementation of json.Decode rather than using an intermediate buffer for the incoming request ([#5661](https://github.com/open-policy-agent/opa/pull/5661)) authored by @anderseknert + +### Topdown and Rego + +- ast: Extend compiler `strict` mode check to include unused arguments ([#5602](https://github.com/open-policy-agent/opa/issues/5602)) authored by @boranx. This change may cause +compile-time errors for policies that have unused arguments in the scope when the `strict` mode is enabled. These +variables could be replaced with `_` (wildcard) or get cleaned up if they are not intended to be used in the body of the functions. +- ast: Respect inlined `schemas` annotations even if `--schema` flag isn't used ([#5506](https://github.com/open-policy-agent/opa/issues/5506)) authored by @johanfylling +- ast: Force type-checker to respect `allow_net` capability when fetching remote schemas ([#5670](https://github.com/open-policy-agent/opa/issues/5670)) authored by @johanfylling +- ast/parse: Provide custom parsing options that allow location information of AST nodes to be included in their JSON +representation. This location information can be used by tools that work with the OPA AST ([#3143](https://github.com/open-policy-agent/opa/issues/3143)) authored by @charlieegan3 + +### Docs + +- docs/policy-reference: Fix typo in policy reference doc ([#5654](https://github.com/open-policy-agent/opa/pull/5654)) authored by @alvarogomez93 +- docs/extensions: Fix sample code provided in the custom built-in implementation example ([#5666](https://github.com/open-policy-agent/opa/pull/5666)) authored by @Ronnie-personal +- docs/bundles: Clarify delta bundle behavior when it contains an empty list of patch operations ([#5629](https://github.com/open-policy-agent/opa/issues/5629)) authored by @charlieegan3 +- docs/http-api-authz: Update the HTTP API authz tutorial with steps related to proper bundle creation ([#5682](https://github.com/open-policy-agent/opa/pull/5682)) authored by @lamoboos223 +- Fix broken 'future keywords' url link ([#5686](https://github.com/open-policy-agent/opa/pull/5686)) authored by @neelanjan00 + + +### Website + Ecosystem + +- Ecosystem: + - Styra Load ([#5659](https://github.com/open-policy-agent/opa/pull/5659)) authored by @charlieegan3 + +- Website: + - Update OPA documentation search to use Algolia v3 ([#5706](https://github.com/open-policy-agent/opa/pull/5706)) authored by @Parsifal-M + - Drop Google Universal Analytics (UA) code as part of Google Analytics 4 migration (authored by @chalin) + +### Miscellaneous + +- Dependency bumps, notably: + - golang from 1.20.1 to 1.20.2 + - github.com/containerd/containerd from 1.6.16 to 1.6.19 + - github.com/golang/protobuf from 1.5.2 to 1.5.3 + - golang.org/x/net from 0.5.0 to 0.8.0 + - google.golang.org/grpc from 1.52.3 to 1.53.0 + - OpenTelemetry-related dependencies (#5701) + + +## 0.49.2 + +This release migrates the [ORAS Go library](oras.land/oras-go/v2) from v1.2.2 to v2. +The earlier version of the library had a dependency on the [docker](github.com/docker/docker) +package. That version of the docker package had some reported vulnerabilities such as +CVE-2022-41716, CVE-2022-41720. The ORAS Go library v2 removes the dependency on the docker package. + +## 0.49.1 + +This is a bug fix release addressing the following Golang security issues: + +### Golang security fix CVE-2022-41723 + +> A maliciously crafted HTTP/2 stream could cause excessive CPU consumption in the HPACK decoder, sufficient to cause a +> denial of service from a small number of small requests. + +### Golang security fix CVE-2022-41724 + +> Large handshake records may cause panics in crypto/tls. Both clients and servers may send large TLS handshake records +> which cause servers and clients, respectively, to panic when attempting to construct responses. + +### Golang security fix CVE-2022-41722 + +> A path traversal vulnerability exists in filepath.Clean on Windows. On Windows, the filepath.Clean function could +> transform an invalid path such as "a/../c:/b" into the valid path "c:\b". This transformation of a relative +> (if invalid) path into an absolute path could enable a directory traversal attack. +> After fix, the filepath.Clean function transforms this path into the relative (but still invalid) path ".\c:\b". + +## 0.49.0 + +This release focuses on bugfixes and documentation improvements, as well as a few small performance improvements. + +### Runtime, Tooling, SDK + +- runtime: Update rule index's trie node scalar handling so that numerics compare correctly ([#5585](https://github.com/open-policy-agent/opa/issues/5585)) authored by @ashutosh-narkar reported by @alvarogomez93 +- ast: Improve error information when metadata yaml fails to compile ([#4475](https://github.com/open-policy-agent/opa/issues/4475)) authored and reported by @johanfylling +- bundle: Retain metadata annotations for Wasm entrypoints during inspection ([#5588](https://github.com/open-policy-agent/opa/issues/5588)) authored and reported by @johanfylling +- compile: Allow object generating rules to be annotated as entrypoints ([#5577](https://github.com/open-policy-agent/opa/issues/5577)) authored and reported by @johanfylling +- plugins/discovery: Support for persisting and loading discovery bundle from disk ([#2886](https://github.com/open-policy-agent/opa/issues/2886)) authored by @ashutosh-narkar reported by @anderseknert +- perf: Use `json.Encode` to avoid extra allocation (authored by @anderseknert) +- `opa inspect`: Fix prefix error when inspecting bundle from root ([#5503](https://github.com/open-policy-agent/opa/issues/5503)) authored by @harikannan512 reported by @HarshPathakhp +- topdown: `http.send` to cache responses based on status code ([#5617](https://github.com/open-policy-agent/opa/issues/5617)) authored by @ashutosh-narkar +- types: Add GoDoc about named types (authored by @wata727) +- deps: Remove `github.com/pkg/errors` dependency (authored by @Iceber) + + +### Docs + +- Update entrypoint documentation ([#5565](https://github.com/open-policy-agent/opa/issues/5565)) authored by @johanfylling reported by @robertgartman +- Add missing folder argument in bundle build example (authored by @charlieegan3) +- Clarify `crypto.x509.parse_certificates` docs (authored by @charlieegan3) +- Added AWS S3 Web Identity Credentials info to tutorial (authored by @vishrana) +- docs/graphql: non-nullable id argument and typo fix (authored by @philipaconrad) + +### Website + Ecosystem + +- Ecosystem: + - ccbr (authored by @niuzhi) + +- Website: + - Show prominent warning when viewing old docs (authored by @charlieegan3) + - Prevent navbar clipping on narrow screens + sticky nav (authored by @charlieegan3) + +### Miscellaneous + +Dependency bumps: +- build: bump golang 1.19.4 -> 1.19.5 (authored by @yanggangtony) +- ci: aquasecurity/trivy-action from 0.8.0 to 0.9.0 +- github.com/containerd/containerd from 1.6.15 to 1.6.16 +- google.golang.org/grpc from 1.51.0 to 1.52.3 + +## 0.48.0 + +This release rolls in security fixes from recent patch releases, along with +a number of bugfixes, and a new builtin function. + +### Improved error reporting available in `opa eval` + +A common frustration when writing policies in OPA is when an error happens, +causing a rule to unexpectedly return `undefined`. Using +`--strict-builtin-errors` would allow finding the first error encountered +during evaluation, but terminates execution immediately. + +To improve the debugging experience, it is now possible to display *all* of +the errors encountered during normal evaluation of a policy, via the new +`--show-builtin-errors` option. + +Consider the following error-filled policy, `multi-error.rego`: + +```rego +package play + +this_errors(number) := result { + result := number / 0 +} + +this_errors_too(number) := result { + result := number / 0 +} + +res1 := this_errors(1) + +res2 := this_errors_too(1) +``` + +Using `--strict-builtin-errors`, we would only see the first divide by zero +error: + + opa eval --strict-builtin-errors -d multi-error.rego data.play + +``` +1 error occurred: multi-error.rego:4: eval_builtin_error: div: divide by zero +``` + +Using `--show-builtin-errors` shows both divide by zero issues though: + + opa eval --show-builtin-errors -d multi-error.rego data.play -f pretty + +``` +2 errors occurred: +multi-error.rego:4: eval_builtin_error: div: divide by zero +multi-error.rego:8: eval_builtin_error: div: divide by zero +``` + +By showing more errors up front, we hope this will improve the overall +policy writing experience. + +### New Built-in Function: `time.format` + +It is now possible to format a time value from nanoseconds to a formatted +timestamp string via a built-in function. The builtin accepts 3 argument +formats, each allowing for different options: + + 1. A number representing the nanoseconds since the epoch (UTC). + 2. A two-element array of the nanoseconds, and a timezone string. + 3. A three-element array of nanoseconds, timezone string, and a layout + string (same format as for `time.parse_ns`). + +See [the documentation](https://www.openpolicyagent.org/docs/v0.48.0/policy-reference/#builtin-time-timeformat) +for all details. + +Implemented by @burnerlee. + +### Optimization in rule indexing + +Previously, every time the evaluator looked up a rule in the index, OPA +performed checks for grounded refs over the entire index *before* looking +up the rule. + +Now, OPA performs all groundedness checks once at index construction time, +which keeps index lookup times much more consistent as the number of +indexed rules scales up. + +Policies with large numbers of index-ready rules can expect a small +performance lift, proportional to the number of indexed rules. + +### Bundle fetching with AWS Signing Version 4A + +AWS has recently developed an extension to SigV4 called Signature Version +4A (SigV4A) which enables signatures that are valid in more than one AWS +Region. This new signature method is required for signing multi-region API +requests, such as Amazon S3 Multi-Region Access Points (MRAP). + +OPA now supports this new request signing method for bundle fetching, which +means that you can use an S3 MRAP as a bundle source. This is configured +via the new `services[].credentials.s3_signing.signature_version` +field. + +See the [the documentation](https://www.openpolicyagent.org/docs/v0.48.0/configuration/#aws-signature) +for more details. + +Implemented by @jwineinger + +### Runtime + +- rego: Check store modules before skipping parsing (authored by @charlieegan3) +- topdown/rego: Add BuiltinErrorList support to rego package, add to eval command (authored by @charlieegan3) +- topdown: Fix evaluator's re-wrapping of `NDBCache` errors (authored by @srenatus) +- Fix potential memory leak from `http.send` in interquery cache (authored by @asleire) +- ast/parser: Detect function rule head + `contains` keyword ([#5525](https://github.com/open-policy-agent/opa/issues/5525)) authored and reported by @philipaconrad +- ast/visit: Add `SomeDecl` to visitor walks ([#5480](https://github.com/open-policy-agent/opa/issues/5480)) authored by @srenatus +- ast/visit: Include `LazyObject` in visitor walks ([#5479](https://github.com/open-policy-agent/opa/issues/5479)) authored by @srenatus reported by @benweint + +### Tooling, SDK + +- topdown: cache undefined rule evaluations ([#593](https://github.com/open-policy-agent/opa/issues/593)) authored by @edpaget reported by @tsdandall +- topdown: Specify host verification policy for http redirects ([#5388](https://github.com/open-policy-agent/opa/issues/5388)) authored and reported by @ashutosh-narkar +- providers/aws: Refactor + Fix 2x Authorization header append issue ([#5472](https://github.com/open-policy-agent/opa/issues/5472)) authored by @philipaconrad reported by @Hiieu +- Add support to enable ND builtin cache via discovery ([#5457](https://github.com/open-policy-agent/opa/issues/5457)) authored by @ashutosh-narkar reported by @asadali +- format: Only use ref heads for all rule heads if necessary ([#5449](https://github.com/open-policy-agent/opa/issues/5449)) authored and reported by @srenatus +- `opa inspect`: Fix path of data namespaces on windows (authored by @shm12) +- ast+cmd: Only enforcing `schemas` annotations if `--schema` flag is used (authored by @johanfylling) +- sdk: Allow use of a query tracer (authored by @charlieegan3) +- sdk: Allow use of metrics, profilers, and instrumentation (authored by @charlieegan3) +- sdk: Return provenance information in Result types (authored by @charlieegan3) +- sdk: Allow use of StrictBuiltinErrors (authored by @charlieegan3) +- Allow print calls in IR (authored by @anderseknert) +- tester/runner: Fix panic'ing case in utility function ([#5496](https://github.com/open-policy-agent/opa/issues/5496)) authored and reported by @philipaconrad + +### Docs + +- Community page updates (authored by @anderseknert) +- Update Hugo version, update deprecated Page fields (authored by @charlieegan3) +- docs: Update TLS-based Authentication Example ([#5521](https://github.com/open-policy-agent/opa/issues/5521)) authored by @charlieegan3 reported by @jjthom87 +- docs: Update opa eval flags to link to bundle docs (authored by @charlieegan3) +- docs: Make SDK first option for Go integraton (authored by @anderseknert) +- docs: Fix typo on Policy Language page. (authored by @mcdonagj) +- docs/integrations: Update kubescape repo links (authored by @dwertent) +- docs/oci: Corrected config section (authored by @ogazitt) +- website/frontpage: Update Learn More links (authored by @pauly4it) + +- integrations.yaml: Ensure inventors listed in organizations (authored by @anderseknert) +- integrations: Fix malformed inventors item (authored by @anderseknert) +- Add Digraph to ADOPTERS.md (authored by @jamesphlewis) + +### Miscellaneous + +- Remove changelog maintainer mention filter (authored by @anderseknert) +- Chore: Fix len check in the `ast/visit_test` error message (authored by @boranx) +- `opa inspect`: Fix wrong windows bundle tar files path separator (authored by @shm12) +- Add CHANGELOG.md to website build triggers (authored by @srenatus) + +Dependency bumps: +- Golang 1.19.3 -> 1.19.4 +- github.com/containerd/containerd from 1.6.10 -> 1.6.15 +- github.com/dgraph-io/badger/v3 +- golang.org/x/net to 0.5.0 +- json5 and postcss-modules +- oras.land/oras-go from 1.2.1 -> 1.2.2 + +CI/Distribution fixes: +- Update base images for non debug builds (authored by @charlieegan3) +- Remove deprecated linters in golangci config (authored by @yanggangtony) + +## 0.47.4 + +This is a bug fix release addressing a panic in `opa test`. + + - tester/runner: Fix panic'ing case in utility function. ([#5496](https://github.com/open-policy-agent/opa/issues/5496)) authored by @philipaconrad + +## 0.47.3 + +This is a bug fix release addressing an issue that prevented OPA from fetching bundles stored in S3 buckets. + + - providers/aws: Refactor + fix 2x Authorization header append issue. ([#5472](https://github.com/open-policy-agent/opa/issues/5472)) authored by @philipaconrad, reported by @Hiieu + +## 0.47.2 and 0.46.3 + +This is a second security fix to address CVE-2022-41717/GO-2022-1144. + +We previously believed that upgrading the Golang version and its stdlib would be sufficient +to address the problem. It turns out we also need to bump the x/net dependency to v0.4.0., +a version that hadn't existed when v0.46.2 was released. + +This release bumps the golang.org/x/net dependency to v0.4.0, and contains no other +changes over v0.46.2. + +Note that the affected code is OPA's HTTP server. So if you're using OPA as a Golang library, +or if your confident that your OPA's HTTP interface is protected by other means (as it should +be -- not exposed to the public internet), you're OK. + +## 0.47.1 and 0.46.2 + +This is a bug fix release addressing two issues: one security issue, and one bug +related to formatting backwards-compatibility. + +### Golang security fix CVE-2022-41717 + +> An attacker can cause excessive memory growth in a Go server accepting HTTP/2 requests. + +Since we advise against running an OPA service exposed to the general public of the +internet, potential attackers would be limited to people that are already capable of +sending direct requests to the OPA service. + +### `opa fmt` and backwards compatibility ([#5449](https://github.com/open-policy-agent/opa/issues/5449)) + +In v0.46.1, it was possible that `opa fmt` would format a rule in such a way that: + +1. Before formatting, it was working fine with older OPA versions, and +2. after formatting, it would only work with OPA version >= 0.46.1. + +This backwards incompatibility wasn't intended, and has now been fixed. + +## 0.47.0 + +This release contains a mix of bugfixes, optimizations, and new features. + +### New Built-in Function: `object.keys` + +It is now possible to conveniently retrieve an object's keys via a built-in function. + +Before, you had to resort to constructs like + +```rego +import future.keywords.in + +keys[k] { + _ = input[k] +} + +allow if "my_key" in keys +``` + +Now, you can simply do + +```rego +import future.keywords.in + +allow if "my_key" in object.keys(input) +``` + +See [the documentation](https://www.openpolicyagent.org/docs/v0.47.0/policy-reference/#builtin-object-objectkeys) +for all details. + +Implemented by @kevinswiber. + +### New Built-in Function: AWS Signature v4 Request Signing + +It is now possible to use a built-in function to prepare a request with a signature, so that +it can be used with AWS endpoints that use request signing for authentication. + +See this example: + +```rego +req := {"method": "get", "url": "https://examplebucket.s3.amazonaws.com/data"} +aws_config := { + "aws_access_key": "MYAWSACCESSKEYGOESHERE", + "aws_secret_access_key": "MYAWSSECRETACCESSKEYGOESHERE", + "aws_service": "s3", + "aws_region": "us-east-1", +} +example_verify_resource { + resp := http.send(providers.aws.sign_req(req, aws_config, time.now_ns())) + # process response from AWS ... +} +``` + +See [the documentation on the new built-in](https://www.openpolicyagent.org/docs/v0.47.0/policy-reference/#providers.aws) +for all details. + +Reported by @jicowan and implemented by @philipaconrad. + +### Performance improvements for `object.get` and `in` operator + +Before, using `object.get` and `in` had come with a performance penalty that wasn't +to be expected just from the look of the calls: Since they have been implemented using +built-in functions (obvious for `object.get`, not obvious for `"admin" in input.user.roles`), +all of their operands had to be read from the store (if applicable) and converted into +AST types. + +Now, we use shallow references ("lazy objects") for store reads in the evaluator. +In these two cases, this can bring huge performance improvements, when the object +argument of these two calls is a ref into the base document (like `data.users`): + +```rego +object.get(data.roles, input.role, []) +{ "id": 12 } in data.users +``` + +### Tooling, SDK, and Runtime + +- `opa eval`: Added `--strict` to enable strict code checking in evaluation ([#5182](https://github.com/open-policy-agent/opa/issues/5182)) authored by @Parsifal-M +- `opa fmt`: Remove `{ true }` block following `else` head +- `opa fmt`: Generate new wildcards for else and chained function heads in the parser ([#5347](https://github.com/open-policy-agent/opa/issues/5347)). This fixes superfluous + introductions of `_1` instead of `_` in when formatting functions that use wildcard arguments, like `f(_) := true`. +- `opa fmt`: Fix assignment rewrite in else formatting ([#5348](https://github.com/open-policy-agent/opa/issues/5348)) +- OCI Download: Set auth credentials only if needed ([#5212](https://github.com/open-policy-agent/opa/issues/5212)) authored by @carabasdaniel +- Server: Differentiate between "missing" and "undefined doc" in default decision ([#5344](https://github.com/open-policy-agent/opa/issues/5344)) + +### Topdown and Rego + +- `http.send`: Fix interquery cache size calculation with concurrent requests ([#5359](https://github.com/open-policy-agent/opa/issues/5359)) reported and authored by @asleire +- `http.send`: Remove socket query param for unix sockets ([#5313](https://github.com/open-policy-agent/opa/issues/5313)) reported and authored by @michivi +- Annotations: Add type coercion guards to avoid panics ([#5368](https://github.com/open-policy-agent/opa/issues/5368)) +- Compiler: Provide more accurate error locations for `some` with unused vars ([#4238](https://github.com/open-policy-agent/opa/issues/4238)) +- Optimization: Read lazy objects from the store ([#5325](https://github.com/open-policy-agent/opa/issues/5325)). This improves the performance of `x in data.foo` and `object.get(data.bar, ...)` calls significantly. +- Partial Evaluation: Skip comprehensions when checking eqs in copy propagation ([#5367](https://github.com/open-policy-agent/opa/issues/5367)). This fixes a bug when optimization on bundles would change the outcome of the subsequent evaluation. +- Parser: Fix else error handling with ref heads -- errors had occurred at a later stage then desired, because an edge case slipped through the earlier check. +- Planner/IR: Fix ref heads processing -- the CallDynamic optimization wasn't planned properly; a bug introduced with ref heads. + +### Documentation + +- Builtins: Mention base64 URL encoding specifically ([#5406](https://github.com/open-policy-agent/opa/issues/5406)) reported by @phi1010 +- Builtins: Include behavior with sets in `json.patch` ([#5328](https://github.com/open-policy-agent/opa/issues/5328)) +- Comparison: small fix to table to match sample code and other tables (authored by @anlandu) +- Builtins: Document reference timestamp behavior for `time.parse_ns` +- Typo fixes, authored by @deining +- Golang integration: update example code, move SDK above low-level packages + +### Website + Ecosystem + +- Ecosystem: + - Add Easegress (authored by @localvar) + - Add Terraform Cloud +- Website: Updated Footer Color ([#5254](https://github.com/open-policy-agent/opa/issues/5254)), reported and authored by @UtkarshMishra12 +- Website: Add "canonical" link to latest to help with SEO and ancient pages being returned by search engines. +- Website: Add experimental "OPA version" badge. (Still needs to be tested more thorougly before advertisting it.) + +### Miscellaneous + +- Dependency bumps: Notably, we're now using wasmtime-go v3 +- CI fixes: + - Move performance tests to nightly tests + - CLI: add simple bundle build tests + - Nightly: Revamp how we're doing fuzz testing + +## 0.46.1 + +This is bugfix release to resolve an issue in the release pipeline. Everything else is +the same as 0.46.0. + +## 0.46.0 + +This release contains a mix of bugfixes, optimizations, and new features. + +### New language feature: refs in rule heads + +With this version of OPA, we can use a shorthand for defining deeply-nested structures +in Rego: + +Before, we had to use multiple packages, and hence multiple files to define a structure +like this: +```json +{ + "method": { + "get": { + "allowed": true + } + "post": { + "allowed": true + } + } +} +``` + +```rego +package method.get +default allowed := false +allowed { ... } +``` + + +```rego +package method.post +default allowed := false +allowed { ... } +``` + +Now, we can define those rules in single package (and file): + +```rego +package method +import future.keywords.if +default get.allowed := false +get.allowed if { ... } + +default post.allowed := false +post.allowed if { ... } +``` + +Note that in this example, the use of the future keyword `if` is mandatory +for backwards-compatibility: without it, `get.allowed` would be interpreted +as `get["allowed"]`, a definition of a partial set rule. + +Currently, variables may only appear in the last part of the rule head: + +```rego +package method +import future.keywords.if + +endpoints[ep].allowed if ep := "/v1/data" # invalid +repos.get.endpoint[x] if x := "/v1/data" # valid +``` + +The valid rule defines this structure: +```json +{ + "method": { + "repos": { + "get": { + "endpoint": { + "/v1/data": true + } + } + } + } +} +``` + +To define a nested key-value pair, we would use + +```rego +package method +import future.keywords.if + +repos.get.endpoint[x] = y if { + x := "/v1/data" + y := "example" +} +``` + +Multi-value rules (previously referred to as "partial set rules") that are +nested like this need to use `contains` future keyword, to differentiate them +from the "last part is a variable" case mentioned just above: + +```rego +package method +import future.keywords.contains + +repos.get.endpoint contains x if x := "/v1/data" +``` + +This rule defines the same structure, but with multiple values instead of a key: +```json +{ + "method": { + "repos": { + "get": { + "endpoint": ["/v1/data"] + } + } + } +} +``` + +To ensure that it's safe to build OPA policies for older OPA versions, a new +capabilities field was introduced: "features". It's a free-form string array: + +```json +{ + "features": [ + "rule_head_ref_string_prefixes" + ] +} +``` + +If this key is not present, the compiler will reject ref-heads. This could be +case when building bundles for older OPA version using their capabilities. + + +### Entrypoint annotations in rule metadata + +It is now possible to annotate a rule with `entrypoint: true`, and it will +automatically be picked up by the tooling that expected `--entrypoint` (`-e`) +parameters before. + +For example, to build this rego policy into a wasm module, you had to pass +an entrypoint: + +```rego +package test +allow { + input.x +} +``` +- `opa build --target wasm --entrypoint test/allow policy.rego` + +With the annotation: +```rego +package test + +# METADATA +# entrypoint: true +allow { + input.x +} +``` +- `opa build --target wasm policy.rego` + +The places where entrypoints are taken from metadata are: + +1. Building optimized bundles +2. Building Wasm bundles +3. Building Plan bundles +4. Using optimization with `opa eval` + +Knowing a module's entrypoints can also help in different analysis tasks. + +### New Built-in Functon: `graphql.schema_is_valid` + +The new built-in allows checking schemas: + +```rego +schema := ` + extend type User { + id: ID! + } + extend type Product { + upc: String! + } + union _Entity = Product | User + extend type Query { + entity: _Entity + } +` +valid_schema_example { + graphql.schema_is_valid(schema) +} +``` + +Requested by @olegroom. + +### New Built-in Functon: `net.cidr_is_valid` + +The new built-in function allows checking if a string is a valid CIDR. + +```rego +valid_cidr_example { + net.cidr_is_valid("192.168.0.0/24") +} +``` + +Authored by @ricardomaraschini. + +### Tooling, SDK, and Runtime + +- `opa build`: exit with failure on empty signing key ([#4972](https://github.com/open-policy-agent/opa/issues/4972)) authored by @Joffref reported by @caldwecr +- `opa exec`: add `--fail` and `--fail-defined` flags ([#5007](https://github.com/open-policy-agent/opa/issues/5007)) authored by @byronic reported by @phantlantis +- `opa exec`: convert slashes of explicit bundles (Windows) ([#5134](https://github.com/open-policy-agent/opa/issues/5134)) reported by @peterchenadded +- `opa test`: check coverage limit range `[0, 100]` ([#5284](https://github.com/open-policy-agent/opa/issues/5284)) authored by @hzliangbin reported by @aholmis +- `opa build`+`opa check`: respect capabilities for parsing, i.e. future keywords ([#5323](https://github.com/open-policy-agent/opa/issues/5323)) reported by @TheLunaticScripter +- `opa bench --e2e`: support providing OPA config ([#4899](https://github.com/open-policy-agent/opa/issues/4899)) +- `opa eval`: new explain mode, `--explain=debug`, that includes unifcations in traces (authored by @jaspervdj) + +- Decision logs: Allow rule-based dropping of decision log entries ([#3945](https://github.com/open-policy-agent/opa/issues/3945)) authored by @mariusblarsen and @iamatwork +- Decision Logs: Include the `req_id` attribute in the decision logs ([#5006](https://github.com/open-policy-agent/opa/issues/5006)) reported and authored by @humbertoc-silva +- Plugins: export OpenTelemetry TracerProvider for use in plugins (authored by @vinhph0906) + + +### Compiler + Topdown + +- `graph.reachable_path`: fix issue with missing subpaths ([#4666](https://github.com/open-policy-agent/opa/issues/4666)) authored by @fredallen-wk +- `http.send`: Ensure `force_cache` attribute ignores `Date` header ([#4960](https://github.com/open-policy-agent/opa/issues/4960)) reported by @bartandacc +- `with`: Allow replacing functions with rules ([#5299](https://github.com/open-policy-agent/opa/issues/5299)) +- Evaluation: Skip default functions in full extent ([#5202](https://github.com/open-policy-agent/opa/issues/5202)) reported by @ericjkao +- Evaluation: capture more cases of conflicts in function evaluation ([#5272](https://github.com/open-policy-agent/opa/issues/5272)) +- Rule Indexing: fix incorrect results from indexing `glob.match` even if output is captured ([#5283](https://github.com/open-policy-agent/opa/issues/5283)) + +- Planner: various correctness fixes: [#5271](https://github.com/open-policy-agent/opa/issues/5271), [#5265](https://github.com/open-policy-agent/opa/issues/5265), [#5252](https://github.com/open-policy-agent/opa/issues/5252) + +- Builtins: Refactor registration functions and signatures (authored by @philipaconrad) +- Compiler: Speed up typechecker when working with Refs (authored by @philipaconrad) +- Trace: add `UnifyOp` to tracer events (authored by @jaspervdj) + +### Documentation + +- Envoy Tutorial: use latest proxy_init (v8) +- Envoy Plugin: Add note about new config param to skip body parsing +- Policy Reference: Add `semver` examples +- Contributing Code: Provide some tips for style fixes + +### Website + Ecosystem + +- Website: Make "outdated version" banner red if looked-at version is ancient +- Ecosystem: Add CircleCI and Topaz + +### Miscellaneous + +- Code Cleanup: + - Don't use the deprecated `ioutil` functions + - Use `t.Setenv` in tests + - Use `t.TempDir` to create temporary test directory (authored by @Juneezee) + - Linters: add `unconvert` and `tenv` +- internal/strvals: port helm strvals fix (CLI --set arguments), reported by @pjbgf, helm fix authored by @mattfarina +- Wasm: Update README + +- Dependency bumps, notably: + - Golang: 1.19.2 -> 1.19.3 + - golang.org/x/text 0.3.7 -> 0.4.0 + - oras.land/oras-go 1.2.0 -> 1.2.1 + +## 0.45.0 + +This release contains a mix of bugfixes, optimizations, and new features. + +### Improved Decision Logging with `nd_builtin_cache` + +OPA has several non-deterministic built-ins, such as `rand.intn` and +`http.send` that can make debugging policies from decision log results +a surprisingly tricky and involved process. To improve the situation +around debugging policies that use those built-ins, OPA now provides +an opt-in system for caching the inputs and outputs of these built-ins +during policy evaluation, and can include this information in decision +log entries. + +A new top-level config key is used to enable the non-deterministic +builtin caching feature, as shown below: + + nd_builtin_cache: true + +This data is exposed to OPA's [decision log masking system](https://www.openpolicyagent.org/docs/v0.45.0/management-decision-logs/#masking-sensitive-data) +under the `/nd_builtin_cache` path, which allows masking or dropping +sensitive values from decision logs selectively. This can be useful +in situations where only some information about a non-deterministic +built-in was needed, or the arguments to the built-in involved +sensitive data. + +To prevent unexpected decision log size growth from non-deterministic +built-ins like `http.send`, the new cache information is included in +decision logs on a best-effort basis. If a decision log event exceeds +the `decision_logs.reporting.upload_size_limit_bytes` limit for an OPA +instance, OPA will reattempt uploading it, after dropping the non- +deterministic builtin cache information from the event. This behavior +will trigger a log error when it happens, and will increment the +`decision_logs_nd_builtin_cache_dropped` metrics counter, so that it +will be possible to debug cases where the cache information is unexpectedly +missing from a decision log entry. + +#### Decision Logging Example + +To observe the change in decision logging we can run OPA in server mode +with `nd_builtin_cache` enabled: + +```bash +opa run -s --set=decision_logs.console=true,nd_builtin_cache=true +``` + +After sending it the query `x := rand.intn("a", 15)` we should see +something like the following in the decision logs: + +``` +{..., "msg":"Decision Log", "nd_builtin_cache":{"rand.intn":{"[\"a\",15]":3}}, "query":"assign(x, rand.intn(\"a\", 15))", ..., "result":[{"x":3}], ..., "type":"openpolicyagent.org/decision_logs"} +``` + +The new information is included under the optional `nd_builtin_cache` +JSON key, and shows what arguments were provided for each unique +invocation of `rand.intn`, as well as what the output of that builtin +call was (in this case, `3`). + +If we sent the query `x := rand.intn("a", 15); y := rand.intn("b", 150)"` +we can see how unique input arguments get recorded in the cache: + +``` +{..., "msg":"Decision Log", "nd_builtin_cache":{"rand.intn":{"[\"a\",15]":12,"[\"b\",150]":149}}, "query":"assign(x, rand.intn(\"a\", 15)); assign(y, rand.intn(\"b\", 150))", ..., "result":[{"x":12,"y":149}], ..., "type":"openpolicyagent.org/decision_logs"} +``` + +With this information, it's now easier to debug exactly why a particular +rule is used or why a rule fails when non-deterministic builtins are used in +a policy. + +### New Built-in Function: `regex.replace` + +This release introduces a new builtin for regex-based search/replace on +strings: `regex.replace`. + +See [the built-in functions docs for all the details](https://www.openpolicyagent.org/docs/v0.45.0/policy-reference/#builtin-regex-regexreplace) + +This implementation fixes [#5162](https://github.com/open-policy-agent/opa/issues/5162) and was authored by @boranx. + +### `object.union_n` Optimization + +The `object.union_n` builtin allows easily merging together an array of Objects. + +Unfortunately, as noted in [#4985](https://github.com/open-policy-agent/opa/issues/4985) +its implementation generated unnecessary intermediate copies from doing +pairwise, recursive Object merges. These pairwise merges resulted in poor +performance for large inputs; in many cases worse than writing the +equivalent operation in pure Rego. + +This release changes the `object.union_n` builtin's implementation to use +a more efficient merge algorithm that respects the original implementation's +sequential, left-to-right merging semantics. The `object.union_n` builtin +now provides a 2-3x improvement in speed and memory efficiency over the pure +Rego equivalent. + +### Tooling, SDK, and Runtime + +- cli: Fix doubled CLI hints/errors. ([#5115](https://github.com/open-policy-agent/opa/issues/5115)) authored by @ivanphdz +- cli/test: Add capabilities flag to test command. (authored by @ivanphdz) +- fmt: Fix blank lines after multiline expressions. (authored by @jaspervdj) +- internal/report: Include heap usage in the telemetry report. +- plugins/logs: Improve error message when decision log chunk size is greater than the upload limit. ([#5155](https://github.com/open-policy-agent/opa/issues/5155)) +- ir: Make the `internal/ir` package public as `ir`. + +### Rego + +- ast/parser+formatter: Allow 'if' in rule 'else' statements. +- ast/schema: Add support for recursive json schema elements. ([#5166](https://github.com/open-policy-agent/opa/issues/5166)) authored and reported by @liamg +- ast/schema: Fix race condition in parsing with reused references.(authored by @liamg) +- internal/gojsonschema: Fix race condition in `SetAllowNet`. ([#5187](https://github.com/open-policy-agent/opa/issues/5187)) authored and reported by @liamg +- ast/compiler: Rewrite declared variables in function calls and recursively rewrite local variables in `with` clauses. ([#5148](https://github.com/open-policy-agent/opa/issues/5148)) authored and reported by @liu-du +- ast: Skip rules when parsing a body (or query) to help improve ambiguous parsing cases. + +### Topdown + +- topdown/object: Rework `object.union_n` to use in-place merge algorithm. (reported by @charlesdaniels) +- topdown/jwt_decode_verify: Ensure `exp` and `nbf` fields are numbers when present. ([#5165](https://github.com/open-policy-agent/opa/issues/5165)) authored and reported by @charlieflowers +- topdown: Fix `InterQueryCache` only dropping one entry when over the size limit. (authored by @vinhph0906) +- topdown+builtins: Block all ND builtins from partial evaluation. +- topdown/builtins: Add Rego Object support for GraphQL builtins to improve composability. +- topdown/json: Fix panic in `json.filter` on empty JSON paths. +- topdown/sets_bench_test: Add `intersection` builtin tests. +- topdown/tokens: Protect against nistec panics. ([#5128](https://github.com/open-policy-agent/opa/issues/5218)) + +### Documentation + +- Add IR to integration docs. +- Added Gloo Edge Tutorial with examples. (authored by @Parsifal-M) +- Updated examples for CLI commands. +- Updated section on performance metrics (authored by @hutchins) +- docs/annotations: Add policy example and a link to the policy reference. ([#4937](https://github.com/open-policy-agent/opa/issues/4937)) authored by @Parsifal-M +- docs/policy-language: Be more explicit about future keywords. +- docs/security: Fix token authz example. (authored by @pigletfly) +- docs: Update generated CLI docs. (authored by @charlieflowers) +- docs: Update mentions of `#development` to `#contributors`. (authored by @charlieflowers) + +### Website + Ecosystem + +- website/security: Style improvements. (authored by @orweis) + +### Miscellaneous + +- ci: Add `prealloc` linter check and linter fixes. +- ci: Add govulncheck to Nightly CI. +- build/wasm: Use golang1.16 `go:embed` mechanism. +- util/backoff: Seed from math/rand source. +- version: Use `runtime/debug.BuildInfo`. + +- Dependency bumps, notably: + - build: bump golang 1.19.1 -> 1.19.2 + - build(deps): bump golang.org/x/net + - build(deps): bump internal/gqlparser to v2.5.1 + - build(deps): bump tj-actions/changed-files from 29.0.3 -> 32.0.0 + - deps(build): bump wasmtime-go 0.36.0 -> 1.0.0 (authored by @Parsifal-M) + +## 0.44.0 + +This release contains a number of fixes, two new builtins, a few new features, +and several performance improvements. + +### Security Fixes + +This release includes the security fixes present in the recent v0.43.1 release, +which mitigate CVE-2022-36085 in OPA itself, and CVE-2022-27664 and +CVE-2022-32190 in our Go build tooling. + +See the Release Notes for v0.43.1 for more details. + +### Set Element Addition Optimization + +Rego Set element addition operations did not scale linearly ([#4999](https://github.com/open-policy-agent/opa/pull/4999)) +in the past, and like the Object type before v0.43.0, experienced noticeable +reallocation/memory movement overheads once the Set grew past 120k-150k elements +in size. + +This release introduces different handling of Set internals during element +addition operations to avoid pathological reallocation behavior, and allows +linear performance scaling up into the 500k key range and beyond. + +### Set `union` Built-in Optimization + +The Set `union` builtin allows applying the union operation to a set of sets. + +However, as discovered in [#4979](https://github.com/open-policy-agent/opa/issues/4979), +its implementation generated unnecessary intermediate copies, which resulted in +poor performance; in many cases, worse than writing the equivalent operation in +pure Rego. + +This release improves the `union` builtin's implementation, such that only the +final result set is ever modified, reducing memory allocations and GC pressure. +The `union` builtin is now about 15-30% faster than the equivalent operation in +pure Rego. + +### New Built-in Functions: `strings.any_prefix_match` and `strings.any_suffix_match` + +This release introduces two new builtins, optimized for bulk matching of string +prefixes and suffixes: `strings.any_prefix_match`, and +`strings.any_suffix_match`. +It works with sets and arrays of strings, allowing efficient matching of +collections of prefixes or suffixes against a target string. + +See [the built-in functions docs for all the details](https://www.openpolicyagent.org/docs/v0.42.0/policy-reference/#builtin-strings-stringsany_prefix_match) + +This implementation fixes [#4994](https://github.com/open-policy-agent/opa/issues/4994) and was authored by @cube2222. + +### Tooling, SDK, and Runtime + +- Logger: Allow configuration of the timestamp format ([#2413](https://github.com/open-policy-agent/opa/issues/2413)) +- loader: Add support for fs.FS (authored by @ear7h) + +#### Bundles + +This release includes several bugfixes and improvements around bundle building: + +- cmd: Add optimize flag to OPA eval command to allow building optimized bundles +- cmd/build+compile: Allow opt-out of dependents gathering to allow compilation of more bundles into WASM ([#5035](https://github.com/open-policy-agent/opa/issues/5035)) +- opa build -t wasm|plan: Fail on unmatched entrypoints ([#3957](https://github.com/open-policy-agent/opa/issues/3957)) +- opa build: Fix bundle mode to work with ignore flag +- bundle/status: Include bundle size in status information +- bundle: Remove raw bytes check for lazy bundle loading mode + +#### Storage Fixes + +This release has performance improvements and bugfixes for the disk storage system: + +- storage/disk: Improve handling of in-flight transactions during truncate operations ([#4900](https://github.com/open-policy-agent/opa/issues/4900)) +- storage/inmem: Allow disabling `util.Roundtrip` on Write for improved performance ([#4708](https://github.com/open-policy-agent/opa/issues/4708)) +- storage: Improve multi-bundle data with overlapping roots is handled ([#4998](https://github.com/open-policy-agent/opa/issues/4998)) reported by @sirpi +- storage: Fix issue with policyID in Truncate calls ([#4958](https://github.com/open-policy-agent/opa/issues/4958)) authored by @martinjoha reported by @martinjoha + +#### Rego + +- eval+rego: Support caching output of non-deterministic builtins. ([#1514](https://github.com/open-policy-agent/opa/issues/1514)) + +#### AST and Topdown + +The AST and Topdown module received a number of important bugfixes in this release: + +- ast/term: Fix multiple-reader race condition for Sets/Objects +- ast/compile: Respect unsafeBuiltinMap for 'with' replacements +- ast: Add capacity to array initialization when size is known (authored by @mstrYoda) +- topdown/object: Fix unchecked error case in `object.union_n` builtin ([#5073](https://github.com/open-policy-agent/opa/issues/5073)) +- topdown/reachable: Fix missing operand type checks. ([#4951](https://github.com/open-policy-agent/opa/issues/4951)) +- topdown/units_parse: Avoid extra decimal places for integers +- topdown/type+wasm: Fix inconsistent `is_type` return values. ([#4943](https://github.com/open-policy-agent/opa/issues/4943)) +- builtins: Fix inconsistent error messages in `units.parse*` +- Add query parameter in canonical request of AWS Sigv4 signature to avoid 403 errors from AWS (authored by @sinhaaks) + +#### Test Suite + +- Add error type to `units.*` builtin test assertions +- test/e2e/certrefresh: Add `file.Sync()` to eliminate test failures due to slow disk writes +- topdown/exported_tests: Remove Golang 1.16 x509 exception +- cmd/bench: Fix port collision in utility function used for E2E testing + +### Documentation + +- SECURITY: Migrate policy to web site, update content ([#4272](https://github.com/open-policy-agent/opa/issues/4272)) reported by @adoliver +- Add deprecated flag to all deprecated builtins ([#5072](https://github.com/open-policy-agent/opa/issues/5072)) +- builtins: Update description of `format_int` to say it rounds down +- docs/policy-reference: Update Rego EBNF grammar (authored by @shaded-enmity) +- docs/builtins: Fix typo in `semver.compare` ([#5012](https://github.com/open-policy-agent/opa/issues/5012)) reported by @tetsuya28 +- docs: Fix AWS Signature section in Configuration (authored by @pauly4it) +- docs: Update port and bundle folder for GraphQL tutorial +- docs: Document that function overloading is unsupported +- docs: Fixing related_resources annotations example ([#4982](https://github.com/open-policy-agent/opa/issues/4982)) reported by @humbertoc-silva +- docs: Fixing typo in metadata ([#5018](https://github.com/open-policy-agent/opa/issues/5018)) authored by @cimin0 reported by @cimin0 + +### Website + Ecosystem + +- Update links to opa-kafka-plugin +- Add OCI documentation (authored by @carabasdaniel) +- Add article on using OPA for data filtering in Kafka +- Ecosystem: Add some links to Rönd (authored by @ugho16) +- Add community integration for Fiber (authored by @mstrYoda) +- Add Spacelift Integration (authored by @theseanodell) +- Fix broken link for Minio OPA integration (authored by @unautre) + +- Ecosystem Additions: + - cosign (#5040) (authored by @Dentrax) + +### Miscellaneous + +- Dockerfile: Append root "/" to $PATH ([#5003](https://github.com/open-policy-agent/opa/issues/5003)) authored by @matusf reported by @matusf +- Add VNG Cloud to adopters (authored by @vinhph0906) + +- Dependency bumps, notably: + - build: bump golang: 1.19 -> 1.19.1 + - build: use go 1.19, drop go 1.16 + - build(deps): bump aquasecurity/trivy-action from 0.6.1 -> 0.7.1 + - build(deps): bump github.com/agnivade/levenshtein from 1.0.1 -> 1.1.1 + - build(deps): bump github.com/containerd/containerd from 1.6.6 -> 1.6.8 + - build(deps): bump github.com/go-ini/ini from 1.66.6 -> 1.67.0 + - build(deps): bump github.com/prometheus/client_golang + - build(deps): bump google.golang.org/grpc from 1.48.0 -> 1.49.0 + - build(deps): bump tj-actions/changed-files from 28.0.0 -> 29.0.3 + +- Dependency removals: + - internal: Vendor gqlparser library ([#5065](https://github.com/open-policy-agent/opa/issues/5065)) reported by @vikstrous2 + +## 0.43.1 + +This is a security release fixing the following vulnerabilities: + +- CVE-2022-36085: Respect unsafeBuiltinMap for 'with' replacements in the compiler + + See https://github.com/open-policy-agent/opa/security/advisories/GHSA-f524-rf33-2jjr for all details. + +- CVE-2022-27664 and CVE-2022-32190. + + Fixed by updating the Go version used in our builds to 1.18.6, + see https://groups.google.com/g/golang-announce/c/x49AQzIVX-s. + Note that CVE-2022-32190 is most likely not relevant for OPA's usage of net/url. + But since these CVEs tend to come up in security assessment tooling regardless, + it's better to get it out of the way. +## 0.43.0 + +This release contains a number of fixes, enhancements, and performance improvements. + +### Object Insertion Optimization + +Rego Object insertion operations did not scale linearly ([#4625](https://github.com/open-policy-agent/opa/issues/4625)) +in the past, and experienced noticeable reallocation/memory movement +overheads once the Object grew past 120k-150k keys in size. + +This release introduces different handling of Object internals during insert +operations to avoid pathological reallocation behavior, and allows linear +performance scaling up into the 500k key range and beyond. + +### Tooling, SDK, and Runtime + +- Add lines covered/not covered counts to test coverage report (authored by @FarisR99) +- Plugins: Status and logs plugins now accept any HTTP 2xx status code (authored by @lvisterin) +- Runtime: Generalize OS check for MacOS to other Unix-likes (authored by @iamleot) + +#### Bundles Fixes + +The Bundles system received several bugfixes and performance improvements in this release: + + - Bundle: `opa bundle` command now supports `.yml` files ([#4859](https://github.com/open-policy-agent/opa/issues/4859)) authored by @Joffref reported by @rdrgmnzsakt + - Plugins/Bundle: Use unique temporary files for persisting activated bundles to disk ([#4782](https://github.com/open-policy-agent/opa/issues/4782)) authored by @FredrikAppelros reported by @FredrikAppelros + - Server: Old policy path is now checked for bundle ownership before update ([#4846](https://github.com/open-policy-agent/opa/issues/4846)) + - Storage+Bundle: Old bundle data is now cleaned before new bundle activation ([#4940](https://github.com/open-policy-agent/opa/issues/4940)) + - Bundle: Paths are now normalized before bundle root check occurs to ensure checks are os-independent + +#### Storage Fixes + +The Storage system received mostly bugfixes, with a notable performance improvement for large bundles in this release: + + - storage/inmem: Speed up bundle activation by avoiding unnecessary read operations ([#4898](https://github.com/open-policy-agent/opa/issues/4898)) + - storage/inmem: Paths are now created during truncate operations if they did not exist before + - storage/disk: Symlinks work with relative paths now ([#4869](https://github.com/open-policy-agent/opa/issues/4869)) + +### Rego and Topdown + +The Rego compiler and runtime environment received a number of bugfixes, and a few new features this release, as well as a notable performance improvement for large Objects +(covered above). + +- AST/Compiler: New method for obtaining parsed, but otherwise unprocessed modules is now available ([#4910](https://github.com/open-policy-agent/opa/issues/4910)) +- `object.subset`: Support array + set combination ([#4858](https://github.com/open-policy-agent/opa/issues/4858)) authored by @x-color +- Compiler: Prevent erasure of `print()` statements in the compiler via a `WithEnablePrintStatements` option to `compiler.Compiler` and `compiler.optimizer` (authored by @kevinstyra) +- Topdown fixes: + - AST/Builtins: `type_name` builtin now has more precise type metadata and improved docs + - Topdown/copypropagation: Ref-based tautologies like `input.a == input.a` are no longer eliminated during the copy-propagation pass ([#4848](https://github.com/open-policy-agent/opa/issues/4848)) reported by @johanneskra + - Topdown/parse_units: Use big.Rat for units parsing to avoid floating-point rounding issues on fractional units. ([#4856](https://github.com/open-policy-agent/opa/issues/4856)) reported by @tmos22 + - Topdown: `is_valid` builtins no longer error, and should always return booleans ([#4760](https://github.com/open-policy-agent/opa/issues/4760)) + - Topdown: `glob.match` now can be used without delimiters ([#4923](https://github.com/open-policy-agent/opa/issues/4923)) authored by @vinhph0906 reported by @vinhph0906 + +### Documentation + + - Docs: Add GraphQL API authorization tutorial + - Docs/bundles: Add bundle CLI command documentation ([#3831](https://github.com/open-policy-agent/opa/issues/3831)) authored by @Joffref + - Docs/policy-reference: Remove extra quote in Grammar to fix formatting ([#4915](https://github.com/open-policy-agent/opa/issues/4915)) authored by @friedrichsenm reported by @friedrichsenm + - Docs/policy-testing: Add missing future.keywords imports ([#4849](https://github.com/open-policy-agent/opa/issues/4849)) reported by @robert-elles + - Docs: Add note about counter_server_query_cache_hit metric ([#4389](https://github.com/open-policy-agent/opa/issues/4389)) + - Docs: Kube tutorial includes updated cert install procedure ([#4902](https://github.com/open-policy-agent/opa/issues/4902)) reported by @Imp + - Docs: GraphQL builtins section now includes a note about framework-specific `@directive` definitions in GraphQL schemas + - Docs: Add warning about name collisions in older policies from importing 'future.keywords' + +### Website + Ecosystem + +- Website: Show navbar on smaller devices ([#3353](https://github.com/open-policy-agent/opa/issues/3353)) authored by @Parsifal-M reported by @OBrienCommaJosh +- Website/frontpage: Update front page examples to use the future.keywords imports +- Website/live-blocks: Only pass 'import future.keywords' when needed and supported +- Website/live-blocks: Update codemirror-rego to 1.3.0 +- Website: Fix community page layout/scrolling issues (authored by @mstade) + +- Ecosystem Additions: + - Rond (authored by @ugho16) + - walt.id + +### Miscellaneous + +- Dependency bumps, notably: + - aquasecurity/trivy-action from 0.5.1 to 0.6.1 + - github.com/sirupsen/logrus from 1.8.1 to 1.9.0 + - github.com/vektah/gqlparser/v2 from 2.4.5 to 2.4.6 + - google.golang.org/grpc from 1.47.0 to 1.48.0 + - terser in /docs/website/scripts/live-blocks + - glob-parent in /docs/website/scripts/live-blocks +- Added GKE Policy Automation to ADOPTERS.md (authored by @mikouaj) +- Fix minor code unreachability error (authored by @Abirdcfly) + +## 0.42.2 + +This is a bug fix release that addresses the following: + +- storage/disk: make symlinks work with relative paths ([#4869](https://github.com/open-policy-agent/opa/issues/4869)) +- bundle: Normalize paths before bundle root check + +## 0.42.1 + +This is a bug fix release that addresses the following: + +1. An issue while writing data to the in-memory store at a non-root nonexistent path ([#4855](https://github.com/open-policy-agent/opa/issues/4855)), reported by @wermerb and others. +2. Policies owned by a bundle could be replaced via the REST API because of a missing bundle scope check ([#4846](https://github.com/open-policy-agent/opa/issues/4846)). +3. Adds missing `future.keywords` import for the examples in the policy testing section of the docs ([#4849](https://github.com/open-policy-agent/opa/issues/4849)), reported by @robert-elles. + +## 0.42.0 + +This release contains a number of fixes and enhancements. + +### New built-in function: `object.subset` + +This function checks if a collection is a subset of another collection. +It works on objects, sets, and arrays. + +If both arguments are objects, then the operation is recursive, e.g. `{"c": {"x": {10, 15, 20}}` +is considered a subset of `{"a": "b", "c": {"x": {10, 15, 20, 25}, "y": "z"}`. + +See [the built-in functions docs for all the details](https://www.openpolicyagent.org/docs/v0.42.0/policy-reference/#builtin-object-objectsubset) + +This implementation fixes [#4358](https://github.com/open-policy-agent/opa/issues/4358) and was authored by @charlesdaniels. + +### New keywords: "contains" and "if" + +These new keywords let you increase the expressiveness of your policy code: + +Before + +```rego +package authz +allow { not denied } # `denied` left out for presentation purposes + +deny[msg] { + count(violations) > 0 + msg := sprintf("there are %d violations", [count(violations)]) +} +``` + +After + +```rego +package authz +import future.keywords + +allow if not denied # one expression only => no { ... } needed! + +deny contains msg if { + count(violations) > 0 + msg := sprintf("there are %d violations", [count(violations)]) +} +``` + +Note that rule bodies containing only one expression can be abbreviated when using `if`. + +To use the new keywords, use `import future.keywords.contains` and `import future.keywords.if`; or +import all of them at once via `import future.keywords`. When these future imports are present, the +pretty printer (`opa fmt`) will introduce `contains` and `if` where applicable. + +`if` is allowed in all places to separate the rule head from the body, like +```rego +response[key] = value if { key := "open", y := "sesame" } +``` +_but_ not for partial set rules, unless also using `contains`: +```rego +deny[msg] if msg := "forbidden" # INVALID +deny contains msg if msg := "forbidden" # VALID +``` + +### Tooling, SDK, and Runtime + +- Plugins: + - S3 Plugin: Allow multiple AWS credential providers at once, chained together ([#4791](https://github.com/open-policy-agent/opa/issues/4791)), reported and authored by @abhisek + - Discovery Plugin: Check for empty key config ([#4656](https://github.com/open-policy-agent/opa/issues/4656)) reported by @humbertoc-silva + - Logs Plugin: Update mechanism to escape field paths ([#4717](https://github.com/open-policy-agent/opa/issues/4717)) reported by @pauly4it + - Status Plugin: fix `bundle_failed_load_counter` metric for bundles without revisions ([#4822](https://github.com/open-policy-agent/opa/issues/4822)) reported and authored by @jkbschmid +- Server: The `system.authz` policy now properly supports the interquery caching of `http.send` calls ([#4829](https://github.com/open-policy-agent/opa/issues/4829)), reported by @HarshPathakhp +- `opa bench`: Passing `--e2e` makes the benchmark measure the performance of a query including the server's HTTP handlers and their processing. +- `opa fmt`: Output list _and_ diff changes with `--fail` flag (#4710) (authored by @davidkuridza) +- Disk Storage: Bundles are now streamed into the disk store, and not extracted completely in-memory ([#4539](https://github.com/open-policy-agent/opa/issues/4539)) +- Golang package `repl`: Add a `WithCapabilities` function (authored by @jaspervdj) +- SDK: Allow configurable ID (authored by @rakshasa-1729) +- Windows: User lookups in various code paths have been avoided. They had no use, but are costly, and removing them should increase + the performance of any CLI calls (even `opa version`) on Windows. Fixes [#4646](https://github.com/open-policy-agent/opa/issues/4646). +- Server: Open read storage transaction in Query API handler (not write) + +### Rego and Topdown + +- Runtime Errors: Fix type error message in `count`, `object.filter`, and `object.remove` built-in functions ([#4767](https://github.com/open-policy-agent/opa/issues/4767)) +- Parser: Remove early MHS return in infix parsing, fixing confusing error messages ([#4672](https://github.com/open-policy-agent/opa/issues/4672)) authored by @philipaconrad +- AST: Disallow shadowing of called functions in comprehension heads ([#4762](https://github.com/open-policy-agent/opa/issues/4762)) +- Planner/IR: shadow rule funcs if mocking functions ([#4746](https://github.com/open-policy-agent/opa/issues/4746)) +- Compiler: Fix "every" handling in partial eval: by reordering body for safety differently, and correctly plugging its terms on safe ([#4801](https://github.com/open-policy-agent/opa/pull/4801)), reported by @jguenther-va +- Compiler: fix util.HashMap eq comparison ([#4759](https://github.com/open-policy-agent/opa/pull/4759)) +- Built-ins: use strings.Builder in glob.match() (authored by @charlesdaniels) + +### Documentation + +- Builtins: Fix documentation of `startswith` and `endswith` (authored by @whme) +- Kubenetes Tutorial: Remove unused assignement in example ([#4778](https://github.com/open-policy-agent/opa/issues/4778)) authored by @Joffref +- OCI: Update configuration docs for private images in OCI registries (authored by @carabasdaniel) +- AWS S3 Signing: Fix profile_credentials docs (authored by @wangli1030) + +### Website + Ecosystem + +- Add "Edit on GitHub" button to docs ([#3784](https://github.com/open-policy-agent/opa/issues/3784)) authored by @avinashdesireddy +- Wasm: fix function table markup ([#4664](https://github.com/open-policy-agent/opa/issues/4664)) +- Ecosystem: use location.hash to track open modal ([#4667](https://github.com/open-policy-agent/opa/issues/4667)) + +Note that website changes like these become effective immediately and are not tied to a release. +We still use our release notes to record the nice fixes contributed by our community. + +- Ecosystem Additions: + - Alfred, the self-hosted playground (authored by @dolevf) + - Java Spring tutorial (authored by @psevestre) + - Pulumi + +### Miscellaneous + +- Add Terminus to ADOPTERS.md (#4734) ([#4713](https://github.com/open-policy-agent/opa/issues/4713)) reported by @charlieflowers +- Remove any data attributes not used in the "YAML tests" ([#4813](https://github.com/open-policy-agent/opa/issues/4813)) +- Dependency bumps, notably: + - github.com/prometheus/client_golang 1.12.2 ([#4697](https://github.com/open-policy-agent/opa/issues/4697)) + - github.com/vektah/gqlparser/v2 2.4.5 +- Build process and CI: + - Use Trivy for vulnerability scans in code and container images (authored by @JAORMX) + - Bump golangci-lint to v1.46.2, fix some issues ([#4765](https://github.com/open-policy-agent/opa/issues/4765)) + - Remove npm-opa-wasm test + - Skip flaky darwin tests on PR runs + - Fix flaky oci e2e test ([#4748](https://github.com/open-policy-agent/opa/issues/4748)) authored by @carabasdaniel + - Integrate builtin_metadata.json handling in release process ([#4754](https://github.com/open-policy-agent/opa/issues/4754)) + + +## 0.41.0 + +This release contains a number of fixes and enhancements. + +### GraphQL Built-in Functions + +A new set of built-in functions are now available to validate, parse and verify GraphQL query and schema! Following are +the new built-ins: + + graphql.is_valid: Checks that a GraphQL query is valid against a given schema + graphql.parse: Returns AST objects for a given GraphQL query and schema + graphql.parse_and_verify: Returns a boolean indicating success or failure alongside the parsed ASTs for a given GraphQL query and schema + graphql.parse_query: Returns an AST object for a GraphQL query + graphql.parse_schema: Returns an AST object for a GraphQL schema + +### Built-in Function Metadata + +Built-in function declarations now support additional metadata to specify name and description for function arguments +and return values. The metadata can be programmatically consumed by external tools such as IDE plugins. The built-in +function documentation is created using the new built-in function metadata. +Check out the new look of the [Built-In Reference](https://www.openpolicyagent.org/docs/latest/policy-reference/#built-in-functions) +page! + +Under the hood, a new file called `builtins_metadata.json` is generated via `make generate` which can be consumed by +external tools. + +### Tooling, SDK, and Runtime + +- OCI Downloader: Add logic to skip bundle reloading based on the digest of the OCI artifact ([#4637](https://github.com/open-policy-agent/opa/issues/4637)) authored by @carabasdaniel +- Bundles: Exclude empty manifest from bundle signature ([#4712](https://github.com/open-policy-agent/opa/issues/4712)) authored by @friedrichsenm reported by @friedrichsenm + +### Rego and Topdown + +- units.parse: New built-in for parsing standard metric decimal and binary SI units (e.g., K, Ki, M, Mi, G, Gi) +- format: Fix `opa fmt` location for non-key rules (#4695) (authored by @jaspervdj) +- token: Ignore keys of unknown alg when verifying JWTs with JWKS ([#4699](https://github.com/open-policy-agent/opa/issues/4699)) reported by @lenalebt + +### Documentation + +- Adding Built-in Functions: Add note about `capabilities.json` while creating a new built-in function +- Policy Reference: Add example for `rego.metadata.rule()` built-in function +- Policy Reference: Fix grammar for `import` keyword ([#4689](https://github.com/open-policy-agent/opa/issues/4689)) authored by @mmzeeman reported by @mmzeeman +- Security: Fix command line flag name for file containing the TLS certificate ([#4678](https://github.com/open-policy-agent/opa/issues/4678)) authored by @pramodak reported by @pramodak + +### Website + Ecosystem + +- Update Kubernetes policy examples on the website to use latest kubernetes schema (`apiVersion`: `admission.k8s.io/v1`) (authored by @vicmarbev) +- Ecosystem: + - Add Sansshell (authored by @sfc-gh-jchacon) + - Add Nginx + +### Miscellaneous + +- Various dependency bumps, notably: + - OpenTelemetry-go: 1.6.3 -> 1.7.0 + - go.uber.org/automaxprocs: 1.4.0 -> 1.5.1 + - github.com/containerd/containerd: 1.6.2 -> 1.6.4 + - google.golang.org/grpc: 1.46.0 -> 1.47.0 + - github.com/bytecodealliance/wasmtime-go: 0.35.0 -> 0.36.0 + - github.com/vektah/gqlparser/v2: 2.4.3 -> 2.4.4 +- `make test`: Fix "too many open files" issue on Mac OS +- Remove usage of github.com/pkg/errors package (authored by @imjasonh) + +## 0.40.0 + +This release contains a number of fixes and enhancements. + +### Metadata introspection + +The _rich metadata_ added in the v0.38.0 release can now be introspected +from the policies themselves! + + package example + + # METADATA + # title: Edits by owner only + # description: | + # Only the owner is allowed to edit their data. + deny[{"allowed": false, "message": rego.metadata.rule().description}] { + input.user != input.owner + } + +This snippet will evaluate to + + [{ + "allowed": false, + "message": "Only the owner is allowed to edit their data.\n" + }] + +Both the rule's metadata can be accessed, via `rego.metadata.rule()`, and the +entire chain of metadata attached to the rule via the various scopes that different +metadata annotations can have, via `rego.metadata.chain()`. + +All the details can be found in the documentation of [these new built-in functions](https://www.openpolicyagent.org/docs/v0.40.0/policy-reference/#rego). + +### Function mocking + +It is now possible to **mock functions** in tests! Both built-in and non-built-in +functions can be mocked: + + package authz + import data.jwks.cert + import data.helpers.extract_token + + allow { + [true, _, _] = io.jwt.decode_verify(extract_token(input.headers), {"cert": cert, "iss": "corp.issuer.com"}) + } + + test_allow { + allow + with input.headers as [] + with data.jwks.cert as "mock-cert" + with io.jwt.decode_verify as [true, {}, {}] # mocked built-in + with extract_token as "my-jwt" # mocked non-built-in + } + +For further information about policy testing with data and function mock, see [the Policy Testing docs](https://www.openpolicyagent.org/docs/v0.40.0/policy-testing/#data-and-function-mocking) +All details about `with` can be found in its [Policy Language section](https://www.openpolicyagent.org/docs/v0.40.0/policy-language/#with-keyword). + +### Assignments with `:=` + +Remaining restrictions around the use of `:=` in rules and functions have been lifted ([#4555](https://github.com/open-policy-agent/opa/issues/4555)). +These constructs are now valid: + + check_images(imgs) := x { # function + # ... + } + + allow := x { # rule + # ... + } + + response[key] := object { # partial object rule + # ... + } + +In the wake of this, rules may now be "redeclared", i.e. you can use `:=` for more than one rule body: + + deny := x { + # body 1 + } + deny := x { + # body 2 + } + +This was forbidden before, but didn't serve a real purpose: it would catch trivial-to-catch errors +like + + p := 1 + p := 2 # redeclared + +But it would do no good in more difficult to debug "multiple assignment" problems like + + p := x { + some x in [1, 2, 3] + } + +### Tooling, SDK, and Runtime + +- Status Plugin: Remove activeRevision label on all but one Prometheus metric ([#4584](https://github.com/open-policy-agent/opa/issues/4584)) reported and authored by @costimuraru +- Status: Include bundle type ("snapshot" or "delta") in status information +- `opa capabilities`: Expose capabilities through CLI, and allow using versions when passing `--capabilities v0.39.0` to the various commands ([#4236](https://github.com/open-policy-agent/opa/issues/4236)) authored by @IoannisMatzaris +- Logging: Log warnings at WARN level not ERROR, authored by @damienjburks +- Runtime: Persist activated bundle Etag to store ([#4544](https://github.com/open-policy-agent/opa/issues/4544)) +- `opa eval`: Don't use source locations when formatting partially evaluated output ([#4609](https://github.com/open-policy-agent/opa/issues/4609)) +- `opa inspect`: Fixing an issue where some errors encountered by the inspect command aren't properly reported +- `opa fmt`: Fix a bug with missing whitespace when formatting multiple `with` statements on one indented line ([#4634](https://github.com/open-policy-agent/opa/issues/4634)) + +#### Experimental OCI support + +When configured to do so, OPA's bundle and discovery plugins will retrieve bundles from **any OCI registry**. +Please see [the Services Configuration section](https://www.openpolicyagent.org/docs/v0.40.0/configuration/#services) +for details. + +Note that at this point, it's best considered a "feature preview". Be aware of this: +- Bundles are not cached, but re-retrieved and activated periodically. +- The persistence directory used for storing retrieved OCI artifacts is not yet managed by OPA, + so its content may accumulate. By default, the OCI downloader will use a temporary file location. +- The documentation on how to push bundles to an OCI repository currently only exists in the development + docs, see [OCI.md](https://github.com/open-policy-agent/opa/blob/v0.40.0/docs/devel/OCI.md). + +Thanks to @carabasdaniel for starting the work on this! + +### Rego and Topdown + +- Builtins: Require prefix length for IPv6 in `net.cidr_merge` ([#4596](https://github.com/open-policy-agent/opa/issues/4596)), reported by @alexhu20 +- Builtins: `http.send` can now parse and cache YAML responses, analogous to JSON responses +- Parser: Guard against invalid domains for "some" and "every", reported by @doyensec +- Formatting: Don't add 'in' keyword import when 'every' is there ([#4606](https://github.com/open-policy-agent/opa/issues/4606)) + +### Documentation + +- Policy Language: Reorder Universal Quantification content, stress `every` over other constructions ([#4603](https://github.com/open-policy-agent/opa/issues/4603)) +- Language pages: Use assignment operator where it's allowed. +- SSH Tutorial: Use bundle API +- Annotations: Update "Custom" annotation section +- Cloudformation: Fix markup and add warning related to booleans +- Blogs: mention OAuth2 and OIDC blog posts + +### Website + Ecosystem + +- Redirect previous patch releases to latest patch release ([#4225](https://github.com/open-policy-agent/opa/issues/4225)) +- Add playground button to navbar +- Add SRI to static html files +- Remove right margin on sidebar (#4529) (authored by @orweis) +- Show yellow banner for old version (#4533) +- Remove unused variables to avoid error in strict mode(#4534) (authored by @panpan0000) +- Ecosystem: + - Add AWS CloudFormation Hook + - Add GKE policy automation + - Add permit.io (authored by @ozradi) + - Add Magda (authored by @t83714) + +### Miscellaneous + +- Workflow: no content permissions for GitHub action 'post-release', authored by @naveensrinivasan +- Various dependency bumps, notably: + - OpenTelemetry-go: 1.6.1 -> 1.6.3 + - go.uber.org/automaxprocs: 1.4.0 -> 1.5.1 +- Binaries and Docker images are now built using Go 1.18.1. +- Dockerfile: add source annotation (#4626) + +## 0.39.0 + +This release contains a number of fixes and enhancements. + +### Disk Storage + +The on-disk storage backend has been fully integrated with the OPA server, and +can now be enabled via configuration: + +```yaml +storage: + disk: + directory: /var/opa # put data here + auto_create: true # create directory if it doesn't exist + partitions: # partitioning is important for data storage, + - /users/* # please see the documentation +``` + +It is intended to enable the use of OPA in scenarios where the data needed for +policy evaluation exceeds the available memory. + +The on-disk contents will persist among restarts, but should not be used as a +single source of truth: there are no backup mechanisms, and certain data partitioning +changes will require a start-over. These are things that may get improved in the +future. + +For all the details, please refer to the [configuration](https://www.openpolicyagent.org/docs/v0.39.0/configuration/#disk-storage) +and [detailled Disk Storage section](https://www.openpolicyagent.org/docs/v0.39.0/misc-disk/) +of the documentations. + +### Tooling, SDK, and Runtime + +- Server: Add warning when `input` attribute is missing in `POST /v1/data` API ([#4386](https://github.com/open-policy-agent/opa/issues/4386)) authored by @aflmp +- SDK: Support partial evaluation ([#4240](https://github.com/open-policy-agent/opa/pull/4240)), authored by @kroekle; with a fix to avoid using different state (authored by @Iceber) +- Runtime: Suppress payloads in debug logs for handlers that compress responses (`/metrics` and `/debug/pprof`) (authored by @christian1607) +- `opa test`: Add file path to failing tests to make debugging failing tests easier ([#4457](https://github.com/open-policy-agent/opa/issues/4457)), authored by @liamg +- `opa fmt`: avoid whitespace mixed with tabs on `with` statements ([#4376](https://github.com/open-policy-agent/opa/issues/4376)) reported by @tiwood +- Coverage reporting: Remove duplicates from coverage report ([#4393](https://github.com/open-policy-agent/opa/issues/4393)) reported by @gianna7wu +- Plugins: Fix broken retry logic in decision logs plugin ([#4486](https://github.com/open-policy-agent/opa/issues/4486)) reported by @iamatwork +- Plugins: Update regular polling fallback mechanism for downloader +- Plugins: Support for adding custom parameters and headers for OAuth2 Client Credentials Token request (authored by @srlk) +- Plugins: Log message on unexpected bundle content type ([#4278](https://github.com/open-policy-agent/opa/issues/4278)) +- Plugins: Mask Authorization header value in debug logs ([#4495](https://github.com/open-policy-agent/opa/issues/4495)) +- Docker images: Use GID 1000 in `-rootless` images ([#4380](https://github.com/open-policy-agent/opa/issues/4380)); also warn when using UID/GID 0. +- Runtime: change processed file event log level to info + +### Rego and Topdown + +- Type checker: Skip pattern JSON Schema attribute compilation ([#4426](https://github.com/open-policy-agent/opa/issues/4426)): These are not supported, but could have caused the parsing of a JSON Schema document to fail. +- Topdown: Copy without modifying expr, fixing a bug that could occur when running multiple partial evaluation requests concurrently. +- Compiler strict mode: Raise error on unused imports ([#4354](https://github.com/open-policy-agent/opa/issues/4354)) authored by @damienjburks +- AST: Fix print call rewriting in else rules ([#4489](https://github.com/open-policy-agent/opa/issues/4489)) +- Compiler: Improve error message on missing `with` target ([#4431](https://github.com/open-policy-agent/opa/issues/4431)) reported by @gabrielfern +- Parser: hint about 'every' future keyword import + +### Documentation and Website + +- AWS CloudFormation Hook: New tutorial +- Community: Stretch background so it covers on larger screens ([#4402](https://github.com/open-policy-agent/opa/issues/4402)) authored by @msorens +- Build: Make local dev and PR preview not build everything ([#4379](https://github.com/open-policy-agent/opa/issues/4379)) +- Philosophy: Grammar fixes (authored by @ajonesiii) +- README: Add note about Hugo version mismatch errors (authored by @ogazitt) +- Integrations: Add GraphQL-Graphene (authored by @dolevf), Emissary-Ingress (authored by @tayyabjamadar), rekor-sidekick, +- Integrations CI: ensure referenced software is listed, and logo file names match; allow SVG logos +- Envoy: Update policy primer with new control headers +- Envoy: Update bob_token and alice_token in tutorial (authored by @rokkiter) +- Envoy: Include new configurable gRPC msg sizes (authored by @emaincourt) +- Annotations: add missing title to index (authored by @itaysk) + +### Miscellaneous + +- Various dependency bumps, notably: + - OpenTelemetry-go: 1.4.1 -> 1.6.1 + - Wasmtime-go: 0.34.0 -> 0.35.0 +- Binaries and Docker images are now built using Go 1.18; CI runs build/test for Ubuntu and macos with Go 1.16 and 1.17. +- CI: remove go-fuzz, use native go 1.18 fuzzer + +## 0.38.1 + +This is a bug fix release that addresses one issue when using `opa test` with the +`--bundle` (`-b`) flag, and a policy that uses the `every` keyword. + +There are no other code changes in this release. + +### Fixes + +- Compiler: don't raise an error with unused declared+generated vars + (every) ([#4420](https://github.com/open-policy-agent/opa/issues/4420)), + reported by @kristiansvalland + +## 0.38.0 + +This release contains a number of fixes and enhancements. + +It contains one **backwards-incompatible change** to the JSON representation +of metrics in **Status API** payloads, please see the section below. + +### Rich Metadata + +It is now possible to annotate Rego policies in a way that can be +processed programmatically, using _Rich Metadata_. + + # METADATA + # title: My rule + # description: A rule that determines if x is allowed. + # authors: + # - Jane Austin + allow { + ... + } + +The available keys are: + +- title +- description +- authors +- organizations +- related_resources +- schemas +- scope +- custom + +Custom annotations can be used to annotate rules, packages, and +documents with whatever you specifically need, beyond the generic +keywords. + +Annotations can be retrieved using the [Golang library](https://www.openpolicyagent.org/docs/v0.38.0/annotations/#go-api) +or via the CLI, `opa inspect -a`. + +All the details can be found in the documentation on [Annotations](https://www.openpolicyagent.org/docs/v0.38.0/annotations/). + +### Every Keyword + +A new keyword for explicit iteration is added to Rego: `every`. + +It comes in two forms, iterating values, or keys and values, of a +collection, and asserting that the body evaluates successfully for +each binding of key and value to the collection's elements: + + every k, v in {"foo": "FOO", "bar": "BAR" } { + upper(k) == v + } + +To use it, `import future.keywords.every` or `future.keywords`. + +For further information, please refer to the [Every Keyword docs](https://www.openpolicyagent.org/docs/v0.38.0/policy-language/#every-keyword) +and the new section on [_FOR SOME and FOR ALL_ in the Intro docs](https://www.openpolicyagent.org/docs/v0.38.0/#for-some-and-for-all). + +### Tooling, SDK, and Runtime + +- Compile API: add `disableInlining` option ([#4357](https://github.com/open-policy-agent/opa/issues/4357)) reported and fixed by @srlk +- Status API: add `http_code` to response ([#4259](https://github.com/open-policy-agent/opa/issues/4259)) reported and fixed by @jkbschmid +- Status plugin: publish experimental bundle-related metrics via prometheus endpoint (authored by @rafaelreinert) -- See [Status Metrics](https://www.openpolicyagent.org/docs/v0.38.0/monitoring/#status-metrics) for details. +- SDK: don't panic without config ([#4303](https://github.com/open-policy-agent/opa/issues/4303)) authored by @damienjburks +- Storage: Support index for array appends (for JSON Patch compatibility) +- `opa deps`: Fix pretty printed output to show virtual documents ([#4342](https://github.com/open-policy-agent/opa/issues/4342)) + +### Rego and Topdown + +- Parser: parse 'with' on 'some x in xs' expression ([#4226](https://github.com/open-policy-agent/opa/issues/4226)) +- AST: hash containers on insert/update ([#4345](https://github.com/open-policy-agent/opa/issues/4345)), fixing a data race reported by @skillcoder +- Planner: Fix bug related to undefined results in dynamic lookups + +### Documentation and Website + +- Policy Reference: update EBNF to include "every" and "some x in ..." ([#4216](https://github.com/open-policy-agent/opa/issues/4216)) +- REST API: Update docs on 400 response +- README: Include Google Analytic Instructions +- Envoy primer: use variables instead of objects +- Istio tutorial: expose application to outside traffic +- New "Community" Webpage (authored by @msorens) + +### WebAssembly + +- OPA now uses Wasmtime 0.34.0 to evaluate its Wasm modules. + +### Miscellaneous + +- Build: `make build` now builds without errors (by disabling Wasm) on darwin/arm64 (M1) +- Various dependency bumps. + - OpenTelemetry SDK: 1.4.1 + - github.com/prometheus/client_golang: 1.12.1 + +### Backwards incompatible changes + +The JSON representation of the Status API's payloads -- both for `GET /v1/status` +responses and the metrics sent to a remote Status API endpoint -- have changed: + +Previously, they had been serialized into JSON using the standard library "encoding/json" +methods. However, the metrics coming from the Prometheus integration are only available +in Golang structs generated from Protobuf definitions. For serializing these into JSON, +the standard library functions are unsuited: + +- enums would be converted into numbers, +- field names would be `snake_case`, not `camelCase`, +- and NaNs would cause the encoder to panic. + +Now, we're using the protobuf ecosystem's `jsonpb` package, to serialize the Prometheus +metrics into JSON in a way that is compliant with the Protobuf specification. + +Concretely, what would before be +``` + "metrics": { + "prometheus": { + "go_gc_duration_seconds": { + "help": "A summary of the GC invocation durations.", + "metric": [ + { + "summary": { + "quantile": [ + { + "quantile": 0, + "value": 0.000011799 + }, + { + "quantile": 0.25, + "value": 0.000011905 + }, + { + "quantile": 0.5, + "value": 0.000040002 + }, + { + "quantile": 0.75, + "value": 0.000065238 + }, + { + "quantile": 1, + "value": 0.000104897 + } + ], + "sample_count": 7, + "sample_sum": 0.000309117 + } + } + ], + "name": "go_gc_duration_seconds", + "type": 2 + }, +``` + +is *now*: +``` + "metrics": { + "prometheus": { + "go_gc_duration_seconds": { + "name": "go_gc_duration_seconds", + "help": "A summary of the pause duration of garbage collection cycles.", + "type": "SUMMARY", + "metric": [ + { + "summary": { + "sampleCount": "1", + "sampleSum": 4.1765e-05, + "quantile": [ + { + "quantile": 0, + "value": 4.1765e-05 + }, + { + "quantile": 0.25, + "value": 4.1765e-05 + }, + { + "quantile": 0.5, + "value": 4.1765e-05 + }, + { + "quantile": 0.75, + "value": 4.1765e-05 + }, + { + "quantile": 1, + "value": 4.1765e-05 + } + ] + } + } + ] + }, +``` + +Note that `sample_count` is now `sampleCount`, and the `type` is using the enum's +string representation, `"SUMMARY"`, not `2`. + +Note: For compatibility reasons (the Prometheus golang client doesn't use the V2 +protobuf API), this change uses `jsonpb` and not `protojson`. + +## 0.37.2 + +This is a bugfix release addressing two bugs: + +1. A regression introduced in the formatter fix for CVE-2022-23628. +2. Support indices for appending to an array, conforming to JSON Patch (RFC6902) + for patch bundles. + +### Miscellaneous + +- format: generated vars may have a proper location +- storage: Support index for array appends + +## 0.37.1 + +This is a bug fix release that reverts the github.com/prometheus/client_golang +upgrade in v0.37.0. The upgrade exposed an issue in the serialization of Go +runtime metrics in the Status API +([#4319](https://github.com/open-policy-agent/opa/issues/4319)). + +### Miscellaneous + +- Revert "build(deps): bump github.com/prometheus/client_golang (#4307)" + +## 0.37.0 + +This release contains a number of fixes and enhancements. + +This is the first release that includes a binary and a docker image for +`linux/arm64`, `opa_linux_arm64_static` and `openpolicyagent/opa:0.37.0-static`. +Thanks to @ngraef for contributing the build changes necessary. + +### Strict Mode + +There have been numerous possible checks in the compiler that fall into this category: + +1. They would help avoid common mistakes; **but** +2. Introducing them would potentially break some uncommon, but legitimate use. + +We've thus far refrained from introducing them. **Now**, a new "strict mode" +allows you to opt-in to these checks, and we encourage you to do so! + +With *OPA 1.0*, they will become the new default behaviour. + +For more details, [see the docs on _Compiler Strict Mode_](https://www.openpolicyagent.org/docs/v0.37.0/strict/). + +### Delta Bundles + +Delta bundles provide a more efficient way to make data changes by containing +*patches to data* instead of snapshots. +Using them together with [HTTP Long Polling](https://www.openpolicyagent.org/docs/v0.37.0/management-bundles/#http-long-polling), +you can propagate small changes to bundles without waiting for polling delays. + +See [the documentation](https://www.openpolicyagent.org/docs/v0.37.0/management-bundles/#delta-bundles) +for more details. + + +### Tooling and Runtime + +- Bundles bug fix: Roundtrip manifest before hashing to allow changing the manifest + and still using signature verification of bundles ([#4233](https://github.com/open-policy-agent/opa/issues/4233)), + reported by @CristianJena + +- The test runner now also supports custom builtins, when invoked through the Golang + interface (authored by @MIA-Deltat1995) + +- The compile package and the `opa build` command support a new output format: "plan". + It represents a _query plan_, steps needed to take to evaluate a query (with policies). + The plan format is a JSON encoding of the intermediate representation (IR) used for + compiling queries and policies into Wasm. + + When calling `opa build -t plan ...`, the plan can be found in `plan.json` at the top- + level directory of the resulting bundle.tar.gz. + [See the documentation for details.](https://www.openpolicyagent.org/docs/v0.37.0/ir/). + +- Compiler+Bundles: Metadata to be added to a bundle's manifest can now be provided via `WithMetadata` + ([#4289](https://github.com/open-policy-agent/opa/issues/4289)), authored by @marensws, reported by @johanneslarsson +- Plugins: failures in auth plugin resolution are now output, previously panicked, authored by @jcchavezs +- Plugins: Fix error when initializing empty decision logging or status plugin ([#4291](https://github.com/open-policy-agent/opa/issues/4291)) +- Bundles: Persisted bundle activation failures are treated like failures with + non-persisted bundles ([#3840](https://github.com/open-policy-agent/opa/issues/3840)), reported by @dsoguet +- Server: `http.send` caching now works in system policy `system.authz` ([#3946](https://github.com/open-policy-agent/opa/issues/3946)), + reported by @amrap030. +- Runtime: Apply credentials masking on `opa.runtime().config` ([#4159](https://github.com/open-policy-agent/opa/issues/4159)) +- `opa test`: removing deprecated code for `--show-failure-line` (`-l`), authored by @damienjburks +- `opa eval`: add description to all output formats +- `opa inspect`: unhide command for [bundle inspection](https://www.openpolicyagent.org/docs/v0.37.0/cli/#opa-inspect) + +### Rego and Topdown + +Built-in function enhancements and fixes: + +- `object.union_n`: New built-in for creating the union of more than two objects ([#4012](https://github.com/open-policy-agent/opa/issues/4012)), + reported by @eliw00d +- `graph.reachable_paths`: New built-in to calculate the set of reachable paths in a graph (authored by @justinlindh-wf) +- `indexof_n`: New built-in function to get all the indexes of a specific substring (or character) from a string (authored by @shuheiktgw) +- `indexof`: Improved performance (authored by @shuheiktgw) +- `object.get`: Support nested key array for deeper lookups with default (authored by @charlieegan3) +- `json.is_valid`: Use Golang's `json.Valid` to avoid unnecessary allocations (authored by @kristiansvalland) + +Strict-mode features: + +- Add _duplicate imports_ check ([#2698](https://github.com/open-policy-agent/opa/issues/2698)) reported by @mikol +- _Deprecate_ `any()` and `all()` built-in functions ([#2437](https://github.com/open-policy-agent/opa/issues/2437)) +- Make `input` and `data` reserved keywords ([#2600](https://github.com/open-policy-agent/opa/issues/2600)) reported by @jpeach +- Add _unused local assignment_ check ([#2514](https://github.com/open-policy-agent/opa/issues/2514)) + + +Miscellaneous fixes and enhancements: + +- `format`: don't group iterable when one has defaulted location +- `topdown`: ability to retrieve input and plug bindings in the `Event`, authored by @istalker2 +- `print()` built-in: fix bug when used with `with` modifier and a function call value ([#4227](https://github.com/open-policy-agent/opa/issues/4227)) +- `ast`: don't error when future keyword import is redundant during parsing + +### Documentation + +- A [new "CLI" docs section](https://www.openpolicyagent.org/docs/v0.37.0/cli/) describes the various + OPA CLI commands and their arguments ([#3915](https://github.com/open-policy-agent/opa/issues/3915)) +- Policy Testing: Add reference to rule indexing in the context of test code coverage + ([#4170](https://github.com/open-policy-agent/opa/issues/4170)), reported by @ekcs +- Management: Add hint that S3 regional endpoint should be used with bundles (authored by @danoliver1) +- Many broken links were fixed, thanks to @phelewski +- Fix rendering of details: add detail-tab for collapsable markdown (authored by @bugg123) + +### WebAssembly + +- Add native support for `json.is_valid` built-in function + ([#4140](https://github.com/open-policy-agent/opa/issues/4140)), authored by @kristiansvalland +- Dependencies: bump wasmtime-go from 0.32.0 to 0.33.1 + +### Miscellaneous + +- Publish multi-arch image manifest lists including linux/arm64 ([#2233](https://github.com/open-policy-agent/opa/issues/2233)), + authored by @ngraef, reported by @povilasv +- `logging`: Remove logger `GetFields` function ([#4114](https://github.com/open-policy-agent/opa/issues/4114)), + authored by @viovanov +- Website: add versioned docs for latest version, so when 0.37.0 is released, both + https://www.openpolicyagent.org/docs/v0.37.0/ and https://www.openpolicyagent.org/docs/latest + contain docs, and 0.37.0 can already be used for stable links to versioned docs pages. +- Community: Initial draft of the community badges program +- `make test`: fix "too many open files" issue on Mac OS +- Various dependency bumps + +## 0.36.1 + +This release includes a number of documentation fixes. +It also includes the experimental binary for darwin/arm64. + +There are no code changes. + +### Documentation + +- OpenTelemetry: fix configuration example, authored by @rvalkenaers +- Configuration: fix typo for `tls-cert-refresh-period`, authored by @mattmahn +- SSH and Sudo authorization: Add missing filename +- Integration: fix example policy + +### Release + +- Build darwin/arm64 in post tag workflow + +## 0.36.0 + +This release contains a number of fixes and enhancements. + +### OpenTelemetry and opa exec + +This release adds OpenTelemetry support to OPA. This makes it possible to emit spans to an OpenTelemetry collector via +gRPC on both incoming and outgoing (i.e. http.send) calls in the server. See the updated docs on +[monitoring](https://www.openpolicyagent.org/docs/latest/monitoring/) for more information and configuration options +([#1469](https://github.com/open-policy-agent/opa/issues/1469)) authored by @[rvalkenaers](https://github.com/rvalkenaers) + +This release also adds a new `opa exec` command for doing one-off evaluations of policy against input similar to +`opa eval`, but using the full capabilities of the server (config file, plugins, etc). This is particularly useful in +contexts such as CI/CD or when enforcing policy for infrastructure as code, where one might want to run OPA with remote +bundles and decision logs but without having a running server. See the updated docs on +[Terraform](https://www.openpolicyagent.org/docs/latest/terraform/) for an example use case. +([#3525](https://github.com/open-policy-agent/opa/issues/3525)) + +### Built-in Functions + +- Four new functions for working with HMAC (`crypto.hmac.md5`, `crypto.hmac.sha1`, `crypto.hmac.sha256`, and `crypto.hmac.sha512`) was added ([#1740](https://github.com/open-policy-agent/opa/issues/1740)) reported by @[jshaw86](https://github.com/jshaw86) +- `array.reverse(array)` and `strings.reverse(string)` was added for reversing arrays and strings ([#3736](https://github.com/open-policy-agent/opa/issues/3736)) authored by @[kristiansvalland](https://github.com/kristiansvalland) and @[olamiko](https://github.com/olamiko) +- The `http.send` built-in function now uses a metric for counting inter-query cache hits ([#4023](https://github.com/open-policy-agent/opa/issues/4023)) authored by @[mirayadav](https://github.com/mirayadav) +- An overflow issue with dates very far in the future has been fixed in the `time.*` built-in functions ([#4098](https://github.com/open-policy-agent/opa/issues/4098)) reported by @[morgante](https://github.com/morgante) + +### Tooling + +- A problem with future keyword import of `in` was fixed for `opa fmt` ([#4111](https://github.com/open-policy-agent/opa/issues/4111)) reported by @[keshavprasadms](https://github.com/keshavprasadms) +- An issue with `opa fmt` when refs contained operators was fixed (authored by @[jaspervdj-luminal](https://github.com/jaspervdj-luminal)) +- Fix file renaming check in optimization using `opa build` (authored by @[davidmarne-wf](https://github.com/davidmarne-wf)) +- The `allow_net` capability was added, allowing setting limits on what hosts can be reached in built-ins like `http.send` and `net.lookup_ip_addr` ([#3665](https://github.com/open-policy-agent/opa/issues/3665)) + +### Server + +- A new credential provider for AWS credential files was added ([#2786](https://github.com/open-policy-agent/opa/issues/2786)) reported by @[rgueldem](https://github.com/rgueldem) +- The new `--tls-cert-refresh-period` flag can now be provided to `opa run`. If used with a positive duration, such as "5m" (5 minutes), + "24h", etc, the server will track the certificate and key files' contents. When their content changes, the certificates will be + reloaded ([#2500](https://github.com/open-policy-agent/opa/issues/2500)) reported by @[patoarvizu](https://github.com/patoarvizu) +- A new `v1/status` endpoint was added, providing the same data as the status plugin would send to a remote endpoint ([#4089](https://github.com/open-policy-agent/opa/issues/4089)) +- The HTTP router of OPA is now exposed to the plugin manager ([#2777](https://github.com/open-policy-agent/opa/issues/2777)) authored by @[bhoriuchi](https://github.com/bhoriuchi) reported by @[mneil](https://github.com/mneil) +- Calling `print` now works in decision masking policies +- An unintended switch between long/regular polling on 304 HTTP status was fixed ([#3923](https://github.com/open-policy-agent/opa/issues/3923)) authored by @[floriangasc](https://github.com/floriangasc) +- The error message about prohibited config in the discovery plugin has been improved +- The discovery plugin no longer panics in Trigger() if downloader is nil +- The bundle plugin now ignores service errors for file:// resources +- The bundle plugin file loader was updated to support directories +- A timer to HTTP request was added to the downloader +- The requested_by field in the logging plugin is now optional + +### Rego + +- The error message raised when using `-` with a number and a set is now more specific (as opposed to the correct usage with two sets, or two numbers) ([#1643](https://github.com/open-policy-agent/opa/issues/1643)) +- Fixed an edge case when using print and arrays in unification ([#4078](https://github.com/open-policy-agent/opa/issues/4078)) +- Improved performance of some array operations by caching an array's groundness bit ([#3679](https://github.com/open-policy-agent/opa/issues/3679)) +- ⚠️ Stricter check of arity in undefined function stage ([#4054](https://github.com/open-policy-agent/opa/issues/4054)). + This change will fail evaluation in some unusual cases where it previously would succeed, but these policies should be very uncommon. + + An example policy that previously would succeed but no longer will (wrong arity): + +```rego +package policy + +default p = false +p { + x := is_blue() + input.bar[x] +} + +is_blue(fruit) = y { # doesn't use fruit + y := input.foo +} +``` + +### SDK + +- The `opa.runtime()` built-in is now made available to the SDK ([#4050](https://github.com/open-policy-agent/opa/issues/4050) authored by @[oren-zohar](https://github.com/oren-zohar) and @[cmschuetz](https://github.com/cmschuetz) +- Plugins are now exposed on the SDK object +- The SDK now supports graceful shutdown ([#3980](https://github.com/open-policy-agent/opa/issues/3980)) reported by @[brianchhun-chime](https://github.com/brianchhun-chime) +- `print` output is now sent to the configured logger + +### Website and Documentation + +- All pages in the docs now have a feedback button ([#3664](https://github.com/open-policy-agent/opa/issues/3664)) authored by @[alan-ma](https://github.com/alan-ma) +- The Kafka docs have been updated to use the new Kafka plugin, and to use the OPA management APIs +- The Terraform tutorial was updated to use `opa exec` ([#3965](https://github.com/open-policy-agent/opa/issues/3965)) +- The docs on Contributing as well as the Vendor Guidelines have been updated +- The term "whitelist" has been replaced by "allowlist" across the docs +- A simple destructuring assignment example was added to the docs +- The docs have been reviewed on the use of assignment, equality and comparison operators, to make sure they follow best practice + +### CI + +- SHA256 checksums of CI builds now published to release directory ([#3448](https://github.com/open-policy-agent/opa/issues/3448)) authored by @[johanneslarsson](https://github.com/johanneslarsson) reported by @[raesene](https://github.com/raesene) +- golangci-lint upgraded to v1.43.0 (authored by @[shuheiktgw](https://github.com/shuheiktgw)) +- The build now creates an executable for darwin/arm64. This should work as expected, but is currently tested in the CI pipeline like the other binaries +- PRs targeting the [ecosystem](https://www.openpolicyagent.org/docs/latest/ecosystem/) page are now checked for mistakes using Rego policies + ## 0.35.0 This release contains a number of fixes and enhancements. @@ -1370,7 +3688,7 @@ more information see https://openpolicyagent.org/docs/latest/privacy/. #### New `opa build` command The `opa build` command can now be used to package OPA policy and data files -into [bundles](https://www.openpolicyagent.org/docs/latest/management/#bundles) +into [bundles](https://www.openpolicyagent.org/docs/latest/management-bundles) that can be easily distributed via HTTP. See `opa build --help` for details. This change is backwards incompatible. If you were previously relying on `opa build` to compile policies to wasm, you can still do so: @@ -2305,7 +4623,7 @@ pass `"force_json_decode": true` as in the `http.send` parameters. * This release adds support for scoping bundles to specific roots under `data`. This allows bundles to be used in conjunction with sidecars like `kube-mgmt` that load local data and policy into - OPA. See the [Bundles](https://www.openpolicyagent.org/docs/bundles.html) + OPA. See the [Bundles](https://www.openpolicyagent.org/docs/latest/management-bundles) page for more details. * This release includes a small but backwards incompatible change to diff --git a/COMMUNITY_BADGES.md b/COMMUNITY_BADGES.md new file mode 100644 index 0000000000..9ef99bd74e --- /dev/null +++ b/COMMUNITY_BADGES.md @@ -0,0 +1,67 @@ +# OPA Community Badges + +Hello OPA Community Members! We thank you for taking the time to be a part of the community and for all of the work you’ve put in. As thanks, the OPA community now offers badges to recognize your contributions to the community. + +As of right now, these badges are simple tokens of appreciation that we will proudly display in the OPA GitHub repository. As the program develops, we will provide digital badges and certificates for members to display on their Linkedin and Twitter accounts. + +## Contributor Badge + +We appreciate contributions to the OPA project, no matter how big or small. So if you’ve contributed a dozen lines of code or corrected a typo in the docs, let us know! + +### Requirements + +- 1 Merged PR + +## Spokesperson Badge + +For community members that like to create OPA content, share your work to receive the OPA Spokesperson Badge. + +### Requirements + +One or more pieces of publically posted content such as any of the following (Available for non-vendor related content only): + +- Conference Talk +- Meetup presentation +- Online Video Recording +- Technical blog or OPA tutorial + +## Super User Badge + +This badge is for members that know the value of running OPA in production. Add your organization to the ADOPTERS.md file, then share a quick snippet about how OPA is used in your production environment to earn this badge. + +### Requirements + +- Deploy or Support OPA in a production environment +- Add your organization to the ADOPTERS.md file + +### Benefits + +- Access to a private OPA Super Users Slack Channel +- Support writing an OPA CFP to grow your portfolio + +## Community Champion Badge + +Our OPA Community Champions badges are for active members that want to create a measurable impact in the community. Apply for this badge if you believe in the OPA mission to unify policy authorization and wish to help the OPA community grow and expand. + +### Requirements + +Actively help and support the OPA community. Some places we will look for contributions are: + +- Slack +- GitHub Discussions +- Stack Overflow +- Meetups + +Being a champion in the OPA community can take many shapes and forms; as such, we do not want to limit this award to a static number of posts or comments. Instead, we are looking for members who take an active role in helping to build the community, or we look for members that want to have a more active role. + +### Benefits + +- Access to a private OPA Champions Slack Channel +- Help from the OPA team to assist with community issues +- Increased support for your OPA projects + +## Application Process + +To apply for any of the OPA badges, fill out the form at: https://airtable.com/shrrd45fRi7yO2k6i + +For any questions, please reach out to peteroneilljr@styra.com or @peteroneilljr in [Slack](https://slack.openpolicyagent.org/). diff --git a/Dockerfile b/Dockerfile index 097f95c752..667b7286f2 100644 --- a/Dockerfile +++ b/Dockerfile @@ -6,20 +6,30 @@ ARG BASE FROM ${BASE} -# Any non-zero number will do, and unfortunately a named user will not, as k8s -# pod securityContext runAsNonRoot can't resolve the user ID: -# https://github.com/kubernetes/kubernetes/issues/40958. Make root (uid 0) when -# not specified. -ARG USER=0 +LABEL org.opencontainers.image.authors="Torin Sandall " +LABEL org.opencontainers.image.source="https://github.com/open-policy-agent/opa" + -MAINTAINER Torin Sandall +# Temporarily allow us to identify whether running from within an offical +# Docker image with a "rootless" tag, so that we may print a warning that this image tag +# will not be published after 0.50.0. Remove after 0.50.0 release. +ARG OPA_DOCKER_IMAGE_TAG +ENV OPA_DOCKER_IMAGE_TAG=${OPA_DOCKER_IMAGE_TAG} -# Hack.. https://github.com/moby/moby/issues/37965 -# _Something_ needs to be between the two COPY steps. +# Any non-zero number will do, and unfortunately a named user will not, as k8s +# pod securityContext runAsNonRoot can't resolve the user ID: +# https://github.com/kubernetes/kubernetes/issues/40958. +ARG USER=1000:1000 USER ${USER} -ARG BIN=./opa_linux_amd64 -COPY ${BIN} /opa +# TARGETOS and TARGETARCH are automatic platform args injected by BuildKit +# https://docs.docker.com/engine/reference/builder/#automatic-platform-args-in-the-global-scope +ARG TARGETOS +ARG TARGETARCH +ARG BIN_DIR=. +ARG BIN_SUFFIX= +COPY ${BIN_DIR}/opa_${TARGETOS}_${TARGETARCH}${BIN_SUFFIX} /opa +ENV PATH=${PATH}:/ ENTRYPOINT ["/opa"] CMD ["run"] diff --git a/MAINTAINERS.md b/MAINTAINERS.md index 3e7c5ef2e6..d773a2d044 100644 --- a/MAINTAINERS.md +++ b/MAINTAINERS.md @@ -4,15 +4,20 @@ The following table lists OPA project maintainers and areas of expertise in alph | Name | GitHub | Email | Organization | Repositories/Area of Expertise | Added/Renewed On | | --- | --- | --- | --- | --- | --- | -| Ash Narkar | @ashutosh-narkar | anarkar4387@gmail.com | Styra | opa, opa-envoy-plugin | 2021-03-29 | -| Craig Tabita | @ctab | ctab@google.com | Google | gatekeeper, gatekeeper-library, cert-controller | 2021-03-29 | -| Max Smythe | @maxsmythe | smythe@google.com | Google | frameworks/constraints, gatekeeper, gatekeeper-library, cert-controller | 2021-03-29 | +| Andrew Peabody | @apeabody | andrewpeabody@google.com | Google | gatekeeper-library | 2023-03-31 | +| Ash Narkar | @ashutosh-narkar | anarkar4387@gmail.com | Styra | opa, opa-envoy-plugin | 2023-03-31 | +| Max Smythe | @maxsmythe | smythe@google.com | Google | frameworks/constraints, gatekeeper, gatekeeper-library, cert-controller | 2023-03-31 | +| Nilekh Chaudhari | @nilekhc | nilekhc@gmail.com | Microsoft | gatekeeper-library | 2023-03-31 | | Oren Shomron | @shomron | shomron@gmail.com | VMware | frameworks/constraints, gatekeeper, gatekeeper-library, cert-controller | 2020-11-13 | -| Rita Zhang | @ritazh | rita.z.zhang@gmail.com | Microsoft | frameworks/constraints, gatekeeper, gatekeeper-library, cert-controller | 2021-03-29 | -| Sertaç Özercan | @sozercan | sozercan@gmail.com | Microsoft | gatekeeper, gatekeeper-library, cert-controller | 2021-03-29 | -| Tim Hinrichs | @timothyhinrichs | timothy.l.hinrichs@gmail.com | Styra | all repositories | 2021-03-29 | -| Torin Sandall | @tsandall | torinsandall@gmail.com | Styra | all repositories | 2021-03-29 | +| Rita Zhang | @ritazh | rita.z.zhang@gmail.com | Microsoft | frameworks/constraints, gatekeeper, gatekeeper-library, cert-controller | 2023-03-31 | +| Sertaç Özercan | @sozercan | sozercan@gmail.com | Microsoft | gatekeeper, gatekeeper-library, cert-controller, gatekeeper-external-data-provider | 2023-03-31 | +| Stephan Renatus | @srenatus | stephan@styra.com | Styra | opa | 2023-03-31 | +| Tim Hinrichs | @timothyhinrichs | timothy.l.hinrichs@gmail.com | Styra | all repositories | 2023-03-31 | +| Torin Sandall | @tsandall | torinsandall@gmail.com | Styra | all repositories | 2023-03-31 | ## Emeritus +* [Craig Tabita](https://github.com/ctab) +* [Ernest Wong](https://github.com/chewong) * [Patrick East](https://github.com/patrick-east) +* [Will Beason](https://github.com/willbeason) diff --git a/Makefile b/Makefile index 7cb3e774b1..843a413601 100644 --- a/Makefile +++ b/Makefile @@ -14,18 +14,22 @@ WASM_ENABLED ?= 1 GO := CGO_ENABLED=$(CGO_ENABLED) GOFLAGS="-buildmode=exe" go GO_TEST_TIMEOUT := -timeout 30m +GOVERSION ?= $(shell cat ./.go-version) +GOARCH := $(shell go env GOARCH) +GOOS := $(shell go env GOOS) + +ifeq ($(GOOS)/$(GOARCH),darwin/arm64) +WASM_ENABLED=0 +endif + GO_TAGS := -tags= ifeq ($(WASM_ENABLED),1) GO_TAGS = -tags=opa_wasm endif -GOVERSION ?= $(shell cat ./.go-version) -GOARCH := $(shell go env GOARCH) -GOOS := $(shell go env GOOS) - -GOLANGCI_LINT_VERSION := v1.40.1 +GOLANGCI_LINT_VERSION := v1.51.0 -DOCKER_RUNNING := $(shell docker ps >/dev/null 2>&1 && echo 1 || echo 0) +DOCKER_RUNNING ?= $(shell docker ps >/dev/null 2>&1 && echo 1 || echo 0) # We use root because the windows build, invoked through the ci-go-build-windows # target, installs the gcc mingw32 cross-compiler. @@ -41,16 +45,21 @@ endif DOCKER := docker +# BuildKit is required for automatic platform arg injection (see Dockerfile) +export DOCKER_BUILDKIT := 1 + +# Supported platforms to include in image manifest lists +DOCKER_PLATFORMS := linux/amd64 +DOCKER_PLATFORMS_STATIC := linux/amd64,linux/arm64 + BIN := opa_$(GOOS)_$(GOARCH) # Optional external configuration useful for forks of OPA DOCKER_IMAGE ?= openpolicyagent/opa S3_RELEASE_BUCKET ?= opa-releases -FUZZ_TIME ?= 3600 # 1hr +FUZZ_TIME ?= 1h TELEMETRY_URL ?= #Default empty -BUILD_COMMIT := $(shell ./build/get-build-commit.sh) -BUILD_TIMESTAMP := $(shell ./build/get-build-timestamp.sh) BUILD_HOSTNAME := $(shell ./build/get-build-hostname.sh) RELEASE_BUILD_IMAGE := golang:$(GOVERSION) @@ -62,9 +71,6 @@ TELEMETRY_FLAG := -X github.com/open-policy-agent/opa/internal/report.ExternalSe endif LDFLAGS := "$(TELEMETRY_FLAG) \ - -X github.com/open-policy-agent/opa/version.Version=$(VERSION) \ - -X github.com/open-policy-agent/opa/version.Vcs=$(BUILD_COMMIT) \ - -X github.com/open-policy-agent/opa/version.Timestamp=$(BUILD_TIMESTAMP) \ -X github.com/open-policy-agent/opa/version.Hostname=$(BUILD_HOSTNAME)" @@ -155,7 +161,7 @@ clean: wasm-lib-clean .PHONY: fuzz fuzz: - $(MAKE) -C ./build/fuzzer all + go test ./ast -fuzz FuzzParseStatementsAndCompileModules -fuzztime ${FUZZ_TIME} -v -run '^$$' ###################################################### # @@ -238,20 +244,20 @@ CI_GOLANG_DOCKER_MAKE := $(DOCKER) run \ -v $(PWD):/src \ -w /src \ -e GOCACHE=/src/.go/cache \ + -e GOARCH=$(GOARCH) \ -e CGO_ENABLED=$(CGO_ENABLED) \ -e WASM_ENABLED=$(WASM_ENABLED) \ -e FUZZ_TIME=$(FUZZ_TIME) \ -e TELEMETRY_URL=$(TELEMETRY_URL) \ - golang:$(GOVERSION) \ - make + golang:$(GOVERSION) .PHONY: ci-go-% ci-go-%: generate - $(CI_GOLANG_DOCKER_MAKE) $* + $(CI_GOLANG_DOCKER_MAKE) /bin/bash -c "git config --system --add safe.directory /src && make $*" .PHONY: ci-release-test ci-release-test: generate - $(CI_GOLANG_DOCKER_MAKE) test perf wasm-sdk-e2e-test check + $(CI_GOLANG_DOCKER_MAKE) make test perf wasm-sdk-e2e-test check .PHONY: ci-check-working-copy ci-check-working-copy: generate @@ -261,7 +267,7 @@ ci-check-working-copy: generate ci-wasm: wasm-test .PHONY: ci-build-linux -ci-build-linux: ensure-release-dir +ci-build-linux: ensure-release-dir ensure-linux-toolchain @$(MAKE) build GOOS=linux chmod +x opa_linux_$(GOARCH) mv opa_linux_$(GOARCH) $(RELEASE_DIR)/ @@ -301,84 +307,134 @@ ci-build-windows: ensure-release-dir ensure-release-dir: mkdir -p $(RELEASE_DIR) +.PHONY: ensure-executable-bin +ensure-executable-bin: + find $(RELEASE_DIR) -type f ! -name "*.sha256" | xargs chmod +x + +.PHONY: ensure-linux-toolchain +ensure-linux-toolchain: +ifeq ($(CGO_ENABLED),1) + $(eval export CC = $(shell GOARCH=$(GOARCH) build/ensure-linux-toolchain.sh)) +else + @echo "CGO_ENABLED=$(CGO_ENABLED). No need to check gcc toolchain." +endif + .PHONY: build-all-platforms build-all-platforms: ci-build-linux ci-build-linux-static ci-build-darwin ci-build-darwin-arm64-static ci-build-windows .PHONY: image-quick -image-quick: - chmod +x $(RELEASE_DIR)/opa_linux_$(GOARCH)* +image-quick: image-quick-$(GOARCH) + +# % = arch +.PHONY: image-quick-% +image-quick-%: ensure-executable-bin +ifneq ($(GOARCH),arm64) # build only static images for arm64 $(DOCKER) build \ -t $(DOCKER_IMAGE):$(VERSION) \ - --build-arg BASE=gcr.io/distroless/cc \ - --build-arg BIN=$(RELEASE_DIR)/opa_linux_$(GOARCH) \ + --build-arg BASE=cgr.dev/chainguard/cc-dynamic \ + --build-arg BIN_DIR=$(RELEASE_DIR) \ + --platform linux/$* \ . $(DOCKER) build \ -t $(DOCKER_IMAGE):$(VERSION)-debug \ - --build-arg BASE=gcr.io/distroless/cc:debug \ - --build-arg BIN=$(RELEASE_DIR)/opa_linux_$(GOARCH) \ + --build-arg BASE=cgr.dev/chainguard/cc-dynamic:latest-dev \ + --build-arg BIN_DIR=$(RELEASE_DIR) \ + --platform linux/$* \ . $(DOCKER) build \ -t $(DOCKER_IMAGE):$(VERSION)-rootless \ - --build-arg USER=1000 \ - --build-arg BASE=gcr.io/distroless/cc \ - --build-arg BIN=$(RELEASE_DIR)/opa_linux_$(GOARCH) \ + --build-arg OPA_DOCKER_IMAGE_TAG=rootless \ + --build-arg BASE=cgr.dev/chainguard/cc-dynamic:latest \ + --build-arg BIN_DIR=$(RELEASE_DIR) \ + --platform linux/$* \ . +endif $(DOCKER) build \ -t $(DOCKER_IMAGE):$(VERSION)-static \ - --build-arg BASE=gcr.io/distroless/static \ - --build-arg BIN=$(RELEASE_DIR)/opa_linux_$(GOARCH)_static \ + --build-arg BASE=cgr.dev/chainguard/static:latest \ + --build-arg BIN_DIR=$(RELEASE_DIR) \ + --build-arg BIN_SUFFIX=_static \ + --platform linux/$* \ + . + + $(DOCKER) build \ + -t $(DOCKER_IMAGE):$(VERSION)-static-debug \ + --build-arg BASE=cgr.dev/chainguard/busybox:latest-glibc \ + --build-arg BIN_DIR=$(RELEASE_DIR) \ + --build-arg BIN_SUFFIX=_static \ + --platform linux/$* \ + . + +# % = base tag +.PHONY: push-manifest-list-% +push-manifest-list-%: ensure-executable-bin + $(DOCKER) buildx build \ + --tag $(DOCKER_IMAGE):$* \ + --build-arg BASE=cgr.dev/chainguard/cc-dynamic:latest \ + --build-arg BIN_DIR=$(RELEASE_DIR) \ + --platform $(DOCKER_PLATFORMS) \ + --push \ + . + # TODO: update busybox shell debug images to image without openssl + $(DOCKER) buildx build \ + --tag $(DOCKER_IMAGE):$*-debug \ + --build-arg BASE=cgr.dev/chainguard/cc-dynamic:latest-dev \ + --build-arg BIN_DIR=$(RELEASE_DIR) \ + --platform $(DOCKER_PLATFORMS) \ + --push \ + . + $(DOCKER) buildx build \ + --tag $(DOCKER_IMAGE):$*-rootless \ + --build-arg OPA_DOCKER_IMAGE_TAG=rootless \ + --build-arg BASE=cgr.dev/chainguard/cc-dynamic:latest \ + --build-arg BIN_DIR=$(RELEASE_DIR) \ + --platform $(DOCKER_PLATFORMS) \ + --push \ + . + + $(DOCKER) buildx build \ + --tag $(DOCKER_IMAGE):$*-static \ + --build-arg BASE=cgr.dev/chainguard/static:latest \ + --build-arg BIN_DIR=$(RELEASE_DIR) \ + --build-arg BIN_SUFFIX=_static \ + --platform $(DOCKER_PLATFORMS_STATIC) \ + --push \ + . + + $(DOCKER) buildx build \ + --tag $(DOCKER_IMAGE):$*-static-debug \ + --build-arg BASE=cgr.dev/chainguard/busybox:latest-glibc \ + --build-arg BIN_DIR=$(RELEASE_DIR) \ + --build-arg BIN_SUFFIX=_static \ + --platform $(DOCKER_PLATFORMS_STATIC) \ + --push \ . .PHONY: ci-image-smoke-test -ci-image-smoke-test: image-quick - $(DOCKER) run $(DOCKER_IMAGE):$(VERSION) version - $(DOCKER) run $(DOCKER_IMAGE):$(VERSION)-debug version - $(DOCKER) run $(DOCKER_IMAGE):$(VERSION)-rootless version - $(DOCKER) run $(DOCKER_IMAGE):$(VERSION)-static version +ci-image-smoke-test: ci-image-smoke-test-$(GOARCH) + +# % = arch +.PHONY: ci-image-smoke-test-% +ci-image-smoke-test-%: image-quick-% +ifneq ($(GOARCH),arm64) # we build only static images for arm64 + $(DOCKER) run --platform linux/$* $(DOCKER_IMAGE):$(VERSION) version + $(DOCKER) run --platform linux/$* $(DOCKER_IMAGE):$(VERSION)-debug version + + $(DOCKER) image inspect $(DOCKER_IMAGE):$(VERSION) |\ + $(DOCKER) run --interactive --platform linux/$* $(DOCKER_IMAGE):$(VERSION) \ + eval --fail --format raw --stdin-input 'input[0].Config.User = "1000:1000"' +endif + $(DOCKER) run --platform linux/$* $(DOCKER_IMAGE):$(VERSION)-static version +# % = rego/wasm .PHONY: ci-binary-smoke-test-% ci-binary-smoke-test-%: chmod +x "$(RELEASE_DIR)/$(BINARY)" - "$(RELEASE_DIR)/$(BINARY)" eval -t "$*" 'time.now_ns()' - -.PHONY: push -push: - $(DOCKER) push $(DOCKER_IMAGE):$(VERSION) - $(DOCKER) push $(DOCKER_IMAGE):$(VERSION)-debug - $(DOCKER) push $(DOCKER_IMAGE):$(VERSION)-rootless - $(DOCKER) push $(DOCKER_IMAGE):$(VERSION)-static - -.PHONY: tag-latest -tag-latest: - $(DOCKER) tag $(DOCKER_IMAGE):$(VERSION) $(DOCKER_IMAGE):latest - $(DOCKER) tag $(DOCKER_IMAGE):$(VERSION)-debug $(DOCKER_IMAGE):latest-debug - $(DOCKER) tag $(DOCKER_IMAGE):$(VERSION)-rootless $(DOCKER_IMAGE):latest-rootless - $(DOCKER) tag $(DOCKER_IMAGE):$(VERSION)-static $(DOCKER_IMAGE):latest-static - -.PHONY: push-latest -push-latest: - $(DOCKER) push $(DOCKER_IMAGE):latest - $(DOCKER) push $(DOCKER_IMAGE):latest-debug - $(DOCKER) push $(DOCKER_IMAGE):latest-rootless - $(DOCKER) push $(DOCKER_IMAGE):latest-static + ./build/binary-smoke-test.sh "$(RELEASE_DIR)/$(BINARY)" "$*" .PHONY: push-binary-edge push-binary-edge: - aws s3 sync $(RELEASE_DIR) s3://$(S3_RELEASE_BUCKET)/edge/ --delete - -.PHONY: tag-edge -tag-edge: - $(DOCKER) tag $(DOCKER_IMAGE):$(VERSION) $(DOCKER_IMAGE):edge - $(DOCKER) tag $(DOCKER_IMAGE):$(VERSION)-debug $(DOCKER_IMAGE):edge-debug - $(DOCKER) tag $(DOCKER_IMAGE):$(VERSION)-rootless $(DOCKER_IMAGE):edge-rootless - $(DOCKER) tag $(DOCKER_IMAGE):$(VERSION)-static $(DOCKER_IMAGE):edge-static - -.PHONY: push-edge -push-edge: - $(DOCKER) push $(DOCKER_IMAGE):edge - $(DOCKER) push $(DOCKER_IMAGE):edge-debug - $(DOCKER) push $(DOCKER_IMAGE):edge-rootless - $(DOCKER) push $(DOCKER_IMAGE):edge-static + aws s3 sync $(RELEASE_DIR) s3://$(S3_RELEASE_BUCKET)/edge/ --no-progress --region us-west-1 .PHONY: docker-login docker-login: @@ -386,14 +442,14 @@ docker-login: @echo ${DOCKER_PASSWORD} | $(DOCKER) login -u ${DOCKER_USER} --password-stdin .PHONY: push-image -push-image: docker-login image-quick push +push-image: docker-login push-manifest-list-$(VERSION) .PHONY: push-wasm-builder-image push-wasm-builder-image: docker-login $(MAKE) -C wasm push-builder .PHONY: deploy-ci -deploy-ci: push-image tag-edge push-edge push-binary-edge +deploy-ci: push-image push-manifest-list-edge push-binary-edge .PHONY: release-ci # Don't tag and push "latest" image tags if the version is a release candidate or a bugfix branch @@ -401,18 +457,18 @@ deploy-ci: push-image tag-edge push-edge push-binary-edge ifneq (,$(or $(findstring rc,$(VERSION)), $(findstring release-,$(shell git branch --contains HEAD)))) release-ci: push-image else -release-ci: push-image tag-latest push-latest +release-ci: push-image push-manifest-list-latest endif .PHONY: netlify-prod -netlify-prod: clean docs-clean build docs-generate docs-production-build +netlify-prod: clean docs-clean build docs-production-build .PHONY: netlify-preview -netlify-preview: clean docs-clean build docs-live-blocks-install-deps docs-live-blocks-test docs-generate docs-preview-build +netlify-preview: clean docs-clean build docs-live-blocks-install-deps docs-live-blocks-test docs-dev-generate docs-preview-build +# Kept for compatibility. Use `make fuzz` instead. .PHONY: check-fuzz -check-fuzz: - ./build/check-fuzz.sh $(FUZZ_TIME) +check-fuzz: fuzz # GOPRIVATE=* causes go to fetch all dependencies from their corresponding VCS # source, not through the golang-provided proxy services. We're cleaning out @@ -427,7 +483,7 @@ check-go-module: -e 'GOPRIVATE=*' \ --tmpfs /src/.go \ golang:$(GOVERSION) \ - go mod vendor -v + /bin/bash -c "git config --system --add safe.directory /src && go mod vendor -v" ###################################################### # @@ -444,14 +500,14 @@ endif -e GITHUB_TOKEN=$(GITHUB_TOKEN) \ -e LAST_VERSION=$(LAST_VERSION) \ -v $(PWD):/_src \ - python:2.7 \ + cmd.cat/make/git/go/python3/perl \ /_src/build/gen-release-patch.sh --version=$(VERSION) --source-url=/_src .PHONY: dev-patch dev-patch: @$(DOCKER) run $(DOCKER_FLAGS) \ -v $(PWD):/_src \ - python:2.7 \ + cmd.cat/make/git/go/python3/perl \ /_src/build/gen-dev-patch.sh --version=$(VERSION) --source-url=/_src # Deprecated targets. To be removed. diff --git a/README.md b/README.md index fadf0efc4e..8f2b788c87 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # ![logo](./logo/logo-144x144.png) Open Policy Agent -[![Slack Status](http://slack.openpolicyagent.org/badge.svg)](https://slack.openpolicyagent.org) [![Build Status](https://github.com/open-policy-agent/opa/workflows/Post%20Merge/badge.svg?branch=main)](https://github.com/open-policy-agent/opa/actions) [![Go Report Card](https://goreportcard.com/badge/open-policy-agent/opa)](https://goreportcard.com/report/open-policy-agent/opa) [![CII Best Practices](https://bestpractices.coreinfrastructure.org/projects/1768/badge)](https://bestpractices.coreinfrastructure.org/projects/1768) [![Netlify Status](https://api.netlify.com/api/v1/badges/4a0a092a-8741-4826-a28f-826d4a576cab/deploy-status)](https://app.netlify.com/sites/openpolicyagent/deploys) +[![Build Status](https://github.com/open-policy-agent/opa/workflows/Post%20Merge/badge.svg?branch=main)](https://github.com/open-policy-agent/opa/actions) [![Go Report Card](https://goreportcard.com/badge/open-policy-agent/opa)](https://goreportcard.com/report/open-policy-agent/opa) [![CII Best Practices](https://bestpractices.coreinfrastructure.org/projects/1768/badge)](https://bestpractices.coreinfrastructure.org/projects/1768) [![Netlify Status](https://api.netlify.com/api/v1/badges/4a0a092a-8741-4826-a28f-826d4a576cab/deploy-status)](https://app.netlify.com/sites/openpolicyagent/deploys) Open Policy Agent (OPA) is an open source, general-purpose policy engine that enables unified, context-aware policy enforcement across the entire stack. @@ -9,7 +9,19 @@ OPA is proud to be a graduated project in the [Cloud Native Computing Foundation ## Want to connect with the community or get support for OPA? - Join the [OPA Slack](https://slack.openpolicyagent.org) for day-to-day conversations with the OPA community. -- Need Support? Go to the [Discussions Board](https://github.com/open-policy-agent/feedback/discussions) to ask questions. +- Need Support? Check out the [Community Discussions](https://github.com/orgs/open-policy-agent/discussions) to ask questions. + +## Join the OPA Office Hours + +**Every Week at these times** + * 10:00 PT / 19:00 CET + * 01:30 PT / 10:30 CET + +These sessions are open format for community members. Come and ask about new features or the road map for the next release. You can use this time to get unblocked with your OPA deployments, learn more about the project, or to get more involved in the community. + + * [Book a slot](https://calendly.com/styra-devrel/opa-office-hours) + +*Watch a replay of a previous office hours on [YouTube](https://www.youtube.com/watch?v=TxFZPrbc9jk&list=PLW-2W4VHBA4RRhCHiiaKHZsDFx-FwH8VU)* ## What is the this fork for? @@ -26,32 +38,30 @@ git cherry-pick {gitsha of commit with your patches} - Go to [openpolicyagent.org](https://www.openpolicyagent.org) to get started with documentation and tutorials. - Browse [blog.openpolicyagent.org](https://blog.openpolicyagent.org) for news about OPA, community, policy and authorization. +- Watch OPA's [YouTube](https://www.youtube.com/channel/UClDMRN5HlqD3di5MMf-SV4A) channel for past office hours and other content. - Try OPA with the [Rego Playground](https://play.openpolicyagent.org) to experiment with policies and share your work. - View the [OPA Roadmap](https://docs.google.com/presentation/d/16QV6gvLDOV3I0_guPC3_19g6jHkEg3X9xqMYgtoCKrs/edit?usp=sharing) to see a high-level snapshot of OPA features in-progress and planned. - Check out the [ADOPTERS.md](./ADOPTERS.md) file for a list of production adopters. Does your organization use OPA in production? Support the OPA project by submitting a PR to add your organization to the list with a short description of your OPA use cases! -## Want to get OPA? +## Want to download OPA? - [Docker Hub](https://hub.docker.com/r/openpolicyagent/opa/tags/) for Docker images. - [GitHub releases](https://github.com/open-policy-agent/opa/releases) for binary releases and changelogs. ## Want to integrate OPA? -* See +* See the high-level [Go SDK](https://www.openpolicyagent.org/docs/latest/integration/#integrating-with-the-go-sdk) or the low-level Go API [![GoDoc](https://godoc.org/github.com/open-policy-agent/opa?status.svg)](https://godoc.org/github.com/open-policy-agent/opa/rego) to integrate OPA with services written in Go. * See [REST API](https://www.openpolicyagent.org/docs/rest-api.html) to integrate OPA with services written in other languages. +* See the [integration docs](https://www.openpolicyagent.org/docs/latest/integration/) for more options. ## Want to contribute to OPA? * Read the [Contributing Guide](https://www.openpolicyagent.org/docs/latest/contributing/) to learn how to make your first contribution. -* Use [#development](https://openpolicyagent.slack.com/archives/C02L1TLPN59) in Slack to talk to the OPA maintainers and other contributors. +* Use [#contributors](https://openpolicyagent.slack.com/archives/C02L1TLPN59) in Slack to talk to other contributors and OPA maintainers. * File a [GitHub Issue](https://github.com/open-policy-agent/opa/issues) to request features or report bugs. -* Join the OPA bi-weekly meetings every other Tuesday at 10:00 (Pacific Timezone): - * [Meeting Notes](https://docs.google.com/document/d/1v6l2gmkRKAn5UIg3V2QdeeCcXMElxsNzEzDkVlWDVg8/edit?usp=sharing) - * [Zoom](https://zoom.us/j/97827947600) - * [Calendar Invite](https://calendar.google.com/event?action=TEMPLATE&tmeid=MnRvb2M4amtldXBuZ2E1azY0MTJndjh0ODRfMjAxODA5MThUMTcwMDAwWiBzdHlyYS5jb21fY28zOXVzc3VobnE2amUzN2l2dHQyYmNiZGdAZw&tmsrc=styra.com_co39ussuhnq6je37ivtt2bcbdg%40group.calendar.google.com&scp=ALL) ## How does OPA work? @@ -84,6 +94,8 @@ For concrete examples of how to integrate OPA with systems like [Kubernetes](htt ## Presentations +- OPA maintainers talk @ Kubecon NA 2022: [video](https://www.youtube.com/watch?v=RMiovzGGCfI) +- Open Policy Agent (OPA) Intro & Deep Dive @ Kubecon EU 2022: [video](https://www.youtube.com/watch?v=MhyQxIp1H58) - Open Policy Agent Intro @ KubeCon EU 2021: [Video](https://www.youtube.com/watch?v=2CgeiWkliaw) - Using Open Policy Agent to Meet Evolving Policy Requirements @ KubeCon NA 2020: [video](https://www.youtube.com/watch?v=zVuM7F_BTyc) - Applying Policy Throughout The Application Lifecycle with Open Policy Agent @ CloudNativeCon 2019: [video](https://www.youtube.com/watch?v=cXfsaE6RKfc) diff --git a/SECURITY.md b/SECURITY.md index a154a5e164..d90b7e962f 100644 --- a/SECURITY.md +++ b/SECURITY.md @@ -1,260 +1,5 @@ -# Security Release Process +# Security Policy -The Open Policy Agent (OPA) community has adopted this security disclosures and -response policy to ensure we responsibly handle critical issues. - -## Product Security Team (PST) - -Security vulnerabilities should be handled quickly and sometimes privately. The primary goal of this -process is to reduce the total time users are vulnerable to publicly known exploits. - -The Product Security Team (PST) is responsible for organizing the entire response including internal -communication and external disclosure but will need help from relevant developers to successfully -run this process. - -The initial Product Security Team will consist of all [maintainers](MAINTAINERS.md) in the private -[open-policy-agent-security](https://groups.google.com/forum/#!forum/open-policy-agent-security) list. In the future we may -decide to have a subset of maintainers work on security response given that this process is time -consuming. - -## Disclosures - -### Private Disclosure Processes - -The OPA community asks that all suspected vulnerabilities be privately and responsibly disclosed -via the [reporting policy](README.md#reporting-security-vulnerabilities). - -### Public Disclosure Processes - -If you know of a publicly disclosed security vulnerability please IMMEDIATELY email -[open-policy-agent-security](https://groups.google.com/forum/#!forum/open-policy-agent-security) to inform the Product -Security Team (PST) about the vulnerability so they may start the patch, release, and communication -process. - -If possible the PST will ask the person making the public report if the issue can be handled via a -private disclosure process (for example if the full exploit details have not yet been published). If -the reporter denies the request for private disclosure, the PST will move swiftly with the fix and -release process. In extreme cases GitHub can be asked to delete the issue but this generally isn't -necessary and is unlikely to make a public disclosure less damaging. - -## Patch, Release, and Public Communication - -For each vulnerability a member of the PST will volunteer to lead coordination with the "Fix Team" -and is responsible for sending disclosure emails to the rest of the community. This lead will be -referred to as the "Fix Lead." - -The role of Fix Lead should rotate round-robin across the PST. - -Note that given the current size of the OPA community it is likely that the PST is the same as -the "Fix team." (I.e., all maintainers). The PST may decide to bring in additional contributors -for added expertise depending on the area of the code that contains the vulnerability. - -All of the timelines below are suggestions and assume a private disclosure. The Fix Lead drives the -schedule using their best judgment based on severity and development time. If the Fix Lead is -dealing with a public disclosure all timelines become ASAP (assuming the vulnerability has a CVSS -score >= 4; see below). If the fix relies on another upstream project's disclosure timeline, that -will adjust the process as well. We will work with the upstream project to fit their timeline and -best protect our users. - -### Fix Team Organization - -These steps should be completed within the first 24 hours of disclosure. - -- The Fix Lead will work quickly to identify relevant engineers from the affected projects and - packages and CC those engineers into the disclosure thread. These selected developers are the Fix - Team. -- The Fix Lead will get the Fix Team access to private security repos to develop the fix. - -### Fix Development Process - -These steps should be completed within the 1-7 days of Disclosure. - -- The Fix Lead and the Fix Team will create a - [CVSS](https://www.first.org/cvss/specification-document) using the [CVSS - Calculator](https://www.first.org/cvss/calculator/3.0). The Fix Lead makes the final call on the - calculated CVSS; it is better to move quickly than make the CVSS perfect. -- The Fix Team will notify the Fix Lead that work on the fix branch is complete once there are LGTMs - on all commits in the private repo from one or more maintainers. - -If the CVSS score is under 4.0 ([a low severity -score](https://www.first.org/cvss/specification-document#i5)) the Fix Team can decide to slow the -release process down in the face of holidays, developer bandwidth, etc. These decisions must be -discussed on the open-policy-agent-security mailing list. - -### Fix Disclosure Process - -With the fix development underway, the Fix Lead needs to come up with an overall communication plan -for the wider community. This Disclosure process should begin after the Fix Team has developed a Fix -or mitigation so that a realistic timeline can be communicated to users. - -**Disclosure of Forthcoming Fix to Users** (Completed within 1-7 days of Disclosure) - -- The Fix Lead will email [open-policy-agent-announce@googlegroups.com](https://groups.google.com/forum/#!forum/open-policy-agent-announce) - informing users that a security vulnerability has been disclosed and that a fix will be made - available at YYYY-MM-DD HH:MM UTC in the future via this list. This time is the Release Date. -- The Fix Lead will include any mitigating steps users can take until a fix is available. - -The communication to users should be actionable. They should know when to block time to apply -patches, understand exact mitigation steps, etc. - -**Optional Fix Disclosure to Private Distributors List** (Completed within 1-14 days of Disclosure): - -- The Fix Lead will make a determination with the help of the Fix Team if an issue is critical enough - to require early disclosure to distributors. Generally this Private Distributor Disclosure process - should be reserved for remotely exploitable or privilege escalation issues. Otherwise, this - process can be skipped. -- The Fix Lead will email the patches to open-policy-agent-distributors-announce@googlegroups.com so - distributors can prepare builds to be available to users on the day of the issue's announcement. - Distributors should read about the [Private Distributors List](#private-distributors-list) to find - out the requirements for being added to this list. -- **What if a vendor breaks embargo?** The PST will assess the damage. The Fix Lead will make the - call to release earlier or continue with the plan. When in doubt push forward and go public ASAP. - -**Fix Release Day** (Completed within 1-21 days of Disclosure) - -- The maintainers will create a new patch release branch from the latest patch release tag + the fix - from the security branch. As a practical example if v1.5.3 is the latest patch release in opa.git - a new branch will be created called v1.5.4 which includes only patches required to fix the issue. -- The Fix Lead will cherry-pick the patches onto the main branch and all relevant release branches. - The Fix Team will LGTM and merge. Maintainers will merge these PRs as quickly as possible. Changes - shouldn't be made to the commits even for a typo in the CHANGELOG as this will change the git sha - of the commits leading to confusion and potentially conflicts as the fix is cherry-picked around - branches. -- The Fix Lead will request a CVE from [DWF](https://github.com/distributedweaknessfiling/DWF-Documentation) - and include the CVSS and release details. -- The Fix Lead will email open-policy-agent[-announce]@googlegroups.com now that everything is public - announcing the new releases, the CVE number, and the relevant merged PRs to get wide distribution - and user action. As much as possible this email should be actionable and include links on how to apply - the fix to user's environments; this can include links to external distributor documentation. -- The Fix Lead will remove the Fix Team from the private security repo. - -### Retrospective - -These steps should be completed 1-3 days after the Release Date. The retrospective process -[should be blameless](https://landing.google.com/sre/book/chapters/postmortem-culture.html). - -- The Fix Lead will send a retrospective of the process to open-policy-agent@googlegroups.com including - details on everyone involved, the timeline of the process, links to relevant PRs that introduced - the issue, if relevant, and any critiques of the response and release process. -- Maintainers and Fix Team are also encouraged to send their own feedback on the process to - open-policy-agent@googlegroups.com. Honest critique is the only way we are going to get good at this as a - community. - -## Private Distributors List - -This list is intended to be used primarily to provide actionable information to -multiple distribution vendors at once. This list is not intended for -individuals to find out about security issues. - -### Embargo Policy - -The information members receive on open-policy-agent-distributors-announce must not be made public, shared, nor -even hinted at anywhere beyond the need-to-know within your specific team except with the list's -explicit approval. This holds true until the public disclosure date/time that was agreed upon by the -list. Members of the list and others may not use the information for anything other than getting the -issue fixed for your respective distribution's users. - -Before any information from the list is shared with respective members of your team required to fix -said issue, they must agree to the same terms and only find out information on a need-to-know basis. - -In the unfortunate event you share the information beyond what is allowed by this policy, you _must_ -urgently inform the open-policy-agent-security@googlegroups.com mailing list of exactly what information leaked -and to whom. A retrospective will take place after the leak so we can assess how to not make the -same mistake in the future. - -If you continue to leak information and break the policy outlined here, you will be removed from the -list. - -### Contributing Back - -This is a team effort. As a member of the list you must carry some water. This -could be in the form of the following: - -**Technical** - -- Review and/or test the proposed patches and point out potential issues with - them (such as incomplete fixes for the originally reported issues, additional - issues you might notice, and newly introduced bugs), and inform the list of the - work done even if no issues were encountered. - -**Administrative** - -- Help draft emails to the public disclosure mailing list. -- Help with release notes. - -### Membership Criteria - -To be eligible for the open-policy-agent-distributors-announce mailing list, your -distribution should: - -1. Be an actively maintained distribution of OPA components OR offer OPA as a publicly - available service in which the product clearly states that it is built on top of OPA. E.g., - "SuperAwesomeLinuxDistro" which offers OPA pre-built packages OR - "SuperAwesomeCloudProvider's OPA as a Service (EaaS)". A cloud service that uses OPA for a - product but does not publicly say they are using OPA does not qualify. -2. Have a user base not limited to your own organization. -3. Have a publicly verifiable track record up to present day of fixing security - issues. -4. Not be a downstream or rebuild of another distribution. -5. Be a participant and active contributor in the community. -6. Accept the [Embargo Policy](#embargo-policy) that is outlined above. -7. Be willing to [contribute back](#contributing-back) as outlined above. -8. Have someone already on the list vouch for the person requesting membership - on behalf of your distribution. - -### Requesting to Join - -New membership requests are sent to open-policy-agent-security@googlegroups.com. - -In the body of your request please specify how you qualify and fulfill each -criterion listed in [Membership Criteria](#membership-criteria). - -Here is a pseudo example: - -``` -To: open-policy-agent-security@googlegroups.com -Subject: Seven-Corp Membership to open-policy-agent-distributors-announce - -Below are each criterion and why I think we, Seven-Corp, qualify. - -> 1. Be an actively maintained distribution of OPA components OR offer OPA as a publicly - available service in which the product clearly states that it is built on top of OPA. - -We distribute the "Seven" distribution of OPA [link]. We have been doing -this since 1999 before proxies were even cool. - -> 2. Have a user base not limited to your own organization. - -Our user base spans of the extensive "Seven" community. We have a slack and -GitHub repos and mailing lists where the community hangs out. [links] - -> 3. Have a publicly verifiable track record up to present day of fixing security - issues. - -We announce on our blog all upstream patches we apply to "Seven." [link to blog -posts] - -> 4. Not be a downstream or rebuild of another distribution. - -This does not apply, "Seven" is a unique snowflake distribution. - -> 5. Be a participant and active contributor in the community. - -Our members, Acidburn, Cereal, and ZeroCool are outstanding members and are well -known throughout the OPA community. Especially for their contributions -in hacking the Gibson. - -> 6. Accept the Embargo Policy that is outlined above. - -We accept. - -> 7. Be willing to contribute back as outlined above. - -We are definitely willing to help! - -> 8. Have someone already on the list vouch for the person requesting membership - on behalf of your distribution. - -CrashOverride will vouch for Acidburn joining the list on behalf of the "Seven" -distribution. -``` +Please refer to the [OPA Security Policy](https://openpolicyagent.org/security) +for details on how to report security issues, our disclosure policy, and how to +receive notifications about security issues. \ No newline at end of file diff --git a/ast/annotations.go b/ast/annotations.go new file mode 100644 index 0000000000..1134cc87c4 --- /dev/null +++ b/ast/annotations.go @@ -0,0 +1,941 @@ +// Copyright 2022 The OPA Authors. All rights reserved. +// Use of this source code is governed by an Apache2 +// license that can be found in the LICENSE file. + +package ast + +import ( + "encoding/json" + "fmt" + "net/url" + "sort" + "strings" + + "github.com/open-policy-agent/opa/internal/deepcopy" + "github.com/open-policy-agent/opa/util" +) + +const ( + annotationScopePackage = "package" + annotationScopeImport = "import" + annotationScopeRule = "rule" + annotationScopeDocument = "document" + annotationScopeSubpackages = "subpackages" +) + +type ( + // Annotations represents metadata attached to other AST nodes such as rules. + Annotations struct { + Scope string `json:"scope"` + Title string `json:"title,omitempty"` + Entrypoint bool `json:"entrypoint,omitempty"` + Description string `json:"description,omitempty"` + Organizations []string `json:"organizations,omitempty"` + RelatedResources []*RelatedResourceAnnotation `json:"related_resources,omitempty"` + Authors []*AuthorAnnotation `json:"authors,omitempty"` + Schemas []*SchemaAnnotation `json:"schemas,omitempty"` + Custom map[string]interface{} `json:"custom,omitempty"` + Location *Location `json:"location,omitempty"` + + comments []*Comment + node Node + jsonOptions JSONOptions + } + + // SchemaAnnotation contains a schema declaration for the document identified by the path. + SchemaAnnotation struct { + Path Ref `json:"path"` + Schema Ref `json:"schema,omitempty"` + Definition *interface{} `json:"definition,omitempty"` + } + + AuthorAnnotation struct { + Name string `json:"name"` + Email string `json:"email,omitempty"` + } + + RelatedResourceAnnotation struct { + Ref url.URL `json:"ref"` + Description string `json:"description,omitempty"` + } + + AnnotationSet struct { + byRule map[*Rule][]*Annotations + byPackage map[int]*Annotations + byPath *annotationTreeNode + modules []*Module // Modules this set was constructed from + } + + annotationTreeNode struct { + Value *Annotations + Children map[Value]*annotationTreeNode // we assume key elements are hashable (vars and strings only!) + } + + AnnotationsRef struct { + Path Ref `json:"path"` // The path of the node the annotations are applied to + Annotations *Annotations `json:"annotations,omitempty"` + Location *Location `json:"location,omitempty"` // The location of the node the annotations are applied to + + jsonOptions JSONOptions + + node Node // The node the annotations are applied to + } + + AnnotationsRefSet []*AnnotationsRef + + FlatAnnotationsRefSet AnnotationsRefSet +) + +func (a *Annotations) String() string { + bs, _ := a.MarshalJSON() + return string(bs) +} + +// Loc returns the location of this annotation. +func (a *Annotations) Loc() *Location { + return a.Location +} + +// SetLoc updates the location of this annotation. +func (a *Annotations) SetLoc(l *Location) { + a.Location = l +} + +// EndLoc returns the location of this annotation's last comment line. +func (a *Annotations) EndLoc() *Location { + count := len(a.comments) + if count == 0 { + return a.Location + } + return a.comments[count-1].Location +} + +// Compare returns an integer indicating if a is less than, equal to, or greater +// than other. +func (a *Annotations) Compare(other *Annotations) int { + + if a == nil && other == nil { + return 0 + } + + if a == nil { + return -1 + } + + if other == nil { + return 1 + } + + if cmp := scopeCompare(a.Scope, other.Scope); cmp != 0 { + return cmp + } + + if cmp := strings.Compare(a.Title, other.Title); cmp != 0 { + return cmp + } + + if cmp := strings.Compare(a.Description, other.Description); cmp != 0 { + return cmp + } + + if cmp := compareStringLists(a.Organizations, other.Organizations); cmp != 0 { + return cmp + } + + if cmp := compareRelatedResources(a.RelatedResources, other.RelatedResources); cmp != 0 { + return cmp + } + + if cmp := compareAuthors(a.Authors, other.Authors); cmp != 0 { + return cmp + } + + if cmp := compareSchemas(a.Schemas, other.Schemas); cmp != 0 { + return cmp + } + + if a.Entrypoint != other.Entrypoint { + if a.Entrypoint { + return 1 + } + return -1 + } + + if cmp := util.Compare(a.Custom, other.Custom); cmp != 0 { + return cmp + } + + return 0 +} + +// GetTargetPath returns the path of the node these Annotations are applied to (the target) +func (a *Annotations) GetTargetPath() Ref { + switch n := a.node.(type) { + case *Package: + return n.Path + case *Rule: + return n.Path() + default: + return nil + } +} + +func (a *Annotations) setJSONOptions(opts JSONOptions) { + a.jsonOptions = opts +} + +func (a *Annotations) MarshalJSON() ([]byte, error) { + if a == nil { + return []byte(`{"scope":""}`), nil + } + + data := map[string]interface{}{ + "scope": a.Scope, + } + + if a.Title != "" { + data["title"] = a.Title + } + + if a.Description != "" { + data["description"] = a.Description + } + + if a.Entrypoint { + data["entrypoint"] = a.Entrypoint + } + + if len(a.Organizations) > 0 { + data["organizations"] = a.Organizations + } + + if len(a.RelatedResources) > 0 { + data["related_resources"] = a.RelatedResources + } + + if len(a.Authors) > 0 { + data["authors"] = a.Authors + } + + if len(a.Schemas) > 0 { + data["schemas"] = a.Schemas + } + + if len(a.Custom) > 0 { + data["custom"] = a.Custom + } + + if a.jsonOptions.MarshalOptions.IncludeLocation.Annotations { + if a.Location != nil { + data["location"] = a.Location + } + } + + return json.Marshal(data) +} + +func NewAnnotationsRef(a *Annotations) *AnnotationsRef { + var loc *Location + if a.node != nil { + loc = a.node.Loc() + } + + return &AnnotationsRef{ + Location: loc, + Path: a.GetTargetPath(), + Annotations: a, + node: a.node, + jsonOptions: a.jsonOptions, + } +} + +func (ar *AnnotationsRef) GetPackage() *Package { + switch n := ar.node.(type) { + case *Package: + return n + case *Rule: + return n.Module.Package + default: + return nil + } +} + +func (ar *AnnotationsRef) GetRule() *Rule { + switch n := ar.node.(type) { + case *Rule: + return n + default: + return nil + } +} + +func (ar *AnnotationsRef) MarshalJSON() ([]byte, error) { + data := map[string]interface{}{ + "path": ar.Path, + } + + if ar.Annotations != nil { + data["annotations"] = ar.Annotations + } + + if ar.jsonOptions.MarshalOptions.IncludeLocation.AnnotationsRef { + if ar.Location != nil { + data["location"] = ar.Location + } + } + + return json.Marshal(data) +} + +func scopeCompare(s1, s2 string) int { + + o1 := scopeOrder(s1) + o2 := scopeOrder(s2) + + if o2 < o1 { + return 1 + } else if o2 > o1 { + return -1 + } + + if s1 < s2 { + return -1 + } else if s2 < s1 { + return 1 + } + + return 0 +} + +func scopeOrder(s string) int { + switch s { + case annotationScopeRule: + return 1 + } + return 0 +} + +func compareAuthors(a, b []*AuthorAnnotation) int { + if len(a) > len(b) { + return 1 + } else if len(a) < len(b) { + return -1 + } + + for i := 0; i < len(a); i++ { + if cmp := a[i].Compare(b[i]); cmp != 0 { + return cmp + } + } + + return 0 +} + +func compareRelatedResources(a, b []*RelatedResourceAnnotation) int { + if len(a) > len(b) { + return 1 + } else if len(a) < len(b) { + return -1 + } + + for i := 0; i < len(a); i++ { + if cmp := strings.Compare(a[i].String(), b[i].String()); cmp != 0 { + return cmp + } + } + + return 0 +} + +func compareSchemas(a, b []*SchemaAnnotation) int { + max := len(a) + if len(b) < max { + max = len(b) + } + + for i := 0; i < max; i++ { + if cmp := a[i].Compare(b[i]); cmp != 0 { + return cmp + } + } + + if len(a) > len(b) { + return 1 + } else if len(a) < len(b) { + return -1 + } + + return 0 +} + +func compareStringLists(a, b []string) int { + if len(a) > len(b) { + return 1 + } else if len(a) < len(b) { + return -1 + } + + for i := 0; i < len(a); i++ { + if cmp := strings.Compare(a[i], b[i]); cmp != 0 { + return cmp + } + } + + return 0 +} + +// Copy returns a deep copy of s. +func (a *Annotations) Copy(node Node) *Annotations { + cpy := *a + + cpy.Organizations = make([]string, len(a.Organizations)) + copy(cpy.Organizations, a.Organizations) + + cpy.RelatedResources = make([]*RelatedResourceAnnotation, len(a.RelatedResources)) + for i := range a.RelatedResources { + cpy.RelatedResources[i] = a.RelatedResources[i].Copy() + } + + cpy.Authors = make([]*AuthorAnnotation, len(a.Authors)) + for i := range a.Authors { + cpy.Authors[i] = a.Authors[i].Copy() + } + + cpy.Schemas = make([]*SchemaAnnotation, len(a.Schemas)) + for i := range a.Schemas { + cpy.Schemas[i] = a.Schemas[i].Copy() + } + + cpy.Custom = deepcopy.Map(a.Custom) + + cpy.node = node + + return &cpy +} + +// toObject constructs an AST Object from a. +func (a *Annotations) toObject() (*Object, *Error) { + obj := NewObject() + + if a == nil { + return &obj, nil + } + + if len(a.Scope) > 0 { + obj.Insert(StringTerm("scope"), StringTerm(a.Scope)) + } + + if len(a.Title) > 0 { + obj.Insert(StringTerm("title"), StringTerm(a.Title)) + } + + if a.Entrypoint { + obj.Insert(StringTerm("entrypoint"), BooleanTerm(true)) + } + + if len(a.Description) > 0 { + obj.Insert(StringTerm("description"), StringTerm(a.Description)) + } + + if len(a.Organizations) > 0 { + orgs := make([]*Term, 0, len(a.Organizations)) + for _, org := range a.Organizations { + orgs = append(orgs, StringTerm(org)) + } + obj.Insert(StringTerm("organizations"), ArrayTerm(orgs...)) + } + + if len(a.RelatedResources) > 0 { + rrs := make([]*Term, 0, len(a.RelatedResources)) + for _, rr := range a.RelatedResources { + rrObj := NewObject(Item(StringTerm("ref"), StringTerm(rr.Ref.String()))) + if len(rr.Description) > 0 { + rrObj.Insert(StringTerm("description"), StringTerm(rr.Description)) + } + rrs = append(rrs, NewTerm(rrObj)) + } + obj.Insert(StringTerm("related_resources"), ArrayTerm(rrs...)) + } + + if len(a.Authors) > 0 { + as := make([]*Term, 0, len(a.Authors)) + for _, author := range a.Authors { + aObj := NewObject() + if len(author.Name) > 0 { + aObj.Insert(StringTerm("name"), StringTerm(author.Name)) + } + if len(author.Email) > 0 { + aObj.Insert(StringTerm("email"), StringTerm(author.Email)) + } + as = append(as, NewTerm(aObj)) + } + obj.Insert(StringTerm("authors"), ArrayTerm(as...)) + } + + if len(a.Schemas) > 0 { + ss := make([]*Term, 0, len(a.Schemas)) + for _, s := range a.Schemas { + sObj := NewObject() + if len(s.Path) > 0 { + sObj.Insert(StringTerm("path"), NewTerm(s.Path.toArray())) + } + if len(s.Schema) > 0 { + sObj.Insert(StringTerm("schema"), NewTerm(s.Schema.toArray())) + } + if s.Definition != nil { + def, err := InterfaceToValue(s.Definition) + if err != nil { + return nil, NewError(CompileErr, a.Location, "invalid definition in schema annotation: %s", err.Error()) + } + sObj.Insert(StringTerm("definition"), NewTerm(def)) + } + ss = append(ss, NewTerm(sObj)) + } + obj.Insert(StringTerm("schemas"), ArrayTerm(ss...)) + } + + if len(a.Custom) > 0 { + c, err := InterfaceToValue(a.Custom) + if err != nil { + return nil, NewError(CompileErr, a.Location, "invalid custom annotation %s", err.Error()) + } + obj.Insert(StringTerm("custom"), NewTerm(c)) + } + + return &obj, nil +} + +func attachAnnotationsNodes(mod *Module) Errors { + var errs Errors + + // Find first non-annotation statement following each annotation and attach + // the annotation to that statement. + for _, a := range mod.Annotations { + for _, stmt := range mod.stmts { + _, ok := stmt.(*Annotations) + if !ok { + if stmt.Loc().Row > a.Location.Row { + a.node = stmt + break + } + } + } + + if a.Scope == "" { + switch a.node.(type) { + case *Rule: + a.Scope = annotationScopeRule + case *Package: + a.Scope = annotationScopePackage + case *Import: + a.Scope = annotationScopeImport + } + } + + if err := validateAnnotationScopeAttachment(a); err != nil { + errs = append(errs, err) + } + + if err := validateAnnotationEntrypointAttachment(a); err != nil { + errs = append(errs, err) + } + } + + return errs +} + +func validateAnnotationScopeAttachment(a *Annotations) *Error { + + switch a.Scope { + case annotationScopeRule, annotationScopeDocument: + if _, ok := a.node.(*Rule); ok { + return nil + } + return newScopeAttachmentErr(a, "rule") + case annotationScopePackage, annotationScopeSubpackages: + if _, ok := a.node.(*Package); ok { + return nil + } + return newScopeAttachmentErr(a, "package") + } + + return NewError(ParseErr, a.Loc(), "invalid annotation scope '%v'. Use one of '%s', '%s', '%s', or '%s'", + a.Scope, annotationScopeRule, annotationScopeDocument, annotationScopePackage, annotationScopeSubpackages) +} + +func validateAnnotationEntrypointAttachment(a *Annotations) *Error { + if a.Entrypoint && !(a.Scope == annotationScopeRule || a.Scope == annotationScopePackage) { + return NewError(ParseErr, a.Loc(), "annotation entrypoint applied to non-rule or package scope '%v'", a.Scope) + } + return nil +} + +// Copy returns a deep copy of a. +func (a *AuthorAnnotation) Copy() *AuthorAnnotation { + cpy := *a + return &cpy +} + +// Compare returns an integer indicating if s is less than, equal to, or greater +// than other. +func (a *AuthorAnnotation) Compare(other *AuthorAnnotation) int { + if cmp := strings.Compare(a.Name, other.Name); cmp != 0 { + return cmp + } + + if cmp := strings.Compare(a.Email, other.Email); cmp != 0 { + return cmp + } + + return 0 +} + +func (a *AuthorAnnotation) String() string { + if len(a.Email) == 0 { + return a.Name + } else if len(a.Name) == 0 { + return fmt.Sprintf("<%s>", a.Email) + } else { + return fmt.Sprintf("%s <%s>", a.Name, a.Email) + } +} + +// Copy returns a deep copy of rr. +func (rr *RelatedResourceAnnotation) Copy() *RelatedResourceAnnotation { + cpy := *rr + return &cpy +} + +// Compare returns an integer indicating if s is less than, equal to, or greater +// than other. +func (rr *RelatedResourceAnnotation) Compare(other *RelatedResourceAnnotation) int { + if cmp := strings.Compare(rr.Description, other.Description); cmp != 0 { + return cmp + } + + if cmp := strings.Compare(rr.Ref.String(), other.Ref.String()); cmp != 0 { + return cmp + } + + return 0 +} + +func (rr *RelatedResourceAnnotation) String() string { + bs, _ := json.Marshal(rr) + return string(bs) +} + +func (rr *RelatedResourceAnnotation) MarshalJSON() ([]byte, error) { + d := map[string]interface{}{ + "ref": rr.Ref.String(), + } + + if len(rr.Description) > 0 { + d["description"] = rr.Description + } + + return json.Marshal(d) +} + +// Copy returns a deep copy of s. +func (s *SchemaAnnotation) Copy() *SchemaAnnotation { + cpy := *s + return &cpy +} + +// Compare returns an integer indicating if s is less than, equal to, or greater +// than other. +func (s *SchemaAnnotation) Compare(other *SchemaAnnotation) int { + + if cmp := s.Path.Compare(other.Path); cmp != 0 { + return cmp + } + + if cmp := s.Schema.Compare(other.Schema); cmp != 0 { + return cmp + } + + if s.Definition != nil && other.Definition == nil { + return -1 + } else if s.Definition == nil && other.Definition != nil { + return 1 + } else if s.Definition != nil && other.Definition != nil { + return util.Compare(*s.Definition, *other.Definition) + } + + return 0 +} + +func (s *SchemaAnnotation) String() string { + bs, _ := json.Marshal(s) + return string(bs) +} + +func newAnnotationSet() *AnnotationSet { + return &AnnotationSet{ + byRule: map[*Rule][]*Annotations{}, + byPackage: map[int]*Annotations{}, + byPath: newAnnotationTree(), + } +} + +func BuildAnnotationSet(modules []*Module) (*AnnotationSet, Errors) { + as := newAnnotationSet() + var errs Errors + for _, m := range modules { + for _, a := range m.Annotations { + if err := as.add(a); err != nil { + errs = append(errs, err) + } + } + } + if len(errs) > 0 { + return nil, errs + } + as.modules = modules + return as, nil +} + +// NOTE(philipc): During copy propagation, the underlying Nodes can be +// stripped away from the annotations, leading to nil deref panics. We +// silently ignore these cases for now, as a workaround. +func (as *AnnotationSet) add(a *Annotations) *Error { + switch a.Scope { + case annotationScopeRule: + if rule, ok := a.node.(*Rule); ok { + as.byRule[rule] = append(as.byRule[rule], a) + } + case annotationScopePackage: + if pkg, ok := a.node.(*Package); ok { + hash := pkg.Path.Hash() + if exist, ok := as.byPackage[hash]; ok { + return errAnnotationRedeclared(a, exist.Location) + } + as.byPackage[hash] = a + } + case annotationScopeDocument: + if rule, ok := a.node.(*Rule); ok { + path := rule.Path() + x := as.byPath.get(path) + if x != nil { + return errAnnotationRedeclared(a, x.Value.Location) + } + as.byPath.insert(path, a) + } + case annotationScopeSubpackages: + if pkg, ok := a.node.(*Package); ok { + x := as.byPath.get(pkg.Path) + if x != nil && x.Value != nil { + return errAnnotationRedeclared(a, x.Value.Location) + } + as.byPath.insert(pkg.Path, a) + } + } + return nil +} + +func (as *AnnotationSet) GetRuleScope(r *Rule) []*Annotations { + if as == nil { + return nil + } + return as.byRule[r] +} + +func (as *AnnotationSet) GetSubpackagesScope(path Ref) []*Annotations { + if as == nil { + return nil + } + return as.byPath.ancestors(path) +} + +func (as *AnnotationSet) GetDocumentScope(path Ref) *Annotations { + if as == nil { + return nil + } + if node := as.byPath.get(path); node != nil { + return node.Value + } + return nil +} + +func (as *AnnotationSet) GetPackageScope(pkg *Package) *Annotations { + if as == nil { + return nil + } + return as.byPackage[pkg.Path.Hash()] +} + +// Flatten returns a flattened list view of this AnnotationSet. +// The returned slice is sorted, first by the annotations' target path, then by their target location +func (as *AnnotationSet) Flatten() FlatAnnotationsRefSet { + // This preallocation often won't be optimal, but it's superior to starting with a nil slice. + refs := make([]*AnnotationsRef, 0, len(as.byPath.Children)+len(as.byRule)+len(as.byPackage)) + + refs = as.byPath.flatten(refs) + + for _, a := range as.byPackage { + refs = append(refs, NewAnnotationsRef(a)) + } + + for _, as := range as.byRule { + for _, a := range as { + refs = append(refs, NewAnnotationsRef(a)) + } + } + + // Sort by path, then annotation location, for stable output + sort.SliceStable(refs, func(i, j int) bool { + return refs[i].Compare(refs[j]) < 0 + }) + + return refs +} + +// Chain returns the chain of annotations leading up to the given rule. +// The returned slice is ordered as follows +// 0. Entries for the given rule, ordered from the METADATA block declared immediately above the rule, to the block declared farthest away (always at least one entry) +// 1. The 'document' scope entry, if any +// 2. The 'package' scope entry, if any +// 3. Entries for the 'subpackages' scope, if any; ordered from the closest package path to the fartest. E.g.: 'do.re.mi', 'do.re', 'do' +// The returned slice is guaranteed to always contain at least one entry, corresponding to the given rule. +func (as *AnnotationSet) Chain(rule *Rule) AnnotationsRefSet { + var refs []*AnnotationsRef + + ruleAnnots := as.GetRuleScope(rule) + + if len(ruleAnnots) >= 1 { + for _, a := range ruleAnnots { + refs = append(refs, NewAnnotationsRef(a)) + } + } else { + // Make sure there is always a leading entry representing the passed rule, even if it has no annotations + refs = append(refs, &AnnotationsRef{ + Location: rule.Location, + Path: rule.Path(), + node: rule, + }) + } + + if len(refs) > 1 { + // Sort by annotation location; chain must start with annotations declared closest to rule, then going outward + sort.SliceStable(refs, func(i, j int) bool { + return refs[i].Annotations.Location.Compare(refs[j].Annotations.Location) > 0 + }) + } + + docAnnots := as.GetDocumentScope(rule.Path()) + if docAnnots != nil { + refs = append(refs, NewAnnotationsRef(docAnnots)) + } + + pkg := rule.Module.Package + pkgAnnots := as.GetPackageScope(pkg) + if pkgAnnots != nil { + refs = append(refs, NewAnnotationsRef(pkgAnnots)) + } + + subPkgAnnots := as.GetSubpackagesScope(pkg.Path) + // We need to reverse the order, as subPkgAnnots ordering will start at the root, + // whereas we want to end at the root. + for i := len(subPkgAnnots) - 1; i >= 0; i-- { + refs = append(refs, NewAnnotationsRef(subPkgAnnots[i])) + } + + return refs +} + +func (ars FlatAnnotationsRefSet) Insert(ar *AnnotationsRef) FlatAnnotationsRefSet { + result := make(FlatAnnotationsRefSet, 0, len(ars)+1) + + // insertion sort, first by path, then location + for i, current := range ars { + if ar.Compare(current) < 0 { + result = append(result, ar) + result = append(result, ars[i:]...) + break + } + result = append(result, current) + } + + if len(result) < len(ars)+1 { + result = append(result, ar) + } + + return result +} + +func newAnnotationTree() *annotationTreeNode { + return &annotationTreeNode{ + Value: nil, + Children: map[Value]*annotationTreeNode{}, + } +} + +func (t *annotationTreeNode) insert(path Ref, value *Annotations) { + node := t + for _, k := range path { + child, ok := node.Children[k.Value] + if !ok { + child = newAnnotationTree() + node.Children[k.Value] = child + } + node = child + } + node.Value = value +} + +func (t *annotationTreeNode) get(path Ref) *annotationTreeNode { + node := t + for _, k := range path { + if node == nil { + return nil + } + child, ok := node.Children[k.Value] + if !ok { + return nil + } + node = child + } + return node +} + +// ancestors returns a slice of annotations in ascending order, starting with the root of ref; e.g.: 'root', 'root.foo', 'root.foo.bar'. +func (t *annotationTreeNode) ancestors(path Ref) (result []*Annotations) { + node := t + for _, k := range path { + if node == nil { + return result + } + child, ok := node.Children[k.Value] + if !ok { + return result + } + if child.Value != nil { + result = append(result, child.Value) + } + node = child + } + return result +} + +func (t *annotationTreeNode) flatten(refs []*AnnotationsRef) []*AnnotationsRef { + if a := t.Value; a != nil { + refs = append(refs, NewAnnotationsRef(a)) + } + for _, c := range t.Children { + refs = c.flatten(refs) + } + return refs +} + +func (ar *AnnotationsRef) Compare(other *AnnotationsRef) int { + if c := ar.Path.Compare(other.Path); c != 0 { + return c + } + + if c := ar.Annotations.Location.Compare(other.Annotations.Location); c != 0 { + return c + } + + return ar.Annotations.Compare(other.Annotations) +} diff --git a/ast/annotations_test.go b/ast/annotations_test.go new file mode 100644 index 0000000000..05563d5197 --- /dev/null +++ b/ast/annotations_test.go @@ -0,0 +1,1106 @@ +// Copyright 2022 The OPA Authors. All rights reserved. +// Use of this source code is governed by an Apache2 +// license that can be found in the LICENSE file. + +package ast + +import ( + "encoding/json" + "fmt" + "testing" +) + +// Test of example code in docs/content/annotations.md +func ExampleAnnotationSet_Flatten() { + modules := [][]string{ + { + "foo.rego", `# METADATA +# scope: subpackages +# organizations: +# - Acme Corp. +package foo`}, + { + "mod", `# METADATA +# description: A couple of useful rules +package foo.bar + +# METADATA +# title: My Rule P +p := 7`}, + } + + parsed := make([]*Module, 0, len(modules)) + for _, entry := range modules { + pm, err := ParseModuleWithOpts(entry[0], entry[1], ParserOptions{ProcessAnnotation: true}) + if err != nil { + panic(err) + } + parsed = append(parsed, pm) + } + + as, err := BuildAnnotationSet(parsed) + if err != nil { + panic(err) + } + + flattened := as.Flatten() + for _, entry := range flattened { + fmt.Printf("%v at %v has annotations %v\n", + entry.Path, + entry.Location, + entry.Annotations) + } + + // Output: + // data.foo at foo.rego:5 has annotations {"organizations":["Acme Corp."],"scope":"subpackages"} + // data.foo.bar at mod:3 has annotations {"description":"A couple of useful rules","scope":"package"} + // data.foo.bar.p at mod:7 has annotations {"scope":"rule","title":"My Rule P"} +} + +// Test of example code in docs/content/annotations.md +func ExampleAnnotationSet_Chain() { + modules := [][]string{ + { + "foo.rego", `# METADATA +# scope: subpackages +# organizations: +# - Acme Corp. +package foo`}, + { + "mod", `# METADATA +# description: A couple of useful rules +package foo.bar + +# METADATA +# title: My Rule P +p := 7`}, + } + + parsed := make([]*Module, 0, len(modules)) + for _, entry := range modules { + pm, err := ParseModuleWithOpts(entry[0], entry[1], ParserOptions{ProcessAnnotation: true}) + if err != nil { + panic(err) + } + parsed = append(parsed, pm) + } + + as, err := BuildAnnotationSet(parsed) + if err != nil { + panic(err) + } + + rule := parsed[1].Rules[0] + + flattened := as.Chain(rule) + for _, entry := range flattened { + fmt.Printf("%v at %v has annotations %v\n", + entry.Path, + entry.Location, + entry.Annotations) + } + + // Output: + // data.foo.bar.p at mod:7 has annotations {"scope":"rule","title":"My Rule P"} + // data.foo.bar at mod:3 has annotations {"description":"A couple of useful rules","scope":"package"} + // data.foo at foo.rego:5 has annotations {"organizations":["Acme Corp."],"scope":"subpackages"} +} + +func TestAnnotationSet_Flatten(t *testing.T) { + tests := []struct { + note string + modules map[string]string + expected []AnnotationsRef + }{ + { + note: "no modules", + modules: map[string]string{}, + expected: []AnnotationsRef{}, + }, + { + note: "simple module (all annotation types)", + modules: map[string]string{ + "module": `# METADATA +# title: pkg +# description: pkg +# organizations: +# - pkg +# related_resources: +# - https://pkg +# authors: +# - pkg +# schemas: +# - input: {"type": "boolean"} +# custom: +# pkg: pkg +package test + +# METADATA +# scope: document +# title: doc +# description: doc +# organizations: +# - doc +# related_resources: +# - https://doc +# authors: +# - doc +# schemas: +# - input: {"type": "integer"} +# custom: +# doc: doc + +# METADATA +# title: rule +# description: rule +# organizations: +# - rule +# related_resources: +# - https://rule +# authors: +# - rule +# schemas: +# - input: {"type": "string"} +# custom: +# rule: rule +p = 1`, + }, + expected: []AnnotationsRef{ + { + Path: MustParseRef("data.test"), + Location: &Location{File: "module", Row: 14}, + Annotations: &Annotations{ + Scope: "package", + Title: "pkg", + Description: "pkg", + Organizations: []string{"pkg"}, + RelatedResources: []*RelatedResourceAnnotation{ + { + Ref: mustParseURL("https://pkg"), + }, + }, + Authors: []*AuthorAnnotation{ + { + Name: "pkg", + }, + }, + Schemas: []*SchemaAnnotation{ + schemaAnnotationFromMap("input", map[string]interface{}{ + "type": "boolean", + }), + }, + Custom: map[string]interface{}{ + "pkg": "pkg", + }, + }, + }, + { + Path: MustParseRef("data.test.p"), + Location: &Location{File: "module", Row: 44}, + Annotations: &Annotations{ + Scope: "document", + Title: "doc", + Description: "doc", + Organizations: []string{"doc"}, + RelatedResources: []*RelatedResourceAnnotation{ + { + Ref: mustParseURL("https://doc"), + }, + }, + Authors: []*AuthorAnnotation{ + { + Name: "doc", + }, + }, + Schemas: []*SchemaAnnotation{ + schemaAnnotationFromMap("input", map[string]interface{}{ + "type": "integer", + }), + }, + Custom: map[string]interface{}{ + "doc": "doc", + }, + }, + }, + { + Path: MustParseRef("data.test.p"), + Location: &Location{File: "module", Row: 44}, + Annotations: &Annotations{ + Scope: "rule", + Title: "rule", + Description: "rule", + Organizations: []string{"rule"}, + RelatedResources: []*RelatedResourceAnnotation{ + { + Ref: mustParseURL("https://rule"), + }, + }, + Authors: []*AuthorAnnotation{ + { + Name: "rule", + }, + }, + Schemas: []*SchemaAnnotation{ + schemaAnnotationFromMap("input", map[string]interface{}{ + "type": "string", + }), + }, + Custom: map[string]interface{}{ + "rule": "rule", + }, + }, + }, + }, + }, + { + note: "multiple subpackages", + modules: map[string]string{ + "root": `# METADATA +# scope: subpackages +# title: ROOT +package root`, + "root.foo": `# METADATA +# title: FOO +# scope: subpackages +package root.foo`, + "root.foo.baz": `# METADATA +# title: BAZ +package root.foo.baz`, + "root.bar": `# METADATA +# title: BAR +# scope: subpackages +package root.bar`, + "root.bar.baz": `# METADATA +# title: BAZ +package root.bar.baz`, + "root2": `# METADATA +# scope: subpackages +# title: ROOT2 +package root2`, + }, + expected: []AnnotationsRef{ + { + Path: MustParseRef("data.root"), + Location: &Location{File: "root", Row: 4}, + Annotations: &Annotations{ + Scope: "subpackages", + Title: "ROOT", + }, + }, + { + Path: MustParseRef("data.root.bar"), + Location: &Location{File: "root.bar", Row: 4}, + Annotations: &Annotations{ + Scope: "subpackages", + Title: "BAR", + }, + }, + { + Path: MustParseRef("data.root.bar.baz"), + Location: &Location{File: "root.bar.baz", Row: 3}, + Annotations: &Annotations{ + Scope: "package", + Title: "BAZ", + }, + }, + { + Path: MustParseRef("data.root.foo"), + Location: &Location{File: "root.foo", Row: 4}, + Annotations: &Annotations{ + Scope: "subpackages", + Title: "FOO", + }, + }, + { + Path: MustParseRef("data.root.foo.baz"), + Location: &Location{File: "root.foo.baz", Row: 3}, + Annotations: &Annotations{ + Scope: "package", + Title: "BAZ", + }, + }, + { + Path: MustParseRef("data.root2"), + Location: &Location{File: "root2", Row: 4}, + Annotations: &Annotations{ + Scope: "subpackages", + Title: "ROOT2", + }, + }, + }, + }, + { + note: "overlapping rule paths (same module)", + modules: map[string]string{ + "mod": `package test + +# METADATA +# title: P1 +p[v] {v = 1} + +# METADATA +# title: P2 +p[v] {v = 2}`, + }, + expected: []AnnotationsRef{ + { + Path: MustParseRef("data.test.p"), + Location: &Location{File: "mod", Row: 5}, + Annotations: &Annotations{ + Scope: "rule", + Title: "P1", + }, + }, + { + Path: MustParseRef("data.test.p"), + Location: &Location{File: "mod", Row: 9}, + Annotations: &Annotations{ + Scope: "rule", + Title: "P2", + }, + }, + }, + }, + { + note: "overlapping rule paths (different modules)", + modules: map[string]string{ + "mod1": `package test +# METADATA +# title: P1 +p[v] {v = 1}`, + "mod2": `package test +# METADATA +# title: P2 +p[v] {v = 2}`, + }, + expected: []AnnotationsRef{ + { + Path: MustParseRef("data.test.p"), + Location: &Location{File: "mod1", Row: 4}, + Annotations: &Annotations{ + Scope: "rule", + Title: "P1", + }, + }, + { + Path: MustParseRef("data.test.p"), + Location: &Location{File: "mod2", Row: 4}, + Annotations: &Annotations{ + Scope: "rule", + Title: "P2", + }, + }, + }, + }, + { + note: "overlapping rule paths (different modules, rule head refs)", + modules: map[string]string{ + "mod1": `package test.a +# METADATA +# title: P1 +b.c.p[v] {v = 1}`, + "mod2": `package test +# METADATA +# title: P2 +a.b.c.p[v] {v = 2}`, + }, + expected: []AnnotationsRef{ + { + Path: MustParseRef("data.test.a.b.c.p"), + Location: &Location{File: "mod1", Row: 4}, + Annotations: &Annotations{ + Scope: "rule", + Title: "P1", + }, + }, + { + Path: MustParseRef("data.test.a.b.c.p"), + Location: &Location{File: "mod2", Row: 4}, + Annotations: &Annotations{ + Scope: "rule", + Title: "P2", + }, + }, + }, + }, + } + + for _, tc := range tests { + t.Run(tc.note, func(t *testing.T) { + compiler := MustCompileModulesWithOpts(tc.modules, + CompileOpts{ParserOptions: ParserOptions{ProcessAnnotation: true}}) + + as := compiler.GetAnnotationSet() + if as == nil { + t.Fatalf("Expected compiled AnnotationSet, got nil") + } + + flattened := as.Flatten() + + if len(flattened) != len(tc.expected) { + t.Fatalf("flattened AnnotationSet\n%v\ndoesn't match expected\n%v", + toJSON(flattened), toJSON(tc.expected)) + } + + for i, expected := range tc.expected { + a := flattened[i] + if !expected.Path.Equal(a.Path) { + t.Fatalf("path of AnnotationRef at %d '%v' doesn't match expected '%v'", + i, a.Path, expected.Path) + } + if expected.Location.File != a.Location.File || expected.Location.Row != a.Location.Row { + t.Fatalf("location of AnnotationRef at %d '%v' doesn't match expected '%v'", + i, a.Location, expected.Location) + } + if expected.Annotations.Compare(a.Annotations) != 0 { + t.Fatalf("annotations of AnnotationRef at %d\n%v\ndoesn't match expected\n%v", + i, a.Annotations, expected.Annotations) + } + } + }) + } +} + +func TestAnnotationSet_Chain(t *testing.T) { + tests := []struct { + note string + modules map[string]string + moduleToAnalyze string + ruleOnLineToAnalyze int + expected []AnnotationsRef + }{ + { + note: "simple module (all annotation types)", + modules: map[string]string{ + "module": `# METADATA +# title: pkg +# description: pkg +# organizations: +# - pkg +# related_resources: +# - https://pkg +# authors: +# - pkg +# schemas: +# - input.foo: {"type": "boolean"} +# custom: +# pkg: pkg +package test + +# METADATA +# scope: document +# title: doc +# description: doc +# organizations: +# - doc +# related_resources: +# - https://doc +# authors: +# - doc +# schemas: +# - input.bar: {"type": "integer"} +# custom: +# doc: doc + +# METADATA +# title: rule +# description: rule +# organizations: +# - rule +# related_resources: +# - https://rule +# authors: +# - rule +# schemas: +# - input.baz: {"type": "string"} +# custom: +# rule: rule +p = 1`, + }, + moduleToAnalyze: "module", + ruleOnLineToAnalyze: 44, + expected: []AnnotationsRef{ + { // Rule annotation is always first + Path: MustParseRef("data.test.p"), + Location: &Location{File: "module", Row: 44}, + Annotations: &Annotations{ + Scope: "rule", + Title: "rule", + Description: "rule", + Organizations: []string{"rule"}, + RelatedResources: []*RelatedResourceAnnotation{ + { + Ref: mustParseURL("https://rule"), + }, + }, + Authors: []*AuthorAnnotation{ + { + Name: "rule", + }, + }, + Schemas: []*SchemaAnnotation{ + schemaAnnotationFromMap("input.baz", map[string]interface{}{ + "type": "string", + }), + }, + Custom: map[string]interface{}{ + "rule": "rule", + }, + }, + }, + { + Path: MustParseRef("data.test.p"), + Location: &Location{File: "module", Row: 44}, + Annotations: &Annotations{ + Scope: "document", + Title: "doc", + Description: "doc", + Organizations: []string{"doc"}, + RelatedResources: []*RelatedResourceAnnotation{ + { + Ref: mustParseURL("https://doc"), + }, + }, + Authors: []*AuthorAnnotation{ + { + Name: "doc", + }, + }, + Schemas: []*SchemaAnnotation{ + schemaAnnotationFromMap("input.bar", map[string]interface{}{ + "type": "integer", + }), + }, + Custom: map[string]interface{}{ + "doc": "doc", + }, + }, + }, + { + Path: MustParseRef("data.test"), + Location: &Location{File: "module", Row: 14}, + Annotations: &Annotations{ + Scope: "package", + Title: "pkg", + Description: "pkg", + Organizations: []string{"pkg"}, + RelatedResources: []*RelatedResourceAnnotation{ + { + Ref: mustParseURL("https://pkg"), + }, + }, + Authors: []*AuthorAnnotation{ + { + Name: "pkg", + }, + }, + Schemas: []*SchemaAnnotation{ + schemaAnnotationFromMap("input.foo", map[string]interface{}{ + "type": "boolean", + }), + }, + Custom: map[string]interface{}{ + "pkg": "pkg", + }, + }, + }, + }, + }, + { + note: "no annotations on rule", + modules: map[string]string{ + "module": `# METADATA +# title: pkg +# description: pkg +package test + +# METADATA +# scope: document +# title: doc +# description: doc + +p = 1`, + }, + moduleToAnalyze: "module", + ruleOnLineToAnalyze: 11, + expected: []AnnotationsRef{ + { // Rule entry is always first, even if no annotations are present + Path: MustParseRef("data.test.p"), + Location: &Location{File: "module", Row: 11}, + Annotations: nil, + }, + { + Path: MustParseRef("data.test.p"), + Location: &Location{File: "module", Row: 11}, + Annotations: &Annotations{ + Scope: "document", + Title: "doc", + Description: "doc", + }, + }, + + { + Path: MustParseRef("data.test"), + Location: &Location{File: "module", Row: 4}, + Annotations: &Annotations{ + Scope: "package", + Title: "pkg", + Description: "pkg", + }, + }, + }, + }, + { + note: "multiple subpackages", + modules: map[string]string{ + "root": `# METADATA +# scope: subpackages +# title: ROOT +package root`, + "root.foo": `# METADATA +# title: FOO +# scope: subpackages +package root.foo`, + "root.foo.bar": `# METADATA +# scope: subpackages +# description: subpackages scope applied to rule in other module +# title: BAR-sub + +# METADATA +# title: BAR-other +# description: This metadata is on the path of the queried rule, and should show up in the result even though it's in a different module. +package root.foo.bar + +# METADATA +# scope: document +# description: document scope applied to rule in other module +# title: P-doc +p = 1`, + "rule": `package root.foo.bar + +# METADATA +# title: P +p = 1`, + }, + moduleToAnalyze: "rule", + ruleOnLineToAnalyze: 5, + expected: []AnnotationsRef{ + { + Path: MustParseRef("data.root.foo.bar.p"), + Location: &Location{File: "rule", Row: 5}, + Annotations: &Annotations{ + Scope: "rule", + Title: "P", + }, + }, + { + Path: MustParseRef("data.root.foo.bar.p"), + Location: &Location{File: "root.foo.bar", Row: 15}, + Annotations: &Annotations{ + Scope: "document", + Title: "P-doc", + Description: "document scope applied to rule in other module", + }, + }, + { + Path: MustParseRef("data.root.foo.bar"), + Location: &Location{File: "root.foo.bar", Row: 9}, + Annotations: &Annotations{ + Scope: "package", + Title: "BAR-other", + Description: "This metadata is on the path of the queried rule, and should show up in the result even though it's in a different module.", + }, + }, + { + Path: MustParseRef("data.root.foo.bar"), + Location: &Location{File: "root.foo.bar", Row: 9}, + Annotations: &Annotations{ + Scope: "subpackages", + Title: "BAR-sub", + Description: "subpackages scope applied to rule in other module", + }, + }, + { + Path: MustParseRef("data.root.foo"), + Location: &Location{File: "root.foo", Row: 4}, + Annotations: &Annotations{ + Scope: "subpackages", + Title: "FOO", + }, + }, + { + Path: MustParseRef("data.root"), + Location: &Location{File: "root", Row: 4}, + Annotations: &Annotations{ + Scope: "subpackages", + Title: "ROOT", + }, + }, + }, + }, + { + note: "multiple subpackages, refs in rule heads", // NOTE(sr): same as above, but last module's rule is `foo.bar.p` in package `root` + modules: map[string]string{ + "root": `# METADATA +# scope: subpackages +# title: ROOT +package root`, + "root.foo": `# METADATA +# title: FOO +# scope: subpackages +package root.foo`, + "root.foo.bar": `# METADATA +# scope: subpackages +# description: subpackages scope applied to rule in other module +# title: BAR-sub + +# METADATA +# title: BAR-other +# description: This metadata is on the path of the queried rule, but shouldn't show up in the result as it's in a different module. +package root.foo.bar + +# METADATA +# scope: document +# description: document scope applied to rule in other module +# title: P-doc +p = 1`, + "rule": `# METADATA +# title: BAR +package root + +# METADATA +# title: P +foo.bar.p = 1`, + }, + moduleToAnalyze: "rule", + ruleOnLineToAnalyze: 7, + expected: []AnnotationsRef{ + { + Path: MustParseRef("data.root.foo.bar.p"), + Location: &Location{File: "rule", Row: 7}, + Annotations: &Annotations{ + Scope: "rule", + Title: "P", + }, + }, + { + Path: MustParseRef("data.root.foo.bar.p"), + Location: &Location{File: "root.foo.bar", Row: 15}, + Annotations: &Annotations{ + Scope: "document", + Title: "P-doc", + Description: "document scope applied to rule in other module", + }, + }, + { + Path: MustParseRef("data.root"), + Location: &Location{File: "rule", Row: 3}, + Annotations: &Annotations{ + Scope: "package", + Title: "BAR", + }, + }, + { + Path: MustParseRef("data.root"), + Location: &Location{File: "root", Row: 4}, + Annotations: &Annotations{ + Scope: "subpackages", + Title: "ROOT", + }, + }, + }, + }, + { + note: "multiple metadata blocks for single rule (order)", + modules: map[string]string{ + "module": `package test + +# METADATA +# title: One + +# METADATA +# title: Two + +# METADATA +# title: Three + +# METADATA +# title: Four +p = true`, + }, + moduleToAnalyze: "module", + ruleOnLineToAnalyze: 14, + expected: []AnnotationsRef{ // Rule annotations order is expected to start closest to the rule, moving out + { + Path: MustParseRef("data.test.p"), + Location: &Location{File: "module", Row: 14}, + Annotations: &Annotations{ + Scope: "rule", + Title: "Four", + }, + }, + { + Path: MustParseRef("data.test.p"), + Location: &Location{File: "module", Row: 14}, + Annotations: &Annotations{ + Scope: "rule", + Title: "Three", + }, + }, + { + Path: MustParseRef("data.test.p"), + Location: &Location{File: "module", Row: 14}, + Annotations: &Annotations{ + Scope: "rule", + Title: "Two", + }, + }, + { + Path: MustParseRef("data.test.p"), + Location: &Location{File: "module", Row: 14}, + Annotations: &Annotations{ + Scope: "rule", + Title: "One", + }, + }, + }, + }, + } + + for _, tc := range tests { + t.Run(tc.note, func(t *testing.T) { + compiler := MustCompileModulesWithOpts(tc.modules, + CompileOpts{ParserOptions: ParserOptions{ProcessAnnotation: true}}) + + as := compiler.GetAnnotationSet() + if as == nil { + t.Fatalf("Expected compiled AnnotationSet, got nil") + } + + m := compiler.Modules[tc.moduleToAnalyze] + if m == nil { + t.Fatalf("no such module: %s", tc.moduleToAnalyze) + } + + var rule *Rule + for _, r := range m.Rules { + if r.Location.Row == tc.ruleOnLineToAnalyze { + rule = r + break + } + } + if rule == nil { + t.Fatalf("no rule found on line %d in module '%s'", + tc.ruleOnLineToAnalyze, tc.moduleToAnalyze) + } + + chain := as.Chain(rule) + + if len(chain) != len(tc.expected) { + t.Errorf("expected %d elements, got %d:", len(tc.expected), len(chain)) + t.Fatalf("chained AnnotationSet\n%v\n\ndoesn't match expected\n\n%v", + toJSON(chain), toJSON(tc.expected)) + } + + for i, expected := range tc.expected { + a := chain[i] + if !expected.Path.Equal(a.Path) { + t.Fatalf("path of AnnotationRef at %d '%v' doesn't match expected '%v'", + i, a.Path, expected.Path) + } + if expected.Location.File != a.Location.File || expected.Location.Row != a.Location.Row { + t.Fatalf("location of AnnotationRef at %d '%v' doesn't match expected '%v'", + i, a.Location, expected.Location) + } + if expected.Annotations.Compare(a.Annotations) != 0 { + t.Fatalf("annotations of AnnotationRef at %d\n%v\n\ndoesn't match expected\n\n%v", + i, a.Annotations, expected.Annotations) + } + } + }) + } +} + +func TestAnnotations_toObject(t *testing.T) { + annotations := Annotations{ + Scope: annotationScopeRule, + Title: "A title", + Description: "A description", + Organizations: []string{ + "Acme Corp.", + "Tyrell Corp.", + }, + RelatedResources: []*RelatedResourceAnnotation{ + { + Ref: mustParseURL("https://example.com"), + Description: "An example", + }, + { + Ref: mustParseURL("https://another.example.com"), + }, + }, + Authors: []*AuthorAnnotation{ + { + Name: "John Doe", + Email: "john@example.com", + }, + { + Name: "Jane Doe", + }, + { + Email: "jeff@example.com", + }, + }, + Schemas: []*SchemaAnnotation{ + { + Path: MustParseRef("input.foo"), + Schema: MustParseRef("schema.a"), + }, + schemaAnnotationFromMap("input.bar", map[string]interface{}{ + "type": "boolean", + }), + }, + Custom: map[string]interface{}{ + "number": 42, + "float": 2.2, + "string": "foo bar baz", + "bool": true, + "list": []interface{}{ + "a", "b", + }, + "list_of_lists": []interface{}{ + []interface{}{ + "a", "b", + }, + []interface{}{ + "b", "c", + }, + }, + "list_of_maps": []interface{}{ + map[string]interface{}{ + "one": 1, + "two": 2, + }, + map[string]interface{}{ + "two": 2, + "three": 3, + }, + }, + "map": map[string]interface{}{ + "nested_number": 1, + "nested_map": map[string]interface{}{ + "do": "re", + "mi": "fa", + }, + "nested_list": []interface{}{ + 1, 2, 3, + }, + }, + }, + } + + expected := NewObject( + Item(StringTerm("scope"), StringTerm(annotationScopeRule)), + Item(StringTerm("title"), StringTerm("A title")), + Item(StringTerm("description"), StringTerm("A description")), + Item(StringTerm("organizations"), ArrayTerm( + StringTerm("Acme Corp."), + StringTerm("Tyrell Corp."), + )), + Item(StringTerm("related_resources"), ArrayTerm( + ObjectTerm( + Item(StringTerm("ref"), StringTerm("https://example.com")), + Item(StringTerm("description"), StringTerm("An example")), + ), + ObjectTerm( + Item(StringTerm("ref"), StringTerm("https://another.example.com")), + ), + )), + Item(StringTerm("authors"), ArrayTerm( + ObjectTerm( + Item(StringTerm("name"), StringTerm("John Doe")), + Item(StringTerm("email"), StringTerm("john@example.com")), + ), + ObjectTerm( + Item(StringTerm("name"), StringTerm("Jane Doe")), + ), + ObjectTerm( + Item(StringTerm("email"), StringTerm("jeff@example.com")), + ), + )), + Item(StringTerm("schemas"), ArrayTerm( + ObjectTerm( + Item(StringTerm("path"), ArrayTerm(StringTerm("input"), StringTerm("foo"))), + Item(StringTerm("schema"), ArrayTerm(StringTerm("schema"), StringTerm("a"))), + ), + ObjectTerm( + Item(StringTerm("path"), ArrayTerm(StringTerm("input"), StringTerm("bar"))), + Item(StringTerm("definition"), ObjectTerm( + Item(StringTerm("type"), StringTerm("boolean")), + )), + ), + )), + Item(StringTerm("custom"), ObjectTerm( + Item(StringTerm("number"), NumberTerm("42")), + Item(StringTerm("float"), NumberTerm("2.2")), + Item(StringTerm("string"), StringTerm("foo bar baz")), + Item(StringTerm("bool"), BooleanTerm(true)), + Item(StringTerm("list"), ArrayTerm( + StringTerm("a"), + StringTerm("b"), + )), + Item(StringTerm("list_of_lists"), ArrayTerm( + ArrayTerm( + StringTerm("a"), + StringTerm("b"), + ), + ArrayTerm( + StringTerm("b"), + StringTerm("c"), + ), + )), + Item(StringTerm("list_of_maps"), ArrayTerm( + ObjectTerm( + Item(StringTerm("one"), NumberTerm("1")), + Item(StringTerm("two"), NumberTerm("2")), + ), + ObjectTerm( + Item(StringTerm("two"), NumberTerm("2")), + Item(StringTerm("three"), NumberTerm("3")), + ), + )), + Item(StringTerm("map"), ObjectTerm( + Item(StringTerm("nested_number"), NumberTerm("1")), + Item(StringTerm("nested_map"), ObjectTerm( + Item(StringTerm("do"), StringTerm("re")), + Item(StringTerm("mi"), StringTerm("fa")), + )), + Item(StringTerm("nested_list"), ArrayTerm( + NumberTerm("1"), + NumberTerm("2"), + NumberTerm("3"), + )), + )), + )), + ) + + obj, err := annotations.toObject() + if err != nil { + t.Fatalf("unexpected error: %s", err.Error()) + } + + if Compare(*obj, expected) != 0 { + t.Fatalf("object generated from annotations\n\n%v\n\ndoesn't match expected\n\n%v", + *obj, expected) + } +} + +func toJSON(v interface{}) string { + b, _ := json.MarshalIndent(v, "", " ") + return string(b) +} + +func schemaAnnotationFromMap(path string, def map[string]interface{}) *SchemaAnnotation { + var p interface{} = def + return &SchemaAnnotation{Path: MustParseRef(path), Definition: &p} +} diff --git a/ast/builtins.go b/ast/builtins.go index ee57ed4413..c696eb1053 100644 --- a/ast/builtins.go +++ b/ast/builtins.go @@ -80,6 +80,7 @@ var DefaultBuiltins = [...]*Builtin{ // Arrays ArrayConcat, ArraySlice, + ArrayReverse, // Conversions ToNumber, @@ -101,6 +102,7 @@ var DefaultBuiltins = [...]*Builtin{ RegexTemplateMatch, RegexFind, RegexFindAllStringSubmatch, + RegexReplace, // Sets SetDiff, @@ -108,9 +110,12 @@ var DefaultBuiltins = [...]*Builtin{ Union, // Strings + AnyPrefixMatch, + AnySuffixMatch, Concat, FormatInt, IndexOf, + IndexOfN, Substring, Lower, Upper, @@ -127,6 +132,7 @@ var DefaultBuiltins = [...]*Builtin{ TrimSuffix, TrimSpace, Sprintf, + StringReverse, // Numbers NumbersRange, @@ -154,9 +160,12 @@ var DefaultBuiltins = [...]*Builtin{ // Object Manipulation ObjectUnion, + ObjectUnionN, ObjectRemove, ObjectFilter, ObjectGet, + ObjectKeys, + ObjectSubset, // JSON Object Manipulation JSONFilter, @@ -186,6 +195,7 @@ var DefaultBuiltins = [...]*Builtin{ ParseNanos, ParseRFC3339Nanos, ParseDurationNanos, + Format, Date, Clock, Weekday, @@ -204,10 +214,12 @@ var DefaultBuiltins = [...]*Builtin{ CryptoHmacSha1, CryptoHmacSha256, CryptoHmacSha512, + CryptoHmacEqual, // Graphs WalkBuiltin, ReachableBuiltin, + ReachablePathsBuiltin, // Sort Sort, @@ -225,8 +237,25 @@ var DefaultBuiltins = [...]*Builtin{ // HTTP HTTPSend, + // GraphQL + GraphQLParse, + GraphQLParseAndVerify, + GraphQLParseQuery, + GraphQLParseSchema, + GraphQLIsValid, + GraphQLSchemaIsValid, + + // JSON Schema + JSONSchemaVerify, + JSONMatchSchema, + + // Cloud Provider Helpers + ProvidersAWSSignReqObj, + // Rego RegoParseModule, + RegoMetadataChain, + RegoMetadataRule, // OPA OPARuntime, @@ -242,18 +271,20 @@ var DefaultBuiltins = [...]*Builtin{ NetCIDRExpand, NetCIDRMerge, NetLookupIPAddr, + NetCIDRIsValid, // Glob GlobMatch, GlobQuoteMeta, // Units + UnitsParse, UnitsParseBytes, // UUIDs UUIDRFC4122, - //SemVers + // SemVers SemVerIsValid, SemVerCompare, @@ -266,14 +297,18 @@ var DefaultBuiltins = [...]*Builtin{ // built-in definitions. var BuiltinMap map[string]*Builtin -// IgnoreDuringPartialEval is a set of built-in functions that should not be -// evaluated during partial evaluation. These functions are not partially -// evaluated because they are not pure. +// Deprecated: Builtins can now be directly annotated with the +// Nondeterministic property, and when set to true, will be ignored +// for partial evaluation. var IgnoreDuringPartialEval = []*Builtin{ + RandIntn, + UUIDRFC4122, + JWTDecodeVerify, + JWTEncodeSignRaw, + JWTEncodeSign, NowNanos, HTTPSend, - UUIDRFC4122, - RandIntn, + OPARuntime, NetLookupIPAddr, } @@ -336,219 +371,283 @@ var MemberWithKey = &Builtin{ /** * Comparisons */ +var comparison = category("comparison") -// GreaterThan represents the ">" comparison operator. var GreaterThan = &Builtin{ - Name: "gt", - Infix: ">", + Name: "gt", + Infix: ">", + Categories: comparison, Decl: types.NewFunction( - types.Args(types.A, types.A), - types.B, + types.Args( + types.Named("x", types.A), + types.Named("y", types.A), + ), + types.Named("result", types.B).Description("true if `x` is greater than `y`; false otherwise"), ), } -// GreaterThanEq represents the ">=" comparison operator. var GreaterThanEq = &Builtin{ - Name: "gte", - Infix: ">=", + Name: "gte", + Infix: ">=", + Categories: comparison, Decl: types.NewFunction( - types.Args(types.A, types.A), - types.B, + types.Args( + types.Named("x", types.A), + types.Named("y", types.A), + ), + types.Named("result", types.B).Description("true if `x` is greater or equal to `y`; false otherwise"), ), } // LessThan represents the "<" comparison operator. var LessThan = &Builtin{ - Name: "lt", - Infix: "<", + Name: "lt", + Infix: "<", + Categories: comparison, Decl: types.NewFunction( - types.Args(types.A, types.A), - types.B, + types.Args( + types.Named("x", types.A), + types.Named("y", types.A), + ), + types.Named("result", types.B).Description("true if `x` is less than `y`; false otherwise"), ), } -// LessThanEq represents the "<=" comparison operator. var LessThanEq = &Builtin{ - Name: "lte", - Infix: "<=", + Name: "lte", + Infix: "<=", + Categories: comparison, Decl: types.NewFunction( - types.Args(types.A, types.A), - types.B, + types.Args( + types.Named("x", types.A), + types.Named("y", types.A), + ), + types.Named("result", types.B).Description("true if `x` is less than or equal to `y`; false otherwise"), ), } -// NotEqual represents the "!=" comparison operator. var NotEqual = &Builtin{ - Name: "neq", - Infix: "!=", + Name: "neq", + Infix: "!=", + Categories: comparison, Decl: types.NewFunction( - types.Args(types.A, types.A), - types.B, + types.Args( + types.Named("x", types.A), + types.Named("y", types.A), + ), + types.Named("result", types.B).Description("true if `x` is not equal to `y`; false otherwise"), ), } // Equal represents the "==" comparison operator. var Equal = &Builtin{ - Name: "equal", - Infix: "==", + Name: "equal", + Infix: "==", + Categories: comparison, Decl: types.NewFunction( - types.Args(types.A, types.A), - types.B, + types.Args( + types.Named("x", types.A), + types.Named("y", types.A), + ), + types.Named("result", types.B).Description("true if `x` is equal to `y`; false otherwise"), ), } /** * Arithmetic */ +var number = category("numbers") -// Plus adds two numbers together. var Plus = &Builtin{ - Name: "plus", - Infix: "+", + Name: "plus", + Infix: "+", + Description: "Plus adds two numbers together.", Decl: types.NewFunction( - types.Args(types.N, types.N), - types.N, + types.Args( + types.Named("x", types.N), + types.Named("y", types.N), + ), + types.Named("z", types.N).Description("the sum of `x` and `y`"), ), + Categories: number, } -// Minus subtracts the second number from the first number or computes the diff -// between two sets. var Minus = &Builtin{ - Name: "minus", - Infix: "-", + Name: "minus", + Infix: "-", + Description: "Minus subtracts the second number from the first number or computes the difference between two sets.", Decl: types.NewFunction( types.Args( - types.NewAny(types.N, types.NewSet(types.A)), - types.NewAny(types.N, types.NewSet(types.A)), + types.Named("x", types.NewAny(types.N, types.NewSet(types.A))), + types.Named("y", types.NewAny(types.N, types.NewSet(types.A))), ), - types.NewAny(types.N, types.NewSet(types.A)), + types.Named("z", types.NewAny(types.N, types.NewSet(types.A))).Description("the difference of `x` and `y`"), ), + Categories: category("sets", "numbers"), } -// Multiply multiplies two numbers together. var Multiply = &Builtin{ - Name: "mul", - Infix: "*", + Name: "mul", + Infix: "*", + Description: "Multiplies two numbers.", Decl: types.NewFunction( - types.Args(types.N, types.N), - types.N, + types.Args( + types.Named("x", types.N), + types.Named("y", types.N), + ), + types.Named("z", types.N).Description("the product of `x` and `y`"), ), + Categories: number, } -// Divide divides the first number by the second number. var Divide = &Builtin{ - Name: "div", - Infix: "/", + Name: "div", + Infix: "/", + Description: "Divides the first number by the second number.", Decl: types.NewFunction( - types.Args(types.N, types.N), - types.N, + types.Args( + types.Named("x", types.N).Description("the dividend"), + types.Named("y", types.N).Description("the divisor"), + ), + types.Named("z", types.N).Description("the result of `x` divided by `y`"), ), + Categories: number, } -// Round rounds the number to the nearest integer. var Round = &Builtin{ - Name: "round", + Name: "round", + Description: "Rounds the number to the nearest integer.", Decl: types.NewFunction( - types.Args(types.N), - types.N, + types.Args( + types.Named("x", types.N).Description("the number to round"), + ), + types.Named("y", types.N).Description("the result of rounding `x`"), ), + Categories: number, } -// Ceil rounds the number up to the nearest integer. var Ceil = &Builtin{ - Name: "ceil", + Name: "ceil", + Description: "Rounds the number _up_ to the nearest integer.", Decl: types.NewFunction( - types.Args(types.N), - types.N, + types.Args( + types.Named("x", types.N).Description("the number to round"), + ), + types.Named("y", types.N).Description("the result of rounding `x` _up_"), ), + Categories: number, } -// Floor rounds the number down to the nearest integer. var Floor = &Builtin{ - Name: "floor", + Name: "floor", + Description: "Rounds the number _down_ to the nearest integer.", Decl: types.NewFunction( - types.Args(types.N), - types.N, + types.Args( + types.Named("x", types.N).Description("the number to round"), + ), + types.Named("y", types.N).Description("the result of rounding `x` _down_"), ), + Categories: number, } -// Abs returns the number without its sign. var Abs = &Builtin{ - Name: "abs", + Name: "abs", + Description: "Returns the number without its sign.", Decl: types.NewFunction( - types.Args(types.N), - types.N, + types.Args( + types.Named("x", types.N), + ), + types.Named("y", types.N).Description("the absolute value of `x`"), ), + Categories: number, } -// Rem returns the remainder for x%y for y != 0. var Rem = &Builtin{ - Name: "rem", - Infix: "%", + Name: "rem", + Infix: "%", + Description: "Returns the remainder for of `x` divided by `y`, for `y != 0`.", Decl: types.NewFunction( - types.Args(types.N, types.N), - types.N, + types.Args( + types.Named("x", types.N), + types.Named("y", types.N), + ), + types.Named("z", types.N).Description("the remainder"), ), + Categories: number, } /** * Bitwise */ -// BitsOr returns the bitwise "or" of two integers. var BitsOr = &Builtin{ - Name: "bits.or", + Name: "bits.or", + Description: "Returns the bitwise \"OR\" of two integers.", Decl: types.NewFunction( - types.Args(types.N, types.N), - types.N, + types.Args( + types.Named("x", types.N), + types.Named("y", types.N), + ), + types.Named("z", types.N), ), } -// BitsAnd returns the bitwise "and" of two integers. var BitsAnd = &Builtin{ - Name: "bits.and", + Name: "bits.and", + Description: "Returns the bitwise \"AND\" of two integers.", Decl: types.NewFunction( - types.Args(types.N, types.N), - types.N, + types.Args( + types.Named("x", types.N), + types.Named("y", types.N), + ), + types.Named("z", types.N), ), } -// BitsNegate returns the bitwise "negation" of an integer (i.e. flips each -// bit). var BitsNegate = &Builtin{ - Name: "bits.negate", + Name: "bits.negate", + Description: "Returns the bitwise negation (flip) of an integer.", Decl: types.NewFunction( - types.Args(types.N), - types.N, + types.Args( + types.Named("x", types.N), + ), + types.Named("z", types.N), ), } -// BitsXOr returns the bitwise "exclusive-or" of two integers. var BitsXOr = &Builtin{ - Name: "bits.xor", + Name: "bits.xor", + Description: "Returns the bitwise \"XOR\" (exclusive-or) of two integers.", Decl: types.NewFunction( - types.Args(types.N, types.N), - types.N, + types.Args( + types.Named("x", types.N), + types.Named("y", types.N), + ), + types.Named("z", types.N), ), } -// BitsShiftLeft returns a new integer with its bits shifted some value to the -// left. var BitsShiftLeft = &Builtin{ - Name: "bits.lsh", + Name: "bits.lsh", + Description: "Returns a new integer with its bits shifted `s` bits to the left.", Decl: types.NewFunction( - types.Args(types.N, types.N), - types.N, + types.Args( + types.Named("x", types.N), + types.Named("s", types.N), + ), + types.Named("z", types.N), ), } -// BitsShiftRight returns a new integer with its bits shifted some value to the -// right. var BitsShiftRight = &Builtin{ - Name: "bits.rsh", + Name: "bits.rsh", + Description: "Returns a new integer with its bits shifted `s` bits to the right.", Decl: types.NewFunction( - types.Args(types.N, types.N), - types.N, + types.Args( + types.Named("x", types.N), + types.Named("s", types.N), + ), + types.Named("z", types.N), ), } @@ -556,528 +655,662 @@ var BitsShiftRight = &Builtin{ * Sets */ -// And performs an intersection operation on sets. +var sets = category("sets") + var And = &Builtin{ - Name: "and", - Infix: "&", + Name: "and", + Infix: "&", + Description: "Returns the intersection of two sets.", Decl: types.NewFunction( types.Args( - types.NewSet(types.A), - types.NewSet(types.A), + types.Named("x", types.NewSet(types.A)), + types.Named("y", types.NewSet(types.A)), ), - types.NewSet(types.A), + types.Named("z", types.NewSet(types.A)).Description("the intersection of `x` and `y`"), ), + Categories: sets, } // Or performs a union operation on sets. var Or = &Builtin{ - Name: "or", - Infix: "|", + Name: "or", + Infix: "|", + Description: "Returns the union of two sets.", Decl: types.NewFunction( types.Args( - types.NewSet(types.A), - types.NewSet(types.A), + types.Named("x", types.NewSet(types.A)), + types.Named("y", types.NewSet(types.A)), ), - types.NewSet(types.A), + types.Named("z", types.NewSet(types.A)).Description("the union of `x` and `y`"), + ), + Categories: sets, +} + +var Intersection = &Builtin{ + Name: "intersection", + Description: "Returns the intersection of the given input sets.", + Decl: types.NewFunction( + types.Args( + types.Named("xs", types.NewSet(types.NewSet(types.A))).Description("set of sets to intersect"), + ), + types.Named("y", types.NewSet(types.A)).Description("the intersection of all `xs` sets"), + ), + Categories: sets, +} + +var Union = &Builtin{ + Name: "union", + Description: "Returns the union of the given input sets.", + Decl: types.NewFunction( + types.Args( + types.Named("xs", types.NewSet(types.NewSet(types.A))).Description("set of sets to merge"), + ), + types.Named("y", types.NewSet(types.A)).Description("the union of all `xs` sets"), ), + Categories: sets, } /** * Aggregates */ -// Count takes a collection or string and counts the number of elements in it. +var aggregates = category("aggregates") + var Count = &Builtin{ - Name: "count", + Name: "count", + Description: " Count takes a collection or string and returns the number of elements (or characters) in it.", Decl: types.NewFunction( types.Args( - types.NewAny( + types.Named("collection", types.NewAny( types.NewSet(types.A), types.NewArray(nil, types.A), types.NewObject(nil, types.NewDynamicProperty(types.A, types.A)), types.S, - ), + )).Description("the set/array/object/string to be counted"), ), - types.N, + types.Named("n", types.N).Description("the count of elements, key/val pairs, or characters, respectively."), ), + Categories: aggregates, } -// Sum takes an array or set of numbers and sums them. var Sum = &Builtin{ - Name: "sum", + Name: "sum", + Description: "Sums elements of an array or set of numbers.", Decl: types.NewFunction( types.Args( - types.NewAny( + types.Named("collection", types.NewAny( types.NewSet(types.N), types.NewArray(nil, types.N), - ), + )), ), - types.N, + types.Named("n", types.N).Description("the sum of all elements"), ), + Categories: aggregates, } -// Product takes an array or set of numbers and multiplies them. var Product = &Builtin{ - Name: "product", + Name: "product", + Description: "Muliplies elements of an array or set of numbers", Decl: types.NewFunction( types.Args( - types.NewAny( + types.Named("collection", types.NewAny( types.NewSet(types.N), types.NewArray(nil, types.N), - ), + )), ), - types.N, + types.Named("n", types.N).Description("the product of all elements"), ), + Categories: aggregates, } -// Max returns the maximum value in a collection. var Max = &Builtin{ - Name: "max", + Name: "max", + Description: "Returns the maximum value in a collection.", Decl: types.NewFunction( types.Args( - types.NewAny( + types.Named("collection", types.NewAny( types.NewSet(types.A), types.NewArray(nil, types.A), - ), + )), ), - types.A, + types.Named("n", types.A).Description("the maximum of all elements"), ), + Categories: aggregates, } -// Min returns the minimum value in a collection. var Min = &Builtin{ - Name: "min", + Name: "min", + Description: "Returns the minimum value in a collection.", Decl: types.NewFunction( types.Args( - types.NewAny( + types.Named("collection", types.NewAny( types.NewSet(types.A), types.NewArray(nil, types.A), - ), + )), ), - types.A, + types.Named("n", types.A).Description("the minimum of all elements"), ), + Categories: aggregates, } -// All takes a list and returns true if all of the items -// are true. A collection of length 0 returns true. -var All = &Builtin{ - Name: "all", - Decl: types.NewFunction( - types.Args( - types.NewAny( - types.NewSet(types.A), - types.NewArray(nil, types.A), - ), - ), - types.B, - ), -} +/** + * Sorting + */ -// Any takes a collection and returns true if any of the items -// is true. A collection of length 0 returns false. -var Any = &Builtin{ - Name: "any", +var Sort = &Builtin{ + Name: "sort", + Description: "Returns a sorted array.", Decl: types.NewFunction( types.Args( - types.NewAny( - types.NewSet(types.A), + types.Named("collection", types.NewAny( types.NewArray(nil, types.A), - ), + types.NewSet(types.A), + )).Description("the array or set to be sorted"), ), - types.B, + types.Named("n", types.NewArray(nil, types.A)).Description("the sorted array"), ), + Categories: aggregates, } /** * Arrays */ -// ArrayConcat returns the result of concatenating two arrays together. var ArrayConcat = &Builtin{ - Name: "array.concat", + Name: "array.concat", + Description: "Concatenates two arrays.", Decl: types.NewFunction( types.Args( - types.NewArray(nil, types.A), - types.NewArray(nil, types.A), + types.Named("x", types.NewArray(nil, types.A)), + types.Named("y", types.NewArray(nil, types.A)), ), - types.NewArray(nil, types.A), + types.Named("z", types.NewArray(nil, types.A)).Description("the concatenation of `x` and `y`"), ), } -// ArraySlice returns a slice of a given array var ArraySlice = &Builtin{ - Name: "array.slice", + Name: "array.slice", + Description: "Returns a slice of a given array. If `start` is greater or equal than `stop`, `slice` is `[]`.", Decl: types.NewFunction( types.Args( - types.NewArray(nil, types.A), - types.NewNumber(), - types.NewNumber(), + types.Named("arr", types.NewArray(nil, types.A)).Description("the array to be sliced"), + types.Named("start", types.NewNumber()).Description("the start index of the returned slice; if less than zero, it's clamped to 0"), + types.Named("stop", types.NewNumber()).Description("the stop index of the returned slice; if larger than `count(arr)`, it's clamped to `count(arr)`"), ), - types.NewArray(nil, types.A), + types.Named("slice", types.NewArray(nil, types.A)).Description("the subslice of `array`, from `start` to `end`, including `arr[start]`, but excluding `arr[end]`"), + ), +} // NOTE(sr): this function really needs examples + +var ArrayReverse = &Builtin{ + Name: "array.reverse", + Description: "Returns the reverse of a given array.", + Decl: types.NewFunction( + types.Args( + types.Named("arr", types.NewArray(nil, types.A)).Description("the array to be reversed"), + ), + types.Named("rev", types.NewArray(nil, types.A)).Description("an array containing the elements of `arr` in reverse order"), ), } /** * Conversions */ +var conversions = category("conversions") -// ToNumber takes a string, bool, or number value and converts it to a number. -// Strings are converted to numbers using strconv.Atoi. -// Boolean false is converted to 0 and boolean true is converted to 1. var ToNumber = &Builtin{ - Name: "to_number", + Name: "to_number", + Description: "Converts a string, bool, or number value to a number: Strings are converted to numbers using `strconv.Atoi`, Boolean `false` is converted to 0 and `true` is converted to 1.", Decl: types.NewFunction( types.Args( - types.NewAny( + types.Named("x", types.NewAny( types.N, types.S, types.B, types.NewNull(), - ), + )), ), - types.N, + types.Named("num", types.N), ), + Categories: conversions, } /** * Regular Expressions */ -// RegexMatch takes two strings and evaluates to true if the string in the second -// position matches the pattern in the first position. var RegexMatch = &Builtin{ - Name: "regex.match", + Name: "regex.match", + Description: "Matches a string against a regular expression.", Decl: types.NewFunction( types.Args( - types.S, - types.S, + types.Named("pattern", types.S).Description("regular expression"), + types.Named("value", types.S).Description("value to match against `pattern`"), ), - types.B, + types.Named("result", types.B), ), } -// RegexIsValid returns true if the regex pattern string is valid, otherwise false. var RegexIsValid = &Builtin{ - Name: "regex.is_valid", + Name: "regex.is_valid", + Description: "Checks if a string is a valid regular expression: the detailed syntax for patterns is defined by https://github.com/google/re2/wiki/Syntax.", Decl: types.NewFunction( types.Args( - types.S, + types.Named("pattern", types.S).Description("regular expression"), ), - types.B, + types.Named("result", types.B), ), } -// RegexFindAllStringSubmatch returns an array of all successive matches of the expression. -// It takes two strings and a number, the pattern, the value and number of matches to -// return, -1 means all matches. var RegexFindAllStringSubmatch = &Builtin{ - Name: "regex.find_all_string_submatch_n", + Name: "regex.find_all_string_submatch_n", + Description: "Returns all successive matches of the expression.", Decl: types.NewFunction( types.Args( - types.S, - types.S, - types.N, + types.Named("pattern", types.S).Description("regular expression"), + types.Named("value", types.S).Description("string to match"), + types.Named("number", types.N).Description("number of matches to return; `-1` means all matches"), ), - types.NewArray(nil, types.NewArray(nil, types.S)), + types.Named("output", types.NewArray(nil, types.NewArray(nil, types.S))), ), } -// RegexTemplateMatch takes two strings and evaluates to true if the string in the second -// position matches the pattern in the first position. var RegexTemplateMatch = &Builtin{ - Name: "regex.template_match", + Name: "regex.template_match", + Description: "Matches a string against a pattern, where there pattern may be glob-like", Decl: types.NewFunction( types.Args( - types.S, - types.S, - types.S, - types.S, + types.Named("template", types.S).Description("template expression containing `0..n` regular expressions"), + types.Named("value", types.S).Description("string to match"), + types.Named("delimiter_start", types.S).Description("start delimiter of the regular expression in `template`"), + types.Named("delimiter_end", types.S).Description("end delimiter of the regular expression in `template`"), ), - types.B, + types.Named("result", types.B), ), -} +} // TODO(sr): example:`regex.template_match("urn:foo:{.*}", "urn:foo:bar:baz", "{", "}")`` returns ``true``. -// RegexSplit splits the input string by the occurrences of the given pattern. var RegexSplit = &Builtin{ - Name: "regex.split", + Name: "regex.split", + Description: "Splits the input string by the occurrences of the given pattern.", Decl: types.NewFunction( types.Args( - types.S, - types.S, + types.Named("pattern", types.S).Description("regular expression"), + types.Named("value", types.S).Description("string to match"), ), - types.NewArray(nil, types.S), + types.Named("output", types.NewArray(nil, types.S)).Description("the parts obtained by splitting `value`"), ), } // RegexFind takes two strings and a number, the pattern, the value and number of match values to // return, -1 means all match values. var RegexFind = &Builtin{ - Name: "regex.find_n", + Name: "regex.find_n", + Description: "Returns the specified number of matches when matching the input against the pattern.", Decl: types.NewFunction( types.Args( - types.S, - types.S, - types.N, + types.Named("pattern", types.S).Description("regular expression"), + types.Named("value", types.S).Description("string to match"), + types.Named("number", types.N).Description("number of matches to return, if `-1`, returns all matches"), ), - types.NewArray(nil, types.S), + types.Named("output", types.NewArray(nil, types.S)).Description("collected matches"), ), } // GlobsMatch takes two strings regexp-style strings and evaluates to true if their // intersection matches a non-empty set of non-empty strings. // Examples: -// - "a.a." and ".b.b" -> true. -// - "[a-z]*" and [0-9]+" -> not true. +// - "a.a." and ".b.b" -> true. +// - "[a-z]*" and [0-9]+" -> not true. var GlobsMatch = &Builtin{ Name: "regex.globs_match", + Description: `Checks if the intersection of two glob-style regular expressions matches a non-empty set of non-empty strings. +The set of regex symbols is limited for this builtin: only ` + "`.`, `*`, `+`, `[`, `-`, `]` and `\\` are treated as special symbols.", Decl: types.NewFunction( types.Args( - types.S, - types.S, + types.Named("glob1", types.S), + types.Named("glob2", types.S), ), - types.B, + types.Named("result", types.B), ), } /** * Strings */ +var stringsCat = category("strings") + +var AnyPrefixMatch = &Builtin{ + Name: "strings.any_prefix_match", + Description: "Returns true if any of the search strings begins with any of the base strings.", + Decl: types.NewFunction( + types.Args( + types.Named("search", types.NewAny( + types.S, + types.NewSet(types.S), + types.NewArray(nil, types.S), + )).Description("search string(s)"), + types.Named("base", types.NewAny( + types.S, + types.NewSet(types.S), + types.NewArray(nil, types.S), + )).Description("base string(s)"), + ), + types.Named("result", types.B).Description("result of the prefix check"), + ), + Categories: stringsCat, +} + +var AnySuffixMatch = &Builtin{ + Name: "strings.any_suffix_match", + Description: "Returns true if any of the search strings ends with any of the base strings.", + Decl: types.NewFunction( + types.Args( + types.Named("search", types.NewAny( + types.S, + types.NewSet(types.S), + types.NewArray(nil, types.S), + )).Description("search string(s)"), + types.Named("base", types.NewAny( + types.S, + types.NewSet(types.S), + types.NewArray(nil, types.S), + )).Description("base string(s)"), + ), + types.Named("result", types.B).Description("result of the suffix check"), + ), + Categories: stringsCat, +} -// Concat joins an array of strings with an input string. var Concat = &Builtin{ - Name: "concat", + Name: "concat", + Description: "Joins a set or array of strings with a delimiter.", Decl: types.NewFunction( types.Args( - types.S, - types.NewAny( + types.Named("delimiter", types.S), + types.Named("collection", types.NewAny( types.NewSet(types.S), types.NewArray(nil, types.S), - ), + )).Description("strings to join"), ), - types.S, + types.Named("output", types.S), ), + Categories: stringsCat, } -// FormatInt returns the string representation of the number in the given base after converting it to an integer value. var FormatInt = &Builtin{ - Name: "format_int", + Name: "format_int", + Description: "Returns the string representation of the number in the given base after rounding it down to an integer value.", Decl: types.NewFunction( types.Args( - types.N, - types.N, + types.Named("number", types.N).Description("number to format"), + types.Named("base", types.N).Description("base of number representation to use"), ), - types.S, + types.Named("output", types.S).Description("formatted number"), ), + Categories: stringsCat, } -// IndexOf returns the index of a substring contained inside a string var IndexOf = &Builtin{ - Name: "indexof", + Name: "indexof", + Description: "Returns the index of a substring contained inside a string.", Decl: types.NewFunction( types.Args( - types.S, - types.S, + types.Named("haystack", types.S).Description("string to search in"), + types.Named("needle", types.S).Description("substring to look for"), + ), + types.Named("output", types.N).Description("index of first occurrence, `-1` if not found"), + ), + Categories: stringsCat, +} + +var IndexOfN = &Builtin{ + Name: "indexof_n", + Description: "Returns a list of all the indexes of a substring contained inside a string.", + Decl: types.NewFunction( + types.Args( + types.Named("haystack", types.S).Description("string to search in"), + types.Named("needle", types.S).Description("substring to look for"), ), - types.N, + types.Named("output", types.NewArray(nil, types.N)).Description("all indices at which `needle` occurs in `haystack`, may be empty"), ), + Categories: stringsCat, } -// Substring returns the portion of a string for a given start index and a length. -// If the length is less than zero, then substring returns the remainder of the string. var Substring = &Builtin{ - Name: "substring", + Name: "substring", + Description: "Returns the portion of a string for a given `offset` and a `length`. If `length < 0`, `output` is the remainder of the string.", Decl: types.NewFunction( types.Args( - types.S, - types.N, - types.N, + types.Named("value", types.S), + types.Named("offset", types.N).Description("offset, must be positive"), + types.Named("length", types.N).Description("length of the substring starting from `offset`"), ), - types.S, + types.Named("output", types.S).Description("substring of `value` from `offset`, of length `length`"), ), + Categories: stringsCat, } -// Contains returns true if the search string is included in the base string var Contains = &Builtin{ - Name: "contains", + Name: "contains", + Description: "Returns `true` if the search string is included in the base string", Decl: types.NewFunction( types.Args( - types.S, - types.S, + types.Named("haystack", types.S).Description("string to search in"), + types.Named("needle", types.S).Description("substring to look for"), ), - types.B, + types.Named("result", types.B).Description("result of the containment check"), ), + Categories: stringsCat, } -// StartsWith returns true if the search string begins with the base string var StartsWith = &Builtin{ - Name: "startswith", + Name: "startswith", + Description: "Returns true if the search string begins with the base string.", Decl: types.NewFunction( types.Args( - types.S, - types.S, + types.Named("search", types.S).Description("search string"), + types.Named("base", types.S).Description("base string"), ), - types.B, + types.Named("result", types.B).Description("result of the prefix check"), ), + Categories: stringsCat, } -// EndsWith returns true if the search string begins with the base string var EndsWith = &Builtin{ - Name: "endswith", + Name: "endswith", + Description: "Returns true if the search string ends with the base string.", Decl: types.NewFunction( types.Args( - types.S, - types.S, + types.Named("search", types.S).Description("search string"), + types.Named("base", types.S).Description("base string"), ), - types.B, + types.Named("result", types.B).Description("result of the suffix check"), ), + Categories: stringsCat, } -// Lower returns the input string but with all characters in lower-case var Lower = &Builtin{ - Name: "lower", + Name: "lower", + Description: "Returns the input string but with all characters in lower-case.", Decl: types.NewFunction( - types.Args(types.S), - types.S, + types.Args( + types.Named("x", types.S).Description("string that is converted to lower-case"), + ), + types.Named("y", types.S).Description("lower-case of x"), ), + Categories: stringsCat, } -// Upper returns the input string but with all characters in upper-case var Upper = &Builtin{ - Name: "upper", + Name: "upper", + Description: "Returns the input string but with all characters in upper-case.", Decl: types.NewFunction( - types.Args(types.S), - types.S, + types.Args( + types.Named("x", types.S).Description("string that is converted to upper-case"), + ), + types.Named("y", types.S).Description("upper-case of x"), ), + Categories: stringsCat, } -// Split returns an array containing elements of the input string split on a delimiter. var Split = &Builtin{ - Name: "split", + Name: "split", + Description: "Split returns an array containing elements of the input string split on a delimiter.", Decl: types.NewFunction( types.Args( - types.S, - types.S, + types.Named("x", types.S).Description("string that is split"), + types.Named("delimiter", types.S).Description("delimiter used for splitting"), ), - types.NewArray(nil, types.S), + types.Named("ys", types.NewArray(nil, types.S)).Description("splitted parts"), ), + Categories: stringsCat, } -// Replace returns the given string with all instances of the second argument replaced -// by the third. var Replace = &Builtin{ - Name: "replace", + Name: "replace", + Description: "Replace replaces all instances of a sub-string.", Decl: types.NewFunction( types.Args( - types.S, - types.S, - types.S, + types.Named("x", types.S).Description("string being processed"), + types.Named("old", types.S).Description("substring to replace"), + types.Named("new", types.S).Description("string to replace `old` with"), ), - types.S, + types.Named("y", types.S).Description("string with replaced substrings"), ), + Categories: stringsCat, } -// ReplaceN replaces a string from a list of old, new string pairs. -// Replacements are performed in the order they appear in the target string, without overlapping matches. -// The old string comparisons are done in argument order. var ReplaceN = &Builtin{ Name: "strings.replace_n", + Description: `Replaces a string from a list of old, new string pairs. +Replacements are performed in the order they appear in the target string, without overlapping matches. +The old string comparisons are done in argument order.`, Decl: types.NewFunction( types.Args( - types.NewObject( + types.Named("patterns", types.NewObject( nil, types.NewDynamicProperty( types.S, types.S)), - types.S, + ).Description("replacement pairs"), + types.Named("value", types.S).Description("string to replace substring matches in"), ), - types.S, + types.Named("output", types.S), + ), +} + +var RegexReplace = &Builtin{ + Name: "regex.replace", + Description: `Find and replaces the text using the regular expression pattern.`, + Decl: types.NewFunction( + types.Args( + types.Named("s", types.S).Description("string being processed"), + types.Named("pattern", types.S).Description("regex pattern to be applied"), + types.Named("value", types.S).Description("regex value"), + ), + types.Named("output", types.S), ), } -// Trim returns the given string with all leading or trailing instances of the second -// argument removed. var Trim = &Builtin{ - Name: "trim", + Name: "trim", + Description: "Returns `value` with all leading or trailing instances of the `cutset` characters removed.", Decl: types.NewFunction( types.Args( - types.S, - types.S, + types.Named("value", types.S).Description("string to trim"), + types.Named("cutset", types.S).Description("string of characters that are cut off"), ), - types.S, + types.Named("output", types.S).Description("string trimmed of `cutset` characters"), ), + Categories: stringsCat, } -// TrimLeft returns the given string with all leading instances of second argument removed. var TrimLeft = &Builtin{ - Name: "trim_left", + Name: "trim_left", + Description: "Returns `value` with all leading instances of the `cutset` chartacters removed.", Decl: types.NewFunction( types.Args( - types.S, - types.S, + types.Named("value", types.S).Description("string to trim"), + types.Named("cutset", types.S).Description("string of characters that are cut off on the left"), ), - types.S, + types.Named("output", types.S).Description("string left-trimmed of `cutset` characters"), ), + Categories: stringsCat, } -// TrimPrefix returns the given string without the second argument prefix string. -// If the given string doesn't start with prefix, it is returned unchanged. var TrimPrefix = &Builtin{ - Name: "trim_prefix", + Name: "trim_prefix", + Description: "Returns `value` without the prefix. If `value` doesn't start with `prefix`, it is returned unchanged.", Decl: types.NewFunction( types.Args( - types.S, - types.S, + types.Named("value", types.S).Description("string to trim"), + types.Named("prefix", types.S).Description("prefix to cut off"), ), - types.S, + types.Named("output", types.S).Description("string with `prefix` cut off"), ), + Categories: stringsCat, } -// TrimRight returns the given string with all trailing instances of second argument removed. var TrimRight = &Builtin{ - Name: "trim_right", + Name: "trim_right", + Description: "Returns `value` with all trailing instances of the `cutset` chartacters removed.", Decl: types.NewFunction( types.Args( - types.S, - types.S, + types.Named("value", types.S).Description("string to trim"), + types.Named("cutset", types.S).Description("string of characters that are cut off on the right"), ), - types.S, + types.Named("output", types.S).Description("string right-trimmed of `cutset` characters"), ), + Categories: stringsCat, } -// TrimSuffix returns the given string without the second argument suffix string. -// If the given string doesn't end with suffix, it is returned unchanged. var TrimSuffix = &Builtin{ - Name: "trim_suffix", + Name: "trim_suffix", + Description: "Returns `value` without the suffix. If `value` doesn't end with `suffix`, it is returned unchanged.", Decl: types.NewFunction( types.Args( - types.S, - types.S, + types.Named("value", types.S).Description("string to trim"), + types.Named("suffix", types.S).Description("suffix to cut off"), ), - types.S, + types.Named("output", types.S).Description("string with `suffix` cut off"), ), + Categories: stringsCat, } -// TrimSpace return the given string with all leading and trailing white space removed. var TrimSpace = &Builtin{ - Name: "trim_space", + Name: "trim_space", + Description: "Return the given string with all leading and trailing white space removed.", Decl: types.NewFunction( types.Args( - types.S, + types.Named("value", types.S).Description("string to trim"), ), - types.S, + types.Named("output", types.S).Description("string leading and trailing white space cut off"), ), + Categories: stringsCat, } -// Sprintf returns the given string, formatted. var Sprintf = &Builtin{ - Name: "sprintf", + Name: "sprintf", + Description: "Returns the given string, formatted.", Decl: types.NewFunction( types.Args( - types.S, - types.NewArray(nil, types.A), + types.Named("format", types.S).Description("string with formatting verbs"), + types.Named("values", types.NewArray(nil, types.A)).Description("arguments to format into formatting verbs"), ), - types.S, + types.Named("output", types.S).Description("`format` formatted by the values in `values`"), + ), + Categories: stringsCat, +} + +var StringReverse = &Builtin{ + Name: "strings.reverse", + Description: "Reverses a given string.", + Decl: types.NewFunction( + types.Args( + types.Named("x", types.S), + ), + types.Named("y", types.S), ), + Categories: stringsCat, } /** @@ -1085,26 +1318,30 @@ var Sprintf = &Builtin{ */ // RandIntn returns a random number 0 - n +// Marked non-deterministic because it relies on RNG internally. var RandIntn = &Builtin{ - Name: "rand.intn", + Name: "rand.intn", + Description: "Returns a random integer between `0` and `n` (`n` exlusive). If `n` is `0`, then `y` is always `0`. For any given argument pair (`str`, `n`), the output will be consistent throughout a query evaluation.", Decl: types.NewFunction( types.Args( - types.S, - types.N, + types.Named("str", types.S), + types.Named("n", types.N), ), - types.N, + types.Named("y", types.N).Description("random integer in the range `[0, abs(n))`"), ), + Categories: number, + Nondeterministic: true, } -// NumbersRange returns an array of numbers in the given inclusive range. var NumbersRange = &Builtin{ - Name: "numbers.range", + Name: "numbers.range", + Description: "Returns an array of numbers in the given (inclusive) range. If `a==b`, then `range == [a]`; if `a > b`, then `range` is in descending order.", Decl: types.NewFunction( types.Args( - types.N, - types.N, + types.Named("a", types.N), + types.Named("b", types.N), ), - types.NewArray(nil, types.N), + types.Named("range", types.NewArray(nil, types.N)).Description("the range between `a` and `b`"), ), } @@ -1112,73 +1349,73 @@ var NumbersRange = &Builtin{ * Units */ -// UnitsParseBytes converts strings like 10GB, 5K, 4mb, and the like into an -// integer number of bytes. -var UnitsParseBytes = &Builtin{ - Name: "units.parse_bytes", +var UnitsParse = &Builtin{ + Name: "units.parse", + Description: `Converts strings like "10G", "5K", "4M", "1500m" and the like into a number. +This number can be a non-integer, such as 1.5, 0.22, etc. Supports standard metric decimal and +binary SI units (e.g., K, Ki, M, Mi, G, Gi etc.) m, K, M, G, T, P, and E are treated as decimal +units and Ki, Mi, Gi, Ti, Pi, and Ei are treated as binary units. + +Note that 'm' and 'M' are case-sensitive, to allow distinguishing between "milli" and "mega" units respectively. Other units are case-insensitive.`, Decl: types.NewFunction( types.Args( - types.S, + types.Named("x", types.S).Description("the unit to parse"), ), - types.N, + types.Named("y", types.N).Description("the parsed number"), ), } -// -/** - * Type - */ - -// UUIDRFC4122 returns a version 4 UUID string. -var UUIDRFC4122 = &Builtin{ - Name: "uuid.rfc4122", +var UnitsParseBytes = &Builtin{ + Name: "units.parse_bytes", + Description: `Converts strings like "10GB", "5K", "4mb" into an integer number of bytes. +Supports standard byte units (e.g., KB, KiB, etc.) KB, MB, GB, and TB are treated as decimal +units and KiB, MiB, GiB, and TiB are treated as binary units. The bytes symbol (b/B) in the +unit is optional and omitting it wil give the same result (e.g. Mi and MiB).`, Decl: types.NewFunction( - types.Args(types.S), - types.S, + types.Args( + types.Named("x", types.S).Description("the byte unit to parse"), + ), + types.Named("y", types.N).Description("the parsed number"), ), } +// /** - * JSON + * Type */ -// JSONMarshal serializes the input term. -var JSONMarshal = &Builtin{ - Name: "json.marshal", +// UUIDRFC4122 returns a version 4 UUID string. +// Marked non-deterministic because it relies on RNG internally. +var UUIDRFC4122 = &Builtin{ + Name: "uuid.rfc4122", + Description: "Returns a new UUIDv4.", Decl: types.NewFunction( - types.Args(types.A), - types.S, + types.Args( + types.Named("k", types.S), + ), + types.Named("output", types.S).Description("a version 4 UUID; for any given `k`, the output will be consistent throughout a query evaluation"), ), + Nondeterministic: true, } -// JSONUnmarshal deserializes the input string. -var JSONUnmarshal = &Builtin{ - Name: "json.unmarshal", - Decl: types.NewFunction( - types.Args(types.S), - types.A, - ), -} +/** + * JSON + */ -// JSONIsValid verifies the input string is a valid JSON document. -var JSONIsValid = &Builtin{ - Name: "json.is_valid", - Decl: types.NewFunction( - types.Args(types.S), - types.B, - ), -} +var objectCat = category("object") -// JSONFilter filters the JSON object var JSONFilter = &Builtin{ Name: "json.filter", + Description: "Filters the object. " + + "For example: `json.filter({\"a\": {\"b\": \"x\", \"c\": \"y\"}}, [\"a/b\"])` will result in `{\"a\": {\"b\": \"x\"}}`). " + + "Paths are not filtered in-order and are deduplicated before being evaluated.", Decl: types.NewFunction( types.Args( - types.NewObject( + types.Named("object", types.NewObject( nil, types.NewDynamicProperty(types.A, types.A), - ), - types.NewAny( + )), + types.Named("paths", types.NewAny( types.NewArray( nil, types.NewAny( @@ -1198,22 +1435,25 @@ var JSONFilter = &Builtin{ ), ), ), - ), + )).Description("JSON string paths"), ), - types.A, + types.Named("filtered", types.A).Description("remaining data from `object` with only keys specified in `paths`"), ), + Categories: objectCat, } -// JSONRemove removes paths in the JSON object var JSONRemove = &Builtin{ Name: "json.remove", + Description: "Removes paths from an object. " + + "For example: `json.remove({\"a\": {\"b\": \"x\", \"c\": \"y\"}}, [\"a/b\"])` will result in `{\"a\": {\"c\": \"y\"}}`. " + + "Paths are not removed in-order and are deduplicated before being evaluated.", Decl: types.NewFunction( types.Args( - types.NewObject( + types.Named("object", types.NewObject( nil, types.NewDynamicProperty(types.A, types.A), - ), - types.NewAny( + )), + types.Named("paths", types.NewAny( types.NewArray( nil, types.NewAny( @@ -1233,19 +1473,23 @@ var JSONRemove = &Builtin{ ), ), ), - ), + )).Description("JSON string paths"), ), - types.A, + types.Named("output", types.A).Description("result of removing all keys specified in `paths`"), ), + Categories: objectCat, } -// JSONPatch patches a JSON object according to RFC6902 var JSONPatch = &Builtin{ Name: "json.patch", + Description: "Patches an object according to RFC6902. " + + "For example: `json.patch({\"a\": {\"foo\": 1}}, [{\"op\": \"add\", \"path\": \"/a/bar\", \"value\": 2}])` results in `{\"a\": {\"foo\": 1, \"bar\": 2}`. " + + "The patches are applied atomically: if any of them fails, the result will be undefined. " + + "Additionally works on sets, where a value contained in the set is considered to be its path.", Decl: types.NewFunction( types.Args( - types.A, - types.NewArray( + types.Named("object", types.A), // TODO(sr): types.A? + types.Named("patches", types.NewArray( nil, types.NewObject( []*types.StaticProperty{ @@ -1254,547 +1498,745 @@ var JSONPatch = &Builtin{ }, types.NewDynamicProperty(types.A, types.A), ), - ), + )), ), - types.A, + types.Named("output", types.A).Description("result obtained after consecutively applying all patch operations in `patches`"), ), + Categories: objectCat, } -// ObjectGet returns takes an object and returns a value under its key if -// present, otherwise it returns the default. -var ObjectGet = &Builtin{ - Name: "object.get", +var ObjectSubset = &Builtin{ + Name: "object.subset", + Description: "Determines if an object `sub` is a subset of another object `super`." + + "Object `sub` is a subset of object `super` if and only if every key in `sub` is also in `super`, " + + "**and** for all keys which `sub` and `super` share, they have the same value. " + + "This function works with objects, sets, arrays and a set of array and set." + + "If both arguments are objects, then the operation is recursive, e.g. " + + "`{\"c\": {\"x\": {10, 15, 20}}` is a subset of `{\"a\": \"b\", \"c\": {\"x\": {10, 15, 20, 25}, \"y\": \"z\"}`. " + + "If both arguments are sets, then this function checks if every element of `sub` is a member of `super`, " + + "but does not attempt to recurse. If both arguments are arrays, " + + "then this function checks if `sub` appears contiguously in order within `super`, " + + "and also does not attempt to recurse. If `super` is array and `sub` is set, " + + "then this function checks if `super` contains every element of `sub` with no consideration of ordering, " + + "and also does not attempt to recurse.", Decl: types.NewFunction( types.Args( - types.NewObject(nil, types.NewDynamicProperty(types.A, types.A)), - types.A, - types.A, + types.Named("super", types.NewAny(types.NewObject( + nil, + types.NewDynamicProperty(types.A, types.A), + ), + types.NewSet(types.A), + types.NewArray(nil, types.A), + )).Description("object to test if sub is a subset of"), + types.Named("sub", types.NewAny(types.NewObject( + nil, + types.NewDynamicProperty(types.A, types.A), + ), + types.NewSet(types.A), + types.NewArray(nil, types.A), + )).Description("object to test if super is a superset of"), ), - types.A, + types.Named("result", types.A).Description("`true` if `sub` is a subset of `super`"), ), } -// ObjectUnion creates a new object that is the asymmetric union of two objects var ObjectUnion = &Builtin{ Name: "object.union", + Description: "Creates a new object of the asymmetric union of two objects. " + + "For example: `object.union({\"a\": 1, \"b\": 2, \"c\": {\"d\": 3}}, {\"a\": 7, \"c\": {\"d\": 4, \"e\": 5}})` will result in `{\"a\": 7, \"b\": 2, \"c\": {\"d\": 4, \"e\": 5}}`.", Decl: types.NewFunction( types.Args( - types.NewObject( + types.Named("a", types.NewObject( nil, types.NewDynamicProperty(types.A, types.A), - ), - types.NewObject( + )), + types.Named("b", types.NewObject( nil, types.NewDynamicProperty(types.A, types.A), - ), + )), ), - types.A, + types.Named("output", types.A).Description("a new object which is the result of an asymmetric recursive union of two objects where conflicts are resolved by choosing the key from the right-hand object `b`"), + ), // TODO(sr): types.A? ^^^^^^^ (also below) +} + +var ObjectUnionN = &Builtin{ + Name: "object.union_n", + Description: "Creates a new object that is the asymmetric union of all objects merged from left to right. " + + "For example: `object.union_n([{\"a\": 1}, {\"b\": 2}, {\"a\": 3}])` will result in `{\"b\": 2, \"a\": 3}`.", + Decl: types.NewFunction( + types.Args( + types.Named("objects", types.NewArray( + nil, + types.NewObject(nil, types.NewDynamicProperty(types.A, types.A)), + )), + ), + types.Named("output", types.A).Description("asymmetric recursive union of all objects in `objects`, merged from left to right, where conflicts are resolved by choosing the key from the right-hand object"), ), } -// ObjectRemove Removes specified keys from an object var ObjectRemove = &Builtin{ - Name: "object.remove", + Name: "object.remove", + Description: "Removes specified keys from an object.", Decl: types.NewFunction( types.Args( - types.NewObject( + types.Named("object", types.NewObject( nil, types.NewDynamicProperty(types.A, types.A), - ), - types.NewAny( + )).Description("object to remove keys from"), + types.Named("keys", types.NewAny( types.NewArray(nil, types.A), types.NewSet(types.A), types.NewObject(nil, types.NewDynamicProperty(types.A, types.A)), - ), + )).Description("keys to remove from x"), ), - types.A, + types.Named("output", types.A).Description("result of removing the specified `keys` from `object`"), ), } -// ObjectFilter filters the object by keeping only specified keys var ObjectFilter = &Builtin{ Name: "object.filter", + Description: "Filters the object by keeping only specified keys. " + + "For example: `object.filter({\"a\": {\"b\": \"x\", \"c\": \"y\"}, \"d\": \"z\"}, [\"a\"])` will result in `{\"a\": {\"b\": \"x\", \"c\": \"y\"}}`).", Decl: types.NewFunction( types.Args( - types.NewObject( + types.Named("object", types.NewObject( nil, types.NewDynamicProperty(types.A, types.A), - ), - types.NewAny( + )).Description("object to filter keys"), + types.Named("keys", types.NewAny( types.NewArray(nil, types.A), types.NewSet(types.A), types.NewObject(nil, types.NewDynamicProperty(types.A, types.A)), - ), + )), ), - types.A, + types.Named("filtered", types.A).Description("remaining data from `object` with only keys specified in `keys`"), ), } -// Base64Encode serializes the input string into base64 encoding. +var ObjectGet = &Builtin{ + Name: "object.get", + Description: "Returns value of an object's key if present, otherwise a default. " + + "If the supplied `key` is an `array`, then `object.get` will search through a nested object or array using each key in turn. " + + "For example: `object.get({\"a\": [{ \"b\": true }]}, [\"a\", 0, \"b\"], false)` results in `true`.", + Decl: types.NewFunction( + types.Args( + types.Named("object", types.NewObject(nil, types.NewDynamicProperty(types.A, types.A))).Description("object to get `key` from"), + types.Named("key", types.A).Description("key to lookup in `object`"), + types.Named("default", types.A).Description("default to use when lookup fails"), + ), + types.Named("value", types.A).Description("`object[key]` if present, otherwise `default`"), + ), +} + +var ObjectKeys = &Builtin{ + Name: "object.keys", + Description: "Returns a set of an object's keys. " + + "For example: `object.keys({\"a\": 1, \"b\": true, \"c\": \"d\")` results in `{\"a\", \"b\", \"c\"}`.", + Decl: types.NewFunction( + types.Args( + types.Named("object", types.NewObject(nil, types.NewDynamicProperty(types.A, types.A))).Description("object to get keys from"), + ), + types.Named("value", types.NewSet(types.A)).Description("set of `object`'s keys"), + ), +} + +/* + * Encoding + */ +var encoding = category("encoding") + +var JSONMarshal = &Builtin{ + Name: "json.marshal", + Description: "Serializes the input term to JSON.", + Decl: types.NewFunction( + types.Args( + types.Named("x", types.A).Description("the term to serialize"), + ), + types.Named("y", types.S).Description("the JSON string representation of `x`"), + ), + Categories: encoding, +} + +var JSONUnmarshal = &Builtin{ + Name: "json.unmarshal", + Description: "Deserializes the input string.", + Decl: types.NewFunction( + types.Args( + types.Named("x", types.S).Description("a JSON string"), + ), + types.Named("y", types.A).Description("the term deseralized from `x`"), + ), + Categories: encoding, +} + +var JSONIsValid = &Builtin{ + Name: "json.is_valid", + Description: "Verifies the input string is a valid JSON document.", + Decl: types.NewFunction( + types.Args( + types.Named("x", types.S).Description("a JSON string"), + ), + types.Named("result", types.B).Description("`true` if `x` is valid JSON, `false` otherwise"), + ), + Categories: encoding, +} + var Base64Encode = &Builtin{ - Name: "base64.encode", + Name: "base64.encode", + Description: "Serializes the input string into base64 encoding.", Decl: types.NewFunction( - types.Args(types.S), - types.S, + types.Args( + types.Named("x", types.S), + ), + types.Named("y", types.S).Description("base64 serialization of `x`"), ), + Categories: encoding, } -// Base64Decode deserializes the base64 encoded input string. var Base64Decode = &Builtin{ - Name: "base64.decode", + Name: "base64.decode", + Description: "Deserializes the base64 encoded input string.", Decl: types.NewFunction( - types.Args(types.S), - types.S, + types.Args( + types.Named("x", types.S), + ), + types.Named("y", types.S).Description("base64 deserialization of `x`"), ), + Categories: encoding, } -// Base64IsValid verifies the input string is base64 encoded. var Base64IsValid = &Builtin{ - Name: "base64.is_valid", + Name: "base64.is_valid", + Description: "Verifies the input string is base64 encoded.", Decl: types.NewFunction( - types.Args(types.S), - types.B, + types.Args( + types.Named("x", types.S), + ), + types.Named("result", types.B).Description("`true` if `x` is valid base64 encoded value, `false` otherwise"), ), + Categories: encoding, } -// Base64UrlEncode serializes the input string into base64url encoding. var Base64UrlEncode = &Builtin{ - Name: "base64url.encode", + Name: "base64url.encode", + Description: "Serializes the input string into base64url encoding.", Decl: types.NewFunction( - types.Args(types.S), - types.S, + types.Args( + types.Named("x", types.S), + ), + types.Named("y", types.S).Description("base64url serialization of `x`"), ), + Categories: encoding, } -// Base64UrlEncodeNoPad serializes the input string into base64url encoding without padding. var Base64UrlEncodeNoPad = &Builtin{ - Name: "base64url.encode_no_pad", + Name: "base64url.encode_no_pad", + Description: "Serializes the input string into base64url encoding without padding.", Decl: types.NewFunction( - types.Args(types.S), - types.S, + types.Args( + types.Named("x", types.S), + ), + types.Named("y", types.S).Description("base64url serialization of `x`"), ), + Categories: encoding, } -// Base64UrlDecode deserializes the base64url encoded input string. var Base64UrlDecode = &Builtin{ - Name: "base64url.decode", + Name: "base64url.decode", + Description: "Deserializes the base64url encoded input string.", Decl: types.NewFunction( - types.Args(types.S), - types.S, + types.Args( + types.Named("x", types.S), + ), + types.Named("y", types.S).Description("base64url deserialization of `x`"), ), + Categories: encoding, } -// URLQueryDecode decodes a URL encoded input string. var URLQueryDecode = &Builtin{ - Name: "urlquery.decode", + Name: "urlquery.decode", + Description: "Decodes a URL-encoded input string.", Decl: types.NewFunction( - types.Args(types.S), - types.S, + types.Args( + types.Named("x", types.S), + ), + types.Named("y", types.S).Description("URL-encoding deserialization of `x`"), ), + Categories: encoding, } -// URLQueryEncode encodes the input string into a URL encoded string. var URLQueryEncode = &Builtin{ - Name: "urlquery.encode", + Name: "urlquery.encode", + Description: "Encodes the input string into a URL-encoded string.", Decl: types.NewFunction( - types.Args(types.S), - types.S, + types.Args( + types.Named("x", types.S), + ), + types.Named("y", types.S).Description("URL-encoding serialization of `x`"), ), + Categories: encoding, } -// URLQueryEncodeObject encodes the given JSON into a URL encoded query string. var URLQueryEncodeObject = &Builtin{ - Name: "urlquery.encode_object", + Name: "urlquery.encode_object", + Description: "Encodes the given object into a URL encoded query string.", Decl: types.NewFunction( types.Args( - types.NewObject( + types.Named("object", types.NewObject( nil, types.NewDynamicProperty( types.S, types.NewAny( types.S, types.NewArray(nil, types.S), - types.NewSet(types.S))))), - types.S, + types.NewSet(types.S)))))), + types.Named("y", types.S).Description("the URL-encoded serialization of `object`"), ), + Categories: encoding, } -// URLQueryDecodeObject decodes the given URL query string into an object. var URLQueryDecodeObject = &Builtin{ - Name: "urlquery.decode_object", + Name: "urlquery.decode_object", + Description: "Decodes the given URL query string into an object.", Decl: types.NewFunction( - types.Args(types.S), - types.NewObject(nil, types.NewDynamicProperty( + types.Args( + types.Named("x", types.S).Description("the query string"), + ), + types.Named("object", types.NewObject(nil, types.NewDynamicProperty( types.S, - types.NewArray(nil, types.S))), + types.NewArray(nil, types.S)))).Description("the resulting object"), ), + Categories: encoding, } -// YAMLMarshal serializes the input term. var YAMLMarshal = &Builtin{ - Name: "yaml.marshal", + Name: "yaml.marshal", + Description: "Serializes the input term to YAML.", Decl: types.NewFunction( - types.Args(types.A), - types.S, + types.Args( + types.Named("x", types.A).Description("the term to serialize"), + ), + types.Named("y", types.S).Description("the YAML string representation of `x`"), ), + Categories: encoding, } -// YAMLUnmarshal deserializes the input string. var YAMLUnmarshal = &Builtin{ - Name: "yaml.unmarshal", + Name: "yaml.unmarshal", + Description: "Deserializes the input string.", Decl: types.NewFunction( - types.Args(types.S), - types.A, + types.Args( + types.Named("x", types.S).Description("a YAML string"), + ), + types.Named("y", types.A).Description("the term deseralized from `x`"), ), + Categories: encoding, } // YAMLIsValid verifies the input string is a valid YAML document. var YAMLIsValid = &Builtin{ - Name: "yaml.is_valid", + Name: "yaml.is_valid", + Description: "Verifies the input string is a valid YAML document.", Decl: types.NewFunction( - types.Args(types.S), - types.B, + types.Args( + types.Named("x", types.S).Description("a YAML string"), + ), + types.Named("result", types.B).Description("`true` if `x` is valid YAML, `false` otherwise"), ), + Categories: encoding, } -// HexEncode serializes the input string into hex encoding. var HexEncode = &Builtin{ - Name: "hex.encode", + Name: "hex.encode", + Description: "Serializes the input string using hex-encoding.", Decl: types.NewFunction( - types.Args(types.S), - types.S, + types.Args( + types.Named("x", types.S), + ), + types.Named("y", types.S).Description("serialization of `x` using hex-encoding"), ), + Categories: encoding, } -// HexDecode deserializes the hex encoded input string. var HexDecode = &Builtin{ - Name: "hex.decode", + Name: "hex.decode", + Description: "Deserializes the hex-encoded input string.", Decl: types.NewFunction( - types.Args(types.S), - types.S, + types.Args( + types.Named("x", types.S).Description("a hex-encoded string"), + ), + types.Named("y", types.S).Description("deseralized from `x`"), ), + Categories: encoding, } /** * Tokens */ +var tokensCat = category("tokens") -// JWTDecode decodes a JSON Web Token and outputs it as an Object. var JWTDecode = &Builtin{ - Name: "io.jwt.decode", + Name: "io.jwt.decode", + Description: "Decodes a JSON Web Token and outputs it as an object.", Decl: types.NewFunction( - types.Args(types.S), - types.NewArray([]types.Type{ + types.Args( + types.Named("jwt", types.S).Description("JWT token to decode"), + ), + types.Named("output", types.NewArray([]types.Type{ types.NewObject(nil, types.NewDynamicProperty(types.A, types.A)), types.NewObject(nil, types.NewDynamicProperty(types.A, types.A)), types.S, - }, nil), + }, nil)).Description("`[header, payload, sig]`, where `header` and `payload` are objects; `sig` is the hexadecimal representation of the signature on the token."), ), + Categories: tokensCat, } -// JWTVerifyRS256 verifies if a RS256 JWT signature is valid or not. var JWTVerifyRS256 = &Builtin{ - Name: "io.jwt.verify_rs256", + Name: "io.jwt.verify_rs256", + Description: "Verifies if a RS256 JWT signature is valid.", Decl: types.NewFunction( types.Args( - types.S, - types.S, + types.Named("jwt", types.S).Description("JWT token whose signature is to be verified"), + types.Named("certificate", types.S).Description("PEM encoded certificate, PEM encoded public key, or the JWK key (set) used to verify the signature"), ), - types.B, + types.Named("result", types.B).Description("`true` if the signature is valid, `false` otherwise"), ), + Categories: tokensCat, } -// JWTVerifyRS384 verifies if a RS384 JWT signature is valid or not. var JWTVerifyRS384 = &Builtin{ - Name: "io.jwt.verify_rs384", + Name: "io.jwt.verify_rs384", + Description: "Verifies if a RS384 JWT signature is valid.", Decl: types.NewFunction( types.Args( - types.S, - types.S, + types.Named("jwt", types.S).Description("JWT token whose signature is to be verified"), + types.Named("certificate", types.S).Description("PEM encoded certificate, PEM encoded public key, or the JWK key (set) used to verify the signature"), ), - types.B, + types.Named("result", types.B).Description("`true` if the signature is valid, `false` otherwise"), ), + Categories: tokensCat, } -// JWTVerifyRS512 verifies if a RS512 JWT signature is valid or not. var JWTVerifyRS512 = &Builtin{ - Name: "io.jwt.verify_rs512", + Name: "io.jwt.verify_rs512", + Description: "Verifies if a RS512 JWT signature is valid.", Decl: types.NewFunction( types.Args( - types.S, - types.S, + types.Named("jwt", types.S).Description("JWT token whose signature is to be verified"), + types.Named("certificate", types.S).Description("PEM encoded certificate, PEM encoded public key, or the JWK key (set) used to verify the signature"), ), - types.B, + types.Named("result", types.B).Description("`true` if the signature is valid, `false` otherwise"), ), + Categories: tokensCat, } -// JWTVerifyPS256 verifies if a PS256 JWT signature is valid or not. var JWTVerifyPS256 = &Builtin{ - Name: "io.jwt.verify_ps256", + Name: "io.jwt.verify_ps256", + Description: "Verifies if a PS256 JWT signature is valid.", Decl: types.NewFunction( types.Args( - types.S, - types.S, + types.Named("jwt", types.S).Description("JWT token whose signature is to be verified"), + types.Named("certificate", types.S).Description("PEM encoded certificate, PEM encoded public key, or the JWK key (set) used to verify the signature"), ), - types.B, + types.Named("result", types.B).Description("`true` if the signature is valid, `false` otherwise"), ), + Categories: tokensCat, } -// JWTVerifyPS384 verifies if a PS384 JWT signature is valid or not. var JWTVerifyPS384 = &Builtin{ - Name: "io.jwt.verify_ps384", + Name: "io.jwt.verify_ps384", + Description: "Verifies if a PS384 JWT signature is valid.", Decl: types.NewFunction( types.Args( - types.S, - types.S, + types.Named("jwt", types.S).Description("JWT token whose signature is to be verified"), + types.Named("certificate", types.S).Description("PEM encoded certificate, PEM encoded public key, or the JWK key (set) used to verify the signature"), ), - types.B, + types.Named("result", types.B).Description("`true` if the signature is valid, `false` otherwise"), ), + Categories: tokensCat, } -// JWTVerifyPS512 verifies if a PS512 JWT signature is valid or not. var JWTVerifyPS512 = &Builtin{ - Name: "io.jwt.verify_ps512", + Name: "io.jwt.verify_ps512", + Description: "Verifies if a PS512 JWT signature is valid.", Decl: types.NewFunction( types.Args( - types.S, - types.S, + types.Named("jwt", types.S).Description("JWT token whose signature is to be verified"), + types.Named("certificate", types.S).Description("PEM encoded certificate, PEM encoded public key, or the JWK key (set) used to verify the signature"), ), - types.B, + types.Named("result", types.B).Description("`true` if the signature is valid, `false` otherwise"), ), + Categories: tokensCat, } -// JWTVerifyES256 verifies if a ES256 JWT signature is valid or not. var JWTVerifyES256 = &Builtin{ - Name: "io.jwt.verify_es256", + Name: "io.jwt.verify_es256", + Description: "Verifies if a ES256 JWT signature is valid.", Decl: types.NewFunction( types.Args( - types.S, - types.S, + types.Named("jwt", types.S).Description("JWT token whose signature is to be verified"), + types.Named("certificate", types.S).Description("PEM encoded certificate, PEM encoded public key, or the JWK key (set) used to verify the signature"), ), - types.B, + types.Named("result", types.B).Description("`true` if the signature is valid, `false` otherwise"), ), + Categories: tokensCat, } -// JWTVerifyES384 verifies if a ES384 JWT signature is valid or not. var JWTVerifyES384 = &Builtin{ - Name: "io.jwt.verify_es384", + Name: "io.jwt.verify_es384", + Description: "Verifies if a ES384 JWT signature is valid.", Decl: types.NewFunction( types.Args( - types.S, - types.S, + types.Named("jwt", types.S).Description("JWT token whose signature is to be verified"), + types.Named("certificate", types.S).Description("PEM encoded certificate, PEM encoded public key, or the JWK key (set) used to verify the signature"), ), - types.B, + types.Named("result", types.B).Description("`true` if the signature is valid, `false` otherwise"), ), + Categories: tokensCat, } -// JWTVerifyES512 verifies if a ES512 JWT signature is valid or not. var JWTVerifyES512 = &Builtin{ - Name: "io.jwt.verify_es512", + Name: "io.jwt.verify_es512", + Description: "Verifies if a ES512 JWT signature is valid.", Decl: types.NewFunction( types.Args( - types.S, - types.S, + types.Named("jwt", types.S).Description("JWT token whose signature is to be verified"), + types.Named("certificate", types.S).Description("PEM encoded certificate, PEM encoded public key, or the JWK key (set) used to verify the signature"), ), - types.B, + types.Named("result", types.B).Description("`true` if the signature is valid, `false` otherwise"), ), + Categories: tokensCat, } -// JWTVerifyHS256 verifies if a HS256 (secret) JWT signature is valid or not. var JWTVerifyHS256 = &Builtin{ - Name: "io.jwt.verify_hs256", + Name: "io.jwt.verify_hs256", + Description: "Verifies if a HS256 (secret) JWT signature is valid.", Decl: types.NewFunction( types.Args( - types.S, - types.S, + types.Named("jwt", types.S).Description("JWT token whose signature is to be verified"), + types.Named("secret", types.S).Description("plain text secret used to verify the signature"), ), - types.B, + types.Named("result", types.B).Description("`true` if the signature is valid, `false` otherwise"), ), + Categories: tokensCat, } -// JWTVerifyHS384 verifies if a HS384 (secret) JWT signature is valid or not. var JWTVerifyHS384 = &Builtin{ - Name: "io.jwt.verify_hs384", + Name: "io.jwt.verify_hs384", + Description: "Verifies if a HS384 (secret) JWT signature is valid.", Decl: types.NewFunction( types.Args( - types.S, - types.S, + types.Named("jwt", types.S).Description("JWT token whose signature is to be verified"), + types.Named("secret", types.S).Description("plain text secret used to verify the signature"), ), - types.B, + types.Named("result", types.B).Description("`true` if the signature is valid, `false` otherwise"), ), + Categories: tokensCat, } -// JWTVerifyHS512 verifies if a HS512 (secret) JWT signature is valid or not. var JWTVerifyHS512 = &Builtin{ - Name: "io.jwt.verify_hs512", + Name: "io.jwt.verify_hs512", + Description: "Verifies if a HS512 (secret) JWT signature is valid.", Decl: types.NewFunction( types.Args( - types.S, - types.S, + types.Named("jwt", types.S).Description("JWT token whose signature is to be verified"), + types.Named("secret", types.S).Description("plain text secret used to verify the signature"), ), - types.B, + types.Named("result", types.B).Description("`true` if the signature is valid, `false` otherwise"), ), + Categories: tokensCat, } -// JWTDecodeVerify verifies a JWT signature under parameterized constraints and decodes the claims if it is valid. +// Marked non-deterministic because it relies on time internally. var JWTDecodeVerify = &Builtin{ Name: "io.jwt.decode_verify", + Description: `Verifies a JWT signature under parameterized constraints and decodes the claims if it is valid. +Supports the following algorithms: HS256, HS384, HS512, RS256, RS384, RS512, ES256, ES384, ES512, PS256, PS384 and PS512.`, Decl: types.NewFunction( types.Args( - types.S, - types.NewObject(nil, types.NewDynamicProperty(types.S, types.A)), + types.Named("jwt", types.S).Description("JWT token whose signature is to be verified and whose claims are to be checked"), + types.Named("constraints", types.NewObject(nil, types.NewDynamicProperty(types.S, types.A))).Description("claim verification constraints"), ), - types.NewArray([]types.Type{ + types.Named("output", types.NewArray([]types.Type{ types.B, types.NewObject(nil, types.NewDynamicProperty(types.A, types.A)), types.NewObject(nil, types.NewDynamicProperty(types.A, types.A)), - }, nil), + }, nil)).Description("`[valid, header, payload]`: if the input token is verified and meets the requirements of `constraints` then `valid` is `true`; `header` and `payload` are objects containing the JOSE header and the JWT claim set; otherwise, `valid` is `false`, `header` and `payload` are `{}`"), ), + Categories: tokensCat, + Nondeterministic: true, } -// JWTEncodeSignRaw encodes and optionally sign a JSON Web Token. -// Inputs are protected headers, payload, secret +var tokenSign = category("tokensign") + +// Marked non-deterministic because it relies on RNG internally. var JWTEncodeSignRaw = &Builtin{ - Name: "io.jwt.encode_sign_raw", + Name: "io.jwt.encode_sign_raw", + Description: "Encodes and optionally signs a JSON Web Token.", Decl: types.NewFunction( types.Args( - types.S, - types.S, - types.S, + types.Named("headers", types.S).Description("JWS Protected Header"), + types.Named("payload", types.S).Description("JWS Payload"), + types.Named("key", types.S).Description("JSON Web Key (RFC7517)"), ), - types.S, + types.Named("output", types.S).Description("signed JWT"), ), + Categories: tokenSign, + Nondeterministic: true, } -// JWTEncodeSign encodes and optionally sign a JSON Web Token. -// Inputs are protected headers, payload, secret +// Marked non-deterministic because it relies on RNG internally. var JWTEncodeSign = &Builtin{ - Name: "io.jwt.encode_sign", + Name: "io.jwt.encode_sign", + Description: "Encodes and optionally signs a JSON Web Token. Inputs are taken as objects, not encoded strings (see `io.jwt.encode_sign_raw`).", Decl: types.NewFunction( types.Args( - types.NewObject(nil, types.NewDynamicProperty(types.S, types.A)), - types.NewObject(nil, types.NewDynamicProperty(types.S, types.A)), - types.NewObject(nil, types.NewDynamicProperty(types.S, types.A)), + types.Named("headers", types.NewObject(nil, types.NewDynamicProperty(types.S, types.A))).Description("JWS Protected Header"), + types.Named("payload", types.NewObject(nil, types.NewDynamicProperty(types.S, types.A))).Description("JWS Payload"), + types.Named("key", types.NewObject(nil, types.NewDynamicProperty(types.S, types.A))).Description("JSON Web Key (RFC7517)"), ), - types.S, + types.Named("output", types.S).Description("signed JWT"), ), + Categories: tokenSign, + Nondeterministic: true, } /** * Time */ -// NowNanos returns the current time since epoch in nanoseconds. +// Marked non-deterministic because it relies on time directly. var NowNanos = &Builtin{ - Name: "time.now_ns", + Name: "time.now_ns", + Description: "Returns the current time since epoch in nanoseconds.", Decl: types.NewFunction( nil, - types.N, + types.Named("now", types.N).Description("nanoseconds since epoch"), ), + Nondeterministic: true, } -// ParseNanos returns the time in nanoseconds parsed from the string in the given format. var ParseNanos = &Builtin{ - Name: "time.parse_ns", + Name: "time.parse_ns", + Description: "Returns the time in nanoseconds parsed from the string in the given format. `undefined` if the result would be outside the valid time range that can fit within an `int64`.", Decl: types.NewFunction( types.Args( - types.S, - types.S, + types.Named("layout", types.S).Description("format used for parsing, see the [Go `time` package documentation](https://golang.org/pkg/time/#Parse) for more details"), + types.Named("value", types.S).Description("input to parse according to `layout`"), ), - types.N, + types.Named("ns", types.N).Description("`value` in nanoseconds since epoch"), ), } -// ParseRFC3339Nanos returns the time in nanoseconds parsed from the string in RFC3339 format. var ParseRFC3339Nanos = &Builtin{ - Name: "time.parse_rfc3339_ns", + Name: "time.parse_rfc3339_ns", + Description: "Returns the time in nanoseconds parsed from the string in RFC3339 format. `undefined` if the result would be outside the valid time range that can fit within an `int64`.", Decl: types.NewFunction( - types.Args(types.S), - types.N, + types.Args( + types.Named("value", types.S), + ), + types.Named("ns", types.N).Description("`value` in nanoseconds since epoch"), ), } -// ParseDurationNanos returns the duration in nanoseconds represented by a duration string. -// Duration string is similar to the Go time.ParseDuration string var ParseDurationNanos = &Builtin{ - Name: "time.parse_duration_ns", + Name: "time.parse_duration_ns", + Description: "Returns the duration in nanoseconds represented by a string.", Decl: types.NewFunction( - types.Args(types.S), - types.N, + types.Args( + types.Named("duration", types.S).Description("a duration like \"3m\"; see the [Go `time` package documentation](https://golang.org/pkg/time/#ParseDuration) for more details"), + ), + types.Named("ns", types.N).Description("the `duration` in nanoseconds"), + ), +} + +var Format = &Builtin{ + Name: "time.format", + Description: "Returns the formatted timestamp for the nanoseconds since epoch.", + Decl: types.NewFunction( + types.Args( + types.Named("x", types.NewAny( + types.N, + types.NewArray([]types.Type{types.N, types.S}, nil), + types.NewArray([]types.Type{types.N, types.S, types.S}, nil), + )).Description("a number representing the nanoseconds since the epoch (UTC); or a two-element array of the nanoseconds, and a timezone string; or a three-element array of ns, timezone string and a layout string (see golang supported time formats)"), + ), + types.Named("formatted timestamp", types.S).Description("the formatted timestamp represented for the nanoseconds since the epoch in the supplied timezone (or UTC)"), ), } -// Date returns the [year, month, day] for the nanoseconds since epoch. var Date = &Builtin{ - Name: "time.date", + Name: "time.date", + Description: "Returns the `[year, month, day]` for the nanoseconds since epoch.", Decl: types.NewFunction( types.Args( - types.NewAny( + types.Named("x", types.NewAny( types.N, types.NewArray([]types.Type{types.N, types.S}, nil), - ), + )).Description("a number representing the nanoseconds since the epoch (UTC); or a two-element array of the nanoseconds, and a timezone string"), ), - types.NewArray([]types.Type{types.N, types.N, types.N}, nil), + types.Named("date", types.NewArray([]types.Type{types.N, types.N, types.N}, nil)).Description("an array of `year`, `month` (1-12), and `day` (1-31)"), ), } -// Clock returns the [hour, minute, second] of the day for the nanoseconds since epoch. var Clock = &Builtin{ - Name: "time.clock", + Name: "time.clock", + Description: "Returns the `[hour, minute, second]` of the day for the nanoseconds since epoch.", Decl: types.NewFunction( types.Args( - types.NewAny( + types.Named("x", types.NewAny( types.N, types.NewArray([]types.Type{types.N, types.S}, nil), - ), + )).Description("a number representing the nanoseconds since the epoch (UTC); or a two-element array of the nanoseconds, and a timezone string"), ), - types.NewArray([]types.Type{types.N, types.N, types.N}, nil), + types.Named("output", types.NewArray([]types.Type{types.N, types.N, types.N}, nil)). + Description("the `hour`, `minute` (0-59), and `second` (0-59) representing the time of day for the nanoseconds since epoch in the supplied timezone (or UTC)"), ), } -// Weekday returns the day of the week (Monday, Tuesday, ...) for the nanoseconds since epoch. var Weekday = &Builtin{ - Name: "time.weekday", + Name: "time.weekday", + Description: "Returns the day of the week (Monday, Tuesday, ...) for the nanoseconds since epoch.", Decl: types.NewFunction( types.Args( - types.NewAny( + types.Named("x", types.NewAny( types.N, types.NewArray([]types.Type{types.N, types.S}, nil), - ), + )).Description("a number representing the nanoseconds since the epoch (UTC); or a two-element array of the nanoseconds, and a timezone string"), ), - types.S, + types.Named("day", types.S).Description("the weekday represented by `ns` nanoseconds since the epoch in the supplied timezone (or UTC)"), ), } -// AddDate returns the nanoseconds since epoch after adding years, months and days to nanoseconds. var AddDate = &Builtin{ - Name: "time.add_date", + Name: "time.add_date", + Description: "Returns the nanoseconds since epoch after adding years, months and days to nanoseconds. `undefined` if the result would be outside the valid time range that can fit within an `int64`.", Decl: types.NewFunction( types.Args( - types.N, - types.N, - types.N, - types.N, + types.Named("ns", types.N).Description("nanoseconds since the epoch"), + types.Named("years", types.N), + types.Named("months", types.N), + types.Named("days", types.N), ), - types.N, + types.Named("output", types.N).Description("nanoseconds since the epoch representing the input time, with years, months and days added"), ), } -// Diff returns the difference [years, months, days, hours, minutes, seconds] between two unix timestamps in nanoseconds var Diff = &Builtin{ - Name: "time.diff", + Name: "time.diff", + Description: "Returns the difference between two unix timestamps in nanoseconds (with optional timezone strings).", Decl: types.NewFunction( types.Args( - types.NewAny( + types.Named("ns1", types.NewAny( types.N, types.NewArray([]types.Type{types.N, types.S}, nil), - ), - types.NewAny( + )), + types.Named("ns2", types.NewAny( types.N, types.NewArray([]types.Type{types.N, types.S}, nil), - ), + )), ), - types.NewArray([]types.Type{types.N, types.N, types.N, types.N, types.N, types.N}, nil), + types.Named("output", types.NewArray([]types.Type{types.N, types.N, types.N, types.N, types.N, types.N}, nil)).Description("difference between `ns1` and `ns2` (in their supplied timezones, if supplied, or UTC) as array of numbers: `[years, months, days, hours, minutes, seconds]`"), ), } @@ -1802,161 +2244,185 @@ var Diff = &Builtin{ * Crypto. */ -// CryptoX509ParseCertificates returns one or more certificates from the given -// base64 encoded string containing DER encoded certificates that have been -// concatenated. var CryptoX509ParseCertificates = &Builtin{ Name: "crypto.x509.parse_certificates", + Description: `Returns zero or more certificates from the given encoded string containing +DER certificate data. + +If the input is empty, the function will return null. The input string should be a list of one or more +concatenated PEM blocks. The whole input of concatenated PEM blocks can optionally be Base64 encoded.`, Decl: types.NewFunction( - types.Args(types.S), - types.NewArray(nil, types.NewObject(nil, types.NewDynamicProperty(types.S, types.A))), + types.Args( + types.Named("certs", types.S).Description("base64 encoded DER or PEM data containing one or more certificates or a PEM string of one or more certificates"), + ), + types.Named("output", types.NewArray(nil, types.NewObject(nil, types.NewDynamicProperty(types.S, types.A)))).Description("parsed X.509 certificates represented as objects"), ), } -// CryptoX509ParseAndVerifyCertificates returns one or more certificates from the given -// string containing PEM or base64 encoded DER certificates after verifying the supplied -// certificates form a complete certificate chain back to a trusted root. -// -// The first certificate is treated as the root and the last is treated as the leaf, -// with all others being treated as intermediates var CryptoX509ParseAndVerifyCertificates = &Builtin{ Name: "crypto.x509.parse_and_verify_certificates", + Description: `Returns one or more certificates from the given string containing PEM +or base64 encoded DER certificates after verifying the supplied certificates form a complete +certificate chain back to a trusted root. + +The first certificate is treated as the root and the last is treated as the leaf, +with all others being treated as intermediates.`, Decl: types.NewFunction( - types.Args(types.S), - types.NewArray([]types.Type{ + types.Args( + types.Named("certs", types.S).Description("base64 encoded DER or PEM data containing two or more certificates where the first is a root CA, the last is a leaf certificate, and all others are intermediate CAs"), + ), + types.Named("output", types.NewArray([]types.Type{ types.B, types.NewArray(nil, types.NewObject(nil, types.NewDynamicProperty(types.S, types.A))), - }, nil), + }, nil)).Description("array of `[valid, certs]`: if the input certificate chain could be verified then `valid` is `true` and `certs` is an array of X.509 certificates represented as objects; if the input certificate chain could not be verified then `valid` is `false` and `certs` is `[]`"), ), } -// CryptoX509ParseCertificateRequest returns a PKCS #10 certificate signing -// request from the given PEM-encoded PKCS#10 certificate signing request. var CryptoX509ParseCertificateRequest = &Builtin{ - Name: "crypto.x509.parse_certificate_request", + Name: "crypto.x509.parse_certificate_request", + Description: "Returns a PKCS #10 certificate signing request from the given PEM-encoded PKCS#10 certificate signing request.", Decl: types.NewFunction( - types.Args(types.S), - types.NewObject(nil, types.NewDynamicProperty(types.S, types.A)), + types.Args( + types.Named("csr", types.S).Description("base64 string containing either a PEM encoded or DER CSR or a string containing a PEM CSR"), + ), + types.Named("output", types.NewObject(nil, types.NewDynamicProperty(types.S, types.A))).Description("X.509 CSR represented as an object"), ), } -// CryptoX509ParseRSAPrivateKey returns a JWK for signing a JWT from the given -// PEM-encoded RSA private key. var CryptoX509ParseRSAPrivateKey = &Builtin{ - Name: "crypto.x509.parse_rsa_private_key", + Name: "crypto.x509.parse_rsa_private_key", + Description: "Returns a JWK for signing a JWT from the given PEM-encoded RSA private key.", Decl: types.NewFunction( - types.Args(types.S), - types.NewObject(nil, types.NewDynamicProperty(types.S, types.A)), + types.Args( + types.Named("pem", types.S).Description("base64 string containing a PEM encoded RSA private key"), + ), + types.Named("output", types.NewObject(nil, types.NewDynamicProperty(types.S, types.A))).Description("JWK as an object"), ), } -// CryptoMd5 returns a string representing the input string hashed with the md5 function var CryptoMd5 = &Builtin{ - Name: "crypto.md5", + Name: "crypto.md5", + Description: "Returns a string representing the input string hashed with the MD5 function", Decl: types.NewFunction( - types.Args(types.S), - types.S, + types.Args( + types.Named("x", types.S), + ), + types.Named("y", types.S).Description("MD5-hash of `x`"), ), } -// CryptoSha1 returns a string representing the input string hashed with the sha1 function var CryptoSha1 = &Builtin{ - Name: "crypto.sha1", + Name: "crypto.sha1", + Description: "Returns a string representing the input string hashed with the SHA1 function", Decl: types.NewFunction( - types.Args(types.S), - types.S, + types.Args( + types.Named("x", types.S), + ), + types.Named("y", types.S).Description("SHA1-hash of `x`"), ), } -// CryptoSha256 returns a string representing the input string hashed with the sha256 function var CryptoSha256 = &Builtin{ - Name: "crypto.sha256", + Name: "crypto.sha256", + Description: "Returns a string representing the input string hashed with the SHA256 function", Decl: types.NewFunction( - types.Args(types.S), - types.S, + types.Args( + types.Named("x", types.S), + ), + types.Named("y", types.S).Description("SHA256-hash of `x`"), ), } -// CryptoHmacMd5 returns a string representing the MD-5 HMAC of the input message using the input key -// Inputs are message, key var CryptoHmacMd5 = &Builtin{ - Name: "crypto.hmac.md5", + Name: "crypto.hmac.md5", + Description: "Returns a string representing the MD5 HMAC of the input message using the input key.", Decl: types.NewFunction( types.Args( - types.S, - types.S, + types.Named("x", types.S).Description("input string"), + types.Named("key", types.S).Description("key to use"), ), - types.S, + types.Named("y", types.S).Description("MD5-HMAC of `x`"), ), } -// CryptoHmacSha1 returns a string representing the SHA-1 HMAC of the input message using the input key -// Inputs are message, key var CryptoHmacSha1 = &Builtin{ - Name: "crypto.hmac.sha1", + Name: "crypto.hmac.sha1", + Description: "Returns a string representing the SHA1 HMAC of the input message using the input key.", Decl: types.NewFunction( types.Args( - types.S, - types.S, + types.Named("x", types.S).Description("input string"), + types.Named("key", types.S).Description("key to use"), ), - types.S, + types.Named("y", types.S).Description("SHA1-HMAC of `x`"), ), } -// CryptoHmacSha256 returns a string representing the SHA-256 HMAC of the input message using the input key -// Inputs are message, key var CryptoHmacSha256 = &Builtin{ - Name: "crypto.hmac.sha256", + Name: "crypto.hmac.sha256", + Description: "Returns a string representing the SHA256 HMAC of the input message using the input key.", Decl: types.NewFunction( types.Args( - types.S, - types.S, + types.Named("x", types.S).Description("input string"), + types.Named("key", types.S).Description("key to use"), ), - types.S, + types.Named("y", types.S).Description("SHA256-HMAC of `x`"), ), } -// CryptoHmacSha512 returns a string representing the SHA-512 HMAC of the input message using the input key -// Inputs are message, key var CryptoHmacSha512 = &Builtin{ - Name: "crypto.hmac.sha512", + Name: "crypto.hmac.sha512", + Description: "Returns a string representing the SHA512 HMAC of the input message using the input key.", Decl: types.NewFunction( types.Args( - types.S, - types.S, + types.Named("x", types.S).Description("input string"), + types.Named("key", types.S).Description("key to use"), ), - types.S, + types.Named("y", types.S).Description("SHA512-HMAC of `x`"), + ), +} + +var CryptoHmacEqual = &Builtin{ + Name: "crypto.hmac.equal", + Description: "Returns a boolean representing the result of comparing two MACs for equality without leaking timing information.", + Decl: types.NewFunction( + types.Args( + types.Named("mac1", types.S).Description("mac1 to compare"), + types.Named("mac2", types.S).Description("mac2 to compare"), + ), + types.Named("result", types.B).Description("`true` if the MACs are equals, `false` otherwise"), ), } /** * Graphs. */ +var graphs = category("graph") -// WalkBuiltin generates [path, value] tuples for all nested documents -// (recursively). var WalkBuiltin = &Builtin{ - Name: "walk", - Relation: true, + Name: "walk", + Relation: true, + Description: "Generates `[path, value]` tuples for all nested documents of `x` (recursively). Queries can use `walk` to traverse documents nested under `x`.", Decl: types.NewFunction( - types.Args(types.A), - types.NewArray( + types.Args( + types.Named("x", types.A), + ), + types.Named("output", types.NewArray( []types.Type{ types.NewArray(nil, types.A), types.A, }, nil, - ), + )).Description("pairs of `path` and `value`: `path` is an array representing the pointer to `value` in `x`"), ), + Categories: graphs, } -// ReachableBuiltin computes the set of reachable nodes in the graph from a set -// of starting nodes. var ReachableBuiltin = &Builtin{ - Name: "graph.reachable", + Name: "graph.reachable", + Description: "Computes the set of reachable nodes in the graph from a set of starting nodes.", Decl: types.NewFunction( types.Args( - types.NewObject( + types.Named("graph", types.NewObject( nil, types.NewDynamicProperty( types.A, @@ -1964,109 +2430,120 @@ var ReachableBuiltin = &Builtin{ types.NewSet(types.A), types.NewArray(nil, types.A)), )), - types.NewAny(types.NewSet(types.A), types.NewArray(nil, types.A)), + ).Description("object containing a set or array of neighboring vertices"), + types.Named("initial", types.NewAny(types.NewSet(types.A), types.NewArray(nil, types.A))).Description("set or array of root vertices"), ), - types.NewSet(types.A), + types.Named("output", types.NewSet(types.A)).Description("set of vertices reachable from the `initial` vertices in the directed `graph`"), ), } -/** - * Sorting - */ - -// Sort returns a sorted array. -var Sort = &Builtin{ - Name: "sort", +var ReachablePathsBuiltin = &Builtin{ + Name: "graph.reachable_paths", + Description: "Computes the set of reachable paths in the graph from a set of starting nodes.", Decl: types.NewFunction( types.Args( - types.NewAny( - types.NewArray(nil, types.A), - types.NewSet(types.A), - ), + types.Named("graph", types.NewObject( + nil, + types.NewDynamicProperty( + types.A, + types.NewAny( + types.NewSet(types.A), + types.NewArray(nil, types.A)), + )), + ).Description("object containing a set or array of root vertices"), + types.Named("initial", types.NewAny(types.NewSet(types.A), types.NewArray(nil, types.A))).Description("initial paths"), // TODO(sr): copied. is that correct? ), - types.NewArray(nil, types.A), + types.Named("output", types.NewSet(types.NewArray(nil, types.A))).Description("paths reachable from the `initial` vertices in the directed `graph`"), ), } /** * Type */ +var typesCat = category("types") -// IsNumber returns true if the input value is a number var IsNumber = &Builtin{ - Name: "is_number", + Name: "is_number", + Description: "Returns `true` if the input value is a number.", Decl: types.NewFunction( types.Args( - types.A, + types.Named("x", types.A), ), - types.B, + types.Named("result", types.B).Description("`true` if `x` is a number, `false` otherwise."), ), + Categories: typesCat, } -// IsString returns true if the input value is a string. var IsString = &Builtin{ - Name: "is_string", + Name: "is_string", + Description: "Returns `true` if the input value is a string.", Decl: types.NewFunction( types.Args( - types.A, + types.Named("x", types.A), ), - types.B, + types.Named("result", types.B).Description("`true` if `x` is a string, `false` otherwise."), ), + Categories: typesCat, } -// IsBoolean returns true if the input value is a boolean. var IsBoolean = &Builtin{ - Name: "is_boolean", + Name: "is_boolean", + Description: "Returns `true` if the input value is a boolean.", Decl: types.NewFunction( types.Args( - types.A, + types.Named("x", types.A), ), - types.B, + types.Named("result", types.B).Description("`true` if `x` is an boolean, `false` otherwise."), ), + Categories: typesCat, } -// IsArray returns true if the input value is an array. var IsArray = &Builtin{ - Name: "is_array", + Name: "is_array", + Description: "Returns `true` if the input value is an array.", Decl: types.NewFunction( types.Args( - types.A, + types.Named("x", types.A), ), - types.B, + types.Named("result", types.B).Description("`true` if `x` is an array, `false` otherwise."), ), + Categories: typesCat, } -// IsSet returns true if the input value is a set. var IsSet = &Builtin{ - Name: "is_set", + Name: "is_set", + Description: "Returns `true` if the input value is a set.", Decl: types.NewFunction( types.Args( - types.A, + types.Named("x", types.A), ), - types.B, + types.Named("result", types.B).Description("`true` if `x` is a set, `false` otherwise."), ), + Categories: typesCat, } -// IsObject returns true if the input value is an object. var IsObject = &Builtin{ - Name: "is_object", + Name: "is_object", + Description: "Returns true if the input value is an object", Decl: types.NewFunction( types.Args( - types.A, + types.Named("x", types.A), ), - types.B, + types.Named("result", types.B).Description("`true` if `x` is an object, `false` otherwise."), ), + Categories: typesCat, } -// IsNull returns true if the input value is null. var IsNull = &Builtin{ - Name: "is_null", + Name: "is_null", + Description: "Returns `true` if the input value is null.", Decl: types.NewFunction( types.Args( - types.A, + types.Named("x", types.A), ), - types.B, + types.Named("result", types.B).Description("`true` if `x` is null, `false` otherwise."), ), + Categories: typesCat, } /** @@ -2075,193 +2552,378 @@ var IsNull = &Builtin{ // TypeNameBuiltin returns the type of the input. var TypeNameBuiltin = &Builtin{ - Name: "type_name", + Name: "type_name", + Description: "Returns the type of its input value.", Decl: types.NewFunction( types.Args( - types.NewAny( - types.A, - ), + types.Named("x", types.A), ), - types.S, + types.Named("type", types.S).Description(`one of "null", "boolean", "number", "string", "array", "object", "set"`), ), + Categories: typesCat, } /** * HTTP Request */ -// HTTPSend returns a HTTP response to the given HTTP request. +// Marked non-deterministic because HTTP request results can be non-deterministic. var HTTPSend = &Builtin{ - Name: "http.send", + Name: "http.send", + Description: "Returns a HTTP response to the given HTTP request.", Decl: types.NewFunction( types.Args( - types.NewObject(nil, types.NewDynamicProperty(types.S, types.A)), + types.Named("request", types.NewObject(nil, types.NewDynamicProperty(types.S, types.A))), ), - types.NewObject(nil, types.NewDynamicProperty(types.A, types.A)), + types.Named("response", types.NewObject(nil, types.NewDynamicProperty(types.A, types.A))), ), + Nondeterministic: true, } /** - * Rego + * GraphQL */ -// RegoParseModule parses the input Rego file and returns a JSON representation -// of the AST. -var RegoParseModule = &Builtin{ - Name: "rego.parse_module", +// GraphQLParse returns a pair of AST objects from parsing/validation. +var GraphQLParse = &Builtin{ + Name: "graphql.parse", + Description: "Returns AST objects for a given GraphQL query and schema after validating the query against the schema. Returns undefined if errors were encountered during parsing or validation. The query and/or schema can be either GraphQL strings or AST objects from the other GraphQL builtin functions.", Decl: types.NewFunction( types.Args( - types.S, - types.S, + types.Named("query", types.NewAny(types.S, types.NewObject(nil, types.NewDynamicProperty(types.A, types.A)))), + types.Named("schema", types.NewAny(types.S, types.NewObject(nil, types.NewDynamicProperty(types.A, types.A)))), ), - types.NewObject(nil, types.NewDynamicProperty(types.S, types.A)), // TODO(tsandall): import AST schema + types.Named("output", types.NewArray([]types.Type{ + types.NewObject(nil, types.NewDynamicProperty(types.A, types.A)), + types.NewObject(nil, types.NewDynamicProperty(types.A, types.A)), + }, nil)).Description("`output` is of the form `[query_ast, schema_ast]`. If the GraphQL query is valid given the provided schema, then `query_ast` and `schema_ast` are objects describing the ASTs for the query and schema."), + ), +} + +// GraphQLParseAndVerify returns a boolean and a pair of AST object from parsing/validation. +var GraphQLParseAndVerify = &Builtin{ + Name: "graphql.parse_and_verify", + Description: "Returns a boolean indicating success or failure alongside the parsed ASTs for a given GraphQL query and schema after validating the query against the schema. The query and/or schema can be either GraphQL strings or AST objects from the other GraphQL builtin functions.", + Decl: types.NewFunction( + types.Args( + types.Named("query", types.NewAny(types.S, types.NewObject(nil, types.NewDynamicProperty(types.A, types.A)))), + types.Named("schema", types.NewAny(types.S, types.NewObject(nil, types.NewDynamicProperty(types.A, types.A)))), + ), + types.Named("output", types.NewArray([]types.Type{ + types.B, + types.NewObject(nil, types.NewDynamicProperty(types.A, types.A)), + types.NewObject(nil, types.NewDynamicProperty(types.A, types.A)), + }, nil)).Description(" `output` is of the form `[valid, query_ast, schema_ast]`. If the query is valid given the provided schema, then `valid` is `true`, and `query_ast` and `schema_ast` are objects describing the ASTs for the GraphQL query and schema. Otherwise, `valid` is `false` and `query_ast` and `schema_ast` are `{}`."), + ), +} + +// GraphQLParseQuery parses the input GraphQL query and returns a JSON +// representation of its AST. +var GraphQLParseQuery = &Builtin{ + Name: "graphql.parse_query", + Description: "Returns an AST object for a GraphQL query.", + Decl: types.NewFunction( + types.Args( + types.Named("query", types.S), + ), + types.Named("output", types.NewObject(nil, types.NewDynamicProperty(types.A, types.A))).Description("AST object for the GraphQL query."), + ), +} + +// GraphQLParseSchema parses the input GraphQL schema and returns a JSON +// representation of its AST. +var GraphQLParseSchema = &Builtin{ + Name: "graphql.parse_schema", + Description: "Returns an AST object for a GraphQL schema.", + Decl: types.NewFunction( + types.Args( + types.Named("schema", types.S), + ), + types.Named("output", types.NewObject(nil, types.NewDynamicProperty(types.A, types.A))).Description("AST object for the GraphQL schema."), + ), +} + +// GraphQLIsValid returns true if a GraphQL query is valid with a given +// schema, and returns false for all other inputs. +var GraphQLIsValid = &Builtin{ + Name: "graphql.is_valid", + Description: "Checks that a GraphQL query is valid against a given schema. The query and/or schema can be either GraphQL strings or AST objects from the other GraphQL builtin functions.", + Decl: types.NewFunction( + types.Args( + types.Named("query", types.NewAny(types.S, types.NewObject(nil, types.NewDynamicProperty(types.A, types.A)))), + types.Named("schema", types.NewAny(types.S, types.NewObject(nil, types.NewDynamicProperty(types.A, types.A)))), + ), + types.Named("output", types.B).Description("`true` if the query is valid under the given schema. `false` otherwise."), + ), +} + +// GraphQLSchemaIsValid returns true if the input is valid GraphQL schema, +// and returns false for all other inputs. +var GraphQLSchemaIsValid = &Builtin{ + Name: "graphql.schema_is_valid", + Description: "Checks that the input is a valid GraphQL schema. The schema can be either a GraphQL string or an AST object from the other GraphQL builtin functions.", + Decl: types.NewFunction( + types.Args( + types.Named("schema", types.NewAny(types.S, types.NewObject(nil, types.NewDynamicProperty(types.A, types.A)))), + ), + types.Named("output", types.B).Description("`true` if the schema is a valid GraphQL schema. `false` otherwise."), ), } /** - * OPA + * JSON Schema */ -// OPARuntime returns an object containing OPA runtime information such as the -// configuration that OPA was booted with. -var OPARuntime = &Builtin{ - Name: "opa.runtime", +// JSONSchemaVerify returns empty string if the input is valid JSON schema +// and returns error string for all other inputs. +var JSONSchemaVerify = &Builtin{ + Name: "json.verify_schema", + Description: "Checks that the input is a valid JSON schema object. The schema can be either a JSON string or an JSON object.", Decl: types.NewFunction( - nil, - types.NewObject(nil, types.NewDynamicProperty(types.S, types.A)), + types.Args( + types.Named("schema", types.NewAny(types.S, types.NewObject(nil, types.NewDynamicProperty(types.A, types.A)))). + Description("the schema to verify"), + ), + types.Named("output", types.NewArray([]types.Type{ + types.B, + types.NewAny(types.S, types.Null{}), + }, nil)). + Description("`output` is of the form `[valid, error]`. If the schema is valid, then `valid` is `true`, and `error` is `null`. Otherwise, `valid` is `false` and `error` is a string describing the error."), ), + Categories: objectCat, +} + +// JSONMatchSchema returns empty array if the document matches the JSON schema, +// and returns non-empty array with error objects otherwise. +var JSONMatchSchema = &Builtin{ + Name: "json.match_schema", + Description: "Checks that the document matches the JSON schema.", + Decl: types.NewFunction( + types.Args( + types.Named("document", types.NewAny(types.S, types.NewObject(nil, types.NewDynamicProperty(types.A, types.A)))). + Description("document to verify by schema"), + types.Named("schema", types.NewAny(types.S, types.NewObject(nil, types.NewDynamicProperty(types.A, types.A)))). + Description("schema to verify document by"), + ), + types.Named("output", types.NewArray([]types.Type{ + types.B, + types.NewArray( + nil, types.NewObject( + []*types.StaticProperty{ + {Key: "error", Value: types.S}, + {Key: "type", Value: types.S}, + {Key: "field", Value: types.S}, + {Key: "desc", Value: types.S}, + }, + nil, + ), + ), + }, nil)). + Description("`output` is of the form `[match, errors]`. If the document is valid given the schema, then `match` is `true`, and `errors` is an empty array. Otherwise, `match` is `false` and `errors` is an array of objects describing the error(s)."), + ), + Categories: objectCat, } /** - * Trace + * Cloud Provider Helper Functions */ +var providersAWSCat = category("providers.aws") -// Trace prints a note that is included in the query explanation. -var Trace = &Builtin{ - Name: "trace", +var ProvidersAWSSignReqObj = &Builtin{ + Name: "providers.aws.sign_req", + Description: "Signs an HTTP request object for Amazon Web Services. Currently implements [AWS Signature Version 4 request signing](https://docs.aws.amazon.com/AmazonS3/latest/API/sig-v4-authenticating-requests.html) by the `Authorization` header method.", Decl: types.NewFunction( types.Args( - types.S, + types.Named("request", types.NewObject(nil, types.NewDynamicProperty(types.S, types.A))), + types.Named("aws_config", types.NewObject(nil, types.NewDynamicProperty(types.S, types.A))), + types.Named("time_ns", types.N), ), - types.B, + types.Named("signed_request", types.NewObject(nil, types.NewDynamicProperty(types.A, types.A))), ), + Categories: providersAWSCat, } /** - * Set + * Rego */ -// Intersection returns the intersection of the given input sets -var Intersection = &Builtin{ - Name: "intersection", +var RegoParseModule = &Builtin{ + Name: "rego.parse_module", + Description: "Parses the input Rego string and returns an object representation of the AST.", Decl: types.NewFunction( types.Args( - types.NewSet(types.NewSet(types.A)), + types.Named("filename", types.S).Description("file name to attach to AST nodes' locations"), + types.Named("rego", types.S).Description("Rego module"), ), - types.NewSet(types.A), + types.Named("output", types.NewObject(nil, types.NewDynamicProperty(types.S, types.A))), // TODO(tsandall): import AST schema ), } -// Union returns the union of the given input sets -var Union = &Builtin{ - Name: "union", +var RegoMetadataChain = &Builtin{ + Name: "rego.metadata.chain", + Description: `Returns the chain of metadata for the active rule. +Ordered starting at the active rule, going outward to the most distant node in its package ancestry. +A chain entry is a JSON document with two members: "path", an array representing the path of the node; and "annotations", a JSON document containing the annotations declared for the node. +The first entry in the chain always points to the active rule, even if it has no declared annotations (in which case the "annotations" member is not present).`, + Decl: types.NewFunction( + types.Args(), + types.Named("chain", types.NewArray(nil, types.A)).Description("each array entry represents a node in the path ancestry (chain) of the active rule that also has declared annotations"), + ), +} + +// RegoMetadataRule returns the metadata for the active rule +var RegoMetadataRule = &Builtin{ + Name: "rego.metadata.rule", + Description: "Returns annotations declared for the active rule and using the _rule_ scope.", + Decl: types.NewFunction( + types.Args(), + types.Named("output", types.A).Description("\"rule\" scope annotations for this rule; empty object if no annotations exist"), + ), +} + +/** + * OPA + */ + +// Marked non-deterministic because of unpredictable config/environment-dependent results. +var OPARuntime = &Builtin{ + Name: "opa.runtime", + Description: "Returns an object that describes the runtime environment where OPA is deployed.", + Decl: types.NewFunction( + nil, + types.Named("output", types.NewObject(nil, types.NewDynamicProperty(types.S, types.A))). + Description("includes a `config` key if OPA was started with a configuration file; an `env` key containing the environment variables that the OPA process was started with; includes `version` and `commit` keys containing the version and build commit of OPA."), + ), + Nondeterministic: true, +} + +/** + * Trace + */ +var tracing = category("tracing") + +var Trace = &Builtin{ + Name: "trace", + Description: "Emits `note` as a `Note` event in the query explanation. Query explanations show the exact expressions evaluated by OPA during policy execution. For example, `trace(\"Hello There!\")` includes `Note \"Hello There!\"` in the query explanation. To include variables in the message, use `sprintf`. For example, `person := \"Bob\"; trace(sprintf(\"Hello There! %v\", [person]))` will emit `Note \"Hello There! Bob\"` inside of the explanation.", Decl: types.NewFunction( types.Args( - types.NewSet(types.NewSet(types.A)), + types.Named("note", types.S).Description("the note to include"), ), - types.NewSet(types.A), + types.Named("result", types.B).Description("always `true`"), ), + Categories: tracing, } /** * Glob */ -// GlobMatch - not to be confused with regex.globs_match - parses and matches strings against the glob notation. var GlobMatch = &Builtin{ - Name: "glob.match", + Name: "glob.match", + Description: "Parses and matches strings against the glob notation. Not to be confused with `regex.globs_match`.", Decl: types.NewFunction( types.Args( - types.S, - types.NewArray(nil, types.S), - types.S, + types.Named("pattern", types.S), + types.Named("delimiters", types.NewAny( + types.NewArray(nil, types.S), + types.NewNull(), + )).Description("glob pattern delimiters, e.g. `[\".\", \":\"]`, defaults to `[\".\"]` if unset. If `delimiters` is `null`, glob match without delimiter."), + types.Named("match", types.S), ), - types.B, + types.Named("result", types.B).Description("true if `match` can be found in `pattern` which is separated by `delimiters`"), ), } -// GlobQuoteMeta returns a string which represents a version of the pattern where all asterisks have been escaped. var GlobQuoteMeta = &Builtin{ - Name: "glob.quote_meta", + Name: "glob.quote_meta", + Description: "Returns a string which represents a version of the pattern where all asterisks have been escaped.", Decl: types.NewFunction( types.Args( - types.S, + types.Named("pattern", types.S), ), - types.S, + types.Named("output", types.S).Description("the escaped string of `pattern`"), ), + // TODO(sr): example for this was: Calling ``glob.quote_meta("*.github.com", output)`` returns ``\\*.github.com`` as ``output``. } /** * Networking */ -// NetCIDRIntersects checks if a cidr intersects with another cidr and returns true or false var NetCIDRIntersects = &Builtin{ - Name: "net.cidr_intersects", + Name: "net.cidr_intersects", + Description: "Checks if a CIDR intersects with another CIDR (e.g. `192.168.0.0/16` overlaps with `192.168.1.0/24`). Supports both IPv4 and IPv6 notations.", Decl: types.NewFunction( types.Args( - types.S, - types.S, + types.Named("cidr1", types.S), + types.Named("cidr2", types.S), ), - types.B, + types.Named("result", types.B), ), } -// NetCIDRExpand returns a set of hosts inside the specified cidr. var NetCIDRExpand = &Builtin{ - Name: "net.cidr_expand", + Name: "net.cidr_expand", + Description: "Expands CIDR to set of hosts (e.g., `net.cidr_expand(\"192.168.0.0/30\")` generates 4 hosts: `{\"192.168.0.0\", \"192.168.0.1\", \"192.168.0.2\", \"192.168.0.3\"}`).", Decl: types.NewFunction( types.Args( - types.S, + types.Named("cidr", types.S), ), - types.NewSet(types.S), + types.Named("hosts", types.NewSet(types.S)).Description("set of IP addresses the CIDR `cidr` expands to"), ), } -// NetCIDRContains checks if a cidr or ip is contained within another cidr and returns true or false var NetCIDRContains = &Builtin{ - Name: "net.cidr_contains", + Name: "net.cidr_contains", + Description: "Checks if a CIDR or IP is contained within another CIDR. `output` is `true` if `cidr_or_ip` (e.g. `127.0.0.64/26` or `127.0.0.1`) is contained within `cidr` (e.g. `127.0.0.1/24`) and `false` otherwise. Supports both IPv4 and IPv6 notations.", Decl: types.NewFunction( types.Args( - types.S, - types.S, + types.Named("cidr", types.S), + types.Named("cidr_or_ip", types.S), ), - types.B, + types.Named("result", types.B), ), } -// NetCIDRContainsMatches checks if collections of cidrs or ips are contained within another collection of cidrs and returns matches. var NetCIDRContainsMatches = &Builtin{ Name: "net.cidr_contains_matches", + Description: "Checks if collections of cidrs or ips are contained within another collection of cidrs and returns matches. " + + "This function is similar to `net.cidr_contains` except it allows callers to pass collections of CIDRs or IPs as arguments and returns the matches (as opposed to a boolean result indicating a match between two CIDRs/IPs).", Decl: types.NewFunction( - types.Args(netCidrContainsMatchesOperandType, netCidrContainsMatchesOperandType), - types.NewSet(types.NewArray([]types.Type{types.A, types.A}, nil)), + types.Args( + types.Named("cidrs", netCidrContainsMatchesOperandType), + types.Named("cidrs_or_ips", netCidrContainsMatchesOperandType), + ), + types.Named("output", types.NewSet(types.NewArray([]types.Type{types.A, types.A}, nil))).Description("tuples identifying matches where `cidrs_or_ips` are contained within `cidrs`"), ), } -// NetCIDRMerge merges IP addresses and subnets into the smallest possible list of CIDRs. var NetCIDRMerge = &Builtin{ Name: "net.cidr_merge", + Description: "Merges IP addresses and subnets into the smallest possible list of CIDRs (e.g., `net.cidr_merge([\"192.0.128.0/24\", \"192.0.129.0/24\"])` generates `{\"192.0.128.0/23\"}`." + + `This function merges adjacent subnets where possible, those contained within others and also removes any duplicates. +Supports both IPv4 and IPv6 notations. IPv6 inputs need a prefix length (e.g. "/128").`, Decl: types.NewFunction( - types.Args(netCidrMergeOperandType), - types.NewSet(types.S), + types.Args( + types.Named("addrs", types.NewAny( + types.NewArray(nil, types.NewAny(types.S)), + types.NewSet(types.S), + )).Description("CIDRs or IP addresses"), + ), + types.Named("output", types.NewSet(types.S)).Description("smallest possible set of CIDRs obtained after merging the provided list of IP addresses and subnets in `addrs`"), ), } -var netCidrMergeOperandType = types.NewAny( - types.NewArray(nil, types.NewAny(types.S)), - types.NewSet(types.S), -) +var NetCIDRIsValid = &Builtin{ + Name: "net.cidr_is_valid", + Description: "Parses an IPv4/IPv6 CIDR and returns a boolean indicating if the provided CIDR is valid.", + Decl: types.NewFunction( + types.Args( + types.Named("cidr", types.S), + ), + types.Named("result", types.B), + ), +} var netCidrContainsMatchesOperandType = types.NewAny( types.S, @@ -2282,44 +2944,43 @@ var netCidrContainsMatchesOperandType = types.NewAny( )), ) -// NetLookupIPAddr returns the set of IP addresses (as strings, both v4 and v6) -// that the passed-in name (string) resolves to using the standard name resolution -// mechanisms available. +// Marked non-deterministic because DNS resolution results can be non-deterministic. var NetLookupIPAddr = &Builtin{ - Name: "net.lookup_ip_addr", + Name: "net.lookup_ip_addr", + Description: "Returns the set of IP addresses (both v4 and v6) that the passed-in `name` resolves to using the standard name resolution mechanisms available.", Decl: types.NewFunction( - types.Args(types.S), - types.NewSet(types.S), + types.Args( + types.Named("name", types.S).Description("domain name to resolve"), + ), + types.Named("addrs", types.NewSet(types.S)).Description("IP addresses (v4 and v6) that `name` resolves to"), ), + Nondeterministic: true, } /** * Semantic Versions */ -// SemVerIsValid validiates a the term is a valid SemVer as a string, returns -// false for all other input var SemVerIsValid = &Builtin{ - Name: "semver.is_valid", + Name: "semver.is_valid", + Description: "Validates that the input is a valid SemVer string.", Decl: types.NewFunction( types.Args( - types.A, + types.Named("vsn", types.A), ), - types.B, + types.Named("result", types.B).Description("`true` if `vsn` is a valid SemVer; `false` otherwise"), ), } -// SemVerCompare compares valid SemVer formatted version strings. Given two -// version strings, if A < B returns -1, if A > B returns 1. If A == B, returns -// 0 var SemVerCompare = &Builtin{ - Name: "semver.compare", + Name: "semver.compare", + Description: "Compares valid SemVer formatted version strings.", Decl: types.NewFunction( types.Args( - types.S, - types.S, + types.Named("a", types.S), + types.Named("b", types.S), ), - types.N, + types.Named("result", types.N).Description("`-1` if `a < b`; `1` if `a > b`; `0` if `a == b`"), ), } @@ -2358,6 +3019,7 @@ var SetDiff = &Builtin{ ), types.NewSet(types.A), ), + deprecated: true, } // NetCIDROverlap has been replaced by the `net.cidr_contains` built-in. @@ -2370,6 +3032,7 @@ var NetCIDROverlap = &Builtin{ ), types.B, ), + deprecated: true, } // CastArray checks the underlying type of the input. If it is array or set, an array @@ -2380,6 +3043,7 @@ var CastArray = &Builtin{ types.Args(types.A), types.NewArray(nil, types.A), ), + deprecated: true, } // CastSet checks the underlying type of the input. @@ -2392,6 +3056,7 @@ var CastSet = &Builtin{ types.Args(types.A), types.NewSet(types.A), ), + deprecated: true, } // CastString returns input if it is a string; if not returns error. @@ -2402,6 +3067,7 @@ var CastString = &Builtin{ types.Args(types.A), types.S, ), + deprecated: true, } // CastBoolean returns input if it is a boolean; if not returns error. @@ -2411,6 +3077,7 @@ var CastBoolean = &Builtin{ types.Args(types.A), types.B, ), + deprecated: true, } // CastNull returns null if input is null; if not returns error. @@ -2420,6 +3087,7 @@ var CastNull = &Builtin{ types.Args(types.A), types.NewNull(), ), + deprecated: true, } // CastObject returns the given object if it is null; throws an error otherwise @@ -2429,6 +3097,7 @@ var CastObject = &Builtin{ types.Args(types.A), types.NewObject(nil, types.NewDynamicProperty(types.A, types.A)), ), + deprecated: true, } // RegexMatchDeprecated declares `re_match` which has been deprecated. Use `regex.match` instead. @@ -2441,15 +3110,72 @@ var RegexMatchDeprecated = &Builtin{ ), types.B, ), + deprecated: true, +} + +// All takes a list and returns true if all of the items +// are true. A collection of length 0 returns true. +var All = &Builtin{ + Name: "all", + Decl: types.NewFunction( + types.Args( + types.NewAny( + types.NewSet(types.A), + types.NewArray(nil, types.A), + ), + ), + types.B, + ), + deprecated: true, +} + +// Any takes a collection and returns true if any of the items +// is true. A collection of length 0 returns false. +var Any = &Builtin{ + Name: "any", + Decl: types.NewFunction( + types.Args( + types.NewAny( + types.NewSet(types.A), + types.NewArray(nil, types.A), + ), + ), + types.B, + ), + deprecated: true, } // Builtin represents a built-in function supported by OPA. Every built-in // function is uniquely identified by a name. type Builtin struct { - Name string `json:"name"` // Unique name of built-in function, e.g., (arg1,arg2,...,argN) - Decl *types.Function `json:"decl"` // Built-in function type declaration. - Infix string `json:"infix,omitempty"` // Unique name of infix operator. Default should be unset. - Relation bool `json:"relation,omitempty"` // Indicates if the built-in acts as a relation. + Name string `json:"name"` // Unique name of built-in function, e.g., (arg1,arg2,...,argN) + Description string `json:"description,omitempty"` // Description of what the built-in function does. + + // Categories of the built-in function. Omitted for namespaced + // built-ins, i.e. "array.concat" is taken to be of the "array" category. + // "minus" for example, is part of two categories: numbers and sets. (NOTE(sr): aspirational) + Categories []string `json:"categories,omitempty"` + + Decl *types.Function `json:"decl"` // Built-in function type declaration. + Infix string `json:"infix,omitempty"` // Unique name of infix operator. Default should be unset. + Relation bool `json:"relation,omitempty"` // Indicates if the built-in acts as a relation. + deprecated bool // Indicates if the built-in has been deprecated. + Nondeterministic bool `json:"nondeterministic,omitempty"` // Indicates if the built-in returns non-deterministic results. +} + +// category is a helper for specifying a Builtin's Categories +func category(cs ...string) []string { + return cs +} + +// IsDeprecated returns true if the Builtin function is deprecated and will be removed in a future release. +func (b *Builtin) IsDeprecated() bool { + return b.deprecated +} + +// IsDeterministic returns true if the Builtin function returns non-deterministic results. +func (b *Builtin) IsNondeterministic() bool { + return b.Nondeterministic } // Expr creates a new expression for the built-in with the given operands. diff --git a/ast/capabilities.go b/ast/capabilities.go index a6b162edf7..217e99c999 100644 --- a/ast/capabilities.go +++ b/ast/capabilities.go @@ -5,20 +5,36 @@ package ast import ( + "bytes" + "fmt" "io" + "os" "sort" + "strings" + caps "github.com/open-policy-agent/opa/capabilities" "github.com/open-policy-agent/opa/internal/wasm/sdk/opa/capabilities" "github.com/open-policy-agent/opa/util" ) -// Capabilities defines a structure containing data that describes the capablilities +// In the compiler, we used this to check that we're OK working with ref heads. +// If this isn't present, we'll fail. This is to ensure that older versions of +// OPA can work with policies that we're compiling -- if they don't know ref +// heads, they wouldn't be able to parse them. +const FeatureRefHeadStringPrefixes = "rule_head_ref_string_prefixes" + +// Capabilities defines a structure containing data that describes the capabilities // or features supported by a particular version of OPA. type Capabilities struct { Builtins []*Builtin `json:"builtins"` FutureKeywords []string `json:"future_keywords"` WasmABIVersions []WasmABIVersion `json:"wasm_abi_versions"` + // Features is a bit of a mixed bag for checking that an older version of OPA + // is able to do what needs to be done. + // TODO(sr): find better words ^^ + Features []string `json:"features"` + // allow_net is an array of hostnames or IP addresses, that an OPA instance is // allowed to connect to. // If omitted, ANY host can be connected to. If empty, NO host can be connected to. @@ -44,7 +60,8 @@ func CapabilitiesForThisVersion() *Capabilities { f.WasmABIVersions = append(f.WasmABIVersions, WasmABIVersion{Version: vers[0], Minor: vers[1]}) } - f.Builtins = append(f.Builtins, Builtins...) + f.Builtins = make([]*Builtin, len(Builtins)) + copy(f.Builtins, Builtins) sort.Slice(f.Builtins, func(i, j int) bool { return f.Builtins[i].Name < f.Builtins[j].Name }) @@ -54,6 +71,10 @@ func CapabilitiesForThisVersion() *Capabilities { } sort.Strings(f.FutureKeywords) + f.Features = []string{ + FeatureRefHeadStringPrefixes, + } + return f } @@ -63,3 +84,48 @@ func LoadCapabilitiesJSON(r io.Reader) (*Capabilities, error) { var c Capabilities return &c, d.Decode(&c) } + +// LoadCapabilitiesVersion loads a JSON serialized capabilities structure from the specific version. +func LoadCapabilitiesVersion(version string) (*Capabilities, error) { + cvs, err := LoadCapabilitiesVersions() + if err != nil { + return nil, err + } + + for _, cv := range cvs { + if cv == version { + cont, err := caps.FS.ReadFile(cv + ".json") + if err != nil { + return nil, err + } + + return LoadCapabilitiesJSON(bytes.NewReader(cont)) + } + + } + return nil, fmt.Errorf("no capabilities version found %v", version) +} + +// LoadCapabilitiesFile loads a JSON serialized capabilities structure from a file. +func LoadCapabilitiesFile(file string) (*Capabilities, error) { + fd, err := os.Open(file) + if err != nil { + return nil, err + } + defer fd.Close() + return LoadCapabilitiesJSON(fd) +} + +// LoadCapabilitiesVersions loads all capabilities versions +func LoadCapabilitiesVersions() ([]string, error) { + ents, err := caps.FS.ReadDir(".") + if err != nil { + return nil, err + } + + capabilitiesVersions := make([]string, 0, len(ents)) + for _, ent := range ents { + capabilitiesVersions = append(capabilitiesVersions, strings.Replace(ent.Name(), ".json", "", 1)) + } + return capabilitiesVersions, nil +} diff --git a/ast/capabilities_test.go b/ast/capabilities_test.go index 0d4d8377da..d6db51f1a6 100644 --- a/ast/capabilities_test.go +++ b/ast/capabilities_test.go @@ -1,7 +1,10 @@ package ast import ( + "path" "testing" + + "github.com/open-policy-agent/opa/util/test" ) func TestParserCatchesIllegalCapabilities(t *testing.T) { @@ -83,3 +86,39 @@ func TestParserCapabilitiesWithWildcardOptInAndOlderOPA(t *testing.T) { t.Fatal("unexpected error:", err) } } + +func TestLoadCapabilitiesVersion(t *testing.T) { + + capabilitiesVersions, err := LoadCapabilitiesVersions() + if err != nil { + t.Fatal("expected success", err) + } + + if len(capabilitiesVersions) == 0 { + t.Fatal("expected a non-empty array of capabilities versions") + } + for _, cv := range capabilitiesVersions { + if _, err := LoadCapabilitiesVersion(cv); err != nil { + t.Fatal("expected success", err) + } + } +} + +func TestLoadCapabilitiesFile(t *testing.T) { + + files := map[string]string{ + "test-capabilities.json": ` + { + "builtins": [] + } + `, + } + + test.WithTempFS(files, func(root string) { + _, err := LoadCapabilitiesFile(path.Join(root, "test-capabilities.json")) + if err != nil { + t.Fatal("expected success", err) + } + }) + +} diff --git a/ast/check.go b/ast/check.go index 1e0971bc89..a2edbe66fa 100644 --- a/ast/check.go +++ b/ast/check.go @@ -136,16 +136,8 @@ func (tc *typeChecker) CheckBody(env *TypeEnv, body Body) (*TypeEnv, Errors) { // CheckTypes runs type checking on the rules returns a TypeEnv if no errors // are found. The resulting TypeEnv wraps the provided one. The resulting // TypeEnv will be able to resolve types of refs that refer to rules. -func (tc *typeChecker) CheckTypes(env *TypeEnv, sorted []util.T) (*TypeEnv, Errors) { +func (tc *typeChecker) CheckTypes(env *TypeEnv, sorted []util.T, as *AnnotationSet) (*TypeEnv, Errors) { env = tc.newEnv(env) - var as *annotationSet - if tc.ss != nil { - var errs Errors - as, errs = buildAnnotationSet(sorted) - if len(errs) > 0 { - return env, errs - } - } for _, s := range sorted { tc.checkRule(env, as, s.(*Rule)) } @@ -181,7 +173,7 @@ func (tc *typeChecker) checkClosures(env *TypeEnv, expr *Expr) Errors { return result } -func (tc *typeChecker) checkRule(env *TypeEnv, as *annotationSet, rule *Rule) { +func (tc *typeChecker) checkRule(env *TypeEnv, as *AnnotationSet, rule *Rule) { env = env.wrap() @@ -192,6 +184,9 @@ func (tc *typeChecker) checkRule(env *TypeEnv, as *annotationSet, rule *Rule) { tc.err([]*Error{err}) continue } + if ref == nil && refType == nil { + continue + } prefixRef, t := getPrefix(env, ref) if t == nil || len(prefixRef) == len(ref) { env.tree.Put(ref, refType) @@ -208,7 +203,7 @@ func (tc *typeChecker) checkRule(env *TypeEnv, as *annotationSet, rule *Rule) { cpy, err := tc.CheckBody(env, rule.Body) env = env.next - path := rule.Path() + path := rule.Ref() if len(err) > 0 { // if the rule/function contains an error, add it to the type env so @@ -221,7 +216,6 @@ func (tc *typeChecker) checkRule(env *TypeEnv, as *annotationSet, rule *Rule) { var tpe types.Type if len(rule.Head.Args) > 0 { - // If args are not referred to in body, infer as any. WalkVars(rule.Head.Args, func(v Var) bool { if cpy.Get(v) == nil { @@ -243,23 +237,28 @@ func (tc *typeChecker) checkRule(env *TypeEnv, as *annotationSet, rule *Rule) { tpe = types.Or(exist, f) } else { - switch rule.Head.DocKind() { - case CompleteDoc: - typeV := cpy.Get(rule.Head.Value) - if typeV != nil { - exist := env.tree.Get(path) - tpe = types.Or(typeV, exist) - } - case PartialObjectDoc: - typeK := cpy.Get(rule.Head.Key) + switch rule.Head.RuleKind() { + case SingleValue: typeV := cpy.Get(rule.Head.Value) - if typeK != nil && typeV != nil { - exist := env.tree.Get(path) - typeV = types.Or(types.Values(exist), typeV) - typeK = types.Or(types.Keys(exist), typeK) - tpe = types.NewObject(nil, types.NewDynamicProperty(typeK, typeV)) + if last := path[len(path)-1]; !last.IsGround() { + + // e.g. store object[string: whatever] at data.p.q.r, not data.p.q.r[x] + path = path.GroundPrefix() + + typeK := cpy.Get(last) + if typeK != nil && typeV != nil { + exist := env.tree.Get(path) + typeV = types.Or(types.Values(exist), typeV) + typeK = types.Or(types.Keys(exist), typeK) + tpe = types.NewObject(nil, types.NewDynamicProperty(typeK, typeV)) + } + } else { + if typeV != nil { + exist := env.tree.Get(path) + tpe = types.Or(typeV, exist) + } } - case PartialSetDoc: + case MultiValue: typeK := cpy.Get(rule.Head.Key) if typeK != nil { exist := env.tree.Get(path) @@ -275,6 +274,9 @@ func (tc *typeChecker) checkRule(env *TypeEnv, as *annotationSet, rule *Rule) { } func (tc *typeChecker) checkExpr(env *TypeEnv, expr *Expr) *Error { + if err := tc.checkExprWith(env, expr, 0); err != nil { + return err + } if !expr.IsCall() { return nil } @@ -319,17 +321,19 @@ func (tc *typeChecker) checkExprBuiltin(env *TypeEnv, expr *Expr) *Error { } fargs := ftpe.FuncArgs() + namedFargs := ftpe.NamedFuncArgs() if ftpe.Result() != nil { fargs.Args = append(fargs.Args, ftpe.Result()) + namedFargs.Args = append(namedFargs.Args, ftpe.NamedResult()) } if len(args) > len(fargs.Args) && fargs.Variadic == nil { - return newArgError(expr.Location, name, "too many arguments", pre, fargs) + return newArgError(expr.Location, name, "too many arguments", pre, namedFargs) } if len(args) < len(ftpe.FuncArgs().Args) { - return newArgError(expr.Location, name, "too few arguments", pre, fargs) + return newArgError(expr.Location, name, "too few arguments", pre, namedFargs) } for i := range args { @@ -338,7 +342,7 @@ func (tc *typeChecker) checkExprBuiltin(env *TypeEnv, expr *Expr) *Error { for i := range args { post[i] = env.Get(args[i]) } - return newArgError(expr.Location, name, "invalid argument(s)", post, fargs) + return newArgError(expr.Location, name, "invalid argument(s)", post, namedFargs) } } @@ -373,6 +377,27 @@ func (tc *typeChecker) checkExprEq(env *TypeEnv, expr *Expr) *Error { return nil } +func (tc *typeChecker) checkExprWith(env *TypeEnv, expr *Expr, i int) *Error { + if i == len(expr.With) { + return nil + } + + target, value := expr.With[i].Target, expr.With[i].Value + targetType, valueType := env.Get(target), env.Get(value) + + if t, ok := targetType.(*types.Function); ok { // built-in function replacement + switch v := valueType.(type) { + case *types.Function: // ...by function + if !unifies(targetType, valueType) { + return newArgError(expr.With[i].Loc(), target.Value.(Ref), "arity mismatch", v.Args(), t.NamedFuncArgs()) + } + default: // ... by value, nothing to check + } + } + + return tc.checkExprWith(env, expr, i+1) +} + func unify2(env *TypeEnv, a *Term, typeA types.Type, b *Term, typeB types.Type) bool { nilA := types.Nil(typeA) @@ -618,11 +643,11 @@ func (rc *refChecker) Visit(x interface{}) bool { } func (rc *refChecker) checkApply(curr *TypeEnv, ref Ref) *Error { - if tpe := curr.Get(ref); tpe != nil { - if _, ok := tpe.(*types.Function); ok { - return newRefErrUnsupported(ref[0].Location, rc.varRewriter(ref), len(ref)-1, tpe) - } + switch tpe := curr.Get(ref).(type) { + case *types.Function: // NOTE(sr): We don't support first-class functions, except for `with`. + return newRefErrUnsupported(ref[0].Location, rc.varRewriter(ref), len(ref)-1, tpe) } + return nil } @@ -634,72 +659,57 @@ func (rc *refChecker) checkRef(curr *TypeEnv, node *typeTreeNode, ref Ref, idx i head := ref[idx] - // Handle constant ref operands, i.e., strings or the ref head. - if _, ok := head.Value.(String); ok || idx == 0 { - - child := node.Child(head.Value) - if child == nil { - - if curr.next != nil { - next := curr.next - return rc.checkRef(next, next.tree, ref, 0) - } - - if RootDocumentNames.Contains(ref[0]) { - return rc.checkRefLeaf(types.A, ref, 1) - } - - return rc.checkRefLeaf(types.A, ref, 0) - } - - if child.Leaf() { - return rc.checkRefLeaf(child.Value(), ref, idx+1) + // NOTE(sr): as long as package statements are required, this isn't possible: + // the shortest possible rule ref is data.a.b (b is idx 2), idx 1 and 2 need to + // be strings or vars. + if idx == 1 || idx == 2 { + switch head.Value.(type) { + case Var, String: // OK + default: + have := rc.env.Get(head.Value) + return newRefErrInvalid(ref[0].Location, rc.varRewriter(ref), idx, have, types.S, getOneOfForNode(node)) } - - return rc.checkRef(curr, child, ref, idx+1) } - // Handle dynamic ref operands. - switch value := head.Value.(type) { - - case Var: - - if exist := rc.env.Get(value); exist != nil { - if !unifies(types.S, exist) { - return newRefErrInvalid(ref[0].Location, rc.varRewriter(ref), idx, exist, types.S, getOneOfForNode(node)) + if v, ok := head.Value.(Var); ok && idx != 0 { + tpe := types.Keys(rc.env.getRefRecExtent(node)) + if exist := rc.env.Get(v); exist != nil { + if !unifies(tpe, exist) { + return newRefErrInvalid(ref[0].Location, rc.varRewriter(ref), idx, exist, tpe, getOneOfForNode(node)) } } else { - rc.env.tree.PutOne(value, types.S) + rc.env.tree.PutOne(v, tpe) } + } - case Ref: + child := node.Child(head.Value) + if child == nil { + // NOTE(sr): idx is reset on purpose: we start over + switch { + case curr.next != nil: + next := curr.next + return rc.checkRef(next, next.tree, ref, 0) - exist := rc.env.Get(value) - if exist == nil { - // If ref type is unknown, an error will already be reported so - // stop here. - return nil - } + case RootDocumentNames.Contains(ref[0]): + if idx != 0 { + node.Children().Iter(func(_, child util.T) bool { + _ = rc.checkRef(curr, child.(*typeTreeNode), ref, idx+1) // ignore error + return false + }) + return nil + } + return rc.checkRefLeaf(types.A, ref, 1) - if !unifies(types.S, exist) { - return newRefErrInvalid(ref[0].Location, rc.varRewriter(ref), idx, exist, types.S, getOneOfForNode(node)) + default: + return rc.checkRefLeaf(types.A, ref, 0) } - - // Catch other ref operand types here. Non-leaf nodes must be referred to - // with string values. - default: - return newRefErrInvalid(ref[0].Location, rc.varRewriter(ref), idx, nil, types.S, getOneOfForNode(node)) } - // Run checking on remaining portion of the ref. Note, since the ref - // potentially refers to data for which no type information exists, - // checking should never fail. - node.Children().Iter(func(_, child util.T) bool { - _ = rc.checkRef(curr, child.(*typeTreeNode), ref, idx+1) // ignore error - return false - }) + if child.Leaf() { + return rc.checkRefLeaf(child.Value(), ref, idx+1) + } - return nil + return rc.checkRef(curr, child, ref, idx+1) } func (rc *refChecker) checkRefLeaf(tpe types.Type, ref Ref, idx int) *Error { @@ -805,7 +815,17 @@ func unifies(a, b types.Type) bool { } return unifies(types.Values(a), types.Values(b)) case *types.Function: - // TODO(tsandall): revisit once functions become first-class values. + // NOTE(sr): variadic functions can only be internal ones, and we've forbidden + // their replacement via `with`; so we disregard variadic here + if types.Arity(a) == types.Arity(b) { + b := b.(*types.Function) + for i := range a.FuncArgs().Args { + if !unifies(a.FuncArgs().Arg(i), b.FuncArgs().Arg(i)) { + return false + } + } + return true + } return false default: panic("unreachable") @@ -1170,7 +1190,7 @@ func getObjectType(ref Ref, o types.Type, rule *Rule, d *types.DynamicProperty) return getObjectTypeRec(keys, o, d), nil } -func getRuleAnnotation(as *annotationSet, rule *Rule) (result []*SchemaAnnotation) { +func getRuleAnnotation(as *AnnotationSet, rule *Rule) (result []*SchemaAnnotation) { for _, x := range as.GetSubpackagesScope(rule.Module.Package.Path) { result = append(result, x.Schemas...) @@ -1196,6 +1216,9 @@ func processAnnotation(ss *SchemaSet, annot *SchemaAnnotation, rule *Rule, allow var schema interface{} if annot.Schema != nil { + if ss == nil { + return nil, nil, nil + } schema = ss.Get(annot.Schema) if schema == nil { return nil, nil, NewError(TypeErr, rule.Location, "undefined schema: %v", annot.Schema) @@ -1215,158 +1238,3 @@ func processAnnotation(ss *SchemaSet, annot *SchemaAnnotation, rule *Rule, allow func errAnnotationRedeclared(a *Annotations, other *Location) *Error { return NewError(TypeErr, a.Location, "%v annotation redeclared: %v", a.Scope, other) } - -type annotationSet struct { - byRule map[*Rule][]*Annotations - byPackage map[*Package]*Annotations - byPath *annotationTreeNode -} - -func buildAnnotationSet(rules []util.T) (*annotationSet, Errors) { - as := newAnnotationSet() - processed := map[*Module]struct{}{} - var errs Errors - for _, x := range rules { - module := x.(*Rule).Module - if _, ok := processed[module]; ok { - continue - } - processed[module] = struct{}{} - for _, a := range module.Annotations { - if err := as.Add(a); err != nil { - errs = append(errs, err) - } - } - } - if len(errs) > 0 { - return nil, errs - } - return as, nil -} - -func newAnnotationSet() *annotationSet { - return &annotationSet{ - byRule: map[*Rule][]*Annotations{}, - byPackage: map[*Package]*Annotations{}, - byPath: newAnnotationTree(), - } -} - -func (as *annotationSet) Add(a *Annotations) *Error { - switch a.Scope { - case annotationScopeRule: - rule := a.node.(*Rule) - as.byRule[rule] = append(as.byRule[rule], a) - case annotationScopePackage: - pkg := a.node.(*Package) - if exist, ok := as.byPackage[pkg]; ok { - return errAnnotationRedeclared(a, exist.Location) - } - as.byPackage[pkg] = a - case annotationScopeDocument: - rule := a.node.(*Rule) - path := rule.Path() - x := as.byPath.Get(path) - if x != nil { - return errAnnotationRedeclared(a, x.Value.Location) - } - as.byPath.Insert(path, a) - case annotationScopeSubpackages: - pkg := a.node.(*Package) - x := as.byPath.Get(pkg.Path) - if x != nil { - return errAnnotationRedeclared(a, x.Value.Location) - } - as.byPath.Insert(pkg.Path, a) - } - return nil -} - -func (as *annotationSet) GetRuleScope(r *Rule) []*Annotations { - if as == nil { - return nil - } - return as.byRule[r] -} - -func (as *annotationSet) GetSubpackagesScope(path Ref) []*Annotations { - if as == nil { - return nil - } - return as.byPath.Ancestors(path) -} - -func (as *annotationSet) GetDocumentScope(path Ref) *Annotations { - if as == nil { - return nil - } - if node := as.byPath.Get(path); node != nil { - return node.Value - } - return nil -} - -func (as *annotationSet) GetPackageScope(pkg *Package) *Annotations { - if as == nil { - return nil - } - return as.byPackage[pkg] -} - -type annotationTreeNode struct { - Value *Annotations - Children map[Value]*annotationTreeNode // we assume key elements are hashable (vars and strings only!) -} - -func newAnnotationTree() *annotationTreeNode { - return &annotationTreeNode{ - Value: nil, - Children: map[Value]*annotationTreeNode{}, - } -} - -func (t *annotationTreeNode) Insert(path Ref, value *Annotations) { - node := t - for _, k := range path { - child, ok := node.Children[k.Value] - if !ok { - child = newAnnotationTree() - node.Children[k.Value] = child - } - node = child - } - node.Value = value -} - -func (t *annotationTreeNode) Get(path Ref) *annotationTreeNode { - node := t - for _, k := range path { - if node == nil { - return nil - } - child, ok := node.Children[k.Value] - if !ok { - return nil - } - node = child - } - return node -} - -func (t *annotationTreeNode) Ancestors(path Ref) (result []*Annotations) { - node := t - for _, k := range path { - if node == nil { - return result - } - child, ok := node.Children[k.Value] - if !ok { - return result - } - if child.Value != nil { - result = append(result, child.Value) - } - node = child - } - return result -} diff --git a/ast/check_test.go b/ast/check_test.go index 8b309d09ce..ea01767067 100644 --- a/ast/check_test.go +++ b/ast/check_test.go @@ -7,6 +7,8 @@ package ast import ( "encoding/json" "fmt" + "net/http" + "net/http/httptest" "reflect" "strings" "testing" @@ -344,6 +346,12 @@ func TestCheckInferenceRules(t *testing.T) { {`number_key`, `q[x] = y { a = ["a", "b"]; y = a[x] }`}, {`non_leaf`, `p[x] { data.prefix.i[x][_] }`}, } + ruleset2 := [][2]string{ + {`ref_rule_single`, `p.q.r { true }`}, + {`ref_rule_single_with_number_key`, `p.q[3] { true }`}, + {`ref_regression_array_key`, + `walker[[p, v]] = o { l = input; walk(l, k); [p, v] = k; o = {} }`}, + } tests := []struct { note string @@ -465,6 +473,37 @@ func TestCheckInferenceRules(t *testing.T) { {"non-leaf", ruleset1, "data.non_leaf.p", types.NewSet( types.S, )}, + + {"ref-rules single value, full ref", ruleset2, "data.ref_rule_single.p.q.r", types.B}, + {"ref-rules single value, prefix", ruleset2, "data.ref_rule_single.p", + types.NewObject( + []*types.StaticProperty{{ + Key: "q", Value: types.NewObject( + []*types.StaticProperty{{Key: "r", Value: types.B}}, + types.NewDynamicProperty(types.S, types.A), + ), + }}, + types.NewDynamicProperty(types.S, types.A), + )}, + + {"ref-rules single value, number key, full ref", ruleset2, "data.ref_rule_single_with_number_key.p.q[3]", types.B}, + {"ref-rules single value, number key, prefix", ruleset2, "data.ref_rule_single_with_number_key.p", + types.NewObject( + []*types.StaticProperty{{ + Key: "q", Value: types.NewObject( + []*types.StaticProperty{{Key: json.Number("3"), Value: types.B}}, + types.NewDynamicProperty(types.S, types.A), + ), + }}, + types.NewDynamicProperty(types.S, types.A), + )}, + + {"ref_regression_array_key", ruleset2, "data.ref_regression_array_key.walker", + types.NewObject( + nil, + types.NewDynamicProperty(types.NewArray([]types.Type{types.NewArray(types.A, types.A), types.A}, nil), + types.NewObject(nil, types.NewDynamicProperty(types.A, types.A))), + )}, } for _, tc := range tests { @@ -489,7 +528,7 @@ func TestCheckInferenceRules(t *testing.T) { ref := MustParseRef(tc.ref) checker := newTypeChecker() - env, err := checker.CheckTypes(nil, elems) + env, err := checker.CheckTypes(newTypeChecker().Env(map[string]*Builtin{"walk": BuiltinMap["walk"]}), elems, nil) if err != nil { t.Fatalf("Unexpected error %v:", err) @@ -512,6 +551,87 @@ func TestCheckInferenceRules(t *testing.T) { } +func TestCheckInferenceOverlapWithRules(t *testing.T) { + ruleset1 := [][2]string{ + {`prefix.i.j.k`, `p = 1 { true }`}, + {`prefix.i.j.k`, `p = "foo" { true }`}, + } + tests := []struct { + note string + rules [][2]string + ref string + expected types.Type // ref's type + query string + extra map[Var]types.Type + }{ + { + note: "non-leaf, extra vars", + rules: ruleset1, + ref: "data.prefix.i.j[k]", + expected: types.A, + query: "data.prefix.i.j[k][b]", + extra: map[Var]types.Type{ + Var("k"): types.S, + Var("b"): types.S, + }, + }, + } + + for _, tc := range tests { + t.Run(tc.note, func(t *testing.T) { + var elems []util.T + + // Convert test rules into rule slice for "warmup" call. + for i := range tc.rules { + pkg := MustParsePackage(`package ` + tc.rules[i][0]) + rule := MustParseRule(tc.rules[i][1]) + module := &Module{ + Package: pkg, + Rules: []*Rule{rule}, + } + rule.Module = module + elems = append(elems, rule) + for next := rule.Else; next != nil; next = next.Else { + next.Module = module + elems = append(elems, next) + } + } + + ref := MustParseRef(tc.ref) + checker := newTypeChecker() + env, err := checker.CheckTypes(nil, elems, nil) + if err != nil { + t.Fatalf("Unexpected error %v:", err) + } + + result := env.Get(ref) + if tc.expected == nil { + if result != nil { + t.Errorf("Expected %v type to be unset but got: %v", ref, result) + } + } else { + if result == nil { + t.Errorf("Expected to infer %v => %v but got nil", ref, tc.expected) + } else if types.Compare(tc.expected, result) != 0 { + t.Errorf("Expected to infer %v => %v but got %v", ref, tc.expected, result) + } + } + + body := MustParseBody(tc.query) + env, err = checker.CheckBody(env, body) + if len(err) != 0 { + t.Fatalf("Unexpected error: %v", err) + } + for ex, exp := range tc.extra { + act := env.Get(ex) + if types.Compare(act, exp) != 0 { + t.Errorf("Expected to infer extra %v => %v but got %v", ex, exp, act) + } + } + }) + } +} + func TestCheckErrorSuppression(t *testing.T) { query := `arr = [1,2,3]; arr[0].deadbeef = 1` @@ -642,7 +762,7 @@ func TestCheckBuiltinErrors(t *testing.T) { {"objects-any", `fake_builtin_2({"a": a, "c": c})`}, {"objects-bad-input", `sum({"a": 1, "b": 2}, x)`}, {"sets-any", `sum({1,2,"3",4}, x)`}, - {"virtual-ref", `plus(data.test.p, data.deabeef, 0)`}, + {"virtual-ref", `plus(data.test.p, data.coffee, 0)`}, } env := newTestEnv([]string{ @@ -781,6 +901,7 @@ func TestCheckRefErrInvalid(t *testing.T) { env := newTestEnv([]string{ `p { true }`, `q = {"foo": 1, "bar": 2} { true }`, + `a.b.c[3] = x { x = {"x": {"y": 2}} }`, }) tests := []struct { @@ -799,7 +920,7 @@ func TestCheckRefErrInvalid(t *testing.T) { pos: 2, have: types.N, want: types.S, - oneOf: []Value{String("p"), String("q")}, + oneOf: []Value{String("a"), String("p"), String("q")}, }, { note: "bad non-leaf ref", @@ -808,7 +929,7 @@ func TestCheckRefErrInvalid(t *testing.T) { pos: 2, have: types.N, want: types.S, - oneOf: []Value{String("p"), String("q")}, + oneOf: []Value{String("a"), String("p"), String("q")}, }, { note: "bad leaf ref", @@ -819,6 +940,24 @@ func TestCheckRefErrInvalid(t *testing.T) { want: types.S, oneOf: []Value{String("bar"), String("foo")}, }, + { + note: "bad ref hitting last term", + query: `x = true; data.test.a.b.c[x][_]`, + ref: `data.test.a.b.c[x][_]`, + pos: 5, + have: types.B, + want: types.Any{types.N, types.S}, + oneOf: []Value{Number("3")}, + }, + { + note: "bad ref hitting dynamic part", + query: `s = true; data.test.a.b.c[3].x[s][_] = _`, + ref: `data.test.a.b.c[3].x[s][_]`, + pos: 7, + have: types.B, + want: types.S, + oneOf: []Value{String("y")}, + }, { note: "bad leaf var", query: `x = 1; data.test.q[x]`, @@ -851,12 +990,25 @@ func TestCheckRefErrInvalid(t *testing.T) { oneOf: []Value{String("a"), String("c")}, }, { + // NOTE(sr): Thins one and the next are special: it cannot work with ref heads, either, since we need at + // least ONE string term after data.test: a module needs a package line, and the shortest head ref + // possible is thus data.x.y. note: "bad non-leaf value", query: `data.test[1]`, ref: "data.test[1]", pos: 2, + have: types.N, want: types.S, - oneOf: []Value{String("p"), String("q")}, + oneOf: []Value{String("a"), String("p"), String("q")}, + }, + { + note: "bad non-leaf value (package)", // See note above ^^ + query: `data[1]`, + ref: "data[1]", + pos: 1, + have: types.N, + want: types.S, + oneOf: []Value{String("test")}, }, { note: "composite ref operand", @@ -1020,9 +1172,11 @@ func TestFunctionTypeInferenceUnappliedWithObjectVarKey(t *testing.T) { f(x) = y { y = {x: 1} } `) - env, err := newTypeChecker().CheckTypes(newTypeChecker().Env(BuiltinMap), []util.T{ + elems := []util.T{ module.Rules[0], - }) + } + + env, err := newTypeChecker().CheckTypes(newTypeChecker().Env(BuiltinMap), elems, nil) if len(err) > 0 { t.Fatal(err) @@ -1218,8 +1372,8 @@ func TestCheckErrorOrdering(t *testing.T) { inputReversed[1] = inputReversed[2] inputReversed[2] = tmp - _, errs1 := newTypeChecker().CheckTypes(nil, input) - _, errs2 := newTypeChecker().CheckTypes(nil, inputReversed) + _, errs1 := newTypeChecker().CheckTypes(nil, input, nil) + _, errs2 := newTypeChecker().CheckTypes(nil, inputReversed, nil) if errs1.Error() != errs2.Error() { t.Fatalf("Expected error slices to be equal. errs1:\n\n%v\n\nerrs2:\n\n%v\n\n", errs1, errs2) @@ -1253,7 +1407,9 @@ func newTestEnv(rs []string) *TypeEnv { package test `) - var elems []util.T + // We preallocate enough for at least the base rules. + // Else cases will cause reallocs, but that's okay. + elems := make([]util.T, 0, len(rs)) for i := range rs { rule := MustParseRule(rs[i]) @@ -1265,7 +1421,7 @@ func newTestEnv(rs []string) *TypeEnv { } } - env, err := newTypeChecker().CheckTypes(newTypeChecker().Env(BuiltinMap), elems) + env, err := newTypeChecker().CheckTypes(newTypeChecker().Env(BuiltinMap), elems, nil) if len(err) > 0 { panic(err) } @@ -1818,39 +1974,39 @@ whocan[user] { schemaSet.Put(MustParseRef(`schema["acl-schema"]`), dschema) tests := []struct { - note string - module string - err string + note string + modules []string + err string }{ - {note: "data and input annotations", module: module1}, - {note: "correct data override", module: module2}, - {note: "incorrect data override", module: module3, err: "undefined ref: input.user"}, - {note: "missing schema", module: module4, err: "undefined schema: schema.missing"}, - {note: "overriding ref with length greater than one and not existing", module: module8, err: "undefined ref: input.apple.banana"}, - {note: "overriding ref with length greater than one and existing prefix", module: module9}, - {note: "overriding ref with length greater than one and existing prefix with type error", module: module10, err: "undefined ref: input.apple.orange.banana.fruit"}, - {note: "overriding ref with length greater than one and existing ref", module: module11, err: "undefined ref: input.apple.orange.user"}, - {note: "overriding ref of size one", module: module12, err: "undefined ref: input.user"}, - {note: "overriding annotation written with brackets", module: module13, err: "undefined ref: input.apple.orange.fruit"}, - {note: "overriding strict", module: module14, err: "undefined ref: input.request.object.spec.typo"}, - {note: "data annotation but no input schema", module: module15}, - {note: "data schema annotation does not overly restrict data expression", module: module16}, - {note: "correct defer annotation on another rule has no effect base case", module: module17}, - {note: "correct defer annotation on another rule has no effect", module: module18}, - {note: "overriding ref with data prefix", module: module19, err: "data.acl.foo.blah"}, - {note: "data annotation type error", module: module20, err: "data.acl.foo"}, - {note: "more than one rule with metadata", module: module21}, - {note: "more than one rule with metadata with type error", module: module22, err: "undefined ref"}, - {note: "document scope", err: "test.rego:8: rego_type_error: match error", module: `package test + {note: "data and input annotations", modules: []string{module1}}, + {note: "correct data override", modules: []string{module2}}, + {note: "incorrect data override", modules: []string{module3}, err: "undefined ref: input.user"}, + {note: "missing schema", modules: []string{module4}, err: "undefined schema: schema.missing"}, + {note: "overriding ref with length greater than one and not existing", modules: []string{module8}, err: "undefined ref: input.apple.banana"}, + {note: "overriding ref with length greater than one and existing prefix", modules: []string{module9}}, + {note: "overriding ref with length greater than one and existing prefix with type error", modules: []string{module10}, err: "undefined ref: input.apple.orange.banana.fruit"}, + {note: "overriding ref with length greater than one and existing ref", modules: []string{module11}, err: "undefined ref: input.apple.orange.user"}, + {note: "overriding ref of size one", modules: []string{module12}, err: "undefined ref: input.user"}, + {note: "overriding annotation written with brackets", modules: []string{module13}, err: "undefined ref: input.apple.orange.fruit"}, + {note: "overriding strict", modules: []string{module14}, err: "undefined ref: input.request.object.spec.typo"}, + {note: "data annotation but no input schema", modules: []string{module15}}, + {note: "data schema annotation does not overly restrict data expression", modules: []string{module16}}, + {note: "correct defer annotation on another rule has no effect base case", modules: []string{module17}}, + {note: "correct defer annotation on another rule has no effect", modules: []string{module18}}, + {note: "overriding ref with data prefix", modules: []string{module19}, err: "data.acl.foo.blah"}, + {note: "data annotation type error", modules: []string{module20}, err: "data.acl.foo"}, + {note: "more than one rule with metadata", modules: []string{module21}}, + {note: "more than one rule with metadata with type error", modules: []string{module22}, err: "undefined ref"}, + {note: "document scope", err: "test1.rego:8: rego_type_error: match error", modules: []string{`package test # METADATA # scope: document # schemas: # - input.foo: schema.number p { input.foo = 7 } -p { input.foo = [] }`}, +p { input.foo = [] }`}}, - {note: "rule scope overrides document scope", module: `package test + {note: "rule scope overrides document scope", modules: []string{`package test # METADATA # scope: document @@ -1862,9 +2018,9 @@ p { input.foo = 7 } # scope: rule # schemas: # - input.foo: schema.string -p { input.foo = "str" }`}, +p { input.foo = "str" }`}}, - {note: "rule scope merges with document scope", err: "test.rego:15: rego_type_error: match error", module: `package test + {note: "rule scope merges with document scope", err: "test1.rego:15: rego_type_error: match error", modules: []string{`package test # METADATA # scope: document @@ -1879,9 +2035,9 @@ p { input.bar = 7 } p { input.foo = "str" input.bar = "str" -}`}, +}`}}, - {note: "document scope conflict", err: "test.rego:9: rego_type_error: document annotation redeclared: test.rego:3", module: `package test + {note: "document scope conflict", err: "test1.rego:9: rego_type_error: document annotation redeclared: test1.rego:3", modules: []string{`package test # METADATA # scope: document @@ -1893,17 +2049,45 @@ p { input.foo = 7 } # scope: document # schemas: # - input.foo: schema.string -p { input.foo = "str" }`}, +p { input.foo = "str" }`}}, + + {note: "package scope in other module", modules: []string{`# METADATA +# scope: package +# schemas: +# - input.foo: schema.number +package test`, `package test + +p { input.foo = 7 }`}}, - {note: "subpackages scope", err: "test.rego:7: rego_type_error: match error", module: `# METADATA + {note: "package scope in other module type conflict", err: "test2.rego:3: rego_type_error: match error", modules: []string{`# METADATA +# scope: package +# schemas: +# - input.foo: schema.string +package test`, `package test + +p { input.foo = 7 }`}}, + + {note: "package scope conflict", err: "test2.rego:1: rego_type_error: package annotation redeclared: test1.rego:1", modules: []string{`# METADATA +# scope: package +# schemas: +# - input.foo: schema.string +package test`, `# METADATA +# scope: package +# schemas: +# - input.foo: schema.number +package test + +p { input.foo = 7 }`}}, + + {note: "subpackages scope", err: "test1.rego:7: rego_type_error: match error", modules: []string{`# METADATA # scope: subpackages # schemas: # - input: schema.number package test -p { input = "str" }`}, +p { input = "str" }`}}, - {note: "document scope overrides subpackages scope", module: `# METADATA + {note: "document scope overrides subpackages scope", modules: []string{`# METADATA # scope: subpackages # schemas: # - input: schema.number @@ -1913,9 +2097,9 @@ package test # scope: document # schemas: # - input: schema.string -p { input = "str" }`}, +p { input = "str" }`}}, - {note: "document scope overrides subpackages scope and finds error", err: "test.rego:11: rego_type_error: match error", module: `# METADATA + {note: "document scope overrides subpackages scope and finds error", err: "test1.rego:11: rego_type_error: match error", modules: []string{`# METADATA # scope: subpackages # schemas: # - input: schema.string @@ -1925,17 +2109,17 @@ package test # scope: rule # schemas: # - input: schema.number -p { input = "str" }`}, +p { input = "str" }`}}, - {note: "package scope", err: "test.rego:7: rego_type_error: match error", module: `# METADATA + {note: "package scope", err: "test1.rego:7: rego_type_error: match error", modules: []string{`# METADATA # scope: package # schemas: # - input: schema.string package test -p { input = 7 }`}, +p { input = 7 }`}}, - {note: "rule scope overrides package scope", module: `# METADATA + {note: "rule scope overrides package scope", modules: []string{`# METADATA # scope: package # schemas: # - input: schema.string @@ -1945,16 +2129,16 @@ package test # scope: rule # schemas: # - input: schema.number -p { input = 7 }`}, +p { input = 7 }`}}, - {note: "inline definition", err: "test.rego:7: rego_type_error: match error", module: `package test + {note: "inline definition", err: "test1.rego:7: rego_type_error: match error", modules: []string{`package test # METADATA # scope: rule # schemas: # - input: {"type": "string"} -p { input = 7 }`}, - {note: "document scope is unordered", err: "test.rego:3: rego_type_error: match error", module: `package test +p { input = 7 }`}}, + {note: "document scope is unordered", err: "test1.rego:3: rego_type_error: match error", modules: []string{`package test p { input = 7 } @@ -1962,28 +2146,35 @@ p { input = 7 } # scope: document # schemas: # - input: schema.string -p { input = "foo" }`}, +p { input = "foo" }`}}, } for _, tc := range tests { t.Run(tc.note, func(t *testing.T) { - mod, err := ParseModuleWithOpts("test.rego", tc.module, ParserOptions{ - ProcessAnnotation: true, - }) - if err != nil { - t.Fatal(err) - } - + var modules []*Module var elems []util.T - for _, rule := range mod.Rules { - elems = append(elems, rule) - for next := rule.Else; next != nil; next = next.Else { - elems = append(elems, next) + + for i, module := range tc.modules { + mod, err := ParseModuleWithOpts(fmt.Sprintf("test%d.rego", i+1), module, ParserOptions{ + ProcessAnnotation: true, + }) + if err != nil { + t.Fatal(err) + } + modules = append(modules, mod) + + for _, rule := range mod.Rules { + elems = append(elems, rule) + for next := rule.Else; next != nil; next = next.Else { + elems = append(elems, next) + } } } oldTypeEnv := newTypeChecker().WithSchemaSet(schemaSet).Env(BuiltinMap) - typeenv, errors := newTypeChecker().WithSchemaSet(schemaSet).CheckTypes(oldTypeEnv, elems) + as, errors := BuildAnnotationSet(modules) + typeenv, checkErrors := newTypeChecker().WithSchemaSet(schemaSet).CheckTypes(oldTypeEnv, elems, as) + errors = append(errors, checkErrors...) if len(errors) > 0 { for _, e := range errors { if tc.err == "" || !strings.Contains(e.Error(), tc.err) { @@ -2059,7 +2250,9 @@ q = p`, ss.Put(ref, schema) } - compiler := NewCompiler().WithSchemas(ss) + compiler := NewCompiler(). + WithSchemas(ss). + WithUseTypeCheckAnnotations(true) compiler.Compile(modules) if compiler.Failed() { t.Fatal("unexpected error:", compiler.Errors) @@ -2078,3 +2271,98 @@ q = p`, } } + +func TestRemoteSchema(t *testing.T) { + schema := `{"type": "boolean"}` + + schemaCalled := false + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + schemaCalled = true + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(schema)) + })) + defer server.Close() + + policy := fmt.Sprintf(` +package test + +# METADATA +# schemas: +# - input: {$ref: "%s"} +p { + input == 42 +}`, server.URL) + + module, err := ParseModuleWithOpts("policy.rego", policy, ParserOptions{ProcessAnnotation: true}) + if err != nil { + t.Fatal(err) + } + modules := map[string]*Module{"policy.rego": module} + + compiler := NewCompiler(). + WithUseTypeCheckAnnotations(true) + compiler.Compile(modules) + + if !compiler.Failed() { + t.Fatal("expected error, got none") + } + + expectedTypeError := "rego_type_error: match error" + if !strings.Contains(compiler.Errors.Error(), expectedTypeError) { + t.Fatalf("expected error:\n\n%s\n\ngot:\n\n%s", + expectedTypeError, compiler.Errors.Error()) + } + + if !schemaCalled { + t.Fatal("expected schema server to be called, was not") + } +} + +func TestRemoteSchemaHostNotAllowed(t *testing.T) { + capabilities := CapabilitiesForThisVersion() + capabilities.AllowNet = []string{} + schema := `{"type": "boolean"}` + + schemaCalled := false + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + schemaCalled = true + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(schema)) + })) + defer server.Close() + + policy := fmt.Sprintf(` +package test + +# METADATA +# schemas: +# - input: {$ref: "%s"} +p { + input == 42 +}`, server.URL) + + module, err := ParseModuleWithOpts("policy.rego", policy, ParserOptions{ProcessAnnotation: true}) + if err != nil { + t.Fatal(err) + } + modules := map[string]*Module{"policy.rego": module} + + compiler := NewCompiler(). + WithUseTypeCheckAnnotations(true). + WithCapabilities(capabilities) + compiler.Compile(modules) + + if !compiler.Failed() { + t.Fatal("expected error, got none") + } + + expectedTypeError := "rego_type_error: unable to compile the schema: remote reference loading disabled" + if !strings.Contains(compiler.Errors.Error(), expectedTypeError) { + t.Fatalf("expected error:\n\n%s\n\ngot:\n\n%s", + expectedTypeError, compiler.Errors.Error()) + } + + if schemaCalled { + t.Fatal("expected schema server to not be called, was") + } +} diff --git a/ast/compare.go b/ast/compare.go index 685082da63..3bb6f2a75d 100644 --- a/ast/compare.go +++ b/ast/compare.go @@ -165,7 +165,12 @@ func Compare(a, b interface{}) int { case *Array: b := b.(*Array) return termSliceCompare(a.elems, b.elems) + case *lazyObj: + return Compare(a.force(), b) case *object: + if x, ok := b.(*lazyObj); ok { + b = x.force() + } b := b.(*object) return a.Compare(b) case Set: @@ -201,6 +206,9 @@ func Compare(a, b interface{}) int { case *SomeDecl: b := b.(*SomeDecl) return a.Compare(b) + case *Every: + b := b.(*Every) + return a.Compare(b) case *With: b := b.(*With) return a.Compare(b) @@ -272,6 +280,8 @@ func sortOrder(x interface{}) int { return 100 case *SomeDecl: return 101 + case *Every: + return 102 case *With: return 110 case *Head: diff --git a/ast/compare_test.go b/ast/compare_test.go index 7f30d58ad5..daa2683bbb 100644 --- a/ast/compare_test.go +++ b/ast/compare_test.go @@ -325,6 +325,445 @@ func TestCompareAnnotations(t *testing.T) { # - input: {"type": "string"}`, exp: 1, }, + { + note: "title", + a: ` +# METADATA +# title: a`, + b: ` +# METADATA +# title: a`, + exp: 0, + }, + { + note: "title - less than", + a: ` +# METADATA +# title: a`, + b: ` +# METADATA +# title: b`, + exp: -1, + }, + { + note: "title - greater than", + a: ` +# METADATA +# title: b`, + b: ` +# METADATA +# title: a`, + exp: 1, + }, + { + note: "description", + a: ` +# METADATA +# description: a`, + b: ` +# METADATA +# description: a`, + exp: 0, + }, + { + note: "description - less than", + a: ` +# METADATA +# description: a`, + b: ` +# METADATA +# description: b`, + exp: -1, + }, + { + note: "description - greater than", + a: ` +# METADATA +# description: b`, + b: ` +# METADATA +# description: a`, + exp: 1, + }, + { + note: "authors", + a: ` +# METADATA +# authors: +# - John Doe +# - Jane Doe`, + b: ` +# METADATA +# authors: +# - John Doe +# - Jane Doe`, + exp: 0, + }, + { + note: "authors - less than", + a: ` +# METADATA +# authors: +# - Jane Doe +# - John Doe +`, + b: ` +# METADATA +# authors: +# - John Doe +# - Jane Doe`, + exp: -1, + }, + { + note: "authors - greater than", + a: ` +# METADATA +# authors: +# - John Doe +# - Jane Doe`, + b: ` +# METADATA +# authors: +# - Jane Doe +# - John Doe`, + exp: 1, + }, + { + note: "authors - less than (fewer)", + a: ` +# METADATA +# scope: rule +# authors: +# - John Doe`, + b: ` +# METADATA +# scope: rule +# authors: +# - John Doe +# - Jane Doe`, + exp: -1, + }, + { + note: "authors - greater than (more)", + a: ` +# METADATA +# scope: rule +# authors: +# - John Doe +# - Jane Doe`, + b: ` +# METADATA +# scope: rule +# authors: +# - John Doe`, + exp: 1, + }, + { + note: "authors - less than (email)", + a: ` +# METADATA +# authors: +# - John Doe `, + b: ` +# METADATA +# authors: +# - John Doe `, + exp: -1, + }, + { + note: "authors - greater than (email)", + a: ` +# METADATA +# authors: +# - John Doe `, + b: ` +# METADATA +# authors: +# - John Doe `, + exp: 1, + }, + { + note: "organizations", + a: ` +# METADATA +# organizations: +# - a +# - b`, + b: ` +# METADATA +# organizations: +# - a +# - b`, + exp: 0, + }, + { + note: "organizations - less than", + a: ` +# METADATA +# organizations: +# - a +# - b`, + b: ` +# METADATA +# organizations: +# - c +# - d`, + exp: -1, + }, + { + note: "organizations - greater than", + a: ` +# METADATA +# organizations: +# - c +# - d`, + b: ` +# METADATA +# organizations: +# - a +# - b`, + exp: 1, + }, + { + note: "organizations - less than (fewer)", + a: ` +# METADATA +# scope: rule +# organizations: +# - a`, + b: ` +# METADATA +# scope: rule +# organizations: +# - a +# - b`, + exp: -1, + }, + { + note: "organizations - greater than (more)", + a: ` +# METADATA +# scope: rule +# organizations: +# - a +# - b`, + b: ` +# METADATA +# scope: rule +# organizations: +# - a`, + exp: 1, + }, + { + note: "related_resources", + a: ` +# METADATA +# related_resources: +# - https://a.example.com +# - +# ref: https://b.example.com +# description: foo bar`, + b: ` +# METADATA +# related_resources: +# - https://a.example.com +# - +# ref: https://b.example.com +# description: foo bar`, + exp: 0, + }, + { + note: "related_resources - less than", + a: ` +# METADATA +# related_resources: +# - https://a.example.com +# - https://b.example.com`, + b: ` +# METADATA +# related_resources: +# - https://b.example.com +# - https://c.example.com`, + exp: -1, + }, + { + note: "related_resources - greater than", + a: ` +# METADATA +# related_resources: +# - https://b.example.com +# - https://c.example.com`, + b: ` +# METADATA +# related_resources: +# - https://a.example.com +# - https://b.example.com`, + exp: 1, + }, + { + note: "related_resources - less than (fewer)", + a: ` +# METADATA +# scope: rule +# organizations: +# - https://a.example.com`, + b: ` +# METADATA +# scope: rule +# organizations: +# - https://a.example.com +# - https://b.example.com`, + exp: -1, + }, + { + note: "related_resources - greater than (more)", + a: ` +# METADATA +# scope: rule +# organizations: +# - https://a.example.com +# - https://b.example.com`, + b: ` +# METADATA +# scope: rule +# organizations: +# - https://a.example.com`, + exp: 1, + }, + { + note: "related_resources - less than (description)", + a: ` +# METADATA +# related_resources: +# - +# ref: https://example.com +# description: a`, + b: ` +# METADATA +# related_resources: +# - +# ref: https://example.com +# description: b`, + exp: -1, + }, + { + note: "related_resources - greater than (description)", + a: ` +# METADATA +# related_resources: +# - +# ref: https://example.com +# description: b`, + b: ` +# METADATA +# related_resources: +# - +# ref: https://example.com +# description: a`, + exp: 1, + }, + { + note: "custom", + a: ` +# METADATA +# custom: +# a: 1 +# b: true +# c: +# d: +# - 1 +# - 2 +# e: +# i: 1 +# j: 2`, + b: ` +# METADATA +# custom: +# a: 1 +# b: true +# c: +# d: +# - 1 +# - 2 +# e: +# i: 1 +# j: 2`, + exp: 0, + }, + { + note: "custom - less than", + a: ` +# METADATA +# custom: +# a: 1`, + b: ` +# METADATA +# custom: +# b: 1`, + exp: -1, + }, + { + note: "custom - greater than", + a: ` +# METADATA +# custom: +# b: 1`, + b: ` +# METADATA +# custom: +# a: 1`, + exp: 1, + }, + { + note: "custom - less than (value)", + a: ` +# METADATA +# custom: +# a: 1`, + b: ` +# METADATA +# custom: +# a: 2`, + exp: -1, + }, + { + note: "custom - greater than (value)", + a: ` +# METADATA +# custom: +# a: 2`, + b: ` +# METADATA +# custom: +# a: 1`, + exp: 1, + }, + { + note: "custom - less than (fewer)", + a: ` +# METADATA +# custom: +# a: 1`, + b: ` +# METADATA +# custom: +# a: 1 +# b: 2`, + exp: -1, + }, + { + note: "custom - greater than (more)", + a: ` +# METADATA +# custom: +# a: 1 +# b: 2`, + b: ` +# METADATA +# custom: +# a: 1`, + exp: 1, + }, } for _, tc := range tests { diff --git a/ast/compile.go b/ast/compile.go index f446606d16..1f99e09692 100644 --- a/ast/compile.go +++ b/ast/compile.go @@ -11,6 +11,7 @@ import ( "strconv" "strings" + "github.com/open-policy-agent/opa/ast/location" "github.com/open-policy-agent/opa/internal/debug" "github.com/open-policy-agent/opa/internal/gojsonschema" "github.com/open-policy-agent/opa/metrics" @@ -63,6 +64,7 @@ type Compiler struct { // p[1] { true } // p[2] { true } // q = true + // a.b.c = 3 // // root // | @@ -73,6 +75,12 @@ type Compiler struct { // +--- p (2 rules) // | // +--- q (1 rule) + // | + // +--- a + // | + // +--- b + // | + // +--- c (1 rule) RuleTree *TreeNode // Graph contains dependencies between rules. An edge (u,v) is added to the @@ -94,21 +102,27 @@ type Compiler struct { metricName string f func() } - maxErrs int - sorted []string // list of sorted module names - pathExists func([]string) (bool, error) - after map[string][]CompilerStageDefinition - metrics metrics.Metrics - capabilities *Capabilities // user-supplied capabilities - builtins map[string]*Builtin // universe of built-in functions - customBuiltins map[string]*Builtin // user-supplied custom built-in functions (deprecated: use capabilities) - unsafeBuiltinsMap map[string]struct{} // user-supplied set of unsafe built-ins functions to block (deprecated: use capabilities) - enablePrintStatements bool // indicates if print statements should be elided (default) - comprehensionIndices map[*Term]*ComprehensionIndex // comprehension key index - initialized bool // indicates if init() has been called - debug debug.Debug // emits debug information produced during compilation - schemaSet *SchemaSet // user-supplied schemas for input and data documents - inputType types.Type // global input type retrieved from schema set + maxErrs int + sorted []string // list of sorted module names + pathExists func([]string) (bool, error) + after map[string][]CompilerStageDefinition + metrics metrics.Metrics + capabilities *Capabilities // user-supplied capabilities + builtins map[string]*Builtin // universe of built-in functions + customBuiltins map[string]*Builtin // user-supplied custom built-in functions (deprecated: use capabilities) + unsafeBuiltinsMap map[string]struct{} // user-supplied set of unsafe built-ins functions to block (deprecated: use capabilities) + deprecatedBuiltinsMap map[string]struct{} // set of deprecated, but not removed, built-in functions + enablePrintStatements bool // indicates if print statements should be elided (default) + comprehensionIndices map[*Term]*ComprehensionIndex // comprehension key index + initialized bool // indicates if init() has been called + debug debug.Debug // emits debug information produced during compilation + schemaSet *SchemaSet // user-supplied schemas for input and data documents + inputType types.Type // global input type retrieved from schema set + annotationSet *AnnotationSet // hierarchical set of annotations + strict bool // enforce strict compilation checks + keepModules bool // whether to keep the unprocessed, parse modules (below) + parsedModules map[string]*Module // parsed, but otherwise unprocessed modules, kept track of when keepModules is true + useTypeCheckAnnotations bool // whether to provide annotated information (schemas) to the type checker } // CompilerStage defines the interface for stages in the compiler. @@ -215,6 +229,9 @@ type QueryCompiler interface { // ComprehensionIndex returns an index data structure for the given comprehension // term. If no index is found, returns nil. ComprehensionIndex(term *Term) *ComprehensionIndex + + // WithStrict enables strict mode for the query compiler. + WithStrict(strict bool) QueryCompiler } // QueryCompilerStage defines the interface for stages in the query compiler. @@ -239,11 +256,12 @@ func NewCompiler() *Compiler { }, func(x util.T) int { return x.(Ref).Hash() }), - maxErrs: CompileErrorLimitDefault, - after: map[string][]CompilerStageDefinition{}, - unsafeBuiltinsMap: map[string]struct{}{}, - comprehensionIndices: map[*Term]*ComprehensionIndex{}, - debug: debug.Discard(), + maxErrs: CompileErrorLimitDefault, + after: map[string][]CompilerStageDefinition{}, + unsafeBuiltinsMap: map[string]struct{}{}, + deprecatedBuiltinsMap: map[string]struct{}{}, + comprehensionIndices: map[*Term]*ComprehensionIndex{}, + debug: debug.Discard(), } c.ModuleTree = NewModuleTree(nil) @@ -258,16 +276,23 @@ func NewCompiler() *Compiler { // load additional modules. If any stages run before resolution, they // need to be re-run after resolution. {"ResolveRefs", "compile_stage_resolve_refs", c.resolveAllRefs}, - {"SetModuleTree", "compile_stage_set_module_tree", c.setModuleTree}, - {"SetRuleTree", "compile_stage_set_rule_tree", c.setRuleTree}, // The local variable generator must be initialized after references are // resolved and the dynamic module loader has run but before subsequent // stages that need to generate variables. {"InitLocalVarGen", "compile_stage_init_local_var_gen", c.initLocalVarGen}, + {"RewriteRuleHeadRefs", "compile_stage_rewrite_rule_head_refs", c.rewriteRuleHeadRefs}, + {"CheckKeywordOverrides", "compile_stage_check_keyword_overrides", c.checkKeywordOverrides}, + {"CheckDuplicateImports", "compile_stage_check_duplicate_imports", c.checkDuplicateImports}, + {"RemoveImports", "compile_stage_remove_imports", c.removeImports}, + {"SetModuleTree", "compile_stage_set_module_tree", c.setModuleTree}, + {"SetRuleTree", "compile_stage_set_rule_tree", c.setRuleTree}, // depends on RewriteRuleHeadRefs {"RewriteLocalVars", "compile_stage_rewrite_local_vars", c.rewriteLocalVars}, {"CheckVoidCalls", "compile_stage_check_void_calls", c.checkVoidCalls}, {"RewritePrintCalls", "compile_stage_rewrite_print_calls", c.rewritePrintCalls}, {"RewriteExprTerms", "compile_stage_rewrite_expr_terms", c.rewriteExprTerms}, + {"ParseMetadataBlocks", "compile_stage_parse_metadata_blocks", c.parseMetadataBlocks}, + {"SetAnnotationSet", "compile_stage_set_annotationset", c.setAnnotationSet}, + {"RewriteRegoMetadataCalls", "compile_stage_rewrite_rego_metadata_calls", c.rewriteRegoMetadataCalls}, {"SetGraph", "compile_stage_set_graph", c.setGraph}, {"RewriteComprehensionTerms", "compile_stage_rewrite_comprehension_terms", c.rewriteComprehensionTerms}, {"RewriteRefsInHead", "compile_stage_rewrite_refs_in_head", c.rewriteRefsInHead}, @@ -279,8 +304,9 @@ func NewCompiler() *Compiler { {"RewriteEquals", "compile_stage_rewrite_equals", c.rewriteEquals}, {"RewriteDynamicTerms", "compile_stage_rewrite_dynamic_terms", c.rewriteDynamicTerms}, {"CheckRecursion", "compile_stage_check_recursion", c.checkRecursion}, - {"CheckTypes", "compile_stage_check_types", c.checkTypes}, + {"CheckTypes", "compile_stage_check_types", c.checkTypes}, // must be run after CheckRecursion {"CheckUnsafeBuiltins", "compile_state_check_unsafe_builtins", c.checkUnsafeBuiltins}, + {"CheckDeprecatedBuiltins", "compile_state_check_deprecated_builtins", c.checkDeprecatedBuiltins}, {"BuildRuleIndices", "compile_stage_rebuild_indices", c.buildRuleIndices}, {"BuildComprehensionIndices", "compile_stage_rebuild_comprehension_indices", c.buildComprehensionIndices}, } @@ -336,6 +362,11 @@ func (c *Compiler) WithCapabilities(capabilities *Capabilities) *Compiler { return c } +// Capabilities returns the capabilities enabled during compilation. +func (c *Compiler) Capabilities() *Capabilities { + return c.capabilities +} + // WithDebug sets where debug messages are written to. Passing `nil` has no // effect. func (c *Compiler) WithDebug(sink io.Writer) *Compiler { @@ -362,10 +393,38 @@ func (c *Compiler) WithUnsafeBuiltins(unsafeBuiltins map[string]struct{}) *Compi return c } -// QueryCompiler returns a new QueryCompiler object. +// WithStrict enables strict mode in the compiler. +func (c *Compiler) WithStrict(strict bool) *Compiler { + c.strict = strict + return c +} + +// WithKeepModules enables retaining unprocessed modules in the compiler. +// Note that the modules aren't copied on the way in or out -- so when +// accessing them via ParsedModules(), mutations will occur in the module +// map that was passed into Compile().` +func (c *Compiler) WithKeepModules(y bool) *Compiler { + c.keepModules = y + return c +} + +// WithUseTypeCheckAnnotations use schema annotations during type checking +func (c *Compiler) WithUseTypeCheckAnnotations(enabled bool) *Compiler { + c.useTypeCheckAnnotations = enabled + return c +} + +// ParsedModules returns the parsed, unprocessed modules from the compiler. +// It is `nil` if keeping modules wasn't enabled via `WithKeepModules(true)`. +// The map includes all modules loaded via the ModuleLoader, if one was used. +func (c *Compiler) ParsedModules() map[string]*Module { + return c.parsedModules +} + func (c *Compiler) QueryCompiler() QueryCompiler { c.init() - return newQueryCompiler(c) + c0 := *c + return newQueryCompiler(&c0) } // Compile runs the compilation process on the input modules. The compiled @@ -377,10 +436,20 @@ func (c *Compiler) Compile(modules map[string]*Module) { c.init() c.Modules = make(map[string]*Module, len(modules)) + c.sorted = make([]string, 0, len(modules)) + + if c.keepModules { + c.parsedModules = make(map[string]*Module, len(modules)) + } else { + c.parsedModules = nil + } for k, v := range modules { c.Modules[k] = v.Copy() c.sorted = append(c.sorted, k) + if c.parsedModules != nil { + c.parsedModules[k] = v + } } sort.Strings(c.sorted) @@ -424,16 +493,16 @@ func (c *Compiler) GetArity(ref Ref) int { // // E.g., given the following module: // -// package a.b.c +// package a.b.c // -// p[k] = v { ... } # rule1 -// p[k1] = v1 { ... } # rule2 +// p[k] = v { ... } # rule1 +// p[k1] = v1 { ... } # rule2 // // The following calls yield the rules on the right. // -// GetRulesExact("data.a.b.c.p") => [rule1, rule2] -// GetRulesExact("data.a.b.c.p.x") => nil -// GetRulesExact("data.a.b.c") => nil +// GetRulesExact("data.a.b.c.p") => [rule1, rule2] +// GetRulesExact("data.a.b.c.p.x") => nil +// GetRulesExact("data.a.b.c") => nil func (c *Compiler) GetRulesExact(ref Ref) (rules []*Rule) { node := c.RuleTree @@ -451,16 +520,16 @@ func (c *Compiler) GetRulesExact(ref Ref) (rules []*Rule) { // // E.g., given the following module: // -// package a.b.c +// package a.b.c // -// p[k] = v { ... } # rule1 -// p[k1] = v1 { ... } # rule2 +// p[k] = v { ... } # rule1 +// p[k1] = v1 { ... } # rule2 // // The following calls yield the rules on the right. // -// GetRulesForVirtualDocument("data.a.b.c.p") => [rule1, rule2] -// GetRulesForVirtualDocument("data.a.b.c.p.x") => [rule1, rule2] -// GetRulesForVirtualDocument("data.a.b.c") => nil +// GetRulesForVirtualDocument("data.a.b.c.p") => [rule1, rule2] +// GetRulesForVirtualDocument("data.a.b.c.p.x") => [rule1, rule2] +// GetRulesForVirtualDocument("data.a.b.c") => nil func (c *Compiler) GetRulesForVirtualDocument(ref Ref) (rules []*Rule) { node := c.RuleTree @@ -481,17 +550,17 @@ func (c *Compiler) GetRulesForVirtualDocument(ref Ref) (rules []*Rule) { // // E.g., given the following module: // -// package a.b.c +// package a.b.c // -// p[x] = y { ... } # rule1 -// p[k] = v { ... } # rule2 -// q { ... } # rule3 +// p[x] = y { ... } # rule1 +// p[k] = v { ... } # rule2 +// q { ... } # rule3 // // The following calls yield the rules on the right. // -// GetRulesWithPrefix("data.a.b.c.p") => [rule1, rule2] -// GetRulesWithPrefix("data.a.b.c.p.a") => nil -// GetRulesWithPrefix("data.a.b.c") => [rule1, rule2, rule3] +// GetRulesWithPrefix("data.a.b.c.p") => [rule1, rule2] +// GetRulesWithPrefix("data.a.b.c.p.a") => nil +// GetRulesWithPrefix("data.a.b.c") => [rule1, rule2, rule3] func (c *Compiler) GetRulesWithPrefix(ref Ref) (rules []*Rule) { node := c.RuleTree @@ -519,9 +588,10 @@ func (c *Compiler) GetRulesWithPrefix(ref Ref) (rules []*Rule) { return rules } -func extractRules(s []util.T) (rules []*Rule) { - for _, r := range s { - rules = append(rules, r.(*Rule)) +func extractRules(s []util.T) []*Rule { + rules := make([]*Rule, len(s)) + for i := range s { + rules[i] = s[i].(*Rule) } return rules } @@ -530,18 +600,18 @@ func extractRules(s []util.T) (rules []*Rule) { // // E.g., given the following module: // -// package a.b.c +// package a.b.c // -// p[x] = y { q[x] = y; ... } # rule1 -// q[x] = y { ... } # rule2 +// p[x] = y { q[x] = y; ... } # rule1 +// q[x] = y { ... } # rule2 // // The following calls yield the rules on the right. // -// GetRules("data.a.b.c.p") => [rule1] -// GetRules("data.a.b.c.p.x") => [rule1] -// GetRules("data.a.b.c.q") => [rule2] -// GetRules("data.a.b.c") => [rule1, rule2] -// GetRules("data.a.b.d") => nil +// GetRules("data.a.b.c.p") => [rule1] +// GetRules("data.a.b.c.p.x") => [rule1] +// GetRules("data.a.b.c.q") => [rule2] +// GetRules("data.a.b.c") => [rule1, rule2] +// GetRules("data.a.b.d") => nil func (c *Compiler) GetRules(ref Ref) (rules []*Rule) { set := map[*Rule]struct{}{} @@ -576,34 +646,34 @@ func (c *Compiler) GetRulesDynamic(ref Ref) []*Rule { // // E.g., given the following modules: // -// package a.b.c +// package a.b.c // -// r1 = 1 # rule1 +// r1 = 1 # rule1 // // and: // -// package a.d.c +// package a.d.c // -// r2 = 2 # rule2 +// r2 = 2 # rule2 // // The following calls yield the rules on the right. // -// GetRulesDynamicWithOpts("data.a[x].c[y]", opts) => [rule1, rule2] -// GetRulesDynamicWithOpts("data.a[x].c.r2", opts) => [rule2] -// GetRulesDynamicWithOpts("data.a.b[x][y]", opts) => [rule1] +// GetRulesDynamicWithOpts("data.a[x].c[y]", opts) => [rule1, rule2] +// GetRulesDynamicWithOpts("data.a[x].c.r2", opts) => [rule2] +// GetRulesDynamicWithOpts("data.a.b[x][y]", opts) => [rule1] // // Using the RulesOptions parameter, the inclusion of hidden modules can be // controlled: // // With // -// package system.main +// package system.main // -// r3 = 3 # rule3 +// r3 = 3 # rule3 // // We'd get this result: // -// GetRulesDynamicWithOpts("data[x]", RulesOptions{IncludeHiddenModules: true}) => [rule1, rule2, rule3] +// GetRulesDynamicWithOpts("data[x]", RulesOptions{IncludeHiddenModules: true}) => [rule1, rule2, rule3] // // Without the options, it would be excluded. func (c *Compiler) GetRulesDynamicWithOpts(ref Ref, opts RulesOptions) []*Rule { @@ -612,7 +682,8 @@ func (c *Compiler) GetRulesDynamicWithOpts(ref Ref, opts RulesOptions) []*Rule { set := map[*Rule]struct{}{} var walk func(node *TreeNode, i int) walk = func(node *TreeNode, i int) { - if i >= len(ref) { + switch { + case i >= len(ref): // We've reached the end of the reference and want to collect everything // under this "prefix". node.DepthFirst(func(descendant *TreeNode) bool { @@ -622,7 +693,8 @@ func (c *Compiler) GetRulesDynamicWithOpts(ref Ref, opts RulesOptions) []*Rule { } return descendant.Hide }) - } else if i == 0 || IsConstant(ref[i].Value) { + + case i == 0 || IsConstant(ref[i].Value): // The head of the ref is always grounded. In case another part of the // ref is also grounded, we can lookup the exact child. If it's not found // we can immediately return... @@ -636,7 +708,8 @@ func (c *Compiler) GetRulesDynamicWithOpts(ref Ref, opts RulesOptions) []*Rule { // Otherwise, we continue using the child node. walk(child, i+1) } - } else { + + default: // This part of the ref is a dynamic term. We can't know what it refers // to and will just need to try all of the children. for _, child := range node.Children { @@ -714,13 +787,30 @@ func (c *Compiler) buildRuleIndices() { if len(node.Values) == 0 { return false } + rules := extractRules(node.Values) + hasNonGroundKey := false + for _, r := range rules { + if ref := r.Head.Ref(); len(ref) > 1 { + if !ref[len(ref)-1].IsGround() { + hasNonGroundKey = true + } + } + } + if hasNonGroundKey { + // collect children: as of now, this cannot go deeper than one level, + // so we grab those, and abort the DepthFirst processing for this branch + for _, n := range node.Children { + rules = append(rules, extractRules(n.Values)...) + } + } + index := newBaseDocEqIndex(func(ref Ref) bool { return isVirtual(c.RuleTree, ref.GroundPrefix()) }) - if rules := extractRules(node.Values); index.Build(rules) { - c.ruleIndices.Put(rules[0].Path(), index) + if index.Build(rules) { + c.ruleIndices.Put(rules[0].Ref().GroundPrefix(), index) } - return false + return hasNonGroundKey // currently, we don't allow those branches to go deeper }) } @@ -757,7 +847,7 @@ func (c *Compiler) checkRecursion() { func (c *Compiler) checkSelfPath(loc *Location, eq func(a, b util.T) bool, a, b util.T) { tr := NewGraphTraversal(c.Graph) if p := util.DFSPath(tr, eq, a, b); len(p) > 0 { - n := []string{} + n := make([]string, 0, len(p)) for _, x := range p { n = append(n, astNodeToString(x)) } @@ -766,46 +856,103 @@ func (c *Compiler) checkSelfPath(loc *Location, eq func(a, b util.T) bool, a, b } func astNodeToString(x interface{}) string { - switch x := x.(type) { - case *Rule: - return string(x.Head.Name) - default: - panic("not reached") - } + return x.(*Rule).Ref().String() } // checkRuleConflicts ensures that rules definitions are not in conflict. func (c *Compiler) checkRuleConflicts() { + rw := rewriteVarsInRef(c.RewrittenVars) + c.RuleTree.DepthFirst(func(node *TreeNode) bool { if len(node.Values) == 0 { - return false + return false // go deeper } - kinds := map[DocKind]struct{}{} + kinds := make(map[RuleKind]struct{}, len(node.Values)) defaultRules := 0 - arities := map[int]struct{}{} - declared := false + arities := make(map[int]struct{}, len(node.Values)) + name := "" + var singleValueConflicts []Ref + var multiValueConflicts []Ref for _, rule := range node.Values { r := rule.(*Rule) - kinds[r.Head.DocKind()] = struct{}{} + ref := r.Ref() + name = rw(ref.Copy()).String() // varRewriter operates in-place + kinds[r.Head.RuleKind()] = struct{}{} arities[len(r.Head.Args)] = struct{}{} - if r.Head.Assign { - declared = true - } if r.Default { defaultRules++ } + + // Single-value rules may not have any other rules in their extent: these pairs are invalid: + // + // data.p.q.r { true } # data.p.q is { "r": true } + // data.p.q.r.s { true } + // + // data.p.q[r] { r := input.r } # data.p.q could be { "r": true } + // data.p.q.r.s { true } + // + // data.p[r] := x { r = input.key; x = input.bar } + // data.p.q[r] := x { r = input.key; x = input.bar } + + // But this is allowed: + // data.p.q[r] = 1 { r := "r" } + // data.p.q.s = 2 + + if r.Head.RuleKind() == SingleValue && len(node.Children) > 0 { + if len(ref) > 1 && !ref[len(ref)-1].IsGround() { // p.q[x] and p.q.s.t => check grandchildren + for _, c := range node.Children { + grandchildrenFound := false + + if len(c.Values) > 0 { + childRules := extractRules(c.Values) + for _, childRule := range childRules { + childRef := childRule.Ref() + if childRule.Head.RuleKind() == SingleValue && !childRef[len(childRef)-1].IsGround() { + // The child is a partial object rule, so it's effectively "generating" grandchildren. + grandchildrenFound = true + break + } + } + } + + if len(c.Children) > 0 { + grandchildrenFound = true + } + + if grandchildrenFound { + singleValueConflicts = node.flattenChildren() + break + } + } + } else { // p.q.s and p.q.s.t => any children are in conflict + singleValueConflicts = node.flattenChildren() + } + } + + // Multi-value rules may not have any other rules in their extent; e.g.: + // + // data.p[v] { v := ... } + // data.p.q := 42 # In direct conflict with data.p[v], which is constructing a set and cannot have values assigned to a sub-path. + + if r.Head.RuleKind() == MultiValue && len(node.Children) > 0 { + multiValueConflicts = node.flattenChildren() + } } - name := Var(node.Key.(String)) + switch { + case singleValueConflicts != nil: + c.err(NewError(TypeErr, node.Values[0].(*Rule).Loc(), "single-value rule %v conflicts with %v", name, singleValueConflicts)) + + case multiValueConflicts != nil: + c.err(NewError(TypeErr, node.Values[0].(*Rule).Loc(), "multi-value rule %v conflicts with %v", name, multiValueConflicts)) - if declared && len(node.Values) > 1 { - c.err(NewError(TypeErr, node.Values[0].(*Rule).Loc(), "rule named %v redeclared at %v", name, node.Values[1].(*Rule).Loc())) - } else if len(kinds) > 1 || len(arities) > 1 { - c.err(NewError(TypeErr, node.Values[0].(*Rule).Loc(), "conflicting rules named %v found", name)) - } else if defaultRules > 1 { - c.err(NewError(TypeErr, node.Values[0].(*Rule).Loc(), "multiple default rules named %s found", name)) + case len(kinds) > 1 || len(arities) > 1: + c.err(NewError(TypeErr, node.Values[0].(*Rule).Loc(), "conflicting rules %v found", name)) + + case defaultRules > 1: + c.err(NewError(TypeErr, node.Values[0].(*Rule).Loc(), "multiple default rules %s found", name)) } return false @@ -817,12 +964,19 @@ func (c *Compiler) checkRuleConflicts() { } } + // NOTE(sr): depthfirst might better use sorted for stable errs? c.ModuleTree.DepthFirst(func(node *ModuleTreeNode) bool { for _, mod := range node.Modules { for _, rule := range mod.Rules { - if childNode, ok := node.Children[String(rule.Head.Name)]; ok { + ref := rule.Head.Ref().GroundPrefix() + childNode, tail := node.find(ref) + if childNode != nil && len(tail) == 0 { for _, childMod := range childNode.Modules { - msg := fmt.Sprintf("%v conflicts with rule defined at %v", childMod.Package, rule.Loc()) + // Avoid recursively checking a module for equality unless we know it's a possible self-match. + if childMod.Equal(mod) { + continue // don't self-conflict + } + msg := fmt.Sprintf("%v conflicts with rule %v defined at %v", childMod.Package, rule.Head.Ref(), rule.Loc()) c.err(NewError(TypeErr, mod.Package.Loc(), msg)) } } @@ -881,7 +1035,7 @@ func arityMismatchError(env *TypeEnv, f Ref, expr *Expr, exp, act int) *Error { for i, op := range expr.Operands() { have[i] = env.Get(op) } - return newArgError(expr.Loc(), f, "arity mismatch", have, want.FuncArgs()) + return newArgError(expr.Loc(), f, "arity mismatch", have, want.NamedFuncArgs()) } if act != 1 { return NewError(TypeErr, expr.Loc(), "function %v has arity %d, got %d arguments", f, exp, act) @@ -1005,7 +1159,25 @@ func mergeSchemas(schemas ...*gojsonschema.SubSchema) (*gojsonschema.SubSchema, return result, nil } -func parseSchema(schema interface{}) (types.Type, error) { +type schemaParser struct { + definitionCache map[string]*cachedDef +} + +type cachedDef struct { + properties []*types.StaticProperty +} + +func newSchemaParser() *schemaParser { + return &schemaParser{ + definitionCache: map[string]*cachedDef{}, + } +} + +func (parser *schemaParser) parseSchema(schema interface{}) (types.Type, error) { + return parser.parseSchemaWithPropertyKey(schema, "") +} + +func (parser *schemaParser) parseSchemaWithPropertyKey(schema interface{}, propertyKey string) (types.Type, error) { subSchema, ok := schema.(*gojsonschema.SubSchema) if !ok { return nil, fmt.Errorf("unexpected schema type %v", subSchema) @@ -1013,7 +1185,10 @@ func parseSchema(schema interface{}) (types.Type, error) { // Handle referenced schemas, returns directly when a $ref is found if subSchema.RefSchema != nil { - return parseSchema(subSchema.RefSchema) + if existing, ok := parser.definitionCache[subSchema.Ref.String()]; ok { + return types.NewObject(existing.properties, nil), nil + } + return parser.parseSchemaWithPropertyKey(subSchema.RefSchema, subSchema.Ref.String()) } // Handle anyOf @@ -1025,7 +1200,7 @@ func parseSchema(schema interface{}) (types.Type, error) { copySchema := *subSchema copySchemaRef := ©Schema copySchemaRef.AnyOf = nil - coreType, err := parseSchema(copySchemaRef) + coreType, err := parser.parseSchema(copySchemaRef) if err != nil { return nil, fmt.Errorf("unexpected schema type %v: %w", subSchema, err) } @@ -1040,7 +1215,7 @@ func parseSchema(schema interface{}) (types.Type, error) { // Iterate through every property of AnyOf and add it to orType for _, pSchema := range subSchema.AnyOf { - newtype, err := parseSchema(pSchema) + newtype, err := parser.parseSchema(pSchema) if err != nil { return nil, fmt.Errorf("unexpected schema type %v: %w", pSchema, err) } @@ -1063,12 +1238,12 @@ func parseSchema(schema interface{}) (types.Type, error) { if err != nil { return nil, err } - return parseSchema(objectOrArrayResult) + return parser.parseSchema(objectOrArrayResult) } else if subSchema.Types.String() != allOfResult.Types.String() { return nil, fmt.Errorf("unable to merge these schemas") } } - return parseSchema(allOfResult) + return parser.parseSchema(allOfResult) } if subSchema.Types.IsTyped() { @@ -1083,15 +1258,28 @@ func parseSchema(schema interface{}) (types.Type, error) { } else if subSchema.Types.Contains("object") { if len(subSchema.PropertiesChildren) > 0 { - staticProps := make([]*types.StaticProperty, 0, len(subSchema.PropertiesChildren)) + def := &cachedDef{ + properties: make([]*types.StaticProperty, 0, len(subSchema.PropertiesChildren)), + } for _, pSchema := range subSchema.PropertiesChildren { - newtype, err := parseSchema(pSchema) + def.properties = append(def.properties, types.NewStaticProperty(pSchema.Property, nil)) + } + if propertyKey != "" { + parser.definitionCache[propertyKey] = def + } + for _, pSchema := range subSchema.PropertiesChildren { + newtype, err := parser.parseSchema(pSchema) if err != nil { return nil, fmt.Errorf("unexpected schema type %v: %w", pSchema, err) } - staticProps = append(staticProps, types.NewStaticProperty(pSchema.Property, newtype)) + for i, prop := range def.properties { + if prop.Key == pSchema.Property { + def.properties[i].Value = newtype + break + } + } } - return types.NewObject(staticProps, nil), nil + return types.NewObject(def.properties, nil), nil } return types.NewObject(nil, types.NewDynamicProperty(types.A, types.A)), nil @@ -1099,7 +1287,7 @@ func parseSchema(schema interface{}) (types.Type, error) { if len(subSchema.ItemsChildren) > 0 { if subSchema.ItemsChildrenIsSingleSchema { iSchema := subSchema.ItemsChildren[0] - newtype, err := parseSchema(iSchema) + newtype, err := parser.parseSchema(iSchema) if err != nil { return nil, fmt.Errorf("unexpected schema type %v", iSchema) } @@ -1108,7 +1296,7 @@ func parseSchema(schema interface{}) (types.Type, error) { newTypes := make([]types.Type, 0, len(subSchema.ItemsChildren)) for i := 0; i != len(subSchema.ItemsChildren); i++ { iSchema := subSchema.ItemsChildren[i] - newtype, err := parseSchema(iSchema) + newtype, err := parser.parseSchema(iSchema) if err != nil { return nil, fmt.Errorf("unexpected schema type %v", iSchema) } @@ -1123,27 +1311,46 @@ func parseSchema(schema interface{}) (types.Type, error) { // Assume types if not specified in schema if len(subSchema.PropertiesChildren) > 0 { if err := subSchema.Types.Add("object"); err == nil { - return parseSchema(subSchema) + return parser.parseSchema(subSchema) } } else if len(subSchema.ItemsChildren) > 0 { if err := subSchema.Types.Add("array"); err == nil { - return parseSchema(subSchema) + return parser.parseSchema(subSchema) } } return types.A, nil } +func (c *Compiler) setAnnotationSet() { + // Sorting modules by name for stable error reporting + sorted := make([]*Module, 0, len(c.Modules)) + for _, mName := range c.sorted { + sorted = append(sorted, c.Modules[mName]) + } + + as, errs := BuildAnnotationSet(sorted) + for _, err := range errs { + c.err(err) + } + c.annotationSet = as +} + // checkTypes runs the type checker on all rules. The type checker builds a // TypeEnv that is stored on the compiler. func (c *Compiler) checkTypes() { // Recursion is caught in earlier step, so this cannot fail. sorted, _ := c.Graph.Sort() checker := newTypeChecker(). + WithAllowNet(c.capabilities.AllowNet). WithSchemaSet(c.schemaSet). WithInputType(c.inputType). WithVarRewriter(rewriteVarsInRef(c.RewrittenVars)) - env, errs := checker.CheckTypes(c.TypeEnv, sorted) + var as *AnnotationSet + if c.useTypeCheckAnnotations { + as = c.annotationSet + } + env, errs := checker.CheckTypes(c.TypeEnv, sorted, as) for _, err := range errs { c.err(err) } @@ -1159,6 +1366,15 @@ func (c *Compiler) checkUnsafeBuiltins() { } } +func (c *Compiler) checkDeprecatedBuiltins() { + for _, name := range c.sorted { + errs := checkDeprecatedBuiltins(c.deprecatedBuiltinsMap, c.Modules[name], c.strict) + for _, err := range errs { + c.err(err) + } + } +} + func (c *Compiler) runStage(metricName string, f func()) { if c.metrics != nil { c.metrics.Timer(metricName).Start() @@ -1188,10 +1404,10 @@ func (c *Compiler) compile() { if c.Failed() { return } - for _, s := range c.after[s.name] { - err := c.runStageAfter(s.MetricName, s.Stage) - if err != nil { + for _, a := range c.after[s.name] { + if err := c.runStageAfter(a.MetricName, a.Stage); err != nil { c.err(err) + return } } } @@ -1211,6 +1427,9 @@ func (c *Compiler) init() { for _, bi := range c.capabilities.Builtins { c.builtins[bi.Name] = bi + if c.strict && bi.IsDeprecated() { + c.deprecatedBuiltinsMap[bi.Name] = struct{}{} + } } for name, bi := range c.customBuiltins { @@ -1248,30 +1467,99 @@ func (c *Compiler) err(err *Error) { func (c *Compiler) getExports() *util.HashMap { rules := util.NewHashMap(func(a, b util.T) bool { - r1 := a.(Ref) - r2 := a.(Ref) - return r1.Equal(r2) + return a.(Ref).Equal(b.(Ref)) }, func(v util.T) int { return v.(Ref).Hash() }) for _, name := range c.sorted { mod := c.Modules[name] - rv, ok := rules.Get(mod.Package.Path) - if !ok { - rv = []Var{} - } - rvs := rv.([]Var) for _, rule := range mod.Rules { - rvs = append(rvs, rule.Head.Name) + hashMapAdd(rules, mod.Package.Path, rule.Head.Ref().GroundPrefix()) } - rules.Put(mod.Package.Path, rvs) } return rules } +func hashMapAdd(rules *util.HashMap, pkg, rule Ref) { + prev, ok := rules.Get(pkg) + if !ok { + rules.Put(pkg, []Ref{rule}) + return + } + for _, p := range prev.([]Ref) { + if p.Equal(rule) { + return + } + } + rules.Put(pkg, append(prev.([]Ref), rule)) +} + +func (c *Compiler) GetAnnotationSet() *AnnotationSet { + return c.annotationSet +} + +func (c *Compiler) checkDuplicateImports() { + if !c.strict { + return + } + + for _, name := range c.sorted { + mod := c.Modules[name] + processedImports := map[Var]*Import{} + + for _, imp := range mod.Imports { + name := imp.Name() + + if processed, conflict := processedImports[name]; conflict { + c.err(NewError(CompileErr, imp.Location, "import must not shadow %v", processed)) + } else { + processedImports[name] = imp + } + } + } +} + +func (c *Compiler) checkKeywordOverrides() { + for _, name := range c.sorted { + mod := c.Modules[name] + errs := checkKeywordOverrides(mod, c.strict) + for _, err := range errs { + c.err(err) + } + } +} + +func checkKeywordOverrides(node interface{}, strict bool) Errors { + if !strict { + return nil + } + + errors := Errors{} + + WalkRules(node, func(rule *Rule) bool { + name := rule.Head.Name.String() + if RootDocumentRefs.Contains(RefTerm(VarTerm(name))) { + errors = append(errors, NewError(CompileErr, rule.Location, "rules must not shadow %v (use a different rule name)", name)) + } + return true + }) + + WalkExprs(node, func(expr *Expr) bool { + if expr.IsAssignment() { + name := expr.Operand(0).String() + if RootDocumentRefs.Contains(RefTerm(VarTerm(name))) { + errors = append(errors, NewError(CompileErr, expr.Location, "variables must not shadow %v (use a different variable name)", name)) + } + } + return false + }) + + return errors +} + // resolveAllRefs resolves references in expressions to their fully qualified values. // // For instance, given the following module: @@ -1281,6 +1569,15 @@ func (c *Compiler) getExports() *util.HashMap { // p[x] { bar[_] = x } // // The reference "bar[_]" would be resolved to "data.foo.bar[_]". +// +// Ref rules are resolved, too: +// +// package a.b +// q { c.d.e == 1 } +// c.d[e] := 1 if e := "e" +// +// The reference "c.d.e" would be resolved to "data.a.b.c.d.e". + func (c *Compiler) resolveAllRefs() { rules := c.getExports() @@ -1288,9 +1585,9 @@ func (c *Compiler) resolveAllRefs() { for _, name := range c.sorted { mod := c.Modules[name] - var ruleExports []Var + var ruleExports []Ref if x, ok := rules.Get(mod.Package.Path); ok { - ruleExports = x.([]Var) + ruleExports = x.([]Ref) } globals := getGlobals(mod.Package, ruleExports, mod.Imports) @@ -1303,8 +1600,20 @@ func (c *Compiler) resolveAllRefs() { return false }) - // Once imports have been resolved, they are no longer needed. - mod.Imports = nil + if c.strict { // check for unused imports + for _, imp := range mod.Imports { + path := imp.Path.Value.(Ref) + if FutureRootDocument.Equal(path[0]) { + continue // ignore future imports + } + + for v, u := range globals { + if v.Equal(imp.Name()) && !u.used { + c.err(NewError(CompileErr, imp.Location, "%s unused", imp.String())) + } + } + } + } } if c.moduleLoader != nil { @@ -1322,6 +1631,9 @@ func (c *Compiler) resolveAllRefs() { for id, module := range parsed { c.Modules[id] = module.Copy() c.sorted = append(c.sorted, id) + if c.parsedModules != nil { + c.parsedModules[id] = module + } } sort.Strings(c.sorted) @@ -1329,6 +1641,12 @@ func (c *Compiler) resolveAllRefs() { } } +func (c *Compiler) removeImports() { + for name := range c.Modules { + c.Modules[name].Imports = nil + } +} + func (c *Compiler) initLocalVarGen() { c.localvargen = newLocalVarGeneratorForModuleSet(c.sorted, c.Modules) } @@ -1352,6 +1670,65 @@ func (c *Compiler) rewriteExprTerms() { } } +func (c *Compiler) rewriteRuleHeadRefs() { + f := newEqualityFactory(c.localvargen) + for _, name := range c.sorted { + WalkRules(c.Modules[name], func(rule *Rule) bool { + + ref := rule.Head.Ref() + // NOTE(sr): We're backfilling Refs here -- all parser code paths would have them, but + // it's possible to construct Module{} instances from Golang code, so we need + // to accommodate for that, too. + if len(rule.Head.Reference) == 0 { + rule.Head.Reference = ref + } + + cannotSpeakRefs := true + for _, f := range c.capabilities.Features { + if f == FeatureRefHeadStringPrefixes { + cannotSpeakRefs = false + break + } + } + + if cannotSpeakRefs && rule.Head.Name == "" { + c.err(NewError(CompileErr, rule.Loc(), "rule heads with refs are not supported: %v", rule.Head.Reference)) + return true + } + + for i := 1; i < len(ref); i++ { + // NOTE(sr): In the first iteration, non-string values in the refs are forbidden + // except for the last position, e.g. + // OK: p.q.r[s] + // NOT OK: p[q].r.s + // TODO(sr): This is stricter than necessary. We could allow any non-var values there, + // but we'll also have to adjust the type tree, for example. + if i != len(ref)-1 { // last + if _, ok := ref[i].Value.(String); !ok { + c.err(NewError(TypeErr, rule.Loc(), "rule head must only contain string terms (except for last): %v", ref[i])) + continue + } + } + + // Rewrite so that any non-scalar elements that in the last position of + // the rule are vars: + // p.q.r[y.z] { ... } => p.q.r[__local0__] { __local0__ = y.z } + // because that's what the RuleTree knows how to deal with. + if _, ok := ref[i].Value.(Var); !ok && !IsScalar(ref[i].Value) { + expr := f.Generate(ref[i]) + if i == len(ref)-1 && rule.Head.Key.Equal(ref[i]) { + rule.Head.Key = expr.Operand(0) + } + rule.Head.Reference[i] = expr.Operand(0) + rule.Body.Append(expr) + } + } + + return true + }) + } +} + func (c *Compiler) checkVoidCalls() { for _, name := range c.sorted { mod := c.Modules[name] @@ -1373,12 +1750,14 @@ func (c *Compiler) rewritePrintCalls() { WalkRules(mod, func(r *Rule) bool { safe := r.Head.Args.Vars() safe.Update(ReservedVars) - WalkBodies(r, func(b Body) bool { + vis := func(b Body) bool { for _, err := range rewritePrintCalls(c.localvargen, c.GetArity, safe, b) { c.err(err) } return false - }) + } + WalkBodies(r.Head, vis) + WalkBodies(r.Body, vis) return false }) } @@ -1407,11 +1786,11 @@ func checkVoidCalls(env *TypeEnv, x interface{}) Errors { // // For example, given the following print statement: // -// print("the value of x is:", input.x) +// print("the value of x is:", input.x) // // The expression would be rewritten to: // -// print({__local0__ | __local0__ = "the value of x is:"}, {__local1__ | __local1__ = input.x}) +// print({__local0__ | __local0__ = "the value of x is:"}, {__local1__ | __local1__ = input.x}) func rewritePrintCalls(gen *localVarGenerator, getArity func(Ref) int, globals VarSet, body Body) Errors { var errs Errors @@ -1419,7 +1798,7 @@ func rewritePrintCalls(gen *localVarGenerator, getArity func(Ref) int, globals V // Visit comprehension bodies recursively to ensure print statements inside // those bodies only close over variables that are safe. for i := range body { - if ContainsComprehensions(body[i]) { + if ContainsClosures(body[i]) { safe := outputVarsForBody(body[:i], getArity, globals) safe.Update(globals) WalkClosures(body[i], func(x interface{}) bool { @@ -1430,6 +1809,9 @@ func rewritePrintCalls(gen *localVarGenerator, getArity func(Ref) int, globals V errs = rewritePrintCalls(gen, getArity, safe, x.Body) case *ObjectComprehension: errs = rewritePrintCalls(gen, getArity, safe, x.Body) + case *Every: + safe.Update(x.KeyValueVars()) + errs = rewritePrintCalls(gen, getArity, safe, x.Body) } return true }) @@ -1491,6 +1873,8 @@ func erasePrintCalls(node interface{}) { x.Body = erasePrintCallsInBody(x.Body) case *ObjectComprehension: x.Body = erasePrintCallsInBody(x.Body) + case *Every: + x.Body = erasePrintCallsInBody(x.Body) } return false }).Walk(node) @@ -1540,7 +1924,7 @@ func isPrintCall(x *Expr) bool { return x.IsCall() && x.Operator().Equal(Print.Ref()) } -// rewriteTermsInHead will rewrite rules so that the head does not contain any +// rewriteRefsInHead will rewrite rules so that the head does not contain any // terms that require evaluation (e.g., refs or comprehensions). If the key or // value contains one or more of these terms, the key or value will be moved // into the body and assigned to a new variable. The new variable will replace @@ -1598,87 +1982,350 @@ func (c *Compiler) rewriteDynamicTerms() { } } -func (c *Compiler) rewriteLocalVars() { +func (c *Compiler) parseMetadataBlocks() { + // Only parse annotations if rego.metadata built-ins are called + regoMetadataCalled := false + for _, name := range c.sorted { + mod := c.Modules[name] + WalkExprs(mod, func(expr *Expr) bool { + if isRegoMetadataChainCall(expr) || isRegoMetadataRuleCall(expr) { + regoMetadataCalled = true + } + return regoMetadataCalled + }) + + if regoMetadataCalled { + break + } + } + + if regoMetadataCalled { + // NOTE: Possible optimization: only parse annotations for modules on the path of rego.metadata-calling module + for _, name := range c.sorted { + mod := c.Modules[name] + + if len(mod.Annotations) == 0 { + var errs Errors + mod.Annotations, errs = parseAnnotations(mod.Comments) + errs = append(errs, attachAnnotationsNodes(mod)...) + for _, err := range errs { + c.err(err) + } + } + } + } +} + +func (c *Compiler) rewriteRegoMetadataCalls() { + eqFactory := newEqualityFactory(c.localvargen) + + _, chainFuncAllowed := c.builtins[RegoMetadataChain.Name] + _, ruleFuncAllowed := c.builtins[RegoMetadataRule.Name] for _, name := range c.sorted { mod := c.Modules[name] - gen := c.localvargen WalkRules(mod, func(rule *Rule) bool { + var firstChainCall *Expr + var firstRuleCall *Expr + + WalkExprs(rule, func(expr *Expr) bool { + if chainFuncAllowed && firstChainCall == nil && isRegoMetadataChainCall(expr) { + firstChainCall = expr + } else if ruleFuncAllowed && firstRuleCall == nil && isRegoMetadataRuleCall(expr) { + firstRuleCall = expr + } + return firstChainCall != nil && firstRuleCall != nil + }) - // Rewrite assignments contained in head of rule. Assignments can - // occur in rule head if they're inside a comprehension. Note, - // assigned vars in comprehensions in the head will be rewritten - // first to preserve scoping rules. For example: - // - // p = [x | x := 1] { x := 2 } becomes p = [__local0__ | __local0__ = 1] { __local1__ = 2 } - // - // This behaviour is consistent scoping inside the body. For example: - // - // p = xs { x := 2; xs = [x | x := 1] } becomes p = xs { __local0__ = 2; xs = [__local1__ | __local1__ = 1] } - nestedXform := &rewriteNestedHeadVarLocalTransform{ - gen: gen, - RewrittenVars: c.RewrittenVars, - } + chainCalled := firstChainCall != nil + ruleCalled := firstRuleCall != nil - NewGenericVisitor(nestedXform.Visit).Walk(rule.Head) + if chainCalled || ruleCalled { + body := make(Body, 0, len(rule.Body)+2) - for _, err := range nestedXform.errs { - c.err(err) - } + var metadataChainVar Var + if chainCalled { + // Create and inject metadata chain for rule - // Rewrite assignments in body. - used := NewVarSet() + chain, err := createMetadataChain(c.annotationSet.Chain(rule)) + if err != nil { + c.err(err) + return false + } - if rule.Head.Key != nil { - used.Update(rule.Head.Key.Vars()) - } + chain.Location = firstChainCall.Location + eq := eqFactory.Generate(chain) + metadataChainVar = eq.Operands()[0].Value.(Var) + body.Append(eq) + } - if rule.Head.Value != nil { - used.Update(rule.Head.Value.Vars()) - } + var metadataRuleVar Var + if ruleCalled { + // Create and inject metadata for rule - stack := newLocalDeclaredVars() + var metadataRuleTerm *Term - c.rewriteLocalArgVars(gen, stack, rule) + a := getPrimaryRuleAnnotations(c.annotationSet, rule) + if a != nil { + annotObj, err := a.toObject() + if err != nil { + c.err(err) + return false + } + metadataRuleTerm = NewTerm(*annotObj) + } else { + // If rule has no annotations, assign an empty object + metadataRuleTerm = ObjectTerm() + } - body, declared, errs := rewriteLocalVars(gen, stack, used, rule.Body) - for _, err := range errs { - c.err(err) + metadataRuleTerm.Location = firstRuleCall.Location + eq := eqFactory.Generate(metadataRuleTerm) + metadataRuleVar = eq.Operands()[0].Value.(Var) + body.Append(eq) + } + + for _, expr := range rule.Body { + body.Append(expr) + } + rule.Body = body + + vis := func(b Body) bool { + for _, err := range rewriteRegoMetadataCalls(&metadataChainVar, &metadataRuleVar, b, &c.RewrittenVars) { + c.err(err) + } + return false + } + WalkBodies(rule.Head, vis) + WalkBodies(rule.Body, vis) } - // For rewritten vars use the collection of all variables that - // were in the stack at some point in time. - for k, v := range stack.rewritten { - c.RewrittenVars[k] = v + return false + }) + } +} + +func getPrimaryRuleAnnotations(as *AnnotationSet, rule *Rule) *Annotations { + annots := as.GetRuleScope(rule) + + if len(annots) == 0 { + return nil + } + + // Sort by annotation location; chain must start with annotations declared closest to rule, then going outward + sort.SliceStable(annots, func(i, j int) bool { + return annots[i].Location.Compare(annots[j].Location) > 0 + }) + + return annots[0] +} + +func rewriteRegoMetadataCalls(metadataChainVar *Var, metadataRuleVar *Var, body Body, rewrittenVars *map[Var]Var) Errors { + var errs Errors + + WalkClosures(body, func(x interface{}) bool { + switch x := x.(type) { + case *ArrayComprehension: + errs = rewriteRegoMetadataCalls(metadataChainVar, metadataRuleVar, x.Body, rewrittenVars) + case *SetComprehension: + errs = rewriteRegoMetadataCalls(metadataChainVar, metadataRuleVar, x.Body, rewrittenVars) + case *ObjectComprehension: + errs = rewriteRegoMetadataCalls(metadataChainVar, metadataRuleVar, x.Body, rewrittenVars) + case *Every: + errs = rewriteRegoMetadataCalls(metadataChainVar, metadataRuleVar, x.Body, rewrittenVars) + } + return true + }) + + for i := range body { + expr := body[i] + var metadataVar Var + + if metadataChainVar != nil && isRegoMetadataChainCall(expr) { + metadataVar = *metadataChainVar + } else if metadataRuleVar != nil && isRegoMetadataRuleCall(expr) { + metadataVar = *metadataRuleVar + } else { + continue + } + + // NOTE(johanfylling): An alternative strategy would be to walk the body and replace all operands[0] + // usages with *metadataChainVar + operands := expr.Operands() + var newExpr *Expr + if len(operands) > 0 { // There is an output var to rewrite + rewrittenVar := operands[0] + newExpr = Equality.Expr(rewrittenVar, NewTerm(metadataVar)) + } else { // No output var, just rewrite expr to metadataVar + newExpr = NewExpr(NewTerm(metadataVar)) + } + + newExpr.Generated = true + newExpr.Location = expr.Location + body.Set(newExpr, i) + } + + return errs +} + +func isRegoMetadataChainCall(x *Expr) bool { + return x.IsCall() && x.Operator().Equal(RegoMetadataChain.Ref()) +} + +func isRegoMetadataRuleCall(x *Expr) bool { + return x.IsCall() && x.Operator().Equal(RegoMetadataRule.Ref()) +} + +func createMetadataChain(chain []*AnnotationsRef) (*Term, *Error) { + + metaArray := NewArray() + for _, link := range chain { + p := link.Path.toArray(). + Slice(1, -1) // Dropping leading 'data' element of path + obj := NewObject( + Item(StringTerm("path"), NewTerm(p)), + ) + if link.Annotations != nil { + annotObj, err := link.Annotations.toObject() + if err != nil { + return nil, err } + obj.Insert(StringTerm("annotations"), NewTerm(*annotObj)) + } + metaArray = metaArray.Append(NewTerm(obj)) + } + + return NewTerm(metaArray), nil +} - rule.Body = body +func (c *Compiler) rewriteLocalVars() { - // Rewrite vars in head that refer to locally declared vars in the body. - localXform := rewriteHeadVarLocalTransform{declared: declared} + for _, name := range c.sorted { + mod := c.Modules[name] + gen := c.localvargen + + WalkRules(mod, func(rule *Rule) bool { + argsStack := newLocalDeclaredVars() - for i := range rule.Head.Args { - rule.Head.Args[i], _ = transformTerm(localXform, rule.Head.Args[i]) + args := NewVarVisitor() + if c.strict { + args.Walk(rule.Head.Args) } + unusedArgs := args.Vars() - if rule.Head.Key != nil { - rule.Head.Key, _ = transformTerm(localXform, rule.Head.Key) + c.rewriteLocalArgVars(gen, argsStack, rule) + + // Rewrite local vars in each else-branch of the rule. + // Note: this is done instead of a walk so that we can capture any unused function arguments + // across else-branches. + for rule := rule; rule != nil; rule = rule.Else { + stack, errs := c.rewriteLocalVarsInRule(rule, unusedArgs, argsStack, gen) + + for arg := range unusedArgs { + if stack.Count(arg) > 1 { + delete(unusedArgs, arg) + } + } + + for _, err := range errs { + c.err(err) + } } - if rule.Head.Value != nil { - rule.Head.Value, _ = transformTerm(localXform, rule.Head.Value) + if c.strict { + // Report an error for each unused function argument + for arg := range unusedArgs { + if !arg.IsWildcard() { + c.err(NewError(CompileErr, rule.Head.Location, "unused argument %v", arg)) + } + } } - return false + return true }) } } +func (c *Compiler) rewriteLocalVarsInRule(rule *Rule, unusedArgs VarSet, argsStack *localDeclaredVars, gen *localVarGenerator) (*localDeclaredVars, Errors) { + // Rewrite assignments contained in head of rule. Assignments can + // occur in rule head if they're inside a comprehension. Note, + // assigned vars in comprehensions in the head will be rewritten + // first to preserve scoping rules. For example: + // + // p = [x | x := 1] { x := 2 } becomes p = [__local0__ | __local0__ = 1] { __local1__ = 2 } + // + // This behaviour is consistent scoping inside the body. For example: + // + // p = xs { x := 2; xs = [x | x := 1] } becomes p = xs { __local0__ = 2; xs = [__local1__ | __local1__ = 1] } + nestedXform := &rewriteNestedHeadVarLocalTransform{ + gen: gen, + RewrittenVars: c.RewrittenVars, + strict: c.strict, + } + + NewGenericVisitor(nestedXform.Visit).Walk(rule.Head) + + for _, err := range nestedXform.errs { + c.err(err) + } + + // Rewrite assignments in body. + used := NewVarSet() + + last := rule.Head.Ref()[len(rule.Head.Ref())-1] + used.Update(last.Vars()) + + if rule.Head.Key != nil { + used.Update(rule.Head.Key.Vars()) + } + + if rule.Head.Value != nil { + valueVars := rule.Head.Value.Vars() + used.Update(valueVars) + for arg := range unusedArgs { + if valueVars.Contains(arg) { + delete(unusedArgs, arg) + } + } + } + + stack := argsStack.Copy() + + body, declared, errs := rewriteLocalVars(gen, stack, used, rule.Body, c.strict) + + // For rewritten vars use the collection of all variables that + // were in the stack at some point in time. + for k, v := range stack.rewritten { + c.RewrittenVars[k] = v + } + + rule.Body = body + + // Rewrite vars in head that refer to locally declared vars in the body. + localXform := rewriteHeadVarLocalTransform{declared: declared} + + for i := range rule.Head.Args { + rule.Head.Args[i], _ = transformTerm(localXform, rule.Head.Args[i]) + } + + for i := 1; i < len(rule.Head.Ref()); i++ { + rule.Head.Reference[i], _ = transformTerm(localXform, rule.Head.Ref()[i]) + } + if rule.Head.Key != nil { + rule.Head.Key, _ = transformTerm(localXform, rule.Head.Key) + } + + if rule.Head.Value != nil { + rule.Head.Value, _ = transformTerm(localXform, rule.Head.Value) + } + return stack, errs +} + type rewriteNestedHeadVarLocalTransform struct { gen *localVarGenerator errs Errors RewrittenVars map[Var]Var + strict bool } func (xform *rewriteNestedHeadVarLocalTransform) Visit(x interface{}) bool { @@ -1708,13 +2355,13 @@ func (xform *rewriteNestedHeadVarLocalTransform) Visit(x interface{}) bool { term.Value = cpy stop = true case *ArrayComprehension: - xform.errs = rewriteDeclaredVarsInArrayComprehension(xform.gen, stack, x, xform.errs) + xform.errs = rewriteDeclaredVarsInArrayComprehension(xform.gen, stack, x, xform.errs, xform.strict) stop = true case *SetComprehension: - xform.errs = rewriteDeclaredVarsInSetComprehension(xform.gen, stack, x, xform.errs) + xform.errs = rewriteDeclaredVarsInSetComprehension(xform.gen, stack, x, xform.errs, xform.strict) stop = true case *ObjectComprehension: - xform.errs = rewriteDeclaredVarsInObjectComprehension(xform.gen, stack, x, xform.errs) + xform.errs = rewriteDeclaredVarsInObjectComprehension(xform.gen, stack, x, xform.errs, xform.strict) stop = true } @@ -1773,7 +2420,9 @@ func (vis *ruleArgLocalRewriter) Visit(x interface{}) Visitor { switch v := t.Value.(type) { case Var: gv, ok := vis.stack.Declared(v) - if !ok { + if ok { + vis.stack.Seen(v) + } else { gv = vis.gen.Generate() vis.stack.Insert(v, gv, argVar) } @@ -1813,7 +2462,7 @@ func (c *Compiler) rewriteWithModifiers() { if !ok { return x, nil } - body, err := rewriteWithModifiersInBody(c, f, body) + body, err := rewriteWithModifiersInBody(c, c.unsafeBuiltinsMap, f, body) if err != nil { c.err(err) } @@ -1860,6 +2509,11 @@ func newQueryCompiler(compiler *Compiler) QueryCompiler { return qc } +func (qc *queryCompiler) WithStrict(strict bool) QueryCompiler { + qc.compiler.WithStrict(strict) + return qc +} + func (qc *queryCompiler) WithEnablePrintStatements(yes bool) QueryCompiler { qc.enablePrintStatements = yes return qc @@ -1921,6 +2575,7 @@ func (qc *queryCompiler) Compile(query Body) (Body, error) { metricName string f func(*QueryContext, Body) (Body, error) }{ + {"CheckKeywordOverrides", "query_compile_stage_check_keyword_overrides", qc.checkKeywordOverrides}, {"ResolveRefs", "query_compile_stage_resolve_refs", qc.resolveRefs}, {"RewriteLocalVars", "query_compile_stage_rewrite_local_vars", qc.rewriteLocalVars}, {"CheckVoidCalls", "query_compile_stage_check_void_calls", qc.checkVoidCalls}, @@ -1933,6 +2588,7 @@ func (qc *queryCompiler) Compile(query Body) (Body, error) { {"RewriteDynamicTerms", "query_compile_stage_rewrite_dynamic_terms", qc.rewriteDynamicTerms}, {"CheckTypes", "query_compile_stage_check_types", qc.checkTypes}, {"CheckUnsafeBuiltins", "query_compile_stage_check_unsafe_builtins", qc.checkUnsafeBuiltins}, + {"CheckDeprecatedBuiltins", "query_compile_stage_check_deprecated_builtins", qc.checkDeprecatedBuiltins}, {"BuildComprehensionIndex", "query_compile_stage_build_comprehension_index", qc.buildComprehensionIndices}, } @@ -1968,9 +2624,16 @@ func (qc *queryCompiler) applyErrorLimit(err error) error { return err } +func (qc *queryCompiler) checkKeywordOverrides(_ *QueryContext, body Body) (Body, error) { + if errs := checkKeywordOverrides(body, qc.compiler.strict); len(errs) > 0 { + return nil, errs + } + return body, nil +} + func (qc *queryCompiler) resolveRefs(qctx *QueryContext, body Body) (Body, error) { - var globals map[Var]Ref + var globals map[Var]*usedRef if qctx != nil { pkg := qctx.Package @@ -1980,10 +2643,10 @@ func (qc *queryCompiler) resolveRefs(qctx *QueryContext, body Body) (Body, error pkg = &Package{Path: RefTerm(VarTerm("")).Value.(Ref)} } if pkg != nil { - var ruleExports []Var + var ruleExports []Ref rules := qc.compiler.getExports() if exist, ok := rules.Get(pkg.Path); ok { - ruleExports = exist.([]Var) + ruleExports = exist.([]Ref) } globals = getGlobals(qctx.Package, ruleExports, qctx.Imports) @@ -2020,7 +2683,7 @@ func (qc *queryCompiler) rewriteExprTerms(_ *QueryContext, body Body) (Body, err func (qc *queryCompiler) rewriteLocalVars(_ *QueryContext, body Body) (Body, error) { gen := newLocalVarGenerator("q", body) stack := newLocalDeclaredVars() - body, _, err := rewriteLocalVars(gen, stack, nil, body) + body, _, err := rewriteLocalVars(gen, stack, nil, body, qc.compiler.strict) if len(err) != 0 { return nil, err } @@ -2083,13 +2746,22 @@ func (qc *queryCompiler) checkTypes(_ *QueryContext, body Body) (Body, error) { } func (qc *queryCompiler) checkUnsafeBuiltins(_ *QueryContext, body Body) (Body, error) { - var unsafe map[string]struct{} + errs := checkUnsafeBuiltins(qc.unsafeBuiltinsMap(), body) + if len(errs) > 0 { + return nil, errs + } + return body, nil +} + +func (qc *queryCompiler) unsafeBuiltinsMap() map[string]struct{} { if qc.unsafeBuiltins != nil { - unsafe = qc.unsafeBuiltins - } else { - unsafe = qc.compiler.unsafeBuiltinsMap + return qc.unsafeBuiltins } - errs := checkUnsafeBuiltins(unsafe, body) + return qc.compiler.unsafeBuiltinsMap +} + +func (qc *queryCompiler) checkDeprecatedBuiltins(_ *QueryContext, body Body) (Body, error) { + errs := checkDeprecatedBuiltins(qc.compiler.deprecatedBuiltinsMap, body, qc.compiler.strict) if len(errs) > 0 { return nil, errs } @@ -2098,7 +2770,7 @@ func (qc *queryCompiler) checkUnsafeBuiltins(_ *QueryContext, body Body) (Body, func (qc *queryCompiler) rewriteWithModifiers(_ *QueryContext, body Body) (Body, error) { f := newEqualityFactory(newLocalVarGenerator("q", body)) - body, err := rewriteWithModifiersInBody(qc.compiler, f, body) + body, err := rewriteWithModifiersInBody(qc.compiler, qc.unsafeBuiltinsMap(), f, body) if err != nil { return nil, Errors{err} } @@ -2357,13 +3029,29 @@ type ModuleTreeNode struct { Hide bool } +func (n *ModuleTreeNode) String() string { + var rules []string + for _, m := range n.Modules { + for _, r := range m.Rules { + rules = append(rules, r.Head.String()) + } + } + return fmt.Sprintf("", n.Key, n.Children, rules, n.Hide) +} + // NewModuleTree returns a new ModuleTreeNode that represents the root // of the module tree populated with the given modules. func NewModuleTree(mods map[string]*Module) *ModuleTreeNode { root := &ModuleTreeNode{ Children: map[Value]*ModuleTreeNode{}, } - for _, m := range mods { + names := make([]string, 0, len(mods)) + for name := range mods { + names = append(names, name) + } + sort.Strings(names) + for _, name := range names { + m := mods[name] node := root for i, x := range m.Package.Path { c, ok := node.Children[x.Value] @@ -2395,13 +3083,43 @@ func (n *ModuleTreeNode) Size() int { return s } +// Child returns n's child with key k. +func (n *ModuleTreeNode) child(k Value) *ModuleTreeNode { + switch k.(type) { + case String, Var: + return n.Children[k] + } + return nil +} + +// Find dereferences ref along the tree. ref[0] is converted to a String +// for convenience. +func (n *ModuleTreeNode) find(ref Ref) (*ModuleTreeNode, Ref) { + if v, ok := ref[0].Value.(Var); ok { + ref = Ref{StringTerm(string(v))}.Concat(ref[1:]) + } + node := n + for i, r := range ref { + next := node.child(r.Value) + if next == nil { + tail := make(Ref, len(ref)-i) + tail[0] = VarTerm(string(ref[i].Value.(String))) + copy(tail[1:], ref[i+1:]) + return node, tail + } + node = next + } + return node, nil +} + // DepthFirst performs a depth-first traversal of the module tree rooted at n. // If f returns true, traversal will not continue to the children of n. -func (n *ModuleTreeNode) DepthFirst(f func(node *ModuleTreeNode) bool) { - if !f(n) { - for _, node := range n.Children { - node.DepthFirst(f) - } +func (n *ModuleTreeNode) DepthFirst(f func(*ModuleTreeNode) bool) { + if f(n) { + return + } + for _, node := range n.Children { + node.DepthFirst(f) } } @@ -2415,78 +3133,164 @@ type TreeNode struct { Hide bool } +func (n *TreeNode) String() string { + return fmt.Sprintf("", n.Key, n.Values, n.Sorted, n.Hide) +} + // NewRuleTree returns a new TreeNode that represents the root // of the rule tree populated with the given rules. func NewRuleTree(mtree *ModuleTreeNode) *TreeNode { + root := TreeNode{ + Key: mtree.Key, + } + + mtree.DepthFirst(func(m *ModuleTreeNode) bool { + for _, mod := range m.Modules { + if len(mod.Rules) == 0 { + root.add(mod.Package.Path, nil) + } + for _, rule := range mod.Rules { + root.add(rule.Ref().GroundPrefix(), rule) + } + } + return false + }) - ruleSets := map[String][]util.T{} + // ensure that data.system's TreeNode is hidden + node, tail := root.find(DefaultRootRef.Append(NewTerm(SystemDocumentKey))) + if len(tail) == 0 { // found + node.Hide = true + } - // Build rule sets for this package. - for _, mod := range mtree.Modules { - for _, rule := range mod.Rules { - key := String(rule.Head.Name) - ruleSets[key] = append(ruleSets[key], rule) + root.DepthFirst(func(x *TreeNode) bool { + x.sort() + return false + }) + + return &root +} + +func (n *TreeNode) add(path Ref, rule *Rule) { + node, tail := n.find(path) + if len(tail) > 0 { + sub := treeNodeFromRef(tail, rule) + if node.Children == nil { + node.Children = make(map[Value]*TreeNode, 1) + } + node.Children[sub.Key] = sub + node.Sorted = append(node.Sorted, sub.Key) + } else { + if rule != nil { + node.Values = append(node.Values, rule) } } +} + +// Size returns the number of rules in the tree. +func (n *TreeNode) Size() int { + s := len(n.Values) + for _, c := range n.Children { + s += c.Size() + } + return s +} + +// Child returns n's child with key k. +func (n *TreeNode) Child(k Value) *TreeNode { + switch k.(type) { + case Ref, Call: + return nil + default: + return n.Children[k] + } +} - // Each rule set becomes a leaf node. - children := map[Value]*TreeNode{} - sorted := make([]Value, 0, len(ruleSets)) +// Find dereferences ref along the tree +func (n *TreeNode) Find(ref Ref) *TreeNode { + node := n + for _, r := range ref { + node = node.Child(r.Value) + if node == nil { + return nil + } + } + return node +} - for key, rules := range ruleSets { - sorted = append(sorted, key) - children[key] = &TreeNode{ - Key: key, - Children: nil, - Values: rules, +// Iteratively dereferences ref along the node's subtree. +// - If matching fails immediately, the tail will contain the full ref. +// - Partial matching will result in a tail of non-zero length. +// - A complete match will result in a 0 length tail. +func (n *TreeNode) find(ref Ref) (*TreeNode, Ref) { + node := n + for i := range ref { + next := node.Child(ref[i].Value) + if next == nil { + tail := make(Ref, len(ref)-i) + copy(tail, ref[i:]) + return node, tail } + node = next } + return node, nil +} - // Each module in subpackage becomes child node. - for key, child := range mtree.Children { - sorted = append(sorted, key) - children[child.Key] = NewRuleTree(child) +// DepthFirst performs a depth-first traversal of the rule tree rooted at n. If +// f returns true, traversal will not continue to the children of n. +func (n *TreeNode) DepthFirst(f func(*TreeNode) bool) { + if f(n) { + return + } + for _, node := range n.Children { + node.DepthFirst(f) } +} - sort.Slice(sorted, func(i, j int) bool { - return sorted[i].Compare(sorted[j]) < 0 +func (n *TreeNode) sort() { + sort.Slice(n.Sorted, func(i, j int) bool { + return n.Sorted[i].Compare(n.Sorted[j]) < 0 }) - - return &TreeNode{ - Key: mtree.Key, - Values: nil, - Children: children, - Sorted: sorted, - Hide: mtree.Hide, - } } -// Size returns the number of rules in the tree. -func (n *TreeNode) Size() int { - s := len(n.Values) - for _, c := range n.Children { - s += c.Size() +func treeNodeFromRef(ref Ref, rule *Rule) *TreeNode { + depth := len(ref) - 1 + key := ref[depth].Value + node := &TreeNode{ + Key: key, + Children: nil, + } + if rule != nil { + node.Values = []util.T{rule} } - return s -} -// Child returns n's child with key k. -func (n *TreeNode) Child(k Value) *TreeNode { - switch k.(type) { - case String, Var: - return n.Children[k] + for i := len(ref) - 2; i >= 0; i-- { + key := ref[i].Value + node = &TreeNode{ + Key: key, + Children: map[Value]*TreeNode{ref[i+1].Value: node}, + Sorted: []Value{ref[i+1].Value}, + } } - return nil + return node } -// DepthFirst performs a depth-first traversal of the rule tree rooted at n. If -// f returns true, traversal will not continue to the children of n. -func (n *TreeNode) DepthFirst(f func(node *TreeNode) bool) { - if !f(n) { - for _, node := range n.Children { - node.DepthFirst(f) - } +// flattenChildren flattens all children's rule refs into a sorted array. +func (n *TreeNode) flattenChildren() []Ref { + ret := newRefSet() + for _, sub := range n.Children { // we only want the children, so don't use n.DepthFirst() right away + sub.DepthFirst(func(x *TreeNode) bool { + for _, r := range x.Values { + rule := r.(*Rule) + ret.AddPrefix(rule.Ref()) + } + return false + }) } + + sort.Slice(ret.s, func(i, j int) bool { + return ret.s[i].Compare(ret.s[j]) < 0 + }) + return ret.s } // Graph represents the graph of dependencies between rules. @@ -2532,7 +3336,7 @@ func NewGraph(modules map[string]*Module, list func(Ref) []*Rule) *Graph { }) } - // Walk over all rules, add them to graph, and build adjencency lists. + // Walk over all rules, add them to graph, and build adjacency lists. for _, module := range modules { WalkRules(module, func(a *Rule) bool { graph.addNode(a) @@ -2752,13 +3556,10 @@ func (vs unsafeVars) Slice() (result []unsafePair) { // contains a mapping of expressions to unsafe variables in those expressions. func reorderBodyForSafety(builtins map[string]*Builtin, arity func(Ref) int, globals VarSet, body Body) (Body, unsafeVars) { - body, unsafe := reorderBodyForClosures(arity, globals, body) - if len(unsafe) != 0 { - return nil, unsafe - } - - reordered := Body{} + bodyVars := body.Vars(SafetyCheckVisitorParams) + reordered := make(Body, 0, len(body)) safe := VarSet{} + unsafe := unsafeVars{} for _, e := range body { for v := range e.Vars(SafetyCheckVisitorParams) { @@ -2778,10 +3579,23 @@ func reorderBodyForSafety(builtins map[string]*Builtin, arity func(Ref) int, glo continue } - safe.Update(outputVarsForExpr(e, arity, safe)) + ovs := outputVarsForExpr(e, arity, safe) + + // check closures: is this expression closing over variables that + // haven't been made safe by what's already included in `reordered`? + vs := unsafeVarsInClosures(e, arity, safe) + cv := vs.Intersect(bodyVars).Diff(globals) + uv := cv.Diff(outputVarsForBody(reordered, arity, safe)) + + if len(uv) > 0 { + if uv.Equal(ovs) { // special case "closure-self" + continue + } + unsafe.Set(e, uv) + } for v := range unsafe[e] { - if safe.Contains(v) { + if ovs.Contains(v) || safe.Contains(v) { delete(unsafe[e], v) } } @@ -2789,10 +3603,11 @@ func reorderBodyForSafety(builtins map[string]*Builtin, arity func(Ref) int, glo if len(unsafe[e]) == 0 { delete(unsafe, e) reordered.Append(e) + safe.Update(ovs) // this expression's outputs are safe } } - if len(reordered) == n { + if len(reordered) == n { // fixed point, could not add any expr of body break } } @@ -2827,7 +3642,8 @@ type bodySafetyTransformer struct { } func (xform *bodySafetyTransformer) Visit(x interface{}) bool { - if term, ok := x.(*Term); ok { + switch term := x.(type) { + case *Term: switch x := term.Value.(type) { case *object: cpy, _ := x.Map(func(k, v *Term) (*Term, *Term, error) { @@ -2857,6 +3673,12 @@ func (xform *bodySafetyTransformer) Visit(x interface{}) bool { xform.reorderSetComprehensionSafety(x) return true } + case *Expr: + if ev, ok := term.Terms.(*Every); ok { + xform.globals.Update(ev.KeyValueVars()) + ev.Body = xform.reorderComprehensionSafety(NewVarSet(), ev.Body) + return true + } } return false } @@ -2892,51 +3714,20 @@ func (xform *bodySafetyTransformer) reorderSetComprehensionSafety(sc *SetCompreh sc.Body = xform.reorderComprehensionSafety(sc.Term.Vars(), sc.Body) } -// reorderBodyForClosures returns a copy of the body ordered such that -// expressions (such as array comprehensions) that close over variables are ordered -// after other expressions that contain the same variable in an output position. -func reorderBodyForClosures(arity func(Ref) int, globals VarSet, body Body) (Body, unsafeVars) { - - reordered := Body{} - unsafe := unsafeVars{} - - for { - n := len(reordered) - - for _, e := range body { - if reordered.Contains(e) { - continue - } - - // Collect vars that are contained in closures within this - // expression. - vs := VarSet{} - WalkClosures(e, func(x interface{}) bool { - vis := &VarVisitor{vars: vs} - vis.Walk(x) - return true - }) - - // Compute vars that are closed over from the body but not yet - // contained in the output position of an expression in the reordered - // body. These vars are considered unsafe. - cv := vs.Intersect(body.Vars(SafetyCheckVisitorParams)).Diff(globals) - uv := cv.Diff(outputVarsForBody(reordered, arity, globals)) - - if len(uv) == 0 { - reordered = append(reordered, e) - delete(unsafe, e) - } else { - unsafe.Set(e, uv) - } - } - - if len(reordered) == n { - break +// unsafeVarsInClosures collects vars that are contained in closures within +// this expression. +func unsafeVarsInClosures(e *Expr, arity func(Ref) int, safe VarSet) VarSet { + vs := VarSet{} + WalkClosures(e, func(x interface{}) bool { + vis := &VarVisitor{vars: vs} + if ev, ok := x.(*Every); ok { + vis.Walk(ev.Body) + return true } - } - - return reordered, unsafe + vis.Walk(x) + return true + }) + return vs } // OutputVarsFromBody returns all variables which are the "output" for @@ -2970,15 +3761,11 @@ func outputVarsForExpr(expr *Expr, arity func(Ref) int, safe VarSet) VarSet { // With modifier inputs must be safe. for _, with := range expr.With { - unsafe := false - WalkVars(with, func(v Var) bool { - if !safe.Contains(v) { - unsafe = true - return true - } - return false - }) - if unsafe { + vis := NewVarVisitor().WithParams(SafetyCheckVisitorParams) + vis.Walk(with) + vars := vis.Vars() + unsafe := vars.Diff(safe) + if len(unsafe) > 0 { return VarSet{} } } @@ -3002,6 +3789,8 @@ func outputVarsForExpr(expr *Expr, arity func(Ref) int, safe VarSet) VarSet { } return outputVarsForExprCall(expr, ar, safe, terms) + case *Every: + return outputVarsForTerms(terms.Domain, safe) default: panic("illegal expression") } @@ -3029,13 +3818,13 @@ func outputVarsForExprCall(expr *Expr, arity int, safe VarSet, terms []*Term) Va return output } - vis := NewVarVisitor().WithParams(VarVisitorParams{ + params := VarVisitorParams{ SkipClosures: true, SkipSets: true, SkipObjectKeys: true, SkipRefHead: true, - }) - + } + vis := NewVarVisitor().WithParams(params) vis.Walk(Args(terms[:numInputTerms])) unsafe := vis.Vars().Diff(output).Diff(safe) @@ -3043,19 +3832,13 @@ func outputVarsForExprCall(expr *Expr, arity int, safe VarSet, terms []*Term) Va return VarSet{} } - vis = NewVarVisitor().WithParams(VarVisitorParams{ - SkipRefHead: true, - SkipSets: true, - SkipObjectKeys: true, - SkipClosures: true, - }) - + vis = NewVarVisitor().WithParams(params) vis.Walk(Args(terms[numInputTerms:])) output.Update(vis.vars) return output } -func outputVarsForTerms(expr *Expr, safe VarSet) VarSet { +func outputVarsForTerms(expr interface{}, safe VarSet) VarSet { output := VarSet{} WalkTerms(expr, func(x *Term) bool { switch r := x.Value.(type) { @@ -3121,31 +3904,23 @@ func (l *localVarGenerator) Generate() Var { } } -func getGlobals(pkg *Package, rules []Var, imports []*Import) map[Var]Ref { +func getGlobals(pkg *Package, rules []Ref, imports []*Import) map[Var]*usedRef { - globals := map[Var]Ref{} + globals := make(map[Var]*usedRef, len(rules)) // NB: might grow bigger with imports // Populate globals with exports within the package. - for _, v := range rules { - global := append(Ref{}, pkg.Path...) - global = append(global, &Term{Value: String(v)}) - globals[v] = global + for _, ref := range rules { + v := ref[0].Value.(Var) + globals[v] = &usedRef{ref: pkg.Path.Append(StringTerm(string(v)))} } // Populate globals with imports. - for _, i := range imports { - if len(i.Alias) > 0 { - path := i.Path.Value.(Ref) - globals[i.Alias] = path - } else { - path := i.Path.Value.(Ref) - if len(path) == 1 { - globals[path[0].Value.(Var)] = path - } else { - v := path[len(path)-1].Value.(String) - globals[Var(v)] = path - } + for _, imp := range imports { + path := imp.Path.Value.(Ref) + if FutureRootDocument.Equal(path[0]) { + continue // ignore future imports } + globals[imp.Name()] = &usedRef{ref: path} } return globals @@ -3158,14 +3933,14 @@ func requiresEval(x *Term) bool { return ContainsRefs(x) || ContainsComprehensions(x) } -func resolveRef(globals map[Var]Ref, ignore *declaredVarStack, ref Ref) Ref { +func resolveRef(globals map[Var]*usedRef, ignore *declaredVarStack, ref Ref) Ref { r := Ref{} for i, x := range ref { switch v := x.Value.(type) { case Var: if g, ok := globals[v]; ok && !ignore.Contains(v) { - cpy := g.Copy() + cpy := g.ref.Copy() for i := range cpy { cpy[i].SetLocation(x.Location) } @@ -3174,6 +3949,7 @@ func resolveRef(globals map[Var]Ref, ignore *declaredVarStack, ref Ref) Ref { } else { r = append(r, NewTerm(cpy).SetLocation(x.Location)) } + g.used = true } else { r = append(r, x) } @@ -3187,7 +3963,12 @@ func resolveRef(globals map[Var]Ref, ignore *declaredVarStack, ref Ref) Ref { return r } -func resolveRefsInRule(globals map[Var]Ref, rule *Rule) error { +type usedRef struct { + ref Ref + used bool +} + +func resolveRefsInRule(globals map[Var]*usedRef, rule *Rule) error { ignore := &declaredVarStack{} vars := NewVarSet() @@ -3238,6 +4019,10 @@ func resolveRefsInRule(globals map[Var]Ref, rule *Rule) error { ignore.Push(vars) ignore.Push(declaredVars(rule.Body)) + ref := rule.Head.Ref() + for i := 1; i < len(ref); i++ { + ref[i] = resolveRefsInTerm(globals, ignore, ref[i]) + } if rule.Head.Key != nil { rule.Head.Key = resolveRefsInTerm(globals, ignore, rule.Head.Key) } @@ -3250,15 +4035,15 @@ func resolveRefsInRule(globals map[Var]Ref, rule *Rule) error { return nil } -func resolveRefsInBody(globals map[Var]Ref, ignore *declaredVarStack, body Body) Body { - r := Body{} +func resolveRefsInBody(globals map[Var]*usedRef, ignore *declaredVarStack, body Body) Body { + r := make([]*Expr, 0, len(body)) for _, expr := range body { r = append(r, resolveRefsInExpr(globals, ignore, expr)) } return r } -func resolveRefsInExpr(globals map[Var]Ref, ignore *declaredVarStack, expr *Expr) *Expr { +func resolveRefsInExpr(globals map[Var]*usedRef, ignore *declaredVarStack, expr *Expr) *Expr { cpy := *expr switch ts := expr.Terms.(type) { case *Term: @@ -3273,6 +4058,20 @@ func resolveRefsInExpr(globals map[Var]Ref, ignore *declaredVarStack, expr *Expr if val, ok := ts.Symbols[0].Value.(Call); ok { cpy.Terms = &SomeDecl{Symbols: []*Term{CallTerm(resolveRefsInTermSlice(globals, ignore, val)...)}} } + case *Every: + locals := NewVarSet() + if ts.Key != nil { + locals.Update(ts.Key.Vars()) + } + locals.Update(ts.Value.Vars()) + ignore.Push(locals) + cpy.Terms = &Every{ + Key: ts.Key.Copy(), // TODO(sr): do more? + Value: ts.Value.Copy(), // TODO(sr): do more? + Domain: resolveRefsInTerm(globals, ignore, ts.Domain), + Body: resolveRefsInBody(globals, ignore, ts.Body), + } + ignore.Pop() } for _, w := range cpy.With { w.Target = resolveRefsInTerm(globals, ignore, w.Target) @@ -3281,14 +4080,15 @@ func resolveRefsInExpr(globals map[Var]Ref, ignore *declaredVarStack, expr *Expr return &cpy } -func resolveRefsInTerm(globals map[Var]Ref, ignore *declaredVarStack, term *Term) *Term { +func resolveRefsInTerm(globals map[Var]*usedRef, ignore *declaredVarStack, term *Term) *Term { switch v := term.Value.(type) { case Var: if g, ok := globals[v]; ok && !ignore.Contains(v) { - cpy := g.Copy() + cpy := g.ref.Copy() for i := range cpy { cpy[i].SetLocation(term.Location) } + g.used = true return NewTerm(cpy).SetLocation(term.Location) } return term @@ -3353,7 +4153,7 @@ func resolveRefsInTerm(globals map[Var]Ref, ignore *declaredVarStack, term *Term } } -func resolveRefsInTermArray(globals map[Var]Ref, ignore *declaredVarStack, terms *Array) []*Term { +func resolveRefsInTermArray(globals map[Var]*usedRef, ignore *declaredVarStack, terms *Array) []*Term { cpy := make([]*Term, terms.Len()) for i := 0; i < terms.Len(); i++ { cpy[i] = resolveRefsInTerm(globals, ignore, terms.Elem(i)) @@ -3361,7 +4161,7 @@ func resolveRefsInTermArray(globals map[Var]Ref, ignore *declaredVarStack, terms return cpy } -func resolveRefsInTermSlice(globals map[Var]Ref, ignore *declaredVarStack, terms []*Term) []*Term { +func resolveRefsInTermSlice(globals map[Var]*usedRef, ignore *declaredVarStack, terms []*Term) []*Term { cpy := make([]*Term, len(terms)) for i := 0; i < len(terms); i++ { cpy[i] = resolveRefsInTerm(globals, ignore, terms[i]) @@ -3520,11 +4320,14 @@ func rewriteEquals(x interface{}) { func rewriteDynamics(f *equalityFactory, body Body) Body { result := make(Body, 0, len(body)) for _, expr := range body { - if expr.IsEquality() { + switch { + case expr.IsEquality(): result = rewriteDynamicsEqExpr(f, expr, result) - } else if expr.IsCall() { + case expr.IsCall(): result = rewriteDynamicsCallExpr(f, expr, result) - } else { + case expr.IsEvery(): + result = rewriteDynamicsEveryExpr(f, expr, result) + default: result = rewriteDynamicsTermExpr(f, expr, result) } } @@ -3554,6 +4357,13 @@ func rewriteDynamicsCallExpr(f *equalityFactory, expr *Expr, result Body) Body { return appendExpr(result, expr) } +func rewriteDynamicsEveryExpr(f *equalityFactory, expr *Expr, result Body) Body { + ev := expr.Terms.(*Every) + result, ev.Domain = rewriteDynamicsOne(expr, f, ev.Domain, result) + ev.Body = rewriteDynamics(f, ev.Body) + return appendExpr(result, expr) +} + func rewriteDynamicsTermExpr(f *equalityFactory, expr *Expr, result Body) Body { term := expr.Terms.(*Term) result, expr.Terms = rewriteDynamicsInTerm(expr, f, term, result) @@ -3700,6 +4510,21 @@ func expandExpr(gen *localVarGenerator, expr *Expr) (result []*Expr) { result = append(result, extras...) } result = append(result, expr) + case *Every: + var extras []*Expr + if _, ok := terms.Domain.Value.(Call); ok { + extras, terms.Domain = expandExprTerm(gen, terms.Domain) + } else { + term := NewTerm(gen.Generate()).SetLocation(terms.Domain.Location) + eq := Equality.Expr(term, terms.Domain).SetLocation(terms.Domain.Location) + eq.Generated = true + eq.With = expr.With + extras = append(extras, eq) + terms.Domain = term + } + terms.Body = rewriteExprTermsInBody(gen, terms.Body) + result = append(result, extras...) + result = append(result, expr) } return } @@ -3817,8 +4642,8 @@ type localDeclaredVars struct { // rewritten contains a mapping of *all* user-defined variables // that have been rewritten whereas vars contains the state - // from the current query (not not any nested queries, and all - // vars seen). + // from the current query (not any nested queries, and all vars + // seen). rewritten map[Var]Var } @@ -3836,6 +4661,7 @@ type declaredVarSet struct { vs map[Var]Var reverse map[Var]Var occurrence map[Var]varOccurrence + count map[Var]int } func newDeclaredVarSet() *declaredVarSet { @@ -3843,6 +4669,7 @@ func newDeclaredVarSet() *declaredVarSet { vs: map[Var]Var{}, reverse: map[Var]Var{}, occurrence: map[Var]varOccurrence{}, + count: map[Var]int{}, } } @@ -3853,6 +4680,35 @@ func newLocalDeclaredVars() *localDeclaredVars { } } +func (s *localDeclaredVars) Copy() *localDeclaredVars { + stack := &localDeclaredVars{ + vars: []*declaredVarSet{}, + rewritten: map[Var]Var{}, + } + + for i := range s.vars { + stack.vars = append(stack.vars, newDeclaredVarSet()) + for k, v := range s.vars[i].vs { + stack.vars[0].vs[k] = v + } + for k, v := range s.vars[i].reverse { + stack.vars[0].reverse[k] = v + } + for k, v := range s.vars[i].count { + stack.vars[0].count[k] = v + } + for k, v := range s.vars[i].occurrence { + stack.vars[0].occurrence[k] = v + } + } + + for k, v := range s.rewritten { + stack.rewritten[k] = v + } + + return stack +} + func (s *localDeclaredVars) Push() { s.vars = append(s.vars, newDeclaredVarSet()) } @@ -3874,6 +4730,8 @@ func (s localDeclaredVars) Insert(x, y Var, occurrence varOccurrence) { elem.reverse[y] = x elem.occurrence[x] = occurrence + elem.count[x] = 1 + // If the variable has been rewritten (where x != y, with y being // the generated value), store it in the map of rewritten vars. // Assume that the generated values are unique for the compilation. @@ -3908,6 +4766,30 @@ func (s localDeclaredVars) GlobalOccurrence(x Var) (varOccurrence, bool) { return newVar, false } +// Seen marks x as seen by incrementing its counter +func (s localDeclaredVars) Seen(x Var) { + for i := len(s.vars) - 1; i >= 0; i-- { + dvs := s.vars[i] + if c, ok := dvs.count[x]; ok { + dvs.count[x] = c + 1 + return + } + } + + s.vars[len(s.vars)-1].count[x] = 1 +} + +// Count returns how many times x has been seen +func (s localDeclaredVars) Count(x Var) int { + for i := len(s.vars) - 1; i >= 0; i-- { + if c, ok := s.vars[i].count[x]; ok { + return c + } + } + + return 0 +} + // rewriteLocalVars rewrites bodies to remove assignment/declaration // expressions. For example: // @@ -3918,24 +4800,27 @@ func (s localDeclaredVars) GlobalOccurrence(x Var) (varOccurrence, bool) { // __local0__ = 1; p[__local0__] // // During rewriting, assignees are validated to prevent use before declaration. -func rewriteLocalVars(g *localVarGenerator, stack *localDeclaredVars, used VarSet, body Body) (Body, map[Var]Var, Errors) { +func rewriteLocalVars(g *localVarGenerator, stack *localDeclaredVars, used VarSet, body Body, strict bool) (Body, map[Var]Var, Errors) { var errs Errors - body, errs = rewriteDeclaredVarsInBody(g, stack, used, body, errs) - return body, stack.Pop().vs, errs + body, errs = rewriteDeclaredVarsInBody(g, stack, used, body, errs, strict) + return body, stack.Peek().vs, errs } -func rewriteDeclaredVarsInBody(g *localVarGenerator, stack *localDeclaredVars, used VarSet, body Body, errs Errors) (Body, Errors) { +func rewriteDeclaredVarsInBody(g *localVarGenerator, stack *localDeclaredVars, used VarSet, body Body, errs Errors, strict bool) (Body, Errors) { var cpy Body for i := range body { var expr *Expr - if body[i].IsAssignment() { - expr, errs = rewriteDeclaredAssignment(g, stack, body[i], errs) - } else if _, ok := body[i].Terms.(*SomeDecl); ok { - expr, errs = rewriteSomeDeclStatement(g, stack, body[i], errs) - } else { - expr, errs = rewriteDeclaredVarsInExpr(g, stack, body[i], errs) + switch { + case body[i].IsAssignment(): + expr, errs = rewriteDeclaredAssignment(g, stack, body[i], errs, strict) + case body[i].IsSome(): + expr, errs = rewriteSomeDeclStatement(g, stack, body[i], errs, strict) + case body[i].IsEvery(): + expr, errs = rewriteEveryStatement(g, stack, body[i], errs, strict) + default: + expr, errs = rewriteDeclaredVarsInExpr(g, stack, body[i], errs, strict) } if expr != nil { cpy.Append(expr) @@ -3949,10 +4834,56 @@ func rewriteDeclaredVarsInBody(g *localVarGenerator, stack *localDeclaredVars, u cpy.Append(NewExpr(BooleanTerm(true))) } - return cpy, checkUnusedDeclaredVars(body[0].Loc(), stack, used, cpy, errs) + errs = checkUnusedAssignedVars(body, stack, used, errs, strict) + return cpy, checkUnusedDeclaredVars(body, stack, used, cpy, errs) +} + +func checkUnusedAssignedVars(body Body, stack *localDeclaredVars, used VarSet, errs Errors, strict bool) Errors { + + if !strict || len(errs) > 0 { + return errs + } + + dvs := stack.Peek() + unused := NewVarSet() + + for v, occ := range dvs.occurrence { + // A var that was assigned in this scope must have been seen (used) more than once (the time of assignment) in + // the same, or nested, scope to be counted as used. + if !v.IsWildcard() && stack.Count(v) <= 1 && occ == assignedVar { + unused.Add(dvs.vs[v]) + } + } + + rewrittenUsed := NewVarSet() + for v := range used { + if gv, ok := stack.Declared(v); ok { + rewrittenUsed.Add(gv) + } else { + rewrittenUsed.Add(v) + } + } + + unused = unused.Diff(rewrittenUsed) + + for _, gv := range unused.Sorted() { + found := false + for i := range body { + if body[i].Vars(VarVisitorParams{}).Contains(gv) { + errs = append(errs, NewError(CompileErr, body[i].Loc(), "assigned var %v unused", dvs.reverse[gv])) + found = true + break + } + } + if !found { + errs = append(errs, NewError(CompileErr, body[0].Loc(), "assigned var %v unused", dvs.reverse[gv])) + } + } + + return errs } -func checkUnusedDeclaredVars(loc *Location, stack *localDeclaredVars, used VarSet, cpy Body, errs Errors) Errors { +func checkUnusedDeclaredVars(body Body, stack *localDeclaredVars, used VarSet, cpy Body, errs Errors) Errors { // NOTE(tsandall): Do not generate more errors if there are existing // declaration errors. @@ -3982,13 +4913,69 @@ func checkUnusedDeclaredVars(loc *Location, stack *localDeclaredVars, used VarSe unused := declared.Diff(bodyvars).Diff(used) for _, gv := range unused.Sorted() { - errs = append(errs, NewError(CompileErr, loc, "declared var %v unused", dvs.reverse[gv])) + rv := dvs.reverse[gv] + if !rv.IsGenerated() { + // Scan through body exprs, looking for a match between the + // bad var's original name, and each expr's declared vars. + foundUnusedVarByName := false + for i := range body { + varsDeclaredInExpr := declaredVars(body[i]) + if varsDeclaredInExpr.Contains(dvs.reverse[gv]) { + // TODO(philipc): Clean up the offset logic here when the parser + // reports more accurate locations. + errs = append(errs, NewError(CompileErr, body[i].Loc(), "declared var %v unused", dvs.reverse[gv])) + foundUnusedVarByName = true + break + } + } + // Default error location returned. + if !foundUnusedVarByName { + errs = append(errs, NewError(CompileErr, body[0].Loc(), "declared var %v unused", dvs.reverse[gv])) + } + } } return errs } -func rewriteSomeDeclStatement(g *localVarGenerator, stack *localDeclaredVars, expr *Expr, errs Errors) (*Expr, Errors) { +func rewriteEveryStatement(g *localVarGenerator, stack *localDeclaredVars, expr *Expr, errs Errors, strict bool) (*Expr, Errors) { + e := expr.Copy() + every := e.Terms.(*Every) + + errs = rewriteDeclaredVarsInTermRecursive(g, stack, every.Domain, errs, strict) + + stack.Push() + defer stack.Pop() + + // if the key exists, rewrite + if every.Key != nil { + if v := every.Key.Value.(Var); !v.IsWildcard() { + gv, err := rewriteDeclaredVar(g, stack, v, declaredVar) + if err != nil { + return nil, append(errs, NewError(CompileErr, every.Loc(), err.Error())) + } + every.Key.Value = gv + } + } else { // if the key doesn't exist, add dummy local + every.Key = NewTerm(g.Generate()) + } + + // value is always present + if v := every.Value.Value.(Var); !v.IsWildcard() { + gv, err := rewriteDeclaredVar(g, stack, v, declaredVar) + if err != nil { + return nil, append(errs, NewError(CompileErr, every.Loc(), err.Error())) + } + every.Value.Value = gv + } + + used := NewVarSet() + every.Body, errs = rewriteDeclaredVarsInBody(g, stack, used, every.Body, errs, strict) + + return rewriteDeclaredVarsInExpr(g, stack, e, errs, strict) +} + +func rewriteSomeDeclStatement(g *localVarGenerator, stack *localDeclaredVars, expr *Expr, errs Errors, strict bool) (*Expr, Errors) { e := expr.Copy() decl := e.Terms.(*SomeDecl) for i := range decl.Symbols { @@ -4026,20 +5013,20 @@ func rewriteSomeDeclStatement(g *localVarGenerator, stack *localDeclaredVars, ex return nil, append(errs, NewError(CompileErr, decl.Loc(), err.Error())) } } - return rewriteDeclaredVarsInExpr(g, stack, e, errs) + return rewriteDeclaredVarsInExpr(g, stack, e, errs, strict) } } return nil, errs } -func rewriteDeclaredVarsInExpr(g *localVarGenerator, stack *localDeclaredVars, expr *Expr, errs Errors) (*Expr, Errors) { +func rewriteDeclaredVarsInExpr(g *localVarGenerator, stack *localDeclaredVars, expr *Expr, errs Errors, strict bool) (*Expr, Errors) { vis := NewGenericVisitor(func(x interface{}) bool { var stop bool switch x := x.(type) { case *Term: - stop, errs = rewriteDeclaredVarsInTerm(g, stack, x, errs) + stop, errs = rewriteDeclaredVarsInTerm(g, stack, x, errs, strict) case *With: - _, errs = rewriteDeclaredVarsInTerm(g, stack, x.Value, errs) + errs = rewriteDeclaredVarsInTermRecursive(g, stack, x.Value, errs, strict) stop = true } return stop @@ -4048,7 +5035,7 @@ func rewriteDeclaredVarsInExpr(g *localVarGenerator, stack *localDeclaredVars, e return expr, errs } -func rewriteDeclaredAssignment(g *localVarGenerator, stack *localDeclaredVars, expr *Expr, errs Errors) (*Expr, Errors) { +func rewriteDeclaredAssignment(g *localVarGenerator, stack *localDeclaredVars, expr *Expr, errs Errors, strict bool) (*Expr, Errors) { if expr.Negated { errs = append(errs, NewError(CompileErr, expr.Location, "cannot assign vars inside negated expression")) @@ -4064,10 +5051,10 @@ func rewriteDeclaredAssignment(g *localVarGenerator, stack *localDeclaredVars, e // Rewrite terms on right hand side capture seen vars and recursively // process comprehensions before left hand side is processed. Also // rewrite with modifier. - errs = rewriteDeclaredVarsInTermRecursive(g, stack, expr.Operand(1), errs) + errs = rewriteDeclaredVarsInTermRecursive(g, stack, expr.Operand(1), errs, strict) for _, w := range expr.With { - errs = rewriteDeclaredVarsInTermRecursive(g, stack, w.Value, errs) + errs = rewriteDeclaredVarsInTermRecursive(g, stack, w.Value, errs, strict) } // Rewrite vars on left hand side with unique names. Catch redeclaration @@ -4114,11 +5101,12 @@ func rewriteDeclaredAssignment(g *localVarGenerator, stack *localDeclaredVars, e return expr, errs } -func rewriteDeclaredVarsInTerm(g *localVarGenerator, stack *localDeclaredVars, term *Term, errs Errors) (bool, Errors) { +func rewriteDeclaredVarsInTerm(g *localVarGenerator, stack *localDeclaredVars, term *Term, errs Errors, strict bool) (bool, Errors) { switch v := term.Value.(type) { case Var: if gv, ok := stack.Declared(v); ok { term.Value = gv + stack.Seen(v) } else if stack.Occurrence(v) == newVar { stack.Insert(v, v, seenVar) } @@ -4133,69 +5121,90 @@ func rewriteDeclaredVarsInTerm(g *localVarGenerator, stack *localDeclaredVars, t return true, errs } return false, errs + case Call: + ref := v[0] + WalkVars(ref, func(v Var) bool { + if gv, ok := stack.Declared(v); ok && !gv.Equal(v) { + // We will rewrite the ref of a function call, which is never ok since we don't have first-class functions. + errs = append(errs, NewError(CompileErr, term.Location, "called function %s shadowed", ref)) + return true + } + return false + }) + return false, errs case *object: cpy, _ := v.Map(func(k, v *Term) (*Term, *Term, error) { kcpy := k.Copy() - errs = rewriteDeclaredVarsInTermRecursive(g, stack, kcpy, errs) - errs = rewriteDeclaredVarsInTermRecursive(g, stack, v, errs) + errs = rewriteDeclaredVarsInTermRecursive(g, stack, kcpy, errs, strict) + errs = rewriteDeclaredVarsInTermRecursive(g, stack, v, errs, strict) return kcpy, v, nil }) term.Value = cpy case Set: cpy, _ := v.Map(func(elem *Term) (*Term, error) { elemcpy := elem.Copy() - errs = rewriteDeclaredVarsInTermRecursive(g, stack, elemcpy, errs) + errs = rewriteDeclaredVarsInTermRecursive(g, stack, elemcpy, errs, strict) return elemcpy, nil }) term.Value = cpy case *ArrayComprehension: - errs = rewriteDeclaredVarsInArrayComprehension(g, stack, v, errs) + errs = rewriteDeclaredVarsInArrayComprehension(g, stack, v, errs, strict) case *SetComprehension: - errs = rewriteDeclaredVarsInSetComprehension(g, stack, v, errs) + errs = rewriteDeclaredVarsInSetComprehension(g, stack, v, errs, strict) case *ObjectComprehension: - errs = rewriteDeclaredVarsInObjectComprehension(g, stack, v, errs) + errs = rewriteDeclaredVarsInObjectComprehension(g, stack, v, errs, strict) default: return false, errs } return true, errs } -func rewriteDeclaredVarsInTermRecursive(g *localVarGenerator, stack *localDeclaredVars, term *Term, errs Errors) Errors { +func rewriteDeclaredVarsInTermRecursive(g *localVarGenerator, stack *localDeclaredVars, term *Term, errs Errors, strict bool) Errors { WalkNodes(term, func(n Node) bool { var stop bool switch n := n.(type) { case *With: - _, errs = rewriteDeclaredVarsInTerm(g, stack, n.Value, errs) + errs = rewriteDeclaredVarsInTermRecursive(g, stack, n.Value, errs, strict) stop = true case *Term: - stop, errs = rewriteDeclaredVarsInTerm(g, stack, n, errs) + stop, errs = rewriteDeclaredVarsInTerm(g, stack, n, errs, strict) } return stop }) return errs } -func rewriteDeclaredVarsInArrayComprehension(g *localVarGenerator, stack *localDeclaredVars, v *ArrayComprehension, errs Errors) Errors { +func rewriteDeclaredVarsInArrayComprehension(g *localVarGenerator, stack *localDeclaredVars, v *ArrayComprehension, errs Errors, strict bool) Errors { + used := NewVarSet() + used.Update(v.Term.Vars()) + stack.Push() - v.Body, errs = rewriteDeclaredVarsInBody(g, stack, nil, v.Body, errs) - errs = rewriteDeclaredVarsInTermRecursive(g, stack, v.Term, errs) + v.Body, errs = rewriteDeclaredVarsInBody(g, stack, used, v.Body, errs, strict) + errs = rewriteDeclaredVarsInTermRecursive(g, stack, v.Term, errs, strict) stack.Pop() return errs } -func rewriteDeclaredVarsInSetComprehension(g *localVarGenerator, stack *localDeclaredVars, v *SetComprehension, errs Errors) Errors { +func rewriteDeclaredVarsInSetComprehension(g *localVarGenerator, stack *localDeclaredVars, v *SetComprehension, errs Errors, strict bool) Errors { + used := NewVarSet() + used.Update(v.Term.Vars()) + stack.Push() - v.Body, errs = rewriteDeclaredVarsInBody(g, stack, nil, v.Body, errs) - errs = rewriteDeclaredVarsInTermRecursive(g, stack, v.Term, errs) + v.Body, errs = rewriteDeclaredVarsInBody(g, stack, used, v.Body, errs, strict) + errs = rewriteDeclaredVarsInTermRecursive(g, stack, v.Term, errs, strict) stack.Pop() return errs } -func rewriteDeclaredVarsInObjectComprehension(g *localVarGenerator, stack *localDeclaredVars, v *ObjectComprehension, errs Errors) Errors { +func rewriteDeclaredVarsInObjectComprehension(g *localVarGenerator, stack *localDeclaredVars, v *ObjectComprehension, errs Errors, strict bool) Errors { + used := NewVarSet() + used.Update(v.Key.Vars()) + used.Update(v.Value.Vars()) + stack.Push() - v.Body, errs = rewriteDeclaredVarsInBody(g, stack, nil, v.Body, errs) - errs = rewriteDeclaredVarsInTermRecursive(g, stack, v.Key, errs) - errs = rewriteDeclaredVarsInTermRecursive(g, stack, v.Value, errs) + v.Body, errs = rewriteDeclaredVarsInBody(g, stack, used, v.Body, errs, strict) + errs = rewriteDeclaredVarsInTermRecursive(g, stack, v.Key, errs, strict) + errs = rewriteDeclaredVarsInTermRecursive(g, stack, v.Value, errs, strict) stack.Pop() return errs } @@ -4219,10 +5228,10 @@ func rewriteDeclaredVar(g *localVarGenerator, stack *localDeclaredVars, v Var, o // rewriteWithModifiersInBody will rewrite the body so that with modifiers do // not contain terms that require evaluation as values. If this function // encounters an invalid with modifier target then it will raise an error. -func rewriteWithModifiersInBody(c *Compiler, f *equalityFactory, body Body) (Body, *Error) { +func rewriteWithModifiersInBody(c *Compiler, unsafeBuiltinsMap map[string]struct{}, f *equalityFactory, body Body) (Body, *Error) { var result Body for i := range body { - exprs, err := rewriteWithModifier(c, f, body[i]) + exprs, err := rewriteWithModifier(c, unsafeBuiltinsMap, f, body[i]) if err != nil { return nil, err } @@ -4237,63 +5246,122 @@ func rewriteWithModifiersInBody(c *Compiler, f *equalityFactory, body Body) (Bod return result, nil } -func rewriteWithModifier(c *Compiler, f *equalityFactory, expr *Expr) ([]*Expr, *Error) { +func rewriteWithModifier(c *Compiler, unsafeBuiltinsMap map[string]struct{}, f *equalityFactory, expr *Expr) ([]*Expr, *Error) { var result []*Expr for i := range expr.With { - err := validateTarget(c, expr.With[i].Target) + eval, err := validateWith(c, unsafeBuiltinsMap, expr, i) if err != nil { return nil, err } - if requiresEval(expr.With[i].Value) { + if eval { eq := f.Generate(expr.With[i].Value) result = append(result, eq) expr.With[i].Value = eq.Operand(0) } } - // If any of the with modifiers in this expression were rewritten then result - // will be non-empty. In this case, the expression will have been modified and - // it should also be added to the result. - if len(result) > 0 { - result = append(result, expr) - } - return result, nil + return append(result, expr), nil } -func validateTarget(c *Compiler, term *Term) *Error { - if !isInputRef(term) && !isDataRef(term) { - return NewError(TypeErr, term.Location, "with keyword target must start with %v or %v", InputRootDocument, DefaultRootDocument) +func validateWith(c *Compiler, unsafeBuiltinsMap map[string]struct{}, expr *Expr, i int) (bool, *Error) { + target, value := expr.With[i].Target, expr.With[i].Value + + // Ensure that values that are built-ins are rewritten to Ref (not Var) + if v, ok := value.Value.(Var); ok { + if _, ok := c.builtins[v.String()]; ok { + value.Value = Ref([]*Term{NewTerm(v)}) + } + } + isBuiltinRefOrVar, err := isBuiltinRefOrVar(c.builtins, unsafeBuiltinsMap, target) + if err != nil { + return false, err } - if isDataRef(term) { - ref := term.Value.(Ref) + switch { + case isDataRef(target): + ref := target.Value.(Ref) node := c.RuleTree for i := 0; i < len(ref)-1; i++ { child := node.Child(ref[i].Value) if child == nil { break } else if len(child.Values) > 0 { - return NewError(CompileErr, term.Loc(), "with keyword cannot partially replace virtual document(s)") + return false, NewError(CompileErr, target.Loc(), "with keyword cannot partially replace virtual document(s)") } node = child } if node != nil { + // NOTE(sr): at this point in the compiler stages, we don't have a fully-populated + // TypeEnv yet -- so we have to make do with this check to see if the replacement + // target is a function. It's probably wrong for arity-0 functions, but those are + // and edge case anyways. if child := node.Child(ref[len(ref)-1].Value); child != nil { - for _, value := range child.Values { - if len(value.(*Rule).Head.Args) > 0 { - return NewError(CompileErr, term.Loc(), "with keyword cannot replace functions") + for _, v := range child.Values { + if len(v.(*Rule).Head.Args) > 0 { + if ok, err := validateWithFunctionValue(c.builtins, unsafeBuiltinsMap, c.RuleTree, value); err != nil || ok { + return false, err // err may be nil + } } } } } + case isInputRef(target): // ok, valid + case isBuiltinRefOrVar: + + // NOTE(sr): first we ensure that parsed Var builtins (`count`, `concat`, etc) + // are rewritten to their proper Ref convention + if v, ok := target.Value.(Var); ok { + target.Value = Ref([]*Term{NewTerm(v)}) + } + + targetRef := target.Value.(Ref) + bi := c.builtins[targetRef.String()] // safe because isBuiltinRefOrVar checked this + if err := validateWithBuiltinTarget(bi, targetRef, target.Loc()); err != nil { + return false, err + } + + if ok, err := validateWithFunctionValue(c.builtins, unsafeBuiltinsMap, c.RuleTree, value); err != nil || ok { + return false, err // err may be nil + } + default: + return false, NewError(TypeErr, target.Location, "with keyword target must reference existing %v, %v, or a function", InputRootDocument, DefaultRootDocument) + } + return requiresEval(value), nil +} + +func validateWithBuiltinTarget(bi *Builtin, target Ref, loc *location.Location) *Error { + switch bi.Name { + case Equality.Name, + RegoMetadataChain.Name, + RegoMetadataRule.Name: + return NewError(CompileErr, loc, "with keyword replacing built-in function: replacement of %q invalid", bi.Name) + } + switch { + case target.HasPrefix(Ref([]*Term{VarTerm("internal")})): + return NewError(CompileErr, loc, "with keyword replacing built-in function: replacement of internal function %q invalid", target) + + case bi.Relation: + return NewError(CompileErr, loc, "with keyword replacing built-in function: target must not be a relation") + + case bi.Decl.Result() == nil: + return NewError(CompileErr, loc, "with keyword replacing built-in function: target must not be a void function") } return nil } +func validateWithFunctionValue(bs map[string]*Builtin, unsafeMap map[string]struct{}, ruleTree *TreeNode, value *Term) (bool, *Error) { + if v, ok := value.Value.(Ref); ok { + if ruleTree.Find(v) != nil { // ref exists in rule tree + return true, nil + } + } + return isBuiltinRefOrVar(bs, unsafeMap, value) +} + func isInputRef(term *Term) bool { if ref, ok := term.Value.(Ref); ok { if ref.HasPrefix(InputRootRef) { @@ -4312,8 +5380,20 @@ func isDataRef(term *Term) bool { return false } +func isBuiltinRefOrVar(bs map[string]*Builtin, unsafeBuiltinsMap map[string]struct{}, term *Term) (bool, *Error) { + switch v := term.Value.(type) { + case Ref, Var: + if _, ok := unsafeBuiltinsMap[v.String()]; ok { + return false, NewError(CompileErr, term.Location, "with keyword replacing built-in function: target must not be unsafe: %q", v) + } + _, ok := bs[v.String()] + return ok, nil + } + return false, nil +} + func isVirtual(node *TreeNode, ref Ref) bool { - for i := 0; i < len(ref); i++ { + for i := range ref { child := node.Child(ref[i].Value) if child == nil { return false @@ -4391,6 +5471,25 @@ func checkUnsafeBuiltins(unsafeBuiltinsMap map[string]struct{}, node interface{} return errs } +func checkDeprecatedBuiltins(deprecatedBuiltinsMap map[string]struct{}, node interface{}, strict bool) Errors { + // Early out; deprecatedBuiltinsMap is only populated in strict-mode. + if !strict { + return nil + } + + errs := make(Errors, 0) + WalkExprs(node, func(x *Expr) bool { + if x.IsCall() { + operator := x.Operator().String() + if _, ok := deprecatedBuiltinsMap[operator]; ok { + errs = append(errs, NewError(TypeErr, x.Loc(), "deprecated built-in function calls in expression: %v", operator)) + } + } + return false + }) + return errs +} + func rewriteVarsInRef(vars ...map[Var]Var) varRewriter { return func(node Ref) Ref { i, _ := TransformVars(node, func(v Var) (Value, error) { @@ -4404,3 +5503,57 @@ func rewriteVarsInRef(vars ...map[Var]Var) varRewriter { return i.(Ref) } } + +// NOTE(sr): This is duplicated with compile/compile.go; but moving it into another location +// would cause a circular dependency -- the refSet definition needs ast.Ref. If we make it +// public in the ast package, the compile package could take it from there, but it would also +// increase our public interface. Let's reconsider if we need it in a third place. +type refSet struct { + s []Ref +} + +func newRefSet(x ...Ref) *refSet { + result := &refSet{} + for i := range x { + result.AddPrefix(x[i]) + } + return result +} + +// ContainsPrefix returns true if r is prefixed by any of the existing refs in the set. +func (rs *refSet) ContainsPrefix(r Ref) bool { + for i := range rs.s { + if r.HasPrefix(rs.s[i]) { + return true + } + } + return false +} + +// AddPrefix inserts r into the set if r is not prefixed by any existing +// refs in the set. If any existing refs are prefixed by r, those existing +// refs are removed. +func (rs *refSet) AddPrefix(r Ref) { + if rs.ContainsPrefix(r) { + return + } + cpy := []Ref{r} + for i := range rs.s { + if !rs.s[i].HasPrefix(r) { + cpy = append(cpy, rs.s[i]) + } + } + rs.s = cpy +} + +// Sorted returns a sorted slice of terms for refs in the set. +func (rs *refSet) Sorted() []*Term { + terms := make([]*Term, len(rs.s)) + for i := range rs.s { + terms[i] = NewTerm(rs.s[i]) + } + sort.Slice(terms, func(i, j int) bool { + return terms[i].Value.Compare(terms[j].Value) < 0 + }) + return terms +} diff --git a/ast/compile_test.go b/ast/compile_test.go index 223f7eac7b..06890bf7a1 100644 --- a/ast/compile_test.go +++ b/ast/compile_test.go @@ -6,6 +6,7 @@ package ast import ( "bytes" + "encoding/json" "errors" "fmt" "reflect" @@ -203,12 +204,31 @@ func TestOutputVarsForNode(t *testing.T) { query: `z = "abc"; x = split(z, a)[y]`, exp: `{z}`, }, + { + note: "every: simple: no output vars", + query: `every k, v in [1, 2] { k < v }`, + exp: `set()`, + }, + { + note: "every: output vars in domain", + query: `xs = []; every k, v in xs[i] { k < v }`, + exp: `{xs, i}`, + }, + { + note: "every: output vars in body", + query: `every k, v in [] { k < v; i = 1 }`, + exp: `set()`, + }, } for _, tc := range tests { t.Run(tc.note, func(t *testing.T) { - body := MustParseBody(tc.query) + opts := ParserOptions{AllFutureKeywords: true, unreleasedKeywords: true} + body, err := ParseBodyWithOpts(tc.query, opts) + if err != nil { + t.Fatal(err) + } arity := func(r Ref) int { a, ok := tc.arities[r.String()] if !ok { @@ -243,7 +263,7 @@ func TestOutputVarsForNode(t *testing.T) { func TestModuleTree(t *testing.T) { - mods := getCompilerTestModules() + mods := getCompilerTestModules() // 7 modules mods["system-mod"] = MustParseModule(` package system.foo @@ -254,8 +274,14 @@ func TestModuleTree(t *testing.T) { p = 1 `) + mods["dots-in-heads"] = MustParseModule(` + package dots + + a.b.c = 12 + d.e.f.g = 34 + `) tree := NewModuleTree(mods) - expectedSize := 9 + expectedSize := 10 if tree.Size() != expectedSize { t.Fatalf("Expected %v but got %v modules", expectedSize, tree.Size()) @@ -274,978 +300,1892 @@ func TestModuleTree(t *testing.T) { } } +func TestCompilerGetExports(t *testing.T) { + tests := []struct { + note string + modules []*Module + exports map[string][]string + }{ + { + note: "simple", + modules: modules(`package p + r = 1`), + exports: map[string][]string{"data.p": {"r"}}, + }, + { + note: "simple single-value ref rule", + modules: modules(`package p + q.r.s = 1`), + exports: map[string][]string{"data.p": {"q.r.s"}}, + }, + { + note: "var key single-value ref rule", + modules: modules(`package p + q.r[s] = 1 { s := "foo" }`), + exports: map[string][]string{"data.p": {"q.r"}}, + }, + { + note: "simple multi-value ref rule", + modules: modules(`package p + import future.keywords -func TestRuleTree(t *testing.T) { - - mods := getCompilerTestModules() - mods["system-mod"] = MustParseModule(` - package system.foo - - p = 1 - `) - mods["non-system-mod"] = MustParseModule(` - package user.system - - p = 1 - `) - mods["mod-incr"] = MustParseModule(`package a.b.c - -s[1] { true } -s[2] { true }`, - ) - - tree := NewRuleTree(NewModuleTree(mods)) - expectedNumRules := 23 - - if tree.Size() != expectedNumRules { - t.Errorf("Expected %v but got %v rules", expectedNumRules, tree.Size()) - } - - // Check that empty packages are represented as leaves with no rules. - node := tree.Children[Var("data")].Children[String("a")].Children[String("b")].Children[String("empty")] - - if node == nil || len(node.Children) != 0 || len(node.Values) != 0 { - t.Fatalf("Unexpected nil value or non-empty leaf of non-leaf node: %v", node) - } - - system := tree.Child(Var("data")).Child(String("system")) - if !system.Hide { - t.Fatalf("Expected system node to be hidden") - } - - if system.Child(String("foo")).Hide { - t.Fatalf("Expected system.foo node to be visible") - } - - user := tree.Child(Var("data")).Child(String("user")).Child(String("system")) - if user.Hide { - t.Fatalf("Expected user.system node to be visible") - } - - if !isVirtual(tree, MustParseRef("data.a.b.empty")) { - t.Fatal("Expected data.a.b.empty to be virtual") - } - - abc := tree.Children[Var("data")].Children[String("a")].Children[String("b")].Children[String("c")] - exp := []Value{String("p"), String("q"), String("r"), String("s"), String("z")} - - if len(abc.Sorted) != len(exp) { - t.Fatal("expected", exp, "but got", abc) + q.r.s contains 1 { true }`), + exports: map[string][]string{"data.p": {"q.r.s"}}, + }, + { + note: "two simple, multiple rules", + modules: modules(`package p + r = 1 + s = 11`, + `package q + x = 2 + y = 22`), + exports: map[string][]string{"data.p": {"r", "s"}, "data.q": {"x", "y"}}, + }, + { + note: "ref head + simple, multiple rules", + modules: modules(`package p.a.b.c + r = 1 + s = 11`, + `package q + a.b.x = 2 + a.b.c.y = 22`), + exports: map[string][]string{ + "data.p.a.b.c": {"r", "s"}, + "data.q": {"a.b.x", "a.b.c.y"}, + }, + }, + { + note: "two ref head, multiple rules", + modules: modules(`package p.a.b.c + r = 1 + s = 11`, + `package p + a.b.x = 2 + a.b.c.y = 22`), + exports: map[string][]string{ + "data.p.a.b.c": {"r", "s"}, + "data.p": {"a.b.x", "a.b.c.y"}, + }, + }, + { + note: "single-value rule with number key", + modules: modules(`package p + q[1] = 1 + q[2] = 2`), + exports: map[string][]string{ + "data.p": {"q[1]", "q[2]"}, // TODO(sr): is this really what we want? + }, + }, + { + note: "single-value (ref) rule with number key", + modules: modules(`package p + a.b.q[1] = 1 + a.b.q[2] = 2`), + exports: map[string][]string{ + "data.p": {"a.b.q[1]", "a.b.q[2]"}, + }, + }, + { + note: "single-value (ref) rule with var key", + modules: modules(`package p + a.b.q[x] = y { x := 1; y := true } + a.b.q[2] = 2`), + exports: map[string][]string{ + "data.p": {"a.b.q", "a.b.q[2]"}, // TODO(sr): GroundPrefix? right thing here? + }, + }, + { // NOTE(sr): An ast.Module can be constructed in various ways, this is to assert that + // our compilation process doesn't explode here if we're fed a Rule that has no Ref. + note: "synthetic", + modules: func() []*Module { + ms := modules(`package p + r = 1`) + ms[0].Rules[0].Head.Reference = nil + return ms + }(), + exports: map[string][]string{"data.p": {"r"}}, + }, + // TODO(sr): add multi-val rule, and ref-with-var single-value rule. } - for i := range exp { - if exp[i].Compare(abc.Sorted[i]) != 0 { - t.Fatal("expected", exp, "but got", abc) + hashMap := func(ms map[string][]string) *util.HashMap { + rules := util.NewHashMap(func(a, b util.T) bool { + switch a := a.(type) { + case Ref: + return a.Equal(b.(Ref)) + case []Ref: + b := b.([]Ref) + if len(b) != len(a) { + return false + } + for i := range a { + if !a[i].Equal(b[i]) { + return false + } + } + return true + default: + panic("unreachable") + } + }, func(v util.T) int { + return v.(Ref).Hash() + }) + for r, rs := range ms { + refs := make([]Ref, len(rs)) + for i := range rs { + refs[i] = toRef(rs[i]) + } + rules.Put(MustParseRef(r), refs) } + return rules } -} - -func TestCompilerEmpty(t *testing.T) { - c := NewCompiler() - c.Compile(nil) - assertNotFailed(t, c) -} -func TestCompilerExample(t *testing.T) { - c := NewCompiler() - m := MustParseModule(testModule) - c.Compile(map[string]*Module{"testMod": m}) - assertNotFailed(t, c) + for _, tc := range tests { + t.Run(tc.note, func(t *testing.T) { + c := NewCompiler() + for i, m := range tc.modules { + c.Modules[fmt.Sprint(i)] = m + c.sorted = append(c.sorted, fmt.Sprint(i)) + } + if exp, act := hashMap(tc.exports), c.getExports(); !exp.Equal(act) { + t.Errorf("expected %v, got %v", exp, act) + } + }) + } } -func TestCompilerWithStageAfter(t *testing.T) { - c := NewCompiler().WithStageAfter( - "CheckRecursion", - CompilerStageDefinition{"MockStage", "mock_stage", mockStageFunctionCall}, - ) - m := MustParseModule(testModule) - c.Compile(map[string]*Module{"testMod": m}) - - if !c.Failed() { - t.Errorf("Expected compilation error") +func toRef(s string) Ref { + switch t := MustParseTerm(s).Value.(type) { + case Var: + return Ref{NewTerm(t)} + case Ref: + return t + default: + panic("unreachable") } } -func TestCompilerFunctions(t *testing.T) { +func TestCompilerCheckRuleHeadRefs(t *testing.T) { + tests := []struct { - note string - modules []string - wantErr bool + note string + modules []*Module + expected *Rule + err string }{ { - note: "multiple input types", - modules: []string{`package x - - f([x]) = y { - y = x - } - - f({"foo": x}) = y { - y = x - }`}, + note: "ref contains var", + modules: modules( + `package x + p.q[i].r = 1 { i := 10 }`, + ), + err: "rego_type_error: rule head must only contain string terms (except for last): i", }, { - note: "multiple input types", - modules: []string{`package x - - f([x]) = y { - y = x - } - - f([[x]]) = y { - y = x - }`}, + note: "valid: ref is single-value rule with var key", + modules: modules( + `package x + p.q.r[i] { i := 10 }`, + ), }, { - note: "constant input", - modules: []string{`package x - - f(1) = y { - y = "foo" - } - - f(2) = y { - y = "bar" - }`}, + note: "valid: ref is single-value rule with var key and value", + modules: modules( + `package x + p.q.r[i] = j { i := 10; j := 11 }`, + ), }, { - note: "constant input", - modules: []string{`package x - - f(1, x) = y { - y = x - } - - f(x, y) = z { - z = x+y - }`}, + note: "valid: ref is single-value rule with var key and static value", + modules: modules( + `package x + p.q.r[i] = "ten" { i := 10 }`, + ), }, { - note: "constant input", - modules: []string{`package x - - f(x, 1) = y { - y = x - } - - f(x, [y]) = z { - z = x+y - }`}, + note: "valid: ref is single-value rule with number key", + modules: modules( + `package x + p.q.r[1] { true }`, + ), }, { - note: "multiple input types (nested)", - modules: []string{`package x - - f({"foo": {"bar": x}}) = y { - y = x - } - - f({"foo": [x]}) = y { - y = x - }`}, + note: "valid: ref is single-value rule with boolean key", + modules: modules( + `package x + p.q.r[true] { true }`, + ), }, { - note: "multiple output types", - modules: []string{`package x - - f(1) = y { - y = "foo" - } - - f(2) = y { - y = 2 - }`}, + note: "valid: ref is single-value rule with null key", + modules: modules( + `package x + p.q.r[null] { true }`, + ), }, { - note: "namespacing", - modules: []string{ + note: "valid: ref is single-value rule with set literal key", + modules: modules( `package x - - f(x) = y { - data.y.f[x] = y - }`, - `package y - - f[x] = y { - y = "bar" - x = "foo" - }`, - }, + p.q.r[set()] { true }`, + ), }, { - note: "implicit value", - modules: []string{ + note: "valid: ref is single-value rule with array literal key", + modules: modules( `package x - - f(x) { - x = "foo" - }`}, + p.q.r[[]] { true }`, + ), }, { - note: "resolving", - modules: []string{ + note: "valid: ref is single-value rule with object literal key", + modules: modules( `package x - - f(x) = x { true }`, - - `package y - - import data.x - import data.x.f as g - - p { g(1, a) } - p { x.f(1, b) } - p { data.x.f(1, c) } - `, - }, + p.q.r[{}] { true }`, + ), }, { - note: "undefined", - modules: []string{ + note: "valid: ref is single-value rule with ref key", + modules: modules( `package x - - p { - f(1) - }`, - }, - wantErr: true, + x := [1,2,3] + p.q.r[x[i]] { i := 0}`, + ), }, { - note: "must apply", - modules: []string{ + note: "invalid: ref in ref", + modules: modules( `package x - - f(1) - - p { - f - } - `, - }, - wantErr: true, + p.q[arr[0]].r { i := 10 }`, + ), + err: "rego_type_error: rule head must only contain string terms (except for last): arr[0]", }, { - note: "must apply", - modules: []string{ + note: "invalid: non-string in ref (not last position)", + modules: modules( `package x - f(1) - p { f.x }`, - }, - wantErr: true, + p.q[10].r { true }`, + ), + err: "rego_type_error: rule head must only contain string terms (except for last): 10", }, { - note: "call argument ref output vars", - modules: []string{ + note: "valid: multi-value with var key", + modules: modules( `package x - - f(x) - - p { f(data.foo[i]) }`, - }, - wantErr: false, + p.q.r contains i if i := 10`, + ), + }, + { + note: "rewrite: single-value with non-var key (ref)", + modules: modules( + `package x + p.q.r[y.z] if y := {"z": "a"}`, + ), + expected: MustParseRule(`p.q.r[__local0__] { y := {"z": "a"}; __local0__ = y.z }`), }, } + for _, tc := range tests { t.Run(tc.note, func(t *testing.T) { - var err error - modules := map[string]*Module{} - for i, module := range tc.modules { - name := fmt.Sprintf("mod%d", i) - modules[name], err = ParseModule(name, module) - if err != nil { - panic(err) - } + mods := make(map[string]*Module, len(tc.modules)) + for i, m := range tc.modules { + mods[fmt.Sprint(i)] = m } c := NewCompiler() - c.Compile(modules) - if tc.wantErr && !c.Failed() { - t.Errorf("Expected compilation error") - } else if !tc.wantErr && c.Failed() { - t.Errorf("Unexpected compilation error(s): %v", c.Errors) + c.Modules = mods + compileStages(c, c.rewriteRuleHeadRefs) + if tc.err != "" { + assertCompilerErrorStrings(t, c, []string{tc.err}) + } else { + if len(c.Errors) > 0 { + t.Fatalf("expected no errors, got %v", c.Errors) + } + if tc.expected != nil { + assertRulesEqual(t, tc.expected, mods["0"].Rules[0]) + } } }) } } -func TestCompilerErrorLimit(t *testing.T) { - modules := map[string]*Module{ - "test": MustParseModule(`package test - r = y { y = true; x = z } +func TestRuleTreeWithDotsInHeads(t *testing.T) { - s[x] = y { - z = y + x + // TODO(sr): multi-val with var key in ref + tests := []struct { + note string + modules []*Module + size int // expected tree size = number of leaves + depth int // expected tree depth + }{ + { + note: "two modules, same package, one rule each", + modules: modules( + `package x + p.q.r = 1`, + `package x + p.q.w = 2`, + ), + size: 2, + }, + { + note: "two modules, sub-package, one rule each", + modules: modules( + `package x + p.q.r = 1`, + `package x.p + q.w.z = 2`, + ), + size: 2, + }, + { + note: "three modules, sub-package, incl simple rule", + modules: modules( + `package x + p.q.r = 1`, + `package x.p + q.w.z = 2`, + `package x.p.q.w + y = 3`, + ), + size: 3, + }, + { + note: "simple: two modules", + modules: modules( + `package x + p.q.r = 1`, + `package y + p.q.w = 2`, + ), + size: 2, + }, + { + note: "conflict: one module", + modules: modules( + `package q + p[x] = 1 + p = 2`, + ), + size: 2, + }, + { + note: "conflict: two modules", + modules: modules( + `package q + p.r.s[x] = 1`, + `package q.p + r.s = 2 if true`, + ), + size: 2, + }, + { + note: "simple: two modules, one using ref head, one package path", + modules: modules( + `package x + p.q.r = 1 { input == 1 }`, + `package x.p.q + r = 2 { input == 2 }`, + ), + size: 2, + }, + { + note: "conflict: two modules, both using ref head, different package paths", + modules: modules( + `package x + p.q.r = 1 { input == 1 }`, // x.p.q.r = 1 + `package x.p + q.r.s = 2 { input == 2 }`, // x.p.q.r.s = 2 + ), + size: 2, + }, + { + note: "overlapping: one module, two ref head", + modules: modules( + `package x + p.q.r = 1 + p.q.w.v = 2`, + ), + size: 2, + depth: 6, + }, + { + note: "last ref term != string", + modules: modules( + `package x + p.q.w[1] = 2 + p.q.w[{"foo": "baz"}] = 20 + p.q.x[true] = false + p.q.x[y] = y { y := "y" }`, + ), + size: 4, + depth: 6, + }, } - t[x] { split(x, y, z) } - `), + for _, tc := range tests { + t.Run(tc.note, func(t *testing.T) { + c := NewCompiler() + for i, m := range tc.modules { + c.Modules[fmt.Sprint(i)] = m + c.sorted = append(c.sorted, fmt.Sprint(i)) + } + compileStages(c, c.setRuleTree) + if len(c.Errors) > 0 { + t.Fatal(c.Errors) + } + tree := c.RuleTree + tree.DepthFirst(func(n *TreeNode) bool { + t.Log(n) + if !sort.SliceIsSorted(n.Sorted, func(i, j int) bool { + return n.Sorted[i].Compare(n.Sorted[j]) < 0 + }) { + t.Errorf("expected sorted to be sorted: %v", n.Sorted) + } + return false + }) + if tc.depth > 0 { + if exp, act := tc.depth, depth(tree); exp != act { + t.Errorf("expected tree depth %d, got %d", exp, act) + } + } + if exp, act := tc.size, tree.Size(); exp != act { + t.Errorf("expected tree size %d, got %d", exp, act) + } + }) } +} - c := NewCompiler().SetErrorLimit(2) - c.Compile(modules) - - errs := c.Errors - exp := []string{ - "2:20: rego_unsafe_var_error: var x is unsafe", - "2:20: rego_unsafe_var_error: var z is unsafe", - "rego_compile_error: error limit reached", - } +func TestRuleTreeWithVars(t *testing.T) { + opts := ParserOptions{AllFutureKeywords: true, unreleasedKeywords: true} - var result []string - for _, err := range errs { - result = append(result, err.Error()) - } + t.Run("simple single-value rule", func(t *testing.T) { + mod0 := `package a.b +c.d.e = 1 if true` - sort.Strings(exp) - sort.Strings(result) - if !reflect.DeepEqual(exp, result) { - t.Errorf("Expected errors %v, got %v", exp, result) - } -} + mods := map[string]*Module{"0.rego": MustParseModuleWithOpts(mod0, opts)} + tree := NewRuleTree(NewModuleTree(mods)) -func TestCompilerCheckSafetyHead(t *testing.T) { - c := NewCompiler() - c.Modules = getCompilerTestModules() - c.Modules["newMod"] = MustParseModule(`package a.b + node := tree.Find(MustParseRef("data.a.b.c.d.e")) + if node == nil { + t.Fatal("expected non-nil leaf node") + } + if exp, act := 1, len(node.Values); exp != act { + t.Errorf("expected %d values, found %d", exp, act) + } + if exp, act := 0, len(node.Children); exp != act { + t.Errorf("expected %d children, found %d", exp, act) + } + if exp, act := MustParseRef("c.d.e"), node.Values[0].(*Rule).Head.Ref(); !exp.Equal(act) { + t.Errorf("expected rule ref %v, found %v", exp, act) + } + }) -unboundKey[x] = y { q[y] = {"foo": [1, 2, [{"bar": y}]]} } -unboundVal[y] = x { q[y] = {"foo": [1, 2, [{"bar": y}]]} } -unboundCompositeVal[y] = [{"foo": x, "bar": y}] { q[y] = {"foo": [1, 2, [{"bar": y}]]} } -unboundCompositeKey[[{"x": x}]] { q[y] } -unboundBuiltinOperator = eq { x = 1 } -unboundElse { false } else = else_var { true } -`, - ) - compileStages(c, c.checkSafetyRuleHeads) + t.Run("two single-value rules", func(t *testing.T) { + mod0 := `package a.b +c.d.e = 1 if true` + mod1 := `package a.b.c +d.e = 2 if true` - makeErrMsg := func(v string) string { - return fmt.Sprintf("rego_unsafe_var_error: var %v is unsafe", v) - } + mods := map[string]*Module{ + "0.rego": MustParseModuleWithOpts(mod0, opts), + "1.rego": MustParseModuleWithOpts(mod1, opts), + } + tree := NewRuleTree(NewModuleTree(mods)) - expected := []string{ - makeErrMsg("x"), - makeErrMsg("x"), - makeErrMsg("x"), - makeErrMsg("x"), - makeErrMsg("eq"), - makeErrMsg("else_var"), - } + node := tree.Find(MustParseRef("data.a.b.c.d.e")) + if node == nil { + t.Fatal("expected non-nil leaf node") + } + if exp, act := 2, len(node.Values); exp != act { + t.Errorf("expected %d values, found %d", exp, act) + } + if exp, act := 0, len(node.Children); exp != act { + t.Errorf("expected %d children, found %d", exp, act) + } + if exp, act := MustParseRef("c.d.e"), node.Values[0].(*Rule).Head.Ref(); !exp.Equal(act) { + t.Errorf("expected rule ref %v, found %v", exp, act) + } + if exp, act := MustParseRef("d.e"), node.Values[1].(*Rule).Head.Ref(); !exp.Equal(act) { + t.Errorf("expected rule ref %v, found %v", exp, act) + } + }) - result := compilerErrsToStringSlice(c.Errors) - sort.Strings(expected) + t.Run("one multi-value rule, one single-value, with var", func(t *testing.T) { + mod0 := `package a.b +c.d.e.g contains 1 if true` + mod1 := `package a.b.c +d.e.f = 2 if true` - if len(result) != len(expected) { - t.Fatalf("Expected %d:\n%v\nBut got %d:\n%v", len(expected), strings.Join(expected, "\n"), len(result), strings.Join(result, "\n")) - } + mods := map[string]*Module{ + "0.rego": MustParseModuleWithOpts(mod0, opts), + "1.rego": MustParseModuleWithOpts(mod1, opts), + } + tree := NewRuleTree(NewModuleTree(mods)) - for i := range result { - if expected[i] != result[i] { - t.Errorf("Expected %v but got: %v", expected[i], result[i]) + // var-key rules should be included in the results + node := tree.Find(MustParseRef("data.a.b.c.d.e.g")) + if node == nil { + t.Fatal("expected non-nil leaf node") } - } + if exp, act := 1, len(node.Values); exp != act { + t.Fatalf("expected %d values, found %d", exp, act) + } + if exp, act := 0, len(node.Children); exp != act { + t.Fatalf("expected %d children, found %d", exp, act) + } + node = tree.Find(MustParseRef("data.a.b.c.d.e.f")) + if node == nil { + t.Fatal("expected non-nil leaf node") + } + if exp, act := 1, len(node.Values); exp != act { + t.Fatalf("expected %d values, found %d", exp, act) + } + if exp, act := MustParseRef("d.e.f"), node.Values[0].(*Rule).Head.Ref(); !exp.Equal(act) { + t.Errorf("expected rule ref %v, found %v", exp, act) + } + }) -} + t.Run("two multi-value rules, back compat", func(t *testing.T) { + mod0 := `package a +b[c] { c := "foo" }` + mod1 := `package a +b[d] { d := "bar" }` -func TestCompilerCheckSafetyBodyReordering(t *testing.T) { - tests := []struct { - note string - body string - expected string - }{ - {"noop", `x = 1; x != 0`, `x = 1; x != 0`}, - {"var/ref", `a[i] = x; a = [1, 2, 3, 4]`, `a = [1, 2, 3, 4]; a[i] = x`}, - {"var/ref (nested)", `a = [1, 2, 3, 4]; a[b[i]] = x; b = [0, 0, 0, 0]`, `a = [1, 2, 3, 4]; b = [0, 0, 0, 0]; a[b[i]] = x`}, - {"negation", - `a = [true, false]; b = [true, false]; not a[i]; b[i]`, - `a = [true, false]; b = [true, false]; b[i]; not a[i]`}, - {"built-in", `x != 0; count([1, 2, 3], x)`, `count([1, 2, 3], x); x != 0`}, - {"var/var 1", `x = y; z = 1; y = z`, `z = 1; y = z; x = y`}, - {"var/var 2", `x = y; 1 = z; z = y`, `1 = z; z = y; x = y`}, - {"var/var 3", `x != 0; y = x; y = 1`, `y = 1; y = x; x != 0`}, - {"array compr/var", `x != 0; [y | y = 1] = x`, `[y | y = 1] = x; x != 0`}, - {"array compr/array", `[1] != [x]; [y | y = 1] = [x]`, `[y | y = 1] = [x]; [1] != [x]`}, - {"with", `data.a.b.d.t with input as x; x = 1`, `x = 1; data.a.b.d.t with input as x`}, - {"with-2", `data.a.b.d.t with input.x as x; x = 1`, `x = 1; data.a.b.d.t with input.x as x`}, - {"with-nop", "data.somedoc[x] with input as true", "data.somedoc[x] with input as true"}, - {"ref-head", `s = [["foo"], ["bar"]]; x = y[0]; y = s[_]; contains(x, "oo")`, ` - s = [["foo"], ["bar"]]; - y = s[_]; - x = y[0]; - contains(x, "oo") - `}, - {"userfunc", `split(y, ".", z); data.a.b.funcs.fn("...foo.bar..", y)`, `data.a.b.funcs.fn("...foo.bar..", y); split(y, ".", z)`}, - } + mods := map[string]*Module{ + "0.rego": MustParseModuleWithOpts(mod0, opts), + "1.rego": MustParseModuleWithOpts(mod1, opts), + } + tree := NewRuleTree(NewModuleTree(mods)) - for i, tc := range tests { - t.Run(tc.note, func(t *testing.T) { - c := NewCompiler() - c.Modules = getCompilerTestModules() - c.Modules["reordering"] = MustParseModule(fmt.Sprintf( - `package test - p { %s }`, tc.body)) + node := tree.Find(MustParseRef("data.a.b")) + if node == nil { + t.Fatal("expected non-nil leaf node") + } + if exp, act := 2, len(node.Values); exp != act { + t.Fatalf("expected %d values, found %d: %v", exp, act, node.Values) + } + if exp, act := 0, len(node.Children); exp != act { + t.Errorf("expected %d children, found %d", exp, act) + } + if exp, act := (Ref{VarTerm("b")}), node.Values[0].(*Rule).Head.Ref(); !exp.Equal(act) { + t.Errorf("expected rule ref %v, found %v", exp, act) + } + if act := node.Values[0].(*Rule).Head.Value; act != nil { + t.Errorf("expected rule value nil, found %v", act) + } + if exp, act := VarTerm("c"), node.Values[0].(*Rule).Head.Key; !exp.Equal(act) { + t.Errorf("expected rule key %v, found %v", exp, act) + } + if exp, act := (Ref{VarTerm("b")}), node.Values[1].(*Rule).Head.Ref(); !exp.Equal(act) { + t.Errorf("expected rule ref %v, found %v", exp, act) + } + if act := node.Values[1].(*Rule).Head.Value; act != nil { + t.Errorf("expected rule value nil, found %v", act) + } + if exp, act := VarTerm("d"), node.Values[1].(*Rule).Head.Key; !exp.Equal(act) { + t.Errorf("expected rule key %v, found %v", exp, act) + } + }) - compileStages(c, c.checkSafetyRuleBodies) + t.Run("two multi-value rules, back compat with short style", func(t *testing.T) { + mod0 := `package a +b[1]` + mod1 := `package a +b[2]` + mods := map[string]*Module{ + "0.rego": MustParseModuleWithOpts(mod0, opts), + "1.rego": MustParseModuleWithOpts(mod1, opts), + } + tree := NewRuleTree(NewModuleTree(mods)) - if c.Failed() { - t.Errorf("%v (#%d): Unexpected compilation error: %v", tc.note, i, c.Errors) - return - } + node := tree.Find(MustParseRef("data.a.b")) + if node == nil { + t.Fatal("expected non-nil leaf node") + } + if exp, act := 2, len(node.Values); exp != act { + t.Fatalf("expected %d values, found %d: %v", exp, act, node.Values) + } + if exp, act := 0, len(node.Children); exp != act { + t.Errorf("expected %d children, found %d", exp, act) + } + if exp, act := (Ref{VarTerm("b")}), node.Values[0].(*Rule).Head.Ref(); !exp.Equal(act) { + t.Errorf("expected rule ref %v, found %v", exp, act) + } + if act := node.Values[0].(*Rule).Head.Value; act != nil { + t.Errorf("expected rule value nil, found %v", act) + } + if exp, act := IntNumberTerm(1), node.Values[0].(*Rule).Head.Key; !exp.Equal(act) { + t.Errorf("expected rule key %v, found %v", exp, act) + } + if exp, act := (Ref{VarTerm("b")}), node.Values[1].(*Rule).Head.Ref(); !exp.Equal(act) { + t.Errorf("expected rule ref %v, found %v", exp, act) + } + if act := node.Values[1].(*Rule).Head.Value; act != nil { + t.Errorf("expected rule value nil, found %v", act) + } + if exp, act := IntNumberTerm(2), node.Values[1].(*Rule).Head.Key; !exp.Equal(act) { + t.Errorf("expected rule key %v, found %v", exp, act) + } + }) - expected := MustParseBody(tc.expected) - result := c.Modules["reordering"].Rules[0].Body + t.Run("two single-value rules, back compat with short style", func(t *testing.T) { + mod0 := `package a +b[1] = 1` + mod1 := `package a +b[2] = 2` + mods := map[string]*Module{ + "0.rego": MustParseModuleWithOpts(mod0, opts), + "1.rego": MustParseModuleWithOpts(mod1, opts), + } + tree := NewRuleTree(NewModuleTree(mods)) - if !expected.Equal(result) { - t.Errorf("%v (#%d): Expected body to be ordered and equal to %v but got: %v", tc.note, i, expected, result) - } - }) - } -} + // branch point + node := tree.Find(MustParseRef("data.a.b")) + if node == nil { + t.Fatal("expected non-nil leaf node") + } + if exp, act := 0, len(node.Values); exp != act { + t.Fatalf("expected %d values, found %d: %v", exp, act, node.Values) + } + if exp, act := 2, len(node.Children); exp != act { + t.Fatalf("expected %d children, found %d", exp, act) + } -func TestCompilerCheckSafetyBodyReorderingClosures(t *testing.T) { - c := NewCompiler() - c.Modules = map[string]*Module{ - "mod": MustParseModule( - `package compr + // branch 1 + node = tree.Find(MustParseRef("data.a.b[1]")) + if node == nil { + t.Fatal("expected non-nil leaf node") + } + if exp, act := 1, len(node.Values); exp != act { + t.Fatalf("expected %d values, found %d: %v", exp, act, node.Values) + } + if exp, act := MustParseRef("b[1]"), node.Values[0].(*Rule).Head.Ref(); !exp.Equal(act) { + t.Errorf("expected rule ref %v, found %v", exp, act) + } + if exp, act := IntNumberTerm(1), node.Values[0].(*Rule).Head.Value; !exp.Equal(act) { + t.Errorf("expected rule value %v, found %v", exp, act) + } + if exp, act := IntNumberTerm(1), node.Values[0].(*Rule).Head.Key; !exp.Equal(act) { + t.Errorf("expected rule key %v, found %v", exp, act) + } -import data.b -import data.c + // branch 2 + node = tree.Find(MustParseRef("data.a.b[2]")) + if node == nil { + t.Fatal("expected non-nil leaf node") + } + if exp, act := 1, len(node.Values); exp != act { + t.Fatalf("expected %d values, found %d: %v", exp, act, node.Values) + } + if exp, act := MustParseRef("b[2]"), node.Values[0].(*Rule).Head.Ref(); !exp.Equal(act) { + t.Errorf("expected rule ref %v, found %v", exp, act) + } + if exp, act := IntNumberTerm(2), node.Values[0].(*Rule).Head.Value; !exp.Equal(act) { + t.Errorf("expected rule value %v, found %v", exp, act) + } + if exp, act := IntNumberTerm(2), node.Values[0].(*Rule).Head.Key; !exp.Equal(act) { + t.Errorf("expected rule key %v, found %v", exp, act) + } + }) -fn(x) = y { - trim(x, ".", y) -} + // NOTE(sr): Now this test seems obvious, but it's a bug that had snuck into the + // NewRuleTree code during development. + t.Run("root node and data node unhidden if there are no system nodes", func(t *testing.T) { + mod0 := `package a +p = 1` + mods := map[string]*Module{ + "0.rego": MustParseModuleWithOpts(mod0, opts), + } + tree := NewRuleTree(NewModuleTree(mods)) -p = true { v = [null | true]; xs = [x | a[i] = x; a = [y | y != 1; y = c[j]]]; xs[j] > 0; z = [true | data.a.b.d.t with input as i2; i2 = i]; b[i] = j } -q = true { _ = [x | x = b[i]]; _ = b[j]; _ = [x | x = true; x != false]; true != false; _ = [x | data.foo[_] = x]; data.foo[_] = _ } -r = true { a = [x | split(y, ".", z); x = z[i]; fn("...foo.bar..", y)] }`, - ), - } - - compileStages(c, c.checkSafetyRuleBodies) - assertNotFailed(t, c) + if exp, act := false, tree.Hide; act != exp { + t.Errorf("expected tree.Hide=%v, got %v", exp, act) + } + dataNode := tree.Child(Var("data")) + if dataNode == nil { + t.Fatal("expected data node") + } + if exp, act := false, dataNode.Hide; act != exp { + t.Errorf("expected dataNode.Hide=%v, got %v", exp, act) + } + }) +} - result1 := c.Modules["mod"].Rules[1].Body - expected1 := MustParseBody(`v = [null | true]; data.b[i] = j; xs = [x | a = [y | y = data.c[j]; y != 1]; a[i] = x]; z = [true | i2 = i; data.a.b.d.t with input as i2]; xs[j] > 0`) - if !result1.Equal(expected1) { - t.Errorf("Expected reordered body to be equal to:\n%v\nBut got:\n%v", expected1, result1) +func depth(n *TreeNode) int { + d := -1 + for _, m := range n.Children { + if d0 := depth(m); d0 > d { + d = d0 + } } + return d + 1 +} - result2 := c.Modules["mod"].Rules[2].Body - expected2 := MustParseBody(`_ = [x | x = data.b[i]]; _ = data.b[j]; _ = [x | x = true; x != false]; true != false; _ = [x | data.foo[_] = x]; data.foo[_] = _`) - if !result2.Equal(expected2) { - t.Errorf("Expected pre-ordered body to equal:\n%v\nBut got:\n%v", expected2, result2) +func TestModuleTreeFilenameOrder(t *testing.T) { + // NOTE(sr): It doesn't matter that these are conflicting; but that's where it + // becomes very apparent: before this change, the rule that was reported as + // "conflicting" was that of either one of the input files, randomly. + mods := map[string]*Module{ + "0.rego": MustParseModule("package p\nr = 1 { true }"), + "1.rego": MustParseModule("package p\nr = 2 { true }"), } - - result3 := c.Modules["mod"].Rules[3].Body - expected3 := MustParseBody(`a = [x | data.compr.fn("...foo.bar..", y); split(y, ".", z); x = z[i]]`) - if !result3.Equal(expected3) { - t.Errorf("Expected pre-ordered body to equal:\n%v\nBut got:\n%v", expected3, result3) + tree := NewModuleTree(mods) + vals := tree.Children[Var("data")].Children[String("p")].Modules + if exp, act := 2, len(vals); exp != act { + t.Fatalf("expected %d rules, found %d", exp, act) } -} - -func TestCompilerCheckSafetyBodyErrors(t *testing.T) { - - moduleBegin := ` - package a.b - - import input.aref.b.c as foo - import input.avar as bar - import data.m.n as baz - ` - - tests := []struct { - note string - moduleContent string - expected string - }{ - {"ref-head", `p { a.b.c = "foo" }`, `{a,}`}, - {"ref-head-2", `p { {"foo": [{"bar": a.b.c}]} = {"foo": [{"bar": "baz"}]} }`, `{a,}`}, - {"negation", `p { a = [1, 2, 3, 4]; not a[i] = x }`, `{i, x}`}, - {"negation-head", `p[x] { a = [1, 2, 3, 4]; not a[i] = x }`, `{i,x}`}, - {"negation-multiple", `p { a = [1, 2, 3, 4]; b = [1, 2, 3, 4]; not a[i] = x; not b[j] = x }`, `{i, x, j}`}, - {"negation-nested", `p { a = [{"foo": ["bar", "baz"]}]; not a[0].foo = [a[0].foo[i], a[0].foo[j]] } `, `{i, j}`}, - {"builtin-input", `p { count([1, 2, x], x) }`, `{x,}`}, - {"builtin-input-name", `p { count(eq, 1) }`, `{eq,}`}, - {"builtin-multiple", `p { x > 0; x <= 3; x != 2 }`, `{x,}`}, - {"unordered-object-keys", `p { x = "a"; [{x: y, z: a}] = [{"a": 1, "b": 2}]}`, `{a,y,z}`}, - {"unordered-sets", `p { x = "a"; [{x, y}] = [{1, 2}]}`, `{y,}`}, - {"array-compr", `p { _ = [x | x = data.a[_]; y > 1] }`, `{y,}`}, - {"array-compr-nested", `p { _ = [x | x = a[_]; a = [y | y = data.a[_]; z > 1]] }`, `{z,}`}, - {"array-compr-closure", `p { _ = [v | v = [x | x = data.a[_]]; x > 1] }`, `{x,}`}, - {"array-compr-term", `p { _ = [u | true] }`, `{u,}`}, - {"array-compr-term-nested", `p { _ = [v | v = [w | w != 0]] }`, `{w,}`}, - {"array-compr-mixed", `p { _ = [x | y = [a | a = z[i]]] }`, `{a, x, z, i}`}, - {"array-compr-builtin", `p { [true | eq != 2] }`, `{eq,}`}, - {"closure-self", `p { x = [x | x = 1] }`, `{x,}`}, - {"closure-transitive", `p { x = y; x = [y | y = 1] }`, `{y,}`}, - {"nested", `p { count(baz[i].attr[bar[dead.beef]], n) }`, `{dead,}`}, - {"negated-import", `p { not foo; not bar; not baz }`, `set()`}, - {"rewritten", `p[{"foo": dead[i]}] { true }`, `{dead, i}`}, - {"with-value", `p { data.a.b.d.t with input as x }`, `{x,}`}, - {"with-value-2", `p { x = data.a.b.d.t with input as x }`, `{x,}`}, - {"else-kw", "p { false } else { count(x, 1) }", `{x,}`}, - {"function", "foo(x) = [y, z] { split(x, y, z) }", `{y,z}`}, - {"call-vars-input", "p { f(x, x) } f(x) = x { true }", `{x,}`}, - {"call-no-output", "p { f(x) } f(x) = x { true }", `{x,}`}, - {"call-too-few", "p { f(1,x) } f(x,y) { true }", "{x,}"}, - {"object-key-comprehension", "p { { {p|x}: 0 } }", "{x,}"}, - {"set-value-comprehension", "p { {1, {p|x}} }", "{x,}"}, + mod0 := vals[0] + mod1 := vals[1] + if exp, act := IntNumberTerm(1), mod0.Rules[0].Head.Value; !exp.Equal(act) { + t.Errorf("expected value %v, got %v", exp, act) } - - makeErrMsg := func(varName string) string { - return fmt.Sprintf("rego_unsafe_var_error: var %v is unsafe", varName) + if exp, act := IntNumberTerm(2), mod1.Rules[0].Head.Value; !exp.Equal(act) { + t.Errorf("expected value %v, got %v", exp, act) } +} +func TestRuleTree(t *testing.T) { - for _, tc := range tests { - t.Run(tc.note, func(t *testing.T) { - - // Build slice of expected error messages. - expected := []string{} - - _ = MustParseTerm(tc.expected).Value.(Set).Iter(func(x *Term) error { - expected = append(expected, makeErrMsg(string(x.Value.(Var)))) - return nil - }) // cannot return error - - sort.Strings(expected) + mods := getCompilerTestModules() + mods["system-mod"] = MustParseModule(` + package system.foo - // Compile test module. - c := NewCompiler() - c.Modules = map[string]*Module{ - "newMod": MustParseModule(fmt.Sprintf(` + p = 1 + `) + mods["non-system-mod"] = MustParseModule(` + package user.system - %v + p = 1`) + mods["mod-incr"] = MustParseModule(` + package a.b.c - %v + s[1] { true } + s[2] { true }`, + ) - `, moduleBegin, tc.moduleContent)), - } + mods["dots-in-heads"] = MustParseModule(` + package dots - compileStages(c, c.checkSafetyRuleBodies) + a.b.c = 12 + d.e.f.g = 34 + `) - // Get errors. - result := compilerErrsToStringSlice(c.Errors) + tree := NewRuleTree(NewModuleTree(mods)) + expectedNumRules := 25 - // Check against expected. - if len(result) != len(expected) { - t.Fatalf("Expected %d:\n%v\nBut got %d:\n%v", len(expected), strings.Join(expected, "\n"), len(result), strings.Join(result, "\n")) - } + if tree.Size() != expectedNumRules { + t.Errorf("Expected %v but got %v rules", expectedNumRules, tree.Size()) + } - for i := range result { - if expected[i] != result[i] { - t.Errorf("Expected %v but got: %v", expected[i], result[i]) - } - } + // Check that empty packages are represented as leaves with no rules. + node := tree.Children[Var("data")].Children[String("a")].Children[String("b")].Children[String("empty")] + if node == nil || len(node.Children) != 0 || len(node.Values) != 0 { + t.Fatalf("Unexpected nil value or non-empty leaf of non-leaf node: %v", node) + } - }) + // Check that root node is not hidden + if exp, act := false, tree.Hide; act != exp { + t.Errorf("expected tree.Hide=%v, got %v", exp, act) } -} -func TestCompilerCheckSafetyVarLoc(t *testing.T) { + system := tree.Child(Var("data")).Child(String("system")) + if !system.Hide { + t.Fatalf("Expected system node to be hidden: %v", system) + } - _, err := CompileModules(map[string]string{"test.rego": `package test + if system.Child(String("foo")).Hide { + t.Fatalf("Expected system.foo node to be visible") + } -p { - not x - x > y -}`}) + user := tree.Child(Var("data")).Child(String("user")).Child(String("system")) + if user.Hide { + t.Fatalf("Expected user.system node to be visible") + } - if err == nil { - t.Fatal("expected error") + if !isVirtual(tree, MustParseRef("data.a.b.empty")) { + t.Fatal("Expected data.a.b.empty to be virtual") } - errs := err.(Errors) + abc := tree.Children[Var("data")].Children[String("a")].Children[String("b")].Children[String("c")] + exp := []Value{String("p"), String("q"), String("r"), String("s"), String("z")} - if !strings.Contains(errs[0].Message, "var x is unsafe") || errs[0].Location.Row != 4 { - t.Fatal("expected error on row 4 but got:", err) + if len(abc.Sorted) != len(exp) { + t.Fatal("expected", exp, "but got", abc) } - if !strings.Contains(errs[1].Message, "var y is unsafe") || errs[1].Location.Row != 5 { - t.Fatal("expected y is unsafe on row 5 but got:", err) + for i := range exp { + if exp[i].Compare(abc.Sorted[i]) != 0 { + t.Fatal("expected", exp, "but got", abc) + } } } -func TestCompilerCheckTypes(t *testing.T) { +func TestCompilerEmpty(t *testing.T) { c := NewCompiler() - modules := getCompilerTestModules() - c.Modules = map[string]*Module{"mod6": modules["mod6"], "mod7": modules["mod7"]} - compileStages(c, c.checkTypes) + c.Compile(nil) assertNotFailed(t, c) } -func TestCompilerCheckTypesWithSchema(t *testing.T) { +func TestCompilerExample(t *testing.T) { c := NewCompiler() - var schema interface{} - err := util.Unmarshal([]byte(objectSchema), &schema) - if err != nil { - t.Fatal("Unexpected error:", err) - } - schemaSet := NewSchemaSet() - schemaSet.Put(SchemaRootRef, schema) - c.WithSchemas(schemaSet) - compileStages(c, c.checkTypes) + m := MustParseModule(testModule) + c.Compile(map[string]*Module{"testMod": m}) assertNotFailed(t, c) } -func TestCompilerCheckTypesWithAllOfSchema(t *testing.T) { +func TestCompilerWithStageAfter(t *testing.T) { + t.Run("after failing means overall failure", func(t *testing.T) { + c := NewCompiler().WithStageAfter( + "CheckRecursion", + CompilerStageDefinition{"MockStage", "mock_stage", + func(*Compiler) *Error { return NewError(CompileErr, &Location{}, "mock stage error") }}, + ) + m := MustParseModule(testModule) + c.Compile(map[string]*Module{"testMod": m}) + + if !c.Failed() { + t.Errorf("Expected compilation error") + } + }) + + t.Run("first 'after' failure inhibits other 'after' stages", func(t *testing.T) { + c := NewCompiler(). + WithStageAfter("CheckRecursion", + CompilerStageDefinition{"MockStage", "mock_stage", + func(*Compiler) *Error { return NewError(CompileErr, &Location{}, "mock stage error") }}). + WithStageAfter("CheckRecursion", + CompilerStageDefinition{"MockStage2", "mock_stage2", + func(*Compiler) *Error { return NewError(CompileErr, &Location{}, "mock stage error two") }}, + ) + m := MustParseModule(`package p +q := true`) + + c.Compile(map[string]*Module{"testMod": m}) + + if !c.Failed() { + t.Errorf("Expected compilation error") + } + if exp, act := 1, len(c.Errors); exp != act { + t.Errorf("expected %d errors, got %d: %v", exp, act, c.Errors) + } + }) + + t.Run("'after' failure inhibits other ordinary stages", func(t *testing.T) { + c := NewCompiler(). + WithStageAfter("CheckRecursion", + CompilerStageDefinition{"MockStage", "mock_stage", + func(*Compiler) *Error { return NewError(CompileErr, &Location{}, "mock stage error") }}) + m := MustParseModule(`package p +q { + 1 == "a" # would fail "CheckTypes", the next stage +} +`) + c.Compile(map[string]*Module{"testMod": m}) + + if !c.Failed() { + t.Errorf("Expected compilation error") + } + if exp, act := 1, len(c.Errors); exp != act { + t.Errorf("expected %d errors, got %d: %v", exp, act, c.Errors) + } + }) +} +func TestCompilerFunctions(t *testing.T) { tests := []struct { - note string - schema string - expectedError error + note string + modules []string + wantErr bool }{ { - note: "allOf with mergeable Object types in schema", - schema: allOfObjectSchema, - expectedError: nil, - }, - { - note: "allOf with mergeable Array types in schema", - schema: allOfArraySchema, - expectedError: nil, - }, - { - note: "allOf without a parent schema", - schema: allOfSchemaParentVariation, - expectedError: nil, - }, - { - note: "allOf with empty schema", - schema: emptySchema, - expectedError: nil, + note: "multiple input types", + modules: []string{`package x + + f([x]) = y { + y = x + } + + f({"foo": x}) = y { + y = x + }`}, }, { - note: "allOf with mergeable Array of Object types in schema", - schema: allOfArrayOfObjects, - expectedError: nil, + note: "multiple input types", + modules: []string{`package x + + f([x]) = y { + y = x + } + + f([[x]]) = y { + y = x + }`}, }, { - note: "allOf with mergeable Object types in schema with type declaration missing", - schema: allOfObjectMissing, - expectedError: nil, + note: "constant input", + modules: []string{`package x + + f(1) = y { + y = "foo" + } + + f(2) = y { + y = "bar" + }`}, }, { - note: "allOf with Array of mergeable different types in schema", - schema: allOfArrayDifTypes, - expectedError: nil, + note: "constant input", + modules: []string{`package x + + f(1, x) = y { + y = x + } + + f(x, y) = z { + z = x+y + }`}, }, { - note: "allOf with mergeable Object containing Array types in schema", - schema: allOfArrayInsideObject, - expectedError: nil, + note: "constant input", + modules: []string{`package x + + f(x, 1) = y { + y = x + } + + f(x, [y]) = z { + z = x+y + }`}, }, { - note: "allOf with mergeable Array types in schema with type declaration missing", - schema: allOfArrayMissing, - expectedError: nil, + note: "multiple input types (nested)", + modules: []string{`package x + + f({"foo": {"bar": x}}) = y { + y = x + } + + f({"foo": [x]}) = y { + y = x + }`}, }, { - note: "allOf with mergeable types inside of core schema", - schema: allOfInsideCoreSchema, - expectedError: nil, - }, - { - note: "allOf with mergeable String types in schema", - schema: allOfStringSchema, - expectedError: nil, - }, - { - note: "allOf with mergeable Integer types in schema", - schema: allOfIntegerSchema, - expectedError: nil, - }, - { - note: "allOf with mergeable Boolean types in schema", - schema: allOfBooleanSchema, - expectedError: nil, - }, - { - note: "allOf with mergeable Array types with uneven numbers of items", - schema: allOfSchemaWithUnevenArray, - expectedError: nil, + note: "multiple output types", + modules: []string{`package x + + f(1) = y { + y = "foo" + } + + f(2) = y { + y = 2 + }`}, }, { - note: "allOf schema with unmergeable Array of Arrays", - schema: allOfArrayOfArrays, - expectedError: fmt.Errorf("unable to merge these schemas"), + note: "namespacing", + modules: []string{ + `package x + + f(x) = y { + data.y.f[x] = y + }`, + `package y + + f[x] = y { + y = "bar" + x = "foo" + }`, + }, }, { - note: "allOf schema with Array and Object types as siblings", - schema: allOfObjectAndArray, - expectedError: fmt.Errorf("unable to merge these schemas"), + note: "implicit value", + modules: []string{ + `package x + + f(x) { + x = "foo" + }`}, }, { - note: "allOf schema with Array type that contains different unmergeable types", - schema: allOfArrayDifTypesWithError, - expectedError: fmt.Errorf("unable to merge these schemas"), + note: "resolving", + modules: []string{ + `package x + + f(x) = x { true }`, + + `package y + + import data.x + import data.x.f as g + + p { g(1, a) } + p { x.f(1, b) } + p { data.x.f(1, c) } + `, + }, }, { - note: "allOf schema with different unmergeable types", - schema: allOfTypeErrorSchema, - expectedError: fmt.Errorf("unable to merge these schemas"), + note: "undefined", + modules: []string{ + `package x + + p { + f(1) + }`, + }, + wantErr: true, }, { - note: "allOf unmergeable schema with different parent and items types", - schema: allOfSchemaWithParentError, - expectedError: fmt.Errorf("unable to merge these schemas"), + note: "must apply", + modules: []string{ + `package x + + f(1) + + p { + f + } + `, + }, + wantErr: true, }, { - note: "allOf schema of Array type with uneven numbers of items to merge", - schema: allOfSchemaWithUnevenArray, - expectedError: nil, + note: "must apply", + modules: []string{ + `package x + f(1) + p { f.x }`, + }, + wantErr: true, }, { - note: "allOf schema with unmergeable types String and Boolean", - schema: allOfStringSchemaWithError, - expectedError: fmt.Errorf("unable to merge these schemas"), + note: "call argument ref output vars", + modules: []string{ + `package x + + f(x) + + p { f(data.foo[i]) }`, + }, + wantErr: false, }, } - for _, tc := range tests { t.Run(tc.note, func(t *testing.T) { - c := NewCompiler() - var schema interface{} - err := util.Unmarshal([]byte(tc.schema), &schema) - if err != nil { - t.Fatal("Unexpected error:", err) - } - schemaSet := NewSchemaSet() - schemaSet.Put(SchemaRootRef, schema) - c.WithSchemas(schemaSet) - compileStages(c, c.checkTypes) - if tc.expectedError != nil { - if errors.Is(c.Errors, tc.expectedError) { - t.Fatal("Unexpected error:", err) + var err error + modules := map[string]*Module{} + for i, module := range tc.modules { + name := fmt.Sprintf("mod%d", i) + modules[name], err = ParseModule(name, module) + if err != nil { + panic(err) } - } else { - assertNotFailed(t, c) + } + c := NewCompiler() + c.Compile(modules) + if tc.wantErr && !c.Failed() { + t.Errorf("Expected compilation error") + } else if !tc.wantErr && c.Failed() { + t.Errorf("Unexpected compilation error(s): %v", c.Errors) } }) } } -func TestCompilerCheckRuleConflicts(t *testing.T) { - - c := getCompilerWithParsedModules(map[string]string{ - "mod1.rego": `package badrules +func TestCompilerErrorLimit(t *testing.T) { + modules := map[string]*Module{ + "test": MustParseModule(`package test + r = y { y = true; x = z } -p[x] { x = 1 } -p[x] = y { x = y; x = "a" } -q[1] { true } -q = {1, 2, 3} { true } -r[x] = y { x = y; x = "a" } -r[x] = y { x = y; x = "a" }`, + s[x] = y { + z = y + x + } - "mod2.rego": `package badrules.r + t[x] { split(x, y, z) } + `), + } -q[1] { true }`, + c := NewCompiler().SetErrorLimit(2) + c.Compile(modules) - "mod3.rego": `package badrules.defkw + errs := c.Errors + exp := []string{ + "2:20: rego_unsafe_var_error: var x is unsafe", + "2:20: rego_unsafe_var_error: var z is unsafe", + "rego_compile_error: error limit reached", + } -default foo = 1 -default foo = 2 -foo = 3 { true }`, - "mod4.rego": `package adrules.arity + result := make([]string, 0, len(errs)) + for _, err := range errs { + result = append(result, err.Error()) + } -f(1) { true } -f { true } + sort.Strings(exp) + sort.Strings(result) + if !reflect.DeepEqual(exp, result) { + t.Errorf("Expected errors %v, got %v", exp, result) + } +} -g(1) { true } -g(1,2) { true }`, - "mod5.rego": `package badrules.dataoverlap +func TestCompilerCheckSafetyHead(t *testing.T) { + c := NewCompiler() + c.Modules = getCompilerTestModules() + popts := ParserOptions{AllFutureKeywords: true, unreleasedKeywords: true} + c.Modules["newMod"] = MustParseModuleWithOpts(`package a.b + +unboundKey[x1] = y { q[y] = {"foo": [1, 2, [{"bar": y}]]} } +unboundVal[y] = x2 { q[y] = {"foo": [1, 2, [{"bar": y}]]} } +unboundCompositeVal[y] = [{"foo": x3, "bar": y}] { q[y] = {"foo": [1, 2, [{"bar": y}]]} } +unboundCompositeKey[[{"x": x4}]] { q[y] } +unboundBuiltinOperator = eq { 4 = 1 } +unboundElse { false } else = else_var { true } +c.d.e[x5] if true +f.g.h[y] = x6 if y := "y" +i.j.k contains x7 if true +`, popts) + compileStages(c, c.checkSafetyRuleHeads) -p { true }`, - "mod6.rego": `package badrules.existserr + makeErrMsg := func(v string) string { + return fmt.Sprintf("rego_unsafe_var_error: var %v is unsafe", v) + } -p { true }`, - "mod7.rego": `package badrules.redeclaration + expected := []string{ + makeErrMsg("x1"), + makeErrMsg("x2"), + makeErrMsg("x3"), + makeErrMsg("x4"), + makeErrMsg("x5"), + makeErrMsg("x6"), + makeErrMsg("x7"), + makeErrMsg("eq"), + makeErrMsg("else_var"), + } -p1 := 1 -p1 := 2 + result := compilerErrsToStringSlice(c.Errors) + sort.Strings(expected) -p2 = 1 -p2 := 2`}) + if len(result) != len(expected) { + t.Fatalf("Expected %d:\n%v\nBut got %d:\n%v", len(expected), strings.Join(expected, "\n"), len(result), strings.Join(result, "\n")) + } - c.WithPathConflictsCheck(func(path []string) (bool, error) { - if reflect.DeepEqual(path, []string{"badrules", "dataoverlap", "p"}) { - return true, nil - } else if reflect.DeepEqual(path, []string{"badrules", "existserr", "p"}) { - return false, fmt.Errorf("unexpected error") + for i := range result { + if expected[i] != result[i] { + t.Errorf("Expected %v but got: %v", expected[i], result[i]) } - return false, nil - }) + } - compileStages(c, c.checkRuleConflicts) +} - expected := []string{ - "rego_compile_error: conflict check for data path badrules/existserr/p: unexpected error", - "rego_compile_error: conflicting rule for data path badrules/dataoverlap/p found", - "rego_type_error: conflicting rules named f found", - "rego_type_error: conflicting rules named g found", - "rego_type_error: conflicting rules named p found", - "rego_type_error: conflicting rules named q found", - "rego_type_error: multiple default rules named foo found", - "rego_type_error: package badrules.r conflicts with rule defined at mod1.rego:7", - "rego_type_error: package badrules.r conflicts with rule defined at mod1.rego:8", - "rego_type_error: rule named p1 redeclared at mod7.rego:4", - "rego_type_error: rule named p2 redeclared at mod7.rego:7", +func TestCompilerCheckSafetyBodyReordering(t *testing.T) { + tests := []struct { + note string + body string + expected string + }{ + {"noop", `x = 1; x != 0`, `x = 1; x != 0`}, + {"var/ref", `a[i] = x; a = [1, 2, 3, 4]`, `a = [1, 2, 3, 4]; a[i] = x`}, + {"var/ref (nested)", `a = [1, 2, 3, 4]; a[b[i]] = x; b = [0, 0, 0, 0]`, `a = [1, 2, 3, 4]; b = [0, 0, 0, 0]; a[b[i]] = x`}, + {"negation", + `a = [true, false]; b = [true, false]; not a[i]; b[i]`, + `a = [true, false]; b = [true, false]; b[i]; not a[i]`}, + {"built-in", `x != 0; count([1, 2, 3], x)`, `count([1, 2, 3], x); x != 0`}, + {"var/var 1", `x = y; z = 1; y = z`, `z = 1; y = z; x = y`}, + {"var/var 2", `x = y; 1 = z; z = y`, `1 = z; z = y; x = y`}, + {"var/var 3", `x != 0; y = x; y = 1`, `y = 1; y = x; x != 0`}, + {"array compr/var", `x != 0; [y | y = 1] = x`, `[y | y = 1] = x; x != 0`}, + {"array compr/array", `[1] != [x]; [y | y = 1] = [x]`, `[y | y = 1] = [x]; [1] != [x]`}, + {"with", `data.a.b.d.t with input as x; x = 1`, `x = 1; data.a.b.d.t with input as x`}, + {"with-2", `data.a.b.d.t with input.x as x; x = 1`, `x = 1; data.a.b.d.t with input.x as x`}, + {"with-nop", "data.somedoc[x] with input as true", "data.somedoc[x] with input as true"}, + {"ref-head", `s = [["foo"], ["bar"]]; x = y[0]; y = s[_]; contains(x, "oo")`, ` + s = [["foo"], ["bar"]]; + y = s[_]; + x = y[0]; + contains(x, "oo") + `}, + {"userfunc", `split(y, ".", z); data.a.b.funcs.fn("...foo.bar..", y)`, `data.a.b.funcs.fn("...foo.bar..", y); split(y, ".", z)`}, + {"every", `every _ in [] { x != 1 }; x = 1`, `__local4__ = []; x = 1; every __local3__, _ in __local4__ { x != 1}`}, + {"every-domain", `every _ in xs { true }; xs = [1]`, `xs = [1]; __local4__ = xs; every __local3__, _ in __local4__ { true }`}, } - assertCompilerErrorStrings(t, c, expected) -} + for i, tc := range tests { + t.Run(tc.note, func(t *testing.T) { + opts := ParserOptions{AllFutureKeywords: true, unreleasedKeywords: true} + c := NewCompiler() + c.Modules = getCompilerTestModules() + c.Modules["reordering"] = MustParseModuleWithOpts(fmt.Sprintf( + `package test + p { %s }`, tc.body), opts) -func TestCompilerCheckUndefinedFuncs(t *testing.T) { + compileStages(c, c.checkSafetyRuleBodies) - module := ` - package test + if c.Failed() { + t.Errorf("%v (#%d): Unexpected compilation error: %v", tc.note, i, c.Errors) + return + } - undefined_function { - data.deadbeef(x) - } + expected := MustParseBodyWithOpts(tc.expected, opts) + result := c.Modules["reordering"].Rules[0].Body - undefined_global { - deadbeef(x) - } + if !expected.Equal(result) { + t.Errorf("%v (#%d): Expected body to be ordered and equal to %v but got: %v", tc.note, i, expected, result) + } + }) + } +} - # NOTE: all the dynamic dispatch examples here are not supported, - # we're checking assertions about the error returned. - undefined_dynamic_dispatch { - x = "f"; data.test2[x](1) - } +func TestCompilerCheckSafetyBodyReorderingClosures(t *testing.T) { + opts := ParserOptions{AllFutureKeywords: true, unreleasedKeywords: true} - undefined_dynamic_dispatch_declared_var { - y := "f"; data.test2[y](1) - } - - undefined_dynamic_dispatch_declared_var_in_array { - z := "f"; data.test2[[z]](1) - } - - arity_mismatch_1 { - data.test2.f(1,2,3) - } + tests := []struct { + note string + mod *Module + exp Body + }{ + { + note: "comprehensions-1", + mod: MustParseModule(`package compr - arity_mismatch_2 { - data.test2.f() - } +import data.b +import data.c +p = true { v = [null | true]; xs = [x | a[i] = x; a = [y | y != 1; y = c[j]]]; xs[j] > 0; z = [true | data.a.b.d.t with input as i2; i2 = i]; b[i] = j } +`), + exp: MustParseBody(`v = [null | true]; data.b[i] = j; xs = [x | a = [y | y = data.c[j]; y != 1]; a[i] = x]; xs[j] > 0; z = [true | i2 = i; data.a.b.d.t with input as i2]`), + }, + { + note: "comprehensions-2", + mod: MustParseModule(`package compr - arity_mismatch_3 { - x:= data.test2.f() - } - ` +import data.b +import data.c +q = true { _ = [x | x = b[i]]; _ = b[j]; _ = [x | x = true; x != false]; true != false; _ = [x | data.foo[_] = x]; data.foo[_] = _ } +`), + exp: MustParseBody(`_ = [x | x = data.b[i]]; _ = data.b[j]; _ = [x | x = true; x != false]; true != false; _ = [x | data.foo[_] = x]; data.foo[_] = _`), + }, - module2 := ` - package test2 + { + note: "comprehensions-3", + mod: MustParseModule(`package compr - f(x) = x - ` +import data.b +import data.c +fn(x) = y { + trim(x, ".", y) +} +r = true { a = [x | split(y, ".", z); x = z[i]; fn("...foo.bar..", y)] } +`), + exp: MustParseBody(`a = [x | data.compr.fn("...foo.bar..", y); split(y, ".", z); x = z[i]]`), + }, + { + note: "closure over function output", + mod: MustParseModule(`package test +import future.keywords - _, err := CompileModules(map[string]string{ - "test.rego": module, - "test2.rego": module2, - }) - if err == nil { - t.Fatal("expected errors") +p { + object.get(input.subject.roles[_], comp, [""], output) + comp = [ 1 | true ] + every y in [2] { + y in output } - - result := err.Error() - want := []string{ - "rego_type_error: undefined function data.deadbeef", - "rego_type_error: undefined function deadbeef", - "rego_type_error: undefined function data.test2[x]", - "rego_type_error: undefined function data.test2[y]", - "rego_type_error: undefined function data.test2[[z]]", - "rego_type_error: function data.test2.f has arity 1, got 3 arguments", - "test.rego:31: rego_type_error: function data.test2.f has arity 1, got 0 arguments", - "test.rego:35: rego_type_error: function data.test2.f has arity 1, got 0 arguments", +}`), + exp: MustParseBodyWithOpts(`comp = [1 | true] + __local2__ = [2] + object.get(input.subject.roles[_], comp, [""], output) + every __local0__, __local1__ in __local2__ { internal.member_2(__local1__, output) }`, opts), + }, } - for _, w := range want { - if !strings.Contains(result, w) { - t.Fatalf("Expected %q in result but got: %v", w, result) - } + + for _, tc := range tests { + t.Run(tc.note, func(t *testing.T) { + c := NewCompiler() + c.Modules = map[string]*Module{"mod": tc.mod} + compileStages(c, c.checkSafetyRuleBodies) + assertNotFailed(t, c) + last := len(c.Modules["mod"].Rules) - 1 + actual := c.Modules["mod"].Rules[last].Body + if !actual.Equal(tc.exp) { + t.Errorf("Expected reordered body to be equal to:\n%v\nBut got:\n%v", tc.exp, actual) + } + }) } } -func TestCompilerQueryCompilerCheckUndefinedFuncs(t *testing.T) { - compiler := NewCompiler() +func TestCompilerCheckSafetyBodyErrors(t *testing.T) { - for _, tc := range []struct { - note, query, err string + moduleBegin := ` + package a.b + + import input.aref.b.c as foo + import input.avar as bar + import data.m.n as baz + ` + + tests := []struct { + note string + moduleContent string + expected string }{ + {"ref-head", `p { a.b.c = "foo" }`, `{a,}`}, + {"ref-head-2", `p { {"foo": [{"bar": a.b.c}]} = {"foo": [{"bar": "baz"}]} }`, `{a,}`}, + {"negation", `p { a = [1, 2, 3, 4]; not a[i] = x }`, `{i, x}`}, + {"negation-head", `p[x] { a = [1, 2, 3, 4]; not a[i] = x }`, `{i,x}`}, + {"negation-multiple", `p { a = [1, 2, 3, 4]; b = [1, 2, 3, 4]; not a[i] = x; not b[j] = x }`, `{i, x, j}`}, + {"negation-nested", `p { a = [{"foo": ["bar", "baz"]}]; not a[0].foo = [a[0].foo[i], a[0].foo[j]] } `, `{i, j}`}, + {"builtin-input", `p { count([1, 2, x], x) }`, `{x,}`}, + {"builtin-input-name", `p { count(eq, 1) }`, `{eq,}`}, + {"builtin-multiple", `p { x > 0; x <= 3; x != 2 }`, `{x,}`}, + {"unordered-object-keys", `p { x = "a"; [{x: y, z: a}] = [{"a": 1, "b": 2}]}`, `{a,y,z}`}, + {"unordered-sets", `p { x = "a"; [{x, y}] = [{1, 2}]}`, `{y,}`}, + {"array-compr", `p { _ = [x | x = data.a[_]; y > 1] }`, `{y,}`}, + {"array-compr-nested", `p { _ = [x | x = a[_]; a = [y | y = data.a[_]; z > 1]] }`, `{z,}`}, + {"array-compr-closure", `p { _ = [v | v = [x | x = data.a[_]]; x > 1] }`, `{x,}`}, + {"array-compr-term", `p { _ = [u | true] }`, `{u,}`}, + {"array-compr-term-nested", `p { _ = [v | v = [w | w != 0]] }`, `{w,}`}, + {"array-compr-mixed", `p { _ = [x | y = [a | a = z[i]]] }`, `{a, x, z, i}`}, + {"array-compr-builtin", `p { [true | eq != 2] }`, `{eq,}`}, + {"closure-self", `p { x = [x | x = 1] }`, `{x,}`}, + {"closure-transitive", `p { x = y; x = [y | y = 1] }`, `{x,y}`}, + {"nested", `p { count(baz[i].attr[bar[dead.beef]], n) }`, `{dead,}`}, + {"negated-import", `p { not foo; not bar; not baz }`, `set()`}, + {"rewritten", `p[{"foo": dead[i]}] { true }`, `{dead, i}`}, + {"with-value", `p { data.a.b.d.t with input as x }`, `{x,}`}, + {"with-value-2", `p { x = data.a.b.d.t with input as x }`, `{x,}`}, + {"else-kw", "p { false } else { count(x, 1) }", `{x,}`}, + {"function", "foo(x) = [y, z] { split(x, y, z) }", `{y,z}`}, + {"call-vars-input", "p { f(x, x) } f(x) = x { true }", `{x,}`}, + {"call-no-output", "p { f(x) } f(x) = x { true }", `{x,}`}, + {"call-too-few", "p { f(1,x) } f(x,y) { true }", "{x,}"}, + {"object-key-comprehension", "p { { {p|x}: 0 } }", "{x,}"}, + {"set-value-comprehension", "p { {1, {p|x}} }", "{x,}"}, + {"every", "p { every y in [10] { x > y } }", "{x,}"}, + } - {note: "undefined function", query: `data.foo(1)`, err: "undefined function data.foo"}, - {note: "undefined global function", query: `foo(1)`, err: "undefined function foo"}, - {note: "var", query: `x = "f"; data[x](1)`, err: "undefined function data[x]"}, - {note: "declared var", query: `x := "f"; data[x](1)`, err: "undefined function data[x]"}, - {note: "declared var in array", query: `x := "f"; data[[x]](1)`, err: "undefined function data[[x]]"}, - } { + makeErrMsg := func(varName string) string { + return fmt.Sprintf("rego_unsafe_var_error: var %v is unsafe", varName) + } + + for _, tc := range tests { t.Run(tc.note, func(t *testing.T) { - _, err := compiler.QueryCompiler().Compile(MustParseBody(tc.query)) - if !strings.Contains(err.Error(), tc.err) { - t.Errorf("Unexpected compilation error: %v (want %s)", err, tc.err) + + // Build slice of expected error messages. + expected := []string{} + + _ = MustParseTerm(tc.expected).Value.(Set).Iter(func(x *Term) error { + expected = append(expected, makeErrMsg(string(x.Value.(Var)))) + return nil + }) // cannot return error + + sort.Strings(expected) + + // Compile test module. + popts := ParserOptions{AllFutureKeywords: true, unreleasedKeywords: true} + c := NewCompiler() + c.Modules = map[string]*Module{ + "newMod": MustParseModuleWithOpts(fmt.Sprintf(` + + %v + + %v + + `, moduleBegin, tc.moduleContent), popts), + } + + compileStages(c, c.checkSafetyRuleBodies) + + // Get errors. + result := compilerErrsToStringSlice(c.Errors) + + // Check against expected. + if len(result) != len(expected) { + t.Fatalf("Expected %d:\n%v\nBut got %d:\n%v", len(expected), strings.Join(expected, "\n"), len(result), strings.Join(result, "\n")) } + + for i := range result { + if expected[i] != result[i] { + t.Errorf("Expected %v but got: %v", expected[i], result[i]) + } + } + }) } } -func TestCompilerImportsResolved(t *testing.T) { +func TestCompilerCheckSafetyVarLoc(t *testing.T) { - modules := map[string]*Module{ - "mod1": MustParseModule(`package ex + _, err := CompileModules(map[string]string{"test.rego": `package test -import data -import input -import data.foo -import input.bar -import data.abc as baz -import input.abc as qux`, - ), +p { + not x + x > y +}`}) + + if err == nil { + t.Fatal("expected error") } - c := NewCompiler() - c.Compile(modules) + errs := err.(Errors) - assertNotFailed(t, c) + if !strings.Contains(errs[0].Message, "var x is unsafe") || errs[0].Location.Row != 4 { + t.Fatal("expected error on row 4 but got:", err) + } - if len(c.Modules["mod1"].Imports) != 0 { - t.Fatalf("Expected imports to be empty after compile but got: %v", c.Modules) + if !strings.Contains(errs[1].Message, "var y is unsafe") || errs[1].Location.Row != 5 { + t.Fatal("expected y is unsafe on row 5 but got:", err) + } +} + +func TestCompilerCheckSafetyFunctionAndContainsKeyword(t *testing.T) { + _, err := CompileModules(map[string]string{"test.rego": `package play + + import future.keywords.contains + + p(id) contains x { + x := id + }`}) + if err == nil { + t.Fatal("expected error") } + errs := err.(Errors) + if !strings.Contains(errs[0].Message, "the contains keyword can only be used with multi-value rule definitions (e.g., p contains { ... })") { + t.Fatal("wrong error message:", err) + } + if errs[0].Location.Row != 5 { + t.Fatal("expected error on line 5 but got:", errs[0].Location.Row) + } } -func TestCompilerExprExpansion(t *testing.T) { +func TestCompilerCheckTypes(t *testing.T) { + c := NewCompiler() + modules := getCompilerTestModules() + c.Modules = map[string]*Module{"mod6": modules["mod6"], "mod7": modules["mod7"]} + compileStages(c, c.checkTypes) + assertNotFailed(t, c) +} - tests := []struct { - note string - input string - expected []*Expr - }{ - { - note: "identity", - input: "x", - expected: []*Expr{ - MustParseExpr("x"), - }, - }, - { - note: "single", - input: "x+y", - expected: []*Expr{ - MustParseExpr("x+y"), - }, - }, - { - note: "chained", - input: "x+y+z+w", - expected: []*Expr{ - MustParseExpr("plus(x, y, __local0__)"), - MustParseExpr("plus(__local0__, z, __local1__)"), - MustParseExpr("plus(__local1__, w)"), - }, - }, - { - note: "assoc", - input: "x+y*z", - expected: []*Expr{ - MustParseExpr("mul(y, z, __local0__)"), +func TestCompilerCheckRuleConflicts(t *testing.T) { + + c := getCompilerWithParsedModules(map[string]string{ + "mod1.rego": `package badrules + +p[x] { x = 1 } +p[x] = y { x = y; x = "a" } +q[1] { true } +q = {1, 2, 3} { true } +r[x] = y { x = y; x = "a" } +r[x] = y { x = y; x = "a" }`, + + "mod2.rego": `package badrules.r + +q[1] { true }`, + + "mod3.rego": `package badrules.defkw + +default foo = 1 +default foo = 2 +foo = 3 { true } + +default p.q.bar = 1 +default p.q.bar = 2 +p.q.bar = 3 { true } +`, + "mod4.rego": `package badrules.arity + +f(1) { true } +f { true } + +g(1) { true } +g(1,2) { true } + +p.q.h(1) { true } +p.q.h { true } + +p.q.i(1) { true } +p.q.i(1,2) { true }`, + "mod5.rego": `package badrules.dataoverlap + +p { true }`, + "mod6.rego": `package badrules.existserr + +p { true }`, + + "mod7.rego": `package badrules.foo +import future.keywords + +bar.baz contains "quz" if true`, + }) + + c.WithPathConflictsCheck(func(path []string) (bool, error) { + if reflect.DeepEqual(path, []string{"badrules", "dataoverlap", "p"}) { + return true, nil + } else if reflect.DeepEqual(path, []string{"badrules", "existserr", "p"}) { + return false, fmt.Errorf("unexpected error") + } + return false, nil + }) + + compileStages(c, c.checkRuleConflicts) + + expected := []string{ + "rego_compile_error: conflict check for data path badrules/existserr/p: unexpected error", + "rego_compile_error: conflicting rule for data path badrules/dataoverlap/p found", + "rego_type_error: conflicting rules data.badrules.arity.f found", + "rego_type_error: conflicting rules data.badrules.arity.g found", + "rego_type_error: conflicting rules data.badrules.arity.p.q.h found", + "rego_type_error: conflicting rules data.badrules.arity.p.q.i found", + "rego_type_error: conflicting rules data.badrules.p[x] found", + "rego_type_error: conflicting rules data.badrules.q found", + "rego_type_error: multiple default rules data.badrules.defkw.foo found", + "rego_type_error: multiple default rules data.badrules.defkw.p.q.bar found", + "rego_type_error: package badrules.r conflicts with rule r[x] defined at mod1.rego:7", + "rego_type_error: package badrules.r conflicts with rule r[x] defined at mod1.rego:8", + } + + assertCompilerErrorStrings(t, c, expected) +} + +func TestCompilerCheckRuleConflictsDotsInRuleHeads(t *testing.T) { + + tests := []struct { + note string + modules []*Module + err string + }{ + { + note: "arity mismatch, ref and non-ref rule", + modules: modules( + `package pkg + p.q.r { true }`, + `package pkg.p.q + r(_) = 2`), + err: "rego_type_error: conflicting rules data.pkg.p.q.r found", + }, + { + note: "two default rules, ref and non-ref rule", + modules: modules( + `package pkg + default p.q.r = 3 + p.q.r { true }`, + `package pkg.p.q + default r = 4 + r = 2`), + err: "rego_type_error: multiple default rules data.pkg.p.q.r found", + }, + { + note: "arity mismatch, ref and ref rule", + modules: modules( + `package pkg.a.b + p.q.r { true }`, + `package pkg.a + b.p.q.r(_) = 2`), + err: "rego_type_error: conflicting rules data.pkg.a.b.p.q.r found", + }, + { + note: "two default rules, ref and ref rule", + modules: modules( + `package pkg + default p.q.w.r = 3 + p.q.w.r { true }`, + `package pkg.p + default q.w.r = 4 + q.w.r = 2`), + err: "rego_type_error: multiple default rules data.pkg.p.q.w.r found", + }, + { + note: "multi-value + single-value rules, both with same ref prefix", + modules: modules( + `package pkg + p.q.w[x] = 1 if x := "foo"`, + `package pkg + p.q.w contains "bar"`), + err: "rego_type_error: conflicting rules data.pkg.p.q.w found", + }, + { + note: "two multi-value rules, both with same ref", + modules: modules( + `package pkg + p.q.w contains "baz"`, + `package pkg + p.q.w contains "bar"`), + }, + { + note: "module conflict: non-ref rule", + modules: modules( + `package pkg.q + r { true }`, + `package pkg.q.r`), + err: "rego_type_error: package pkg.q.r conflicts with rule r defined at mod0.rego:2", + }, + { + note: "module conflict: ref rule", + modules: modules( + `package pkg + p.q.r { true }`, + `package pkg.p.q.r`), + err: "rego_type_error: package pkg.p.q.r conflicts with rule p.q.r defined at mod0.rego:2", + }, + { + note: "single-value with other rule overlap", + modules: modules( + `package pkg + p.q.r { true }`, + `package pkg + p.q.r.s { true }`), + err: "rego_type_error: single-value rule data.pkg.p.q.r conflicts with [data.pkg.p.q.r.s]", + }, + { + note: "single-value with other rule overlap", + modules: modules( + `package pkg + p.q.r { true } + p.q.r.s { true } + p.q.r.t { true }`), + err: "rego_type_error: single-value rule data.pkg.p.q.r conflicts with [data.pkg.p.q.r.s data.pkg.p.q.r.t]", + }, + { + note: "single-value with other rule overlap, unknown key", + modules: modules( + `package pkg + p.q[r] = x { r = input.key; x = input.foo } + p.q.r.s = x { true } + `), + err: "rego_type_error: single-value rule data.pkg.p.q[r] conflicts with [data.pkg.p.q.r.s]", + }, + { + note: "single-value partial object with other partial object rule overlap, unknown keys (regression test for #5855)", + modules: modules( + `package pkg + p[r] := x { r = input.key; x = input.bar } + p.q[r] := x { r = input.key; x = input.bar } + `), + err: "rego_type_error: single-value rule data.pkg.p[r] conflicts with [data.pkg.p.q[r]]", + }, + { + note: "single-value partial object with other partial object (implicit 'true' value) rule overlap, unknown keys", + modules: modules( + `package pkg + p[r] := x { r = input.key; x = input.bar } + p.q[r] { r = input.key } + `), + err: "rego_type_error: single-value rule data.pkg.p[r] conflicts with [data.pkg.p.q[r]]", + }, + { + note: "single-value partial object with multi-value rule (ref head) overlap, unknown key", + modules: modules( + `package pkg + import future.keywords + p[r] := x { r = input.key; x = input.bar } + p.q contains r { r = input.key } + `), + }, + { + note: "single-value partial object with multi-value rule overlap, unknown key", + modules: modules( + `package pkg + p[r] := x { r = input.key; x = input.bar } + p.q { true } + `), + err: "rego_type_error: conflicting rules data.pkg.p found", + }, + { + note: "single-value rule with known and unknown key", + modules: modules( + `package pkg + p.q[r] = x { r = input.key; x = input.foo } + p.q.s = "x" { true } + `), + }, + { + note: "multi-value rule with other rule overlap", + modules: modules( + `package pkg + p[v] { v := ["a", "b"][_] } + p.q := 42 + `), + err: "rego_type_error: multi-value rule data.pkg.p conflicts with [data.pkg.p.q]", + }, + { + note: "multi-value rule with other rule (ref) overlap", + modules: modules( + `package pkg + p[v] { v := ["a", "b"][_] } + p.q.r { true } + `), + err: "rego_type_error: multi-value rule data.pkg.p conflicts with [data.pkg.p.q.r]", + }, + } + for _, tc := range tests { + t.Run(tc.note, func(t *testing.T) { + mods := make(map[string]*Module, len(tc.modules)) + for i, m := range tc.modules { + mods[fmt.Sprint(i)] = m + } + c := NewCompiler() + c.Modules = mods + compileStages(c, c.checkRuleConflicts) + if tc.err != "" { + assertCompilerErrorStrings(t, c, []string{tc.err}) + } else { + assertCompilerErrorStrings(t, c, []string{}) + } + }) + } +} + +func TestCompilerCheckUndefinedFuncs(t *testing.T) { + + module := ` + package test + + undefined_function { + data.deadbeef(x) + } + + undefined_global { + deadbeef(x) + } + + # NOTE: all the dynamic dispatch examples here are not supported, + # we're checking assertions about the error returned. + undefined_dynamic_dispatch { + x = "f"; data.test2[x](1) + } + + undefined_dynamic_dispatch_declared_var { + y := "f"; data.test2[y](1) + } + + undefined_dynamic_dispatch_declared_var_in_array { + z := "f"; data.test2[[z]](1) + } + + arity_mismatch_1 { + data.test2.f(1,2,3) + } + + arity_mismatch_2 { + data.test2.f() + } + + arity_mismatch_3 { + x:= data.test2.f() + } + ` + + module2 := ` + package test2 + + f(x) = x + ` + + _, err := CompileModules(map[string]string{ + "test.rego": module, + "test2.rego": module2, + }) + if err == nil { + t.Fatal("expected errors") + } + + result := err.Error() + want := []string{ + "rego_type_error: undefined function data.deadbeef", + "rego_type_error: undefined function deadbeef", + "rego_type_error: undefined function data.test2[x]", + "rego_type_error: undefined function data.test2[y]", + "rego_type_error: undefined function data.test2[[z]]", + "rego_type_error: function data.test2.f has arity 1, got 3 arguments", + "test.rego:31: rego_type_error: function data.test2.f has arity 1, got 0 arguments", + "test.rego:35: rego_type_error: function data.test2.f has arity 1, got 0 arguments", + } + for _, w := range want { + if !strings.Contains(result, w) { + t.Fatalf("Expected %q in result but got: %v", w, result) + } + } +} + +func TestCompilerQueryCompilerCheckUndefinedFuncs(t *testing.T) { + compiler := NewCompiler() + + for _, tc := range []struct { + note, query, err string + }{ + + {note: "undefined function", query: `data.foo(1)`, err: "undefined function data.foo"}, + {note: "undefined global function", query: `foo(1)`, err: "undefined function foo"}, + {note: "var", query: `x = "f"; data[x](1)`, err: "undefined function data[x]"}, + {note: "declared var", query: `x := "f"; data[x](1)`, err: "undefined function data[x]"}, + {note: "declared var in array", query: `x := "f"; data[[x]](1)`, err: "undefined function data[[x]]"}, + } { + t.Run(tc.note, func(t *testing.T) { + _, err := compiler.QueryCompiler().Compile(MustParseBody(tc.query)) + if !strings.Contains(err.Error(), tc.err) { + t.Errorf("Unexpected compilation error: %v (want %s)", err, tc.err) + } + }) + } +} + +func TestCompilerImportsResolved(t *testing.T) { + + modules := map[string]*Module{ + "mod1": MustParseModule(`package ex + +import data +import input +import data.foo +import input.bar +import data.abc as baz +import input.abc as qux`, + ), + } + + c := NewCompiler() + c.Compile(modules) + + assertNotFailed(t, c) + + if len(c.Modules["mod1"].Imports) != 0 { + t.Fatalf("Expected imports to be empty after compile but got: %v", c.Modules) + } + +} + +func TestCompilerExprExpansion(t *testing.T) { + + tests := []struct { + note string + input string + expected []*Expr + }{ + { + note: "identity", + input: "x", + expected: []*Expr{ + MustParseExpr("x"), + }, + }, + { + note: "single", + input: "x+y", + expected: []*Expr{ + MustParseExpr("x+y"), + }, + }, + { + note: "chained", + input: "x+y+z+w", + expected: []*Expr{ + MustParseExpr("plus(x, y, __local0__)"), + MustParseExpr("plus(__local0__, z, __local1__)"), + MustParseExpr("plus(__local1__, w)"), + }, + }, + { + note: "assoc", + input: "x+y*z", + expected: []*Expr{ + MustParseExpr("mul(y, z, __local0__)"), MustParseExpr("plus(x, __local0__)"), }, }, @@ -1434,13 +2374,26 @@ func TestCompilerRewriteExprTerms(t *testing.T) { f(__local0__[0]) { true; __local0__ = [1] }`, }, + { + note: "every: domain", + module: ` + package test + + p { every x in [1,2] { x } }`, + expected: ` + package test + + p { __local2__ = [1, 2]; every __local0__, __local1__ in __local2__ { __local1__ } }`, + }, } for _, tc := range cases { t.Run(tc.note, func(t *testing.T) { compiler := NewCompiler() + opts := ParserOptions{AllFutureKeywords: true, unreleasedKeywords: true} + compiler.Modules = map[string]*Module{ - "test": MustParseModule(tc.module), + "test": MustParseModuleWithOpts(tc.module, opts), } compileStages(compiler, compiler.rewriteExprTerms) @@ -1448,31 +2401,13 @@ func TestCompilerRewriteExprTerms(t *testing.T) { case string: assertNotFailed(t, compiler) - expected := MustParseModule(exp) + expected := MustParseModuleWithOpts(exp, opts) if !expected.Equal(compiler.Modules["test"]) { t.Fatalf("Expected modules to be equal. Expected:\n\n%v\n\nGot:\n\n%v", expected, compiler.Modules["test"]) } case Errors: - if len(exp) != len(compiler.Errors) { - t.Fatalf("Expected %d errors, got %d:\n\n%s\n", len(exp), len(compiler.Errors), compiler.Errors.Error()) - } - incorrectErrs := false - for _, e := range exp { - found := false - for _, actual := range compiler.Errors { - if e.Message == actual.Message { - found = true - break - } - } - if !found { - incorrectErrs = true - } - } - if incorrectErrs { - t.Fatalf("Expected errors:\n\n%s\n\nGot:\n\n%s\n", exp.Error(), compiler.Errors.Error()) - } + assertErrors(t, compiler.Errors, exp, false) default: t.Fatalf("Unsupported value type for test case 'expected' field: %v", exp) } @@ -1481,1769 +2416,4660 @@ func TestCompilerRewriteExprTerms(t *testing.T) { } } -func TestCompilerResolveAllRefs(t *testing.T) { - c := NewCompiler() - c.Modules = getCompilerTestModules() - c.Modules["head"] = MustParseModule(`package head - -import data.doc1 as bar -import input.x.y.foo -import input.qux as baz - -p[foo[bar[i]]] = {"baz": baz} { true }`) - - c.Modules["elsekw"] = MustParseModule(`package elsekw - - import input.x.y.foo - import data.doc1 as bar - import input.baz - - p { - false - } else = foo { - bar - } else = baz { - true +func TestIllegalFunctionCallRewrite(t *testing.T) { + cases := []struct { + note string + module string + expectedErrors []string + }{ + /*{ + note: "function call override in function value", + module: `package test + foo(x) := x + + p := foo(bar) { + #foo := 1 + bar := 2 + }`, + expectedErrors: []string{ + "undefined function foo", + }, + },*/ + { + note: "function call override in array comprehension value", + module: `package test +p := [foo(bar) | foo := 1; bar := 2]`, + expectedErrors: []string{ + "called function foo shadowed", + }, + }, + { + note: "function call override in set comprehension value", + module: `package test +p := {foo(bar) | foo := 1; bar := 2}`, + expectedErrors: []string{ + "called function foo shadowed", + }, + }, + { + note: "function call override in object comprehension value", + module: `package test +p := {foo(bar): bar(foo) | foo := 1; bar := 2}`, + expectedErrors: []string{ + "called function bar shadowed", + "called function foo shadowed", + }, + }, + { + note: "function call override in array comprehension value", + module: `package test +p := [foo.bar(baz) | foo := 1; bar := 2; baz := 3]`, + expectedErrors: []string{ + "called function foo.bar shadowed", + }, + }, + { + note: "nested function call override in array comprehension value", + module: `package test +p := [baz(foo(bar)) | foo := 1; bar := 2]`, + expectedErrors: []string{ + "called function foo shadowed", + }, + }, + { + note: "function call override of 'input' root document", + module: `package test +p := [input() | input := 1]`, + expectedErrors: []string{ + "called function input shadowed", + }, + }, + { + note: "function call override of 'data' root document", + module: `package test +p := [data() | data := 1]`, + expectedErrors: []string{ + "called function data shadowed", + }, + }, } - `) - - c.Modules["nestedexprs"] = MustParseModule(`package nestedexprs - - x = 1 - - p { - f(g(x)) - }`) - - c.Modules["assign"] = MustParseModule(`package assign - - x = 1 - y = 1 - - p { - x := y - [true | x := y] - }`) - - c.Modules["someinassign"] = MustParseModule(`package someinassign - import future.keywords.in - x = 1 - y = 1 - p[x] { - some x in [1, 2, y] - }`) - - c.Modules["someinassignwithkey"] = MustParseModule(`package someinassignwithkey - import future.keywords.in - x = 1 - y = 1 - - p[x] { - some k, v in [1, 2, y] - }`) - - c.Modules["donotresolve"] = MustParseModule(`package donotresolve - - x = 1 - - f(x) { - x = 2 - } - `) - - c.Modules["indirectrefs"] = MustParseModule(`package indirectrefs + for _, tc := range cases { + t.Run(tc.note, func(t *testing.T) { + compiler := NewCompiler() + opts := ParserOptions{AllFutureKeywords: true, unreleasedKeywords: true} - f(x) = [x] {true} + compiler.Modules = map[string]*Module{ + "test": MustParseModuleWithOpts(tc.module, opts), + } + compileStages(compiler, compiler.rewriteLocalVars) - p { - f(1)[0] - } - `) + result := make([]string, 0, len(compiler.Errors)) + for i := range compiler.Errors { + result = append(result, compiler.Errors[i].Message) + } - c.Modules["comprehensions"] = MustParseModule(`package comprehensions + sort.Strings(tc.expectedErrors) + sort.Strings(result) - nums = [1, 2, 3] + if len(tc.expectedErrors) != len(result) { + t.Fatalf("Expected %d errors but got %d:\n\n%v\n\nGot:\n\n%v", + len(tc.expectedErrors), len(result), + strings.Join(tc.expectedErrors, "\n"), strings.Join(result, "\n")) + } - f(x) = [x] {true} + for i := range result { + if result[i] != tc.expectedErrors[i] { + t.Fatalf("Expected:\n\n%v\n\nGot:\n\n%v", + strings.Join(tc.expectedErrors, "\n"), strings.Join(result, "\n")) + } + } + }) + } +} - p[[1]] {true} +func TestCompilerCheckUnusedImports(t *testing.T) { + cases := []strictnessTestCase{ + { + note: "simple unused: input ref with same name", + module: `package p + import data.foo.bar as bar + r { + input.bar == 11 + } + `, + expectedErrors: Errors{ + &Error{ + Location: NewLocation([]byte("import"), "", 2, 4), + Message: "import data.foo.bar as bar unused", + }, + }, + }, + { + note: "unused import, but imported ref used", + module: `package p + import data.foo # unused + r { data.foo == 10 } + `, + expectedErrors: Errors{ + &Error{ + Location: NewLocation([]byte("import"), "", 2, 4), + Message: "import data.foo unused", + }, + }, + }, + { + note: "one of two unused", + module: `package p + import data.foo + import data.x.power #unused + r { foo == 10 } + `, + expectedErrors: Errors{ + &Error{ + Location: NewLocation([]byte("import"), "", 3, 4), + Message: "import data.x.power unused", + }, + }, + }, + { + note: "multiple unused: with input ref of same name", + module: `package p + import data.foo + import data.x.power + r { input.foo == 10 } + `, + expectedErrors: Errors{ + &Error{ + Location: NewLocation([]byte("import"), "", 2, 4), + Message: "import data.foo unused", + }, + &Error{ + Location: NewLocation([]byte("import"), "", 3, 4), + Message: "import data.x.power unused", + }, + }, + }, + { + note: "import used in comparison", + module: `package p + import data.foo.x + r { x == 10 } + `, + }, + { + note: "multiple used imports in one rule", + module: `package p + import data.foo.x + import data.power.ranger + r { ranger == x } + `, + }, + { + note: "multiple used imports in separate rules", + module: `package p + import data.foo.x + import data.power.ranger + r { ranger == 23 } + t { x == 1 } + `, + }, + { + note: "import used as function operand", + module: `package p + import data.foo + r = count(foo) > 1 # only one operand + `, + }, + { + note: "import used as function operand, compount term", + module: `package p + import data.foo + r = sprintf("%v %d", [foo, 0]) + `, + }, + { + note: "import used as plain term", + module: `package p + import data.foo + r { + foo + } + `, + }, + { + note: "import used in 'every' domain", + module: `package p + import future.keywords.every + import data.foo + r { + every x in foo { x > 1 } + } + `, + }, + { + note: "import used in 'every' body", + module: `package p + import future.keywords.every + import data.foo + r { + every x in [1,2,3] { x > foo } + } + `, + }, + { + note: "future import kept even if unused", + module: `package p + import future.keywords - q { - p[[x | x = nums[_]]] - } + r { true } + `, + }, + { + note: "shadowed var name in function arg", + module: `package p + import data.foo # unused - r = [y | y = f(1)[0]] - `) + r { f(1) } + f(foo) = foo == 1 + `, + expectedErrors: Errors{ + &Error{ + Location: NewLocation([]byte("import"), "", 2, 4), + Message: "import data.foo unused", + }, + }, + }, + { + note: "shadowed assigned var name", + module: `package p + import data.foo # unused - compileStages(c, c.resolveAllRefs) - assertNotFailed(t, c) + r { foo := true; foo } + `, + expectedErrors: Errors{ + &Error{ + Location: NewLocation([]byte("import"), "", 2, 4), + Message: "import data.foo unused", + }, + }, + }, + { + note: "used as rule value", + module: `package p + import data.bar # unused + import data.foo - // Basic test cases. - mod1 := c.Modules["mod1"] - p := mod1.Rules[0] - expr1 := p.Body[0] - term := expr1.Terms.(*Term) - e := MustParseTerm("data.a.b.c.q[x]") - if !term.Equal(e) { - t.Errorf("Wrong term (global in same module): expected %v but got: %v", e, term) - } + r = foo { true } + `, + expectedErrors: Errors{ + &Error{ + Location: NewLocation([]byte("import"), "", 2, 4), + Message: "import data.bar unused", + }, + }, + }, + { + note: "unused as rule value (but same data ref)", + module: `package p + import data.bar # unused + import data.foo # unused - expr2 := p.Body[1] - term = expr2.Terms.(*Term) - e = MustParseTerm("data.a.b.c.r[x]") - if !term.Equal(e) { - t.Errorf("Wrong term (global in same package/diff module): expected %v but got: %v", e, term) + r = data.foo { true } + `, + expectedErrors: Errors{ + &Error{ + Location: NewLocation([]byte("import"), "", 2, 4), + Message: "import data.bar unused", + }, + &Error{ + Location: NewLocation([]byte("import"), "", 3, 4), + Message: "import data.foo unused", + }, + }, + }, } - mod2 := c.Modules["mod2"] - r := mod2.Rules[0] - expr3 := r.Body[1] - term = expr3.Terms.([]*Term)[1] - e = MustParseTerm("data.x.y.p") - if !term.Equal(e) { - t.Errorf("Wrong term (var import): expected %v but got: %v", e, term) - } + runStrictnessTestCase(t, cases, true) +} - mod3 := c.Modules["mod3"] - expr4 := mod3.Rules[0].Body[0] - term = expr4.Terms.([]*Term)[2] - e = MustParseTerm("{input.x.secret: [{input.x.keyid}]}") - if !term.Equal(e) { - t.Errorf("Wrong term (nested refs): expected %v but got: %v", e, term) - } +func TestCompilerCheckDuplicateImports(t *testing.T) { + cases := []strictnessTestCase{ + { + note: "shadow", + module: `package test + import input.noconflict + import input.foo + import data.foo + import data.bar.foo - // Array comprehensions. - mod5 := c.Modules["mod5"] + p := noconflict + q := foo + `, + expectedErrors: Errors{ + &Error{ + Location: NewLocation([]byte("import"), "", 4, 5), + Message: "import must not shadow import input.foo", + }, + &Error{ + Location: NewLocation([]byte("import"), "", 5, 5), + Message: "import must not shadow import input.foo", + }, + }, + }, { + note: "alias shadow", + module: `package test + import input.noconflict + import input.foo + import input.bar as foo - ac := func(r *Rule) *ArrayComprehension { - return r.Body[0].Terms.(*Term).Value.(*ArrayComprehension) + p := noconflict + q := foo + `, + expectedErrors: Errors{ + &Error{ + Location: NewLocation([]byte("import"), "", 4, 5), + Message: "import must not shadow import input.foo", + }, + }, + }, } - acTerm1 := ac(mod5.Rules[0]) - assertTermEqual(t, acTerm1.Term, MustParseTerm("input.x.a")) - acTerm2 := ac(mod5.Rules[1]) - assertTermEqual(t, acTerm2.Term, MustParseTerm("data.a.b.c.q.a")) - acTerm3 := ac(mod5.Rules[2]) - assertTermEqual(t, acTerm3.Body[0].Terms.([]*Term)[1], MustParseTerm("input.x.a")) - acTerm4 := ac(mod5.Rules[3]) - assertTermEqual(t, acTerm4.Body[0].Terms.([]*Term)[1], MustParseTerm("data.a.b.c.q[i]")) - acTerm5 := ac(mod5.Rules[4]) - assertTermEqual(t, acTerm5.Body[0].Terms.([]*Term)[2].Value.(*ArrayComprehension).Term, MustParseTerm("input.x.a")) - acTerm6 := ac(mod5.Rules[5]) - assertTermEqual(t, acTerm6.Body[0].Terms.([]*Term)[2].Value.(*ArrayComprehension).Body[0].Terms.([]*Term)[1], MustParseTerm("data.a.b.c.q[i]")) - - // Nested references. - mod6 := c.Modules["mod6"] - nested1 := mod6.Rules[0].Body[0].Terms.(*Term) - assertTermEqual(t, nested1, MustParseTerm("data.x[input.x[i].a[data.z.b[j]]]")) - - nested2 := mod6.Rules[1].Body[1].Terms.(*Term) - assertTermEqual(t, nested2, MustParseTerm("v[input.x[i]]")) + runStrictnessTestCase(t, cases, true) +} - nested3 := mod6.Rules[3].Body[0].Terms.(*Term) - assertTermEqual(t, nested3, MustParseTerm("data.x[data.a.b.nested.r]")) +func TestCompilerCheckKeywordOverrides(t *testing.T) { + cases := []strictnessTestCase{ + { + note: "rule names", + module: `package test + input { true } + p { true } + data { true } + `, + expectedErrors: Errors{ + &Error{ + Location: NewLocation([]byte("input { true }"), "", 2, 5), + Message: "rules must not shadow input (use a different rule name)", + }, + &Error{ + Location: NewLocation([]byte("data { true }"), "", 4, 5), + Message: "rules must not shadow data (use a different rule name)", + }, + }, + }, + { + note: "global assignments", + module: `package test + input = 1 + p := 2 + data := 3 + `, + expectedErrors: Errors{ + &Error{ + Location: NewLocation([]byte("input = 1"), "", 2, 5), + Message: "rules must not shadow input (use a different rule name)", + }, + &Error{ + Location: NewLocation([]byte("data := 3"), "", 4, 5), + Message: "rules must not shadow data (use a different rule name)", + }, + }, + }, + { + note: "rule-local assignments", + module: `package test + p { + input := 1 + x := 2 + } else { + data := 3 + } + q { + input := 4 + } + `, + expectedErrors: Errors{ + &Error{ + Location: NewLocation([]byte("input := 1"), "", 3, 6), + Message: "variables must not shadow input (use a different variable name)", + }, + &Error{ + Location: NewLocation([]byte("data := 3"), "", 6, 6), + Message: "variables must not shadow data (use a different variable name)", + }, + &Error{ + Location: NewLocation([]byte("input := 4"), "", 9, 6), + Message: "variables must not shadow input (use a different variable name)", + }, + }, + }, + { + note: "array comprehension-local assignments", + module: `package test + p = [ x | + input := 1 + x := 2 + data := 3 + ] + `, + expectedErrors: Errors{ + &Error{ + Location: NewLocation([]byte("input := 1"), "", 3, 6), + Message: "variables must not shadow input (use a different variable name)", + }, + &Error{ + Location: NewLocation([]byte("data := 3"), "", 5, 6), + Message: "variables must not shadow data (use a different variable name)", + }, + }, + }, + { + note: "set comprehension-local assignments", + module: `package test + p = { x | + input := 1 + x := 2 + data := 3 + } + `, + expectedErrors: Errors{ + &Error{ + Location: NewLocation([]byte("input := 1"), "", 3, 6), + Message: "variables must not shadow input (use a different variable name)", + }, + &Error{ + Location: NewLocation([]byte("data := 3"), "", 5, 6), + Message: "variables must not shadow data (use a different variable name)", + }, + }, + }, + { + note: "object comprehension-local assignments", + module: `package test + p = { x: 1 | + input := 1 + x := 2 + data := 3 + } + `, + expectedErrors: Errors{ + &Error{ + Location: NewLocation([]byte("input := 1"), "", 3, 6), + Message: "variables must not shadow input (use a different variable name)", + }, + &Error{ + Location: NewLocation([]byte("data := 3"), "", 5, 6), + Message: "variables must not shadow data (use a different variable name)", + }, + }, + }, + { + note: "nested override", + module: `package test + p { + [ x | + input := 1 + x := 2 + data := 3 + ] + } + `, + expectedErrors: Errors{ + &Error{ + Location: NewLocation([]byte("input := 1"), "", 4, 7), + Message: "variables must not shadow input (use a different variable name)", + }, + &Error{ + Location: NewLocation([]byte("data := 3"), "", 6, 7), + Message: "variables must not shadow data (use a different variable name)", + }, + }, + }, + } - // Refs in head. - mod7 := c.Modules["head"] - assertTermEqual(t, mod7.Rules[0].Head.Key, MustParseTerm("input.x.y.foo[data.doc1[i]]")) - assertTermEqual(t, mod7.Rules[0].Head.Value, MustParseTerm(`{"baz": input.qux}`)) + runStrictnessTestCase(t, cases, true) +} - // Refs in else. - mod8 := c.Modules["elsekw"] - assertTermEqual(t, mod8.Rules[0].Else.Head.Value, MustParseTerm("input.x.y.foo")) - assertTermEqual(t, mod8.Rules[0].Else.Body[0].Terms.(*Term), MustParseTerm("data.doc1")) - assertTermEqual(t, mod8.Rules[0].Else.Else.Head.Value, MustParseTerm("input.baz")) +func TestCompilerCheckDeprecatedMethods(t *testing.T) { + cases := []strictnessTestCase{ + { + note: "all() built-in", + module: `package test + p := all([true, false]) + `, + expectedErrors: Errors{ + &Error{ + Location: NewLocation([]byte("all([true, false])"), "", 2, 10), + Message: "deprecated built-in function calls in expression: all", + }, + }, + }, + { + note: "user-defined all()", + module: `package test + import future.keywords.in + all(arr) = {x | some x in arr} == {true} + p := all([true, false]) + `, + }, + { + note: "any() built-in", + module: `package test + p := any([true, false]) + `, + expectedErrors: Errors{ + &Error{ + Location: NewLocation([]byte("any([true, false])"), "", 2, 10), + Message: "deprecated built-in function calls in expression: any", + }, + }, + }, + { + note: "user-defined any()", + module: `package test + import future.keywords.in + any(arr) := true in arr + p := any([true, false]) + `, + }, + { + note: "re_match built-in", + module: `package test + p := re_match("[a]", "a") + `, + expectedErrors: Errors{ + &Error{ + Location: NewLocation([]byte(`re_match("[a]", "a")`), "", 2, 10), + Message: "deprecated built-in function calls in expression: re_match", + }, + }, + }, + } - // Refs in calls. - mod9 := c.Modules["nestedexprs"] - assertTermEqual(t, mod9.Rules[1].Body[0].Terms.([]*Term)[1], CallTerm(RefTerm(VarTerm("g")), MustParseTerm("data.nestedexprs.x"))) + runStrictnessTestCase(t, cases, true) +} - // Ignore assigned vars. - mod10 := c.Modules["assign"] - assertTermEqual(t, mod10.Rules[2].Body[0].Terms.([]*Term)[1], VarTerm("x")) - assertTermEqual(t, mod10.Rules[2].Body[0].Terms.([]*Term)[2], MustParseTerm("data.assign.y")) - assignCompr := mod10.Rules[2].Body[1].Terms.(*Term).Value.(*ArrayComprehension) - assertTermEqual(t, assignCompr.Body[0].Terms.([]*Term)[1], VarTerm("x")) - assertTermEqual(t, assignCompr.Body[0].Terms.([]*Term)[2], MustParseTerm("data.assign.y")) +type strictnessTestCase struct { + note string + module string + expectedErrors Errors +} - // Args - mod11 := c.Modules["donotresolve"] - assertTermEqual(t, mod11.Rules[1].Head.Args[0], VarTerm("x")) - assertExprEqual(t, mod11.Rules[1].Body[0], MustParseExpr("x = 2")) +func runStrictnessTestCase(t *testing.T, cases []strictnessTestCase, assertLocation bool) { + t.Helper() + makeTestRunner := func(tc strictnessTestCase, strict bool) func(t *testing.T) { + return func(t *testing.T) { + compiler := NewCompiler().WithStrict(strict) + compiler.Modules = map[string]*Module{ + "test": MustParseModule(tc.module), + } + compileStages(compiler, nil) - // Locations. - parsedLoc := getCompilerTestModules()["mod1"].Rules[0].Body[0].Terms.(*Term).Value.(Ref)[0].Location - compiledLoc := c.Modules["mod1"].Rules[0].Body[0].Terms.(*Term).Value.(Ref)[0].Location - if parsedLoc.Row != compiledLoc.Row { - t.Fatalf("Expected parsed location (%v) and compiled location (%v) to be equal", parsedLoc.Row, compiledLoc.Row) + if strict { + assertErrors(t, compiler.Errors, tc.expectedErrors, assertLocation) + } else { + assertNotFailed(t, compiler) + } + } } - // Indirect references. - mod12 := c.Modules["indirectrefs"] - assertExprEqual(t, mod12.Rules[1].Body[0], MustParseExpr("data.indirectrefs.f(1)[0]")) + for _, tc := range cases { + t.Run(tc.note+"_strict", makeTestRunner(tc, true)) + t.Run(tc.note+"_non-strict", makeTestRunner(tc, false)) + } +} - // Comprehensions - mod13 := c.Modules["comprehensions"] - assertExprEqual(t, mod13.Rules[3].Body[0].Terms.(*Term).Value.(Ref)[3].Value.(*ArrayComprehension).Body[0], MustParseExpr("x = data.comprehensions.nums[_]")) - assertExprEqual(t, mod13.Rules[4].Head.Value.Value.(*ArrayComprehension).Body[0], MustParseExpr("y = data.comprehensions.f(1)[0]")) +func assertErrors(t *testing.T, actual Errors, expected Errors, assertLocation bool) { + t.Helper() + if len(expected) != len(actual) { + t.Fatalf("Expected %d errors, got %d:\n\n%s\n", len(expected), len(actual), actual.Error()) + } + incorrectErrs := false + for _, e := range expected { + found := false + for _, actual := range actual { + if e.Message == actual.Message { + if !assertLocation || e.Location.Equal(actual.Location) { + found = true + break + } + } + } + if !found { + incorrectErrs = true + } + } + if incorrectErrs { + t.Fatalf("Expected errors:\n\n%s\n\nGot:\n\n%s\n", expected.Error(), actual.Error()) + } +} - // Ignore vars assigned via `some x in xs`. - mod14 := c.Modules["someinassign"] - someInAssignCall := mod14.Rules[2].Body[0].Terms.(*SomeDecl).Symbols[0].Value.(Call) - assertTermEqual(t, someInAssignCall[1], VarTerm("x")) - collectionLastElem := someInAssignCall[2].Value.(*Array).Get(IntNumberTerm(2)) - assertTermEqual(t, collectionLastElem, MustParseTerm("data.someinassign.y")) +// NOTE(sr): the tests below this function are unwieldy, let's keep adding new ones to this one +func TestCompilerResolveAllRefsNewTests(t *testing.T) { + tests := []struct { + note string + mod string + exp string + extra string + }{ + { + note: "ref-rules referenced in body", + mod: `package test +a.b.c = 1 +q if a.b.c == 1 +`, + exp: `package test +a.b.c = 1 { true } +q if data.test.a.b.c = 1 +`, + }, + { + // NOTE(sr): This is a conservative extension of how it worked before: + // we will not automatically extend references to other parts of the rule tree, + // only to ref rules defined on the same level. + note: "ref-rules from other module referenced in body", + mod: `package test +q if a.b.c == 1 +`, + extra: `package test +a.b.c = 1 +`, + exp: `package test +q if data.test.a.b.c = 1 +`, + }, + { + note: "single-value rule in comprehension in call", // NOTE(sr): this is TestRego/partialiter/objects_conflict + mod: `package test +p := count([x | q[x]]) +q[1] = 1 +`, + exp: `package test +p := __local0__ { true; __local1__ = [x | data.test.q[x]]; count(__local1__, __local0__) } +q[1] = 1 +`, + }, + } + + for _, tc := range tests { + t.Run(tc.note, func(t *testing.T) { + opts := ParserOptions{AllFutureKeywords: true, unreleasedKeywords: true} + c := NewCompiler() + mod, err := ParseModuleWithOpts("test.rego", tc.mod, opts) + if err != nil { + t.Fatal(err) + } + exp, err := ParseModuleWithOpts("test.rego", tc.exp, opts) + if err != nil { + t.Fatal(err) + } + mods := map[string]*Module{"test": mod} + if tc.extra != "" { + extra, err := ParseModuleWithOpts("test.rego", tc.extra, opts) + if err != nil { + t.Fatal(err) + } + mods["extra"] = extra + } + c.Compile(mods) + if err := c.Errors; len(err) > 0 { + t.Errorf("compile module: %v", err) + } + if act := c.Modules["test"]; !exp.Equal(act) { + t.Errorf("compiled: expected %v, got %v", exp, act) + } + }) + } +} + +func TestCompilerResolveAllRefs(t *testing.T) { + c := NewCompiler() + c.Modules = getCompilerTestModules() + c.Modules["head"] = MustParseModule(`package head + +import data.doc1 as bar +import input.x.y.foo +import input.qux as baz + +p[foo[bar[i]]] = {"baz": baz} { true }`) + + c.Modules["elsekw"] = MustParseModule(`package elsekw + + import input.x.y.foo + import data.doc1 as bar + import input.baz + + p { + false + } else = foo { + bar + } else = baz { + true + } + `) + + c.Modules["nestedexprs"] = MustParseModule(`package nestedexprs + + x = 1 + + p { + f(g(x)) + }`) + + c.Modules["assign"] = MustParseModule(`package assign + + x = 1 + y = 1 + + p { + x := y + [true | x := y] + }`) + + c.Modules["someinassign"] = MustParseModule(`package someinassign + import future.keywords.in + x = 1 + y = 1 + + p[x] { + some x in [1, 2, y] + }`) + + c.Modules["someinassignwithkey"] = MustParseModule(`package someinassignwithkey + import future.keywords.in + x = 1 + y = 1 + + p[x] { + some k, v in [1, 2, y] + }`) + + c.Modules["donotresolve"] = MustParseModule(`package donotresolve + + x = 1 + + f(x) { + x = 2 + } + `) + + c.Modules["indirectrefs"] = MustParseModule(`package indirectrefs + + f(x) = [x] {true} + + p { + f(1)[0] + } + `) + + c.Modules["comprehensions"] = MustParseModule(`package comprehensions + + nums = [1, 2, 3] + + f(x) = [x] {true} + + p[[1]] {true} + + q { + p[[x | x = nums[_]]] + } + + r = [y | y = f(1)[0]] + `) + + c.Modules["everykw"] = MustParseModuleWithOpts(`package everykw + + nums = {1, 2, 3} + f(_) = true + x = 100 + xs = [1, 2, 3] + p { + every x in xs { + nums[x] + x > 10 + } + }`, ParserOptions{unreleasedKeywords: true, FutureKeywords: []string{"every", "in"}}) + + c.Modules["heads_with_dots"] = MustParseModule(`package heads_with_dots + + this_is_not = true + this.is.dotted { this_is_not } + `) + + compileStages(c, c.resolveAllRefs) + assertNotFailed(t, c) + + // Basic test cases. + mod1 := c.Modules["mod1"] + p := mod1.Rules[0] + expr1 := p.Body[0] + term := expr1.Terms.(*Term) + e := MustParseTerm("data.a.b.c.q[x]") + if !term.Equal(e) { + t.Errorf("Wrong term (global in same module): expected %v but got: %v", e, term) + } + + expr2 := p.Body[1] + term = expr2.Terms.(*Term) + e = MustParseTerm("data.a.b.c.r[x]") + if !term.Equal(e) { + t.Errorf("Wrong term (global in same package/diff module): expected %v but got: %v", e, term) + } + + mod2 := c.Modules["mod2"] + r := mod2.Rules[0] + expr3 := r.Body[1] + term = expr3.Terms.([]*Term)[1] + e = MustParseTerm("data.x.y.p") + if !term.Equal(e) { + t.Errorf("Wrong term (var import): expected %v but got: %v", e, term) + } + + mod3 := c.Modules["mod3"] + expr4 := mod3.Rules[0].Body[0] + term = expr4.Terms.([]*Term)[2] + e = MustParseTerm("{input.x.secret: [{input.x.keyid}]}") + if !term.Equal(e) { + t.Errorf("Wrong term (nested refs): expected %v but got: %v", e, term) + } + + // Array comprehensions. + mod5 := c.Modules["mod5"] + + ac := func(r *Rule) *ArrayComprehension { + return r.Body[0].Terms.(*Term).Value.(*ArrayComprehension) + } + + acTerm1 := ac(mod5.Rules[0]) + assertTermEqual(t, acTerm1.Term, MustParseTerm("input.x.a")) + acTerm2 := ac(mod5.Rules[1]) + assertTermEqual(t, acTerm2.Term, MustParseTerm("data.a.b.c.q.a")) + acTerm3 := ac(mod5.Rules[2]) + assertTermEqual(t, acTerm3.Body[0].Terms.([]*Term)[1], MustParseTerm("input.x.a")) + acTerm4 := ac(mod5.Rules[3]) + assertTermEqual(t, acTerm4.Body[0].Terms.([]*Term)[1], MustParseTerm("data.a.b.c.q[i]")) + acTerm5 := ac(mod5.Rules[4]) + assertTermEqual(t, acTerm5.Body[0].Terms.([]*Term)[2].Value.(*ArrayComprehension).Term, MustParseTerm("input.x.a")) + acTerm6 := ac(mod5.Rules[5]) + assertTermEqual(t, acTerm6.Body[0].Terms.([]*Term)[2].Value.(*ArrayComprehension).Body[0].Terms.([]*Term)[1], MustParseTerm("data.a.b.c.q[i]")) + + // Nested references. + mod6 := c.Modules["mod6"] + nested1 := mod6.Rules[0].Body[0].Terms.(*Term) + assertTermEqual(t, nested1, MustParseTerm("data.x[input.x[i].a[data.z.b[j]]]")) + + nested2 := mod6.Rules[1].Body[1].Terms.(*Term) + assertTermEqual(t, nested2, MustParseTerm("v[input.x[i]]")) + + nested3 := mod6.Rules[3].Body[0].Terms.(*Term) + assertTermEqual(t, nested3, MustParseTerm("data.x[data.a.b.nested.r]")) + + // Refs in head. + mod7 := c.Modules["head"] + assertTermEqual(t, mod7.Rules[0].Head.Key, MustParseTerm("input.x.y.foo[data.doc1[i]]")) + assertTermEqual(t, mod7.Rules[0].Head.Value, MustParseTerm(`{"baz": input.qux}`)) + + // Refs in else. + mod8 := c.Modules["elsekw"] + assertTermEqual(t, mod8.Rules[0].Else.Head.Value, MustParseTerm("input.x.y.foo")) + assertTermEqual(t, mod8.Rules[0].Else.Body[0].Terms.(*Term), MustParseTerm("data.doc1")) + assertTermEqual(t, mod8.Rules[0].Else.Else.Head.Value, MustParseTerm("input.baz")) + + // Refs in calls. + mod9 := c.Modules["nestedexprs"] + assertTermEqual(t, mod9.Rules[1].Body[0].Terms.([]*Term)[1], CallTerm(RefTerm(VarTerm("g")), MustParseTerm("data.nestedexprs.x"))) + + // Ignore assigned vars. + mod10 := c.Modules["assign"] + assertTermEqual(t, mod10.Rules[2].Body[0].Terms.([]*Term)[1], VarTerm("x")) + assertTermEqual(t, mod10.Rules[2].Body[0].Terms.([]*Term)[2], MustParseTerm("data.assign.y")) + assignCompr := mod10.Rules[2].Body[1].Terms.(*Term).Value.(*ArrayComprehension) + assertTermEqual(t, assignCompr.Body[0].Terms.([]*Term)[1], VarTerm("x")) + assertTermEqual(t, assignCompr.Body[0].Terms.([]*Term)[2], MustParseTerm("data.assign.y")) + + // Args + mod11 := c.Modules["donotresolve"] + assertTermEqual(t, mod11.Rules[1].Head.Args[0], VarTerm("x")) + assertExprEqual(t, mod11.Rules[1].Body[0], MustParseExpr("x = 2")) + + // Locations. + parsedLoc := getCompilerTestModules()["mod1"].Rules[0].Body[0].Terms.(*Term).Value.(Ref)[0].Location + compiledLoc := c.Modules["mod1"].Rules[0].Body[0].Terms.(*Term).Value.(Ref)[0].Location + if parsedLoc.Row != compiledLoc.Row { + t.Fatalf("Expected parsed location (%v) and compiled location (%v) to be equal", parsedLoc.Row, compiledLoc.Row) + } + + // Indirect references. + mod12 := c.Modules["indirectrefs"] + assertExprEqual(t, mod12.Rules[1].Body[0], MustParseExpr("data.indirectrefs.f(1)[0]")) + + // Comprehensions + mod13 := c.Modules["comprehensions"] + assertExprEqual(t, mod13.Rules[3].Body[0].Terms.(*Term).Value.(Ref)[3].Value.(*ArrayComprehension).Body[0], MustParseExpr("x = data.comprehensions.nums[_]")) + assertExprEqual(t, mod13.Rules[4].Head.Value.Value.(*ArrayComprehension).Body[0], MustParseExpr("y = data.comprehensions.f(1)[0]")) + + // Ignore vars assigned via `some x in xs`. + mod14 := c.Modules["someinassign"] + someInAssignCall := mod14.Rules[2].Body[0].Terms.(*SomeDecl).Symbols[0].Value.(Call) + assertTermEqual(t, someInAssignCall[1], VarTerm("x")) + collectionLastElem := someInAssignCall[2].Value.(*Array).Get(IntNumberTerm(2)) + assertTermEqual(t, collectionLastElem, MustParseTerm("data.someinassign.y")) + + // Ignore key and val vars assigned via `some k, v in xs`. + mod15 := c.Modules["someinassignwithkey"] + someInAssignCall = mod15.Rules[2].Body[0].Terms.(*SomeDecl).Symbols[0].Value.(Call) + assertTermEqual(t, someInAssignCall[1], VarTerm("k")) + assertTermEqual(t, someInAssignCall[2], VarTerm("v")) + collectionLastElem = someInAssignCall[3].Value.(*Array).Get(IntNumberTerm(2)) + assertTermEqual(t, collectionLastElem, MustParseTerm("data.someinassignwithkey.y")) + + mod16 := c.Modules["everykw"] + everyExpr := mod16.Rules[len(mod16.Rules)-1].Body[0].Terms.(*Every) + assertTermEqual(t, everyExpr.Body[0].Terms.(*Term), MustParseTerm("data.everykw.nums[x]")) + assertTermEqual(t, everyExpr.Domain, MustParseTerm("data.everykw.xs")) + + // 'x' is not resolved + assertTermEqual(t, everyExpr.Value, VarTerm("x")) + gt10 := MustParseExpr("x > 10") + gt10.Index++ // TODO(sr): why? + assertExprEqual(t, everyExpr.Body[1], gt10) + + // head refs are kept as-is, but their bodies are replaced. + mod := c.Modules["heads_with_dots"] + rule := mod.Rules[1] + body := rule.Body[0].Terms.(*Term) + assertTermEqual(t, body, MustParseTerm("data.heads_with_dots.this_is_not")) + if act, exp := rule.Head.Ref(), MustParseRef("this.is.dotted"); act.Compare(exp) != 0 { + t.Errorf("expected %v to match %v", act, exp) + } +} + +func TestCompilerResolveErrors(t *testing.T) { + + c := NewCompiler() + c.Modules = map[string]*Module{ + "shadow-globals": MustParseModule(` + package shadow_globals + + f([input]) { true } + `), + } + + compileStages(c, c.resolveAllRefs) + + expected := []string{ + `args must not shadow input`, + } + + assertCompilerErrorStrings(t, c, expected) +} + +func TestCompilerRewriteTermsInHead(t *testing.T) { + popts := ParserOptions{AllFutureKeywords: true, unreleasedKeywords: true} + + tests := []struct { + note string + mod *Module + exp *Rule + }{ + { + note: "imports", + mod: MustParseModule(`package head +import data.doc1 as bar +import data.doc2 as corge +import input.x.y.foo +import input.qux as baz + +p[foo[bar[i]]] = {"baz": baz, "corge": corge} { true } +`), + exp: MustParseRule(`p[__local0__] = __local1__ { true; __local0__ = input.x.y.foo[data.doc1[i]]; __local1__ = {"baz": input.qux, "corge": data.doc2} }`), + }, + { + note: "array comprehension value", + mod: MustParseModule(`package head +q = [true | true] { true } +`), + exp: MustParseRule(`q = __local0__ { true; __local0__ = [true | true] }`), + }, + { + note: "array comprehension value in else head", + mod: MustParseModule(`package head +q { + false +} else = [true | true] { + true +} +`), + exp: MustParseRule(`q = true { false } else = __local0__ { true; __local0__ = [true | true] }`), + }, + { + note: "array comprehension value in head (comprehension-local var)", + mod: MustParseModule(`package head +q = [a | a := true] { + false +} else = [a | a := true] { + true +} +`), + exp: MustParseRule(`q = __local2__ { false; __local2__ = [__local0__ | __local0__ = true] } else = __local3__ { true; __local3__ = [__local1__ | __local1__ = true] }`), + }, + { + note: "array comprehension value in function head (comprehension-local var)", + mod: MustParseModule(`package head +f(x) = [a | a := true] { + false +} else = [a | a := true] { + true +} +`), + exp: MustParseRule(`f(__local0__) = __local3__ { false; __local3__ = [__local1__ | __local1__ = true] } else = __local4__ { true; __local4__ = [__local2__ | __local2__ = true] }`), + }, + { + note: "array comprehension value in else-func head (reused arg rewrite)", + mod: MustParseModule(`package head +f(x, y) = [x | y] { + false +} else = [x | y] { + true +} +`), + exp: MustParseRule(`f(__local0__, __local1__) = __local2__ { false; __local2__ = [__local0__ | __local1__] } else = __local3__ { true; __local3__ = [__local0__ | __local1__] }`), + }, + { + note: "object comprehension value", + mod: MustParseModule(`package head +r = {"true": true | true} { true } +`), + exp: MustParseRule(`r = __local0__ { true; __local0__ = {"true": true | true} }`), + }, + { + note: "object comprehension value in else head", + mod: MustParseModule(`package head +q { + false +} else = {"true": true | true} { + true +} +`), + exp: MustParseRule(`q = true { false } else = __local0__ { true; __local0__ = {"true": true | true} }`), + }, + { + note: "object comprehension value in head (comprehension-local var)", + mod: MustParseModule(`package head +q = {"a": a | a := true} { + false +} else = {"a": a | a := true} { + true +} +`), + exp: MustParseRule(`q = __local2__ { false; __local2__ = {"a": __local0__ | __local0__ = true} } else = __local3__ { true; __local3__ = {"a": __local1__ | __local1__ = true} }`), + }, + { + note: "object comprehension value in function head (comprehension-local var)", + mod: MustParseModule(`package head +f(x) = {"a": a | a := true} { + false +} else = {"a": a | a := true} { + true +} +`), + exp: MustParseRule(`f(__local0__) = __local3__ { false; __local3__ = {"a": __local1__ | __local1__ = true} } else = __local4__ { true; __local4__ = {"a": __local2__ | __local2__ = true} }`), + }, + { + note: "object comprehension value in else-func head (reused arg rewrite)", + mod: MustParseModule(`package head +f(x, y) = {x: y | true} { + false +} else = {x: y | true} { + true +} +`), + exp: MustParseRule(`f(__local0__, __local1__) = __local2__ { false; __local2__ = {__local0__: __local1__ | true} } else = __local3__ { true; __local3__ = {__local0__: __local1__ | true} }`), + }, + { + note: "set comprehension value", + mod: MustParseModule(`package head +s = {true | true} { true } +`), + exp: MustParseRule(`s = __local0__ { true; __local0__ = {true | true} }`), + }, + { + note: "set comprehension value in else head", + mod: MustParseModule(`package head +q = {false | false} { + false +} else = {true | true} { + true +} +`), + exp: MustParseRule(`q = __local0__ { false; __local0__ = {false | false} } else = __local1__ { true; __local1__ = {true | true} }`), + }, + { + note: "set comprehension value in head (comprehension-local var)", + mod: MustParseModule(`package head +q = {a | a := true} { + false +} else = {a | a := true} { + true +} +`), + exp: MustParseRule(`q = __local2__ { false; __local2__ = {__local0__ | __local0__ = true} } else = __local3__ { true; __local3__ = {__local1__ | __local1__ = true} }`), + }, + { + note: "set comprehension value in function head (comprehension-local var)", + mod: MustParseModule(`package head +f(x) = {a | a := true} { + false +} else = {a | a := true} { + true +} +`), + exp: MustParseRule(`f(__local0__) = __local3__ { false; __local3__ = {__local1__ | __local1__ = true} } else = __local4__ { true; __local4__ = {__local2__ | __local2__ = true} }`), + }, + { + note: "set comprehension value in else-func head (reused arg rewrite)", + mod: MustParseModule(`package head +f(x, y) = {x | y} { + false +} else = {x | y} { + true +} +`), + exp: MustParseRule(`f(__local0__, __local1__) = __local2__ { false; __local2__ = {__local0__ | __local1__} } else = __local3__ { true; __local3__ = {__local0__ | __local1__} }`), + }, + { + note: "import in else value", + mod: MustParseModule(`package head +import input.qux as baz +elsekw { + false +} else = baz { + true +} +`), + exp: MustParseRule(`elsekw { false } else = __local0__ { true; __local0__ = input.qux }`), + }, + { + note: "import ref in last ref head term", + mod: MustParseModule(`package head +import data.doc1 as bar +x.y.z[bar[i]] = true +`), + exp: MustParseRule(`x.y.z[__local0__] = true { true; __local0__ = data.doc1[i] }`), + }, + { + note: "import ref in multi-value ref rule", + mod: MustParseModule(`package head +import future.keywords.if +import future.keywords.contains +import data.doc1 as bar +x.y.w contains bar[i] if true +`), + exp: func() *Rule { + exp, _ := ParseRuleWithOpts(`x.y.w contains __local0__ if {true; __local0__ = data.doc1[i] }`, popts) + return exp + }(), + }, + } + + for _, tc := range tests { + t.Run(tc.note, func(t *testing.T) { + c := NewCompiler() + c.Modules["head"] = tc.mod + compileStages(c, c.rewriteRefsInHead) + assertNotFailed(t, c) + act := c.Modules["head"].Rules[0] + assertRulesEqual(t, act, tc.exp) + }) + } +} + +func TestCompilerRefHeadsNeedCapability(t *testing.T) { + popts := ParserOptions{AllFutureKeywords: true, unreleasedKeywords: true} + for _, tc := range []struct { + note string + mod *Module + err string + }{ + { + note: "one-dot ref, single-value rule, short+compat", + mod: MustParseModule(`package t +p[1] = 2`), + }, + { + note: "function, short", + mod: MustParseModule(`package t +p(1)`), + }, + { + note: "function", + mod: MustParseModuleWithOpts(`package t +p(1) if true`, popts), + }, + { + note: "function with value", + mod: MustParseModuleWithOpts(`package t +p(1) = 2 if true`, popts), + }, + { + note: "function with value", + mod: MustParseModule(`package t +p(1) = 2`), + }, + { + note: "one-dot ref, single-value rule, compat", + mod: MustParseModuleWithOpts(`package t +p[3] = 4 if true`, popts), + }, + { + note: "multi-value non-ref head", + mod: MustParseModuleWithOpts(`package t +p contains 1 if true`, popts), + }, + { // NOTE(sr): this was previously forbidden because we need the `if` for disambiguation + note: "one-dot ref head", + mod: MustParseModuleWithOpts(`package t +p[1] if true`, popts), + err: "rule heads with refs are not supported: p[1]", + }, + { + note: "single-value ref rule", + mod: MustParseModuleWithOpts(`package t +a.b.c[x] if x := input`, popts), + err: "rule heads with refs are not supported: a.b.c[x]", + }, + { + note: "ref head function", + mod: MustParseModuleWithOpts(`package t +a.b.c(x) if x == input`, popts), + err: "rule heads with refs are not supported: a.b.c", + }, + { + note: "multi-value ref rule", + mod: MustParseModuleWithOpts(`package t +a.b.c contains x if x := input`, popts), + err: "rule heads with refs are not supported: a.b.c", + }, + } { + t.Run(tc.note, func(t *testing.T) { + caps, err := LoadCapabilitiesVersion("v0.44.0") + if err != nil { + t.Fatal(err) + } + c := NewCompiler().WithCapabilities(caps) + c.Modules["test"] = tc.mod + compileStages(c, c.rewriteRefsInHead) + if tc.err != "" { + assertErrorWithMessage(t, c.Errors, tc.err) + } else { + assertNotFailed(t, c) + } + }) + } +} + +func TestCompilerRewriteRegoMetadataCalls(t *testing.T) { + tests := []struct { + note string + module string + exp string + }{ + { + note: "rego.metadata called, no metadata", + module: `package test + +p { + rego.metadata.chain()[0].path == ["test", "p"] + rego.metadata.rule() == {} +}`, + exp: `package test + +p = true { + __local2__ = [{"path": ["test", "p"]}] + __local3__ = {} + __local0__ = __local2__ + equal(__local0__[0].path, ["test", "p"]) + __local1__ = __local3__ + equal(__local1__, {}) +}`, + }, + { + note: "rego.metadata called, no output var, no metadata", + module: `package test + +p { + rego.metadata.chain() + rego.metadata.rule() +}`, + exp: `package test + +p = true { + __local0__ = [{"path": ["test", "p"]}] + __local1__ = {} + __local0__ + __local1__ +}`, + }, + { + note: "rego.metadata called, with metadata", + module: `# METADATA +# description: A test package +package test + +# METADATA +# title: My P Rule +p { + rego.metadata.chain()[0].title == "My P Rule" + rego.metadata.chain()[1].description == "A test package" +} + +# METADATA +# title: My Other P Rule +p { + rego.metadata.rule().title == "My Other P Rule" +}`, + exp: `# METADATA +# {"scope":"package","description":"A test package"} +package test + +# METADATA +# {"scope":"rule","title":"My P Rule"} +p = true { + __local3__ = [ + {"annotations": {"scope": "rule", "title": "My P Rule"}, "path": ["test", "p"]}, + {"annotations": {"description": "A test package", "scope": "package"}, "path": ["test"]} + ] + __local0__ = __local3__ + equal(__local0__[0].title, "My P Rule") + __local1__ = __local3__ + equal(__local1__[1].description, "A test package") +} + +# METADATA +# {"scope":"rule","title":"My Other P Rule"} +p = true { + __local4__ = {"scope": "rule", "title": "My Other P Rule"} + __local2__ = __local4__ + equal(__local2__.title, "My Other P Rule") +}`, + }, + { + note: "rego.metadata referenced multiple times", + module: `# METADATA +# description: TEST +package test + +p { + rego.metadata.chain()[0].path == ["test", "p"] + rego.metadata.chain()[1].path == ["test"] +}`, + exp: `# METADATA +# {"scope":"package","description":"TEST"} +package test + +p = true { + __local2__ = [ + {"path": ["test", "p"]}, + {"annotations": {"description": "TEST", "scope": "package"}, "path": ["test"]} + ] + __local0__ = __local2__ + equal(__local0__[0].path, ["test", "p"]) + __local1__ = __local2__ + equal(__local1__[1].path, ["test"]) }`, + }, + { + note: "rego.metadata return value", + module: `package test + +p := rego.metadata.chain()`, + exp: `package test + +p := __local0__ { + __local1__ = [{"path": ["test", "p"]}] + true + __local0__ = __local1__ +}`, + }, + { + note: "rego.metadata argument in function call", + module: `package test + +p { + q(rego.metadata.chain()) +} + +q(s) { + s == ["test", "p"] +}`, + exp: `package test + +p = true { + __local2__ = [{"path": ["test", "p"]}] + __local1__ = __local2__ + data.test.q(__local1__) +} + +q(__local0__) = true { + equal(__local0__, ["test", "p"]) +}`, + }, + { + note: "rego.metadata used in array comprehension", + module: `package test + +p = [x | x := rego.metadata.chain()]`, + exp: `package test + +p = [__local0__ | __local1__ = __local2__; __local0__ = __local1__] { + __local2__ = [{"path": ["test", "p"]}] + true +}`, + }, + { + note: "rego.metadata used in nested array comprehension", + module: `package test + +p { + y := [x | x := rego.metadata.chain()] + y[0].path == ["test", "p"] +}`, + exp: `package test + +p = true { + __local3__ = [{"path": ["test", "p"]}]; + __local1__ = [__local0__ | __local2__ = __local3__; __local0__ = __local2__]; + equal(__local1__[0].path, ["test", "p"]) +}`, + }, + { + note: "rego.metadata used in set comprehension", + module: `package test + +p = {x | x := rego.metadata.chain()}`, + exp: `package test + +p = {__local0__ | __local1__ = __local2__; __local0__ = __local1__} { + __local2__ = [{"path": ["test", "p"]}] + true +}`, + }, + { + note: "rego.metadata used in nested set comprehension", + module: `package test + +p { + y := {x | x := rego.metadata.chain()} + y[0].path == ["test", "p"] +}`, + exp: `package test + +p = true { + __local3__ = [{"path": ["test", "p"]}] + __local1__ = {__local0__ | __local2__ = __local3__; __local0__ = __local2__} + equal(__local1__[0].path, ["test", "p"]) +}`, + }, + { + note: "rego.metadata used in object comprehension", + module: `package test + +p = {i: x | x := rego.metadata.chain()[i]}`, + exp: `package test + +p = {i: __local0__ | __local1__ = __local2__; __local0__ = __local1__[i]} { + __local2__ = [{"path": ["test", "p"]}] + true +}`, + }, + { + note: "rego.metadata used in nested object comprehension", + module: `package test + +p { + y := {i: x | x := rego.metadata.chain()[i]} + y[0].path == ["test", "p"] +}`, + exp: `package test + +p = true { + __local3__ = [{"path": ["test", "p"]}] + __local1__ = {i: __local0__ | __local2__ = __local3__; __local0__ = __local2__[i]} + equal(__local1__[0].path, ["test", "p"]) +}`, + }, + } + + for _, tc := range tests { + t.Run(tc.note, func(t *testing.T) { + c := NewCompiler() + c.Modules = map[string]*Module{ + "test.rego": MustParseModule(tc.module), + } + compileStages(c, c.rewriteRegoMetadataCalls) + assertNotFailed(t, c) + + result := c.Modules["test.rego"] + exp := MustParseModuleWithOpts(tc.exp, ParserOptions{ProcessAnnotation: true}) + + if result.Compare(exp) != 0 { + t.Fatalf("\nExpected:\n\n%v\n\nGot:\n\n%v", exp, result) + } + }) + } +} + +func TestCompilerOverridingSelfCalls(t *testing.T) { + c := NewCompiler() + c.Modules = map[string]*Module{ + "self.rego": MustParseModule(`package self.metadata + +chain(x) = "foo" +rule := "bar"`), + "test.rego": MustParseModule(`package test +import data.self + +p := self.metadata.chain(42) +q := self.metadata.rule`), + } + + compileStages(c, nil) + assertNotFailed(t, c) +} + +func TestCompilerRewriteLocalAssignments(t *testing.T) { + + tests := []struct { + module string + exp interface{} + expRewrittenMap map[Var]Var + }{ + { + module: ` + package test + body { a := 1; a > 0 } + `, + exp: ` + package test + body = true { __local0__ = 1; gt(__local0__, 0) } + `, + expRewrittenMap: map[Var]Var{ + Var("__local0__"): Var("a"), + }, + }, + { + module: ` + package test + head_vars(a) = b { b := a } + `, + exp: ` + package test + head_vars(__local0__) = __local1__ { __local1__ = __local0__ } + `, + expRewrittenMap: map[Var]Var{ + Var("__local0__"): Var("a"), + Var("__local1__"): Var("b"), + }, + }, + { + module: ` + package test + head_key[a] { a := 1 } + `, + exp: ` + package test + head_key[__local0__] { __local0__ = 1 } + `, + expRewrittenMap: map[Var]Var{ + Var("__local0__"): Var("a"), + }, + }, + { + module: ` + package test + head_unsafe_var[a] { some a } + `, + exp: ` + package test + head_unsafe_var[__local0__] { true } + `, + expRewrittenMap: map[Var]Var{ + Var("__local0__"): Var("a"), + }, + }, + { + module: ` + package test + p = {1,2,3} + x = 4 + head_nested[p[x]] { + some x + }`, + exp: ` + package test + p = {1,2,3} + x = 4 + head_nested[data.test.p[__local0__]] + `, + expRewrittenMap: map[Var]Var{ + Var("__local0__"): Var("x"), + }, + }, + { + module: ` + package test + p = {1,2} + head_closure_nested[p[x]] { + y = [true | some x; x = 1] + } + `, + exp: ` + package test + p = {1,2} + head_closure_nested[data.test.p[x]] { + y = [true | __local0__ = 1] + } + `, + expRewrittenMap: map[Var]Var{ + Var("__local0__"): Var("x"), + }, + }, + { + module: ` + package test + nested { + a := [1,2,3] + x := [true | a[i] > 1] + } + `, + exp: ` + package test + nested = true { __local0__ = [1, 2, 3]; __local1__ = [true | gt(__local0__[i], 1)] } + `, + expRewrittenMap: map[Var]Var{ + Var("__local0__"): Var("a"), + Var("__local1__"): Var("x"), + }, + }, + { + module: ` + package test + x = 2 + shadow_globals[x] { x := 1 } + `, + exp: ` + package test + x = 2 { true } + shadow_globals[__local0__] { __local0__ = 1 } + `, + expRewrittenMap: map[Var]Var{ + Var("__local0__"): Var("x"), + }, + }, + { + module: ` + package test + shadow_rule[shadow_rule] { shadow_rule := 1 } + `, + exp: ` + package test + shadow_rule[__local0__] { __local0__ = 1 } + `, + expRewrittenMap: map[Var]Var{ + Var("__local0__"): Var("shadow_rule"), + }, + }, + { + module: ` + package test + shadow_roots_1 { data := 1; input := 2; input > data } + `, + exp: ` + package test + shadow_roots_1 = true { __local0__ = 1; __local1__ = 2; gt(__local1__, __local0__) } + `, + expRewrittenMap: map[Var]Var{ + Var("__local0__"): Var("data"), + Var("__local1__"): Var("input"), + }, + }, + { + module: ` + package test + shadow_roots_2 { input := {"a": 1}; input.a > 0 } + `, + exp: ` + package test + shadow_roots_2 = true { __local0__ = {"a": 1}; gt(__local0__.a, 0) } + `, + expRewrittenMap: map[Var]Var{ + Var("__local0__"): Var("input"), + }, + }, + { + module: ` + package test + skip_with_target { a := 1; input := 2; data.p with input as a } + `, + exp: ` + package test + skip_with_target = true { __local0__ = 1; __local1__ = 2; data.p with input as __local0__ } + `, + expRewrittenMap: map[Var]Var{ + Var("__local0__"): Var("a"), + Var("__local1__"): Var("input"), + }, + }, + { + module: ` + package test + shadow_comprehensions { + a := 1 + [true | a := 2; b := 1] + b := 2 + } + `, + exp: ` + package test + shadow_comprehensions = true { __local0__ = 1; [true | __local1__ = 2; __local2__ = 1]; __local3__ = 2 } + `, + expRewrittenMap: map[Var]Var{ + Var("__local0__"): Var("a"), + Var("__local1__"): Var("a"), + Var("__local2__"): Var("b"), + Var("__local3__"): Var("b"), + }, + }, + { + module: ` + package test + scoping { + [true | a := 1] + [true | a := 2] + } + `, + exp: ` + package test + scoping = true { [true | __local0__ = 1]; [true | __local1__ = 2] } + `, + expRewrittenMap: map[Var]Var{ + Var("__local0__"): Var("a"), + Var("__local1__"): Var("a"), + }, + }, + { + module: ` + package test + object_keys { + {k: v1, "k2": v2} := {"foo": 1, "k2": 2} + } + `, + exp: ` + package test + object_keys = true { {"k2": __local0__, k: __local1__} = {"foo": 1, "k2": 2} } + `, + expRewrittenMap: map[Var]Var{ + Var("__local0__"): Var("v2"), + Var("__local1__"): Var("v1"), + }, + }, + { + module: ` + package test + head_array_comprehensions = [[x] | x := 1] + head_set_comprehensions = {[x] | x := 1} + head_object_comprehensions = {k: [x] | k := "foo"; x := 1} + `, + exp: ` + package test + head_array_comprehensions = [[__local0__] | __local0__ = 1] { true } + head_set_comprehensions = {[__local1__] | __local1__ = 1} { true } + head_object_comprehensions = {__local2__: [__local3__] | __local2__ = "foo"; __local3__ = 1} { true } + `, + expRewrittenMap: map[Var]Var{ + Var("__local0__"): Var("x"), + Var("__local1__"): Var("x"), + Var("__local2__"): Var("k"), + Var("__local3__"): Var("x"), + }, + }, + { + module: ` + package test + rewritten_object_key { + k := "foo" + {k: 1} + } + `, + exp: ` + package test + rewritten_object_key = true { __local0__ = "foo"; {__local0__: 1} } + `, + expRewrittenMap: map[Var]Var{ + Var("__local0__"): Var("k"), + }, + }, + { + module: ` + package test + rewritten_object_key_head[[{k: 1}]] { + k := "foo" + } + `, + exp: ` + package test + rewritten_object_key_head[[{__local0__: 1}]] { __local0__ = "foo" } + `, + expRewrittenMap: map[Var]Var{ + Var("__local0__"): Var("k"), + }, + }, + { + module: ` + package test + rewritten_object_key_head_value = [{k: 1}] { + k := "foo" + } + `, + exp: ` + package test + rewritten_object_key_head_value = [{__local0__: 1}] { __local0__ = "foo" } + `, + expRewrittenMap: map[Var]Var{ + Var("__local0__"): Var("k"), + }, + }, + { + module: ` + package test + skip_with_target_in_assignment { + input := 1 + a := [true | true with input as 2; true with input as 3] + } + `, + exp: ` + package test + skip_with_target_in_assignment = true { __local0__ = 1; __local1__ = [true | true with input as 2; true with input as 3] } + `, + expRewrittenMap: map[Var]Var{ + Var("__local0__"): Var("input"), + Var("__local1__"): Var("a"), + }, + }, + { + module: ` + package test + rewrite_with_value_in_assignment { + a := 1 + b := 1 with input as [a] + } + `, + exp: ` + package test + rewrite_with_value_in_assignment = true { __local0__ = 1; __local1__ = 1 with input as [__local0__] } + `, + expRewrittenMap: map[Var]Var{ + Var("__local0__"): Var("a"), + Var("__local1__"): Var("b"), + }, + }, + { + module: ` + package test + rewrite_with_value_in_expr { + a := 1 + a > 0 with input as [a] + } + `, + exp: ` + package test + rewrite_with_value_in_expr = true { __local0__ = 1; gt(__local0__, 0) with input as [__local0__] } + `, + expRewrittenMap: map[Var]Var{ + Var("__local0__"): Var("a"), + }, + }, + { + module: ` + package test + rewrite_nested_with_value_in_expr { + a := 1 + a > 0 with input as object.union({"a": a}, {"max_a": max([a])}) + } + `, + exp: ` + package test + rewrite_nested_with_value_in_expr = true { __local0__ = 1; gt(__local0__, 0) with input as object.union({"a": __local0__}, {"max_a": max([__local0__])}) } + `, + expRewrittenMap: map[Var]Var{ + Var("__local0__"): Var("a"), + }, + }, + { + module: ` + package test + global = {} + ref_shadowed { + global := {"a": 1} + global.a > 0 + } + `, + exp: ` + package test + global = {} { true } + ref_shadowed = true { __local0__ = {"a": 1}; gt(__local0__.a, 0) } + `, + expRewrittenMap: map[Var]Var{ + Var("__local0__"): Var("global"), + }, + }, + { + module: ` + package test + f(x) = y { + x == 1 + y := 2 + } else = y { + x == 3 + y := 4 + } + `, + // Each "else" rule has a separate rule head and the vars in the + // args will be rewritten. Since we cannot currently redefine the + // args, we must parse the module and then manually update the args. + exp: func() *Module { + module := MustParseModule(` + package test + + f(__local0__) = __local1__ { __local0__ == 1; __local1__ = 2 } else = __local2__ { __local0__ == 3; __local2__ = 4 } + `) + module.Rules[0].Else.Head.Args[0].Value = Var("__local0__") + return module + }, + expRewrittenMap: map[Var]Var{ + Var("__local0__"): Var("x"), + Var("__local1__"): Var("y"), + Var("__local2__"): Var("y"), + }, + }, + { + module: ` + package test + f({"x": [x]}) = y { x == 1; y := 2 }`, + exp: ` + package test + + f({"x": [__local0__]}) = __local1__ { __local0__ == 1; __local1__ = 2 }`, + expRewrittenMap: map[Var]Var{ + Var("__local0__"): Var("x"), + Var("__local1__"): Var("y"), + }, + }, + { + module: ` + package test + + f(x, [x]) = x { x == 1 } + `, + exp: ` + package test + + f(__local0__, [__local0__]) = __local0__ { __local0__ == 1 } + `, + expRewrittenMap: map[Var]Var{ + Var("__local0__"): Var("x"), + }, + }, + { + module: ` + package test + + f(x) = {x[0]: 1} { true } + `, + exp: ` + package test + + f(__local0__) = {__local0__[0]: 1} { true } + `, + expRewrittenMap: map[Var]Var{ + Var("__local0__"): Var("x"), + }, + }, + { + module: ` + package test + + f({{t | t := 0}: 1}) { + true + } + `, + exp: ` + package test + + f({{__local0__ | __local0__ = 0}: 1}) { true } + `, + expRewrittenMap: map[Var]Var{ + Var("__local0__"): Var("t"), + }, + }, + { + module: ` + package test + + f({{t | t := 0}}) { + true + } + `, + exp: ` + package test + + f({{__local0__ | __local0__ = 0}}) { true } + `, + expRewrittenMap: map[Var]Var{ + Var("__local0__"): Var("t"), + }, + }, + } + + for i, tc := range tests { + t.Run(fmt.Sprint(i), func(t *testing.T) { + c := NewCompiler() + c.Modules = map[string]*Module{ + "test.rego": MustParseModule(tc.module), + } + compileStages(c, c.rewriteLocalVars) + assertNotFailed(t, c) + result := c.Modules["test.rego"] + var exp *Module + switch e := tc.exp.(type) { + case string: + exp = MustParseModule(e) + case func() *Module: + exp = e() + default: + panic("expected value must be string or func() *Module") + } + if result.Compare(exp) != 0 { + t.Fatalf("\nExpected:\n\n%v\n\nGot:\n\n%v", exp, result) + } + if !reflect.DeepEqual(c.RewrittenVars, tc.expRewrittenMap) { + t.Fatalf("\nExpected Rewritten Vars:\n\n\t%+v\n\nGot:\n\n\t%+v\n\n", tc.expRewrittenMap, c.RewrittenVars) + } + }) + } + +} +func TestRewriteLocalVarDeclarationErrors(t *testing.T) { + + c := NewCompiler() + + c.Modules["test"] = MustParseModule(`package test + + redeclaration { + r1 = 1 + r1 := 2 + r2 := 1 + [b, r2] := [1, 2] + input.path == 1 + input := "foo" + _ := [1 | nested := 1; nested := 2] + } + + negation { + not a := 1 + } + + bad_assign { + null := x + true := x + 4.5 := x + "foo" := x + [true | true] := [] + {true | true} := set() + {"foo": true | true} := {} + x + 1 := 2 + data.foo := 1 + [z, 1] := [1, 2] + } + + arg_redeclared(arg1) { + arg1 := 1 + } + + arg_nested_redeclared({{arg_nested| arg_nested := 1; arg_nested := 2}}) { true } + `) + + compileStages(c, c.rewriteLocalVars) + + expectedErrors := []string{ + "var r1 referenced above", + "var r2 assigned above", + "var input referenced above", + "var nested assigned above", + "arg arg1 redeclared", + "var arg_nested assigned above", + "cannot assign vars inside negated expression", + "cannot assign to ref", + "cannot assign to arraycomprehension", + "cannot assign to setcomprehension", + "cannot assign to objectcomprehension", + "cannot assign to call", + "cannot assign to number", + "cannot assign to number", + "cannot assign to boolean", + "cannot assign to string", + "cannot assign to null", + } + + sort.Strings(expectedErrors) + + result := []string{} + + for i := range c.Errors { + result = append(result, c.Errors[i].Message) + } + + sort.Strings(result) + + if len(expectedErrors) != len(result) { + t.Fatalf("Expected %d errors but got %d:\n\n%v\n\nGot:\n\n%v", len(expectedErrors), len(result), strings.Join(expectedErrors, "\n"), strings.Join(result, "\n")) + } + + for i := range result { + if result[i] != expectedErrors[i] { + t.Fatalf("Expected:\n\n%v\n\nGot:\n\n%v", strings.Join(expectedErrors, "\n"), strings.Join(result, "\n")) + } + } +} + +func TestRewriteDeclaredVarsStage(t *testing.T) { + + // Unlike the following test case, this only executes up to the + // RewriteLocalVars stage. This is done so that later stages like + // RewriteDynamics are not executed. + + tests := []struct { + note string + module string + exp string + }{ + { + note: "object ref key", + module: ` + package test + + p { + a := {"a": "a"} + {a.a: a.a} + } + `, + exp: ` + package test + + p { + __local0__ = {"a": "a"} + {__local0__.a: __local0__.a} + } + `, + }, + { + note: "set ref element", + module: ` + package test + + p { + a := {"a": "a"} + {a.a} + } + `, + exp: ` + package test + + p { + __local0__ = {"a": "a"} + {__local0__.a} + } + `, + }, + } + + for _, tc := range tests { + t.Run(tc.note, func(t *testing.T) { + + c := NewCompiler() + + c.Modules = map[string]*Module{ + "test.rego": MustParseModule(tc.module), + } + + compileStages(c, c.rewriteLocalVars) + + exp := MustParseModule(tc.exp) + result := c.Modules["test.rego"] + + if !exp.Equal(result) { + t.Fatalf("Expected:\n\n%v\n\nGot:\n\n%v", exp, result) + } + }) + } +} + +func TestRewriteDeclaredVars(t *testing.T) { + tests := []struct { + note string + module string + exp string + wantErr error + }{ + { + note: "rewrite unify", + module: ` + package test + x = 1 + y = 2 + p { some x; input = [x, y] } + `, + exp: ` + package test + x = 1 + y = 2 + p { __local1__ = data.test.y; input = [__local0__, __local1__] } + `, + }, + { + note: "rewrite call", + module: ` + package test + x = [] + y = {} + p { some x; walk(y, [x, y]) } + `, + exp: ` + package test + x = [] + y = {} + p { __local1__ = data.test.y; __local2__ = data.test.y; walk(__local1__, [__local0__, __local2__]) } + `, + }, + { + note: "rewrite term", + module: ` + package test + x = "a" + y = 1 + q[[2, "b"]] + p { some x; q[[y,x]] } + `, + exp: ` + package test + x = "a" + y = 1 + q[[2, "b"]] + p { __local1__ = data.test.y; data.test.q[[__local1__, __local0__]] } + `, + }, + { + note: "single-value rule with ref head", + module: ` + package test + + p.r.q[s] = t { + t := 1 + s := input.foo + } + `, + exp: ` + package test + + p.r.q[__local1__] = __local0__ { + __local0__ = 1 + __local1__ = input.foo + } + `, + }, + { + note: "rewrite some x in xs", + module: ` + package test + import future.keywords.in + xs = ["a", "b", "c"] + p { some x in xs; x == "a" } + `, + exp: ` + package test + xs = ["a", "b", "c"] + p { __local2__ = data.test.xs[__local1__]; __local2__ = "a" } + `, + }, + { + note: "rewrite some k, x in xs", + module: ` + package test + import future.keywords.in + xs = ["a", "b", "c"] + p { some k, x in xs; x == "a"; k == 2 } + `, + exp: ` + package test + xs = ["a", "b", "c"] + p { __local1__ = data.test.xs[__local0__]; __local1__ = "a"; __local0__ = 2 } + `, + }, + { + note: "rewrite some k, x in xs[i]", + module: ` + package test + import future.keywords.in + xs = [["a", "b", "c"], []] + p { + some i + some k, x in xs[i] + x == "a" + k == 2 + } + `, + exp: ` + package test + xs = [["a", "b", "c"], []] + p = true { __local2__ = data.test.xs[__local0__][__local1__]; __local2__ = "a"; __local1__ = 2 } + `, + }, + { + note: "rewrite some k, x in xs[i] with `i` as ref", + module: ` + package test + import future.keywords.in + i = 0 + xs = [["a", "b", "c"], []] + p { + some k, x in xs[i] + x == "a" + k == 2 + } + `, + exp: ` + package test + i = 0 + xs = [["a", "b", "c"], []] + p = true { __local2__ = data.test.i; __local1__ = data.test.xs[__local2__][__local0__]; __local1__ = "a"; __local0__ = 2 } + `, + }, + { + note: "rewrite some: with modifier on domain", + module: ` + package test + p { + some k, x in input with input as [1, 1, 1] + k == 0 + x == 1 + } + `, + exp: ` + package test + p { + __local1__ = input[__local0__] with input as [1, 1, 1] + __local0__ = 0 + __local1__ = 1 + } + `, + }, + { + note: "rewrite every", + module: ` + package test + # import future.keywords.in + # import future.keywords.every + i = 0 + xs = [1, 2] + k = "foo" + v = "bar" + p { + every k, v in xs { k + v > i } + } + `, + exp: ` + package test + i = 0 + xs = [1, 2] + k = "foo" + v = "bar" + p = true { + __local2__ = data.test.xs + every __local0__, __local1__ in __local2__ { + plus(__local0__, __local1__, __local3__) + __local4__ = data.test.i + gt(__local3__, __local4__) + } + } `, + }, + { + note: "rewrite every: unused key var", + module: ` + package test + # import future.keywords.in + # import future.keywords.every + p { + every k, v in [1] { v >= 1 } + } + `, + wantErr: errors.New("declared var k unused"), + }, + { + // NOTE(sr): this would happen when compiling modules twice: + // the first run rewrites every to include a generated key var, + // the second one bails because it's not used. + // Seen in the wild when using `opa test -b` on a bundle that + // used `every`, https://github.com/open-policy-agent/opa/issues/4420 + note: "rewrite every: unused generated key var", + module: ` + package test + + p { + every __local0__, v in [1] { v >= 1 } + } + `, + exp: ` + package test + p = true { + __local3__ = [1] + every __local1__, __local2__ in __local3__ { __local2__ >= 1 } + } + `, + }, + { + note: "rewrite every: unused value var", + module: ` + package test + # import future.keywords.in + # import future.keywords.every + p { + every v in [1] { true } + } + `, + wantErr: errors.New("declared var v unused"), + }, + { + note: "rewrite every: wildcard value var, used key", + module: ` + package test + # import future.keywords.in + # import future.keywords.every + p { + every k, _ in [1] { k >= 0 } + } + `, + exp: ` + package test + p = true { + __local1__ = [1] + every __local0__, _ in __local1__ { gte(__local0__, 0) } + } + `, + }, + { + note: "rewrite every: wildcard key+value var", // NOTE(sr): may be silly, but valid + module: ` + package test + # import future.keywords.in + # import future.keywords.every + p { + every _, _ in [1] { true } + } + `, + exp: ` + package test + p = true { __local0__ = [1]; every _, _ in __local0__ { true } } + `, + }, + { + note: "rewrite every: declared vars with different scopes", + module: ` + package test + # import future.keywords.in + # import future.keywords.every + p { + some x + x = 10 + every x in [1] { x == 1 } + } + `, + exp: ` + package test + p = true { + __local0__ = 10 + __local3__ = [1] + every __local1__, __local2__ in __local3__ { __local2__ = 1 } + } + `, + }, + { + note: "rewrite every: declared vars used in body", + module: ` + package test + # import future.keywords.in + # import future.keywords.every + p { + some y + y = 10 + every x in [1] { x == y } + } + `, + exp: ` + package test + p = true { + __local0__ = 10 + __local3__ = [1] + every __local1__, __local2__ in __local3__ { + __local2__ = __local0__ + } + } + `, + }, + { + note: "rewrite every: pops declared var stack", + module: ` + package test + # import future.keywords.in + # import future.keywords.every + p[x] { + some x + x = 10 + every _ in [1] { true } + } + `, + exp: ` + package test + p[__local0__] { __local0__ = 10; __local2__ = [1]; every __local1__, _ in __local2__ { true } } + `, + }, + { + note: "rewrite every: nested", + module: ` + package test + # import future.keywords.in + # import future.keywords.every + p { + xs := [[1], [2]] + every v in [1] { + every w in xs[v] { + w == 2 + } + } + } + `, + exp: ` + package test + p = true { + __local0__ = [[1], [2]] + __local5__ = [1] + every __local1__, __local2__ in __local5__ { + __local6__ = __local0__[__local2__] + every __local3__, __local4__ in __local6__ { + __local4__ = 2 + } + } + } + `, + }, + { + note: "rewrite every: with modifier on domain", + module: ` + package test + # import future.keywords.in + # import future.keywords.every + p { + every x in input { x == 1 } with input as [1, 1, 1] + } + `, + exp: ` + package test + p { + __local2__ = input with input as [1, 1, 1] + every __local0__, __local1__ in __local2__ { + __local1__ = 1 + } with input as [1, 1, 1] + } + `, + }, + { + note: "rewrite every: with modifier on domain with declared var", + module: ` + package test + # import future.keywords.in + # import future.keywords.every + p { + xs := [1, 2] + every x in input { x == 1 } with input as xs + } + `, + exp: ` + package test + p { + __local0__ = [1, 2] + __local3__ = input with input as __local0__ + every __local1__, __local2__ in __local3__ { + __local2__ = 1 + } with input as __local0__ + } + `, + }, + { + note: "rewrite every: with modifier on body", + module: ` + package test + # import future.keywords.in + # import future.keywords.every + p { + every x in [2] { x == input } with input as 2 + } + `, + exp: ` + package test + p { + __local2__ = [2] with input as 2 + every __local0__, __local1__ in __local2__ { + __local1__ = input + } with input as 2 + } + `, + }, + { + note: "rewrite closures", + module: ` + package test + x = 1 + y = 2 + p { + some x, z + z = 3 + [x | x = 2; y = 2; some z; z = 4] + } + `, + exp: ` + package test + x = 1 + y = 2 + p { + __local1__ = 3 + [__local0__ | __local0__ = 2; data.test.y = 2; __local2__ = 4] + } + `, + }, + { + note: "rewrite head var", + module: ` + package test + x = "a" + y = 1 + z = 2 + p[x] = [y, z] { + some x, z + x = "b" + z = 4 + }`, + exp: ` + package test + x = "a" + y = 1 + z = 2 + p[__local0__] = __local2__ { + __local0__ = "b" + __local1__ = 4; + __local3__ = data.test.y + __local2__ = [__local3__, __local1__] + } + `, + }, + { + note: "rewrite call with root document ref as arg", + module: ` + package test + + p { + f(input, "bar") + } + + f(x, y) { + x[y] + } + `, + exp: ` + package test + + p = true { + __local2__ = input; + data.test.f(__local2__, "bar") + } + + f(__local0__, __local1__) = true { + __local0__[__local1__] + } + `, + }, + { + note: "redeclare err", + module: ` + package test + p { + some x + some x + } + `, + wantErr: errors.New("var x declared above"), + }, + { + note: "redeclare assigned err", + module: ` + package test + p { + x := 1 + some x + } + `, + wantErr: errors.New("var x assigned above"), + }, + { + note: "redeclare reference err", + module: ` + package test + p { + data.q[x] + some x + } + `, + wantErr: errors.New("var x referenced above"), + }, + { + note: "declare unused err", + module: ` + package test + p { + some x + } + `, + wantErr: errors.New("declared var x unused"), + }, + { + note: "declare unsafe err", + module: ` + package test + p[x] { + some x + x == 1 + } + `, + wantErr: errors.New("var x is unsafe"), + }, + { + note: "declare arg err", + module: ` + package test + + f([a]) { + some a + a = 1 + } + `, + wantErr: errors.New("arg a redeclared"), + }, + } + + for _, tc := range tests { + t.Run(tc.note, func(t *testing.T) { + opts := CompileOpts{ParserOptions: ParserOptions{AllFutureKeywords: true, unreleasedKeywords: true}} + compiler, err := CompileModulesWithOpt(map[string]string{"test.rego": tc.module}, opts) + if tc.wantErr != nil { + if err == nil { + t.Fatal("Expected error but got success") + } + if !strings.Contains(err.Error(), tc.wantErr.Error()) { + t.Fatalf("Expected %v but got %v", tc.wantErr, err) + } + } else if err != nil { + t.Fatal(err) + } else { + exp := MustParseModuleWithOpts(tc.exp, opts.ParserOptions) + result := compiler.Modules["test.rego"] + if exp.Compare(result) != 0 { + t.Fatalf("Expected:\n\n%v\n\nGot:\n\n%v", exp, result) + } + } + }) + } +} + +func TestCheckUnusedFunctionArgVars(t *testing.T) { + tests := []strictnessTestCase{ + { + note: "one of the two function args is not used - issue 5602 regression test", + module: `package test + func(x, y) { + x = 1 + }`, + expectedErrors: Errors{ + &Error{ + Code: CompileErr, + Location: NewLocation([]byte("func(x, y)"), "", 2, 4), + Message: "unused argument y", + }, + }, + }, + { + note: "one of the two ref-head function args is not used", + module: `package test + a.b.c.func(x, y) { + x = 1 + }`, + expectedErrors: Errors{ + &Error{ + Code: CompileErr, + Location: NewLocation([]byte("a.b.c.func(x, y)"), "", 2, 4), + Message: "unused argument y", + }, + }, + }, + { + note: "multiple unused argvar in scope - issue 5602 regression test", + module: `package test + func(x, y) { + input.baz = 1 + input.test == "foo" + }`, + expectedErrors: Errors{ + &Error{ + Code: CompileErr, + Location: NewLocation([]byte("func(x, y)"), "", 2, 4), + Message: "unused argument x", + }, + &Error{ + Code: CompileErr, + Location: NewLocation([]byte("func(x, y)"), "", 2, 4), + Message: "unused argument y", + }, + }, + }, + { + note: "some unused argvar in scope - issue 5602 regression test", + module: `package test + func(x, y) { + input.test == "foo" + x = 1 + }`, + expectedErrors: Errors{ + &Error{ + Code: CompileErr, + Location: NewLocation([]byte("func(x, y)"), "", 2, 4), + Message: "unused argument y", + }, + }, + }, + { + note: "wildcard argvar that's ignored - issue 5602 regression test", + module: `package test + func(x, _) { + input.test == "foo" + x = 1 + }`, + expectedErrors: Errors{}, + }, + { + note: "wildcard argvar that's ignored - issue 5602 regression test", + module: `package test + func(x, _) { + input.test == "foo" + }`, + expectedErrors: Errors{ + &Error{ + Code: CompileErr, + Location: NewLocation([]byte("func(x, _)"), "", 2, 4), + Message: "unused argument x", + }, + }, + }, + { + note: "argvar not used in body but in head - issue 5602 regression test", + module: `package test + func(x) := x { + input.test == "foo" + }`, + expectedErrors: Errors{}, + }, + { + note: "argvar not used in body but in head value comprehension", + module: `package test + a := {"foo": 1} + func(x) := { x: v | v := a[x] } { + input.test == "foo" + }`, + expectedErrors: Errors{}, + }, + { + note: "argvar not used in body but in else-head value comprehension", + module: `package test + a := {"foo": 1} + func(x) { + input.test == "foo" + } else := { x: v | v := a[x] } { + input.test == "bar" + }`, + expectedErrors: Errors{}, + }, + { + note: "argvar not used in body and shadowed in head value comprehension", + module: `package test + a := {"foo": 1} + func(x) := { x: v | x := "foo"; v := a[x] } { + input.test == "foo" + }`, + expectedErrors: Errors{ + &Error{ + Code: CompileErr, + Location: NewLocation([]byte("func(x) := { x: v | x := \"foo\"; v := a[x] }"), "", 3, 4), + Message: "unused argument x", + }, + }, + }, + { + note: "argvar used in primary body but not in else body", + module: `package test + func(x) { + input.test == x + } else := false { + input.test == "foo" + }`, + expectedErrors: Errors{}, + }, + { + note: "argvar used in primary body but not in else body (with wildcard)", + module: `package test + func(x, _) { + input.test == x + } else := false { + input.test == "foo" + }`, + expectedErrors: Errors{}, + }, + { + note: "argvar not used in primary body but in else body", + module: `package test + func(x) { + input.test == "foo" + } else := false { + input.test == x + }`, + expectedErrors: Errors{}, + }, + { + note: "argvar not used in primary body but in else body (with wildcard)", + module: `package test + func(x, _) { + input.test == "foo" + } else := false { + input.test == x + }`, + expectedErrors: Errors{}, + }, + { + note: "argvar used in primary body but not in implicit else body", + module: `package test + func(x) { + input.test == x + } else := false`, + expectedErrors: Errors{}, + }, + { + note: "argvars usage spread over multiple bodies", + module: `package test + func(x, y, z) { + input.test == x + } else { + input.test == y + } else { + input.test == z + }`, + expectedErrors: Errors{}, + }, + { + note: "argvars usage spread over multiple bodies, missing in first", + module: `package test + func(x, y, z) { + input.test == "foo" + } else { + input.test == y + } else { + input.test == z + }`, + expectedErrors: Errors{ + &Error{ + Code: CompileErr, + Location: NewLocation([]byte("func(x, y, z)"), "", 2, 4), + Message: "unused argument x", + }, + }, + }, + { + note: "argvars usage spread over multiple bodies, missing in second", + module: `package test + func(x, y, z) { + input.test == x + } else { + input.test == "bar" + } else { + input.test == z + }`, + expectedErrors: Errors{ + &Error{ + Code: CompileErr, + Location: NewLocation([]byte("func(x, y, z)"), "", 2, 4), + Message: "unused argument y", + }, + }, + }, + { + note: "argvars usage spread over multiple bodies, missing in third", + module: `package test + func(x, y, z) { + input.test == x + } else { + input.test == y + } else { + input.test == "baz" + }`, + expectedErrors: Errors{ + &Error{ + Code: CompileErr, + Location: NewLocation([]byte("func(x, y, z)"), "", 2, 4), + Message: "unused argument z", + }, + }, + }, + } + + t.Helper() + for _, tc := range tests { + t.Run(tc.note, func(t *testing.T) { + compiler := NewCompiler().WithStrict(true) + compiler.Modules = map[string]*Module{ + "test": MustParseModule(tc.module), + } + compileStages(compiler, nil) + + assertErrors(t, compiler.Errors, tc.expectedErrors, true) + }) + } +} + +func TestCompileUnusedAssignedVarsErrorLocations(t *testing.T) { + tests := []strictnessTestCase{ + { + note: "one of the two function args is not used - issue 5662 regression test", + module: `package test + func(x, y) { + x = 1 + }`, + expectedErrors: Errors{ + &Error{ + Code: CompileErr, + Location: NewLocation([]byte("func(x, y)"), "", 2, 4), + Message: "unused argument y", + }, + }, + }, + { + note: "multiple unused assigned var in scope - issue 5662 regression test", + module: `package test + allow { + input.message == "world" + input.test == "foo" + input.x == "foo" + input.y == "baz" + a := 1 + b := 2 + x := { + "a": a, + "b": "bar", + } + input.z == "baz" + c := 3 + }`, + expectedErrors: Errors{ + &Error{ + Code: CompileErr, + Location: NewLocation([]byte("b := 2"), "", 8, 5), + Message: "assigned var b unused", + }, + &Error{ + Code: CompileErr, + Location: NewLocation([]byte("x := {\n\t\t\t\t\t\"a\": a,\n\t\t\t\t\t\"b\": \"bar\",\n\t\t\t\t}"), "", 9, 5), + Message: "assigned var x unused", + }, + &Error{ + Code: CompileErr, + Location: NewLocation([]byte("c := 3"), "", 14, 5), + Message: "assigned var c unused", + }, + }, + }, + } + + t.Helper() + for _, tc := range tests { + t.Run(tc.note, func(t *testing.T) { + compiler := NewCompiler().WithStrict(true) + compiler.Modules = map[string]*Module{ + "test": MustParseModule(tc.module), + } + compileStages(compiler, nil) + assertErrors(t, compiler.Errors, tc.expectedErrors, true) + }) + } - // Ignore key and val vars assigned via `some k, v in xs`. - mod15 := c.Modules["someinassignwithkey"] - someInAssignCall = mod15.Rules[2].Body[0].Terms.(*SomeDecl).Symbols[0].Value.(Call) - assertTermEqual(t, someInAssignCall[1], VarTerm("k")) - assertTermEqual(t, someInAssignCall[2], VarTerm("v")) - collectionLastElem = someInAssignCall[3].Value.(*Array).Get(IntNumberTerm(2)) - assertTermEqual(t, collectionLastElem, MustParseTerm("data.someinassignwithkey.y")) } -func TestCompilerResolveErrors(t *testing.T) { +func TestCompileUnusedDeclaredVarsErrorLocations(t *testing.T) { + tests := []strictnessTestCase{ + { + note: "simple unused some var - issue 4238 regression test", + module: `package test - c := NewCompiler() - c.Modules = map[string]*Module{ - "shadow-globals": MustParseModule(` - package shadow_globals + foo { + print("Hello world") + some i + }`, + expectedErrors: Errors{ + &Error{ + Code: CompileErr, + Location: NewLocation([]byte("some i"), "", 5, 5), + Message: "declared var i unused", + }, + }, + }, + { + note: "simple unused some vars, 2x rules", + module: `package test - f([input]) { true } - `), + foo { + print("Hello world") + some i + } + + bar { + print("Hello world") + some j + }`, + expectedErrors: Errors{ + &Error{ + Code: CompileErr, + Location: NewLocation([]byte("some i"), "", 5, 5), + Message: "declared var i unused", + }, + &Error{ + Code: CompileErr, + Location: NewLocation([]byte("some j"), "", 10, 5), + Message: "declared var j unused", + }, + }, + }, + { + note: "multiple unused some vars", + module: `package test + + x := [1, 1, 1] + foo2 { + print("A") + some a, b, c + some i, j + some k + x[b] == 1 + print("B") + }`, + expectedErrors: Errors{ + &Error{ + Code: CompileErr, + Location: NewLocation([]byte("some a, b, c"), "", 6, 5), + Message: "declared var a unused", + }, + &Error{ + Code: CompileErr, + Location: NewLocation([]byte("some a, b, c"), "", 6, 5), + Message: "declared var c unused", + }, + &Error{ + Code: CompileErr, + Location: NewLocation([]byte("some i, j"), "", 7, 5), + Message: "declared var i unused", + }, + &Error{ + Code: CompileErr, + Location: NewLocation([]byte("some i, j"), "", 7, 5), + Message: "declared var j unused", + }, + &Error{ + Code: CompileErr, + Location: NewLocation([]byte("some k"), "", 8, 5), + Message: "declared var k unused", + }, + }, + }, } - compileStages(c, c.resolveAllRefs) + // This is similar to the logic for runStrictnessTestCase(), but expects + // unconditional compiler errors. + t.Helper() + for _, tc := range tests { + t.Run(tc.note, func(t *testing.T) { + compiler := NewCompiler().WithStrict(true) + compiler.Modules = map[string]*Module{ + "test": MustParseModule(tc.module), + } + compileStages(compiler, nil) - expected := []string{ - `args must not shadow input`, + assertErrors(t, compiler.Errors, tc.expectedErrors, true) + }) } - - assertCompilerErrorStrings(t, c, expected) } -func TestCompilerRewriteTermsInHead(t *testing.T) { +func TestCompileInvalidEqAssignExpr(t *testing.T) { + c := NewCompiler() - c.Modules["head"] = MustParseModule(`package head -import data.doc1 as bar -import data.doc2 as corge -import input.x.y.foo -import input.qux as baz + c.Modules["error"] = MustParseModule(`package errors -p[foo[bar[i]]] = {"baz": baz, "corge": corge} { true } -q = [true | true] { true } -r = {"true": true | true} { true } -s = {true | true} { true } -elsekw { - false -} else = baz { - true -} -`) + p { + # Arity mismatches are caught in the checkUndefinedFuncs check, + # and invalid eq/assign calls are passed along until then. + assign() + assign(1) + eq() + eq(1) + }`) + + var prev func() + checkUndefinedFuncs := reflect.ValueOf(c.checkUndefinedFuncs) + + for _, stage := range c.stages { + if reflect.ValueOf(stage.f).Pointer() == checkUndefinedFuncs.Pointer() { + break + } + prev = stage.f + } - compileStages(c, c.rewriteRefsInHead) + compileStages(c, prev) assertNotFailed(t, c) +} + +func TestCompilerRewriteComprehensionTerm(t *testing.T) { - rule1 := c.Modules["head"].Rules[0] - expected1 := MustParseRule(`p[__local0__] = __local1__ { true; __local0__ = input.x.y.foo[data.doc1[i]]; __local1__ = {"baz": input.qux, "corge": data.doc2} }`) - assertRulesEqual(t, rule1, expected1) + c := NewCompiler() + c.Modules["head"] = MustParseModule(`package head + arr = [[1], [2], [3]] + arr2 = [["a"], ["b"], ["c"]] + arr_comp = [[x[i]] | arr[j] = x] + set_comp = {[x[i]] | arr[j] = x} + obj_comp = {x[i]: x[i] | arr2[j] = x} + `) - rule2 := c.Modules["head"].Rules[1] - expected2 := MustParseRule(`q = __local2__ { true; __local2__ = [true | true] }`) - assertRulesEqual(t, rule2, expected2) + compileStages(c, c.rewriteComprehensionTerms) + assertNotFailed(t, c) - rule3 := c.Modules["head"].Rules[2] - expected3 := MustParseRule(`r = __local3__ { true; __local3__ = {"true": true | true} }`) - assertRulesEqual(t, rule3, expected3) + arrCompRule := c.Modules["head"].Rules[2] + exp1 := MustParseRule(`arr_comp = [__local0__ | data.head.arr[j] = x; __local0__ = [x[i]]] { true }`) + assertRulesEqual(t, arrCompRule, exp1) - rule4 := c.Modules["head"].Rules[3] - expected4 := MustParseRule(`s = __local4__ { true; __local4__ = {true | true} }`) - assertRulesEqual(t, rule4, expected4) + setCompRule := c.Modules["head"].Rules[3] + exp2 := MustParseRule(`set_comp = {__local1__ | data.head.arr[j] = x; __local1__ = [x[i]]} { true }`) + assertRulesEqual(t, setCompRule, exp2) - rule5 := c.Modules["head"].Rules[4] - expected5 := MustParseRule(`elsekw { false } else = __local5__ { true; __local5__ = input.qux }`) - assertRulesEqual(t, rule5, expected5) + objCompRule := c.Modules["head"].Rules[4] + exp3 := MustParseRule(`obj_comp = {__local2__: __local3__ | data.head.arr2[j] = x; __local2__ = x[i]; __local3__ = x[i]} { true }`) + assertRulesEqual(t, objCompRule, exp3) } -func TestCompilerRewriteLocalAssignments(t *testing.T) { - +func TestCompilerRewriteDoubleEq(t *testing.T) { tests := []struct { - module string - exp interface{} - expRewrittenMap map[Var]Var + note string + input string + exp string }{ { - module: ` - package test - body { a := 1; a > 0 } - `, - exp: ` - package test - body = true { __local0__ = 1; gt(__local0__, 0) } - `, - expRewrittenMap: map[Var]Var{ - Var("__local0__"): Var("a"), - }, - }, - { - module: ` - package test - head_vars(a) = b { b := a } - `, - exp: ` - package test - head_vars(__local0__) = __local1__ { __local1__ = __local0__ } - `, - expRewrittenMap: map[Var]Var{ - Var("__local0__"): Var("a"), - Var("__local1__"): Var("b"), - }, - }, - { - module: ` - package test - head_key[a] { a := 1 } - `, - exp: ` - package test - head_key[__local0__] { __local0__ = 1 } - `, - expRewrittenMap: map[Var]Var{ - Var("__local0__"): Var("a"), - }, - }, - { - module: ` - package test - head_unsafe_var[a] { some a } - `, - exp: ` - package test - head_unsafe_var[__local0__] { true } - `, - expRewrittenMap: map[Var]Var{ - Var("__local0__"): Var("a"), - }, + note: "vars and constants", + input: "p { x = 1; x == 1; y = [1,2,3]; y == [1,2,3] }", + exp: `x = 1; x = 1; y = [1,2,3]; y = [1,2,3]`, }, { - module: ` - package test - p = {1,2,3} - x = 4 - head_nested[p[x]] { - some x - }`, - exp: ` - package test - p = {1,2,3} - x = 4 - head_nested[data.test.p[__local0__]] - `, - expRewrittenMap: map[Var]Var{ - Var("__local0__"): Var("x"), - }, + note: "refs", + input: "p { input.x == data.y }", + exp: `input.x = data.y`, }, { - module: ` - package test - p = {1,2} - head_closure_nested[p[x]] { - y = [true | some x; x = 1] - } - `, - exp: ` - package test - p = {1,2} - head_closure_nested[data.test.p[x]] { - y = [true | __local0__ = 1] - } - `, - expRewrittenMap: map[Var]Var{ - Var("__local0__"): Var("x"), - }, + note: "comprehensions", + input: "p { [1|true] == [2|true] }", + exp: `[1|true] = [2|true]`, }, + // TODO(tsandall): improve support for calls so that extra unification step is + // not required. This requires more changes to the compiler as the initial + // stages that rewrite term exprs needs to be updated to handle == differently + // and then other stages need to be reviewed to make sure they can deal with + // nested calls. Alternatively, the compiler could keep track of == exprs that + // have been converted into = and then the safety check would need to be updated. { - module: ` - package test - nested { - a := [1,2,3] - x := [true | a[i] > 1] - } - `, - exp: ` - package test - nested = true { __local0__ = [1, 2, 3]; __local1__ = [true | gt(__local0__[i], 1)] } - `, - expRewrittenMap: map[Var]Var{ - Var("__local0__"): Var("a"), - Var("__local1__"): Var("x"), - }, + note: "calls", + input: "p { count([1,2]) == 2 }", + exp: `count([1,2], __local0__); __local0__ = 2`, }, { - module: ` - package test - x = 2 - shadow_globals[x] { x := 1 } - `, - exp: ` - package test - x = 2 { true } - shadow_globals[__local0__] { __local0__ = 1 } - `, - expRewrittenMap: map[Var]Var{ - Var("__local0__"): Var("x"), - }, + note: "embedded", + input: "p { x = 1; y = [x == 0] }", + exp: `x = 1; equal(x, 0, __local0__); y = [__local0__]`, }, { - module: ` - package test - shadow_rule[shadow_rule] { shadow_rule := 1 } - `, - exp: ` - package test - shadow_rule[__local0__] { __local0__ = 1 } - `, - expRewrittenMap: map[Var]Var{ - Var("__local0__"): Var("shadow_rule"), - }, + note: "embedded in call", + input: `p { x = 0; neq(true, x == 1) }`, + exp: `x = 0; equal(x, 1, __local0__); neq(true, __local0__)`, }, { - module: ` - package test - shadow_roots_1 { data := 1; input := 2; input > data } - `, - exp: ` - package test - shadow_roots_1 = true { __local0__ = 1; __local1__ = 2; gt(__local1__, __local0__) } - `, - expRewrittenMap: map[Var]Var{ - Var("__local0__"): Var("data"), - Var("__local1__"): Var("input"), - }, + note: "comprehension in object key", + input: `p { {{1 | 0 == 0}: 2} }`, + exp: `{{1 | 0 = 0}: 2}`, }, + } + for _, tc := range tests { + t.Run(tc.note, func(t *testing.T) { + c := NewCompiler() + c.Modules["test"] = MustParseModule("package test\n" + tc.input) + compileStages(c, c.rewriteEquals) + assertNotFailed(t, c) + exp := MustParseBody(tc.exp) + result := c.Modules["test"].Rules[0].Body + if result.Compare(exp) != 0 { + t.Fatalf("\nExp: %v\nGot: %v", exp, result) + } + }) + } +} + +func TestCompilerRewriteDynamicTerms(t *testing.T) { + + fixture := ` + package test + str = "hello" + ` + + tests := []struct { + input string + expected string + }{ + {`arr { [str] }`, `__local0__ = data.test.str; [__local0__]`}, + {`arr2 { [[str]] }`, `__local0__ = data.test.str; [[__local0__]]`}, + {`obj { {"x": str} }`, `__local0__ = data.test.str; {"x": __local0__}`}, + {`obj2 { {"x": {"y": str}} }`, `__local0__ = data.test.str; {"x": {"y": __local0__}}`}, + {`set { {str} }`, `__local0__ = data.test.str; {__local0__}`}, + {`set2 { {{str}} }`, `__local0__ = data.test.str; {{__local0__}}`}, + {`ref { str[str] }`, `__local0__ = data.test.str; data.test.str[__local0__]`}, + {`ref2 { str[str[str]] }`, `__local0__ = data.test.str; __local1__ = data.test.str[__local0__]; data.test.str[__local1__]`}, + {`arr_compr { [1 | [str]] }`, `[1 | __local0__ = data.test.str; [__local0__]]`}, + {`arr_compr2 { [1 | [1 | [str]]] }`, `[1 | [1 | __local0__ = data.test.str; [__local0__]]]`}, + {`set_compr { {1 | [str]} }`, `{1 | __local0__ = data.test.str; [__local0__]}`}, + {`set_compr2 { {1 | {1 | [str]}} }`, `{1 | {1 | __local0__ = data.test.str; [__local0__]}}`}, + {`obj_compr { {"a": "b" | [str]} }`, `{"a": "b" | __local0__ = data.test.str; [__local0__]}`}, + {`obj_compr2 { {"a": "b" | {"a": "b" | [str]}} }`, `{"a": "b" | {"a": "b" | __local0__ = data.test.str; [__local0__]}}`}, + {`equality { str = str }`, `data.test.str = data.test.str`}, + {`equality2 { [str] = [str] }`, `__local0__ = data.test.str; __local1__ = data.test.str; [__local0__] = [__local1__]`}, + {`call { startswith(str, "") }`, `__local0__ = data.test.str; startswith(__local0__, "")`}, + {`call2 { count([str], n) }`, `__local0__ = data.test.str; count([__local0__], n)`}, + {`eq_with { [str] = [1] with input as 1 }`, `__local0__ = data.test.str with input as 1; [__local0__] = [1] with input as 1`}, + {`term_with { [[str]] with input as 1 }`, `__local0__ = data.test.str with input as 1; [[__local0__]] with input as 1`}, + {`call_with { count(str) with input as 1 }`, `__local0__ = data.test.str with input as 1; count(__local0__) with input as 1`}, + {`call_func { f(input, "foo") } f(x,y) { x[y] }`, `__local2__ = input; data.test.f(__local2__, "foo")`}, + {`call_func2 { f(input.foo, "foo") } f(x,y) { x[y] }`, `__local2__ = input.foo; data.test.f(__local2__, "foo")`}, + {`every_domain { every _ in str { true } }`, `__local1__ = data.test.str; every __local0__, _ in __local1__ { true }`}, + {`every_domain_call { every _ in numbers.range(1, 10) { true } }`, `numbers.range(1, 10, __local1__); every __local0__, _ in __local1__ { true }`}, + {`every_body { every _ in [] { [str] } }`, + `__local1__ = []; every __local0__, _ in __local1__ { __local2__ = data.test.str; [__local2__] }`}, + } + + for _, tc := range tests { + t.Run(tc.input, func(t *testing.T) { + c := NewCompiler() + opts := ParserOptions{AllFutureKeywords: true, unreleasedKeywords: true} + module := fixture + tc.input + c.Modules["test"] = MustParseModuleWithOpts(module, opts) + compileStages(c, c.rewriteDynamicTerms) + assertNotFailed(t, c) + expected := MustParseBodyWithOpts(tc.expected, opts) + result := c.Modules["test"].Rules[1].Body + if result.Compare(expected) != 0 { + t.Fatalf("\nExp: %v\nGot: %v", expected, result) + } + }) + } +} + +func TestCompilerRewriteWithValue(t *testing.T) { + fixture := `package test + + arr = ["hello", "goodbye"] + + ` + + tests := []struct { + note string + input string + opts func(*Compiler) *Compiler + expected string + expectedRule *Rule + wantErr error + }{ { - module: ` - package test - shadow_roots_2 { input := {"a": 1}; input.a > 0 } - `, - exp: ` - package test - shadow_roots_2 = true { __local0__ = {"a": 1}; gt(__local0__.a, 0) } - `, - expRewrittenMap: map[Var]Var{ - Var("__local0__"): Var("input"), - }, + note: "nop", + input: `p { true with input as 1 }`, + expected: `p { true with input as 1 }`, }, { - module: ` - package test - skip_with_target { a := 1; input := 2; data.p with input as a } - `, - exp: ` - package test - skip_with_target = true { __local0__ = 1; __local1__ = 2; data.p with input as __local0__ } - `, - expRewrittenMap: map[Var]Var{ - Var("__local0__"): Var("a"), - Var("__local1__"): Var("input"), - }, + note: "refs", + input: `p { true with input as arr }`, + expected: `p { __local0__ = data.test.arr; true with input as __local0__ }`, }, { - module: ` - package test - shadow_comprehensions { - a := 1 - [true | a := 2; b := 1] - b := 2 - } - `, - exp: ` - package test - shadow_comprehensions = true { __local0__ = 1; [true | __local1__ = 2; __local2__ = 1]; __local3__ = 2 } - `, - expRewrittenMap: map[Var]Var{ - Var("__local0__"): Var("a"), - Var("__local1__"): Var("a"), - Var("__local2__"): Var("b"), - Var("__local3__"): Var("b"), - }, + note: "array comprehension", + input: `p { true with input as [true | true] }`, + expected: `p { __local0__ = [true | true]; true with input as __local0__ }`, }, { - module: ` - package test - scoping { - [true | a := 1] - [true | a := 2] - } - `, - exp: ` - package test - scoping = true { [true | __local0__ = 1]; [true | __local1__ = 2] } - `, - expRewrittenMap: map[Var]Var{ - Var("__local0__"): Var("a"), - Var("__local1__"): Var("a"), - }, + note: "set comprehension", + input: `p { true with input as {true | true} }`, + expected: `p { __local0__ = {true | true}; true with input as __local0__ }`, }, { - module: ` - package test - object_keys { - {k: v1, "k2": v2} := {"foo": 1, "k2": 2} - } - `, - exp: ` - package test - object_keys = true { {"k2": __local0__, k: __local1__} = {"foo": 1, "k2": 2} } - `, - expRewrittenMap: map[Var]Var{ - Var("__local0__"): Var("v2"), - Var("__local1__"): Var("v1"), - }, + note: "object comprehension", + input: `p { true with input as {"k": true | true} }`, + expected: `p { __local0__ = {"k": true | true}; true with input as __local0__ }`, }, { - module: ` - package test - head_array_comprehensions = [[x] | x := 1] - head_set_comprehensions = {[x] | x := 1} - head_object_comprehensions = {k: [x] | k := "foo"; x := 1} - `, - exp: ` - package test - head_array_comprehensions = [[__local0__] | __local0__ = 1] { true } - head_set_comprehensions = {[__local1__] | __local1__ = 1} { true } - head_object_comprehensions = {__local2__: [__local3__] | __local2__ = "foo"; __local3__ = 1} { true } - `, - expRewrittenMap: map[Var]Var{ - Var("__local0__"): Var("x"), - Var("__local1__"): Var("x"), - Var("__local2__"): Var("k"), - Var("__local3__"): Var("x"), - }, + note: "comprehension nested", + input: `p { true with input as [true | true with input as arr] }`, + expected: `p { __local0__ = [true | __local1__ = data.test.arr; true with input as __local1__]; true with input as __local0__ }`, }, { - module: ` - package test - rewritten_object_key { - k := "foo" - {k: 1} - } - `, - exp: ` - package test - rewritten_object_key = true { __local0__ = "foo"; {__local0__: 1} } - `, - expRewrittenMap: map[Var]Var{ - Var("__local0__"): Var("k"), - }, + note: "multiple", + input: `p { true with input.a as arr[0] with input.b as arr[1] }`, + expected: `p { __local0__ = data.test.arr[0]; __local1__ = data.test.arr[1]; true with input.a as __local0__ with input.b as __local1__ }`, }, { - module: ` - package test - rewritten_object_key_head[[{k: 1}]] { - k := "foo" - } - `, - exp: ` - package test - rewritten_object_key_head[[{__local0__: 1}]] { __local0__ = "foo" } - `, - expRewrittenMap: map[Var]Var{ - Var("__local0__"): Var("k"), - }, + note: "invalid target", + input: `p { true with foo.q as 1 }`, + wantErr: fmt.Errorf("rego_type_error: with keyword target must reference existing input, data, or a function"), }, - { - module: ` - package test - rewritten_object_key_head_value = [{k: 1}] { - k := "foo" - } - `, - exp: ` - package test - rewritten_object_key_head_value = [{__local0__: 1}] { __local0__ = "foo" } - `, - expRewrittenMap: map[Var]Var{ - Var("__local0__"): Var("k"), - }, + { + note: "built-in function: replaced by (unknown) var", + input: `p { true with time.now_ns as foo }`, + expected: `p { true with time.now_ns as foo }`, // `foo` still a Var here }, { - module: ` - package test - skip_with_target_in_assignment { - input := 1 - a := [true | true with input as 2; true with input as 3] - } - `, - exp: ` - package test - skip_with_target_in_assignment = true { __local0__ = 1; __local1__ = [true | true with input as 2; true with input as 3] } + note: "built-in function: valid, arity 0", + input: ` + p { true with time.now_ns as now } + now() = 1 `, - expRewrittenMap: map[Var]Var{ - Var("__local0__"): Var("input"), - Var("__local1__"): Var("a"), - }, + expected: `p { true with time.now_ns as data.test.now }`, }, { - module: ` - package test - rewrite_value_in_assignment { - a := 1 - b := 1 with input as [a] - } + note: "built-in function: valid func ref, arity 1", + input: ` + p { true with http.send as mock_http_send } + mock_http_send(_) = { "body": "yay" } `, - exp: ` - package test - rewrite_value_in_assignment = true { __local0__ = 1; __local1__ = 1 with input as [__local0__] } + expected: `p { true with http.send as data.test.mock_http_send }`, + }, + { + note: "built-in function: replaced by value", + input: ` + p { true with http.send as { "body": "yay" } } `, - expRewrittenMap: map[Var]Var{ - Var("__local0__"): Var("a"), - Var("__local1__"): Var("b"), - }, + expected: `p { true with http.send as {"body": "yay"} }`, }, { - module: ` - package test - global = {} - ref_shadowed { - global := {"a": 1} - global.a > 0 + note: "built-in function: replaced by var", + input: ` + p { + resp := { "body": "yay" } + true with http.send as resp } `, - exp: ` - package test - global = {} { true } - ref_shadowed = true { __local0__ = {"a": 1}; gt(__local0__.a, 0) } - `, - expRewrittenMap: map[Var]Var{ - Var("__local0__"): Var("global"), - }, + expected: `p { __local0__ = {"body": "yay"}; true with http.send as __local0__ }`, }, { - module: ` - package test - f(x) = y { - x == 1 - y := 2 - } else = y { - x == 3 - y := 4 + note: "non-built-in function: replaced by var", + input: ` + p { + resp := true + f(true) with f as resp } + f(false) { true } `, - // Each "else" rule has a separate rule head and the vars in the - // args will be rewritten. Since we cannot currently redefine the - // args, we must parse the module and then manually update the args. - exp: func() *Module { - module := MustParseModule(` - package test - - f(__local0__) = __local1__ { __local0__ == 1; __local1__ = 2 } else = __local3__ { __local2__ == 3; __local3__ = 4 } - `) - module.Rules[0].Else.Head.Args[0].Value = Var("__local2__") - return module - }, - expRewrittenMap: map[Var]Var{ - Var("__local0__"): Var("x"), - Var("__local1__"): Var("y"), - Var("__local2__"): Var("x"), - Var("__local3__"): Var("y"), - }, + expected: `p { __local0__ = true; data.test.f(true) with data.test.f as __local0__ }`, }, { - module: ` - package test - f({"x": [x]}) = y { x == 1; y := 2 }`, - exp: ` - package test - - f({"x": [__local0__]}) = __local1__ { __local0__ == 1; __local1__ = 2 }`, - expRewrittenMap: map[Var]Var{ - Var("__local0__"): Var("x"), - Var("__local1__"): Var("y"), - }, + note: "built-in function: replaced by comprehension", + input: ` + p { true with http.send as { x: true | x := ["a", "b"][_] } } + `, + expected: `p { __local2__ = {__local0__: true | __local1__ = ["a", "b"]; __local0__ = __local1__[_]}; true with http.send as __local2__ }`, }, { - module: ` - package test - - f(x, [x]) = x { x == 1 } - `, - exp: ` - package test - - f(__local0__, [__local0__]) = __local0__ { __local0__ == 1 } + note: "built-in function: replaced by ref", + input: ` + p { true with http.send as resp } + resp := { "body": "yay" } `, - expRewrittenMap: map[Var]Var{ - Var("__local0__"): Var("x"), - }, + expected: `p { true with http.send as data.test.resp }`, }, { - module: ` - package test - - f(x) = {x[0]: 1} { true } - `, - exp: ` - package test - - f(__local0__) = {__local0__[0]: 1} { true } + note: "built-in function: replaced by another built-in (ref)", + input: ` + p { true with http.send as object.union_n } `, - expRewrittenMap: map[Var]Var{ - Var("__local0__"): Var("x"), - }, + expected: `p { true with http.send as object.union_n }`, }, { - module: ` - package test - - f({{t | t := 0}: 1}) { - true - } - `, - exp: ` - package test - - f({{__local0__ | __local0__ = 0}: 1}) { true } + note: "built-in function: replaced by another built-in (simple)", + input: ` + p { true with http.send as count } `, - expRewrittenMap: map[Var]Var{ - Var("__local0__"): Var("t"), - }, + expectedRule: func() *Rule { + r := MustParseRule(`p { true with http.send as count }`) + r.Body[0].With[0].Value.Value = Ref([]*Term{VarTerm("count")}) + return r + }(), }, { - module: ` - package test - - f({{t | t := 0}}) { - true - } + note: "built-in function: replaced by another built-in that's marked unsafe", + input: ` + q := is_object({"url": "https://httpbin.org", "method": "GET"}) + p { q with is_object as http.send } `, - exp: ` - package test - - f({{__local0__ | __local0__ = 0}}) { true } + opts: func(c *Compiler) *Compiler { return c.WithUnsafeBuiltins(map[string]struct{}{"http.send": {}}) }, + wantErr: fmt.Errorf("rego_compile_error: with keyword replacing built-in function: target must not be unsafe: \"http.send\""), + }, + { + note: "non-built-in function: replaced by another built-in that's marked unsafe", + input: ` + r(_) = {} + q := r({"url": "https://httpbin.org", "method": "GET"}) + p { + q with r as http.send + }`, + opts: func(c *Compiler) *Compiler { return c.WithUnsafeBuiltins(map[string]struct{}{"http.send": {}}) }, + wantErr: fmt.Errorf("rego_compile_error: with keyword replacing built-in function: target must not be unsafe: \"http.send\""), + }, + { + note: "built-in function: valid, arity 1, non-compound name", + input: ` + p { concat("/", input) with concat as mock_concat } + mock_concat(_, _) = "foo/bar" `, - expRewrittenMap: map[Var]Var{ - Var("__local0__"): Var("t"), - }, + expectedRule: func() *Rule { + r := MustParseRule(`p { concat("/", input) with concat as data.test.mock_concat }`) + r.Body[0].With[0].Target.Value = Ref([]*Term{VarTerm("concat")}) + return r + }(), }, } - for i, tc := range tests { - t.Run(fmt.Sprint(i), func(t *testing.T) { + for _, tc := range tests { + t.Run(tc.note, func(t *testing.T) { c := NewCompiler() - c.Modules = map[string]*Module{ - "test.rego": MustParseModule(tc.module), - } - compileStages(c, c.rewriteLocalVars) - assertNotFailed(t, c) - result := c.Modules["test.rego"] - var exp *Module - switch e := tc.exp.(type) { - case string: - exp = MustParseModule(e) - case func() *Module: - exp = e() - default: - panic("expected value must be string or func() *Module") - } - if result.Compare(exp) != 0 { - t.Fatalf("\nExpected:\n\n%v\n\nGot:\n\n%v", exp, result) + if tc.opts != nil { + c = tc.opts(c) } - if !reflect.DeepEqual(c.RewrittenVars, tc.expRewrittenMap) { - t.Fatalf("\nExpected Rewritten Vars:\n\n\t%+v\n\nGot:\n\n\t%+v\n\n", tc.expRewrittenMap, c.RewrittenVars) + module := fixture + tc.input + c.Modules["test"] = MustParseModule(module) + compileStages(c, c.rewriteWithModifiers) + if tc.wantErr == nil { + assertNotFailed(t, c) + expected := tc.expectedRule + if expected == nil { + expected = MustParseRule(tc.expected) + } + result := c.Modules["test"].Rules[1] + if result.Compare(expected) != 0 { + t.Fatalf("\nExp: %v\nGot: %v", expected, result) + } + } else { + assertCompilerErrorStrings(t, c, []string{tc.wantErr.Error()}) } }) } +} -} -func TestRewriteLocalVarDeclarationErrors(t *testing.T) { - - c := NewCompiler() - - c.Modules["test"] = MustParseModule(`package test - - redeclaration { - r1 = 1 - r1 := 2 - r2 := 1 - [b, r2] := [1, 2] - input.path == 1 - input := "foo" - _ := [1 | nested := 1; nested := 2] - } - - negation { - not a := 1 - } - - bad_assign { - null := x - true := x - 4.5 := x - "foo" := x - [true | true] := [] - {true | true} := set() - {"foo": true | true} := {} - x + 1 := 2 - data.foo := 1 - [z, 1] := [1, 2] - } - - arg_redeclared(arg1) { - arg1 := 1 - } - - arg_nested_redeclared({{arg_nested| arg_nested := 1; arg_nested := 2}}) { true } - `) - - compileStages(c, c.rewriteLocalVars) - - expectedErrors := []string{ - "var r1 referenced above", - "var r2 assigned above", - "var input referenced above", - "var nested assigned above", - "arg arg1 redeclared", - "var arg_nested assigned above", - "cannot assign vars inside negated expression", - "cannot assign to ref", - "cannot assign to arraycomprehension", - "cannot assign to setcomprehension", - "cannot assign to objectcomprehension", - "cannot assign to call", - "cannot assign to number", - "cannot assign to number", - "cannot assign to boolean", - "cannot assign to string", - "cannot assign to null", - } - - sort.Strings(expectedErrors) - - result := []string{} - - for i := range c.Errors { - result = append(result, c.Errors[i].Message) - } - - sort.Strings(result) - - if len(expectedErrors) != len(result) { - t.Fatalf("Expected %d errors but got %d:\n\n%v\n\nGot:\n\n%v", len(expectedErrors), len(result), strings.Join(expectedErrors, "\n"), strings.Join(result, "\n")) - } - - for i := range result { - if result[i] != expectedErrors[i] { - t.Fatalf("Expected:\n\n%v\n\nGot:\n\n%v", strings.Join(expectedErrors, "\n"), strings.Join(result, "\n")) - } - } -} - -func TestRewriteDeclaredVarsStage(t *testing.T) { - - // Unlike the following test case, this only executes up to the - // RewriteLocalVars stage. This is done so that later stages like - // RewriteDynamics are not executed. +func TestCompilerRewritePrintCallsErasure(t *testing.T) { - tests := []struct { + cases := []struct { note string module string exp string }{ { - note: "object ref key", - module: ` - package test + note: "no-op", + module: `package test + p { true }`, + exp: `package test + p { true }`, + }, + { + note: "replace empty body with true", + module: `package test - p { - a := {"a": "a"} - {a.a: a.a} - } + p { print(1) } `, - exp: ` - package test + exp: `package test - p { - __local0__ = {"a": "a"} - {__local0__.a: __local0__.a} - } + p { true } `, + }, + { + note: "rule body", + module: `package test + + p { false; print(1) } `, + exp: `package test + + p { false } `, }, { - note: "set ref element", - module: ` - package test + note: "set comprehension body", + module: `package test - p { - a := {"a": "a"} - {a.a} - } + p { {1 | false; print(1)} } `, - exp: ` - package test + exp: `package test - p { - __local0__ = {"a": "a"} - {__local0__.a} - } + p { {1 | false} } `, + }, + { + note: "array comprehension body", + module: `package test + + p { [1 | false; print(1)] } `, + exp: `package test + + p { [1 | false] } `, }, - } + { + note: "object comprehension body", + module: `package test - for _, tc := range tests { - t.Run(tc.note, func(t *testing.T) { + p { {"x": 1 | false; print(1)} } + `, + exp: `package test - c := NewCompiler() + p { {"x": 1 | false} } `, + }, + { + note: "every body", + module: `package test - c.Modules = map[string]*Module{ - "test.rego": MustParseModule(tc.module), - } + p { every _ in [] { false; print(1) } } + `, + exp: `package test - compileStages(c, c.rewriteLocalVars) + p = true { __local1__ = []; every __local0__, _ in __local1__ { false } }`, + }, + { + note: "in head", + module: `package test - exp := MustParseModule(tc.exp) - result := c.Modules["test.rego"] + p = {1 | print("x")}`, + exp: `package test - if !exp.Equal(result) { - t.Fatalf("Expected:\n\n%v\n\nGot:\n\n%v", exp, result) + p = __local0__ { true; __local0__ = {1 | true} }`, + }, + } + + for _, tc := range cases { + t.Run(tc.note, func(t *testing.T) { + c := NewCompiler().WithEnablePrintStatements(false) + opts := ParserOptions{AllFutureKeywords: true, unreleasedKeywords: true} + c.Compile(map[string]*Module{ + "test.rego": MustParseModuleWithOpts(tc.module, opts), + }) + if c.Failed() { + t.Fatal(c.Errors) + } + exp := MustParseModuleWithOpts(tc.exp, opts) + if !exp.Equal(c.Modules["test.rego"]) { + t.Fatalf("Expected:\n\n%v\n\nGot:\n\n%v", exp, c.Modules["test.rego"]) } }) } } -func TestRewriteDeclaredVars(t *testing.T) { - tests := []struct { - note string - module string - exp string - wantErr error +func TestCompilerRewritePrintCallsErrors(t *testing.T) { + cases := []struct { + note string + module string + exp error }{ { - note: "rewrite unify", - module: ` - package test - x = 1 - y = 2 - p { some x; input = [x, y] } - `, - exp: ` - package test - x = 1 - y = 2 - p { __local1__ = data.test.y; input = [__local0__, __local1__] } - `, - }, - { - note: "rewrite call", - module: ` - package test - x = [] - y = {} - p { some x; walk(y, [x, y]) } - `, - exp: ` - package test - x = [] - y = {} - p { __local1__ = data.test.y; __local2__ = data.test.y; walk(__local1__, [__local0__, __local2__]) } - `, + note: "non-existent var", + module: `package test + + p { print(x) }`, + exp: errors.New("var x is undeclared"), }, { - note: "rewrite term", - module: ` - package test - x = "a" - y = 1 - q[[2, "b"]] - p { some x; q[[y,x]] } - `, - exp: ` - package test - x = "a" - y = 1 - q[[2, "b"]] - p { __local1__ = data.test.y; data.test.q[[__local1__, __local0__]] } - `, + note: "declared after print", + module: `package test + + p { print(x); x = 7 }`, + exp: errors.New("var x is undeclared"), }, { - note: "rewrite some x in xs", - module: ` - package test - import future.keywords.in - xs = ["a", "b", "c"] - p { some x in xs; x == "a" } - `, - exp: ` - package test - xs = ["a", "b", "c"] - p { __local2__ = data.test.xs[__local1__]; __local2__ = "a" } + note: "inside comprehension", + module: `package test + p { {1 | print(x)} } `, + exp: errors.New("var x is undeclared"), }, + } + + for _, tc := range cases { + t.Run(tc.note, func(t *testing.T) { + c := NewCompiler().WithEnablePrintStatements(true) + c.Compile(map[string]*Module{ + "test.rego": MustParseModule(tc.module), + }) + if !c.Failed() { + t.Fatal("expected error") + } + if c.Errors[0].Code != CompileErr || c.Errors[0].Message != tc.exp.Error() { + t.Fatal("unexpected error:", c.Errors) + } + }) + } +} + +func TestCompilerRewritePrintCalls(t *testing.T) { + cases := []struct { + note string + module string + exp string + }{ { - note: "rewrite some k, x in xs", - module: ` - package test - import future.keywords.in - xs = ["a", "b", "c"] - p { some k, x in xs; x == "a"; k == 2 } - `, - exp: ` - package test - xs = ["a", "b", "c"] - p { __local1__ = data.test.xs[__local0__]; __local1__ = "a"; __local0__ = 2 } - `, + note: "print one", + module: `package test + + p { print(1) }`, + exp: `package test + + p = true { __local1__ = {__local0__ | __local0__ = 1}; internal.print([__local1__]) }`, }, { - note: "rewrite some k, x in xs[i]", - module: ` - package test - import future.keywords.in - xs = [["a", "b", "c"], []] - p { - some i - some k, x in xs[i] - x == "a" - k == 2 - } - `, - exp: ` - package test - xs = [["a", "b", "c"], []] - p = true { __local2__ = data.test.xs[__local0__][__local1__]; __local2__ = "a"; __local1__ = 2 } - `, + note: "print multiple", + module: `package test + + p { print(1, 2) }`, + exp: `package test + + p = true { __local2__ = {__local0__ | __local0__ = 1}; __local3__ = {__local1__ | __local1__ = 2}; internal.print([__local2__, __local3__]) }`, }, { - note: "rewrite some k, x in xs[i] with `i` as ref", - module: ` - package test - import future.keywords.in - i = 0 - xs = [["a", "b", "c"], []] - p { - some k, x in xs[i] - x == "a" - k == 2 - } - `, - exp: ` - package test - i = 0 - xs = [["a", "b", "c"], []] - p = true { __local2__ = data.test.i; __local1__ = data.test.xs[__local2__][__local0__]; __local1__ = "a"; __local0__ = 2 } - `, + note: "print inside set comprehension", + module: `package test + + p { x = 1; {2 | print(x)} }`, + exp: `package test + + p = true { x = 1; {2 | __local1__ = {__local0__ | __local0__ = x}; internal.print([__local1__])} }`, }, { - note: "rewrite closures", - module: ` - package test - x = 1 - y = 2 - p { - some x, z - z = 3 - [x | x = 2; y = 2; some z; z = 4] - } - `, - exp: ` - package test - x = 1 - y = 2 - p { - __local1__ = 3 - [__local0__ | __local0__ = 2; data.test.y = 2; __local2__ = 4] - } - `, + note: "print inside array comprehension", + module: `package test + + p { x = 1; [2 | print(x)] }`, + exp: `package test + + p = true { x = 1; [2 | __local1__ = {__local0__ | __local0__ = x}; internal.print([__local1__])] }`, }, { - note: "rewrite head var", - module: ` - package test - x = "a" - y = 1 - z = 2 - p[x] = [y, z] { - some x, z - x = "b" - z = 4 - }`, - exp: ` - package test - x = "a" - y = 1 - z = 2 - p[__local0__] = __local2__ { - __local0__ = "b" - __local1__ = 4; - __local3__ = data.test.y - __local2__ = [__local3__, __local1__] - } - `, + note: "print inside object comprehension", + module: `package test + + p { x = 1; {"x": 2 | print(x)} }`, + exp: `package test + + p = true { x = 1; {"x": 2 | __local1__ = {__local0__ | __local0__ = x}; internal.print([__local1__])} }`, }, { - note: "rewrite call with root document ref as arg", - module: ` - package test + note: "print inside every", + module: `package test - p { - f(input, "bar") - } + p { every x in [1,2] { print(x) } }`, + exp: `package test - f(x, y) { - x[y] + p = true { + __local3__ = [1, 2] + every __local0__, __local1__ in __local3__ { + __local4__ = {__local2__ | __local2__ = __local1__} + internal.print([__local4__]) } - `, - exp: ` - package test + }`, + }, + { + note: "print output of nested call", + module: `package test - p = true { - __local2__ = input; - data.test.f(__local2__, "bar") - } + p { + x := split("abc", "")[y] + print(x, y) + }`, + exp: `package test - f(__local0__, __local1__) = true { - __local0__[__local1__] - } - `, + p = true { split("abc", "", __local3__); __local0__ = __local3__[y]; __local4__ = {__local1__ | __local1__ = __local0__}; __local5__ = {__local2__ | __local2__ = y}; internal.print([__local4__, __local5__]) }`, }, { - note: "redeclare err", - module: ` - package test - p { - some x - some x - } - `, - wantErr: errors.New("var x declared above"), + note: "print call in head", + module: `package test + + p = {1 | print("x") }`, + exp: `package test + + p = __local1__ { + true + __local1__ = {1 | __local2__ = { __local0__ | __local0__ = "x"}; internal.print([__local2__])} + }`, }, { - note: "redeclare assigned err", - module: ` - package test - p { - x := 1 - some x - } + note: "print call in head - args treated as safe", + module: `package test + + f(a) = {1 | a[x]; print(x)}`, + exp: `package test + + f(__local0__) = __local2__ { true; __local2__ = {1 | __local0__[x]; __local3__ = {__local1__ | __local1__ = x}; internal.print([__local3__])} } `, - wantErr: errors.New("var x assigned above"), }, { - note: "redeclare reference err", - module: ` - package test - p { - data.q[x] - some x - } + note: "print call of var in head key", + module: `package test + f(_) = [1, 2, 3] + p[x] { [_, x, _] := f(true); print(x) }`, + exp: `package test + f(__local0__) = [1, 2, 3] { true } + p[__local2__] { data.test.f(true, __local5__); [__local1__, __local2__, __local3__] = __local5__; __local6__ = {__local4__ | __local4__ = __local2__}; internal.print([__local6__]) } `, - wantErr: errors.New("var x referenced above"), }, { - note: "declare unused err", - module: ` - package test - p { - some x - } + note: "print call of var in head value", + module: `package test + f(_) = [1, 2, 3] + p = x { [_, x, _] := f(true); print(x) }`, + exp: `package test + f(__local0__) = [1, 2, 3] { true } + p = __local2__ { data.test.f(true, __local5__); [__local1__, __local2__, __local3__] = __local5__; __local6__ = {__local4__ | __local4__ = __local2__}; internal.print([__local6__]) } `, - wantErr: errors.New("declared var x unused"), }, { - note: "declare unsafe err", - module: ` - package test - p[x] { - some x - x == 1 - } + note: "print call of vars in head key and value", + module: `package test + f(_) = [1, 2, 3] + p[x] = y { [_, x, y] := f(true); print(x) }`, + exp: `package test + f(__local0__) = [1, 2, 3] { true } + p[__local2__] = __local3__ { data.test.f(true, __local5__); [__local1__, __local2__, __local3__] = __local5__; __local6__ = {__local4__ | __local4__ = __local2__}; internal.print([__local6__]) } `, - wantErr: errors.New("var x is unsafe"), }, { - note: "declare arg err", - module: ` - package test - - f([a]) { - some a - a = 1 - } - `, - wantErr: errors.New("arg a redeclared"), + note: "print call of vars altered with 'with' and call", + module: `package test + q = input + p { + x := q with input as json.unmarshal("{}") + print(x) + }`, + exp: `package test + q = __local3__ { true; __local3__ = input } + p = true { + json.unmarshal("{}", __local2__) + __local0__ = data.test.q with input as __local2__ + __local4__ = {__local1__ | __local1__ = __local0__} + internal.print([__local4__]) + }`, }, } - for _, tc := range tests { + for _, tc := range cases { t.Run(tc.note, func(t *testing.T) { - compiler, err := CompileModules(map[string]string{"test.rego": tc.module}) - if tc.wantErr != nil { - if err == nil { - t.Fatal("Expected error but got success") - } - if !strings.Contains(err.Error(), tc.wantErr.Error()) { - t.Fatalf("Expected %v but got %v", tc.wantErr, err) - } - } else if err != nil { - t.Fatal(err) - } else { - exp := MustParseModule(tc.exp) - result := compiler.Modules["test.rego"] - if exp.Compare(result) != 0 { - t.Fatalf("Expected:\n\n%v\n\nGot:\n\n%v", exp, result) - } + c := NewCompiler().WithEnablePrintStatements(true) + opts := ParserOptions{AllFutureKeywords: true, unreleasedKeywords: true} + c.Compile(map[string]*Module{ + "test.rego": MustParseModuleWithOpts(tc.module, opts), + }) + if c.Failed() { + t.Fatal(c.Errors) + } + exp := MustParseModuleWithOpts(tc.exp, opts) + if !exp.Equal(c.Modules["test.rego"]) { + t.Fatalf("Expected:\n\n%v\n\nGot:\n\n%v", exp, c.Modules["test.rego"]) } }) } } -func TestCompileInvalidEqAssignExpr(t *testing.T) { - - c := NewCompiler() - - c.Modules["error"] = MustParseModule(`package errors - - - p { - # Arity mismatches are caught in the checkUndefinedFuncs check, - # and invalid eq/assign calls are passed along until then. - assign() - assign(1) - eq() - eq(1) - }`) - - var prev func() - checkUndefinedFuncs := reflect.ValueOf(c.checkUndefinedFuncs) - - for _, stage := range c.stages { - if reflect.ValueOf(stage.f).Pointer() == checkUndefinedFuncs.Pointer() { - break - } - prev = stage.f +func TestRewritePrintCallsWithElseImplicitArgs(t *testing.T) { + + module := `package test + + f(x, y) { + x = y } - compileStages(c, prev) - assertNotFailed(t, c) -} + else = false { + print(x, y) + }` -func TestCompilerRewriteComprehensionTerm(t *testing.T) { + c := NewCompiler().WithEnablePrintStatements(true) + opts := ParserOptions{AllFutureKeywords: true, unreleasedKeywords: true} + c.Compile(map[string]*Module{ + "test.rego": MustParseModuleWithOpts(module, opts), + }) - c := NewCompiler() - c.Modules["head"] = MustParseModule(`package head - arr = [[1], [2], [3]] - arr2 = [["a"], ["b"], ["c"]] - arr_comp = [[x[i]] | arr[j] = x] - set_comp = {[x[i]] | arr[j] = x} - obj_comp = {x[i]: x[i] | arr2[j] = x} - `) + if c.Failed() { + t.Fatal(c.Errors) + } - compileStages(c, c.rewriteComprehensionTerms) - assertNotFailed(t, c) + exp := MustParseModuleWithOpts(`package test - arrCompRule := c.Modules["head"].Rules[2] - exp1 := MustParseRule(`arr_comp = [__local0__ | data.head.arr[j] = x; __local0__ = [x[i]]] { true }`) - assertRulesEqual(t, arrCompRule, exp1) + f(__local0__, __local1__) = true { __local0__ = __local1__ } + else = false { __local4__ = {__local2__ | __local2__ = __local0__}; __local5__ = {__local3__ | __local3__ = __local1__}; internal.print([__local4__, __local5__]) } + `, opts) - setCompRule := c.Modules["head"].Rules[3] - exp2 := MustParseRule(`set_comp = {__local1__ | data.head.arr[j] = x; __local1__ = [x[i]]} { true }`) - assertRulesEqual(t, setCompRule, exp2) + // NOTE(tsandall): we have to patch the implicit args on the else rule + // because of how the parser copies the arg names across from the first + // rule. + exp.Rules[0].Else.Head.Args[0] = VarTerm("__local0__") + exp.Rules[0].Else.Head.Args[1] = VarTerm("__local1__") - objCompRule := c.Modules["head"].Rules[4] - exp3 := MustParseRule(`obj_comp = {__local2__: __local3__ | data.head.arr2[j] = x; __local2__ = x[i]; __local3__ = x[i]} { true }`) - assertRulesEqual(t, objCompRule, exp3) + if !exp.Equal(c.Modules["test.rego"]) { + t.Fatalf("Expected:\n\n%v\n\nGot:\n\n%v", exp, c.Modules["test.rego"]) + } } -func TestCompilerRewriteDoubleEq(t *testing.T) { +func TestCompilerMockFunction(t *testing.T) { tests := []struct { - note string - input string - exp string + note string + module, extra string + err string }{ { - note: "vars and constants", - input: "p { x = 1; x == 1; y = [1,2,3]; y == [1,2,3] }", - exp: `x = 1; x = 1; y = [1,2,3]; y = [1,2,3]`, + note: "simple valid", + module: `package test + now() = 123 + p { true with time.now_ns as now } + `, }, { - note: "refs", - input: "p { input.x == data.y }", - exp: `input.x = data.y`, + note: "simple valid, simple name", + module: `package test + mock_concat(_, _) = "foo/bar" + p { concat("/", input) with concat as mock_concat } + `, }, { - note: "comprehensions", - input: "p { [1|true] == [2|true] }", - exp: `[1|true] = [2|true]`, + note: "invalid ref: nonexistant", + module: `package test + p { true with time.now_ns as now } + `, + err: "rego_unsafe_var_error: var now is unsafe", // we're running all compiler stages here }, - // TODO(tsandall): improve support for calls so that extra unification step is - // not required. This requires more changes to the compiler as the initial - // stages that rewrite term exprs needs to be updated to handle == differently - // and then other stages need to be reviewed to make sure they can deal with - // nested calls. Alternatively, the compiler could keep track of == exprs that - // have been converted into = and then the safety check would need to be updated. { - note: "calls", - input: "p { count([1,2]) == 2 }", - exp: `count([1,2], __local0__); __local0__ = 2`, + note: "valid ref: not a function, but arity = 0", + module: `package test + now = 1 + p { true with time.now_ns as now } + `, }, { - note: "embedded", - input: "p { x = 1; y = [x == 0] }", - exp: `x = 1; equal(x, 0, __local0__); y = [__local0__]`, + note: "ref: not a function, arity > 0", + module: `package test + http_send = { "body": "nope" } + p { true with http.send as http_send } + `, }, { - note: "embedded in call", - input: `p { x = 0; neq(true, x == 1) }`, - exp: `x = 0; equal(x, 1, __local0__); neq(true, __local0__)`, + note: "invalid ref: arity mismatch", + module: `package test + http_send(_, _) = { "body": "nope" } + p { true with http.send as http_send } + `, + err: "rego_type_error: http.send: arity mismatch\n\thave: (any, any)\n\twant: (request: object[string: any])", }, { - note: "comprehension in object key", - input: `p { {{1 | 0 == 0}: 2} }`, - exp: `{{1 | 0 = 0}: 2}`, + note: "invalid ref: arity mismatch (in call)", + module: `package test + http_send(_, _) = { "body": "nope" } + p { http.send({}) with http.send as http_send } + `, + err: "rego_type_error: http.send: arity mismatch\n\thave: (any, any)\n\twant: (request: object[string: any])", + }, + { + note: "invalid ref: value another built-in with different type", + module: `package test + p { true with http.send as net.lookup_ip_addr } + `, + err: "rego_type_error: http.send: arity mismatch\n\thave: (string)\n\twant: (request: object[string: any])", + }, + { + note: "ref: value another built-in with compatible type", + module: `package test + p { true with count as object.union_n } + `, + }, + { + note: "valid: package import", + extra: `package mocks + http_send(_) = {} + `, + module: `package test + import data.mocks + p { true with http.send as mocks.http_send } + `, + }, + { + note: "valid: function import", + extra: `package mocks + http_send(_) = {} + `, + module: `package test + import data.mocks.http_send + p { true with http.send as http_send } + `, + }, + { + note: "invalid target: relation", + module: `package test + my_walk(_, _) + p { true with walk as my_walk } + `, + err: "rego_compile_error: with keyword replacing built-in function: target must not be a relation", + }, + { + note: "invalid target: eq", + module: `package test + my_eq(_, _) + p { true with eq as my_eq } + `, + err: `rego_compile_error: with keyword replacing built-in function: replacement of "eq" invalid`, + }, + { + note: "invalid target: rego.metadata.chain", + module: `package test + p { true with rego.metadata.chain as [] } + `, + err: `rego_compile_error: with keyword replacing built-in function: replacement of "rego.metadata.chain" invalid`, + }, + { + note: "invalid target: rego.metadata.rule", + module: `package test + p { true with rego.metadata.rule as {} } + `, + err: `rego_compile_error: with keyword replacing built-in function: replacement of "rego.metadata.rule" invalid`, + }, + { + note: "invalid target: internal.print", + module: `package test + my_print(_, _) + p { true with internal.print as my_print } + `, + err: `rego_compile_error: with keyword replacing built-in function: replacement of internal function "internal.print" invalid`, + }, + { + note: "mocking custom built-in", + module: `package test + mock(_) + mock_mock(_) + p { bar(foo.bar("one")) with bar as mock with foo.bar as mock_mock } + `, + }, + { + note: "non-built-in function replaced value", + module: `package test + original(_) + p { original(true) with original as 123 } + `, + }, + { + note: "non-built-in function replaced by another, arity 0", + module: `package test + original() = 1 + mock() = 2 + p { original() with original as mock } + `, + err: "rego_type_error: undefined function data.test.original", // TODO(sr): file bug -- this doesn't depend on "with" used or not + }, + { + note: "non-built-in function replaced by another, arity 1", + module: `package test + original(_) + mock(_) + p { original(true) with original as mock } + `, + }, + { + note: "non-built-in function replaced by built-in", + module: `package test + original(_) + p { original([1]) with original as count } + `, + }, + { + note: "non-built-in function replaced by another, arity mismatch", + module: `package test + original(_) + mock(_, _) + p { original([1]) with original as mock } + `, + err: "rego_type_error: data.test.original: arity mismatch\n\thave: (any, any)\n\twant: (any)", + }, + { + note: "non-built-in function replaced by built-in, arity mismatch", + module: `package test + original(_) + p { original([1]) with original as concat } + `, + err: "rego_type_error: data.test.original: arity mismatch\n\thave: (string, any)\n\twant: (any)", }, } + for _, tc := range tests { t.Run(tc.note, func(t *testing.T) { - c := NewCompiler() - c.Modules["test"] = MustParseModule("package test\n" + tc.input) - compileStages(c, c.rewriteEquals) - assertNotFailed(t, c) - exp := MustParseBody(tc.exp) - result := c.Modules["test"].Rules[0].Body - if result.Compare(exp) != 0 { - t.Fatalf("\nExp: %v\nGot: %v", exp, result) + c := NewCompiler().WithBuiltins(map[string]*Builtin{ + "bar": { + Name: "bar", + Decl: types.NewFunction([]types.Type{types.S}, types.A), + }, + "foo.bar": { + Name: "foo.bar", + Decl: types.NewFunction([]types.Type{types.S}, types.A), + }, + }) + if tc.extra != "" { + c.Modules["extra"] = MustParseModule(tc.extra) } - }) - } -} - -func TestCompilerRewriteDynamicTerms(t *testing.T) { + c.Modules["test"] = MustParseModule(tc.module) - fixture := ` - package test - str = "hello" - ` - - tests := []struct { - input string - expected string - }{ - {`arr { [str] }`, `__local0__ = data.test.str; [__local0__]`}, - {`arr2 { [[str]] }`, `__local0__ = data.test.str; [[__local0__]]`}, - {`obj { {"x": str} }`, `__local0__ = data.test.str; {"x": __local0__}`}, - {`obj2 { {"x": {"y": str}} }`, `__local0__ = data.test.str; {"x": {"y": __local0__}}`}, - {`set { {str} }`, `__local0__ = data.test.str; {__local0__}`}, - {`set2 { {{str}} }`, `__local0__ = data.test.str; {{__local0__}}`}, - {`ref { str[str] }`, `__local0__ = data.test.str; data.test.str[__local0__]`}, - {`ref2 { str[str[str]] }`, `__local0__ = data.test.str; __local1__ = data.test.str[__local0__]; data.test.str[__local1__]`}, - {`arr_compr { [1 | [str]] }`, `[1 | __local0__ = data.test.str; [__local0__]]`}, - {`arr_compr2 { [1 | [1 | [str]]] }`, `[1 | [1 | __local0__ = data.test.str; [__local0__]]]`}, - {`set_compr { {1 | [str]} }`, `{1 | __local0__ = data.test.str; [__local0__]}`}, - {`set_compr2 { {1 | {1 | [str]}} }`, `{1 | {1 | __local0__ = data.test.str; [__local0__]}}`}, - {`obj_compr { {"a": "b" | [str]} }`, `{"a": "b" | __local0__ = data.test.str; [__local0__]}`}, - {`obj_compr2 { {"a": "b" | {"a": "b" | [str]}} }`, `{"a": "b" | {"a": "b" | __local0__ = data.test.str; [__local0__]}}`}, - {`equality { str = str }`, `data.test.str = data.test.str`}, - {`equality2 { [str] = [str] }`, `__local0__ = data.test.str; __local1__ = data.test.str; [__local0__] = [__local1__]`}, - {`call { startswith(str, "") }`, `__local0__ = data.test.str; startswith(__local0__, "")`}, - {`call2 { count([str], n) }`, `__local0__ = data.test.str; count([__local0__], n)`}, - {`eq_with { [str] = [1] with input as 1 }`, `__local0__ = data.test.str with input as 1; [__local0__] = [1] with input as 1`}, - {`term_with { [[str]] with input as 1 }`, `__local0__ = data.test.str with input as 1; [[__local0__]] with input as 1`}, - {`call_with { count(str) with input as 1 }`, `__local0__ = data.test.str with input as 1; count(__local0__) with input as 1`}, - {`call_func { f(input, "foo") } f(x,y) { x[y] }`, `__local2__ = input; data.test.f(__local2__, "foo")`}, - {`call_func2 { f(input.foo, "foo") } f(x,y) { x[y] }`, `__local2__ = input.foo; data.test.f(__local2__, "foo")`}, - } + // NOTE(sr): We're running all compiler stages here, since the type checking of + // built-in function replacements happens at the type check stage. + c.Compile(c.Modules) - for _, tc := range tests { - t.Run(tc.input, func(t *testing.T) { - c := NewCompiler() - module := fixture + tc.input - c.Modules["test"] = MustParseModule(module) - compileStages(c, c.rewriteDynamicTerms) - assertNotFailed(t, c) - expected := MustParseBody(tc.expected) - result := c.Modules["test"].Rules[1].Body - if result.Compare(expected) != 0 { - t.Fatalf("\nExp: %v\nGot: %v", expected, result) + if tc.err != "" { + if !strings.Contains(c.Errors.Error(), tc.err) { + t.Errorf("expected error to contain %q, got %q", tc.err, c.Errors.Error()) + } + } else if len(c.Errors) > 0 { + t.Errorf("expected no errors, got %v", c.Errors) } }) } + } -func TestCompilerRewriteWithValue(t *testing.T) { - fixture := `package test +func TestCompilerMockVirtualDocumentPartially(t *testing.T) { + c := NewCompiler() - arr = ["hello", "goodbye"] + c.Modules["test"] = MustParseModule(` + package test + p = {"a": 1} + q = x { p = x with p.a as 2 } + `) - ` + compileStages(c, c.rewriteWithModifiers) + assertCompilerErrorStrings(t, c, []string{"rego_compile_error: with keyword cannot partially replace virtual document(s)"}) +} - tests := []struct { - note string - input string - expected string - wantErr error - }{ +func TestCompilerCheckUnusedAssignedVar(t *testing.T) { + type testCase struct { + note string + module string + expectedErrors Errors + } + + cases := []testCase{ { - note: "nop", - input: `p { true with input as 1 }`, - expected: `p { true with input as 1 }`, + note: "global var", + module: `package test + x := 1 + `, + }, + { + note: "simple rule with wildcard", + module: `package test + p { + _ := 1 + } + `, + }, + { + note: "simple rule", + module: `package test + p { + x := 1 + y := 2 + z := x + 3 + } + `, + expectedErrors: Errors{ + &Error{Message: "assigned var y unused"}, + &Error{Message: "assigned var z unused"}, + }, + }, + { + note: "rule with return", + module: `package test + p = x { + x := 2 + y := 3 + } + `, + expectedErrors: Errors{ + &Error{Message: "assigned var y unused"}, + }, + }, + { + note: "rule with function call", + module: `package test + p { + x := 2 + y := f(x) + } + `, + expectedErrors: Errors{ + &Error{Message: "assigned var y unused"}, + }, }, { - note: "refs", - input: `p { true with input as arr }`, - expected: `p { __local0__ = data.test.arr; true with input as __local0__ }`, + note: "rule with nested array comprehension", + module: `package test + p { + x := 2 + y := [z | z := 2 * x] + } + `, + expectedErrors: Errors{ + &Error{Message: "assigned var y unused"}, + }, }, { - note: "array comprehension", - input: `p { true with input as [true | true] }`, - expected: `p { __local0__ = [true | true]; true with input as __local0__ }`, + note: "rule with nested array comprehension and shadowing", + module: `package test + p { + x := 2 + y := [x | x := 2 * x] + } + `, + expectedErrors: Errors{ + &Error{Message: "assigned var y unused"}, + }, }, { - note: "set comprehension", - input: `p { true with input as {true | true} }`, - expected: `p { __local0__ = {true | true}; true with input as __local0__ }`, + note: "rule with nested array comprehension and shadowing (unused shadowed var)", + module: `package test + p { + x := 2 + y := [x | x := 2] + } + `, + expectedErrors: Errors{ + &Error{Message: "assigned var x unused"}, + &Error{Message: "assigned var y unused"}, + }, }, { - note: "object comprehension", - input: `p { true with input as {"k": true | true} }`, - expected: `p { __local0__ = {"k": true | true}; true with input as __local0__ }`, + note: "rule with nested array comprehension and shadowing (unused shadowing var)", + module: `package test + p { + x := 2 + x > 1 + [1 | x := 2] + } + `, + expectedErrors: Errors{ + &Error{Message: "assigned var x unused"}, + }, }, { - note: "comprehension nested", - input: `p { true with input as [true | true with input as arr] }`, - expected: `p { __local0__ = [true | __local1__ = data.test.arr; true with input as __local1__]; true with input as __local0__ }`, + note: "rule with nested array comprehension and some declaration", + module: `package test + p { + some i + _ := [z | z := [1, 2][i]] + } + `, }, { - note: "multiple", - input: `p { true with input.a as arr[0] with input.b as arr[1] }`, - expected: `p { __local0__ = data.test.arr[0]; __local1__ = data.test.arr[1]; true with input.a as __local0__ with input.b as __local1__ }`, + note: "rule with nested set comprehension", + module: `package test + p { + x := 2 + y := {z | z := 2 * x} + } + `, + expectedErrors: Errors{ + &Error{Message: "assigned var y unused"}, + }, }, { - note: "invalid target", - input: `p { true with foo.q as 1 }`, - wantErr: fmt.Errorf("rego_type_error: with keyword target must start with input or data"), + note: "rule with nested set comprehension and unused inner var", + module: `package test + p { + x := 2 + y := {z | z := 2 * x; a := 2} + } + `, + expectedErrors: Errors{ + &Error{Message: "assigned var a unused"}, // y isn't reported, as we abort early on errors when moving through the stack + }, }, - } - - for _, tc := range tests { - t.Run(tc.note, func(t *testing.T) { - c := NewCompiler() - module := fixture + tc.input - c.Modules["test"] = MustParseModule(module) - compileStages(c, c.rewriteWithModifiers) - if tc.wantErr == nil { - assertNotFailed(t, c) - expected := MustParseRule(tc.expected) - result := c.Modules["test"].Rules[1] - if result.Compare(expected) != 0 { - t.Fatalf("\nExp: %v\nGot: %v", expected, result) + { + note: "rule with nested object comprehension", + module: `package test + p { + x := 2 + y := {z: x | z := 2 * x} } - } else { - assertCompilerErrorStrings(t, c, []string{tc.wantErr.Error()}) - } - }) - } -} - -func TestCompilerRewritePrintCallsErasure(t *testing.T) { - - cases := []struct { - note string - module string - exp string - }{ + `, + expectedErrors: Errors{ + &Error{Message: "assigned var y unused"}, + }, + }, { - note: "no-op", + note: "rule with nested closure", module: `package test - p { true }`, - exp: `package test - p { true }`, + p { + x := 1 + a := 1 + { y | y := [ z | z:=[1,2,3][a]; z > 1 ][_] } + } + `, + expectedErrors: Errors{ + &Error{Message: "assigned var x unused"}, + }, }, { - note: "replace empty body with true", + note: "rule with nested closure and unused inner var", module: `package test - - p { print(1) } + p { + x := 1 + { y | y := [ z | z:=[1,2,3][x]; z > 1; a := 2 ][_] } + } `, - exp: `package test - - p { true } `, + expectedErrors: Errors{ + &Error{Message: "assigned var a unused"}, + }, }, { - note: "rule body", + note: "simple function", module: `package test - - p { false; print(1) } + f() { + x := 1 + y := 2 + } `, - exp: `package test - - p { false } `, + expectedErrors: Errors{ + &Error{Message: "assigned var x unused"}, + &Error{Message: "assigned var y unused"}, + }, }, { - note: "set comprehension body", + note: "simple function with wildcard", module: `package test - - p { {1 | false; print(1)} } + f() { + x := 1 + _ := 2 + } `, - exp: `package test - - p { {1 | false} } `, + expectedErrors: Errors{ + &Error{Message: "assigned var x unused"}, + }, }, { - note: "array comprehension body", + note: "function with return", module: `package test - - p { [1 | false; print(1)] } + f() = x { + x := 1 + y := 2 + } `, - exp: `package test - - p { [1 | false] } `, + expectedErrors: Errors{ + &Error{Message: "assigned var y unused"}, + }, }, { - note: "object comprehension body", + note: "array comprehension", module: `package test - - p { {"x": 1 | false; print(1)} } + comp = [ 1 | + x := [1, 2, 3] + y := 2 + z := x[_] + ] `, - exp: `package test - - p { {"x": 1 | false} } `, + expectedErrors: Errors{ + &Error{Message: "assigned var y unused"}, + &Error{Message: "assigned var z unused"}, + }, }, { - note: "in head", + note: "array comprehension nested", module: `package test - - p = {1 | print("x")}`, - exp: `package test - - p = __local0__ { true; __local0__ = {1 | true} }`, + comp := [ 1 | + x := 1 + y := [a | a := x] + ] + `, + expectedErrors: Errors{ + &Error{Message: "assigned var y unused"}, + }, }, - } - - for _, tc := range cases { - t.Run(tc.note, func(t *testing.T) { - c := NewCompiler().WithEnablePrintStatements(false) - c.Compile(map[string]*Module{ - "test.rego": MustParseModule(tc.module), - }) - if c.Failed() { - t.Fatal(c.Errors) - } - exp := MustParseModule(tc.exp) - if !exp.Equal(c.Modules["test.rego"]) { - t.Fatalf("Expected:\n\n%v\n\nGot:\n\n%v", exp, c.Modules["test.rego"]) - } - }) - } -} - -func TestCompilerRewritePrintCallsErrors(t *testing.T) { - cases := []struct { - note string - module string - exp error - }{ { - note: "non-existent var", + note: "array comprehension with wildcard", module: `package test - - p { print(x) }`, - exp: errors.New("var x is undeclared"), + comp = [ 1 | + x := [1, 2, 3] + _ := 2 + z := x[_] + ] + `, + expectedErrors: Errors{ + &Error{Message: "assigned var z unused"}, + }, }, { - note: "declared after print", + note: "array comprehension with return", module: `package test - - p { print(x); x = 7 }`, - exp: errors.New("var x is undeclared"), + comp = [ z | + x := [1, 2, 3] + y := 2 + z := x[_] + ] + `, + expectedErrors: Errors{ + &Error{Message: "assigned var y unused"}, + }, }, { - note: "inside comprehension", + note: "array comprehension with some", module: `package test - p { {1 | print(x)} } + comp = [ i | + some i + y := 2 + ] `, - exp: errors.New("var x is undeclared"), + expectedErrors: Errors{ + &Error{Message: "assigned var y unused"}, + }, }, - } - - for _, tc := range cases { - t.Run(tc.note, func(t *testing.T) { - c := NewCompiler().WithEnablePrintStatements(true) - c.Compile(map[string]*Module{ - "test.rego": MustParseModule(tc.module), - }) - if !c.Failed() { - t.Fatal("expected error") - } - if c.Errors[0].Code != CompileErr || c.Errors[0].Message != tc.exp.Error() { - t.Fatal("unexpected error:", c.Errors) - } - }) - } -} - -func TestCompilerRewritePrintCalls(t *testing.T) { - cases := []struct { - note string - module string - exp string - }{ { - note: "print one", + note: "set comprehension", module: `package test - - p { print(1) }`, - exp: `package test - - p = true { __local1__ = {__local0__ | __local0__ = 1}; internal.print([__local1__]) }`, + comp = { 1 | + x := [1, 2, 3] + y := 2 + z := x[_] + } + `, + expectedErrors: Errors{ + &Error{Message: "assigned var y unused"}, + &Error{Message: "assigned var z unused"}, + }, }, { - note: "print multiple", + note: "set comprehension nested", module: `package test - - p { print(1, 2) }`, - exp: `package test - - p = true { __local2__ = {__local0__ | __local0__ = 1}; __local3__ = {__local1__ | __local1__ = 2}; internal.print([__local2__, __local3__]) }`, + comp := { 1 | + x := 1 + y := [a | a := x] + } + `, + expectedErrors: Errors{ + &Error{Message: "assigned var y unused"}, + }, }, { - note: "print inside set comprehension", + note: "set comprehension with wildcard", module: `package test - - p { x = 1; {2 | print(x)} }`, - exp: `package test - - p = true { x = 1; {2 | __local1__ = {__local0__ | __local0__ = x}; internal.print([__local1__])} }`, + comp = { 1 | + x := [1, 2, 3] + _ := 2 + z := x[_] + } + `, + expectedErrors: Errors{ + &Error{Message: "assigned var z unused"}, + }, }, { - note: "print inside array comprehension", + note: "set comprehension with return", module: `package test - - p { x = 1; [2 | print(x)] }`, - exp: `package test - - p = true { x = 1; [2 | __local1__ = {__local0__ | __local0__ = x}; internal.print([__local1__])] }`, + comp = { z | + x := [1, 2, 3] + y := 2 + z := x[_] + } + `, + expectedErrors: Errors{ + &Error{Message: "assigned var y unused"}, + }, }, { - note: "print inside object comprehension", + note: "set comprehension with some", module: `package test - - p { x = 1; {"x": 2 | print(x)} }`, - exp: `package test - - p = true { x = 1; {"x": 2 | __local1__ = {__local0__ | __local0__ = x}; internal.print([__local1__])} }`, + comp = { i | + some i + y := 2 + } + `, + expectedErrors: Errors{ + &Error{Message: "assigned var y unused"}, + }, }, { - note: "print output of nested call", + note: "object comprehension", module: `package test - - p { - x := split("abc", "")[y] - print(x, y) - }`, - exp: `package test - - p = true { split("abc", "", __local3__); __local0__ = __local3__[y]; __local4__ = {__local1__ | __local1__ = __local0__}; __local5__ = {__local2__ | __local2__ = y}; internal.print([__local4__, __local5__]) }`, + comp = { 1: 2 | + x := [1, 2, 3] + y := 2 + z := x[_] + } + `, + expectedErrors: Errors{ + &Error{Message: "assigned var y unused"}, + &Error{Message: "assigned var z unused"}, + }, }, { - note: "print call in head", + note: "object comprehension nested", module: `package test - - p = {1 | print("x") }`, - exp: `package test - - p = __local1__ { - true - __local1__ = {1 | __local2__ = { __local0__ | __local0__ = "x"}; internal.print([__local2__])} - }`, + comp := { 1: 1 | + x := 1 + y := {a: x | a := x} + } + `, + expectedErrors: Errors{ + &Error{Message: "assigned var y unused"}, + }, }, { - note: "print call in head - args treated as safe", + note: "object comprehension with wildcard", module: `package test - - f(a) = {1 | a[x]; print(x)}`, - exp: `package test - - f(__local0__) = __local2__ { true; __local2__ = {1 | __local0__[x]; __local3__ = {__local1__ | __local1__ = x}; internal.print([__local3__])} } + comp = { 1: 2 | + x := [1, 2, 3] + _ := 2 + z := x[_] + } `, + expectedErrors: Errors{ + &Error{Message: "assigned var z unused"}, + }, }, { - note: "print call of var in head key", + note: "object comprehension with return", module: `package test - f(_) = [1, 2, 3] - p[x] { [_, x, _] := f(true); print(x) }`, - exp: `package test - f(__local0__) = [1, 2, 3] { true } - p[__local2__] { data.test.f(true, __local5__); [__local1__, __local2__, __local3__] = __local5__; __local6__ = {__local4__ | __local4__ = __local2__}; internal.print([__local6__]) } + comp = { z: x | + x := [1, 2, 3] + y := 2 + z := x[_] + } `, + expectedErrors: Errors{ + &Error{Message: "assigned var y unused"}, + }, }, { - note: "print call of var in head value", + note: "object comprehension with some", module: `package test - f(_) = [1, 2, 3] - p = x { [_, x, _] := f(true); print(x) }`, - exp: `package test - f(__local0__) = [1, 2, 3] { true } - p = __local2__ { data.test.f(true, __local5__); [__local1__, __local2__, __local3__] = __local5__; __local6__ = {__local4__ | __local4__ = __local2__}; internal.print([__local6__]) } + comp = { i | + some i + y := 2 + } `, + expectedErrors: Errors{ + &Error{Message: "assigned var y unused"}, + }, }, { - note: "print call of vars in head key and value", + note: "every: unused assigned var in body", module: `package test - f(_) = [1, 2, 3] - p[x] = y { [_, x, y] := f(true); print(x) }`, - exp: `package test - f(__local0__) = [1, 2, 3] { true } - p[__local2__] = __local3__ { data.test.f(true, __local5__); [__local1__, __local2__, __local3__] = __local5__; __local6__ = {__local4__ | __local4__ = __local2__}; internal.print([__local6__]) } + p { every i in [1] { y := 10; i == 1 } } `, + expectedErrors: Errors{ + &Error{Message: "assigned var y unused"}, + }, }, } - for _, tc := range cases { - t.Run(tc.note, func(t *testing.T) { - c := NewCompiler().WithEnablePrintStatements(true) - c.Compile(map[string]*Module{ - "test.rego": MustParseModule(tc.module), - }) - if c.Failed() { - t.Fatal(c.Errors) + makeTestRunner := func(tc testCase, strict bool) func(t *testing.T) { + return func(t *testing.T) { + compiler := NewCompiler().WithStrict(strict) + opts := ParserOptions{AllFutureKeywords: true, unreleasedKeywords: true} + compiler.Modules = map[string]*Module{ + "test": MustParseModuleWithOpts(tc.module, opts), } - exp := MustParseModule(tc.exp) - if !exp.Equal(c.Modules["test.rego"]) { - t.Fatalf("Expected:\n\n%v\n\nGot:\n\n%v", exp, c.Modules["test.rego"]) + compileStages(compiler, compiler.rewriteLocalVars) + + if strict { + assertErrors(t, compiler.Errors, tc.expectedErrors, false) + } else { + assertNotFailed(t, compiler) } - }) + } } -} - -func TestCompilerMockFunction(t *testing.T) { - c := NewCompiler() - c.Modules["test"] = MustParseModule(` - package test - is_allowed(label) { - label == "test_label" + for _, tc := range cases { + t.Run(tc.note+"_strict", makeTestRunner(tc, true)) + t.Run(tc.note+"_non-strict", makeTestRunner(tc, false)) } - - p {true with data.test.is_allowed as "blah" } - `) - compileStages(c, c.rewriteWithModifiers) - assertCompilerErrorStrings(t, c, []string{"rego_compile_error: with keyword cannot replace functions"}) -} - -func TestCompilerMockVirtualDocumentPartially(t *testing.T) { - c := NewCompiler() - - c.Modules["test"] = MustParseModule(` - package test - p = {"a": 1} - q = x { p = x with p.a as 2 } - `) - - compileStages(c, c.rewriteWithModifiers) - assertCompilerErrorStrings(t, c, []string{"rego_compile_error: with keyword cannot partially replace virtual document(s)"}) } func TestCompilerSetGraph(t *testing.T) { @@ -3517,39 +7343,53 @@ dataref = true { data }`, x = "foo.bar" foo(x, y) }`), + "everyMod": MustParseModule(`package everymod + import future.keywords.every + everyp { + every x in [true, false] { x; everyp } + } + everyq[1] { + every x in everyq { x == 1 } + }`), } compileStages(c, c.checkRecursion) - makeRuleErrMsg := func(rule string, loop ...string) string { - return fmt.Sprintf("rego_recursion_error: rule %v is recursive: %v", rule, strings.Join(loop, " -> ")) + makeRuleErrMsg := func(pkg, rule string, loop ...string) string { + l := make([]string, len(loop)) + for i, lo := range loop { + l[i] = "data." + pkg + "." + lo + } + return fmt.Sprintf("rego_recursion_error: rule data.%s.%s is recursive: %v", pkg, rule, strings.Join(l, " -> ")) } expected := []string{ - makeRuleErrMsg("s", "s", "t", "s"), - makeRuleErrMsg("t", "t", "s", "t"), - makeRuleErrMsg("a", "a", "b", "c", "e", "a"), - makeRuleErrMsg("b", "b", "c", "e", "a", "b"), - makeRuleErrMsg("c", "c", "e", "a", "b", "c"), - makeRuleErrMsg("e", "e", "a", "b", "c", "e"), - makeRuleErrMsg("p", "p", "q", "p"), - makeRuleErrMsg("q", "q", "p", "q"), - makeRuleErrMsg("acq", "acq", "acp", "acq"), - makeRuleErrMsg("acp", "acp", "acq", "acp"), - makeRuleErrMsg("np", "np", "nq", "np"), - makeRuleErrMsg("nq", "nq", "np", "nq"), - makeRuleErrMsg("prefix", "prefix", "prefix"), - makeRuleErrMsg("dataref", "dataref", "dataref"), - makeRuleErrMsg("else_self", "else_self", "else_self"), - makeRuleErrMsg("elsetop", "elsetop", "elsemid", "elsebottom", "elsetop"), - makeRuleErrMsg("elsemid", "elsemid", "elsebottom", "elsetop", "elsemid"), - makeRuleErrMsg("elsebottom", "elsebottom", "elsetop", "elsemid", "elsebottom"), - makeRuleErrMsg("fn", "fn", "fn"), - makeRuleErrMsg("foo", "foo", "bar", "foo"), - makeRuleErrMsg("bar", "bar", "foo", "bar"), - makeRuleErrMsg("bar", "bar", "p", "foo", "bar"), - makeRuleErrMsg("foo", "foo", "bar", "p", "foo"), - makeRuleErrMsg("p", "p", "foo", "bar", "p"), + makeRuleErrMsg("rec", "s", "s", "t", "s"), + makeRuleErrMsg("rec", "t", "t", "s", "t"), + makeRuleErrMsg("rec", "a", "a", "b", "c", "e", "a"), + makeRuleErrMsg("rec", "b", "b", "c", "e", "a", "b"), + makeRuleErrMsg("rec", "c", "c", "e", "a", "b", "c"), + makeRuleErrMsg("rec", "e", "e", "a", "b", "c", "e"), + `rego_recursion_error: rule data.rec3.p[x] is recursive: data.rec3.p[x] -> data.rec4.q[x] -> data.rec3.p[x]`, // NOTE(sr): these two are hardcoded: they are + `rego_recursion_error: rule data.rec4.q[x] is recursive: data.rec4.q[x] -> data.rec3.p[x] -> data.rec4.q[x]`, // the only ones not fitting the pattern. + makeRuleErrMsg("rec5", "acq", "acq", "acp", "acq"), + makeRuleErrMsg("rec5", "acp", "acp", "acq", "acp"), + makeRuleErrMsg("rec6", "np[x]", "np[x]", "nq[x]", "np[x]"), + makeRuleErrMsg("rec6", "nq[x]", "nq[x]", "np[x]", "nq[x]"), + makeRuleErrMsg("rec7", "prefix", "prefix", "prefix"), + makeRuleErrMsg("rec8", "dataref", "dataref", "dataref"), + makeRuleErrMsg("rec9", "else_self", "else_self", "else_self"), + makeRuleErrMsg("rec9", "elsetop", "elsetop", "elsemid", "elsebottom", "elsetop"), + makeRuleErrMsg("rec9", "elsemid", "elsemid", "elsebottom", "elsetop", "elsemid"), + makeRuleErrMsg("rec9", "elsebottom", "elsebottom", "elsetop", "elsemid", "elsebottom"), + makeRuleErrMsg("f0", "fn", "fn", "fn"), + makeRuleErrMsg("f1", "foo", "foo", "bar", "foo"), + makeRuleErrMsg("f1", "bar", "bar", "foo", "bar"), + makeRuleErrMsg("f2", "bar", "bar", "p[x]", "foo", "bar"), + makeRuleErrMsg("f2", "foo", "foo", "bar", "p[x]", "foo"), + makeRuleErrMsg("f2", "p[x]", "p[x]", "foo", "bar", "p[x]"), + makeRuleErrMsg("everymod", "everyp", "everyp", "everyp"), + makeRuleErrMsg("everymod", "everyq", "everyq", "everyq"), } result := compilerErrsToStringSlice(c.Errors) @@ -3571,28 +7411,38 @@ func TestCompilerCheckDynamicRecursion(t *testing.T) { // references. For more background info, see // . - for note, mod := range map[string]*Module{ - "recursion": MustParseModule(` + for _, tc := range []struct { + note, err string + mod *Module + }{ + { + note: "recursion", + mod: MustParseModule(` package recursion pkg = "recursion" foo[x] { data[pkg]["foo"][x] } `), - "system.main": MustParseModule(` + err: "rego_recursion_error: rule data.recursion.foo is recursive: data.recursion.foo -> data.recursion.foo", + }, + {note: "system.main", + mod: MustParseModule(` package system.main foo { - data[input] + data[input] } `), + err: "rego_recursion_error: rule data.system.main.foo is recursive: data.system.main.foo -> data.system.main.foo", + }, } { - t.Run(note, func(t *testing.T) { + t.Run(tc.note, func(t *testing.T) { c := NewCompiler() - c.Modules = map[string]*Module{note: mod} + c.Modules = map[string]*Module{tc.note: tc.mod} compileStages(c, c.checkRecursion) result := compilerErrsToStringSlice(c.Errors) - expected := "rego_recursion_error: rule foo is recursive: foo -> foo" + expected := tc.err if len(result) != 1 || result[0] != expected { t.Errorf("Expected %v but got: %v", expected, result) @@ -3868,6 +7718,7 @@ func TestCompilerGetRulesDynamic(t *testing.T) { "mod1": `package a.b.c.d r1 = 1`, "mod2": `package a.b.c.e +default r2 = false r2 = 2`, "mod3": `package a.b r3 = 3`, @@ -3878,7 +7729,8 @@ r4 = 4`, compileStages(compiler, nil) rule1 := compiler.Modules["mod1"].Rules[0] - rule2 := compiler.Modules["mod2"].Rules[0] + rule2d := compiler.Modules["mod2"].Rules[0] + rule2 := compiler.Modules["mod2"].Rules[1] rule3 := compiler.Modules["mod3"].Rules[0] rule4 := compiler.Modules["hidden"].Rules[0] @@ -3888,15 +7740,16 @@ r4 = 4`, excludeHidden bool }{ {input: "data.a.b.c.d.r1", expected: []*Rule{rule1}}, - {input: "data.a.b[x]", expected: []*Rule{rule1, rule2, rule3}}, + {input: "data.a.b[x]", expected: []*Rule{rule1, rule2d, rule2, rule3}}, {input: "data.a.b[x].d", expected: []*Rule{rule1, rule3}}, - {input: "data.a.b.c", expected: []*Rule{rule1, rule2}}, + {input: "data.a.b.c", expected: []*Rule{rule1, rule2d, rule2}}, {input: "data.a.b.d"}, - {input: "data[x]", expected: []*Rule{rule1, rule2, rule3, rule4}}, - {input: "data[data.complex_computation].b[y]", expected: []*Rule{rule1, rule2, rule3}}, - {input: "data[x][y].c.e", expected: []*Rule{rule2}}, + {input: "data", expected: []*Rule{rule1, rule2d, rule2, rule3, rule4}}, + {input: "data[x]", expected: []*Rule{rule1, rule2d, rule2, rule3, rule4}}, + {input: "data[data.complex_computation].b[y]", expected: []*Rule{rule1, rule2d, rule2, rule3}}, + {input: "data[x][y].c.e", expected: []*Rule{rule2d, rule2}}, {input: "data[x][y].r3", expected: []*Rule{rule3}}, - {input: "data[x][y]", expected: []*Rule{rule1, rule2, rule3}, excludeHidden: true}, // old behaviour of GetRulesDynamic + {input: "data[x][y]", expected: []*Rule{rule1, rule2d, rule2, rule3}, excludeHidden: true}, // old behaviour of GetRulesDynamic } for _, tc := range tests { @@ -4128,7 +7981,7 @@ func TestCompilerWithStageAfterWithMetrics(t *testing.T) { m := metrics.New() c := NewCompiler().WithStageAfter( "CheckRecursion", - CompilerStageDefinition{"MockStage", "mock_stage", mockStageFunctionCallNoErr}, + CompilerStageDefinition{"MockStage", "mock_stage", func(*Compiler) *Error { return nil }}, ) c.WithMetrics(m) @@ -4412,6 +8265,19 @@ func TestCompilerBuildComprehensionIndexKeySet(t *testing.T) { } } +func TestCompilerAllowMultipleAssignments(t *testing.T) { + + _, err := CompileModules(map[string]string{"test.rego": ` + package test + + p := 7 + p := 8 + `}) + if err != nil { + t.Fatal(err) + } +} + func TestQueryCompiler(t *testing.T) { tests := []struct { note string @@ -4502,7 +8368,7 @@ func TestQueryCompiler(t *testing.T) { q: "x = 1 with foo.p as null", pkg: "", imports: nil, - expected: fmt.Errorf("1 error occurred: 1:12: rego_type_error: with keyword target must start with input or data"), + expected: fmt.Errorf("1 error occurred: 1:12: rego_type_error: with keyword target must reference existing input, data, or a function"), }, { note: "rewrite with value", @@ -4516,7 +8382,7 @@ func TestQueryCompiler(t *testing.T) { q: `startswith("x")`, pkg: "", imports: nil, - expected: fmt.Errorf("1 error occurred: 1:1: rego_type_error: startswith: arity mismatch\n\thave: (string)\n\twant: (string, string)"), + expected: fmt.Errorf("1 error occurred: 1:1: rego_type_error: startswith: arity mismatch\n\thave: (string)\n\twant: (search: string, base: string)"), }, { note: "built-in function arity mismatch (arity 0)", @@ -4530,7 +8396,7 @@ func TestQueryCompiler(t *testing.T) { q: "count(sum())", pkg: "", imports: nil, - expected: fmt.Errorf("1 error occurred: 1:7: rego_type_error: sum: arity mismatch\n\thave: (???)\n\twant: (any)"), + expected: fmt.Errorf("1 error occurred: 1:7: rego_type_error: sum: arity mismatch\n\thave: (???)\n\twant: (collection: any)"), }, { note: "check types", @@ -4662,7 +8528,7 @@ func TestQueryCompilerWithStageAfterWithMetrics(t *testing.T) { QueryCompilerStageDefinition{ "MockStage", "mock_stage", - func(qc QueryCompiler, b Body) (Body, error) { + func(_ QueryCompiler, b Body) (Body, error) { return b, nil }, }) @@ -4679,13 +8545,164 @@ func TestQueryCompilerWithStageAfterWithMetrics(t *testing.T) { } func TestQueryCompilerWithUnsafeBuiltins(t *testing.T) { - c := NewCompiler().WithUnsafeBuiltins(map[string]struct{}{ - "count": {}, - }) + tests := []struct { + note string + query string + compiler *Compiler + opts func(QueryCompiler) QueryCompiler + err string + }{ + { + note: "builtin unsafe via compiler", + query: "count([])", + compiler: NewCompiler().WithUnsafeBuiltins(map[string]struct{}{"count": {}}), + err: "unsafe built-in function calls in expression: count", + }, + { + note: "builtin unsafe via query compiler", + query: "count([])", + compiler: NewCompiler(), + opts: func(qc QueryCompiler) QueryCompiler { + return qc.WithUnsafeBuiltins(map[string]struct{}{"count": {}}) + }, + err: "unsafe built-in function calls in expression: count", + }, + { + note: "builtin unsafe via compiler, 'with' mocking", + query: "is_array([]) with is_array as count", + compiler: NewCompiler().WithUnsafeBuiltins(map[string]struct{}{"count": {}}), + err: `with keyword replacing built-in function: target must not be unsafe: "count"`, + }, + { + note: "builtin unsafe via query compiler, 'with' mocking", + query: "is_array([]) with is_array as count", + compiler: NewCompiler(), + opts: func(qc QueryCompiler) QueryCompiler { + return qc.WithUnsafeBuiltins(map[string]struct{}{"count": {}}) + }, + err: `with keyword replacing built-in function: target must not be unsafe: "count"`, + }, + } - _, err := c.QueryCompiler().WithUnsafeBuiltins(map[string]struct{}{}).Compile(MustParseBody("count([])")) - if err != nil { - t.Fatal(err) + for _, tc := range tests { + t.Run(tc.note, func(t *testing.T) { + qc := tc.compiler.QueryCompiler() + if tc.opts != nil { + qc = tc.opts(qc) + } + _, err := qc.Compile(MustParseBody(tc.query)) + var errs Errors + if !errors.As(err, &errs) { + t.Fatalf("expected error type %T, got %v %[2]T", errs, err) + } + if exp, act := 1, len(errs); exp != act { + t.Fatalf("expected %d error(s), got %d", exp, act) + } + if exp, act := tc.err, errs[0].Message; exp != act { + t.Errorf("expected message %q, got %q", exp, act) + } + }) + } +} + +func TestQueryCompilerWithDeprecatedBuiltins(t *testing.T) { + cases := []strictnessQueryTestCase{ + { + note: "all() built-in", + query: "all([true, false])", + expectedErrors: fmt.Errorf("1 error occurred: 1:1: rego_type_error: deprecated built-in function calls in expression: all"), + }, + { + note: "any() built-in", + query: "any([true, false])", + expectedErrors: fmt.Errorf("1 error occurred: 1:1: rego_type_error: deprecated built-in function calls in expression: any"), + }, + } + + runStrictnessQueryTestCase(t, cases) +} + +func TestQueryCompilerWithUnusedAssignedVar(t *testing.T) { + cases := []strictnessQueryTestCase{ + { + note: "array comprehension", + query: "[1 | x := 2]", + expectedErrors: fmt.Errorf("1 error occurred: 1:6: rego_compile_error: assigned var x unused"), + }, + { + note: "set comprehension", + query: "{1 | x := 2}", + expectedErrors: fmt.Errorf("1 error occurred: 1:6: rego_compile_error: assigned var x unused"), + }, + { + note: "object comprehension", + query: "{1: 2 | x := 2}", + expectedErrors: fmt.Errorf("1 error occurred: 1:9: rego_compile_error: assigned var x unused"), + }, + { + note: "every: unused var in body", + query: "every _ in [] { x := 10 }", + expectedErrors: fmt.Errorf("1 error occurred: 1:17: rego_compile_error: assigned var x unused"), + }, + } + + runStrictnessQueryTestCase(t, cases) +} + +func TestQueryCompilerCheckKeywordOverrides(t *testing.T) { + cases := []strictnessQueryTestCase{ + { + note: "input assigned", + query: "input := 1", + expectedErrors: fmt.Errorf("1 error occurred: 1:1: rego_compile_error: variables must not shadow input (use a different variable name)"), + }, + { + note: "data assigned", + query: "data := 1", + expectedErrors: fmt.Errorf("1 error occurred: 1:1: rego_compile_error: variables must not shadow data (use a different variable name)"), + }, + { + note: "nested input assigned", + query: "d := [input | input := 1]", + expectedErrors: fmt.Errorf("1 error occurred: 1:15: rego_compile_error: variables must not shadow input (use a different variable name)"), + }, + } + + runStrictnessQueryTestCase(t, cases) +} + +type strictnessQueryTestCase struct { + note string + query string + expectedErrors error +} + +func runStrictnessQueryTestCase(t *testing.T, cases []strictnessQueryTestCase) { + t.Helper() + makeTestRunner := func(tc strictnessQueryTestCase, strict bool) func(t *testing.T) { + return func(t *testing.T) { + c := NewCompiler().WithStrict(strict) + opts := ParserOptions{AllFutureKeywords: true, unreleasedKeywords: true} + result, err := c.QueryCompiler().Compile(MustParseBodyWithOpts(tc.query, opts)) + + if strict { + if err == nil { + t.Fatalf("Expected error from %v but got: %v", tc.query, result) + } + if !strings.Contains(err.Error(), tc.expectedErrors.Error()) { + t.Fatalf("Expected error %v but got: %v", tc.expectedErrors, err) + } + } else { + if err != nil { + t.Fatalf("Unexpected error from %v: %v", tc.query, err) + } + } + } + } + + for _, tc := range cases { + t.Run(tc.note+"_strict", makeTestRunner(tc, true)) + t.Run(tc.note+"_non-strict", makeTestRunner(tc, false)) } } @@ -4709,14 +8726,6 @@ func assertNotFailed(t *testing.T, c *Compiler) { } } -func mockStageFunctionCall(c *Compiler) *Error { - return NewError(CompileErr, &Location{}, "mock stage error") -} - -func mockStageFunctionCallNoErr(c *Compiler) *Error { - return nil -} - func getCompilerWithParsedModules(mods map[string]string) *Compiler { parsed := map[string]*Module{} @@ -4985,869 +8994,402 @@ func TestCompilerPassesTypeCheckNegative(t *testing.T) { } } -func TestWithSchema(t *testing.T) { - c := NewCompiler() - schemaSet := NewSchemaSet() - schemaSet.Put(SchemaRootRef, objectSchema) - c.WithSchemas(schemaSet) - if c.schemaSet == nil { - t.Fatalf("WithSchema did not set the schema correctly in the compiler") - } -} - -func TestAnyOfObjectSchema1(t *testing.T) { - c := NewCompiler() - schemaSet := NewSchemaSet() - schemaSet.Put(SchemaRootRef, anyOfExtendCoreSchema) - c.WithSchemas(schemaSet) - if c.schemaSet == nil { - t.Fatalf("Did not correctly compile an object type schema with anyOf outside core schema") - } -} - -func TestAnyOfObjectSchema2(t *testing.T) { - c := NewCompiler() - schemaSet := NewSchemaSet() - schemaSet.Put(SchemaRootRef, anyOfInsideCoreSchema) - c.WithSchemas(schemaSet) - if c.schemaSet == nil { - t.Fatalf("Did not correctly compile an object type schema with anyOf inside core schema") - } -} - -func TestAnyOfArraySchema(t *testing.T) { - c := NewCompiler() - schemaSet := NewSchemaSet() - schemaSet.Put(SchemaRootRef, anyOfArraySchema) - c.WithSchemas(schemaSet) - if c.schemaSet == nil { - t.Fatalf("Did not correctly compile an array type schema with anyOf") - } -} - -func TestAnyOfObjectMissing(t *testing.T) { - c := NewCompiler() - schemaSet := NewSchemaSet() - schemaSet.Put(SchemaRootRef, anyOfObjectMissing) - c.WithSchemas(schemaSet) - if c.schemaSet == nil { - t.Fatalf("Did not correctly compile an object type schema with anyOf where one of the props did not explicitly claim type") - } -} - -func TestAnyOfArrayMissing(t *testing.T) { - c := NewCompiler() - schemaSet := NewSchemaSet() - schemaSet.Put(SchemaRootRef, anyOfArrayMissing) - c.WithSchemas(schemaSet) - if c.schemaSet == nil { - t.Fatalf("Did not correctly compile an array type schema with anyOf where items are inside anyOf") - } -} - -const objectSchema = `{ - "$schema": "http://json-schema.org/draft-07/schema", - "$id": "http://example.com/example.json", - "type": "object", - "title": "The root schema", - "description": "The root schema comprises the entire JSON document.", - "required": [ - "foo", - "b" - ], - "properties": { - "foo": { - "$id": "#/properties/foo", - "type": "string", - "title": "The foo schema", - "description": "An explanation about the purpose of this instance." - }, - "b": { - "$id": "#/properties/b", - "type": "array", - "title": "The b schema", - "description": "An explanation about the purpose of this instance.", - "additionalItems": false, - "items": { - "$id": "#/properties/b/items", - "type": "object", - "title": "The items schema", - "description": "An explanation about the purpose of this instance.", - "required": [ - "a", - "b", - "c" - ], - "properties": { - "a": { - "$id": "#/properties/b/items/properties/a", - "type": "integer", - "title": "The a schema", - "description": "An explanation about the purpose of this instance." - }, - "b": { - "$id": "#/properties/b/items/properties/b", - "type": "array", - "title": "The b schema", - "description": "An explanation about the purpose of this instance.", - "additionalItems": false, - "items": { - "$id": "#/properties/b/items/properties/b/items", - "type": "integer", - "title": "The items schema", - "description": "An explanation about the purpose of this instance." - } - }, - "c": { - "$id": "#/properties/b/items/properties/c", - "type": "null", - "title": "The c schema", - "description": "An explanation about the purpose of this instance." - } - }, - "additionalProperties": false - } - } - }, - "additionalProperties": false -}` +func TestKeepModules(t *testing.T) { -const arrayNoItemsSchema = `{ - "$schema": "http://json-schema.org/draft-07/schema", - "$id": "http://example.com/example.json", - "type": "object", - "title": "The root schema", - "description": "The root schema comprises the entire JSON document.", - "required": [ - "b" - ], - "properties": { - "b": { - "$id": "#/properties/b", - "type": "array", - "title": "The b schema", - "description": "An explanation about the purpose of this instance.", - "additionalItems": true - } - }, - "additionalProperties": false -}` + t.Run("no keep", func(t *testing.T) { + c := NewCompiler() // no keep is default -const noChildrenObjectSchema = `{ - "$schema": "http://json-schema.org/draft-07/schema", - "$id": "http://example.com/example.json", - "type": "object", - "title": "The root schema", - "description": "The root schema comprises the entire JSON document.", - "additionalProperties": true -}` + // This one is overwritten by c.Compile() + c.Modules["foo.rego"] = MustParseModule("package foo\np = true") -const untypedFieldObjectSchema = `{ - "$schema": "http://json-schema.org/draft-07/schema", - "$id": "http://example.com/example.json", - "type": "object", - "title": "The root schema", - "description": "The root schema comprises the entire JSON document.", - "required": [ - "foo" - ], - "properties": { - "foo": { - "$id": "#/properties/foo" - } - }, - "additionalProperties": false -}` + c.Compile(map[string]*Module{"bar.rego": MustParseModule("package bar\np = input")}) -const booleanSchema = `{ - "$schema": "http://json-schema.org/draft-07/schema", - "$id": "http://example.com/example.json", - "type": "object", - "title": "The root schema", - "description": "The root schema comprises the entire JSON document.", - "required": [ - "a" - ], - "properties": { - "a": { - "$id": "#/properties/foo", - "type": "boolean", - "title": "The foo schema", - "description": "An explanation about the purpose of this instance." + if len(c.Errors) != 0 { + t.Fatalf("expected no error; got %v", c.Errors) } - }, - "additionalProperties": false -}` -const refSchema = ` -{ - "description": "Pod is a collection of containers that can run on a host. This resource is created by clients and scheduled onto hosts.", - "type": "object", - "properties": { - "apiVersion": { - "description": "APIVersion defines the versioned schema of this representation of an object. Servers should convert recognized schemas to the latest internal value, and may reject unrecognized values. More info: https://git.k8s.io/community/contributors/devel/api-conventions.md#resources", - "type": [ - "string", - "null" - ] - }, - - "kind": { - "description": "Kind is a string value representing the REST resource this object represents. Servers may infer this from the endpoint the client submits requests to. Cannot be updated. In CamelCase. More info: https://git.k8s.io/community/contributors/devel/api-conventions.md#types-kinds", - "type": [ - "string", - "null" - ], - "enum": [ - "Pod" - ] - }, - "metadata": { - "$ref": "https://kubernetesjsonschema.dev/v1.14.0/_definitions.json#/definitions/io.k8s.apimachinery.pkg.apis.meta.v1.ObjectMeta", - "description": "Standard object's metadata. More info: https://git.k8s.io/community/contributors/devel/api-conventions.md#metadata" - } - } -} -` -const podSchema = ` -{ - "description": "Pod is a collection of containers that can run on a host. This resource is created by clients and scheduled onto hosts.", - "properties": { - "apiVersion": { - "description": "APIVersion defines the versioned schema of this representation of an object. Servers should convert recognized schemas to the latest internal value, and may reject unrecognized values. More info: https://git.k8s.io/community/contributors/devel/api-conventions.md#resources", - "type": [ - "string", - "null" - ] - }, - "kind": { - "description": "Kind is a string value representing the REST resource this object represents. Servers may infer this from the endpoint the client submits requests to. Cannot be updated. In CamelCase. More info: https://git.k8s.io/community/contributors/devel/api-conventions.md#types-kinds", - "type": [ - "string", - "null" - ], - "enum": [ - "Pod" - ] - }, - "metadata": { - "$ref": "https://kubernetesjsonschema.dev/v1.14.0/_definitions.json#/definitions/io.k8s.apimachinery.pkg.apis.meta.v1.ObjectMeta", - "description": "Standard object's metadata. More info: https://git.k8s.io/community/contributors/devel/api-conventions.md#metadata" - }, - "spec": { - "$ref": "https://kubernetesjsonschema.dev/v1.14.0/_definitions.json#/definitions/io.k8s.api.core.v1.PodSpec", - "description": "Specification of the desired behavior of the pod. More info: https://git.k8s.io/community/contributors/devel/api-conventions.md#spec-and-status" - }, - "status": { - "$ref": "https://kubernetesjsonschema.dev/v1.14.0/_definitions.json#/definitions/io.k8s.api.core.v1.PodStatus", - "description": "Most recently observed status of the pod. This data may not be up to date. Populated by the system. Read-only. More info: https://git.k8s.io/community/contributors/devel/api-conventions.md#spec-and-status" - } - }, - "type": "object", - "x-kubernetes-group-version-kind": [ - { - "group": "", - "kind": "Pod", - "version": "v1" - } - ], - "$schema": "http://json-schema.org/schema#" - }` - -const anyOfArraySchema = `{ - "type": "object", - "properties": { - "familyMembers": { - "type": "array", - "items": { - "anyOf": [ - { - "type": "object", - "properties": { - "age": { "type": "integer" }, - "name": {"type": "string"} - } - },{ - "type": "object", - "properties": { - "personality": { "type": "string" }, - "nickname": { "type": "string" } - } - } - ] - } + mods := c.ParsedModules() + if mods != nil { + t.Errorf("expected ParsedModules == nil, got %v", mods) } - } -}` + }) -const anyOfExtendCoreSchema = `{ - "type": "object", - "properties": { - "AddressLine": { "type": "string" } - }, - "anyOf": [ - { - "type": "object", - "properties": { - "State": { "type": "string" }, - "ZipCode": { "type": "string" } - } - }, - { - "type": "object", - "properties": { - "County": { "type": "string" }, - "PostCode": { "type": "integer" } - } - } - ] -}` + t.Run("keep", func(t *testing.T) { -const allOfObjectSchema = `{ - "$schema": "http://json-schema.org/draft-04/schema#", - "type": "object", - "title": "My schema", - "properties": { - "AddressLine1": { "type": "string" }, - "AddressLine2": { "type": "string" }, - "City": { "type": "string" } - }, - "allOf": [ - { - "type": "object", - "properties": { - "State": { "type": "string" }, - "ZipCode": { "type": "string" } - }, - }, - { - "type": "object", - "properties": { - "County": { "type": "string" }, - "PostCode": { "type": "string" } - }, - } - ] -}` + c := NewCompiler().WithKeepModules(true) + + // This one is overwritten by c.Compile() + c.Modules["foo.rego"] = MustParseModule("package foo\np = true") -const allOfArraySchema = `{ - "$schema": "http://json-schema.org/draft-04/schema#", - "type": "array", - "title": "The b schema", - "description": "An explanation about the purpose of this instance.", - "items": { - "type": "integer", - "title": "The items schema", - "description": "An explanation about the purpose of this instance." - }, - "allOf": [ - { - "type": "array", - "title": "The b schema", - "description": "An explanation about the purpose of this instance.", - "items": { - "type": "integer", - "title": "The items schema", - "description": "An explanation about the purpose of this instance." + c.Compile(map[string]*Module{"bar.rego": MustParseModule("package bar\np = input")}) + if len(c.Errors) != 0 { + t.Fatalf("expected no error; got %v", c.Errors) } - }, - { - "type": "array", - "title": "The b schema", - "description": "An explanation about the purpose of this instance.", - "items": { - "type": "integer", - "title": "The items schema", - "description": "An explanation about the purpose of this instance." + + mods := c.ParsedModules() + if exp, act := 1, len(mods); exp != act { + t.Errorf("expected %d modules, found %d: %v", exp, act, mods) } + for k := range mods { + if k != "bar.rego" { + t.Errorf("unexpected key: %v, want 'bar.rego'", k) + } } - ] -}` -const allOfSchemaParentVariation = `{ - "$schema": "http://json-schema.org/draft-04/schema#", - "allOf": [ - { - "type": "object", - "properties": { - "State": { "type": "string" }, - "ZipCode": { "type": "string" } - }, - }, - { - "type": "object", - "properties": { - "County": { "type": "string" }, - "PostCode": { "type": "string" } - }, - } - ] -}` + for k := range mods { + compiled := c.Modules[k] + if compiled.Equal(mods[k]) { + t.Errorf("expected module %v to not be compiled: %v", k, mods[k]) + } + } -const emptySchema = `{ - "allof" : [] - }` - -const allOfArrayOfArrays = `{ - "$schema": "http://json-schema.org/draft-04/schema#", - "type": "array", - "title": "The b schema", - "description": "An explanation about the purpose of this instance.", - "items": { - "type": "array", - "title": "The items schema", - "description": "An explanation about the purpose of this instance.", - "items": { - "type": "integer", - "title": "The items schema", - "description": "An explanation about the purpose of this instance." + // expect ParsedModules to be reset + c.Compile(map[string]*Module{"baz.rego": MustParseModule("package baz\np = input")}) + mods = c.ParsedModules() + if exp, act := 1, len(mods); exp != act { + t.Errorf("expected %d modules, found %d: %v", exp, act, mods) } - }, - "allOf": [{ - "type": "array", - "title": "The b schema", - "description": "An explanation about the purpose of this instance.", - "items": { - "type": "array", - "title": "The items schema", - "description": "An explanation about the purpose of this instance.", - "items": { - "type": "integer", - "title": "The items schema", - "description": "An explanation about the purpose of this instance." - } + for k := range mods { + if k != "baz.rego" { + t.Errorf("unexpected key: %v, want 'baz.rego'", k) } - }, - { - "type": "array", - "title": "The b schema", - "description": "An explanation about the purpose of this instance.", - "items": { - "type": "integer", - "title": "The items schema", - "description": "An explanation about the purpose of this instance." + } + + for k := range mods { + compiled := c.Modules[k] + if compiled.Equal(mods[k]) { + t.Errorf("expected module %v to not be compiled: %v", k, mods[k]) } } - ] -}` -const anyOfInsideCoreSchema = ` { - "type": "object", - "properties": { - "AddressLine": { "type": "string" }, - "RandomInfo": { - "anyOf": [ - { "type": "object", - "properties": { - "accessMe": {"type": "string"} - } - }, - { "type": "number", "minimum": 0 } - ] + // expect ParsedModules to be reset to nil + c = c.WithKeepModules(false) + c.Compile(map[string]*Module{"baz.rego": MustParseModule("package baz\np = input")}) + mods = c.ParsedModules() + if mods != nil { + t.Errorf("expected ParsedModules == nil, got %v", mods) } - } -}` + }) -const anyOfObjectMissing = `{ - "type": "object", - "properties": { - "AddressLine": { "type": "string" } - }, - "anyOf": [ - { - "type": "object", - "properties": { - "State": { "type": "string" }, - "ZipCode": { "type": "string" } - } - }, - { - "properties": { - "County": { "type": "string" }, - "PostCode": { "type": "integer" } + t.Run("no copies", func(t *testing.T) { + extra := MustParseModule("package extra\np = input") + done := false + testLoader := func(map[string]*Module) (map[string]*Module, error) { + if done { + return nil, nil } + done = true + return map[string]*Module{"extra.rego": extra}, nil } - ] -}` -const allOfArrayOfObjects = `{ - "$schema": "http://json-schema.org/draft-04/schema#", - "type": "array", - "title": "The b schema", - "description": "An explanation about the purpose of this instance.", - "items": { - "type": "object", - "title": "The items schema", - "description": "An explanation about the purpose of this instance.", - "properties": { - "State": { - "type": "string" - }, - "ZipCode": { - "type": "string" - } - }, - "allOf": [{ - "type": "object", - "title": "The b schema", - "description": "An explanation about the purpose of this instance.", - "properties": { - "County": { - "type": "string" - }, - "PostCode": { - "type": "string" - } - } - }, - { - "type": "object", - "title": "The b schema", - "description": "An explanation about the purpose of this instance.", - "properties": { - "Street": { - "type": "string" - }, - "House": { - "type": "string" - } - } - } - ] - } -}` + c := NewCompiler().WithModuleLoader(testLoader).WithKeepModules(true) -const allOfObjectAndArray = `{ - "$schema": "http://json-schema.org/draft-04/schema#", - "type": "object", - "title": "My schema", - "properties": { - "AddressLine1": { - "type": "string" - }, - "AddressLine2": { - "type": "string" - }, - "City": { - "type": "string" + mod := MustParseModule("package bar\np = input") + c.Compile(map[string]*Module{"bar.rego": mod}) + if len(c.Errors) != 0 { + t.Fatalf("expected no error; got %v", c.Errors) } - }, - "allOf": [{ - "type": "object", - "properties": { - "State": { - "type": "string" - }, - "ZipCode": { - "type": "string" - } - } - }, - { - "type": "array", - "title": "The b schema", - "description": "An explanation about the purpose of this instance.", - "items": { - "type": "integer", - "title": "The items schema", - "description": "An explanation about the purpose of this instance." - } + + mods := c.ParsedModules() + if exp, act := 2, len(mods); exp != act { + t.Errorf("expected %d modules, found %d: %v", exp, act, mods) } - ] -}` + newName := Var("q") + mods["bar.rego"].Rules[0].Head.Name = newName + if exp, act := newName, mod.Rules[0].Head.Name; !exp.Equal(act) { + t.Errorf("expected modified rule name %v, found %v", exp, act) + } + mods["extra.rego"].Rules[0].Head.Name = newName + if exp, act := newName, extra.Rules[0].Head.Name; !exp.Equal(act) { + t.Errorf("expected modified rule name %v, found %v", exp, act) + } + }) -const allOfObjectMissing = `{ - "type": "object", - "properties": { - "AddressLine": { "type": "string" } - }, - "allOf": [ - { - "type": "object", - "properties": { - "State": { "type": "string" }, - "ZipCode": { "type": "string" } - } - }, - { - "properties": { - "County": { "type": "string" }, - "PostCode": { "type": "integer" } + t.Run("keep, with loader", func(t *testing.T) { + extra := MustParseModule("package extra\np = input") + done := false + testLoader := func(map[string]*Module) (map[string]*Module, error) { + if done { + return nil, nil } + done = true + return map[string]*Module{"extra.rego": extra}, nil } - ] -}` -const allOfArrayDifTypes = `{ - "$schema": "http://json-schema.org/draft-04/schema#", - "type": "array", - "title": "The b schema", - "description": "An explanation about the purpose of this instance.", - "allOf": [{ - "type": "array", - "items": [{ - "type": "string" - }, - { - "type": "integer" - } - ] - }, - { - "type": "array", - "items": [{ - "type": "string" - }, - { - "type": "integer" - } - ] + c := NewCompiler().WithModuleLoader(testLoader).WithKeepModules(true) + + // This one is overwritten by c.Compile() + c.Modules["foo.rego"] = MustParseModule("package foo\np = true") + + c.Compile(map[string]*Module{"bar.rego": MustParseModule("package bar\np = input")}) + + if len(c.Errors) != 0 { + t.Fatalf("expected no error; got %v", c.Errors) } - ] -}` -const allOfArrayInsideObject = `{ - "type": "object", - "properties": { - "familyMembers": { - "type": "array", - "items": { - "allOf": [{ - "type": "object", - "properties": { - "age": { - "type": "integer" - }, - "name": { - "type": "string" - } - } - }, { - "type": "object", - "properties": { - "personality": { - "type": "string" - }, - "nickname": { - "type": "string" - } - } - }] + mods := c.ParsedModules() + if exp, act := 2, len(mods); exp != act { + t.Errorf("expected %d modules, found %d: %v", exp, act, mods) + } + for k := range mods { + if k != "bar.rego" && k != "extra.rego" { + t.Errorf("unexpected key: %v, want 'extra.rego' and 'bar.rego'", k) } } - } -}` -const anyOfArrayMissing = `{ - "type": "array", - "anyOf": [ - { - "items": [ - {"type": "number"}, - {"type": "string"}] - }, - { "items": [ - {"type": "integer"}] + for k := range mods { + compiled := c.Modules[k] + if compiled.Equal(mods[k]) { + t.Errorf("expected module %v to not be compiled: %v", k, mods[k]) + } } - ] -}` + }) +} -const allOfArrayMissing = `{ - "type": "array", - "allOf": [{ - "items": [{ - "type": "integer" - }, - { - "type": "integer" - } - ] - }, - { - "items": [{ - "type": "integer" - }] - } - ] +// see https://github.com/open-policy-agent/opa/issues/5166 +func TestCompilerWithRecursiveSchema(t *testing.T) { + + jsonSchema := `{ + "$schema": "https://json-schema.org/draft/2020-12/schema", + "$id": "https://github.com/open-policy-agent/opa/issues/5166", + "type": "object", + "properties": { + "Something": { + "$ref": "#/$defs/X" + } + }, + "$defs": { + "X": { + "type": "object", + "properties": { + "Name": { "type": "string" }, + "Y": { + "$ref": "#/$defs/Y" + } + } + }, + "Y": { + "type": "object", + "properties": { + "X": { + "$ref": "#/$defs/X" + } + } + } + } }` -const anyOfSchemaParentVariation = `{ - "$schema": "http://json-schema.org/draft-04/schema#", - "anyOf": [ - { - "type": "object", - "properties": { - "State": { "type": "string" }, - "ZipCode": { "type": "string" } - }, - }, - { - "type": "object", - "properties": { - "County": { "type": "string" }, - "PostCode": { "type": "string" } - }, - } - ] + exampleModule := `# METADATA +# schemas: +# - input: schema.input +package opa.recursion + +deny { + input.Something.Y.X.Name == "Something" +} +` + + c := NewCompiler() + var schema interface{} + if err := json.Unmarshal([]byte(jsonSchema), &schema); err != nil { + t.Fatal(err) } -}` + schemaSet := NewSchemaSet() + schemaSet.Put(MustParseRef("schema.input"), schema) + c.WithSchemas(schemaSet) -const allOfInsideCoreSchema = `{ - "type": "object", - "properties": { - "AddressLine": { "type": "string" }, - "RandomInfo": { - "allOf": [ - { "type": "object", - "properties": { - "accessMe": {"type": "string"} - } - }, - { "type": "object", - "properties": { - "accessYou": {"type": "string"} - }} - ] - } + m := MustParseModuleWithOpts(exampleModule, ParserOptions{ProcessAnnotation: true}) + c.Compile(map[string]*Module{"testMod": m}) + if c.Failed() { + t.Errorf("Expected compilation to succeed, but got errors: %v", c.Errors) } -}` +} -const allOfArrayDifTypesWithError = `{ - "$schema": "http://json-schema.org/draft-04/schema#", - "type": "array", - "title": "The b schema", - "description": "An explanation about the purpose of this instance.", - "allOf": [{ - "type": "array", - "items": [{ - "type": "string" - }, - { - "type": "integer" - } - ] - }, - { - "type": "array", - "items": [{ - "type": "boolean" - }, - { - "type": "integer" - } - ] - } - ] +// see https://github.com/open-policy-agent/opa/issues/5166 +func TestCompilerWithRecursiveSchemaAndInvalidSource(t *testing.T) { + + jsonSchema := `{ + "$schema": "https://json-schema.org/draft/2020-12/schema", + "$id": "https://github.com/open-policy-agent/opa/issues/5166", + "type": "object", + "properties": { + "Something": { + "$ref": "#/$defs/X" + } + }, + "$defs": { + "X": { + "type": "object", + "properties": { + "Name": { "type": "string" }, + "Y": { + "$ref": "#/$defs/Y" + } + } + }, + "Y": { + "type": "object", + "properties": { + "X": { + "$ref": "#/$defs/X" + } + } + } + } }` -const allOfStringSchema = `{ - "$schema": "http://json-schema.org/draft-04/schema#", - "type": "string", - "title": "The b schema", - "description": "An explanation about the purpose of this instance.", - "allOf": [{ - "type": "string", - }, - { - "type": "string", - } - ] -}` + exampleModule := `# METADATA +# schemas: +# - input: schema.input +package opa.recursion -const allOfIntegerSchema = `{ - "$schema": "http://json-schema.org/draft-04/schema#", - "type": "integer", - "title": "The b schema", - "description": "An explanation about the purpose of this instance.", - "allOf": [{ - "type": "integer", - }, - { - "type": "integer", - } - ] -}` +deny { + input.Something.Y.X.ThisDoesNotExist == "Something" +} +` -const allOfBooleanSchema = `{ - "$schema": "http://json-schema.org/draft-04/schema#", - "type": "boolean", - "title": "The b schema", - "description": "An explanation about the purpose of this instance.", - "allOf": [{ - "type": "boolean", - }, - { - "type": "boolean", - } - ] -}` + c := NewCompiler(). + WithUseTypeCheckAnnotations(true) + var schema interface{} + if err := json.Unmarshal([]byte(jsonSchema), &schema); err != nil { + t.Fatal(err) + } + schemaSet := NewSchemaSet() + schemaSet.Put(MustParseRef("schema.input"), schema) + c.WithSchemas(schemaSet) -const allOfTypeErrorSchema = `{ - "$schema": "http://json-schema.org/draft-04/schema#", - "type": "string", - "title": "The b schema", - "description": "An explanation about the purpose of this instance.", - "allOf": [{ - "type": "string", - }, - { - "type": "integer", - } - ] -}` + m := MustParseModuleWithOpts(exampleModule, ParserOptions{ProcessAnnotation: true}) + c.Compile(map[string]*Module{"testMod": m}) + if !c.Failed() { + t.Errorf("Expected compilation to fail, but it succeeded") + } else if !strings.HasPrefix(c.Errors.Error(), "1 error occurred: 7:2: rego_type_error: undefined ref: input.Something.Y.X.ThisDoesNotExist") { + t.Errorf("unexpected error: %v", c.Errors.Error()) + } +} -const allOfStringSchemaWithError = `{ - "$schema": "http://json-schema.org/draft-04/schema#", - "type": "string", - "title": "The b schema", - "description": "An explanation about the purpose of this instance.", - "allOf": [{ - "type": "string", - }, - { - "type": "string", - }, - { - "type": "boolean", +func modules(ms ...string) []*Module { + opts := ParserOptions{AllFutureKeywords: true, unreleasedKeywords: true} + mods := make([]*Module, len(ms)) + for i, m := range ms { + var err error + mods[i], err = ParseModuleWithOpts(fmt.Sprintf("mod%d.rego", i), m, opts) + if err != nil { + panic(err) } - ] -}` + } + return mods +} -const allOfSchemaWithParentError = `{ - "$schema": "http://json-schema.org/draft-04/schema#", - "type": "string", - "title": "The b schema", - "description": "An explanation about the purpose of this instance.", - "allOf": [{ - "type": "integer", - }, - { - "type": "integer", - } - ] +func TestCompilerWithRecursiveSchemaAvoidRace(t *testing.T) { + + jsonSchema := `{ + "type": "object", + "properties": { + "aws": { + "type": "object", + "$ref": "#/$defs/example.pkg.providers.aws.AWS" + } + }, + "$defs": { + "example.pkg.providers.aws.AWS": { + "type": "object", + "properties": { + "iam": { + "type": "object", + "$ref": "#/$defs/example.pkg.providers.aws.iam.IAM" + }, + "sqs": { + "type": "object", + "$ref": "#/$defs/example.pkg.providers.aws.sqs.SQS" + } + } + }, + "example.pkg.providers.aws.iam.Document": { + "type": "object" + }, + "example.pkg.providers.aws.iam.IAM": { + "type": "object", + "properties": { + "policies": { + "type": "array", + "items": { + "type": "object", + "$ref": "#/$defs/example.pkg.providers.aws.iam.Policy" + } + } + } + }, + "example.pkg.providers.aws.iam.Policy": { + "type": "object", + "properties": { + "builtin": { + "type": "object", + "properties": { + "value": { + "type": "boolean" + } + } + }, + "document": { + "type": "object", + "$ref": "#/$defs/example.pkg.providers.aws.iam.Document" + } + } + }, + "example.pkg.providers.aws.sqs.Queue": { + "type": "object", + "properties": { + "policies": { + "type": "array", + "items": { + "type": "object", + "$ref": "#/$defs/example.pkg.providers.aws.iam.Policy" + } + } + } + }, + "example.pkg.providers.aws.sqs.SQS": { + "type": "object", + "properties": { + "queues": { + "type": "array", + "items": { + "type": "object", + "$ref": "#/$defs/example.pkg.providers.aws.sqs.Queue" + } + } + } + } + } }` -const allOfSchemaWithUnevenArray = `{ - "type": "array", - "allOf": [{ - "items": [{ - "type": "integer" - }, - { - "type": "integer" - } - ] - }, - { - "items": [{ - "type": "integer" - }, - { - "type": "integer" - }, - { - "type": "string" - }] - } - ] -}` + exampleModule := `# METADATA +# schemas: +# - input: schema.input +package race.condition + +deny { + queue := input.aws.sqs.queues[_] + policy := queue.policies[_] + doc := json.unmarshal(policy.document.value) + statement = doc.Statement[_] + action := statement.Action[_] + action == "*" +} +` + + c := NewCompiler() + var schema interface{} + if err := json.Unmarshal([]byte(jsonSchema), &schema); err != nil { + t.Fatal(err) + } + schemaSet := NewSchemaSet() + schemaSet.Put(MustParseRef("schema.input"), schema) + c.WithSchemas(schemaSet) + + m := MustParseModuleWithOpts(exampleModule, ParserOptions{ProcessAnnotation: true}) + c.Compile(map[string]*Module{"testMod": m}) + if c.Failed() { + t.Fatal(c.Errors) + } +} diff --git a/ast/compilehelper.go b/ast/compilehelper.go index ca75dfabae..dd48884f9d 100644 --- a/ast/compilehelper.go +++ b/ast/compilehelper.go @@ -13,6 +13,7 @@ func CompileModules(modules map[string]string) (*Compiler, error) { // CompileOpts defines a set of options for the compiler. type CompileOpts struct { EnablePrintStatements bool + ParserOptions ParserOptions } // CompileModulesWithOpt takes a set of Rego modules represented as strings and @@ -24,7 +25,7 @@ func CompileModulesWithOpt(modules map[string]string, opts CompileOpts) (*Compil for f, module := range modules { var pm *Module var err error - if pm, err = ParseModule(f, module); err != nil { + if pm, err = ParseModuleWithOpts(f, module, opts.ParserOptions); err != nil { return nil, err } parsed[f] = pm diff --git a/ast/conflicts.go b/ast/conflicts.go index d1013ccedd..c2713ad576 100644 --- a/ast/conflicts.go +++ b/ast/conflicts.go @@ -27,7 +27,12 @@ func CheckPathConflicts(c *Compiler, exists func([]string) (bool, error)) Errors func checkDocumentConflicts(node *TreeNode, exists func([]string) (bool, error), path []string) Errors { - path = append(path, string(node.Key.(String))) + switch key := node.Key.(type) { + case String: + path = append(path, string(key)) + default: // other key types cannot conflict with data + return nil + } if len(node.Values) > 0 { s := strings.Join(path, "/") diff --git a/ast/doc.go b/ast/doc.go index 363660cf98..62b04e301e 100644 --- a/ast/doc.go +++ b/ast/doc.go @@ -8,29 +8,29 @@ // // Rego policies are typically defined in text files and then parsed and compiled by the policy engine at runtime. The parsing stage takes the text or string representation of the policy and converts it into an abstract syntax tree (AST) that consists of the types mentioned above. The AST is organized as follows: // -// Module -// | -// +--- Package (Reference) -// | -// +--- Imports -// | | -// | +--- Import (Term) -// | -// +--- Rules -// | -// +--- Rule -// | -// +--- Head -// | | -// | +--- Name (Variable) -// | | -// | +--- Key (Term) -// | | -// | +--- Value (Term) -// | -// +--- Body -// | -// +--- Expression (Term | Terms | Variable Declaration) +// Module +// | +// +--- Package (Reference) +// | +// +--- Imports +// | | +// | +--- Import (Term) +// | +// +--- Rules +// | +// +--- Rule +// | +// +--- Head +// | | +// | +--- Name (Variable) +// | | +// | +--- Key (Term) +// | | +// | +--- Value (Term) +// | +// +--- Body +// | +// +--- Expression (Term | Terms | Variable Declaration) // // At query time, the policy engine expects policies to have been compiled. The compilation stage takes one or more modules and compiles them into a format that the policy engine supports. package ast diff --git a/ast/env.go b/ast/env.go index 60006baafd..e3a4e44284 100644 --- a/ast/env.go +++ b/ast/env.go @@ -5,6 +5,8 @@ package ast import ( + "fmt" + "github.com/open-policy-agent/opa/types" "github.com/open-policy-agent/opa/util" ) @@ -59,6 +61,8 @@ func (env *TypeEnv) Get(x interface{}) types.Type { return types.NewArray(static, dynamic) + case *lazyObj: + return env.Get(x.force()) case *object: static := []*types.StaticProperty{} var dynamic *types.DynamicProperty @@ -195,9 +199,17 @@ func (env *TypeEnv) getRefRecExtent(node *typeTreeNode) types.Type { child := v.(*typeTreeNode) tpe := env.getRefRecExtent(child) - // TODO(tsandall): handle non-string keys? - if s, ok := key.(String); ok { - children = append(children, types.NewStaticProperty(string(s), tpe)) + + // NOTE(sr): Converting to Golang-native types here is an extension of what we did + // before -- only supporting strings. But since we cannot differentiate sets and arrays + // that way, we could reconsider. + switch key.(type) { + case String, Number, Boolean: // skip anything else + propKey, err := JSON(key) + if err != nil { + panic(fmt.Errorf("unreachable, ValueToInterface: %w", err)) + } + children = append(children, types.NewStaticProperty(propKey, tpe)) } return false }) diff --git a/ast/errors.go b/ast/errors.go index 11348b3d7a..066dfcdd68 100644 --- a/ast/errors.go +++ b/ast/errors.go @@ -121,12 +121,3 @@ func NewError(code string, loc *Location, f string, a ...interface{}) *Error { Message: fmt.Sprintf(f, a...), } } - -var ( - errPartialRuleAssignOperator = fmt.Errorf("partial rules must use = operator (not := operator)") - errFunctionAssignOperator = fmt.Errorf("functions must use = operator (not := operator)") -) - -func errTermAssignOperator(x interface{}) error { - return fmt.Errorf("cannot assign to %v", TypeName(x)) -} diff --git a/ast/fuzz.go b/ast/fuzz.go deleted file mode 100644 index 6ff7e35a8e..0000000000 --- a/ast/fuzz.go +++ /dev/null @@ -1,16 +0,0 @@ -// +build gofuzz - -package ast - -func Fuzz(data []byte) int { - - str := string(data) - _, _, err := ParseStatements("", str) - - if err == nil { - CompileModules(map[string]string{"": str}) - return 1 - } - - return 0 -} diff --git a/ast/fuzz_test.go b/ast/fuzz_test.go index acc0c6d2ba..c129b4ac36 100644 --- a/ast/fuzz_test.go +++ b/ast/fuzz_test.go @@ -5,16 +5,40 @@ //go:build go1.18 // +build go1.18 +// nolint package ast -import "testing" +import ( + "testing" -func FuzzParseStatementsAndCompileModules(f *testing.F) { + "github.com/open-policy-agent/opa/test/cases" +) + +var testcases = cases.MustLoad("../test/cases/testdata").Sorted().Cases + +func FuzzCompileModules(f *testing.F) { + for _, tc := range testcases { + for _, mod := range tc.Modules { + f.Add(mod) + } + } f.Fuzz(func(t *testing.T, input string) { - t.Parallel() // seed corpus tests can run in parallel - _, _, err := ParseStatements("", input) - if err != nil { - CompileModules(map[string]string{"": input}) + t.Parallel() + CompileModules(map[string]string{"": input}) + }) +} + +func FuzzCompileModulesWithPrintAndAllFutureKWs(f *testing.F) { + for _, tc := range testcases { + for _, mod := range tc.Modules { + f.Add(mod) } + } + f.Fuzz(func(t *testing.T, input string) { + t.Parallel() + CompileModulesWithOpt(map[string]string{"": input}, CompileOpts{ + EnablePrintStatements: true, + ParserOptions: ParserOptions{AllFutureKeywords: true}, + }) }) } diff --git a/ast/index.go b/ast/index.go index bcbb5c1765..ed7a2be26b 100644 --- a/ast/index.go +++ b/ast/index.go @@ -32,15 +32,16 @@ type RuleIndex interface { // IndexResult contains the result of an index lookup. type IndexResult struct { - Kind DocKind - Rules []*Rule - Else map[*Rule][]*Rule - Default *Rule - EarlyExit bool + Kind RuleKind + Rules []*Rule + Else map[*Rule][]*Rule + Default *Rule + EarlyExit bool + OnlyGroundRefs bool } // NewIndexResult returns a new IndexResult object. -func NewIndexResult(kind DocKind) *IndexResult { +func NewIndexResult(kind RuleKind) *IndexResult { return &IndexResult{ Kind: kind, Else: map[*Rule][]*Rule{}, @@ -53,18 +54,20 @@ func (ir *IndexResult) Empty() bool { } type baseDocEqIndex struct { - skipIndexing Set - isVirtual func(Ref) bool - root *trieNode - defaultRule *Rule - kind DocKind + skipIndexing Set + isVirtual func(Ref) bool + root *trieNode + defaultRule *Rule + kind RuleKind + onlyGroundRefs bool } func newBaseDocEqIndex(isVirtual func(Ref) bool) *baseDocEqIndex { return &baseDocEqIndex{ - skipIndexing: NewSet(NewTerm(InternalPrint.Ref())), - isVirtual: isVirtual, - root: newTrieNodeImpl(), + skipIndexing: NewSet(NewTerm(InternalPrint.Ref())), + isVirtual: isVirtual, + root: newTrieNodeImpl(), + onlyGroundRefs: true, } } @@ -73,7 +76,7 @@ func (i *baseDocEqIndex) Build(rules []*Rule) bool { return false } - i.kind = rules[0].Head.DocKind() + i.kind = rules[0].Head.RuleKind() indices := newrefindices(i.isVirtual) // build indices for each rule. @@ -83,6 +86,9 @@ func (i *baseDocEqIndex) Build(rules []*Rule) bool { i.defaultRule = rule return false } + if i.onlyGroundRefs { + i.onlyGroundRefs = rule.Head.Reference.IsGround() + } var skip bool for _, expr := range rule.Body { if op := expr.OperatorTerm(); op != nil && i.skipIndexing.Contains(op) { @@ -134,6 +140,7 @@ func (i *baseDocEqIndex) Lookup(resolver ValueResolver) (*IndexResult, error) { result := NewIndexResult(i.kind) result.Default = i.defaultRule + result.OnlyGroundRefs = i.onlyGroundRefs result.Rules = make([]*Rule, 0, len(tr.ordering)) for _, pos := range tr.ordering { @@ -166,6 +173,7 @@ func (i *baseDocEqIndex) AllRules(resolver ValueResolver) (*IndexResult, error) result := NewIndexResult(i.kind) result.Default = i.defaultRule + result.OnlyGroundRefs = i.onlyGroundRefs result.Rules = make([]*Rule, 0, len(tr.ordering)) for _, pos := range tr.ordering { @@ -246,15 +254,20 @@ func (i *refindices) Update(rule *Rule, expr *Expr) { op := expr.Operator() - if op.Equal(Equality.Ref()) { + switch { + case op.Equal(Equality.Ref()): i.updateEq(rule, expr) - } else if op.Equal(Equal.Ref()) && len(expr.Operands()) == 2 { + + case op.Equal(Equal.Ref()) && len(expr.Operands()) == 2: // NOTE(tsandall): if equal() is called with more than two arguments the // output value is being captured in which case the indexer cannot // exclude the rule if the equal() call would return false (because the // false value must still be produced.) i.updateEq(rule, expr) - } else if op.Equal(GlobMatch.Ref()) { + + case op.Equal(GlobMatch.Ref()) && len(expr.Operands()) == 3: + // NOTE(sr): Same as with equal() above -- 4 operands means the output + // of `glob.match` is captured and the rule can thus not be excluded. i.updateGlobMatch(rule, expr) } } @@ -430,7 +443,7 @@ type trieNode struct { next *trieNode any *trieNode undefined *trieNode - scalars map[Value]*trieNode + scalars *util.HashMap array *trieNode rules []*ruleNode } @@ -453,11 +466,14 @@ func (node *trieNode) String() string { if node.array != nil { flags = append(flags, fmt.Sprintf("array:%p", node.array)) } - if len(node.scalars) > 0 { - buf := make([]string, 0, len(node.scalars)) - for k, v := range node.scalars { - buf = append(buf, fmt.Sprintf("scalar(%v):%p", k, v)) - } + if node.scalars.Len() > 0 { + buf := make([]string, 0, node.scalars.Len()) + node.scalars.Iter(func(k, v util.T) bool { + key := k.(Value) + val := v.(*trieNode) + buf = append(buf, fmt.Sprintf("scalar(%v):%p", key, val)) + return false + }) sort.Strings(buf) flags = append(flags, strings.Join(buf, " ")) } @@ -493,7 +509,7 @@ type ruleNode struct { func newTrieNodeImpl() *trieNode { return &trieNode{ - scalars: map[Value]*trieNode{}, + scalars: util.NewHashMap(valueEq, valueHash), } } @@ -508,9 +524,13 @@ func (node *trieNode) Do(walker trieWalker) { if node.undefined != nil { node.undefined.Do(next) } - for _, child := range node.scalars { + + node.scalars.Iter(func(_, v util.T) bool { + child := v.(*trieNode) child.Do(next) - } + return false + }) + if node.array != nil { node.array.Do(next) } @@ -567,12 +587,12 @@ func (node *trieNode) insertValue(value Value) *trieNode { } return node.any case Null, Boolean, Number, String: - child, ok := node.scalars[value] + child, ok := node.scalars.Get(value) if !ok { child = newTrieNodeImpl() - node.scalars[value] = child + node.scalars.Put(value, child) } - return child + return child.(*trieNode) case *Array: if node.array == nil { node.array = newTrieNodeImpl() @@ -596,12 +616,12 @@ func (node *trieNode) insertArray(arr *Array) *trieNode { } return node.any.insertArray(arr.Slice(1, -1)) case Null, Boolean, Number, String: - child, ok := node.scalars[head] + child, ok := node.scalars.Get(head) if !ok { child = newTrieNodeImpl() - node.scalars[head] = child + node.scalars.Put(head, child) } - return child.insertArray(arr.Slice(1, -1)) + return child.(*trieNode).insertArray(arr.Slice(1, -1)) } panic("illegal value") @@ -662,11 +682,11 @@ func (node *trieNode) traverseValue(resolver ValueResolver, tr *trieTraversalRes return node.array.traverseArray(resolver, tr, value) case Null, Boolean, Number, String: - child, ok := node.scalars[value] + child, ok := node.scalars.Get(value) if !ok { return nil } - return child.Traverse(resolver, tr) + return child.(*trieNode).Traverse(resolver, tr) } return nil @@ -691,12 +711,11 @@ func (node *trieNode) traverseArray(resolver ValueResolver, tr *trieTraversalRes } } - child, ok := node.scalars[head] + child, ok := node.scalars.Get(head) if !ok { return nil } - - return child.traverseArray(resolver, tr, arr.Slice(1, -1)) + return child.(*trieNode).traverseArray(resolver, tr, arr.Slice(1, -1)) } func (node *trieNode) traverseUnknown(resolver ValueResolver, tr *trieTraversalResult) error { @@ -721,13 +740,16 @@ func (node *trieNode) traverseUnknown(resolver ValueResolver, tr *trieTraversalR return err } - for _, child := range node.scalars { - if err := child.traverseUnknown(resolver, tr); err != nil { - return err + var iterErr error + node.scalars.Iter(func(_, v util.T) bool { + child := v.(*trieNode) + if iterErr = child.traverseUnknown(resolver, tr); iterErr != nil { + return true } - } + return false + }) - return nil + return iterErr } // If term `a` is one of the function's operands, we store a Ref: `args[0]` diff --git a/ast/index_test.go b/ast/index_test.go index 01e92283fa..5026c2c16f 100644 --- a/ast/index_test.go +++ b/ast/index_test.go @@ -41,6 +41,48 @@ func (r testResolver) Resolve(ref Ref) (Value, error) { } func TestBaseDocEqIndexing(t *testing.T) { + opts := ParserOptions{AllFutureKeywords: true, unreleasedKeywords: true} + + expectOnlyGroundRefs := func(exp bool) func(*testing.T, *IndexResult) { + return func(t *testing.T, res *IndexResult) { + t.Helper() + if act := res.OnlyGroundRefs; exp != act { + t.Errorf("OnlyGroundRefs: expected %v, got %v", exp, act) + } + } + } + + everyMod := MustParseModuleWithOpts(`package test + p { every _ in [] { input.a = 1 } }`, opts) + + // NOTE(sr): This looks a bit silly; but it's what + // + // every x in input.a { input.x == x } + // + // will get rewritten to -- so to assert that the domain of 'every' expressions + // get respected in the rule indexing, we'll need to provide this "pseudo-compiled" + // module source here. + everyModWithDomain := MustParseModuleWithOpts(`package test + p { + __local0__ = input.a + every x in __local0__ { input.x = x } + } { + input.b = 1 + }`, opts) + + refMod := MustParseModuleWithOpts(`package test + + ref.single.value.ground = x if x := input.x + + ref.single.value.key[k] = v if { k := input.k; v := input.v } + + ref.multi.value.ground contains x if x := input.x + + ref.multiple.single.value.ground = x if x := input.x + ref.multiple.single.value[y] = x if { x := input.x; y := index.y } + + # ref.multi.value.key[k] contains v if { k := input.k; v := input.v } # not supported yet + `, opts) module := MustParseModule(` package test @@ -52,6 +94,7 @@ func TestBaseDocEqIndexing(t *testing.T) { input.x = 3 input.y = 4 } + scalars { input.x = 0 @@ -188,14 +231,16 @@ func TestBaseDocEqIndexing(t *testing.T) { `) tests := []struct { - note string - module *Module - ruleset string - input string - unknowns []string - args []Value - expectedRS interface{} - expectedDR *Rule + note string + module *Module + ruleset string + ruleRef Ref + input string + unknowns []string + args []Value + expectedRS interface{} + expectedDR *Rule + checkResult func(*testing.T, *IndexResult) }{ { note: "exact match", @@ -204,6 +249,7 @@ func TestBaseDocEqIndexing(t *testing.T) { expectedRS: []string{ `exact { input.x = 3; input.y = 4 }`, }, + checkResult: expectOnlyGroundRefs(true), // covering base case }, { note: "undefined match", @@ -461,6 +507,17 @@ func TestBaseDocEqIndexing(t *testing.T) { input: `{"x": [0]}`, expectedRS: []string{}, }, + { + note: "glob.match: do not index captured output", + module: MustParseModule(`package test + p { x = input.x; glob.match("/a/*/c", ["/"], x, false) } + `), + ruleset: "p", + input: `{"x": "wrong"}`, + expectedRS: []string{ + `p { x = input.x; glob.match("/a/*/c", ["/"], x, false) }`, + }, + }, { note: "functions: args match", module: MustParseModule(`package test @@ -600,6 +657,59 @@ func TestBaseDocEqIndexing(t *testing.T) { `f(x) = y { equal(x, 1, z); y = z }`, }, }, + { + note: "every: do not index body", + module: everyMod, + ruleset: "p", + input: `{"a": 2}`, + expectedRS: RuleSet(everyMod.Rules), + }, + { + note: "every: index domain", + module: everyModWithDomain, + ruleset: "p", + input: `{"a": [1]}`, + expectedRS: RuleSet([]*Rule{everyModWithDomain.Rules[0]}), + }, + { + note: "ref: single value, ground ref", + module: refMod, + ruleRef: MustParseRef("ref.single.value.ground"), + input: `{"x": 1}`, + expectedRS: RuleSet([]*Rule{refMod.Rules[0]}), + checkResult: expectOnlyGroundRefs(true), + }, + { + note: "ref: single value, ground ref and non-ground ref", + module: refMod, + ruleRef: MustParseRef("ref.multiple.single.value"), + input: `{"x": 1, "y": "Y"}`, + expectedRS: RuleSet([]*Rule{refMod.Rules[3], refMod.Rules[4]}), + checkResult: expectOnlyGroundRefs(false), + }, + { + note: "ref: single value, var in ref", + module: refMod, + ruleRef: MustParseRef("ref.single.value.key[k]"), + input: `{"k": 1, "v": 2}`, + expectedRS: RuleSet([]*Rule{refMod.Rules[1]}), + checkResult: expectOnlyGroundRefs(false), + }, + { + note: "ref: multi value, ground ref", + module: refMod, + ruleRef: MustParseRef("ref.multi.value.ground"), + input: `{"x": 1}`, + expectedRS: RuleSet([]*Rule{refMod.Rules[2]}), + checkResult: expectOnlyGroundRefs(true), + }, + // { + // note: "ref: multi value, var in ref", + // module: refMod, + // ruleRef: MustParseRef("ref.multi.value.key[k]"), + // input: `{"k": 1, "v": 2}`, + // expectedRS: RuleSet([]*Rule{refMod.Rules[3]}), + // }, } for _, tc := range tests { @@ -610,10 +720,19 @@ func TestBaseDocEqIndexing(t *testing.T) { } rules := []*Rule{} for _, rule := range module.Rules { - if rule.Head.Name == Var(tc.ruleset) { - rules = append(rules, rule) + if tc.ruleRef == nil { + if rule.Head.Name == Var(tc.ruleset) { + rules = append(rules, rule) + } + } else { + if rule.Head.Ref().HasPrefix(tc.ruleRef) { + rules = append(rules, rule) + } } } + if len(rules) == 0 { + t.Fatal("selected empty ruleset") + } var input *Term if tc.input != "" { @@ -630,7 +749,7 @@ func TestBaseDocEqIndexing(t *testing.T) { case RuleSet: expectedRS = e default: - panic("Unexpected test case expected value") + panic("Unexpected test case: expected value") } index := newBaseDocEqIndex(func(Ref) bool { @@ -655,6 +774,10 @@ func TestBaseDocEqIndexing(t *testing.T) { t.Fatalf("Unexpected error during index lookup: %v", err) } + if tc.checkResult != nil { + tc.checkResult(t, result) + } + if !NewRuleSet(result.Rules...).Equal(expectedRS) { t.Fatalf("Expected ruleset %v but got: %v", expectedRS, result.Rules) } diff --git a/ast/internal/scanner/scanner.go b/ast/internal/scanner/scanner.go index 9402749e34..7174d092b9 100644 --- a/ast/internal/scanner/scanner.go +++ b/ast/internal/scanner/scanner.go @@ -7,7 +7,6 @@ package scanner import ( "fmt" "io" - "io/ioutil" "unicode" "unicode/utf8" @@ -47,7 +46,7 @@ type Position struct { // through the source code provided by the io.Reader. func New(r io.Reader) (*Scanner, error) { - bs, err := ioutil.ReadAll(r) + bs, err := io.ReadAll(r) if err != nil { return nil, err } @@ -96,6 +95,11 @@ func (s *Scanner) Keyword(lit string) tokens.Token { // AddKeyword adds a string -> token mapping to this Scanner instance. func (s *Scanner) AddKeyword(kw string, tok tokens.Token) { s.keywords[kw] = tok + + switch tok { + case tokens.Every: // importing 'every' means also importing 'in' + s.keywords["in"] = tokens.In + } } // WithKeywords returns a new copy of the Scanner struct `s`, with the set @@ -112,6 +116,21 @@ func (s *Scanner) WithKeywords(kws map[string]tokens.Token) *Scanner { return &cpy } +// WithoutKeywords returns a new copy of the Scanner struct `s`, with the +// set of known keywords being that of `s` with `kws` removed. +// The previously known keywords are returned for a convenient reset. +func (s *Scanner) WithoutKeywords(kws map[string]tokens.Token) (*Scanner, map[string]tokens.Token) { + cpy := *s + kw := s.keywords + cpy.keywords = make(map[string]tokens.Token, len(s.keywords)-len(kws)) + for kw, tok := range s.keywords { + if _, ok := kws[kw]; !ok { + cpy.AddKeyword(kw, tok) + } + } + return &cpy, kw +} + // Scan will increment the scanners position in the source // code until the next token is found. The token, starting position // of the token, string literal, and any errors encountered are diff --git a/ast/internal/tokens/tokens.go b/ast/internal/tokens/tokens.go index ce053c18a9..29bf971d34 100644 --- a/ast/internal/tokens/tokens.go +++ b/ast/internal/tokens/tokens.go @@ -65,6 +65,10 @@ const ( Lte Dot Semicolon + + Every + Contains + If ) var strings = [...]string{ @@ -112,6 +116,9 @@ var strings = [...]string{ Lte: "lte", Dot: ".", Semicolon: ";", + Every: "every", + Contains: "contains", + If: "if", } var keywords = map[string]Token{ @@ -136,3 +143,9 @@ func Keywords() map[string]Token { } return cpy } + +// IsKeyword returns if a token is a keyword +func IsKeyword(tok Token) bool { + _, ok := keywords[strings[tok]] + return ok +} diff --git a/ast/location/location.go b/ast/location/location.go index 13ae6e35d7..9ef1e6dfd4 100644 --- a/ast/location/location.go +++ b/ast/location/location.go @@ -3,9 +3,8 @@ package location import ( "bytes" + "errors" "fmt" - - "github.com/pkg/errors" ) // Location records a position in source code @@ -39,7 +38,7 @@ func (loc *Location) Errorf(f string, a ...interface{}) error { // Wrapf returns a new error value that wraps an existing error with a message formatted // to include the location info (e.g., line, column, filename, etc.) func (loc *Location) Wrapf(err error, f string, a ...interface{}) error { - return errors.Wrap(err, loc.Format(f, a...)) + return fmt.Errorf(loc.Format(f, a...)+": %w", err) } // Format returns a formatted string prefixed with the location information. diff --git a/ast/marshal.go b/ast/marshal.go new file mode 100644 index 0000000000..891945db8b --- /dev/null +++ b/ast/marshal.go @@ -0,0 +1,7 @@ +package ast + +// customJSON is an interface that can be implemented by AST nodes that +// allows the parser to set options for JSON operations on that node. +type customJSON interface { + setJSONOptions(JSONOptions) +} diff --git a/ast/marshal_test.go b/ast/marshal_test.go new file mode 100644 index 0000000000..38b5216279 --- /dev/null +++ b/ast/marshal_test.go @@ -0,0 +1,972 @@ +package ast + +import ( + "encoding/json" + "testing" + + "github.com/open-policy-agent/opa/util" +) + +func TestTerm_MarshalJSON(t *testing.T) { + testCases := map[string]struct { + Term *Term + ExpectedJSON string + }{ + "base case": { + Term: func() *Term { + v, _ := InterfaceToValue("example") + return &Term{ + Value: v, + Location: NewLocation([]byte{}, "example.rego", 1, 2), + } + }(), + ExpectedJSON: `{"type":"string","value":"example"}`, + }, + "location excluded": { + Term: func() *Term { + v, _ := InterfaceToValue("example") + return &Term{ + Value: v, + Location: NewLocation([]byte{}, "example.rego", 1, 2), + jsonOptions: JSONOptions{ + MarshalOptions: JSONMarshalOptions{ + IncludeLocation: NodeToggle{ + Term: false, + }, + }, + }, + } + }(), + ExpectedJSON: `{"type":"string","value":"example"}`, + }, + "location included": { + Term: func() *Term { + v, _ := InterfaceToValue("example") + return &Term{ + Value: v, + Location: NewLocation([]byte{}, "example.rego", 1, 2), + jsonOptions: JSONOptions{ + MarshalOptions: JSONMarshalOptions{ + IncludeLocation: NodeToggle{ + Term: true, + }, + }, + }, + } + }(), + ExpectedJSON: `{"location":{"file":"example.rego","row":1,"col":2},"type":"string","value":"example"}`, + }, + } + + for name, data := range testCases { + t.Run(name, func(t *testing.T) { + bs := util.MustMarshalJSON(data.Term) + got := string(bs) + exp := data.ExpectedJSON + + if got != exp { + t.Fatalf("expected:\n%s got\n%s", exp, got) + } + }) + } +} + +func TestTerm_UnmarshalJSON(t *testing.T) { + testCases := map[string]struct { + JSON string + ExpectedTerm *Term + }{ + "base case": { + JSON: `{"type":"string","value":"example"}`, + ExpectedTerm: func() *Term { + v, _ := InterfaceToValue("example") + return &Term{ + Value: v, + } + }(), + }, + "location case": { + JSON: `{"location":{"file":"example.rego","row":1,"col":2},"type":"string","value":"example"}`, + ExpectedTerm: func() *Term { + v, _ := InterfaceToValue("example") + return &Term{ + Value: v, + Location: NewLocation([]byte{}, "example.rego", 1, 2), + } + }(), + }, + } + + for name, data := range testCases { + t.Run(name, func(t *testing.T) { + var term Term + err := json.Unmarshal([]byte(data.JSON), &term) + if err != nil { + t.Fatal(err) + } + + if !term.Equal(data.ExpectedTerm) { + t.Fatalf("expected:\n%#v got\n%#v", data.ExpectedTerm, term) + } + if data.ExpectedTerm.Location != nil { + if !term.Location.Equal(data.ExpectedTerm.Location) { + t.Fatalf("expected location:\n%#v got\n%#v", data.ExpectedTerm, term) + } + } + }) + } +} + +func TestPackage_MarshalJSON(t *testing.T) { + testCases := map[string]struct { + Package *Package + ExpectedJSON string + }{ + "base case": { + Package: &Package{ + Path: EmptyRef(), + }, + ExpectedJSON: `{"path":[]}`, + }, + "location excluded": { + Package: &Package{ + Path: EmptyRef(), + Location: NewLocation([]byte{}, "example.rego", 1, 2), + jsonOptions: JSONOptions{ + MarshalOptions: JSONMarshalOptions{ + IncludeLocation: NodeToggle{ + Package: false, + }, + }, + }, + }, + ExpectedJSON: `{"path":[]}`, + }, + "location included": { + Package: &Package{ + Path: EmptyRef(), + Location: NewLocation([]byte{}, "example.rego", 1, 2), + jsonOptions: JSONOptions{ + MarshalOptions: JSONMarshalOptions{ + IncludeLocation: NodeToggle{ + Package: true, + }, + }, + }, + }, + ExpectedJSON: `{"location":{"file":"example.rego","row":1,"col":2},"path":[]}`, + }, + } + + for name, data := range testCases { + t.Run(name, func(t *testing.T) { + bs := util.MustMarshalJSON(data.Package) + got := string(bs) + exp := data.ExpectedJSON + + if got != exp { + t.Fatalf("expected:\n%s got\n%s", exp, got) + } + }) + } +} + +// TODO: Comment has inconsistent JSON field names starting with an upper case letter. Comment Location is +// also always included for legacy reasons +func TestComment_MarshalJSON(t *testing.T) { + testCases := map[string]struct { + Comment *Comment + ExpectedJSON string + }{ + "base case": { + Comment: &Comment{ + Text: []byte("comment"), + }, + ExpectedJSON: `{"Text":"Y29tbWVudA==","Location":null}`, + }, + "location excluded, still included for legacy reasons": { + Comment: &Comment{ + Text: []byte("comment"), + Location: NewLocation([]byte{}, "example.rego", 1, 2), + jsonOptions: JSONOptions{ + MarshalOptions: JSONMarshalOptions{ + IncludeLocation: NodeToggle{ + Comment: false, // ignored + }, + }, + }, + }, + ExpectedJSON: `{"Text":"Y29tbWVudA==","Location":{"file":"example.rego","row":1,"col":2}}`, + }, + "location included": { + Comment: &Comment{ + Text: []byte("comment"), + Location: NewLocation([]byte{}, "example.rego", 1, 2), + jsonOptions: JSONOptions{ + MarshalOptions: JSONMarshalOptions{ + IncludeLocation: NodeToggle{ + Comment: true, // ignored + }, + }, + }, + }, + ExpectedJSON: `{"Text":"Y29tbWVudA==","Location":{"file":"example.rego","row":1,"col":2}}`, + }, + } + + for name, data := range testCases { + t.Run(name, func(t *testing.T) { + bs := util.MustMarshalJSON(data.Comment) + got := string(bs) + exp := data.ExpectedJSON + + if got != exp { + t.Fatalf("expected:\n%s got\n%s", exp, got) + } + }) + } +} + +func TestImport_MarshalJSON(t *testing.T) { + testCases := map[string]struct { + Import *Import + ExpectedJSON string + }{ + "base case": { + Import: func() *Import { + v, _ := InterfaceToValue("example") + term := Term{ + Value: v, + Location: NewLocation([]byte{}, "example.rego", 1, 2), + } + return &Import{Path: &term} + }(), + ExpectedJSON: `{"path":{"type":"string","value":"example"}}`, + }, + "location excluded": { + Import: func() *Import { + v, _ := InterfaceToValue("example") + term := Term{ + Value: v, + Location: NewLocation([]byte{}, "example.rego", 1, 2), + } + return &Import{ + Path: &term, + Location: NewLocation([]byte{}, "example.rego", 1, 2), + jsonOptions: JSONOptions{ + MarshalOptions: JSONMarshalOptions{ + IncludeLocation: NodeToggle{ + Import: false, + }, + }, + }, + } + }(), + ExpectedJSON: `{"path":{"type":"string","value":"example"}}`, + }, + "location included": { + Import: func() *Import { + v, _ := InterfaceToValue("example") + term := Term{ + Value: v, + Location: NewLocation([]byte{}, "example.rego", 1, 2), + } + return &Import{ + Path: &term, + Location: NewLocation([]byte{}, "example.rego", 1, 2), + jsonOptions: JSONOptions{ + MarshalOptions: JSONMarshalOptions{ + IncludeLocation: NodeToggle{ + Import: true, + }, + }, + }, + } + }(), + ExpectedJSON: `{"location":{"file":"example.rego","row":1,"col":2},"path":{"type":"string","value":"example"}}`, + }, + } + + for name, data := range testCases { + t.Run(name, func(t *testing.T) { + bs := util.MustMarshalJSON(data.Import) + got := string(bs) + exp := data.ExpectedJSON + + if got != exp { + t.Fatalf("expected:\n%s got\n%s", exp, got) + } + }) + } +} + +func TestRule_MarshalJSON(t *testing.T) { + rawModule := ` + package foo + + # comment + + allow { true } + ` + + module, err := ParseModuleWithOpts("example.rego", rawModule, ParserOptions{}) + if err != nil { + t.Fatal(err) + } + + rule := module.Rules[0] + + testCases := map[string]struct { + Rule *Rule + ExpectedJSON string + }{ + "base case": { + Rule: rule.Copy(), + ExpectedJSON: `{"body":[{"index":0,"terms":{"type":"boolean","value":true}}],"head":{"name":"allow","value":{"type":"boolean","value":true},"ref":[{"type":"var","value":"allow"}]}}`, + }, + "location excluded": { + Rule: func() *Rule { + r := rule.Copy() + r.jsonOptions = JSONOptions{ + MarshalOptions: JSONMarshalOptions{ + IncludeLocation: NodeToggle{ + Rule: false, + }, + }, + } + return r + }(), + ExpectedJSON: `{"body":[{"index":0,"terms":{"type":"boolean","value":true}}],"head":{"name":"allow","value":{"type":"boolean","value":true},"ref":[{"type":"var","value":"allow"}]}}`, + }, + "location included": { + Rule: func() *Rule { + r := rule.Copy() + r.jsonOptions = JSONOptions{ + MarshalOptions: JSONMarshalOptions{ + IncludeLocation: NodeToggle{ + Rule: true, + }, + }, + } + return r + }(), + ExpectedJSON: `{"body":[{"index":0,"terms":{"type":"boolean","value":true}}],"head":{"name":"allow","value":{"type":"boolean","value":true},"ref":[{"type":"var","value":"allow"}]},"location":{"file":"example.rego","row":6,"col":2}}`, + }, + } + + for name, data := range testCases { + t.Run(name, func(t *testing.T) { + bs := util.MustMarshalJSON(data.Rule) + got := string(bs) + exp := data.ExpectedJSON + + if got != exp { + t.Fatalf("expected:\n%s got\n%s", exp, got) + } + }) + } +} + +func TestHead_MarshalJSON(t *testing.T) { + rawModule := ` + package foo + + # comment + + allow { true } + ` + + module, err := ParseModuleWithOpts("example.rego", rawModule, ParserOptions{}) + if err != nil { + t.Fatal(err) + } + + head := module.Rules[0].Head + + testCases := map[string]struct { + Head *Head + ExpectedJSON string + }{ + "base case": { + Head: head.Copy(), + ExpectedJSON: `{"name":"allow","value":{"type":"boolean","value":true},"ref":[{"type":"var","value":"allow"}]}`, + }, + "location excluded": { + Head: func() *Head { + h := head.Copy() + h.jsonOptions = JSONOptions{ + MarshalOptions: JSONMarshalOptions{ + IncludeLocation: NodeToggle{ + Head: false, + }, + }, + } + + return h + }(), + ExpectedJSON: `{"name":"allow","value":{"type":"boolean","value":true},"ref":[{"type":"var","value":"allow"}]}`, + }, + "location included": { + Head: func() *Head { + h := head.Copy() + h.jsonOptions = JSONOptions{ + MarshalOptions: JSONMarshalOptions{ + IncludeLocation: NodeToggle{ + Head: true, + }, + }, + } + return h + }(), + ExpectedJSON: `{"name":"allow","value":{"type":"boolean","value":true},"ref":[{"type":"var","value":"allow"}],"location":{"file":"example.rego","row":6,"col":2}}`, + }, + } + + for name, data := range testCases { + t.Run(name, func(t *testing.T) { + bs := util.MustMarshalJSON(data.Head) + got := string(bs) + exp := data.ExpectedJSON + + if got != exp { + t.Fatalf("expected:\n%s got\n%s", exp, got) + } + }) + } +} + +func TestExpr_MarshalJSON(t *testing.T) { + rawModule := ` + package foo + + # comment + + allow { true } + ` + + module, err := ParseModuleWithOpts("example.rego", rawModule, ParserOptions{}) + if err != nil { + t.Fatal(err) + } + + expr := module.Rules[0].Body[0] + + testCases := map[string]struct { + Expr *Expr + ExpectedJSON string + }{ + "base case": { + Expr: expr.Copy(), + ExpectedJSON: `{"index":0,"terms":{"type":"boolean","value":true}}`, + }, + "location excluded": { + Expr: func() *Expr { + e := expr.Copy() + e.jsonOptions = JSONOptions{ + MarshalOptions: JSONMarshalOptions{ + IncludeLocation: NodeToggle{ + Expr: false, + }, + }, + } + + return e + }(), + ExpectedJSON: `{"index":0,"terms":{"type":"boolean","value":true}}`, + }, + "location included": { + Expr: func() *Expr { + e := expr.Copy() + e.jsonOptions = JSONOptions{ + MarshalOptions: JSONMarshalOptions{ + IncludeLocation: NodeToggle{ + Expr: true, + }, + }, + } + return e + }(), + ExpectedJSON: `{"index":0,"location":{"file":"example.rego","row":6,"col":10},"terms":{"type":"boolean","value":true}}`, + }, + } + + for name, data := range testCases { + t.Run(name, func(t *testing.T) { + bs := util.MustMarshalJSON(data.Expr) + got := string(bs) + exp := data.ExpectedJSON + + if got != exp { + t.Fatalf("expected:\n%s got\n%s", exp, got) + } + }) + } +} + +func TestExpr_UnmarshalJSON(t *testing.T) { + rawModule := ` + package foo + + # comment + + allow { true } + ` + + module, err := ParseModuleWithOpts("example.rego", rawModule, ParserOptions{}) + if err != nil { + t.Fatal(err) + } + + expr := module.Rules[0].Body[0] + // text is not marshalled to JSON so we just drop it in our examples + expr.Location.Text = nil + + testCases := map[string]struct { + JSON string + ExpectedExpr *Expr + }{ + "base case": { + JSON: `{"index":0,"terms":{"type":"boolean","value":true}}`, + ExpectedExpr: func() *Expr { + e := expr.Copy() + e.Location = nil + return e + }(), + }, + "location case": { + JSON: `{"index":0,"location":{"file":"example.rego","row":6,"col":10},"terms":{"type":"boolean","value":true}}`, + ExpectedExpr: expr, + }, + } + + for name, data := range testCases { + t.Run(name, func(t *testing.T) { + var expr Expr + err := json.Unmarshal([]byte(data.JSON), &expr) + if err != nil { + t.Fatal(err) + } + + if !expr.Equal(data.ExpectedExpr) { + t.Fatalf("expected:\n%#v got\n%#v", data.ExpectedExpr, expr) + } + if data.ExpectedExpr.Location != nil { + if !expr.Location.Equal(data.ExpectedExpr.Location) { + t.Fatalf("expected location:\n%#v got\n%#v", data.ExpectedExpr.Location, expr.Location) + } + } + }) + } +} + +func TestSomeDecl_MarshalJSON(t *testing.T) { + v, _ := InterfaceToValue("example") + term := &Term{ + Value: v, + Location: NewLocation([]byte{}, "example.rego", 1, 2), + } + + testCases := map[string]struct { + SomeDecl *SomeDecl + ExpectedJSON string + }{ + "base case": { + SomeDecl: &SomeDecl{ + Symbols: []*Term{term}, + Location: NewLocation([]byte{}, "example.rego", 1, 2), + }, + ExpectedJSON: `{"symbols":[{"type":"string","value":"example"}]}`, + }, + "location excluded": { + SomeDecl: &SomeDecl{ + Symbols: []*Term{term}, + Location: NewLocation([]byte{}, "example.rego", 1, 2), + jsonOptions: JSONOptions{MarshalOptions: JSONMarshalOptions{IncludeLocation: NodeToggle{SomeDecl: false}}}, + }, + ExpectedJSON: `{"symbols":[{"type":"string","value":"example"}]}`, + }, + "location included": { + SomeDecl: &SomeDecl{ + Symbols: []*Term{term}, + Location: NewLocation([]byte{}, "example.rego", 1, 2), + jsonOptions: JSONOptions{MarshalOptions: JSONMarshalOptions{IncludeLocation: NodeToggle{SomeDecl: true}}}, + }, + ExpectedJSON: `{"location":{"file":"example.rego","row":1,"col":2},"symbols":[{"type":"string","value":"example"}]}`, + }, + } + + for name, data := range testCases { + t.Run(name, func(t *testing.T) { + bs := util.MustMarshalJSON(data.SomeDecl) + got := string(bs) + exp := data.ExpectedJSON + + if got != exp { + t.Fatalf("expected:\n%s got\n%s", exp, got) + } + }) + } +} + +func TestEvery_MarshalJSON(t *testing.T) { + + rawModule := ` +package foo + +import future.keywords.every + +allow { + every e in [1,2,3] { + e == 1 + } +} +` + + module, err := ParseModuleWithOpts("example.rego", rawModule, ParserOptions{}) + if err != nil { + t.Fatal(err) + } + + every, ok := module.Rules[0].Body[0].Terms.(*Every) + if !ok { + t.Fatal("expected every term") + } + + testCases := map[string]struct { + Every *Every + ExpectedJSON string + }{ + "base case": { + Every: every, + ExpectedJSON: `{"body":[{"index":0,"terms":[{"type":"ref","value":[{"type":"var","value":"equal"}]},{"type":"var","value":"e"},{"type":"number","value":1}]}],"domain":{"type":"array","value":[{"type":"number","value":1},{"type":"number","value":2},{"type":"number","value":3}]},"key":null,"value":{"type":"var","value":"e"}}`, + }, + "location excluded": { + Every: func() *Every { + e := every.Copy() + e.jsonOptions = JSONOptions{MarshalOptions: JSONMarshalOptions{IncludeLocation: NodeToggle{Every: false}}} + return e + }(), + ExpectedJSON: `{"body":[{"index":0,"terms":[{"type":"ref","value":[{"type":"var","value":"equal"}]},{"type":"var","value":"e"},{"type":"number","value":1}]}],"domain":{"type":"array","value":[{"type":"number","value":1},{"type":"number","value":2},{"type":"number","value":3}]},"key":null,"value":{"type":"var","value":"e"}}`, + }, + "location included": { + Every: func() *Every { + e := every.Copy() + e.jsonOptions = JSONOptions{MarshalOptions: JSONMarshalOptions{IncludeLocation: NodeToggle{Every: true}}} + return e + }(), + ExpectedJSON: `{"body":[{"index":0,"terms":[{"type":"ref","value":[{"type":"var","value":"equal"}]},{"type":"var","value":"e"},{"type":"number","value":1}]}],"domain":{"type":"array","value":[{"type":"number","value":1},{"type":"number","value":2},{"type":"number","value":3}]},"key":null,"location":{"file":"example.rego","row":7,"col":2},"value":{"type":"var","value":"e"}}`, + }, + } + + for name, data := range testCases { + t.Run(name, func(t *testing.T) { + bs := util.MustMarshalJSON(data.Every) + got := string(bs) + exp := data.ExpectedJSON + + if got != exp { + t.Fatalf("expected:\n%s got\n%s", exp, got) + } + }) + } +} + +func TestWith_MarshalJSON(t *testing.T) { + + rawModule := ` +package foo + +a {input} + +b { + a with input as 1 +} +` + + module, err := ParseModuleWithOpts("example.rego", rawModule, ParserOptions{}) + if err != nil { + t.Fatal(err) + } + + with := module.Rules[1].Body[0].With[0] + + testCases := map[string]struct { + With *With + ExpectedJSON string + }{ + "base case": { + With: with, + ExpectedJSON: `{"target":{"type":"ref","value":[{"type":"var","value":"input"}]},"value":{"type":"number","value":1}}`, + }, + "location excluded": { + With: func() *With { + w := with.Copy() + w.jsonOptions = JSONOptions{MarshalOptions: JSONMarshalOptions{IncludeLocation: NodeToggle{With: false}}} + return w + }(), + ExpectedJSON: `{"target":{"type":"ref","value":[{"type":"var","value":"input"}]},"value":{"type":"number","value":1}}`, + }, + "location included": { + With: func() *With { + w := with.Copy() + w.jsonOptions = JSONOptions{MarshalOptions: JSONMarshalOptions{IncludeLocation: NodeToggle{With: true}}} + return w + }(), + ExpectedJSON: `{"location":{"file":"example.rego","row":7,"col":4},"target":{"type":"ref","value":[{"type":"var","value":"input"}]},"value":{"type":"number","value":1}}`, + }, + } + + for name, data := range testCases { + t.Run(name, func(t *testing.T) { + bs := util.MustMarshalJSON(data.With) + got := string(bs) + exp := data.ExpectedJSON + + if got != exp { + t.Fatalf("expected:\n%s got\n%s", exp, got) + } + }) + } +} + +func TestAnnotations_MarshalJSON(t *testing.T) { + + testCases := map[string]struct { + Annotations *Annotations + ExpectedJSON string + }{ + "base case": { + Annotations: &Annotations{ + Scope: "rule", + Title: "My rule", + Entrypoint: true, + Organizations: []string{"org1"}, + Description: "My desc", + Custom: map[string]interface{}{ + "foo": "bar", + }, + Location: NewLocation([]byte{}, "example.rego", 1, 4), + }, + ExpectedJSON: `{"custom":{"foo":"bar"},"description":"My desc","entrypoint":true,"organizations":["org1"],"scope":"rule","title":"My rule"}`, + }, + "location excluded": { + Annotations: &Annotations{ + Scope: "rule", + Title: "My rule", + Entrypoint: true, + Organizations: []string{"org1"}, + Description: "My desc", + Custom: map[string]interface{}{ + "foo": "bar", + }, + Location: NewLocation([]byte{}, "example.rego", 1, 4), + + jsonOptions: JSONOptions{ + MarshalOptions: JSONMarshalOptions{ + IncludeLocation: NodeToggle{Annotations: false}, + }, + }, + }, + ExpectedJSON: `{"custom":{"foo":"bar"},"description":"My desc","entrypoint":true,"organizations":["org1"],"scope":"rule","title":"My rule"}`, + }, + "location included": { + Annotations: &Annotations{ + Scope: "rule", + Title: "My rule", + Entrypoint: true, + Organizations: []string{"org1"}, + Description: "My desc", + Custom: map[string]interface{}{ + "foo": "bar", + }, + Location: NewLocation([]byte{}, "example.rego", 1, 4), + + jsonOptions: JSONOptions{ + MarshalOptions: JSONMarshalOptions{ + IncludeLocation: NodeToggle{Annotations: true}, + }, + }, + }, + ExpectedJSON: `{"custom":{"foo":"bar"},"description":"My desc","entrypoint":true,"location":{"file":"example.rego","row":1,"col":4},"organizations":["org1"],"scope":"rule","title":"My rule"}`, + }, + } + + for name, data := range testCases { + t.Run(name, func(t *testing.T) { + bs := util.MustMarshalJSON(data.Annotations) + got := string(bs) + exp := data.ExpectedJSON + + if got != exp { + t.Fatalf("expected:\n%s got\n%s", exp, got) + } + }) + } +} + +func TestAnnotationsRef_MarshalJSON(t *testing.T) { + + testCases := map[string]struct { + AnnotationsRef *AnnotationsRef + ExpectedJSON string + }{ + "base case": { + AnnotationsRef: &AnnotationsRef{ + Path: []*Term{}, + // using an empty annotations object here since Annotations marshalling is tested separately + Annotations: &Annotations{}, + Location: NewLocation([]byte{}, "example.rego", 1, 4), + }, + ExpectedJSON: `{"annotations":{"scope":""},"path":[]}`, + }, + "location excluded": { + AnnotationsRef: &AnnotationsRef{ + Path: []*Term{}, + Annotations: &Annotations{}, + Location: NewLocation([]byte{}, "example.rego", 1, 4), + jsonOptions: JSONOptions{ + MarshalOptions: JSONMarshalOptions{ + IncludeLocation: NodeToggle{AnnotationsRef: false}, + }, + }, + }, + ExpectedJSON: `{"annotations":{"scope":""},"path":[]}`, + }, + "location included": { + AnnotationsRef: &AnnotationsRef{ + Path: []*Term{}, + Annotations: &Annotations{}, + Location: NewLocation([]byte{}, "example.rego", 1, 4), + + jsonOptions: JSONOptions{ + MarshalOptions: JSONMarshalOptions{ + IncludeLocation: NodeToggle{AnnotationsRef: true}, + }, + }, + }, + ExpectedJSON: `{"annotations":{"scope":""},"location":{"file":"example.rego","row":1,"col":4},"path":[]}`, + }, + } + + for name, data := range testCases { + t.Run(name, func(t *testing.T) { + bs := util.MustMarshalJSON(data.AnnotationsRef) + got := string(bs) + exp := data.ExpectedJSON + + if got != exp { + t.Fatalf("expected:\n%s got\n%s", exp, got) + } + }) + } +} + +func TestNewAnnotationsRef_JSONOptions(t *testing.T) { + tests := []struct { + note string + module string + expected []string + options ParserOptions + }{ + { + note: "all JSON marshaller options set to true", + module: `# METADATA +# title: pkg +# description: pkg +# organizations: +# - pkg +# related_resources: +# - https://pkg +# authors: +# - pkg +# schemas: +# - input.foo: {"type": "boolean"} +# custom: +# pkg: pkg +package test + +# METADATA +# scope: document +# title: doc +# description: doc +# organizations: +# - doc +# related_resources: +# - https://doc +# authors: +# - doc +# schemas: +# - input.bar: {"type": "integer"} +# custom: +# doc: doc + +# METADATA +# title: rule +# description: rule +# organizations: +# - rule +# related_resources: +# - https://rule +# authors: +# - rule +# schemas: +# - input.baz: {"type": "string"} +# custom: +# rule: rule +p = 1`, + options: ParserOptions{ + ProcessAnnotation: true, + JSONOptions: &JSONOptions{ + MarshalOptions: JSONMarshalOptions{ + IncludeLocation: NodeToggle{ + Term: true, + Package: true, + Comment: true, + Import: true, + Rule: true, + Head: true, + Expr: true, + SomeDecl: true, + Every: true, + With: true, + Annotations: true, + AnnotationsRef: true, + }, + }, + }, + }, + expected: []string{ + `{"annotations":{"authors":[{"name":"pkg"}],"custom":{"pkg":"pkg"},"description":"pkg","location":{"file":"","row":1,"col":1},"organizations":["pkg"],"related_resources":[{"ref":"https://pkg"}],"schemas":[{"path":[{"type":"var","value":"input"},{"type":"string","value":"foo"}],"definition":{"type":"boolean"}}],"scope":"package","title":"pkg"},"location":{"file":"","row":14,"col":1},"path":[{"location":{"file":"","row":14,"col":9},"type":"var","value":"data"},{"location":{"file":"","row":14,"col":9},"type":"string","value":"test"}]}`, + `{"annotations":{"authors":[{"name":"doc"}],"custom":{"doc":"doc"},"description":"doc","location":{"file":"","row":16,"col":1},"organizations":["doc"],"related_resources":[{"ref":"https://doc"}],"schemas":[{"path":[{"type":"var","value":"input"},{"type":"string","value":"bar"}],"definition":{"type":"integer"}}],"scope":"document","title":"doc"},"location":{"file":"","row":44,"col":1},"path":[{"location":{"file":"","row":14,"col":9},"type":"var","value":"data"},{"location":{"file":"","row":14,"col":9},"type":"string","value":"test"},{"type":"string","value":"p"}]}`, + `{"annotations":{"authors":[{"name":"rule"}],"custom":{"rule":"rule"},"description":"rule","location":{"file":"","row":31,"col":1},"organizations":["rule"],"related_resources":[{"ref":"https://rule"}],"schemas":[{"path":[{"type":"var","value":"input"},{"type":"string","value":"baz"}],"definition":{"type":"string"}}],"scope":"rule","title":"rule"},"location":{"file":"","row":44,"col":1},"path":[{"location":{"file":"","row":14,"col":9},"type":"var","value":"data"},{"location":{"file":"","row":14,"col":9},"type":"string","value":"test"},{"type":"string","value":"p"}]}`, + }, + }, + } + + for _, tc := range tests { + t.Run(tc.note, func(t *testing.T) { + module := MustParseModuleWithOpts(tc.module, tc.options) + + if len(tc.expected) != len(module.Annotations) { + t.Fatalf("expected %d annotations got %d", len(tc.expected), len(module.Annotations)) + } + + for i, a := range module.Annotations { + ref := NewAnnotationsRef(a) + + bytes, err := json.Marshal(ref) + if err != nil { + t.Fatal(err) + } + + got := string(bytes) + expected := tc.expected[i] + + if got != expected { + t.Fatalf("expected:\n%s got\n%s", expected, got) + } + } + + }) + } +} diff --git a/ast/parser.go b/ast/parser.go index 0990a6a539..58e9e73c8a 100644 --- a/ast/parser.go +++ b/ast/parser.go @@ -10,11 +10,13 @@ import ( "fmt" "io" "math/big" + "net/url" "regexp" + "sort" "strconv" "strings" + "unicode/utf8" - "github.com/pkg/errors" "gopkg.in/yaml.v2" "github.com/open-policy-agent/opa/ast/internal/scanner" @@ -94,10 +96,42 @@ func (e *parsedTermCacheItem) String() string { // ParserOptions defines the options for parsing Rego statements. type ParserOptions struct { - Capabilities *Capabilities - ProcessAnnotation bool - AllFutureKeywords bool - FutureKeywords []string + Capabilities *Capabilities + ProcessAnnotation bool + AllFutureKeywords bool + FutureKeywords []string + SkipRules bool + JSONOptions *JSONOptions + unreleasedKeywords bool // TODO(sr): cleanup +} + +// JSONOptions defines the options for JSON operations, +// currently only marshaling can be configured +type JSONOptions struct { + MarshalOptions JSONMarshalOptions +} + +// JSONMarshalOptions defines the options for JSON marshaling, +// currently only toggling the marshaling of location information is supported +type JSONMarshalOptions struct { + IncludeLocation NodeToggle +} + +// NodeToggle is a generic struct to allow the toggling of +// settings for different ast node types +type NodeToggle struct { + Term bool + Package bool + Comment bool + Import bool + Rule bool + Head bool + Expr bool + SomeDecl bool + Every bool + With bool + Annotations bool + AnnotationsRef bool } // NewParser creates and initializes a Parser. @@ -133,12 +167,12 @@ func (p *Parser) WithProcessAnnotation(processAnnotation bool) *Parser { // WithFutureKeywords enables "future" keywords, i.e., keywords that can // be imported via // -// import future.keywords.kw -// import future.keywords.other +// import future.keywords.kw +// import future.keywords.other // // but in a more direct way. The equivalent of this import would be // -// WithFutureKeywords("kw", "other") +// WithFutureKeywords("kw", "other") func (p *Parser) WithFutureKeywords(kws ...string) *Parser { p.po.FutureKeywords = kws return p @@ -147,25 +181,38 @@ func (p *Parser) WithFutureKeywords(kws ...string) *Parser { // WithAllFutureKeywords enables all "future" keywords, i.e., the // ParserOption equivalent of // -// import future.keywords +// import future.keywords func (p *Parser) WithAllFutureKeywords(yes bool) *Parser { p.po.AllFutureKeywords = yes return p } +// withUnreleasedKeywords allows using keywords that haven't surfaced +// as future keywords (see above) yet, but have tests that require +// them to be parsed +func (p *Parser) withUnreleasedKeywords(yes bool) *Parser { + p.po.unreleasedKeywords = yes + return p +} + // WithCapabilities sets the capabilities structure on the parser. func (p *Parser) WithCapabilities(c *Capabilities) *Parser { p.po.Capabilities = c return p } -const ( - annotationScopePackage = "package" - annotationScopeImport = "import" - annotationScopeRule = "rule" - annotationScopeDocument = "document" - annotationScopeSubpackages = "subpackages" -) +// WithSkipRules instructs the parser not to attempt to parse Rule statements. +func (p *Parser) WithSkipRules(skip bool) *Parser { + p.po.SkipRules = skip + return p +} + +// WithJSONOptions sets the JSONOptions which will be set on nodes to configure +// their JSON marshaling behavior. +func (p *Parser) WithJSONOptions(jsonOptions *JSONOptions) *Parser { + p.po.JSONOptions = jsonOptions + return p +} func (p *Parser) parsedTermCacheLookup() (*Term, *state) { l := p.s.loc.Offset @@ -206,6 +253,24 @@ func (p *Parser) futureParser() *Parser { return &q } +// presentParser returns a shallow copy of `p` with an empty +// cache, and a scanner that knows none of the future keywords. +// It is used to successfully parse keyword imports, like +// +// import future.keywords.in +// +// even when the parser has already been informed about the +// future keyword "in". This parser won't error out because +// "in" is an identifier. +func (p *Parser) presentParser() (*Parser, map[string]tokens.Token) { + var cpy map[string]tokens.Token + q := *p + q.s = p.save() + q.s.s, cpy = p.s.s.WithoutKeywords(futureKeywords) + q.cache = parsedTermCache{} + return &q, cpy +} + // Parse will read the Rego source and parse statements and // comments as they are found. Any errors encountered while // parsing will be accumulated and returned as a list of Errors. @@ -301,18 +366,21 @@ func (p *Parser) Parse() ([]Statement, []*Comment, Errors) { } p.restore(s) - s = p.save() - if rules := p.parseRules(); rules != nil { - for i := range rules { - stmts = append(stmts, rules[i]) + if !p.po.SkipRules { + s = p.save() + + if rules := p.parseRules(); rules != nil { + for i := range rules { + stmts = append(stmts, rules[i]) + } + continue + } else if len(p.s.errors) > 0 { + break } - continue - } else if len(p.s.errors) > 0 { - break - } - p.restore(s) + p.restore(s) + } if body := p.parseQuery(true, tokens.EOF); body != nil { stmts = append(stmts, body) @@ -326,39 +394,72 @@ func (p *Parser) Parse() ([]Statement, []*Comment, Errors) { stmts = p.parseAnnotations(stmts) } + if p.po.JSONOptions != nil { + for i := range stmts { + vis := NewGenericVisitor(func(x interface{}) bool { + if x, ok := x.(customJSON); ok { + x.setJSONOptions(*p.po.JSONOptions) + } + return false + }) + + vis.Walk(stmts[i]) + } + } + return stmts, p.s.comments, p.s.errors } func (p *Parser) parseAnnotations(stmts []Statement) []Statement { + annotStmts, errs := parseAnnotations(p.s.comments) + for _, err := range errs { + p.error(err.Location, err.Message) + } + + for _, annotStmt := range annotStmts { + stmts = append(stmts, annotStmt) + } + + return stmts +} + +func parseAnnotations(comments []*Comment) ([]*Annotations, Errors) { + var hint = []byte("METADATA") var curr *metadataParser var blocks []*metadataParser - for i := 0; i < len(p.s.comments); i++ { + for i := 0; i < len(comments); i++ { if curr != nil { - if p.s.comments[i].Location.Row == p.s.comments[i-1].Location.Row+1 && p.s.comments[i].Location.Col == 1 { - curr.Append(p.s.comments[i]) + if comments[i].Location.Row == comments[i-1].Location.Row+1 && comments[i].Location.Col == 1 { + curr.Append(comments[i]) continue } curr = nil } - if bytes.HasPrefix(bytes.TrimSpace(p.s.comments[i].Text), hint) { - curr = newMetadataParser(p.s.comments[i].Location) + if bytes.HasPrefix(bytes.TrimSpace(comments[i].Text), hint) { + curr = newMetadataParser(comments[i].Location) blocks = append(blocks, curr) } } + var stmts []*Annotations + var errs Errors for _, b := range blocks { a, err := b.Parse() if err != nil { - p.error(b.loc, err.Error()) + errs = append(errs, &Error{ + Code: ParseErr, + Message: err.Error(), + Location: b.loc, + }) } else { stmts = append(stmts, a) } } - return stmts + return stmts, errs } func (p *Parser) parsePackage() *Package { @@ -433,8 +534,8 @@ func (p *Parser) parseImport() *Import { p.error(p.s.Loc(), "expected ident") return nil } - - term := p.parseTerm() + q, prev := p.presentParser() + term := q.parseTerm() if term != nil { switch v := term.Value.(type) { case Var: @@ -449,6 +550,9 @@ func (p *Parser) parseImport() *Import { imp.Path = term } } + // keep advanced parser state, reset known keywords + p.s = q.s + p.s.s = q.s.s.WithKeywords(prev) if imp.Path == nil { p.error(p.s.Loc(), "expected path") @@ -500,7 +604,8 @@ func (p *Parser) parseRules() []*Rule { return nil } - if rule.Head = p.parseHead(rule.Default); rule.Head == nil { + usesContains := false + if rule.Head, usesContains = p.parseHead(rule.Default); rule.Head == nil { return nil } @@ -513,25 +618,89 @@ func (p *Parser) parseRules() []*Rule { return []*Rule{&rule} } - if p.s.tok == tokens.LBrace { + if usesContains && !rule.Head.Reference.IsGround() { + p.error(p.s.Loc(), "multi-value rules need ground refs") + return nil + } + + // back-compat with `p[x] { ... }`` + hasIf := p.s.tok == tokens.If + + // p[x] if ... becomes a single-value rule p[x] + if hasIf && !usesContains && len(rule.Head.Ref()) == 2 { + if rule.Head.Value == nil { + rule.Head.Value = BooleanTerm(true).SetLocation(rule.Head.Location) + } else { + // p[x] = y if becomes a single-value rule p[x] with value y, but needs name for compat + v, ok := rule.Head.Ref()[0].Value.(Var) + if !ok { + return nil + } + rule.Head.Name = v + } + } + + // p[x] becomes a multi-value rule p + if !hasIf && !usesContains && + len(rule.Head.Args) == 0 && // not a function + len(rule.Head.Ref()) == 2 { // ref like 'p[x]' + v, ok := rule.Head.Ref()[0].Value.(Var) + if !ok { + return nil + } + rule.Head.Name = v + rule.Head.Key = rule.Head.Ref()[1] + if rule.Head.Value == nil { + rule.Head.SetRef(rule.Head.Ref()[:len(rule.Head.Ref())-1]) + } + } + + switch { + case hasIf: + p.scan() + s := p.save() + if expr := p.parseLiteral(); expr != nil { + // NOTE(sr): set literals are never false or undefined, so parsing this as + // p if { true } + // ^^^^^^^^ set of one element, `true` + // isn't valid. + isSetLiteral := false + if t, ok := expr.Terms.(*Term); ok { + _, isSetLiteral = t.Value.(Set) + } + // expr.Term is []*Term or Every + if !isSetLiteral { + rule.Body.Append(expr) + break + } + } + + // parsing as literal didn't work out, expect '{ BODY }' + p.restore(s) + fallthrough + + case p.s.tok == tokens.LBrace: p.scan() if rule.Body = p.parseBody(tokens.RBrace); rule.Body == nil { return nil } p.scan() - } else { + + case usesContains: + rule.Body = NewBody(NewExpr(BooleanTerm(true).SetLocation(rule.Location)).SetLocation(rule.Location)) + return []*Rule{&rule} + + default: return nil } if p.s.tok == tokens.Else { - - if rule.Head.Assign { - p.error(p.s.Loc(), "else keyword cannot be used on rule declared with := operator") + if r := rule.Head.Ref(); len(r) > 1 && !r[len(r)-1].Value.IsGround() { + p.error(p.s.Loc(), "else keyword cannot be used on rules with variables in head") return nil } - if rule.Head.Key != nil { - p.error(p.s.Loc(), "else keyword cannot be used on partial rules") + p.error(p.s.Loc(), "else keyword cannot be used on multi-value rules") return nil } @@ -542,9 +711,7 @@ func (p *Parser) parseRules() []*Rule { rule.Location.Text = p.s.Text(rule.Location.Offset, p.s.lastEnd) - var rules []*Rule - - rules = append(rules, &rule) + rules := []*Rule{&rule} for p.s.tok == tokens.LBrace { @@ -570,6 +737,11 @@ func (p *Parser) parseRules() []*Rule { // rule's head AST but have their location // set to the rule body. next.Head = rule.Head.Copy() + for i := range next.Head.Args { + if v, ok := next.Head.Args[i].Value.(Var); ok && v.IsWildcard() { + next.Head.Args[i].Value = Var(p.genwildcard()) + } + } setLocRecursive(next.Head, loc) rules = append(rules, &next) @@ -584,6 +756,11 @@ func (p *Parser) parseElse(head *Head) *Rule { rule.SetLoc(p.s.Loc()) rule.Head = head.Copy() + for i := range rule.Head.Args { + if v, ok := rule.Head.Args[i].Value.(Var); ok && v.IsWildcard() { + rule.Head.Args[i].Value = Var(p.genwildcard()) + } + } rule.Head.SetLoc(p.s.Loc()) defer func() { @@ -593,9 +770,10 @@ func (p *Parser) parseElse(head *Head) *Rule { p.scan() switch p.s.tok { - case tokens.LBrace: + case tokens.LBrace, tokens.If: // no value, but a body follows directly rule.Head.Value = BooleanTerm(true) - case tokens.Unify: + case tokens.Assign, tokens.Unify: + rule.Head.Assign = tokens.Assign == p.s.tok p.scan() rule.Head.Value = p.parseTermInfixCall() if rule.Head.Value == nil { @@ -607,6 +785,30 @@ func (p *Parser) parseElse(head *Head) *Rule { return nil } + hasIf := p.s.tok == tokens.If + + if hasIf { + p.scan() + s := p.save() + if expr := p.parseLiteral(); expr != nil { + // NOTE(sr): set literals are never false or undefined, so parsing this as + // p if false else if { true } + // ^^^^^^^^ set of one element, `true` + // isn't valid. + isSetLiteral := false + if t, ok := expr.Terms.(*Term); ok { + _, isSetLiteral = t.Value.(Set) + } + // expr.Term is []*Term or Every + if !isSetLiteral { + rule.Body.Append(expr) + setLocRecursive(rule.Body, rule.Location) + return &rule + } + } + p.restore(s) + } + if p.s.tok != tokens.LBrace { rule.Body = NewBody(NewExpr(BooleanTerm(true))) setLocRecursive(rule.Body, rule.Location) @@ -629,85 +831,101 @@ func (p *Parser) parseElse(head *Head) *Rule { return &rule } -func (p *Parser) parseHead(defaultRule bool) *Head { - - var head Head - head.SetLoc(p.s.Loc()) +func (p *Parser) parseHead(defaultRule bool) (*Head, bool) { + head := &Head{} + loc := p.s.Loc() defer func() { - head.Location.Text = p.s.Text(head.Location.Offset, p.s.lastEnd) + if head != nil { + head.SetLoc(loc) + head.Location.Text = p.s.Text(head.Location.Offset, p.s.lastEnd) + } }() - if term := p.parseVar(); term != nil { - head.Name = term.Value.(Var) - } else { - p.illegal("expected rule head name") + term := p.parseVar() + if term == nil { + return nil, false } - p.scan() - - if p.s.tok == tokens.LParen { - p.scan() - if p.s.tok != tokens.RParen { - head.Args = p.parseTermList(tokens.RParen, nil) - if head.Args == nil { - return nil + ref := p.parseTermFinish(term, true) + if ref == nil { + p.illegal("expected rule head name") + return nil, false + } + + switch x := ref.Value.(type) { + case Var: + head = NewHead(x) + case Ref: + head = RefHead(x) + case Call: + op, args := x[0], x[1:] + var ref Ref + switch y := op.Value.(type) { + case Var: + ref = Ref{op} + case Ref: + if _, ok := y[0].Value.(Var); !ok { + p.illegal("rule head ref %v invalid", y) + return nil, false } + ref = y } - p.scan() + head = RefHead(ref) + head.Args = append([]*Term{}, args...) - if p.s.tok == tokens.LBrack { - return nil - } + default: + return nil, false } - if p.s.tok == tokens.LBrack { + name := head.Ref().String() + + switch p.s.tok { + case tokens.Contains: // NOTE: no Value for `contains` heads, we return here + // Catch error case of using 'contains' with a function definition rule head. + if head.Args != nil { + p.illegal("the contains keyword can only be used with multi-value rule definitions (e.g., %s contains { ... })", name) + } p.scan() head.Key = p.parseTermInfixCall() if head.Key == nil { - p.illegal("expected rule key term (e.g., %s[] { ... })", head.Name) - } - if p.s.tok != tokens.RBrack { - if _, ok := futureKeywords[head.Name.String()]; ok { - p.hint("`import future.keywords.%[1]s` for '%[1]s' keyword", head.Name.String()) - } - p.illegal("non-terminated rule key") + p.illegal("expected rule key term (e.g., %s contains { ... })", name) } - p.scan() - } + return head, true - if p.s.tok == tokens.Unify { + case tokens.Unify: p.scan() head.Value = p.parseTermInfixCall() if head.Value == nil { - p.illegal("expected rule value term (e.g., %s[%s] = { ... })", head.Name, head.Key) - } - } else if p.s.tok == tokens.Assign { - - if defaultRule { - p.error(p.s.Loc(), "default rules must use = operator (not := operator)") - return nil - } else if head.Key != nil { - p.error(p.s.Loc(), "partial rules must use = operator (not := operator)") - return nil - } else if len(head.Args) > 0 { - p.error(p.s.Loc(), "functions must use = operator (not := operator)") - return nil + // FIX HEAD.String() + p.illegal("expected rule value term (e.g., %s[%s] = { ... })", name, head.Key) } - + case tokens.Assign: + s := p.save() p.scan() head.Assign = true head.Value = p.parseTermInfixCall() if head.Value == nil { - p.illegal("expected rule value term (e.g., %s := { ... })", head.Name) + p.restore(s) + switch { + case len(head.Args) > 0: + p.illegal("expected function value term (e.g., %s(...) := { ... })", name) + case head.Key != nil: + p.illegal("expected partial rule value term (e.g., %s[...] := { ... })", name) + case defaultRule: + p.illegal("expected default rule value term (e.g., default %s := )", name) + default: + p.illegal("expected rule value term (e.g., %s := { ... })", name) + } } } if head.Value == nil && head.Key == nil { - head.Value = BooleanTerm(true).SetLocation(head.Location) + if len(head.Ref()) != 2 || len(head.Args) > 0 { + head.Value = BooleanTerm(true).SetLocation(head.Location) + } } - - return &head + return head, false } func (p *Parser) parseBody(end tokens.Token) Body { @@ -723,7 +941,6 @@ func (p *Parser) parseQuery(requireSemi bool, end tokens.Token) Body { } for { - expr := p.parseLiteral() if expr == nil { return nil @@ -763,14 +980,26 @@ func (p *Parser) parseLiteral() (expr *Expr) { }() var negated bool + if p.s.tok == tokens.Not { + p.scan() + negated = true + } + switch p.s.tok { case tokens.Some: + if negated { + p.illegal("illegal negation of 'some'") + return nil + } return p.parseSome() - case tokens.Not: - p.scan() - negated = true - fallthrough + case tokens.Every: + if negated { + p.illegal("illegal negation of 'every'") + return nil + } + return p.parseEvery() default: + s := p.save() expr := p.parseExpr() if expr != nil { expr.Negated = negated @@ -779,6 +1008,20 @@ func (p *Parser) parseLiteral() (expr *Expr) { return nil } } + // If we find a plain `every` identifier, attempt to parse an every expression, + // add hint if it succeeds. + if term, ok := expr.Terms.(*Term); ok && Var("every").Equal(term.Value) { + var hint bool + t := p.save() + p.restore(s) + if expr := p.futureParser().parseEvery(); expr != nil { + _, hint = expr.Terms.(*Every) + } + p.restore(t) + if hint { + p.hint("`import future.keywords.every` for `every x in xs { ... }` expressions") + } + } return expr } return nil @@ -801,7 +1044,8 @@ func (p *Parser) parseWith() []*With { return nil } - if with.Target = p.parseTerm(); with.Target == nil { + with.Target = p.parseTerm() + if with.Target == nil { return nil } @@ -847,14 +1091,29 @@ func (p *Parser) parseSome() *Expr { if term := p.parseTermInfixCall(); term != nil { if call, ok := term.Value.(Call); ok { switch call[0].String() { - case Member.Name, MemberWithKey.Name: // OK + case Member.Name: + if len(call) != 3 { + p.illegal("illegal domain") + return nil + } + case MemberWithKey.Name: + if len(call) != 4 { + p.illegal("illegal domain") + return nil + } default: p.illegal("expected `x in xs` or `x, y in xs` expression") return nil } decl.Symbols = []*Term{term} - return NewExpr(decl).SetLocation(decl.Location) + expr := NewExpr(decl).SetLocation(decl.Location) + if p.s.tok == tokens.With { + if expr.With = p.parseWith(); expr.With == nil { + return nil + } + } + return expr } } @@ -898,6 +1157,72 @@ func (p *Parser) parseSome() *Expr { return NewExpr(decl).SetLocation(decl.Location) } +func (p *Parser) parseEvery() *Expr { + qb := &Every{} + qb.SetLoc(p.s.Loc()) + + // TODO(sr): We'd get more accurate error messages if we didn't rely on + // parseTermInfixCall here, but parsed "var [, var] in term" manually. + p.scan() + term := p.parseTermInfixCall() + if term == nil { + return nil + } + call, ok := term.Value.(Call) + if !ok { + p.illegal("expected `x[, y] in xs { ... }` expression") + return nil + } + switch call[0].String() { + case Member.Name: // x in xs + if len(call) != 3 { + p.illegal("illegal domain") + return nil + } + qb.Value = call[1] + qb.Domain = call[2] + case MemberWithKey.Name: // k, v in xs + if len(call) != 4 { + p.illegal("illegal domain") + return nil + } + qb.Key = call[1] + qb.Value = call[2] + qb.Domain = call[3] + if _, ok := qb.Key.Value.(Var); !ok { + p.illegal("expected key to be a variable") + return nil + } + default: + p.illegal("expected `x[, y] in xs { ... }` expression") + return nil + } + if _, ok := qb.Value.Value.(Var); !ok { + p.illegal("expected value to be a variable") + return nil + } + if p.s.tok == tokens.LBrace { // every x in xs { ... } + p.scan() + body := p.parseBody(tokens.RBrace) + if body == nil { + return nil + } + p.scan() + qb.Body = body + expr := NewExpr(qb).SetLocation(qb.Location) + + if p.s.tok == tokens.With { + if expr.With = p.parseWith(); expr.With == nil { + return nil + } + } + return expr + } + + p.illegal("missing body") + return nil +} + func (p *Parser) parseExpr() *Expr { lhs := p.parseTermInfixCall() @@ -961,7 +1286,6 @@ func (p *Parser) parseTermIn(lhs *Term, keyVal bool, offset int) *Term { } } p.restore(s) - return nil } if op := p.parseTermOpName(Member.Ref(), tokens.In); op != nil { if rhs := p.parseTermRelation(nil, p.s.loc.Offset); rhs != nil { @@ -1099,7 +1423,7 @@ func (p *Parser) parseTerm() *Term { term = p.parseNumber() case tokens.String: term = p.parseString() - case tokens.Ident: + case tokens.Ident, tokens.Contains: // NOTE(sr): contains anywhere BUT in rule heads gets no special treatment term = p.parseVar() case tokens.LBrack: term = p.parseArray() @@ -1120,17 +1444,18 @@ func (p *Parser) parseTerm() *Term { p.illegalToken() } - term = p.parseTermFinish(term) + term = p.parseTermFinish(term, false) p.parsedTermCachePush(term, s0) return term } -func (p *Parser) parseTermFinish(head *Term) *Term { +func (p *Parser) parseTermFinish(head *Term, skipws bool) *Term { if head == nil { return nil } offset := p.s.loc.Offset - p.scanWS() + p.doScan(skipws) + switch p.s.tok { case tokens.LParen, tokens.Dot, tokens.LBrack: return p.parseRef(head, offset) @@ -1717,7 +2042,10 @@ func (p *Parser) illegal(note string, a ...interface{}) { } tokType := "token" - if p.s.tok >= tokens.Package && p.s.tok <= tokens.False { + if tokens.IsKeyword(p.s.tok) { + tokType = "keyword" + } + if _, ok := futureKeywords[p.s.tok.String()]; ok { tokType = "keyword" } @@ -1848,9 +2176,18 @@ func (p *Parser) validateDefaultRuleValue(rule *Rule) bool { return valid } +// We explicitly use yaml unmarshalling, to accommodate for the '_' in 'related_resources', +// which isn't handled properly by json for some reason. type rawAnnotation struct { - Scope string `json:"scope"` - Schemas []rawSchemaAnnotation `json:"schemas"` + Scope string `yaml:"scope"` + Title string `yaml:"title"` + Entrypoint bool `yaml:"entrypoint"` + Description string `yaml:"description"` + Organizations []string `yaml:"organizations"` + RelatedResources []interface{} `yaml:"related_resources"` + Authors []interface{} `yaml:"authors"` + Schemas []rawSchemaAnnotation `yaml:"schemas"` + Custom map[string]interface{} `yaml:"custom"` } type rawSchemaAnnotation map[string]interface{} @@ -1871,7 +2208,7 @@ func (b *metadataParser) Append(c *Comment) { b.comments = append(b.comments, c) } -var yamlLineErrRegex = regexp.MustCompile(`^yaml: line ([[:digit:]]+):`) +var yamlLineErrRegex = regexp.MustCompile(`^yaml:(?: unmarshal errors:[\n\s]*)? line ([[:digit:]]+):`) func (b *metadataParser) Parse() (*Annotations, error) { @@ -1882,29 +2219,41 @@ func (b *metadataParser) Parse() (*Annotations, error) { } if err := yaml.Unmarshal(b.buf.Bytes(), &raw); err != nil { + var comment *Comment match := yamlLineErrRegex.FindStringSubmatch(err.Error()) if len(match) == 2 { n, err2 := strconv.Atoi(match[1]) if err2 == nil { index := n - 1 // line numbering is 1-based so subtract one from row if index >= len(b.comments) { - b.loc = b.comments[len(b.comments)-1].Location + comment = b.comments[len(b.comments)-1] } else { - b.loc = b.comments[index].Location + comment = b.comments[index] } + b.loc = comment.Location } } - return nil, err + return nil, augmentYamlError(err, b.comments) } var result Annotations + result.comments = b.comments result.Scope = raw.Scope + result.Entrypoint = raw.Entrypoint + result.Title = raw.Title + result.Description = raw.Description + result.Organizations = raw.Organizations - for _, pair := range raw.Schemas { - var k string - var v interface{} - for k, v = range pair { + for _, v := range raw.RelatedResources { + rr, err := parseRelatedResource(v) + if err != nil { + return nil, fmt.Errorf("invalid related-resource definition %s: %w", v, err) } + result.RelatedResources = append(result.RelatedResources, rr) + } + + for _, pair := range raw.Schemas { + k, v := unwrapPair(pair) var a SchemaAnnotation var err error @@ -1923,7 +2272,7 @@ func (b *metadataParser) Parse() (*Annotations, error) { case map[interface{}]interface{}: w, err := convertYAMLMapKeyTypes(v, nil) if err != nil { - return nil, errors.Wrap(err, "invalid schema definition") + return nil, fmt.Errorf("invalid schema definition: %w", err) } a.Definition = &w default: @@ -1933,10 +2282,67 @@ func (b *metadataParser) Parse() (*Annotations, error) { result.Schemas = append(result.Schemas, &a) } + for _, v := range raw.Authors { + author, err := parseAuthor(v) + if err != nil { + return nil, fmt.Errorf("invalid author definition %s: %w", v, err) + } + result.Authors = append(result.Authors, author) + } + + result.Custom = make(map[string]interface{}) + for k, v := range raw.Custom { + val, err := convertYAMLMapKeyTypes(v, nil) + if err != nil { + return nil, err + } + result.Custom[k] = val + } + result.Location = b.loc return &result, nil } +// augmentYamlError augments a YAML error with hints intended to help the user figure out the cause of an otherwise cryptic error. +// These are hints, instead of proper errors, because they are educated guesses, and aren't guaranteed to be correct. +func augmentYamlError(err error, comments []*Comment) error { + // Adding hints for when key/value ':' separator isn't suffixed with a legal YAML space symbol + for _, comment := range comments { + txt := string(comment.Text) + parts := strings.Split(txt, ":") + if len(parts) > 1 { + parts = parts[1:] + var invalidSpaces []string + for partIndex, part := range parts { + if len(part) == 0 && partIndex == len(parts)-1 { + invalidSpaces = []string{} + break + } + + r, _ := utf8.DecodeRuneInString(part) + if r == ' ' || r == '\t' { + invalidSpaces = []string{} + break + } + + invalidSpaces = append(invalidSpaces, fmt.Sprintf("%+q", r)) + } + if len(invalidSpaces) > 0 { + err = fmt.Errorf( + "%s\n Hint: on line %d, symbol(s) %v immediately following a key/value separator ':' is not a legal yaml space character", + err.Error(), comment.Location.Row, invalidSpaces) + } + } + } + return err +} + +func unwrapPair(pair map[string]interface{}) (k string, v interface{}) { + for k, v = range pair { + } + return +} + var errInvalidSchemaRef = fmt.Errorf("invalid schema reference") // NOTE(tsandall): 'schema' is not registered as a root because it's not @@ -1961,6 +2367,96 @@ func parseSchemaRef(s string) (Ref, error) { return nil, errInvalidSchemaRef } +func parseRelatedResource(rr interface{}) (*RelatedResourceAnnotation, error) { + rr, err := convertYAMLMapKeyTypes(rr, nil) + if err != nil { + return nil, err + } + + switch rr := rr.(type) { + case string: + if len(rr) > 0 { + u, err := url.Parse(rr) + if err != nil { + return nil, err + } + return &RelatedResourceAnnotation{Ref: *u}, nil + } + return nil, fmt.Errorf("ref URL may not be empty string") + case map[string]interface{}: + description := strings.TrimSpace(getSafeString(rr, "description")) + ref := strings.TrimSpace(getSafeString(rr, "ref")) + if len(ref) > 0 { + u, err := url.Parse(ref) + if err != nil { + return nil, err + } + return &RelatedResourceAnnotation{Description: description, Ref: *u}, nil + } + return nil, fmt.Errorf("'ref' value required in object") + } + + return nil, fmt.Errorf("invalid value type, must be string or map") +} + +func parseAuthor(a interface{}) (*AuthorAnnotation, error) { + a, err := convertYAMLMapKeyTypes(a, nil) + if err != nil { + return nil, err + } + + switch a := a.(type) { + case string: + return parseAuthorString(a) + case map[string]interface{}: + name := strings.TrimSpace(getSafeString(a, "name")) + email := strings.TrimSpace(getSafeString(a, "email")) + if len(name) > 0 || len(email) > 0 { + return &AuthorAnnotation{name, email}, nil + } + return nil, fmt.Errorf("'name' and/or 'email' values required in object") + } + + return nil, fmt.Errorf("invalid value type, must be string or map") +} + +func getSafeString(m map[string]interface{}, k string) string { + if v, found := m[k]; found { + if s, ok := v.(string); ok { + return s + } + } + return "" +} + +const emailPrefix = "<" +const emailSuffix = ">" + +// parseAuthor parses a string into an AuthorAnnotation. If the last word of the input string is enclosed within <>, +// it is extracted as the author's email. The email may not contain whitelines, as it then will be interpreted as +// multiple words. +func parseAuthorString(s string) (*AuthorAnnotation, error) { + parts := strings.Fields(s) + + if len(parts) == 0 { + return nil, fmt.Errorf("author is an empty string") + } + + namePartCount := len(parts) + trailing := parts[namePartCount-1] + var email string + if len(trailing) >= len(emailPrefix)+len(emailSuffix) && strings.HasPrefix(trailing, emailPrefix) && + strings.HasSuffix(trailing, emailSuffix) { + email = trailing[len(emailPrefix):] + email = email[0 : len(email)-len(emailSuffix)] + namePartCount = namePartCount - 1 + } + + name := strings.Join(parts[0:namePartCount], " ") + + return &AuthorAnnotation{Name: name, Email: email}, nil +} + func convertYAMLMapKeyTypes(x interface{}, path []string) (interface{}, error) { var err error switch x := x.(type) { @@ -1993,23 +2489,22 @@ func convertYAMLMapKeyTypes(x interface{}, path []string) (interface{}, error) { // futureKeywords is the source of truth for future keywords that will // eventually become standard keywords inside of Rego. var futureKeywords = map[string]tokens.Token{ - "in": tokens.In, + "in": tokens.In, + "every": tokens.Every, + "contains": tokens.Contains, + "if": tokens.If, } func (p *Parser) futureImport(imp *Import, allowedFutureKeywords map[string]tokens.Token) { path := imp.Path.Value.(Ref) - if len(path) == 1 { - p.errorf(imp.Path.Location, "invalid import, use `import future.keywords` or `import.future.keywords.in`") - return - } - if !path[1].Equal(StringTerm("keywords")) { + if len(path) == 1 || !path[1].Equal(StringTerm("keywords")) { p.errorf(imp.Path.Location, "invalid import, must be `future.keywords`") return } if imp.Alias != "" { - p.errorf(imp.Path.Location, "future keyword imports cannot be aliased") + p.errorf(imp.Path.Location, "`future` imports cannot be aliased") return } @@ -2017,6 +2512,7 @@ func (p *Parser) futureImport(imp *Import, allowedFutureKeywords map[string]toke for k := range allowedFutureKeywords { kwds = append(kwds, k) } + switch len(path) { case 2: // all keywords imported, nothing to do case 3: // one keyword imported @@ -2028,6 +2524,7 @@ func (p *Parser) futureImport(imp *Import, allowedFutureKeywords map[string]toke keyword := string(kw) _, ok = allowedFutureKeywords[keyword] if !ok { + sort.Strings(kwds) // so the error message is stable p.errorf(imp.Path.Location, "unexpected keyword, must be one of %v", kwds) return } diff --git a/ast/parser_ext.go b/ast/parser_ext.go index 749dbd62e8..99430a8fc5 100644 --- a/ast/parser_ext.go +++ b/ast/parser_ext.go @@ -12,17 +12,22 @@ package ast import ( "bytes" + "errors" "fmt" "strings" "unicode" - - "github.com/pkg/errors" ) // MustParseBody returns a parsed body. // If an error occurs during parsing, panic. func MustParseBody(input string) Body { - parsed, err := ParseBody(input) + return MustParseBodyWithOpts(input, ParserOptions{}) +} + +// MustParseBodyWithOpts returns a parsed body. +// If an error occurs during parsing, panic. +func MustParseBodyWithOpts(input string, opts ParserOptions) Body { + parsed, err := ParseBodyWithOpts(input, opts) if err != nil { panic(err) } @@ -52,7 +57,13 @@ func MustParseImports(input string) []*Import { // MustParseModule returns a parsed module. // If an error occurs during parsing, panic. func MustParseModule(input string) *Module { - parsed, err := ParseModule("", input) + return MustParseModuleWithOpts(input, ParserOptions{}) +} + +// MustParseModuleWithOpts returns a parsed module. +// If an error occurs during parsing, panic. +func MustParseModuleWithOpts(input string, opts ParserOptions) *Module { + parsed, err := ParseModuleWithOpts("", input, opts) if err != nil { panic(err) } @@ -143,12 +154,15 @@ func ParseRuleFromExpr(module *Module, expr *Expr) (*Rule, error) { } if _, ok := expr.Terms.(*SomeDecl); ok { - return nil, errors.New("some declarations cannot be used for rule head") + return nil, errors.New("'some' declarations cannot be used for rule head") } if term, ok := expr.Terms.(*Term); ok { switch v := term.Value.(type) { case Ref: + if len(v) > 2 { // 2+ dots + return ParseCompleteDocRuleWithDotsFromTerm(module, term) + } return ParsePartialSetDocRuleFromTerm(module, term) default: return nil, fmt.Errorf("%v cannot be used for rule name", TypeName(v)) @@ -161,30 +175,15 @@ func ParseRuleFromExpr(module *Module, expr *Expr) (*Rule, error) { return nil, errors.New("expression cannot be used for rule head") } - if expr.IsAssignment() { - - lhs, rhs := expr.Operand(0), expr.Operand(1) - if lhs == nil || rhs == nil { - return nil, errors.New("assignment requires two operands") - } - - rule, err := ParseCompleteDocRuleFromAssignmentExpr(module, lhs, rhs) - - if err == nil { - rule.Location = expr.Location - rule.Head.Location = expr.Location - return rule, nil - } else if _, ok := lhs.Value.(Call); ok { - return nil, errFunctionAssignOperator - } else if _, ok := lhs.Value.(Ref); ok { - return nil, errPartialRuleAssignOperator - } - - return nil, errTermAssignOperator(lhs.Value) - } - if expr.IsEquality() { return parseCompleteRuleFromEq(module, expr) + } else if expr.IsAssignment() { + rule, err := parseCompleteRuleFromEq(module, expr) + if err != nil { + return nil, err + } + rule.Head.Assign = true + return rule, nil } if _, ok := BuiltinMap[expr.Operator().String()]; ok { @@ -211,18 +210,17 @@ func parseCompleteRuleFromEq(module *Module, expr *Expr) (rule *Rule, err error) return nil, errors.New("assignment requires two operands") } - rule, err = ParseCompleteDocRuleFromEqExpr(module, lhs, rhs) - + rule, err = ParseRuleFromCallEqExpr(module, lhs, rhs) if err == nil { return rule, nil } - rule, err = ParseRuleFromCallEqExpr(module, lhs, rhs) + rule, err = ParsePartialObjectDocRuleFromEqExpr(module, lhs, rhs) if err == nil { return rule, nil } - return ParsePartialObjectDocRuleFromEqExpr(module, lhs, rhs) + return ParseCompleteDocRuleFromEqExpr(module, lhs, rhs) } // ParseCompleteDocRuleFromAssignmentExpr returns a rule if the expression can @@ -243,61 +241,92 @@ func ParseCompleteDocRuleFromAssignmentExpr(module *Module, lhs, rhs *Term) (*Ru // ParseCompleteDocRuleFromEqExpr returns a rule if the expression can be // interpreted as a complete document definition. func ParseCompleteDocRuleFromEqExpr(module *Module, lhs, rhs *Term) (*Rule, error) { + var head *Head - var name Var - - if RootDocumentRefs.Contains(lhs) { - name = lhs.Value.(Ref)[0].Value.(Var) - } else if v, ok := lhs.Value.(Var); ok { - name = v + if v, ok := lhs.Value.(Var); ok { + head = NewHead(v) + } else if r, ok := lhs.Value.(Ref); ok { // groundness ? + if _, ok := r[0].Value.(Var); !ok { + return nil, fmt.Errorf("invalid rule head: %v", r) + } + head = RefHead(r) + if len(r) > 1 && !r[len(r)-1].IsGround() { + return nil, fmt.Errorf("ref not ground") + } } else { return nil, fmt.Errorf("%v cannot be used for rule name", TypeName(lhs.Value)) } + head.Value = rhs + head.Location = lhs.Location + head.setJSONOptions(lhs.jsonOptions) - rule := &Rule{ - Location: lhs.Location, - Head: &Head{ - Location: lhs.Location, - Name: name, - Value: rhs, - }, - Body: NewBody( - NewExpr(BooleanTerm(true).SetLocation(rhs.Location)).SetLocation(rhs.Location), - ), - Module: module, + body := NewBody(NewExpr(BooleanTerm(true).SetLocation(rhs.Location)).SetLocation(rhs.Location)) + setJSONOptions(body, &rhs.jsonOptions) + + return &Rule{ + Location: lhs.Location, + Head: head, + Body: body, + Module: module, + jsonOptions: lhs.jsonOptions, + }, nil +} + +func ParseCompleteDocRuleWithDotsFromTerm(module *Module, term *Term) (*Rule, error) { + ref, ok := term.Value.(Ref) + if !ok { + return nil, fmt.Errorf("%v cannot be used for rule name", TypeName(term.Value)) } - return rule, nil + if _, ok := ref[0].Value.(Var); !ok { + return nil, fmt.Errorf("invalid rule head: %v", ref) + } + head := RefHead(ref, BooleanTerm(true).SetLocation(term.Location)) + head.Location = term.Location + head.jsonOptions = term.jsonOptions + + body := NewBody(NewExpr(BooleanTerm(true).SetLocation(term.Location)).SetLocation(term.Location)) + setJSONOptions(body, &term.jsonOptions) + + return &Rule{ + Location: term.Location, + Head: head, + Body: body, + Module: module, + + jsonOptions: term.jsonOptions, + }, nil } // ParsePartialObjectDocRuleFromEqExpr returns a rule if the expression can be // interpreted as a partial object document definition. func ParsePartialObjectDocRuleFromEqExpr(module *Module, lhs, rhs *Term) (*Rule, error) { - ref, ok := lhs.Value.(Ref) - if !ok || len(ref) != 2 { + if !ok { return nil, fmt.Errorf("%v cannot be used as rule name", TypeName(lhs.Value)) } if _, ok := ref[0].Value.(Var); !ok { - return nil, fmt.Errorf("%vs cannot be used as rule name", TypeName(ref[0].Value)) + return nil, fmt.Errorf("invalid rule head: %v", ref) } - name := ref[0].Value.(Var) - key := ref[1] + head := RefHead(ref, rhs) + if len(ref) == 2 { // backcompat for naked `foo.bar = "baz"` statements + head.Name = ref[0].Value.(Var) + head.Key = ref[1] + } + head.Location = rhs.Location + head.jsonOptions = rhs.jsonOptions + + body := NewBody(NewExpr(BooleanTerm(true).SetLocation(rhs.Location)).SetLocation(rhs.Location)) + setJSONOptions(body, &rhs.jsonOptions) rule := &Rule{ - Location: rhs.Location, - Head: &Head{ - Location: rhs.Location, - Name: name, - Key: key, - Value: rhs, - }, - Body: NewBody( - NewExpr(BooleanTerm(true).SetLocation(rhs.Location)).SetLocation(rhs.Location), - ), - Module: module, + Location: rhs.Location, + Head: head, + Body: body, + Module: module, + jsonOptions: rhs.jsonOptions, } return rule, nil @@ -308,30 +337,34 @@ func ParsePartialObjectDocRuleFromEqExpr(module *Module, lhs, rhs *Term) (*Rule, func ParsePartialSetDocRuleFromTerm(module *Module, term *Term) (*Rule, error) { ref, ok := term.Value.(Ref) - if !ok { + if !ok || len(ref) == 1 { return nil, fmt.Errorf("%vs cannot be used for rule head", TypeName(term.Value)) } - - if len(ref) != 2 { - return nil, fmt.Errorf("refs cannot be used for rule") + if _, ok := ref[0].Value.(Var); !ok { + return nil, fmt.Errorf("invalid rule head: %v", ref) } - name, ok := ref[0].Value.(Var) - if !ok { - return nil, fmt.Errorf("%vs cannot be used as rule name", TypeName(ref[0].Value)) + head := RefHead(ref) + if len(ref) == 2 { + v, ok := ref[0].Value.(Var) + if !ok { + return nil, fmt.Errorf("%vs cannot be used for rule head", TypeName(term.Value)) + } + head = NewHead(v) + head.Key = ref[1] } + head.Location = term.Location + head.jsonOptions = term.jsonOptions + + body := NewBody(NewExpr(BooleanTerm(true).SetLocation(term.Location)).SetLocation(term.Location)) + setJSONOptions(body, &term.jsonOptions) rule := &Rule{ - Location: term.Location, - Head: &Head{ - Location: term.Location, - Name: name, - Key: ref[1], - }, - Body: NewBody( - NewExpr(BooleanTerm(true).SetLocation(term.Location)).SetLocation(term.Location), - ), - Module: module, + Location: term.Location, + Head: head, + Body: body, + Module: module, + jsonOptions: term.jsonOptions, } return rule, nil @@ -350,22 +383,24 @@ func ParseRuleFromCallEqExpr(module *Module, lhs, rhs *Term) (*Rule, error) { if !ok { return nil, fmt.Errorf("%vs cannot be used in function signature", TypeName(call[0].Value)) } - - name, ok := ref[0].Value.(Var) - if !ok { - return nil, fmt.Errorf("%vs cannot be used in function signature", TypeName(ref[0].Value)) + if _, ok := ref[0].Value.(Var); !ok { + return nil, fmt.Errorf("invalid rule head: %v", ref) } + head := RefHead(ref, rhs) + head.Location = lhs.Location + head.Args = Args(call[1:]) + head.jsonOptions = lhs.jsonOptions + + body := NewBody(NewExpr(BooleanTerm(true).SetLocation(rhs.Location)).SetLocation(rhs.Location)) + setJSONOptions(body, &rhs.jsonOptions) + rule := &Rule{ - Location: lhs.Location, - Head: &Head{ - Location: lhs.Location, - Name: name, - Args: Args(call[1:]), - Value: rhs, - }, - Body: NewBody(NewExpr(BooleanTerm(true).SetLocation(rhs.Location)).SetLocation(rhs.Location)), - Module: module, + Location: lhs.Location, + Head: head, + Body: body, + Module: module, + jsonOptions: lhs.jsonOptions, } return rule, nil @@ -380,19 +415,24 @@ func ParseRuleFromCallExpr(module *Module, terms []*Term) (*Rule, error) { } loc := terms[0].Location - args := terms[1:] - value := BooleanTerm(true).SetLocation(loc) + ref := terms[0].Value.(Ref) + if _, ok := ref[0].Value.(Var); !ok { + return nil, fmt.Errorf("invalid rule head: %v", ref) + } + head := RefHead(ref, BooleanTerm(true).SetLocation(loc)) + head.Location = loc + head.Args = terms[1:] + head.jsonOptions = terms[0].jsonOptions + + body := NewBody(NewExpr(BooleanTerm(true).SetLocation(loc)).SetLocation(loc)) + setJSONOptions(body, &terms[0].jsonOptions) rule := &Rule{ - Location: loc, - Head: &Head{ - Location: loc, - Name: Var(terms[0].String()), - Args: args, - Value: value, - }, - Module: module, - Body: NewBody(NewExpr(BooleanTerm(true).SetLocation(loc)).SetLocation(loc)), + Location: loc, + Head: head, + Module: module, + Body: body, + jsonOptions: terms[0].jsonOptions, } return rule, nil } @@ -435,10 +475,13 @@ func ParseModuleWithOpts(filename, input string, popts ParserOptions) (*Module, // ParseBody returns exactly one body. // If multiple bodies are parsed, an error is returned. func ParseBody(input string) (Body, error) { - return ParseBodyWithOpts(input, ParserOptions{}) + return ParseBodyWithOpts(input, ParserOptions{SkipRules: true}) } +// ParseBodyWithOpts returns exactly one body. It does _not_ set SkipRules: true on its own, +// but respects whatever ParserOptions it's been given. func ParseBodyWithOpts(input string, popts ParserOptions) (Body, error) { + stmts, _, err := ParseStatementsWithOpts("", input, popts) if err != nil { return nil, err @@ -467,7 +510,7 @@ func ParseBodyWithOpts(input string, popts ParserOptions) (Body, error) { func ParseExpr(input string) (*Expr, error) { body, err := ParseBody(input) if err != nil { - return nil, errors.Wrap(err, "failed to parse expression") + return nil, fmt.Errorf("failed to parse expression: %w", err) } if len(body) != 1 { return nil, fmt.Errorf("expected exactly one expression but got: %v", body) @@ -494,7 +537,7 @@ func ParsePackage(input string) (*Package, error) { func ParseTerm(input string) (*Term, error) { body, err := ParseBody(input) if err != nil { - return nil, errors.Wrap(err, "failed to parse term") + return nil, fmt.Errorf("failed to parse term: %w", err) } if len(body) != 1 { return nil, fmt.Errorf("expected exactly one term but got: %v", body) @@ -510,7 +553,7 @@ func ParseTerm(input string) (*Term, error) { func ParseRef(input string) (Ref, error) { term, err := ParseTerm(input) if err != nil { - return nil, errors.Wrap(err, "failed to parse ref") + return nil, fmt.Errorf("failed to parse ref: %w", err) } ref, ok := term.Value.(Ref) if !ok { @@ -519,15 +562,15 @@ func ParseRef(input string) (Ref, error) { return ref, nil } -// ParseRule returns exactly one rule. +// ParseRuleWithOpts returns exactly one rule. // If multiple rules are parsed, an error is returned. -func ParseRule(input string) (*Rule, error) { - stmts, _, err := ParseStatements("", input) +func ParseRuleWithOpts(input string, opts ParserOptions) (*Rule, error) { + stmts, _, err := ParseStatementsWithOpts("", input, opts) if err != nil { return nil, err } if len(stmts) != 1 { - return nil, fmt.Errorf("expected exactly one statement (rule)") + return nil, fmt.Errorf("expected exactly one statement (rule), got %v = %T, %T", stmts, stmts[0], stmts[1]) } rule, ok := stmts[0].(*Rule) if !ok { @@ -536,6 +579,12 @@ func ParseRule(input string) (*Rule, error) { return rule, nil } +// ParseRule returns exactly one rule. +// If multiple rules are parsed, an error is returned. +func ParseRule(input string) (*Rule, error) { + return ParseRuleWithOpts(input, ParserOptions{}) +} + // ParseStatement returns exactly one statement. // A statement might be a term, expression, rule, etc. Regardless, // this function expects *exactly* one statement. If multiple @@ -566,7 +615,10 @@ func ParseStatementsWithOpts(filename, input string, popts ParserOptions) ([]Sta WithProcessAnnotation(popts.ProcessAnnotation). WithFutureKeywords(popts.FutureKeywords...). WithAllFutureKeywords(popts.AllFutureKeywords). - WithCapabilities(popts.Capabilities) + WithCapabilities(popts.Capabilities). + WithSkipRules(popts.SkipRules). + WithJSONOptions(popts.JSONOptions). + withUnreleasedKeywords(popts.unreleasedKeywords) stmts, comments, errs := parser.Parse() @@ -585,14 +637,15 @@ func parseModule(filename string, stmts []Statement, comments []*Comment) (*Modu var errs Errors - _package, ok := stmts[0].(*Package) + pkg, ok := stmts[0].(*Package) if !ok { loc := stmts[0].Loc() errs = append(errs, NewError(ParseErr, loc, "package expected")) } mod := &Module{ - Package: _package, + Package: pkg, + stmts: stmts, } // The comments slice only holds comments that were not their own statements. @@ -609,14 +662,14 @@ func parseModule(filename string, stmts []Statement, comments []*Comment) (*Modu rule, err := ParseRuleFromBody(mod, stmt) if err != nil { errs = append(errs, NewError(ParseErr, stmt[0].Location, err.Error())) - } else { - mod.Rules = append(mod.Rules, rule) - - // NOTE(tsandall): the statement should now be interpreted as a - // rule so update the statement list. This is important for the - // logic below that associates annotations with statements. - stmts[i+1] = rule + continue } + mod.Rules = append(mod.Rules, rule) + + // NOTE(tsandall): the statement should now be interpreted as a + // rule so update the statement list. This is important for the + // logic below that associates annotations with statements. + stmts[i+1] = rule case *Package: errs = append(errs, NewError(ParseErr, stmt.Loc(), "unexpected package")) case *Annotations: @@ -632,34 +685,7 @@ func parseModule(filename string, stmts []Statement, comments []*Comment) (*Modu return nil, errs } - // Find first non-annotation statement following each annotation and attach - // the annotation to that statement. - for _, a := range mod.Annotations { - for _, stmt := range stmts { - _, ok := stmt.(*Annotations) - if !ok { - if stmt.Loc().Row > a.Location.Row { - a.node = stmt - break - } - } - } - - if a.Scope == "" { - switch a.node.(type) { - case *Rule: - a.Scope = annotationScopeRule - case *Package: - a.Scope = annotationScopePackage - case *Import: - a.Scope = annotationScopeImport - } - } - - if err := validateAnnotationScopeAttachment(a); err != nil { - errs = append(errs, err) - } - } + errs = append(errs, attachAnnotationsNodes(mod)...) if len(errs) > 0 { return nil, errs @@ -668,24 +694,6 @@ func parseModule(filename string, stmts []Statement, comments []*Comment) (*Modu return mod, nil } -func validateAnnotationScopeAttachment(a *Annotations) *Error { - - switch a.Scope { - case annotationScopeRule, annotationScopeDocument: - if _, ok := a.node.(*Rule); ok { - return nil - } - return newScopeAttachmentErr(a, "rule") - case annotationScopePackage, annotationScopeSubpackages: - if _, ok := a.node.(*Package); ok { - return nil - } - return newScopeAttachmentErr(a, "package") - } - - return NewError(ParseErr, a.Loc(), "invalid annotation scope '%v'", a.Scope) -} - func newScopeAttachmentErr(a *Annotations, want string) *Error { var have string if a.node != nil { @@ -701,6 +709,16 @@ func setRuleModule(rule *Rule, module *Module) { } } +func setJSONOptions(x interface{}, jsonOptions *JSONOptions) { + vis := NewGenericVisitor(func(x interface{}) bool { + if x, ok := x.(customJSON); ok { + x.setJSONOptions(*jsonOptions) + } + return false + }) + vis.Walk(x) +} + // ParserErrorDetail holds additional details for parser errors. type ParserErrorDetail struct { Line string `json:"line"` diff --git a/ast/parser_test.go b/ast/parser_test.go index 56fc775427..5c8422cc7d 100644 --- a/ast/parser_test.go +++ b/ast/parser_test.go @@ -7,6 +7,7 @@ package ast import ( "bytes" "encoding/json" + "errors" "fmt" "reflect" "strings" @@ -692,6 +693,10 @@ func TestSomeDeclExpr(t *testing.T) { }, }, opts) + assertParseErrorContains(t, "not some", "not some x, y in xs", + "unexpected some keyword: illegal negation of 'some'", + opts) + assertParseErrorContains(t, "some + function call", "some f(x)", "expected `x in xs` or `x, y in xs` expression") @@ -767,6 +772,92 @@ func TestSomeDeclExpr(t *testing.T) { NewExpr(VarTerm("x")), ), }) + + assertParseOneExpr(t, "with modifier on expr", "some x, y in input with input as []", + &Expr{ + Terms: &SomeDecl{ + Symbols: []*Term{ + MemberWithKey.Call( + VarTerm("x"), + VarTerm("y"), + NewTerm(MustParseRef("input")), + ), + }, + }, + With: []*With{{Value: ArrayTerm(), Target: NewTerm(MustParseRef("input"))}}, + }, opts) + + assertParseErrorContains(t, "invalid domain (internal.member_2)", "some internal.member_2()", "illegal domain", opts) + assertParseErrorContains(t, "invalid domain (internal.member_3)", "some internal.member_3()", "illegal domain", opts) + +} + +func TestEvery(t *testing.T) { + opts := ParserOptions{unreleasedKeywords: true, FutureKeywords: []string{"every"}} + assertParseOneExpr(t, "simple", "every x in xs { true }", + &Expr{ + Terms: &Every{ + Value: VarTerm("x"), + Domain: VarTerm("xs"), + Body: []*Expr{ + NewExpr(BooleanTerm(true)), + }, + }, + }, + opts) + + assertParseOneExpr(t, "with key", "every k, v in [1,2] { true }", + &Expr{ + Terms: &Every{ + Key: VarTerm("k"), + Value: VarTerm("v"), + Domain: ArrayTerm(IntNumberTerm(1), IntNumberTerm(2)), + Body: []*Expr{ + NewExpr(BooleanTerm(true)), + }, + }, + }, opts) + + assertParseErrorContains(t, "arbitrary term", "every 10", "expected `x[, y] in xs { ... }` expression", opts) + assertParseErrorContains(t, "non-var value", "every 10 in xs { true }", "unexpected { token: expected value to be a variable", opts) + assertParseErrorContains(t, "non-var key", "every 10, x in xs { true }", "unexpected { token: expected key to be a variable", opts) + assertParseErrorContains(t, "arbitrary call", "every f(10)", "expected `x[, y] in xs { ... }` expression", opts) + assertParseErrorContains(t, "no body", "every x in xs", "missing body", opts) + assertParseErrorContains(t, "invalid body", "every x in xs { + }", "unexpected plus token", opts) + assertParseErrorContains(t, "not every", "not every x in xs { true }", "unexpected every keyword: illegal negation of 'every'", opts) + + assertParseOneExpr(t, `"every" kw implies "in" kw`, "x in xs", Member.Expr( + VarTerm("x"), + VarTerm("xs"), + ), opts) + + assertParseOneExpr(t, "with modifier on expr", "every x in input { x } with input as []", + &Expr{ + Terms: &Every{ + Value: VarTerm("x"), + Domain: NewTerm(MustParseRef("input")), + Body: []*Expr{ + NewExpr(VarTerm("x")), + }, + }, + With: []*With{{Value: ArrayTerm(), Target: NewTerm(MustParseRef("input"))}}, + }, opts) + + assertParseErrorContains(t, "every x, y in ... usage is hinted properly", ` + p { + every x, y in {"foo": "bar"} { is_string(x); is_string(y) } + }`, + "unexpected ident token: expected \\n or ; or } (hint: `import future.keywords.every` for `every x in xs { ... }` expressions)") + + assertParseErrorContains(t, "not every 'every' gets a hint", ` + p { + every x + }`, + "unexpected ident token: expected \\n or ; or }\n\tevery x\n", // this asserts that the tail of the error message doesn't contain a hint + ) + + assertParseErrorContains(t, "invalid domain (internal.member_2)", "every internal.member_2()", "illegal domain", opts) + assertParseErrorContains(t, "invalid domain (internal.member_3)", "every internal.member_3()", "illegal domain", opts) } func TestNestedExpressions(t *testing.T) { @@ -916,71 +1007,102 @@ func TestChainedCall(t *testing.T) { } func TestMultiLineBody(t *testing.T) { - - input1 := ` - x = 1 - y = 2 - z = [ i | [x,y] = arr - arr[_] = i] - ` - - body1, err := ParseBody(input1) - if err != nil { - t.Fatalf("Unexpected parse error on enclosed body: %v", err) - } - - expected1 := MustParseBody(`x = 1; y = 2; z = [i | [x,y] = arr; arr[_] = i]`) - - if !body1.Equal(expected1) { - t.Errorf("Expected enclosed body to equal %v but got: %v", expected1, body1) - } - - // Check that parser can handle multiple expressions w/o enclosing braces. - input2 := ` - x = 1 ; # comment after semicolon - y = 2 # comment without semicolon - z = [ i | [x,y] = arr # comment in comprehension - arr[_] = i] - ` - - body2, err := ParseBody(input2) - if err != nil { - t.Fatalf("Unexpected parse error on enclosed body: %v", err) + tests := []struct { + note string + input string + exp Body + }{ + { + note: "three definitions", + input: ` +x = 1 +y = 2 +z = [ i | [x,y] = arr + arr[_] = i] +`, + exp: MustParseBody(`x = 1; y = 2; z = [i | [x,y] = arr; arr[_] = i]`), + }, + { + note: "three definitions, with comments and w/o enclosing braces", + input: ` +x = 1 ; # comment after semicolon +y = 2 # comment without semicolon +z = [ i | [x,y] = arr # comment in comprehension + arr[_] = i] +`, + exp: MustParseBody(`x = 1; y = 2; z = [i | [x,y] = arr; arr[_] = i]`), + }, + { + note: "array following call w/ whitespace", + input: "f(x)\n [1]", + exp: NewBody( + NewExpr([]*Term{RefTerm(VarTerm("f")), VarTerm("x")}), + NewExpr(ArrayTerm(IntNumberTerm(1))), + ), + }, + { + note: "set following call w/ semicolon", + input: "f(x);{1}", + exp: NewBody( + NewExpr([]*Term{RefTerm(VarTerm("f")), VarTerm("x")}), + NewExpr(SetTerm(IntNumberTerm(1))), + ), + }, + { + note: "array following array w/ whitespace", + input: "[1]\n [2]", + exp: NewBody( + NewExpr(ArrayTerm(IntNumberTerm(1))), + NewExpr(ArrayTerm(IntNumberTerm(2))), + ), + }, + { + note: "array following set w/ whitespace", + input: "{1}\n [2]", + exp: NewBody( + NewExpr(SetTerm(IntNumberTerm(1))), + NewExpr(ArrayTerm(IntNumberTerm(2))), + ), + }, + { + note: "set following call w/ whitespace", + input: "f(x)\n {1}", + exp: NewBody( + NewExpr([]*Term{RefTerm(VarTerm("f")), VarTerm("x")}), + NewExpr(SetTerm(IntNumberTerm(1))), + ), + }, + { + note: "set following ref w/ whitespace", + input: "data.p.q\n {1}", + exp: NewBody( + NewExpr(&Term{Value: MustParseRef("data.p.q")}), + NewExpr(SetTerm(IntNumberTerm(1))), + ), + }, + { + note: "set following variable w/ whitespace", + input: "input\n {1}", + exp: NewBody( + NewExpr(&Term{Value: MustParseRef("input")}), + NewExpr(SetTerm(IntNumberTerm(1))), + ), + }, + { + note: "set following equality w/ whitespace", + input: "input = 2 \n {1}", + exp: NewBody( + Equality.Expr(&Term{Value: MustParseRef("input")}, IntNumberTerm(2)), + NewExpr(SetTerm(IntNumberTerm(1))), + ), + }, } - if !body2.Equal(expected1) { - t.Errorf("Expected unenclosed body to equal %v but got: %v", expected1, body1) + for _, tc := range tests { + t.Run(tc.note, func(t *testing.T) { + assertParseOneBody(t, tc.note, tc.input, tc.exp) + }) } - - assertParseOneBody(t, "whitespace following call", "f(x)\t\n [1]", NewBody( - NewExpr( - []*Term{ - RefTerm(VarTerm("f")), - VarTerm("x"), - }, - ), - NewExpr( - ArrayTerm(IntNumberTerm(1)), - ), - )) - - assertParseOneBody(t, "whitespace following array", "[1]\t\n [2]", NewBody( - NewExpr( - ArrayTerm(IntNumberTerm(1)), - ), - NewExpr( - ArrayTerm(IntNumberTerm(2)), - ), - )) - - assertParseOneBody(t, "whitespace following set", "{1}\t\n {2}", NewBody( - NewExpr( - SetTerm(IntNumberTerm(1)), - ), - NewExpr( - SetTerm(IntNumberTerm(2)), - ), - )) } func TestBitwiseOrVsComprehension(t *testing.T) { @@ -1108,12 +1230,38 @@ func TestImport(t *testing.T) { } func TestFutureImports(t *testing.T) { - assertParseErrorContains(t, "future", "import future", "invalid import, use `import future.keywords` or `import.future.keywords.in`") + assertParseErrorContains(t, "future", "import future", "invalid import, must be `future.keywords`") assertParseErrorContains(t, "future.a", "import future.a", "invalid import, must be `future.keywords`") - assertParseErrorContains(t, "unknown keyword", "import future.keywords.xyz", "unexpected keyword, must be one of [in]") - assertParseErrorContains(t, "all keyword import + alias", "import future.keywords as xyz", "future keyword imports cannot be aliased") - assertParseErrorContains(t, "keyword import + alias", "import future.keywords.in as xyz", "future keyword imports cannot be aliased") + assertParseErrorContains(t, "unknown keyword", "import future.keywords.xyz", "unexpected keyword, must be one of [contains every if in]") + assertParseErrorContains(t, "all keyword import + alias", "import future.keywords as xyz", "`future` imports cannot be aliased") + assertParseErrorContains(t, "keyword import + alias", "import future.keywords.in as xyz", "`future` imports cannot be aliased") + + assertParseImport(t, "import kw with kw in options", + "import future.keywords.in", &Import{Path: RefTerm(VarTerm("future"), StringTerm("keywords"), StringTerm("in"))}, + ParserOptions{FutureKeywords: []string{"in"}}) + assertParseImport(t, "import kw with all kw in options", + "import future.keywords.in", &Import{Path: RefTerm(VarTerm("future"), StringTerm("keywords"), StringTerm("in"))}, + ParserOptions{AllFutureKeywords: true}) + + mod := ` + package p + import future.keywords + import future.keywords.in + ` + parsed := Module{ + Package: MustParseStatement(`package p`).(*Package), + Imports: []*Import{ + MustParseStatement("import future.keywords").(*Import), + MustParseStatement("import future.keywords.in").(*Import), + }, + } + assertParseModule(t, "multiple imports, all kw in options", mod, &parsed, ParserOptions{AllFutureKeywords: true}) + assertParseModule(t, "multiple imports, single in options", mod, &parsed, ParserOptions{FutureKeywords: []string{"in"}}) +} +func TestFutureImportsExtraction(t *testing.T) { + // These tests assert that "import future..." statements in policies cause + // the proper keywords to be added to the parser's list of known keywords. tests := []struct { note, imp string exp map[string]tokens.Token @@ -1128,6 +1276,13 @@ func TestFutureImports(t *testing.T) { imp: "import future.keywords", exp: map[string]tokens.Token{"in": tokens.In}, }, + { + note: "all keywords + single keyword imported", + imp: ` + import future.keywords + import future.keywords.in`, + exp: map[string]tokens.Token{"in": tokens.In}, + }, } for _, tc := range tests { t.Run(tc.note, func(t *testing.T) { @@ -1256,6 +1411,17 @@ func TestRule(t *testing.T) { Body: NewBody(NewExpr(BooleanTerm(true))), }) + assertParseRule(t, "default w/ assignment", `default allow := false`, &Rule{ + Default: true, + Head: &Head{ + Name: "allow", + Reference: Ref{VarTerm("allow")}, + Value: BooleanTerm(false), + Assign: true, + }, + Body: NewBody(NewExpr(BooleanTerm(true))), + }) + assertParseRule(t, "default w/ comprehension", `default widgets = [x | x = data.fooz[_]]`, &Rule{ Default: true, Head: NewHead(Var("widgets"), nil, MustParseTerm(`[x | x = data.fooz[_]]`)), @@ -1275,9 +1441,10 @@ func TestRule(t *testing.T) { }) fxy := &Head{ - Name: Var("f"), - Args: Args{VarTerm("x")}, - Value: VarTerm("y"), + Name: Var("f"), + Reference: Ref{VarTerm("f")}, + Args: Args{VarTerm("x")}, + Value: VarTerm("y"), } assertParseRule(t, "identity", `f(x) = y { y = x }`, &Rule{ @@ -1289,9 +1456,10 @@ func TestRule(t *testing.T) { assertParseRule(t, "composite arg", `f([x, y]) = z { split(x, y, z) }`, &Rule{ Head: &Head{ - Name: Var("f"), - Args: Args{ArrayTerm(VarTerm("x"), VarTerm("y"))}, - Value: VarTerm("z"), + Name: Var("f"), + Reference: Ref{VarTerm("f")}, + Args: Args{ArrayTerm(VarTerm("x"), VarTerm("y"))}, + Value: VarTerm("z"), }, Body: NewBody( Split.Expr(VarTerm("x"), VarTerm("y"), VarTerm("z")), @@ -1300,9 +1468,10 @@ func TestRule(t *testing.T) { assertParseRule(t, "composite result", `f(1) = [x, y] { split("foo.bar", x, y) }`, &Rule{ Head: &Head{ - Name: Var("f"), - Args: Args{IntNumberTerm(1)}, - Value: ArrayTerm(VarTerm("x"), VarTerm("y")), + Name: Var("f"), + Reference: Ref{VarTerm("f")}, + Args: Args{IntNumberTerm(1)}, + Value: ArrayTerm(VarTerm("x"), VarTerm("y")), }, Body: NewBody( Split.Expr(StringTerm("foo.bar"), VarTerm("x"), VarTerm("y")), @@ -1311,7 +1480,8 @@ func TestRule(t *testing.T) { assertParseRule(t, "expr terms: key", `p[f(x) + g(x)] { true }`, &Rule{ Head: &Head{ - Name: Var("p"), + Name: Var("p"), + Reference: Ref{VarTerm("p")}, Key: Plus.Call( CallTerm(RefTerm(VarTerm("f")), VarTerm("x")), CallTerm(RefTerm(VarTerm("g")), VarTerm("x")), @@ -1322,7 +1492,8 @@ func TestRule(t *testing.T) { assertParseRule(t, "expr terms: value", `p = f(x) + g(x) { true }`, &Rule{ Head: &Head{ - Name: Var("p"), + Name: Var("p"), + Reference: Ref{VarTerm("p")}, Value: Plus.Call( CallTerm(RefTerm(VarTerm("f")), VarTerm("x")), CallTerm(RefTerm(VarTerm("g")), VarTerm("x")), @@ -1333,7 +1504,8 @@ func TestRule(t *testing.T) { assertParseRule(t, "expr terms: args", `p(f(x) + g(x)) { true }`, &Rule{ Head: &Head{ - Name: Var("p"), + Name: Var("p"), + Reference: Ref{VarTerm("p")}, Args: Args{ Plus.Call( CallTerm(RefTerm(VarTerm("f")), VarTerm("x")), @@ -1347,8 +1519,52 @@ func TestRule(t *testing.T) { assertParseRule(t, "assignment operator", `x := 1 { true }`, &Rule{ Head: &Head{ - Name: Var("x"), - Value: IntNumberTerm(1), + Name: Var("x"), + Reference: Ref{VarTerm("x")}, + Value: IntNumberTerm(1), + Assign: true, + }, + Body: NewBody(NewExpr(BooleanTerm(true))), + }) + + assertParseRule(t, "else assignment", `x := 1 { false } else := 2`, &Rule{ + Head: &Head{ + Name: "x", // ha! clever! + Reference: Ref{VarTerm("x")}, + Value: IntNumberTerm(1), + Assign: true, + }, + Body: NewBody(NewExpr(BooleanTerm(false))), + Else: &Rule{ + Head: &Head{ + Name: "x", + Reference: Ref{VarTerm("x")}, + Value: IntNumberTerm(2), + Assign: true, + }, + Body: NewBody(NewExpr(BooleanTerm(true))), + }, + }) + + assertParseRule(t, "partial assignment", `p[x] := y { true }`, &Rule{ + Head: &Head{ + Name: "p", + Reference: MustParseRef("p[x]"), + Value: VarTerm("y"), + Key: VarTerm("x"), + Assign: true, + }, + Body: NewBody(NewExpr(BooleanTerm(true))), + }) + + assertParseRule(t, "function assignment", `f(x) := y { true }`, &Rule{ + Head: &Head{ + Name: "f", + Reference: Ref{VarTerm("f")}, + Value: VarTerm("y"), + Args: Args{ + VarTerm("x"), + }, Assign: true, }, Body: NewBody(NewExpr(BooleanTerm(true))), @@ -1359,33 +1575,33 @@ func TestRule(t *testing.T) { assertParseErrorContains(t, "empty rule body", "p {}", "rego_parse_error: found empty body") assertParseErrorContains(t, "unmatched braces", `f(x) = y { trim(x, ".", y) `, `rego_parse_error: unexpected eof token: expected \n or ; or }`) - // TODO: how to highlight that assignment is incorrect here? assertParseErrorContains(t, "no output", `f(_) = { "foo" = "bar" }`, "rego_parse_error: unexpected eq token: expected rule value term") + assertParseErrorContains(t, "no output", `f(_) := { "foo" = "bar" }`, "rego_parse_error: unexpected assign token: expected function value term") + assertParseErrorContains(t, "no output", `f := { "foo" = "bar" }`, "rego_parse_error: unexpected assign token: expected rule value term") + assertParseErrorContains(t, "no output", `f[_] := { "foo" = "bar" }`, "rego_parse_error: unexpected assign token: expected rule value term") + assertParseErrorContains(t, "no output", `default f :=`, "rego_parse_error: unexpected assign token: expected default rule value term") // TODO(tsandall): improve error checking here. This is a common mistake // and the current error message is not very good. Need to investigate if the // parser can be improved. assertParseError(t, "dangling semicolon", "p { true; false; }") - assertParseErrorContains(t, "default assignment", "default p := 1", `default rules must use = operator (not := operator)`) - assertParseErrorContains(t, "partial assignment", `p[x] := y { true }`, "partial rules must use = operator (not := operator)") - assertParseErrorContains(t, "function assignment", `f(x) := y { true }`, "functions must use = operator (not := operator)") - assertParseErrorContains(t, "else assignment", `p := y { true } else = 2 { true } `, "else keyword cannot be used on rule declared with := operator") - assertParseErrorContains(t, "default invalid rule name", `default 0[0`, "unexpected default keyword") - assertParseErrorContains(t, "default invalid rule value", `default a[0`, "illegal default rule (must have a value)") + assertParseErrorContains(t, "default invalid rule value", `default a[0]`, "illegal default rule (must have a value)") assertParseRule(t, "default missing value", `default a`, &Rule{ Default: true, Head: &Head{ - Name: Var("a"), - Value: BooleanTerm(true), + Name: Var("a"), + Reference: Ref{VarTerm("a")}, + Value: BooleanTerm(true), }, Body: NewBody(NewExpr(BooleanTerm(true))), }) assertParseRule(t, "empty arguments", `f() { x := 1 }`, &Rule{ Head: &Head{ - Name: "f", - Value: BooleanTerm(true), + Name: "f", + Reference: Ref{VarTerm("f")}, + Value: BooleanTerm(true), }, Body: MustParseBody(`x := 1`), }) @@ -1396,190 +1612,797 @@ func TestRule(t *testing.T) { assertParseErrorContains(t, "default invalid rule head call", `default a = b`, "illegal default rule (value cannot contain var)") assertParseError(t, "extra braces", `{ a := 1 }`) - assertParseError(t, "invalid rule name dots", `a.b = x { x := 1 }`) - assertParseError(t, "invalid rule name dots and call", `a.b(x) { x := 1 }`) assertParseError(t, "invalid rule name hyphen", `a-b = x { x := 1 }`) assertParseRule(t, "wildcard name", `_ { x == 1 }`, &Rule{ Head: &Head{ - Name: "$0", - Value: BooleanTerm(true), + Name: "$0", + Reference: Ref{VarTerm("$0")}, + Value: BooleanTerm(true), }, Body: MustParseBody(`x == 1`), }) assertParseRule(t, "partial object array key", `p[[a, 1, 2]] = x { a := 1; x := "foo" }`, &Rule{ Head: &Head{ - Name: "p", - Key: ArrayTerm(VarTerm("a"), NumberTerm("1"), NumberTerm("2")), - Value: VarTerm("x"), + Name: "p", + Reference: MustParseRef("p[[a,1,2]]"), + Key: ArrayTerm(VarTerm("a"), NumberTerm("1"), NumberTerm("2")), + Value: VarTerm("x"), }, Body: MustParseBody(`a := 1; x := "foo"`), }) assertParseError(t, "invalid rule body no separator", `p { a = "foo"bar }`) assertParseError(t, "invalid rule body no newline", `p { a b c }`) -} - -func TestRuleElseKeyword(t *testing.T) { - mod := `package test - - p { - "p0" - } - - p { - "p1" - } else { - "p1_e1" - } else = [null] { - "p1_e2" - } else = x { - x = "p1_e3" - } - - p { - "p2" - } - - f(x) { - x < 100 - } else = false { - x > 200 - } else { - x != 150 - } - - _ { - x > 0 - } else { - x == -1 - } else { - x > -100 - } - - nobody = 1 { - false - } else = 7 - - nobody_f(x) = 1 { - false - } else = 7 - ` - parsed, err := ParseModule("", mod) - if err != nil { - t.Fatalf("Unexpected parse error: %v", err) - } + assertParseRule(t, "wildcard in else args", `f(_) { true } else := false`, &Rule{ + Head: &Head{ + Name: "f", + Reference: Ref{VarTerm("f")}, + Args: Args{ + VarTerm("$0"), + }, + Value: BooleanTerm(true), + }, + Body: MustParseBody(`true`), + Else: &Rule{ + Head: &Head{ + Name: "f", + Assign: true, + Reference: Ref{VarTerm("f")}, + Args: Args{ + VarTerm("$1"), + }, + Value: BooleanTerm(false), + }, + Body: MustParseBody(`true`), + }, + }) - name := Var("p") + name := Var("f") + ref := Ref{VarTerm("f")} tr := BooleanTerm(true) - head := &Head{Name: name, Value: tr} - - expected := &Module{ + head := func(v string) *Head { return &Head{Name: name, Reference: ref, Value: tr, Args: []*Term{VarTerm(v)}} } + assertParseModule(t, "wildcard in chained function heads", `package test + f(_) { true } { true } +`, &Module{ Package: MustParsePackage(`package test`), Rules: []*Rule{ { - Head: head, - Body: MustParseBody(`"p0"`), + Head: head("$0"), + Body: MustParseBody("true"), }, { - Head: head, - Body: MustParseBody(`"p1"`), - Else: &Rule{ - Head: head, - Body: MustParseBody(`"p1_e1"`), - Else: &Rule{ - Head: &Head{ - Name: Var("p"), - Value: ArrayTerm(NullTerm()), - }, - Body: MustParseBody(`"p1_e2"`), - Else: &Rule{ - Head: &Head{ - Name: name, - Value: VarTerm("x"), - }, - Body: MustParseBody(`x = "p1_e3"`), - }, - }, - }, + Head: head("$1"), + Body: MustParseBody("true"), }, - { - Head: head, - Body: MustParseBody(`"p2"`), + }, + }) +} + +func TestRuleContains(t *testing.T) { + opts := ParserOptions{FutureKeywords: []string{"contains", "if"}} + + tests := []struct { + note string + rule string + exp *Rule + }{ + { + note: "simple", + rule: `p contains "x" { true }`, + exp: &Rule{ + Head: NewHead(Var("p"), StringTerm("x")), + Body: NewBody(NewExpr(BooleanTerm(true))), }, - { - Head: &Head{ - Name: Var("f"), - Args: Args{VarTerm("x")}, - Value: BooleanTerm(true), - }, - Body: MustParseBody(`x < 100`), - Else: &Rule{ - Head: &Head{ - Name: Var("f"), - Args: Args{VarTerm("x")}, - Value: BooleanTerm(false), - }, - Body: MustParseBody(`x > 200`), - Else: &Rule{ - Head: &Head{ - Name: Var("f"), - Args: Args{VarTerm("x")}, - Value: BooleanTerm(true), - }, - Body: MustParseBody(`x != 150`), - }, - }, + }, + { + note: "no body", + rule: `p contains "x"`, + exp: &Rule{ + Head: NewHead(Var("p"), StringTerm("x")), + Body: NewBody(NewExpr(BooleanTerm(true))), }, - - { + }, + { + note: "ref head, no body", + rule: `p.q contains "x"`, + exp: &Rule{ Head: &Head{ - Name: Var("$0"), - Value: BooleanTerm(true), - }, - Body: MustParseBody(`x > 0`), - Else: &Rule{ - Head: &Head{ - Name: Var("$0"), - Value: BooleanTerm(true), - }, - Body: MustParseBody(`x == -1`), - Else: &Rule{ - Head: &Head{ - Name: Var("$0"), - Value: BooleanTerm(true), - }, - Body: MustParseBody(`x > -100`), - }, + Reference: MustParseRef("p.q"), + Key: StringTerm("x"), }, + Body: NewBody(NewExpr(BooleanTerm(true))), }, - { + }, + { + note: "ref head", + rule: `p.q contains "x" { true }`, + exp: &Rule{ Head: &Head{ - Name: Var("nobody"), - Value: IntNumberTerm(1), + Reference: MustParseRef("p.q"), + Key: StringTerm("x"), }, - Body: MustParseBody("false"), - Else: &Rule{ - Head: &Head{ - Name: Var("nobody"), - Value: IntNumberTerm(7), - }, + Body: NewBody(NewExpr(BooleanTerm(true))), + }, + }, + { + note: "set with var element", + rule: `deny contains msg { msg := "nonono" }`, + exp: &Rule{ + Head: NewHead(Var("deny"), VarTerm("msg")), + Body: MustParseBody(`msg := "nonono"`), + }, + }, + { + note: "set with object elem", + rule: `deny contains {"allow": false, "msg": msg} { msg := "nonono" }`, + exp: &Rule{ + Head: NewHead(Var("deny"), MustParseTerm(`{"allow": false, "msg": msg}`)), + Body: MustParseBody(`msg := "nonono"`), + }, + }, + } + + for _, tc := range tests { + t.Run(tc.note, func(t *testing.T) { + assertParseRule(t, tc.note, tc.rule, tc.exp, opts) + }) + } +} + +func TestRuleContainsFail(t *testing.T) { + opts := ParserOptions{FutureKeywords: []string{"contains", "if", "every"}} + + tests := []struct { + note string + rule string + expected string + }{ + { + note: "contains used with a 1+ argument function", + rule: "p(a) contains x { x := a }", + expected: "the contains keyword can only be used with multi-value rule definitions (e.g., p contains { ... })", + }, + { + note: "contains used with a 0 argument function", + rule: "p() contains x { x := 1 }", + expected: "the contains keyword can only be used with multi-value rule definitions (e.g., p contains { ... })", + }, + } + for _, tc := range tests { + t.Run(tc.note, func(t *testing.T) { + assertParseErrorContains(t, tc.note, tc.rule, tc.expected, opts) + }) + } +} + +func TestRuleIf(t *testing.T) { + opts := ParserOptions{FutureKeywords: []string{"contains", "if", "every"}} + + tests := []struct { + note string + rule string + exp *Rule + }{ + { + note: "complete", + rule: `p if { true }`, + exp: &Rule{ + Head: NewHead(Var("p"), nil, BooleanTerm(true)), + Body: NewBody(NewExpr(BooleanTerm(true))), + }, + }, + { + note: "else", + rule: `p if { true } else if { true }`, + exp: &Rule{ + Head: NewHead(Var("p"), nil, BooleanTerm(true)), + Body: NewBody(NewExpr(BooleanTerm(true))), + Else: &Rule{ + Head: NewHead(Var("p"), nil, BooleanTerm(true)), + Body: NewBody(NewExpr(BooleanTerm(true))), + }, + }, + }, + { + note: "ref head, complete", + rule: `p.q if { true }`, + exp: &Rule{ + Head: &Head{ + Reference: MustParseRef("p.q"), + Value: BooleanTerm(true), + }, + Body: NewBody(NewExpr(BooleanTerm(true))), + }, + }, + { + note: "complete, normal body", + rule: `p if { x := 10; x > y }`, + exp: &Rule{ + Head: NewHead(Var("p"), nil, BooleanTerm(true)), + Body: MustParseBody(`x := 10; x > y`), + }, + }, + { + note: "complete+else, normal bodies, assign", + rule: `p := "yes" if { 10 > y } else := "no" { 10 <= y }`, + exp: &Rule{ + Head: &Head{ + Reference: Ref{VarTerm("p")}, + Name: Var("p"), + Value: StringTerm("yes"), + Assign: true, + }, + Body: MustParseBody(`10 > y`), + Else: &Rule{ + Head: &Head{ + Reference: Ref{VarTerm("p")}, + Name: Var("p"), + Value: StringTerm("no"), + Assign: true, + }, + Body: MustParseBody(`10 <= y`), + }, + }, + }, + { + note: "complete+else, normal bodies, assign; if", + rule: `p := "yes" if { 10 > y } else := "no" if { 10 <= y }`, + exp: &Rule{ + Head: &Head{ + Reference: Ref{VarTerm("p")}, + Name: Var("p"), + Value: StringTerm("yes"), + Assign: true, + }, + Body: MustParseBody(`10 > y`), + Else: &Rule{ + Head: &Head{ + Reference: Ref{VarTerm("p")}, + Name: Var("p"), + Value: StringTerm("no"), + Assign: true, + }, + Body: MustParseBody(`10 <= y`), + }, + }, + }, + { + note: "complete, shorthand", + rule: `p if true`, + exp: &Rule{ + Head: NewHead(Var("p"), nil, BooleanTerm(true)), + Body: NewBody(NewExpr(BooleanTerm(true))), + }, + }, + { + note: "complete, else, shorthand", + rule: `p if true else if true`, + exp: &Rule{ + Head: NewHead(Var("p"), nil, BooleanTerm(true)), + Body: NewBody(NewExpr(BooleanTerm(true))), + Else: &Rule{ + Head: NewHead(Var("p"), nil, BooleanTerm(true)), + Body: NewBody(NewExpr(BooleanTerm(true))), + }, + }, + }, + { + note: "complete, else, assignment+shorthand", + rule: `p if true else := 3 if 2 < 1`, + exp: &Rule{ + Head: NewHead(Var("p"), nil, BooleanTerm(true)), + Body: NewBody(NewExpr(BooleanTerm(true))), + Else: &Rule{ + Head: &Head{ + Reference: Ref{VarTerm("p")}, + Name: Var("p"), + Value: NumberTerm("3"), + Assign: true, + }, + Body: NewBody(LessThan.Expr(IntNumberTerm(2), IntNumberTerm(1))), + }, + }, + }, + { + note: "complete+not, shorthand", + rule: `p if not q`, + exp: &Rule{ + Head: NewHead(Var("p"), nil, BooleanTerm(true)), + Body: MustParseBody(`not q`), + }, + }, + { + note: "complete+else, shorthand", + rule: `p if 1 > 2 else = 42 { 2 > 1 }`, + exp: &Rule{ + Head: NewHead(Var("p"), nil, BooleanTerm(true)), + Body: MustParseBody(`1 > 2`), + Else: &Rule{ + Head: &Head{ + Reference: Ref{VarTerm("p")}, + Name: Var("p"), + Value: NumberTerm("42"), + }, + Body: MustParseBody(`2 > 1`), + }, + }, + }, + { + note: "complete+call, shorthand", + rule: `p if count(q) > 0`, + exp: &Rule{ + Head: NewHead(Var("p"), nil, BooleanTerm(true)), + Body: MustParseBody(`count(q) > 0`), + }, + }, + { + note: "function, shorthand", + rule: `f(x) = y if y := x + 1`, + exp: &Rule{ + Head: &Head{ + Reference: Ref{VarTerm("f")}, + Name: Var("f"), + Args: []*Term{VarTerm("x")}, + Value: VarTerm("y"), + }, + Body: MustParseBody(`y := x + 1`), + }, + }, + { + note: "function+every, shorthand", + rule: `f(xs) if every x in xs { x != 0 }`, + exp: &Rule{ + Head: &Head{ + Reference: Ref{VarTerm("f")}, + Name: Var("f"), + Args: []*Term{VarTerm("xs")}, + Value: BooleanTerm(true), + }, + Body: MustParseBodyWithOpts(`every x in xs { x != 0 }`, opts), + }, + }, + { + note: "object", + rule: `p["foo"] = "bar" if { true }`, + exp: &Rule{ + Head: &Head{ + Name: Var("p"), + Reference: MustParseRef("p.foo"), + Value: StringTerm("bar"), + }, + Body: NewBody(NewExpr(BooleanTerm(true))), + }, + }, + { + note: "object, shorthand", + rule: `p["foo"] = "bar" if true`, + exp: &Rule{ + Head: &Head{ + Name: Var("p"), + Reference: MustParseRef("p.foo"), + Value: StringTerm("bar"), + }, + Body: NewBody(NewExpr(BooleanTerm(true))), + }, + }, + { + note: "object with vars", + rule: `p[x] = y if { + x := "foo" + y := "bar" + }`, + exp: &Rule{ + Head: &Head{ + Name: Var("p"), + Reference: MustParseRef("p[x]"), + Value: VarTerm("y"), + }, + Body: MustParseBody(`x := "foo"; y := "bar"`), + }, + }, + { + note: "set", + rule: `p contains "foo" if { true }`, + exp: &Rule{ + Head: NewHead(Var("p"), StringTerm("foo")), + Body: NewBody(NewExpr(BooleanTerm(true))), + }, + }, + { + note: "set, shorthand", + rule: `p contains "foo" if true`, + exp: &Rule{ + Head: NewHead(Var("p"), StringTerm("foo")), + Body: NewBody(NewExpr(BooleanTerm(true))), + }, + }, + { + note: "set+var+shorthand", + rule: `p contains x if { x := "foo" }`, + exp: &Rule{ + Head: NewHead(Var("p"), VarTerm("x")), + Body: MustParseBody(`x := "foo"`), + }, + }, + { + note: "partial set+if, shorthand", // these are now Head.Ref rules, previously forbidden + rule: `p[x] if x := 1`, + exp: &Rule{ + Head: &Head{ + Reference: MustParseRef("p[x]"), + Value: BooleanTerm(true), + }, + Body: MustParseBody(`x := 1`), + }, + }, + { + note: "partial set+if", // these are now Head.Ref rules, previously forbidden + rule: `p[x] if { x := 1 }`, + exp: &Rule{ + Head: &Head{ + Reference: MustParseRef("p[x]"), + Value: BooleanTerm(true), + }, + Body: MustParseBody(`x := 1`), + }, + }, + } + + for _, tc := range tests { + t.Run(tc.note, func(t *testing.T) { + assertParseRule(t, tc.note, tc.rule, tc.exp, opts) + }) + } +} + +func TestRuleRefHeads(t *testing.T) { + opts := ParserOptions{FutureKeywords: []string{"contains", "if", "every"}} + trueBody := NewBody(NewExpr(BooleanTerm(true))) + + tests := []struct { + note string + rule string + exp *Rule + }{ + { + note: "single-value rule", + rule: "p.q.r = 1 if true", + exp: &Rule{ + Head: &Head{ + Reference: MustParseRef("p.q.r"), + Value: IntNumberTerm(1), + }, + Body: trueBody, + }, + }, + { + note: "single-value with brackets, string key", + rule: `p.q["r"] = 1 if true`, + exp: &Rule{ + Head: &Head{ + Reference: MustParseRef("p.q.r"), + Value: IntNumberTerm(1), + }, + Body: trueBody, + }, + }, + { + note: "single-value with brackets, number key", + rule: `p.q[2] = 1 if true`, + exp: &Rule{ + Head: &Head{ + Reference: MustParseRef("p.q[2]"), + Value: IntNumberTerm(1), + }, + Body: trueBody, + }, + }, + { + note: "single-value with brackets, no value", + rule: `p.q[2] if true`, + exp: &Rule{ + Head: &Head{ + Reference: MustParseRef("p.q[2]"), + Value: BooleanTerm(true), + }, + Body: trueBody, + }, + }, + { + note: "single-value with brackets, var key", + rule: `p.q[x] = 1 if x := 2`, + exp: &Rule{ + Head: &Head{ + Reference: MustParseRef("p.q[x]"), + Value: IntNumberTerm(1), + }, + Body: MustParseBody("x := 2"), + }, + }, + { + note: "single-value with brackets, var key, no dot", + rule: `p[x] = 1 if x := 2`, + exp: &Rule{ + Head: &Head{ + Name: Var("p"), + Reference: MustParseRef("p[x]"), + Value: IntNumberTerm(1), + }, + Body: MustParseBody("x := 2"), + }, + }, + { + note: "multi-value, simple", + rule: `p.q.r contains x if x := 2`, + exp: &Rule{ + Head: &Head{ + Reference: MustParseRef("p.q.r"), + Key: VarTerm("x"), + }, + Body: MustParseBody("x := 2"), + }, + }, + { + note: "backcompat: multi-value, no dot", + rule: `p[x] { x := 2 }`, // no "if", which triggers ref-interpretation + exp: &Rule{ + Head: &Head{ + Name: "p", + Reference: Ref{VarTerm("p")}, // we're defining p as multi-val rule + Key: VarTerm("x"), + }, + Body: MustParseBody("x := 2"), + }, + }, + { + note: "backcompat: single-value, no dot", + rule: `p[x] = 3 { x := 2 }`, + exp: &Rule{ + Head: &Head{ + Name: "p", + Reference: MustParseRef("p[x]"), + Key: VarTerm("x"), // not used + Value: IntNumberTerm(3), + }, + Body: MustParseBody("x := 2"), + }, + }, + { + note: "backcompat: single-value, no dot, complex object", + rule: `partialobj[x] = {"foo": y} { y = "bar"; x = y }`, + exp: &Rule{ + Head: &Head{ + Name: "partialobj", + Reference: MustParseRef("partialobj[x]"), + Key: VarTerm("x"), // not used + Value: MustParseTerm(`{"foo": y}`), + }, + Body: MustParseBody(`y = "bar"; x = y`), + }, + }, + { + note: "function, simple", + rule: `p.q.f(x) = 1 if true`, + exp: &Rule{ + Head: &Head{ + Reference: MustParseRef("p.q.f"), + Args: Args([]*Term{VarTerm("x")}), + Value: IntNumberTerm(1), + }, + Body: trueBody, + }, + }, + { + note: "function, no value", + rule: `p.q.f(x) if true`, + exp: &Rule{ + Head: &Head{ + Reference: MustParseRef("p.q.f"), + Args: Args([]*Term{VarTerm("x")}), + Value: BooleanTerm(true), + }, + Body: trueBody, + }, + }, + { + note: "function, with value", + rule: `p.q.f(x) = x + 1 if true`, + exp: &Rule{ + Head: &Head{ + Reference: MustParseRef("p.q.f"), + Args: Args([]*Term{VarTerm("x")}), + Value: Plus.Call(VarTerm("x"), IntNumberTerm(1)), + }, + Body: trueBody, + }, + }, + } + for _, tc := range tests { + t.Run(tc.note, func(t *testing.T) { + assertParseRule(t, tc.note, tc.rule, tc.exp, opts) + }) + } + + assertParseErrorContains(t, "first ref head term is call", `package p +q(0).r(0) { true }`, + "unexpected { token: rule head ref q(0).r invalid", opts) +} + +func TestRuleElseKeyword(t *testing.T) { + mod := `package test + + p { + "p0" + } + + p { + "p1" + } else { + "p1_e1" + } else = [null] { + "p1_e2" + } else = x { + x = "p1_e3" + } + + p { + "p2" + } + + f(x) { + x < 100 + } else = false { + x > 200 + } else { + x != 150 + } + + _ { + x > 0 + } else { + x == -1 + } else { + x > -100 + } + + nobody = 1 { + false + } else = 7 + + nobody_f(x) = 1 { + false + } else = 7 + ` + + parsed, err := ParseModule("", mod) + if err != nil { + t.Fatalf("Unexpected parse error: %v", err) + } + + name := Var("p") + ref := Ref{VarTerm("p")} + tr := BooleanTerm(true) + head := &Head{Name: name, Reference: ref, Value: tr} + + expected := &Module{ + Package: MustParsePackage(`package test`), + Rules: []*Rule{ + { + Head: head, + Body: MustParseBody(`"p0"`), + }, + { + Head: head, + Body: MustParseBody(`"p1"`), + Else: &Rule{ + Head: head, + Body: MustParseBody(`"p1_e1"`), + Else: &Rule{ + Head: &Head{ + Name: name, + Reference: ref, + Value: ArrayTerm(NullTerm()), + }, + Body: MustParseBody(`"p1_e2"`), + Else: &Rule{ + Head: &Head{ + Name: name, + Reference: ref, + Value: VarTerm("x"), + }, + Body: MustParseBody(`x = "p1_e3"`), + }, + }, + }, + }, + { + Head: head, + Body: MustParseBody(`"p2"`), + }, + { + Head: &Head{ + Name: Var("f"), + Reference: Ref{VarTerm("f")}, + Args: Args{VarTerm("x")}, + Value: BooleanTerm(true), + }, + Body: MustParseBody(`x < 100`), + Else: &Rule{ + Head: &Head{ + Name: Var("f"), + Reference: Ref{VarTerm("f")}, + Args: Args{VarTerm("x")}, + Value: BooleanTerm(false), + }, + Body: MustParseBody(`x > 200`), + Else: &Rule{ + Head: &Head{ + Name: Var("f"), + Reference: Ref{VarTerm("f")}, + Args: Args{VarTerm("x")}, + Value: BooleanTerm(true), + }, + Body: MustParseBody(`x != 150`), + }, + }, + }, + + { + Head: &Head{ + Name: Var("$0"), + Reference: Ref{VarTerm("$0")}, + Value: BooleanTerm(true), + }, + Body: MustParseBody(`x > 0`), + Else: &Rule{ + Head: &Head{ + Name: Var("$0"), + Reference: Ref{VarTerm("$0")}, + Value: BooleanTerm(true), + }, + Body: MustParseBody(`x == -1`), + Else: &Rule{ + Head: &Head{ + Name: Var("$0"), + Reference: Ref{VarTerm("$0")}, + Value: BooleanTerm(true), + }, + Body: MustParseBody(`x > -100`), + }, + }, + }, + { + Head: &Head{ + Name: Var("nobody"), + Reference: Ref{VarTerm("nobody")}, + Value: IntNumberTerm(1), + }, + Body: MustParseBody("false"), + Else: &Rule{ + Head: &Head{ + Name: Var("nobody"), + Reference: Ref{VarTerm("nobody")}, + Value: IntNumberTerm(7), + }, Body: MustParseBody("true"), }, }, { Head: &Head{ - Name: Var("nobody_f"), - Args: Args{VarTerm("x")}, - Value: IntNumberTerm(1), + Name: Var("nobody_f"), + Reference: Ref{VarTerm("nobody_f")}, + Args: Args{VarTerm("x")}, + Value: IntNumberTerm(1), }, Body: MustParseBody("false"), Else: &Rule{ Head: &Head{ - Name: Var("nobody_f"), - Args: Args{VarTerm("x")}, - Value: IntNumberTerm(7), + Name: Var("nobody_f"), + Reference: Ref{VarTerm("nobody_f")}, + Args: Args{VarTerm("x")}, + Value: IntNumberTerm(7), }, Body: MustParseBody("true"), }, @@ -1606,14 +2429,16 @@ func TestRuleElseKeyword(t *testing.T) { Body: MustParseBody(`"p1_e1"`), Else: &Rule{ Head: &Head{ - Name: Var("p"), - Value: ArrayTerm(NullTerm()), + Name: Var("p"), + Reference: Ref{VarTerm("p")}, + Value: ArrayTerm(NullTerm()), }, Body: MustParseBody(`"p1_e2"`), Else: &Rule{ Head: &Head{ - Name: name, - Value: VarTerm("x"), + Name: name, + Reference: ref, + Value: VarTerm("x"), }, Body: MustParseBody(`x = "p1_e4"`), }, @@ -1636,7 +2461,7 @@ func TestRuleElseKeyword(t *testing.T) { p[1] { false } else { true } `) - if err == nil || !strings.Contains(err.Error(), "else keyword cannot be used on partial rules") { + if err == nil || !strings.Contains(err.Error(), "else keyword cannot be used on multi-value rules") { t.Fatalf("Expected parse error but got: %v", err) } @@ -1660,6 +2485,97 @@ func TestRuleElseKeyword(t *testing.T) { } +func TestRuleElseRefHeads(t *testing.T) { + tests := []struct { + note string + rule string + exp *Rule + err string + }{ + { + note: "simple ref head", + rule: ` +a.b.c := 1 if false +else := 2 +`, + exp: &Rule{ + Head: &Head{ + Reference: MustParseRef("a.b.c"), + Value: NumberTerm("1"), + Assign: true, + }, + Body: MustParseBody("false"), + Else: &Rule{ + Head: &Head{ + Reference: MustParseRef("a.b.c"), + Value: NumberTerm("2"), + Assign: true, + }, + Body: MustParseBody("true"), + }, + }, + }, + { + note: "multi-value ref head", + rule: ` +a.b.c contains 1 if false +else := 2 +`, + err: "else keyword cannot be used on multi-value rules", + }, + { + note: "single-value ref head with var", + rule: ` +a.b[x] := 1 if false +else := 2 +`, + err: "else keyword cannot be used on rules with variables in head", + }, + { + note: "single-value ref head with length 1 (last is var)", + rule: ` +a := 1 if false +else := 2 +`, + exp: &Rule{ + Head: &Head{ + Reference: Ref{VarTerm("a")}, + Name: Var("a"), + Value: NumberTerm("1"), + Assign: true, + }, + Body: MustParseBody("false"), + Else: &Rule{ + Head: &Head{ + Reference: Ref{VarTerm("a")}, + Name: Var("a"), + Value: NumberTerm("2"), + Assign: true, + }, + Body: MustParseBody("true"), + }, + }, + }, + } + + opts := ParserOptions{FutureKeywords: []string{"if", "contains"}} + for _, tc := range tests { + t.Run(tc.note, func(t *testing.T) { + if tc.err != "" { + assertParseErrorContains(t, tc.note, tc.rule, tc.err, opts) + return + } + if tc.exp != nil { + testModule := "package test\n" + tc.rule + assertParseModule(t, tc.note, testModule, &Module{ + Package: MustParseStatement(`package test`).(*Package), + Rules: []*Rule{tc.exp}, + }, opts) + } + }) + } +} + func TestMultipleEnclosedBodies(t *testing.T) { result, err := ParseModule("", `package ex @@ -1713,7 +2629,6 @@ func TestEmptyModule(t *testing.T) { } func TestComments(t *testing.T) { - testModule := `package a.b.c import input.e.f as g # end of line @@ -1929,52 +2844,245 @@ func TestLocation(t *testing.T) { } } -func TestRuleFromBody(t *testing.T) { - testModule := `package a.b.c - -pi = 3.14159 -p[x] { x = 1 } -greeting = "hello" -cores = [{0: 1}, {1: 2}] -wrapper = cores[0][1] -pi = [3, 1, 4, x, y, z] -foo["bar"] = "buz" -foo["9"] = "10" -foo.buz = "bar" -bar[1] -bar[[{"foo":"baz"}]] -bar.qux -input = 1 -data = 2 -f(1) = 2 -f(1) -d1 := 1234 -` +func TestRuleFromBodyRefs(t *testing.T) { + opts := ParserOptions{FutureKeywords: []string{"if", "contains"}} - assertParseModule(t, "rules from bodies", testModule, &Module{ - Package: MustParseStatement(`package a.b.c`).(*Package), - Rules: []*Rule{ - MustParseRule(`pi = 3.14159 { true }`), - MustParseRule(`p[x] { x = 1 }`), - MustParseRule(`greeting = "hello" { true }`), - MustParseRule(`cores = [{0: 1}, {1: 2}] { true }`), - MustParseRule(`wrapper = cores[0][1] { true }`), - MustParseRule(`pi = [3, 1, 4, x, y, z] { true }`), - MustParseRule(`foo["bar"] = "buz" { true }`), - MustParseRule(`foo["9"] = "10" { true }`), - MustParseRule(`foo["buz"] = "bar" { true }`), - MustParseRule(`bar[1] { true }`), - MustParseRule(`bar[[{"foo":"baz"}]] { true }`), - MustParseRule(`bar["qux"] { true }`), - MustParseRule(`input = 1 { true }`), - MustParseRule(`data = 2 { true }`), - MustParseRule(`f(1) = 2 { true }`), - MustParseRule(`f(1) = true { true }`), - MustParseRule("d1 := 1234 { true }"), + // NOTE(sr): These tests assert that the other code path, parsing a module, and + // then interpreting naked expressions into (shortcut) rule definitions, works + // the same as parsing the string as a Rule directly. Without also passing + // TestRuleRefHeads, these tests are not to be trusted -- if changing something, + // start with getting TestRuleRefHeads to PASS. + tests := []struct { + note string + rule string + exp string + }{ + { + note: "no dots: single-value rule (complete doc)", + rule: `foo["bar"] = 12`, + exp: `foo["bar"] = 12 { true }`, + }, + { + note: "no dots: partial set of numbers", + rule: `foo[1]`, + exp: `foo[1] { true }`, + }, + { + note: "no dots: shorthand set of strings", // back compat + rule: `foo.one`, + exp: `foo["one"] { true }`, + }, + { + note: "no dots: partial set", + rule: `foo[x] { x = 1 }`, + exp: `foo[x] { x = 1 }`, + }, + { + note: "no dots + if: complete doc", + rule: `foo[x] if x := 1`, + exp: `foo[x] if x := 1`, + }, + { + note: "no dots: function", + rule: `foo(x)`, + exp: `foo(x) { true }`, + }, + { + note: "no dots: function with value", + rule: `foo(x) = y`, + exp: `foo(x) = y { true }`, + }, + { + note: "no dots: partial set, ref element", + rule: `test[arr[0]]`, + exp: `test[arr[0]] { true }`, + }, + { + note: "one dot: complete rule shorthand", + rule: `foo.bar = "buz"`, + exp: `foo.bar = "buz" { true }`, + }, + { + note: "one dot, bracket with var: partial object", + rule: `foo.bar[x] = "buz"`, + exp: `foo.bar[x] = "buz" { true }`, + }, + { + note: "one dot, bracket with var: partial set", + rule: `foo.bar[x] { x = 1 }`, + exp: `foo.bar[x] { x = 1 }`, + }, + { + note: "one dot, bracket with string: complete doc", + rule: `foo.bar["baz"] = "buz"`, + exp: `foo.bar.baz = "buz" { true }`, + }, + { + note: "one dot, bracket with var, rule body: partial object", + rule: `foo.bar[x] = "buz" { x = 1 }`, + exp: `foo.bar[x] = "buz" { x = 1 }`, + }, + { + note: "one dot: function", + rule: `foo.bar(x)`, + exp: `foo.bar(x) { true }`, + }, + { + note: "one dot: function with value", + rule: `foo.bar(x) = y`, + exp: `foo.bar(x) = y { true }`, + }, + { + note: "two dots, bracket with var: partial object", + rule: `foo.bar.baz[x] = "buz" { x = 1 }`, + exp: `foo.bar.baz[x] = "buz" { x = 1 }`, + }, + { + note: "two dots, bracket with var: partial set", + rule: `foo.bar.baz[x] { x = 1 }`, + exp: `foo.bar.baz[x] { x = 1 }`, + }, + { + note: "one dot, bracket with string, no key: complete doc", + rule: `foo.bar["baz"]`, + exp: `foo.bar.baz { true }`, + }, + { + note: "two dots: function", + rule: `foo.bar("baz")`, + exp: `foo.bar("baz") { true }`, + }, + { + note: "two dots: function with value", + rule: `foo.bar("baz") = y`, + exp: `foo.bar("baz") = y { true }`, + }, + { + note: "non-ground ref: complete doc", + rule: `foo.bar[i].baz { i := 1 }`, + exp: `foo.bar[i].baz { i := 1 }`, + }, + { + note: "non-ground ref: partial set", + rule: `foo.bar[i].baz[x] { i := 1; x := 2 }`, + exp: `foo.bar[i].baz[x] { i := 1; x := 2 }`, + }, + { + note: "non-ground ref: partial object", + rule: `foo.bar[i].baz[x] = 3 { i := 1; x := 2 }`, + exp: `foo.bar[i].baz[x] = 3 { i := 1; x := 2 }`, + }, + { + note: "non-ground ref: function", + rule: `foo.bar[i].baz(x) = 3 { i := 1 }`, + exp: `foo.bar[i].baz(x) = 3 { i := 1 }`, + }, + { + note: "last term is number: partial set", + rule: `foo.bar.baz[3] { true }`, + exp: `foo.bar.baz[3] { true }`, }, + } + for _, tc := range tests { + t.Run(tc.note, func(t *testing.T) { + r, err := ParseRuleWithOpts(tc.exp, opts) + if err != nil { + t.Fatal(err) + } + + testModule := "package a.b.c\n" + tc.rule + m, err := ParseModuleWithOpts("", testModule, opts) + if err != nil { + t.Fatal(err) + } + mr := m.Rules[0] + + if r.Head.Name.Compare(mr.Head.Name) != 0 { + t.Errorf("rule.Head.Name differs:\n exp = %#v\nrule = %#v", r.Head.Name, mr.Head.Name) + } + if r.Head.Ref().Compare(mr.Head.Ref()) != 0 { + t.Errorf("rule.Head.Ref() differs:\n exp = %v\nrule = %v", r.Head.Ref(), mr.Head.Ref()) + } + exp, err := ParseRuleWithOpts(tc.exp, opts) + if err != nil { + t.Fatal(err) + } + assertParseModule(t, tc.note, testModule, &Module{ + Package: MustParseStatement(`package a.b.c`).(*Package), + Rules: []*Rule{exp}, + }, opts) + }) + } + + // edge cases + t.Run("errors", func(t *testing.T) { + t.Run("naked 'data' ref", func(t *testing.T) { + _, err := ParseModuleWithOpts("", "package a.b.c\ndata", opts) + assertErrorWithMessage(t, err, "refs cannot be used for rule head") + }) + t.Run("naked 'input' ref", func(t *testing.T) { + _, err := ParseModuleWithOpts("", "package a.b.c\ninput", opts) + assertErrorWithMessage(t, err, "refs cannot be used for rule head") + }) }) +} + +func assertErrorWithMessage(t *testing.T, err error, msg string) { + t.Helper() + var errs Errors + if !errors.As(err, &errs) { + t.Fatalf("expected Errors, got %v %[1]T", err) + } + if exp, act := 1, len(errs); exp != act { + t.Fatalf("expected %d errors, got %d", exp, act) + } + e := errs[0] + if exp, act := msg, e.Message; exp != act { + t.Fatalf("expected error message %q, got %q", exp, act) + } +} + +func TestRuleFromBody(t *testing.T) { + tests := []struct { + input string + exp string + }{ + {`pi = 3.14159`, `pi = 3.14159 { true }`}, + {`p[x] { x = 1 }`, `p[x] { x = 1 }`}, + {`greeting = "hello"`, `greeting = "hello" { true }`}, + {`cores = [{0: 1}, {1: 2}]`, `cores = [{0: 1}, {1: 2}] { true }`}, + {`wrapper = cores[0][1]`, `wrapper = cores[0][1] { true }`}, + {`pi = [3, 1, 4, x, y, z]`, `pi = [3, 1, 4, x, y, z] { true }`}, + {`foo["bar"] = "buz"`, `foo["bar"] = "buz" { true }`}, + {`foo["9"] = "10"`, `foo["9"] = "10" { true }`}, + {`foo.buz = "bar"`, `foo["buz"] = "bar" { true }`}, + {`bar[1]`, `bar[1] { true }`}, + {`bar[[{"foo":"baz"}]]`, `bar[[{"foo":"baz"}]] { true }`}, + {`bar.qux`, `bar["qux"] { true }`}, + {`input = 1`, `input = 1 { true }`}, + {`data = 2`, `data = 2 { true }`}, + {`f(1) = 2`, `f(1) = 2 { true }`}, + {`f(1)`, `f(1) = true { true }`}, + {`d1 := 1234`, "d1 := 1234 { true }"}, + } + + for _, tc := range tests { + t.Run(tc.input, func(t *testing.T) { + testModule := "package a.b.c\n" + tc.input + assertParseModule(t, tc.input, testModule, &Module{ + Package: MustParseStatement(`package a.b.c`).(*Package), + Rules: []*Rule{ + MustParseRule(tc.exp), + }, + }) + }) + } // Verify the rule and rule and rule head col/loc values + testModule := "package a.b.c\n\n" + for _, tc := range tests { + testModule += tc.input + "\n" + } module, err := ParseModule("test.rego", testModule) if err != nil { t.Fatal(err) @@ -1983,19 +3091,19 @@ d1 := 1234 for i := range module.Rules { col := module.Rules[i].Location.Col if col != 1 { - t.Fatalf("expected rule %v column to be 1 but got %v", module.Rules[i].Head.Name, col) + t.Errorf("expected rule %v column to be 1 but got %v", module.Rules[i].Head.Name, col) } row := module.Rules[i].Location.Row - if row != 3+i { // 'pi' rule stats on row 3 - t.Fatalf("expected rule %v row to be %v but got %v", module.Rules[i].Head.Name, 3+i, row) + if row != 3+i { // 'pi' rule starts on row 3 + t.Errorf("expected rule %v row to be %v but got %v", module.Rules[i].Head.Name, 3+i, row) } col = module.Rules[i].Head.Location.Col if col != 1 { - t.Fatalf("expected rule head %v column to be 1 but got %v", module.Rules[i].Head.Name, col) + t.Errorf("expected rule head %v column to be 1 but got %v", module.Rules[i].Head.Name, col) } row = module.Rules[i].Head.Location.Row - if row != 3+i { // 'pi' rule stats on row 3 - t.Fatalf("expected rule head %v row to be %v but got %v", module.Rules[i].Head.Name, 3+i, row) + if row != 3+i { // 'pi' rule starts on row 3 + t.Errorf("expected rule head %v row to be %v but got %v", module.Rules[i].Head.Name, 3+i, row) } } @@ -2036,16 +3144,6 @@ data = {"bar": 2}` foo = input with input as 1 ` - badRefLen1 := ` - package a.b.c - - p["x"].y = 1` - - badRefLen2 := ` - package a.b.c - - p["x"].y` - negated := ` package a.b.c @@ -2104,8 +3202,6 @@ data = {"bar": 2}` assertParseModuleError(t, "non-equality", nonEquality) assertParseModuleError(t, "non-var name", nonVarName) assertParseModuleError(t, "with expr", withExpr) - assertParseModuleError(t, "bad ref (too long)", badRefLen1) - assertParseModuleError(t, "bad ref (too long)", badRefLen2) assertParseModuleError(t, "negated", negated) assertParseModuleError(t, "non ref term", nonRefTerm) assertParseModuleError(t, "zero args", zeroArgs) @@ -2169,7 +3265,8 @@ func TestWildcards(t *testing.T) { assertParseRule(t, "functions", `f(_) = y { true }`, &Rule{ Head: &Head{ - Name: Var("f"), + Name: Var("f"), + Reference: Ref{VarTerm("f")}, Args: Args{ VarTerm("$0"), }, @@ -2179,6 +3276,56 @@ func TestWildcards(t *testing.T) { }) } +func TestRuleFromBodyJSONOptions(t *testing.T) { + tests := []string{ + `pi = 3.14159`, + `p[x] { x = 1 }`, + `greeting = "hello"`, + `cores = [{0: 1}, {1: 2}]`, + `wrapper = cores[0][1]`, + `pi = [3, 1, 4, x, y, z]`, + `foo["bar"] = "buz"`, + `foo["9"] = "10"`, + `foo.buz = "bar"`, + `foo.fizz.buzz`, + `bar[1]`, + `bar[[{"foo":"baz"}]]`, + `bar.qux`, + `input = 1`, + `data = 2`, + `f(1) = 2`, + `f(1)`, + `d1 := 1234`, + } + + parserOpts := ParserOptions{ProcessAnnotation: true} + parserOpts.JSONOptions = &JSONOptions{ + MarshalOptions: JSONMarshalOptions{ + IncludeLocation: NodeToggle{ + Term: true, + Package: true, + Comment: true, + Import: true, + Rule: true, + Head: true, + Expr: true, + SomeDecl: true, + Every: true, + With: true, + Annotations: true, + AnnotationsRef: true, + }, + }, + } + + for _, tc := range tests { + t.Run(tc, func(t *testing.T) { + testModule := "package a.b.c\n" + tc + assertParseModuleJSONOptions(t, tc, testModule, parserOpts) + }) + } +} + func TestRuleModulePtr(t *testing.T) { mod := `package test @@ -2229,6 +3376,97 @@ func TestNoMatchError(t *testing.T) { } } +func TestBraceBracketParenMatchingErrors(t *testing.T) { + // Checks to prevent regression on issue #4672. + // Error location is important here, which is why we check + // the error strings directly. + tests := []struct { + note string + err string + input string + }{ + { + note: "Unmatched ')' case", + err: `1 error occurred: test.rego:4: rego_parse_error: unexpected , token: expected \n or ; or } + y := contains("a"), "b") + ^`, + input: `package test +p { + x := 5 + y := contains("a"), "b") +}`, + }, + { + note: "Unmatched '}' case", + err: `1 error occurred: test.rego:4: rego_parse_error: unexpected , token: expected \n or ; or } + y := {"a", "b", "c"}, "a"} + ^`, + input: `package test +p { + x := 5 + y := {"a", "b", "c"}, "a"} +}`, + }, + { + note: "Unmatched ']' case", + err: `1 error occurred: test.rego:4: rego_parse_error: unexpected , token: expected \n or ; or } + y := ["a", "b", "c"], "a"] + ^`, + input: `package test +p { + x := 5 + y := ["a", "b", "c"], "a"] +}`, + }, + { + note: "Unmatched '(' case", + err: `1 error occurred: test.rego:5: rego_parse_error: unexpected } token: expected "," or ")" + } + ^`, + input: `package test +p { + x := 5 + y := contains("a", "b" +}`, + }, + { + note: "Unmatched '{' case", + + err: `1 error occurred: test.rego:5: rego_parse_error: unexpected eof token: expected \n or ; or } + } + ^`, + input: `package test +p { + x := 5 + y := {{"a", "b", "c"}, "a" +}`, + }, + { + note: "Unmatched '[' case", + err: `1 error occurred: test.rego:5: rego_parse_error: unexpected } token: expected "," or "]" + } + ^`, + input: `package test +p { + x := 5 + y := [["a", "b", "c"], "a" +}`, + }, + } + + for _, tc := range tests { + t.Run(tc.note, func(t *testing.T) { + _, err := ParseModule("test.rego", tc.input) + if err == nil { + t.Fatal("Expected error") + } + if tc.err != "" && tc.err != err.Error() { + t.Fatalf("Expected error string %q but got: %q", tc.err, err.Error()) + } + }) + } +} + func TestParseErrorDetails(t *testing.T) { tests := []struct { @@ -2799,6 +4037,7 @@ func TestAnnotations(t *testing.T) { expNumComments int expAnnotations []*Annotations expError string + expErrorRow int }{ { note: "Single valid annotation", @@ -3012,6 +4251,14 @@ public_servers[server] { }`, expNumComments: 5, expError: "rego_parse_error: yaml: unmarshal errors:\n line 3: cannot unmarshal !!str", + expErrorRow: 11, + }, + { + note: "Ill-structured (invalid) annotation with control character (vertical tab)", + module: "# METADATA\n" + + "# title: foo\vbar\n" + + "package opa.examples\n", + expError: "rego_parse_error: yaml: control characters are not allowed", }, { note: "Indentation error in yaml", @@ -3104,20 +4351,42 @@ public_servers_1[server] { expNumComments: 9, expAnnotations: []*Annotations{ { - Schemas: []*SchemaAnnotation{ - {Path: dataServers, Schema: schemaServers}, - }, + Schemas: []*SchemaAnnotation{ + {Path: dataServers, Schema: schemaServers}, + }, + Scope: annotationScopeRule, + node: MustParseRule(`public_servers[server] { server = servers[i] }`), + }, + { + Schemas: []*SchemaAnnotation{ + + {Path: dataNetworks, Schema: schemaNetworks}, + {Path: dataPorts, Schema: schemaPorts}, + }, + Scope: annotationScopeRule, + node: MustParseRule(`public_servers_1[server] { ports[k].networks[l] = networks[m].id; networks[m].public = true }`), + }, + }, + }, + { + note: "multiple metadata blocks on a single rule", + module: `package test + +# METADATA +# title: My rule + +# METADATA +# title: My rule 2 +p { input = "str" }`, + expNumComments: 4, + expAnnotations: []*Annotations{ + { Scope: annotationScopeRule, - node: MustParseRule(`public_servers[server] { server = servers[i] }`), + Title: "My rule", }, { - Schemas: []*SchemaAnnotation{ - - {Path: dataNetworks, Schema: schemaNetworks}, - {Path: dataPorts, Schema: schemaPorts}, - }, Scope: annotationScopeRule, - node: MustParseRule(`public_servers_1[server] { ports[k].networks[l] = networks[m].id; networks[m].public = true }`), + Title: "My rule 2", }, }, }, @@ -3144,6 +4413,33 @@ p := 7`, {Scope: annotationScopeRule}, }, }, + { + note: "annotation on package", + module: `# METADATA +# title: My package +package test + +p { input = "str" }`, + expNumComments: 2, + expAnnotations: []*Annotations{ + { + Scope: annotationScopePackage, + Title: "My package", + }, + }, + }, + { + note: "annotation on import", + module: `package test + +# METADATA +# title: My import +import input.foo + +p { input = "str" }`, + expNumComments: 2, + expError: "1 error occurred: test.rego:3: rego_parse_error: invalid annotation scope 'import'", + }, { note: "Default rule scope", module: ` @@ -3221,6 +4517,87 @@ p { input = "str" }`, }, }, }, + { + note: "Rich meta", + module: `package test + +# METADATA +# title: My rule +# description: | +# My rule has a +# multiline description. +# organizations: +# - Acme Corp. +# - Soylent Corp. +# - Tyrell Corp. +# related_resources: +# - https://example.com +# - +# ref: http://john:123@do.re/mi?foo=bar#baz +# description: foo bar +# authors: +# - John Doe +# - name: Jane Doe +# email: jane@example.com +# custom: +# list: +# - a +# - b +# map: +# a: 1 +# b: 2.2 +# c: +# "3": d +# "4": e +# number: 42 +# string: foo bar baz +# flag: +p { input = "str" }`, + expNumComments: 31, + expAnnotations: []*Annotations{ + { + Scope: annotationScopeRule, + Title: "My rule", + Description: "My rule has a\nmultiline description.\n", + Organizations: []string{"Acme Corp.", "Soylent Corp.", "Tyrell Corp."}, + RelatedResources: []*RelatedResourceAnnotation{ + { + Ref: mustParseURL("https://example.com"), + }, + { + Ref: mustParseURL("http://john:123@do.re/mi?foo=bar#baz"), + Description: "foo bar", + }, + }, + Authors: []*AuthorAnnotation{ + { + Name: "John Doe", + Email: "john@example.com", + }, + { + Name: "Jane Doe", + Email: "jane@example.com", + }, + }, + Custom: map[string]interface{}{ + "list": []interface{}{ + "a", "b", + }, + "map": map[string]interface{}{ + "a": 1, + "b": 2.2, + "c": map[string]interface{}{ + "3": "d", + "4": "e", + }, + }, + "number": 42, + "string": "foo bar baz", + "flag": nil, + }, + }, + }, + }, } for _, tc := range tests { @@ -3232,6 +4609,15 @@ p { input = "str" }`, if tc.expError == "" || !strings.Contains(err.Error(), tc.expError) { t.Fatalf("Unexpected parse error when getting annotations: %v", err) } + if tc.expErrorRow != 0 { + if errs, ok := err.(Errors); !ok { + t.Fatalf("expected ast.Errors, got %v", err) + } else if len(errs) != 1 { + t.Fatalf("expected exactly one ast.Error, got %v: %v", len(errs), errs) + } else if loc := errs[0].Location; tc.expErrorRow != loc.Row { + t.Fatalf("expected error location row %v, got %v", tc.expErrorRow, loc.Row) + } + } return } else if tc.expError != "" { t.Fatalf("Expected err: %v but no error from parse module", tc.expError) @@ -3248,6 +4634,398 @@ p { input = "str" }`, } } +func TestAnnotationsAugmentedError(t *testing.T) { + tests := []struct { + note string + module string + expAnnotations []*Annotations + expErrorHint string + expErrorRow int + }{ + { + note: "no whitespace after key/value separator", + module: "package opa.examples\n" + + "\n" + + "# METADATA\n" + + "# description:p is true\n" + + "p := true\n", + expErrorHint: "Hint: on line 4, symbol(s) ['p'] immediately following a key/value separator ':' is not a legal yaml space character", + expErrorRow: 4, + }, + { + note: "non-breaking whitespace (\\u00A0) after key/value separator", + module: "package opa.examples\n" + + "\n" + + "# METADATA\n" + + "# description:\u00A0p is true\n" + + "p := true\n", + expErrorHint: "Hint: on line 4, symbol(s) ['\\u00a0'] immediately following a key/value separator ':' is not a legal yaml space character", + expErrorRow: 4, + }, + { + note: "non-breaking whitespace (\\u00A0) after key/value separator (different line)", + module: "package opa.examples\n" + + "\n" + + "# METADATA\n" + + "# title: P\n" + + "# description:\u00A0p is true\n" + + "p := true\n", + expErrorHint: "Hint: on line 5, symbol(s) ['\\u00a0'] immediately following a key/value separator ':' is not a legal yaml space character", + expErrorRow: 5, + }, + { + note: "non-breaking whitespace (\\u00A0) after key/value separator (different line)", + module: "package opa.examples\n" + + "\n" + + "# METADATA\n" + + "# title:\n" + + "# P\n" + + "# description:\u00A0p is true\n" + + "p := true\n", + expErrorHint: "Hint: on line 6, symbol(s) ['\\u00a0'] immediately following a key/value separator ':' is not a legal yaml space character", + expErrorRow: 6, + }, + { + note: "non-breaking whitespace (\\u00A0) after key/value separator (different line)", + module: "package opa.examples\n" + + "\n" + + "# METADATA\n" + + "# title:\u00A0P\n" + + "# description: p is true\n" + + "# scope: rule\n" + + "p := true\n", + expErrorHint: "Hint: on line 4, symbol(s) ['\\u00a0'] immediately following a key/value separator ':' is not a legal yaml space character", + expErrorRow: 5, + }, + { + note: "thin whitespace (\\u2009) after key/value separator", + module: "package opa.examples\n" + + "\n" + + "# METADATA\n" + + "# description:\u2009p is true\n" + + "p := true\n", + expErrorHint: "Hint: on line 4, symbol(s) ['\\u2009'] immediately following a key/value separator ':' is not a legal yaml space character", + expErrorRow: 4, + }, + { + note: "ideographic whitespace (\\u3000) after key/value separator", + module: "package opa.examples\n" + + "\n" + + "# METADATA\n" + + "# description:\u3000p is true\n" + + "p := true\n", + expErrorHint: "Hint: on line 4, symbol(s) ['\\u3000'] immediately following a key/value separator ':' is not a legal yaml space character", + expErrorRow: 4, + }, + { + note: "several offending runes after key/value separator on single line", + module: "package opa.examples\n" + + "\n" + + "# METADATA\n" + + "# descr:iption:\u3000p is true\n" + + "p := true\n", + expErrorHint: "Hint: on line 4, symbol(s) ['i' '\\u3000'] immediately following a key/value separator ':' is not a legal yaml space character", + expErrorRow: 4, + }, + { + note: "several offending runes after key/value separator on single line", + module: "package opa.examples\n" + + "\n" + + "# METADATA\n" + + "# title:\u3000p\n" + + "# scope: rule\n" + + "# description:\u2009p is true\n" + + "p := true\n", + expErrorHint: "Hint: on line 4, symbol(s) ['\\u3000'] immediately following a key/value separator ':' is not a legal yaml space character\n" + + " Hint: on line 6, symbol(s) ['\\u2009'] immediately following a key/value separator ':' is not a legal yaml space character", + expErrorRow: 5, + }, + } + + for _, tc := range tests { + t.Run(tc.note, func(t *testing.T) { + _, err := ParseModuleWithOpts("test.rego", tc.module, ParserOptions{ + ProcessAnnotation: true, + }) + + if err == nil { + t.Fatalf("Expected err with hint: %v but no error from parse module", tc.expErrorHint) + } + + if !strings.Contains(err.Error(), tc.expErrorHint) { + t.Fatalf("Unexpected parse error when getting annotations: %v", err) + } + + if errs, ok := err.(Errors); !ok { + t.Fatalf("expected ast.Errors, got %v", err) + } else if len(errs) != 1 { + t.Fatalf("expected exactly one ast.Error, got %v: %v", len(errs), errs) + } else if loc := errs[0].Location; tc.expErrorRow != loc.Row { + t.Fatalf("expected error location row %v, got %v", tc.expErrorRow, loc.Row) + } + }) + } +} + +func TestAuthorAnnotation(t *testing.T) { + tests := []struct { + note string + raw interface{} + expected interface{} + }{ + { + note: "no name", + raw: "", + expected: fmt.Errorf("author is an empty string"), + }, + { + note: "only whitespaces", + raw: " \t", + expected: fmt.Errorf("author is an empty string"), + }, + { + note: "one name only", + raw: "John", + expected: AuthorAnnotation{Name: "John"}, + }, + { + note: "multiple names", + raw: "John Jr.\tDoe", + expected: AuthorAnnotation{Name: "John Jr. Doe"}, + }, + { + note: "email only", + raw: "", + expected: AuthorAnnotation{Email: "john@example.com"}, + }, + { + note: "name and email", + raw: "John Doe ", + expected: AuthorAnnotation{Name: "John Doe", Email: "john@example.com"}, + }, + { + note: "empty email", + raw: "John Doe <>", + expected: AuthorAnnotation{Name: "John Doe"}, + }, + { + note: "name with reserved characters", + raw: "John Doe < >", + expected: AuthorAnnotation{Name: "John Doe < >"}, + }, + { + note: "name with reserved characters (email with space)", + raw: "", + expected: AuthorAnnotation{Name: ""}, + }, + { + note: "map with name", + raw: map[string]interface{}{ + "name": "John Doe", + }, + expected: AuthorAnnotation{Name: "John Doe"}, + }, + { + note: "map with email", + raw: map[string]interface{}{ + "email": "john@example.com", + }, + expected: AuthorAnnotation{Email: "john@example.com"}, + }, + { + note: "map with name and email", + raw: map[string]interface{}{ + "name": "John Doe", + "email": "john@example.com", + }, + expected: AuthorAnnotation{Name: "John Doe", Email: "john@example.com"}, + }, + { + note: "map with extra entry", + raw: map[string]interface{}{ + "name": "John Doe", + "email": "john@example.com", + "foo": "bar", + }, + expected: AuthorAnnotation{Name: "John Doe", Email: "john@example.com"}, + }, + { + note: "empty map", + raw: map[string]interface{}{}, + expected: fmt.Errorf("'name' and/or 'email' values required in object"), + }, + { + note: "map with empty name", + raw: map[string]interface{}{ + "name": "", + }, + expected: fmt.Errorf("'name' and/or 'email' values required in object"), + }, + { + note: "map with email and empty name", + raw: map[string]interface{}{ + "name": "", + "email": "john@example.com", + }, + expected: AuthorAnnotation{Email: "john@example.com"}, + }, + { + note: "map with empty email", + raw: map[string]interface{}{ + "email": "", + }, + expected: fmt.Errorf("'name' and/or 'email' values required in object"), + }, + { + note: "map with name and empty email", + raw: map[string]interface{}{ + "name": "John Doe", + "email": "", + }, + expected: AuthorAnnotation{Name: "John Doe"}, + }, + } + + for _, tc := range tests { + t.Run(tc.note, func(t *testing.T) { + parsed, err := parseAuthor(tc.raw) + + switch expected := tc.expected.(type) { + case AuthorAnnotation: + if err != nil { + t.Fatal(err) + } + + if parsed.Compare(&expected) != 0 { + t.Fatalf("expected %v but got %v", tc.expected, parsed) + } + case error: + if err == nil { + t.Fatalf("expected '%v' error but got %v", tc.expected, parsed) + } + + if strings.Compare(expected.Error(), err.Error()) != 0 { + t.Fatalf("expected %v but got %v", tc.expected, err) + } + default: + t.Fatalf("Unexpected result type: %T", expected) + } + }) + } +} + +func TestRelatedResourceAnnotation(t *testing.T) { + tests := []struct { + note string + raw interface{} + expected interface{} + }{ + { + note: "empty ref URL", + raw: "", + expected: fmt.Errorf("ref URL may not be empty string"), + }, + { + note: "only whitespaces in ref URL", + raw: " \t", + expected: fmt.Errorf("parse \" \\t\": net/url: invalid control character in URL"), + }, + { + note: "invalid ref URL", + raw: "https://foo:bar", + expected: fmt.Errorf("parse \"https://foo:bar\": invalid port \":bar\" after host"), + }, + { + note: "ref URL as string", + raw: "https://example.com/foo?bar#baz", + expected: RelatedResourceAnnotation{Ref: mustParseURL("https://example.com/foo?bar#baz")}, + }, + { + note: "map with only ref", + raw: map[string]interface{}{ + "ref": "https://example.com/foo?bar#baz", + }, + expected: RelatedResourceAnnotation{Ref: mustParseURL("https://example.com/foo?bar#baz")}, + }, + { + note: "map with only description", + raw: map[string]interface{}{ + "description": "foo bar", + }, + expected: fmt.Errorf("'ref' value required in object"), + }, + { + note: "map with ref and description", + raw: map[string]interface{}{ + "ref": "https://example.com/foo?bar#baz", + "description": "foo bar", + }, + expected: RelatedResourceAnnotation{ + Ref: mustParseURL("https://example.com/foo?bar#baz"), + Description: "foo bar", + }, + }, + { + note: "map with ref and description", + raw: map[string]interface{}{ + "ref": "https://example.com/foo?bar#baz", + "description": "foo bar", + "foo": "bar", + }, + expected: RelatedResourceAnnotation{ + Ref: mustParseURL("https://example.com/foo?bar#baz"), + Description: "foo bar", + }, + }, + { + note: "empty map", + raw: map[string]interface{}{}, + expected: fmt.Errorf("'ref' value required in object"), + }, + { + note: "map with empty ref", + raw: map[string]interface{}{ + "ref": "", + }, + expected: fmt.Errorf("'ref' value required in object"), + }, + { + note: "map with only whitespace in ref", + raw: map[string]interface{}{ + "ref": " \t", + }, + expected: fmt.Errorf("'ref' value required in object"), + }, + } + + for _, tc := range tests { + t.Run(tc.note, func(t *testing.T) { + parsed, err := parseRelatedResource(tc.raw) + + switch expected := tc.expected.(type) { + case RelatedResourceAnnotation: + if err != nil { + t.Fatal(err) + } + + if parsed.Compare(&expected) != 0 { + t.Fatalf("expected %v but got %v", tc.expected, parsed) + } + case error: + if err == nil { + t.Fatalf("expected '%v' error but got %v", tc.expected, parsed) + } + + if strings.Compare(expected.Error(), err.Error()) != 0 { + t.Fatalf("expected %v but got %v", tc.expected, err) + } + default: + t.Fatalf("Unexpected result type: %T", expected) + } + }) + } +} + func assertLocationText(t *testing.T, expected string, actual *Location) { t.Helper() if actual == nil || actual.Text == nil { @@ -3266,21 +5044,28 @@ func assertParseError(t *testing.T, msg string, input string) { }) } -func assertParseErrorContains(t *testing.T, msg string, input string, expected string) { +func assertParseErrorContains(t *testing.T, msg string, input string, expected string, opts ...ParserOptions) { t.Helper() assertParseErrorFunc(t, msg, input, func(result string) { t.Helper() if !strings.Contains(result, expected) { t.Errorf("Error on test \"%s\": expected parse error to contain:\n\n%v\n\nbut got:\n\n%v", msg, expected, result) } - }) + }, opts...) } -func assertParseErrorFunc(t *testing.T, msg string, input string, f func(string)) { +func assertParseErrorFunc(t *testing.T, msg string, input string, f func(string), opts ...ParserOptions) { t.Helper() - p, err := ParseStatement(input) + opt := ParserOptions{} + if len(opts) == 1 { + opt = opts[0] + } + stmts, _, err := ParseStatementsWithOpts("", input, opt) + if err == nil && len(stmts) != 1 { + err = fmt.Errorf("expected exactly one statement") + } if err == nil { - t.Errorf("Error on test \"%s\": expected parse error but parsed successfully:\n\n%v\n\n(parsed)", msg, p) + t.Errorf("Error on test \"%s\": expected parse error on %s: expected no statements, got %d: %v", msg, input, len(stmts), stmts) return } result := err.Error() @@ -3290,7 +5075,7 @@ func assertParseErrorFunc(t *testing.T, msg string, input string, f func(string) f(result) } -func assertParseImport(t *testing.T, msg string, input string, correct *Import) { +func assertParseImport(t *testing.T, msg string, input string, correct *Import, opts ...ParserOptions) { t.Helper() assertParseOne(t, msg, input, func(parsed interface{}) { t.Helper() @@ -3298,12 +5083,15 @@ func assertParseImport(t *testing.T, msg string, input string, correct *Import) if !imp.Equal(correct) { t.Errorf("Error on test \"%s\": imports not equal: %v (parsed), %v (correct)", msg, imp, correct) } - }) + }, opts...) } -func assertParseModule(t *testing.T, msg string, input string, correct *Module) { - - m, err := ParseModule("", input) +func assertParseModule(t *testing.T, msg string, input string, correct *Module, opts ...ParserOptions) { + opt := ParserOptions{} + if len(opts) == 1 { + opt = opts[0] + } + m, err := ParseModuleWithOpts("", input, opt) if err != nil { t.Errorf("Error on test \"%s\": parse error on %s: %s", msg, input, err) return @@ -3315,6 +5103,53 @@ func assertParseModule(t *testing.T, msg string, input string, correct *Module) } +func assertParseModuleJSONOptions(t *testing.T, msg string, input string, opts ...ParserOptions) { + opt := ParserOptions{} + if len(opts) == 1 { + opt = opts[0] + } + m, err := ParseModuleWithOpts("", input, opt) + if err != nil { + t.Errorf("Error on test \"%s\": parse error on %s: %s", msg, input, err) + return + } + + if len(m.Rules) != 1 { + t.Fatalf("Error on test \"%s\": expected 1 rule but got %d", msg, len(m.Rules)) + } + + rule := m.Rules[0] + if rule.Head.jsonOptions != *opt.JSONOptions { + t.Fatalf("Error on test \"%s\": expected rule Head JSONOptions\n%v\n, got\n%v", msg, *opt.JSONOptions, rule.Head.jsonOptions) + } + if rule.Body[0].jsonOptions != *opt.JSONOptions { + t.Fatalf("Error on test \"%s\": expected rule Body JSONOptions\n%v\n, got\n%v", msg, *opt.JSONOptions, rule.Body[0].jsonOptions) + } + switch terms := rule.Body[0].Terms.(type) { + case []*Term: + for _, term := range terms { + if term.jsonOptions != *opt.JSONOptions { + t.Fatalf("Error on test \"%s\": expected body Term JSONOptions\n%v\n, got\n%v", msg, *opt.JSONOptions, term.jsonOptions) + } + } + case *SomeDecl: + if terms.jsonOptions != *opt.JSONOptions { + t.Fatalf("Error on test \"%s\": expected body Term JSONOptions\n%v\n, got\n%v", msg, *opt.JSONOptions, terms.jsonOptions) + } + case *Every: + if terms.jsonOptions != *opt.JSONOptions { + t.Fatalf("Error on test \"%s\": expected body Term JSONOptions\n%v\n, got\n%v", msg, *opt.JSONOptions, terms.jsonOptions) + } + case *Term: + if terms.jsonOptions != *opt.JSONOptions { + t.Fatalf("Error on test \"%s\": expected body Term JSONOptions\n%v\n, got\n%v", msg, *opt.JSONOptions, terms.jsonOptions) + } + } + if rule.jsonOptions != *opt.JSONOptions { + t.Fatalf("Error on test \"%s\": expected rule JSONOptions\n%v\n, got\n%v", msg, *opt.JSONOptions, rule.jsonOptions) + } +} + func assertParseModuleError(t *testing.T, msg, input string) { m, err := ParseModule("", input) if err == nil { @@ -3338,14 +5173,14 @@ func assertParseOne(t *testing.T, msg string, input string, correct func(interfa opt = opts[0] } stmts, _, err := ParseStatementsWithOpts("", input, opt) - if len(stmts) != 1 { - t.Errorf("Error on test \"%s\": parse error on %s: expected exactly one statement, got %d", msg, input, len(stmts)) - return - } if err != nil { t.Errorf("Error on test \"%s\": parse error on %s: %s", msg, input, err) return } + if len(stmts) != 1 { + t.Errorf("Error on test \"%s\": parse error on %s: expected exactly one statement, got %d: %v", msg, input, len(stmts), stmts) + return + } correct(stmts[0]) } @@ -3393,13 +5228,23 @@ func assertParseOneTermNegated(t *testing.T, msg string, input string, correct * assertParseOneExprNegated(t, msg, input, &Expr{Terms: correct}) } -func assertParseRule(t *testing.T, msg string, input string, correct *Rule) { +func assertParseRule(t *testing.T, msg string, input string, correct *Rule, opts ...ParserOptions) { t.Helper() assertParseOne(t, msg, input, func(parsed interface{}) { t.Helper() rule := parsed.(*Rule) + if rule.Head.Name != correct.Head.Name { + t.Errorf("Error on test \"%s\": rule heads not equal: name = %v (parsed), name = %v (correct)", msg, rule.Head.Name, correct.Head.Name) + } + if !rule.Head.Ref().Equal(correct.Head.Ref()) { + t.Errorf("Error on test \"%s\": rule heads not equal: ref = %v (parsed), ref = %v (correct)", msg, rule.Head.Ref(), correct.Head.Ref()) + } + if !rule.Head.Equal(correct.Head) { + t.Errorf("Error on test \"%s\": rule heads not equal: %v (parsed), %v (correct)", msg, rule.Head, correct.Head) + } if !rule.Equal(correct) { t.Errorf("Error on test \"%s\": rules not equal: %v (parsed), %v (correct)", msg, rule, correct) } - }) + }, + opts...) } diff --git a/ast/policy.go b/ast/policy.go index 26b6d78e0d..caf756f6aa 100644 --- a/ast/policy.go +++ b/ast/policy.go @@ -144,68 +144,66 @@ type ( Annotations []*Annotations `json:"annotations,omitempty"` Rules []*Rule `json:"rules,omitempty"` Comments []*Comment `json:"comments,omitempty"` + stmts []Statement } // Comment contains the raw text from the comment in the definition. Comment struct { + // TODO: these fields have inconsistent JSON keys with other structs in this package. Text []byte Location *Location - } - - // Annotations represents metadata attached to other AST nodes such as rules. - Annotations struct { - Location *Location `json:"-"` - Scope string `json:"scope"` - Schemas []*SchemaAnnotation `json:"schemas,omitempty"` - node Node - } - // SchemaAnnotation contains a schema declaration for the document identified by the path. - SchemaAnnotation struct { - Path Ref `json:"path"` - Schema Ref `json:"schema,omitempty"` - Definition *interface{} `json:"definition,omitempty"` + jsonOptions JSONOptions } // Package represents the namespace of the documents produced // by rules inside the module. Package struct { - Location *Location `json:"-"` Path Ref `json:"path"` + Location *Location `json:"location,omitempty"` + + jsonOptions JSONOptions } // Import represents a dependency on a document outside of the policy // namespace. Imports are optional. Import struct { - Location *Location `json:"-"` Path *Term `json:"path"` Alias Var `json:"alias,omitempty"` + Location *Location `json:"location,omitempty"` + + jsonOptions JSONOptions } // Rule represents a rule as defined in the language. Rules define the // content of documents that represent policy decisions. Rule struct { - Location *Location `json:"-"` Default bool `json:"default,omitempty"` Head *Head `json:"head"` Body Body `json:"body"` Else *Rule `json:"else,omitempty"` + Location *Location `json:"location,omitempty"` // Module is a pointer to the module containing this rule. If the rule // was NOT created while parsing/constructing a module, this should be // left unset. The pointer is not included in any standard operations // on the rule (e.g., printing, comparison, visiting, etc.) Module *Module `json:"-"` + + jsonOptions JSONOptions } // Head represents the head of a rule. Head struct { - Location *Location `json:"-"` - Name Var `json:"name"` - Args Args `json:"args,omitempty"` - Key *Term `json:"key,omitempty"` - Value *Term `json:"value,omitempty"` - Assign bool `json:"assign,omitempty"` + Name Var `json:"name,omitempty"` + Reference Ref `json:"ref,omitempty"` + Args Args `json:"args,omitempty"` + Key *Term `json:"key,omitempty"` + Value *Term `json:"value,omitempty"` + Assign bool `json:"assign,omitempty"` + Location *Location `json:"location,omitempty"` + + jsonOptions JSONOptions } // Args represents zero or more arguments to a rule. @@ -219,141 +217,41 @@ type ( Expr struct { With []*With `json:"with,omitempty"` Terms interface{} `json:"terms"` - Location *Location `json:"-"` Index int `json:"index"` Generated bool `json:"generated,omitempty"` Negated bool `json:"negated,omitempty"` + Location *Location `json:"location,omitempty"` + + jsonOptions JSONOptions } // SomeDecl represents a variable declaration statement. The symbols are variables. SomeDecl struct { - Location *Location `json:"-"` Symbols []*Term `json:"symbols"` - } - - // With represents a modifier on an expression. - With struct { - Location *Location `json:"-"` - Target *Term `json:"target"` - Value *Term `json:"value"` - } -) - -func (s *Annotations) String() string { - bs, _ := json.Marshal(s) - return string(bs) -} - -// Loc returns the location of this annotation. -func (s *Annotations) Loc() *Location { - return s.Location -} - -// SetLoc updates the location of this annotation. -func (s *Annotations) SetLoc(l *Location) { - s.Location = l -} - -// Compare returns an integer indicating if s is less than, equal to, or greater -// than other. -func (s *Annotations) Compare(other *Annotations) int { - - if cmp := scopeCompare(s.Scope, other.Scope); cmp != 0 { - return cmp - } - - max := len(s.Schemas) - if len(other.Schemas) < max { - max = len(other.Schemas) - } - - for i := 0; i < max; i++ { - if cmp := s.Schemas[i].Compare(other.Schemas[i]); cmp != 0 { - return cmp - } - } - - if len(s.Schemas) > len(other.Schemas) { - return 1 - } else if len(s.Schemas) < len(other.Schemas) { - return -1 - } + Location *Location `json:"location,omitempty"` - return 0 -} - -// Copy returns a deep copy of s. -func (s *Annotations) Copy(node Node) *Annotations { - cpy := *s - cpy.Schemas = make([]*SchemaAnnotation, len(s.Schemas)) - for i := range cpy.Schemas { - cpy.Schemas[i] = s.Schemas[i].Copy() + jsonOptions JSONOptions } - cpy.node = node - return &cpy -} -// Copy returns a deep copy of s. -func (s *SchemaAnnotation) Copy() *SchemaAnnotation { - cpy := *s - return &cpy -} - -// Compare returns an integer indicating if s is less than, equal to, or greater -// than other. -func (s *SchemaAnnotation) Compare(other *SchemaAnnotation) int { - - if cmp := s.Path.Compare(other.Path); cmp != 0 { - return cmp - } - - if cmp := s.Schema.Compare(other.Schema); cmp != 0 { - return cmp - } - - if s.Definition != nil && other.Definition == nil { - return -1 - } else if s.Definition == nil && other.Definition != nil { - return 1 - } else if s.Definition != nil && other.Definition != nil { - return util.Compare(*s.Definition, *other.Definition) - } - - return 0 -} - -func (s *SchemaAnnotation) String() string { - bs, _ := json.Marshal(s) - return string(bs) -} - -func scopeCompare(s1, s2 string) int { - - o1 := scopeOrder(s1) - o2 := scopeOrder(s2) - - if o2 < o1 { - return 1 - } else if o2 > o1 { - return -1 - } + Every struct { + Key *Term `json:"key"` + Value *Term `json:"value"` + Domain *Term `json:"domain"` + Body Body `json:"body"` + Location *Location `json:"location,omitempty"` - if s1 < s2 { - return -1 - } else if s2 < s1 { - return 1 + jsonOptions JSONOptions } - return 0 -} + // With represents a modifier on an expression. + With struct { + Target *Term `json:"target"` + Value *Term `json:"value"` + Location *Location `json:"location,omitempty"` -func scopeOrder(s string) int { - switch s { - case annotationScopeRule: - return 1 + jsonOptions JSONOptions } - return 0 -} +) // Compare returns an integer indicating whether mod is less than, equal to, // or greater than other. @@ -383,32 +281,22 @@ func (mod *Module) Copy() *Module { cpy := *mod cpy.Rules = make([]*Rule, len(mod.Rules)) - var nodes map[Node]Node - - if len(mod.Annotations) > 0 { - nodes = make(map[Node]Node) - } + nodes := make(map[Node]Node, len(mod.Rules)+len(mod.Imports)+1 /* package */) for i := range mod.Rules { cpy.Rules[i] = mod.Rules[i].Copy() cpy.Rules[i].Module = &cpy - if nodes != nil { - nodes[mod.Rules[i]] = cpy.Rules[i] - } + nodes[mod.Rules[i]] = cpy.Rules[i] } cpy.Imports = make([]*Import, len(mod.Imports)) for i := range mod.Imports { cpy.Imports[i] = mod.Imports[i].Copy() - if nodes != nil { - nodes[mod.Imports[i]] = cpy.Imports[i] - } + nodes[mod.Imports[i]] = cpy.Imports[i] } cpy.Package = mod.Package.Copy() - if nodes != nil { - nodes[mod.Package] = cpy.Package - } + nodes[mod.Package] = cpy.Package cpy.Annotations = make([]*Annotations, len(mod.Annotations)) for i := range mod.Annotations { @@ -420,6 +308,11 @@ func (mod *Module) Copy() *Module { cpy.Comments[i] = mod.Comments[i].Copy() } + cpy.stmts = make([]Statement, len(mod.stmts)) + for i := range mod.stmts { + cpy.stmts[i] = nodes[mod.stmts[i]] + } + return &cpy } @@ -535,6 +428,12 @@ func (c *Comment) Equal(other *Comment) bool { return c.Location.Equal(other.Location) && bytes.Equal(c.Text, other.Text) } +func (c *Comment) setJSONOptions(opts JSONOptions) { + // Note: this is not used for location since Comments use default JSON marshaling + // behavior with struct field names in JSON. + c.jsonOptions = opts +} + // Compare returns an integer indicating whether pkg is less than, equal to, // or greater than other. func (pkg *Package) Compare(other *Package) int { @@ -579,6 +478,24 @@ func (pkg *Package) String() string { return fmt.Sprintf("package %v", path) } +func (pkg *Package) setJSONOptions(opts JSONOptions) { + pkg.jsonOptions = opts +} + +func (pkg *Package) MarshalJSON() ([]byte, error) { + data := map[string]interface{}{ + "path": pkg.Path, + } + + if pkg.jsonOptions.MarshalOptions.IncludeLocation.Package { + if pkg.Location != nil { + data["location"] = pkg.Location + } + } + + return json.Marshal(data) +} + // IsValidImportPath returns an error indicating if the import path is invalid. // If the import path is invalid, err is nil. func IsValidImportPath(v Value) (err error) { @@ -671,6 +588,28 @@ func (imp *Import) String() string { return strings.Join(buf, " ") } +func (imp *Import) setJSONOptions(opts JSONOptions) { + imp.jsonOptions = opts +} + +func (imp *Import) MarshalJSON() ([]byte, error) { + data := map[string]interface{}{ + "path": imp.Path, + } + + if len(imp.Alias) != 0 { + data["alias"] = imp.Alias + } + + if imp.jsonOptions.MarshalOptions.IncludeLocation.Import { + if imp.Location != nil { + data["location"] = imp.Location + } + } + + return json.Marshal(data) +} + // Compare returns an integer indicating whether rule is less than, equal to, // or greater than other. func (rule *Rule) Compare(other *Rule) int { @@ -725,11 +664,22 @@ func (rule *Rule) SetLoc(loc *Location) { // Path returns a ref referring to the document produced by this rule. If rule // is not contained in a module, this function panics. +// Deprecated: Poor handling of ref rules. Use `(*Rule).Ref()` instead. func (rule *Rule) Path() Ref { if rule.Module == nil { panic("assertion failed") } - return rule.Module.Package.Path.Append(StringTerm(string(rule.Head.Name))) + return rule.Module.Package.Path.Extend(rule.Head.Ref().GroundPrefix()) +} + +// Ref returns a ref referring to the document produced by this rule. If rule +// is not contained in a module, this function panics. The returned ref may +// contain variables in the last position. +func (rule *Rule) Ref() Ref { + if rule.Module == nil { + panic("assertion failed") + } + return rule.Module.Package.Path.Extend(rule.Head.Ref()) } func (rule *Rule) String() string { @@ -749,6 +699,33 @@ func (rule *Rule) String() string { return strings.Join(buf, " ") } +func (rule *Rule) setJSONOptions(opts JSONOptions) { + rule.jsonOptions = opts +} + +func (rule *Rule) MarshalJSON() ([]byte, error) { + data := map[string]interface{}{ + "head": rule.Head, + "body": rule.Body, + } + + if rule.Default { + data["default"] = true + } + + if rule.Else != nil { + data["else"] = rule.Else + } + + if rule.jsonOptions.MarshalOptions.IncludeLocation.Rule { + if rule.Location != nil { + data["location"] = rule.Location + } + } + + return json.Marshal(data) +} + func (rule *Rule) elseString() string { var buf []string @@ -775,7 +752,8 @@ func (rule *Rule) elseString() string { // used for the key and the second will be used for the value. func NewHead(name Var, args ...*Term) *Head { head := &Head{ - Name: name, + Name: name, // backcompat + Reference: []*Term{NewTerm(name)}, } if len(args) == 0 { return head @@ -785,6 +763,23 @@ func NewHead(name Var, args ...*Term) *Head { return head } head.Value = args[1] + if head.Key != nil && head.Value != nil { + head.Reference = head.Reference.Append(args[0]) + } + return head +} + +// RefHead returns a new Head object with the passed Ref. If args are provided, +// the first will be used for the value. +func RefHead(ref Ref, args ...*Term) *Head { + head := &Head{} + head.SetRef(ref) + if len(ref) < 2 { + head.Name = ref[0].Value.(Var) + } + if len(args) >= 1 { + head.Value = args[0] + } return head } @@ -800,7 +795,7 @@ const ( // PartialObjectDoc represents an object document that is partially defined by the rule. PartialObjectDoc -) +) // TODO(sr): Deprecate? // DocKind returns the type of document produced by this rule. func (head *Head) DocKind() DocKind { @@ -813,6 +808,41 @@ func (head *Head) DocKind() DocKind { return CompleteDoc } +type RuleKind int + +const ( + SingleValue = iota + MultiValue +) + +// RuleKind returns the type of rule this is +func (head *Head) RuleKind() RuleKind { + // NOTE(sr): This is bit verbose, since the key is irrelevant for single vs + // multi value, but as good a spot as to assert the invariant. + switch { + case head.Value != nil: + return SingleValue + case head.Key != nil: + return MultiValue + default: + panic("unreachable") + } +} + +// Ref returns the Ref of the rule. If it doesn't have one, it's filled in +// via the Head's Name. +func (head *Head) Ref() Ref { + if len(head.Reference) > 0 { + return head.Reference + } + return Ref{&Term{Value: head.Name}} +} + +// SetRef can be used to set a rule head's Reference +func (head *Head) SetRef(r Ref) { + head.Reference = r +} + // Compare returns an integer indicating whether head is less than, equal to, // or greater than other. func (head *Head) Compare(other *Head) int { @@ -832,6 +862,9 @@ func (head *Head) Compare(other *Head) int { if cmp := Compare(head.Args, other.Args); cmp != 0 { return cmp } + if cmp := Compare(head.Reference, other.Reference); cmp != 0 { + return cmp + } if cmp := Compare(head.Name, other.Name); cmp != 0 { return cmp } @@ -844,6 +877,7 @@ func (head *Head) Compare(other *Head) int { // Copy returns a deep copy of head. func (head *Head) Copy() *Head { cpy := *head + cpy.Reference = head.Reference.Copy() cpy.Args = head.Args.Copy() cpy.Key = head.Key.Copy() cpy.Value = head.Value.Copy() @@ -856,23 +890,57 @@ func (head *Head) Equal(other *Head) bool { } func (head *Head) String() string { - var buf []string - if len(head.Args) != 0 { - buf = append(buf, head.Name.String()+head.Args.String()) - } else if head.Key != nil { - buf = append(buf, head.Name.String()+"["+head.Key.String()+"]") - } else { - buf = append(buf, head.Name.String()) + buf := strings.Builder{} + buf.WriteString(head.Ref().String()) + + switch { + case len(head.Args) != 0: + buf.WriteString(head.Args.String()) + case len(head.Reference) == 1 && head.Key != nil: + buf.WriteRune('[') + buf.WriteString(head.Key.String()) + buf.WriteRune(']') } if head.Value != nil { if head.Assign { - buf = append(buf, ":=") + buf.WriteString(" := ") } else { - buf = append(buf, "=") + buf.WriteString(" = ") } - buf = append(buf, head.Value.String()) + buf.WriteString(head.Value.String()) + } else if head.Name == "" && head.Key != nil { + buf.WriteString(" contains ") + buf.WriteString(head.Key.String()) } - return strings.Join(buf, " ") + return buf.String() +} + +func (head *Head) setJSONOptions(opts JSONOptions) { + head.jsonOptions = opts +} + +func (head *Head) MarshalJSON() ([]byte, error) { + var loc *Location + if head.jsonOptions.MarshalOptions.IncludeLocation.Head { + if head.Location != nil { + loc = head.Location + } + } + + // NOTE(sr): we do this to override the rendering of `head.Reference`. + // It's still what'll be used via the default means of encoding/json + // for unmarshaling a json object into a Head struct! + // NOTE(charlieegan3): we also need to optionally include the location + type h Head + return json.Marshal(struct { + h + Ref Ref `json:"ref"` + Location *Location `json:"location,omitempty"` + }{ + h: h(*head), + Ref: head.Ref(), + Location: loc, + }) } // Vars returns a set of vars found in the head. @@ -888,6 +956,9 @@ func (head *Head) Vars() VarSet { if head.Value != nil { vis.Walk(head.Value) } + if len(head.Reference) > 0 { + vis.Walk(head.Reference[1:]) + } return vis.vars } @@ -914,7 +985,7 @@ func (a Args) Copy() Args { } func (a Args) String() string { - var buf []string + buf := make([]string, 0, len(a)) for _, t := range a { buf = append(buf, t.String()) } @@ -959,7 +1030,8 @@ func (body Body) MarshalJSON() ([]byte, error) { if len(body) == 0 { return []byte(`[]`), nil } - return json.Marshal([]*Expr(body)) + ret, err := json.Marshal([]*Expr(body)) + return ret, err } // Append adds the expr to the body and updates the expr's index accordingly. @@ -1058,7 +1130,7 @@ func (body Body) SetLoc(loc *Location) { } func (body Body) String() string { - var buf []string + buf := make([]string, 0, len(body)) for _, v := range body { buf = append(buf, v.String()) } @@ -1075,6 +1147,11 @@ func (body Body) Vars(params VarVisitorParams) VarSet { // NewExpr returns a new Expr object. func NewExpr(terms interface{}) *Expr { + switch terms.(type) { + case *SomeDecl, *Every, *Term, []*Term: // ok + default: + panic("unreachable") + } return &Expr{ Negated: false, Terms: terms, @@ -1153,6 +1230,10 @@ func (expr *Expr) Compare(other *Expr) int { if cmp := Compare(t, other.Terms.(*SomeDecl)); cmp != 0 { return cmp } + case *Every: + if cmp := Compare(t, other.Terms.(*Every)); cmp != 0 { + return cmp + } } return withSliceCompare(expr.With, other.With) @@ -1166,14 +1247,28 @@ func (expr *Expr) sortOrder() int { return 1 case []*Term: return 2 + case *Every: + return 3 } return -1 } +// CopyWithoutTerms returns a deep copy of expr without its Terms +func (expr *Expr) CopyWithoutTerms() *Expr { + cpy := *expr + + cpy.With = make([]*With, len(expr.With)) + for i := range expr.With { + cpy.With[i] = expr.With[i].Copy() + } + + return &cpy +} + // Copy returns a deep copy of expr. func (expr *Expr) Copy() *Expr { - cpy := *expr + cpy := expr.CopyWithoutTerms() switch ts := expr.Terms.(type) { case *SomeDecl: @@ -1186,14 +1281,11 @@ func (expr *Expr) Copy() *Expr { cpy.Terms = cpyTs case *Term: cpy.Terms = ts.Copy() + case *Every: + cpy.Terms = ts.Copy() } - cpy.With = make([]*With, len(expr.With)) - for i := range expr.With { - cpy.With[i] = expr.With[i].Copy() - } - - return &cpy + return cpy } // Hash returns the hash code of the Expr. @@ -1248,6 +1340,18 @@ func (expr *Expr) IsCall() bool { return ok } +// IsEvery returns true if this expression is an 'every' expression. +func (expr *Expr) IsEvery() bool { + _, ok := expr.Terms.(*Every) + return ok +} + +// IsSome returns true if this expression is a 'some' expression. +func (expr *Expr) IsSome() bool { + _, ok := expr.Terms.(*SomeDecl) + return ok +} + // Operator returns the name of the function or built-in this expression refers // to. If this expression is not a function call, returns nil. func (expr *Expr) Operator() Ref { @@ -1333,7 +1437,7 @@ func (expr *Expr) SetLoc(loc *Location) { } func (expr *Expr) String() string { - var buf []string + buf := make([]string, 0, 2+len(expr.With)) if expr.Negated { buf = append(buf, "not") } @@ -1344,9 +1448,7 @@ func (expr *Expr) String() string { } else { buf = append(buf, Call(t).String()) } - case *Term: - buf = append(buf, t.String()) - case *SomeDecl: + case fmt.Stringer: buf = append(buf, t.String()) } @@ -1357,6 +1459,37 @@ func (expr *Expr) String() string { return strings.Join(buf, " ") } +func (expr *Expr) setJSONOptions(opts JSONOptions) { + expr.jsonOptions = opts +} + +func (expr *Expr) MarshalJSON() ([]byte, error) { + data := map[string]interface{}{ + "terms": expr.Terms, + "index": expr.Index, + } + + if len(expr.With) > 0 { + data["with"] = expr.With + } + + if expr.Generated { + data["generated"] = true + } + + if expr.Negated { + data["negated"] = true + } + + if expr.jsonOptions.MarshalOptions.IncludeLocation.Expr { + if expr.Location != nil { + data["location"] = expr.Location + } + } + + return json.Marshal(data) +} + // UnmarshalJSON parses the byte array and stores the result in expr. func (expr *Expr) UnmarshalJSON(bs []byte) error { v := map[string]interface{}{} @@ -1422,6 +1555,101 @@ func (d *SomeDecl) Hash() int { return termSliceHash(d.Symbols) } +func (d *SomeDecl) setJSONOptions(opts JSONOptions) { + d.jsonOptions = opts +} + +func (d *SomeDecl) MarshalJSON() ([]byte, error) { + data := map[string]interface{}{ + "symbols": d.Symbols, + } + + if d.jsonOptions.MarshalOptions.IncludeLocation.SomeDecl { + if d.Location != nil { + data["location"] = d.Location + } + } + + return json.Marshal(data) +} + +func (q *Every) String() string { + if q.Key != nil { + return fmt.Sprintf("every %s, %s in %s { %s }", + q.Key, + q.Value, + q.Domain, + q.Body) + } + return fmt.Sprintf("every %s in %s { %s }", + q.Value, + q.Domain, + q.Body) +} + +func (q *Every) Loc() *Location { + return q.Location +} + +func (q *Every) SetLoc(l *Location) { + q.Location = l +} + +// Copy returns a deep copy of d. +func (q *Every) Copy() *Every { + cpy := *q + cpy.Key = q.Key.Copy() + cpy.Value = q.Value.Copy() + cpy.Domain = q.Domain.Copy() + cpy.Body = q.Body.Copy() + return &cpy +} + +func (q *Every) Compare(other *Every) int { + for _, terms := range [][2]*Term{ + {q.Key, other.Key}, + {q.Value, other.Value}, + {q.Domain, other.Domain}, + } { + if d := Compare(terms[0], terms[1]); d != 0 { + return d + } + } + return q.Body.Compare(other.Body) +} + +// KeyValueVars returns the key and val arguments of an `every` +// expression, if they are non-nil and not wildcards. +func (q *Every) KeyValueVars() VarSet { + vis := &VarVisitor{vars: VarSet{}} + if q.Key != nil { + vis.Walk(q.Key) + } + vis.Walk(q.Value) + return vis.vars +} + +func (q *Every) setJSONOptions(opts JSONOptions) { + q.jsonOptions = opts +} + +func (q *Every) MarshalJSON() ([]byte, error) { + data := map[string]interface{}{ + "key": q.Key, + "value": q.Value, + "domain": q.Domain, + "body": q.Body, + } + + if q.jsonOptions.MarshalOptions.IncludeLocation.Every { + if q.Location != nil { + data["location"] = q.Location + } + } + + return json.Marshal(data) +} + func (w *With) String() string { return "with " + w.Target.String() + " as " + w.Value.String() } @@ -1480,6 +1708,25 @@ func (w *With) SetLoc(loc *Location) { w.Location = loc } +func (w *With) setJSONOptions(opts JSONOptions) { + w.jsonOptions = opts +} + +func (w *With) MarshalJSON() ([]byte, error) { + data := map[string]interface{}{ + "target": w.Target, + "value": w.Value, + } + + if w.jsonOptions.MarshalOptions.IncludeLocation.With { + if w.Location != nil { + data["location"] = w.Location + } + } + + return json.Marshal(data) +} + // Copy returns a deep copy of the AST node x. If x is not an AST node, x is returned unmodified. func Copy(x interface{}) interface{} { switch x := x.(type) { @@ -1503,6 +1750,8 @@ func Copy(x interface{}) interface{} { return x.Copy() case *SomeDecl: return x.Copy() + case *Every: + return x.Copy() case *Term: return x.Copy() case *ArrayComprehension: diff --git a/ast/policy_test.go b/ast/policy_test.go index 5f01f9a3ad..118fc5b99f 100644 --- a/ast/policy_test.go +++ b/ast/policy_test.go @@ -8,6 +8,7 @@ import ( "bytes" "encoding/json" "fmt" + "net/url" "reflect" "testing" @@ -19,6 +20,7 @@ func TestModuleJSONRoundTrip(t *testing.T) { mod, err := ParseModuleWithOpts("test.rego", `package a.b.c +import future.keywords import data.x.y as z import data.u.i @@ -41,6 +43,7 @@ a = true { xs = {a: b | input.y[a] = "foo"; b = input.z["bar"]} } b = true { xs = {{"x": a[i].a} | a[i].n = "bob"; b[x]} } call_values { f(x) != g(x) } assigned := 1 +rule.having.ref.head[1] = x if x := 2 # METADATA # scope: rule @@ -376,17 +379,76 @@ func TestExprBadJSON(t *testing.T) { assert(js, exp) } +func TestExprEveryCopy(t *testing.T) { + opts := ParserOptions{AllFutureKeywords: true} + newEvery := func() *Expr { + return MustParseBodyWithOpts( + `every k, v in [1,2,3] { true }`, opts, + )[0] + } + e0 := newEvery() + e1 := e0.Copy() + e1.Terms.(*Every).Body = NewBody(NewExpr(BooleanTerm(false))) + if exp := newEvery(); exp.Compare(e0) != 0 { + t.Errorf("expected e0 unchanged (%v), found %v", exp, e0) + } +} + +func TestRuleHeadJSON(t *testing.T) { + // NOTE(sr): we may get to see Rule objects that aren't the result of parsing, but + // fed as-is into the compiler. We need to be able to make sense of their refs, too. + head := Head{ + Name: Var("allow"), + } + + rule := Rule{ + Head: &head, + } + bs, err := json.Marshal(&rule) + if err != nil { + t.Fatal(err) + } + if exp, act := `{"body":[],"head":{"name":"allow","ref":[{"type":"var","value":"allow"}]}}`, string(bs); act != exp { + t.Errorf("expected %q, got %q", exp, act) + } + + var readRule Rule + if err := json.Unmarshal(bs, &readRule); err != nil { + t.Fatal(err) + } + if exp, act := 1, len(readRule.Head.Reference); act != exp { + t.Errorf("expected unmarshalled rule to have Reference, got %v", readRule.Head.Reference) + } + bs0, err := json.Marshal(&readRule) + if err != nil { + t.Fatal(err) + } + if exp, act := string(bs), string(bs0); exp != act { + t.Errorf("expected json repr to match %q, got %q", exp, act) + } + + var readAgainRule Rule + if err := json.Unmarshal(bs, &readAgainRule); err != nil { + t.Fatal(err) + } + if !readAgainRule.Equal(&readRule) { + t.Errorf("expected roundtripped rule reference to match %v, got %v", readRule.Head.Reference, readAgainRule.Head.Reference) + } +} + func TestRuleHeadEquals(t *testing.T) { assertHeadsEqual(t, &Head{}, &Head{}) - // Same name/key/value + // Same name/ref/key/value assertHeadsEqual(t, &Head{Name: Var("p")}, &Head{Name: Var("p")}) + assertHeadsEqual(t, &Head{Reference: Ref{VarTerm("p"), StringTerm("r")}}, &Head{Reference: Ref{VarTerm("p"), StringTerm("r")}}) // TODO: string for first section assertHeadsEqual(t, &Head{Key: VarTerm("x")}, &Head{Key: VarTerm("x")}) assertHeadsEqual(t, &Head{Value: VarTerm("x")}, &Head{Value: VarTerm("x")}) assertHeadsEqual(t, &Head{Args: []*Term{VarTerm("x"), VarTerm("y")}}, &Head{Args: []*Term{VarTerm("x"), VarTerm("y")}}) - // Different name/key/value + // Different name/ref/key/value assertHeadsNotEqual(t, &Head{Name: Var("p")}, &Head{Name: Var("q")}) + assertHeadsNotEqual(t, &Head{Reference: Ref{VarTerm("p")}}, &Head{Reference: Ref{VarTerm("q")}}) // TODO: string for first section assertHeadsNotEqual(t, &Head{Key: VarTerm("x")}, &Head{Key: VarTerm("y")}) assertHeadsNotEqual(t, &Head{Value: VarTerm("x")}, &Head{Value: VarTerm("y")}) assertHeadsNotEqual(t, &Head{Args: []*Term{VarTerm("x"), VarTerm("z")}}, &Head{Args: []*Term{VarTerm("x"), VarTerm("y")}}) @@ -422,48 +484,139 @@ func TestRuleBodyEquals(t *testing.T) { } func TestRuleString(t *testing.T) { - - rule1 := &Rule{ - Head: NewHead(Var("p"), nil, BooleanTerm(true)), - Body: NewBody( - Equality.Expr(StringTerm("foo"), StringTerm("bar")), - ), - } - - rule2 := &Rule{ - Head: NewHead(Var("p"), VarTerm("x"), VarTerm("y")), - Body: NewBody( - Equality.Expr(StringTerm("foo"), VarTerm("x")), - &Expr{ - Negated: true, - Terms: RefTerm(VarTerm("a"), StringTerm("b"), VarTerm("x")), + trueBody := NewBody(NewExpr(BooleanTerm(true))) + + tests := []struct { + rule *Rule + exp string + }{ + { + rule: &Rule{ + Head: NewHead(Var("p"), nil, BooleanTerm(true)), + Body: NewBody( + Equality.Expr(StringTerm("foo"), StringTerm("bar")), + ), + }, + exp: `p = true { "foo" = "bar" }`, + }, + { + rule: &Rule{ + Head: NewHead(Var("p"), VarTerm("x")), + Body: trueBody, + }, + exp: `p[x] { true }`, + }, + { + rule: &Rule{ + Head: RefHead(MustParseRef("p[x]"), BooleanTerm(true)), + Body: MustParseBody("x = 1"), + }, + exp: `p[x] = true { x = 1 }`, + }, + { + rule: &Rule{ + Head: RefHead(MustParseRef("p.q.r[x]"), BooleanTerm(true)), + Body: MustParseBody("x = 1"), + }, + exp: `p.q.r[x] = true { x = 1 }`, + }, + { + rule: &Rule{ + Head: &Head{ + Reference: MustParseRef("p.q.r"), + Key: VarTerm("1"), + }, + Body: MustParseBody("x = 1"), + }, + exp: `p.q.r contains 1 { x = 1 }`, + }, + { + rule: &Rule{ + Head: NewHead(Var("p"), VarTerm("x"), VarTerm("y")), + Body: NewBody( + Equality.Expr(StringTerm("foo"), VarTerm("x")), + &Expr{ + Negated: true, + Terms: RefTerm(VarTerm("a"), StringTerm("b"), VarTerm("x")), + }, + Equality.Expr(StringTerm("b"), VarTerm("y")), + ), + }, + exp: `p[x] = y { "foo" = x; not a.b[x]; "b" = y }`, + }, + { + rule: &Rule{ + Default: true, + Head: NewHead("p", nil, BooleanTerm(true)), + }, + exp: `default p = true`, + }, + { + rule: &Rule{ + Head: &Head{ + Name: Var("f"), + Args: Args{VarTerm("x"), VarTerm("y")}, + Value: VarTerm("z"), + }, + Body: NewBody(Plus.Expr(VarTerm("x"), VarTerm("y"), VarTerm("z"))), + }, + exp: "f(x, y) = z { plus(x, y, z) }", + }, + { + rule: &Rule{ + Head: &Head{ + Name: Var("p"), + Value: BooleanTerm(true), + Assign: true, + }, + Body: NewBody( + Equality.Expr(StringTerm("foo"), StringTerm("bar")), + ), }, - Equality.Expr(StringTerm("b"), VarTerm("y")), - ), + exp: `p := true { "foo" = "bar" }`, + }, + { + rule: &Rule{ + Head: RefHead(MustParseRef("p.q.r")), + Body: trueBody, + }, + exp: `p.q.r { true }`, + }, + { + rule: &Rule{ + Head: RefHead(MustParseRef("p.q.r"), StringTerm("foo")), + Body: trueBody, + }, + exp: `p.q.r = "foo" { true }`, + }, + { + rule: &Rule{ + Head: RefHead(MustParseRef("p.q.r[x]"), StringTerm("foo")), + Body: MustParseBody(`x := 1`), + }, + exp: `p.q.r[x] = "foo" { assign(x, 1) }`, + }, } - rule3 := &Rule{ - Default: true, - Head: NewHead("p", nil, BooleanTerm(true)), + for _, tc := range tests { + t.Run(tc.exp, func(t *testing.T) { + assertRuleString(t, tc.rule, tc.exp) + }) } +} - rule4 := &Rule{ - Head: &Head{ - Name: Var("f"), - Args: Args{VarTerm("x"), VarTerm("y")}, - Value: VarTerm("z"), - }, - Body: NewBody(Plus.Expr(VarTerm("x"), VarTerm("y"), VarTerm("z"))), +func TestRulePath(t *testing.T) { + ruleWithMod := func(r string) Ref { + mod := MustParseModule("package pkg\n" + r) + return mod.Rules[0].Path() + } + if exp, act := MustParseRef("data.pkg.p.q.r"), ruleWithMod("p.q.r { true }"); !exp.Equal(act) { + t.Errorf("expected %v, got %v", exp, act) } - rule5 := rule1.Copy() - rule5.Head.Assign = true - - assertRuleString(t, rule1, `p = true { "foo" = "bar" }`) - assertRuleString(t, rule2, `p[x] = y { "foo" = x; not a.b[x]; "b" = y }`) - assertRuleString(t, rule3, `default p = true`) - assertRuleString(t, rule4, "f(x, y) = z { plus(x, y, z) }") - assertRuleString(t, rule5, `p := true { "foo" = "bar" }`) + if exp, act := MustParseRef("data.pkg.p"), ruleWithMod("p { true }"); !exp.Equal(act) { + t.Errorf("expected %v, got %v", exp, act) + } } func TestModuleString(t *testing.T) { @@ -562,24 +715,109 @@ func TestSomeDeclString(t *testing.T) { } } +func TestEveryString(t *testing.T) { + tests := []struct { + every Every + exp string + }{ + { + exp: `every x in ["foo", "bar"] { true; true }`, + every: Every{ + Value: VarTerm("x"), + Domain: ArrayTerm(StringTerm("foo"), StringTerm("bar")), + Body: []*Expr{ + { + Terms: BooleanTerm(true), + }, + { + Terms: BooleanTerm(true), + }, + }, + }, + }, + { + exp: `every k, v in ["foo", "bar"] { true; true }`, + every: Every{ + Key: VarTerm("k"), + Value: VarTerm("v"), + Domain: ArrayTerm(StringTerm("foo"), StringTerm("bar")), + Body: []*Expr{ + { + Terms: BooleanTerm(true), + }, + { + Terms: BooleanTerm(true), + }, + }, + }, + }, + } + for _, tc := range tests { + if act := tc.every.String(); act != tc.exp { + t.Errorf("expected %q, got %q", tc.exp, act) + } + } +} + func TestAnnotationsString(t *testing.T) { a := &Annotations{ - Scope: "foo", + Scope: "foo", + Title: "bar", + Description: "baz", + Authors: []*AuthorAnnotation{ + { + Name: "John Doe", + Email: "john@example.com", + }, + { + Name: "Jane Doe", + }, + }, + Organizations: []string{"mi", "fa"}, + RelatedResources: []*RelatedResourceAnnotation{ + { + Ref: mustParseURL("https://example.com"), + }, + { + Ref: mustParseURL("https://example.com/2"), + Description: "Some resource", + }, + }, Schemas: []*SchemaAnnotation{ { Path: MustParseRef("data.bar"), Schema: MustParseRef("schema.baz"), }, }, + Custom: map[string]interface{}{ + "list": []int{ + 1, 2, 3, + }, + "map": map[string]interface{}{ + "one": 1, + "two": map[int]interface{}{ + 3: "three", + }, + }, + "flag": true, + }, } // NOTE(tsandall): for now, annotations are represented as JSON objects // which are a subset of YAML. We could improve this in the future. - exp := `{"scope":"foo","schemas":[{"path":[{"type":"var","value":"data"},{"type":"string","value":"bar"}],"schema":[{"type":"var","value":"schema"},{"type":"string","value":"baz"}]}]}` + exp := `{"authors":[{"name":"John Doe","email":"john@example.com"},{"name":"Jane Doe"}],"custom":{"flag":true,"list":[1,2,3],"map":{"one":1,"two":{"3":"three"}}},"description":"baz","organizations":["mi","fa"],"related_resources":[{"ref":"https://example.com"},{"description":"Some resource","ref":"https://example.com/2"}],"schemas":[{"path":[{"type":"var","value":"data"},{"type":"string","value":"bar"}],"schema":[{"type":"var","value":"schema"},{"type":"string","value":"baz"}]}],"scope":"foo","title":"bar"}` - if exp != a.String() { - t.Fatalf("expected %q but got %q", exp, a.String()) + if got := a.String(); exp != got { + t.Fatalf("expected\n%s\nbut got\n%s", exp, got) + } +} + +func mustParseURL(str string) url.URL { + parsed, err := url.Parse(str) + if err != nil { + panic(err) } + return *parsed } func TestModuleStringAnnotations(t *testing.T) { @@ -683,7 +921,7 @@ func assertPackagesNotEqual(t *testing.T, a, b *Package) { func assertRulesEqual(t *testing.T, a, b *Rule) { t.Helper() if !a.Equal(b) { - t.Errorf("Rules are not equal (expected equal): a=%v b=%v", a, b) + t.Errorf("Rules are not equal (expected equal):\na=%v\nb=%v", a, b) } } diff --git a/ast/pretty_test.go b/ast/pretty_test.go index d59f7cb809..50fe97c142 100644 --- a/ast/pretty_test.go +++ b/ast/pretty_test.go @@ -41,8 +41,9 @@ func TestPretty(t *testing.T) { qux rule head - p - x + ref + p + x y body expr index=0 @@ -67,7 +68,8 @@ func TestPretty(t *testing.T) { true rule head - f + ref + f args x call diff --git a/ast/schema.go b/ast/schema.go index 76bd475677..8c96ac624e 100644 --- a/ast/schema.go +++ b/ast/schema.go @@ -54,7 +54,7 @@ func loadSchema(raw interface{}, allowNet []string) (types.Type, error) { return nil, err } - tpe, err := parseSchema(jsonSchema.RootSchema) + tpe, err := newSchemaParser().parseSchema(jsonSchema.RootSchema) if err != nil { return nil, fmt.Errorf("type checking: %w", err) } diff --git a/ast/schema_test.go b/ast/schema_test.go index c9671f1b24..2882829b2e 100644 --- a/ast/schema_test.go +++ b/ast/schema_test.go @@ -1,11 +1,15 @@ +// Copyright 2021 The OPA Authors. All rights reserved. +// Use of this source code is governed by an Apache2 +// license that can be found in the LICENSE file. + package ast import ( "errors" "fmt" - "io/ioutil" "net/http" "net/http/httptest" + "os" "strings" "testing" @@ -456,7 +460,7 @@ func TestParseSchemaWithSchemaBadSchema(t *testing.T) { if err != nil { t.Fatalf("Unable to compile schema: %v", err) } - newtype, err := parseSchema(jsonSchema) // Did not pass the subschema + newtype, err := newSchemaParser().parseSchema(jsonSchema) // Did not pass the subschema if err == nil { t.Fatalf("Expected parseSchema() = error, got nil") } @@ -562,7 +566,7 @@ func TestAnyOfSchema(t *testing.T) { func kubeSchemaServer(t *testing.T) *httptest.Server { t.Helper() - bs, err := ioutil.ReadFile("testdata/_definitions.json") + bs, err := os.ReadFile("testdata/_definitions.json") if err != nil { t.Fatal(err) } @@ -574,3 +578,1083 @@ func kubeSchemaServer(t *testing.T) *httptest.Server { })) return ts } + +func TestCompilerCheckTypesWithSchema(t *testing.T) { + c := NewCompiler() + var schema interface{} + err := util.Unmarshal([]byte(objectSchema), &schema) + if err != nil { + t.Fatal("Unexpected error:", err) + } + schemaSet := NewSchemaSet() + schemaSet.Put(SchemaRootRef, schema) + c.WithSchemas(schemaSet) + compileStages(c, c.checkTypes) + assertNotFailed(t, c) +} + +func TestCompilerCheckTypesWithRegexPatternInSchema(t *testing.T) { + c := NewCompiler() + var schema interface{} + // Negative lookahead is not supported in the Go regex dialect, but this is still a valid + // JSON schema. Since we don't rely on the "pattern" attribute for type checking, ensure + // that this still works (by being ignored) + err := util.Unmarshal([]byte(`{ + "properties": { + "name": { + "pattern": "^(?!testing:.*)[a-z]+$", + "type": "string" + } + } + }`), &schema) + if err != nil { + t.Fatal("Unexpected error:", err) + } + schemaSet := NewSchemaSet() + schemaSet.Put(SchemaRootRef, schema) + c.WithSchemas(schemaSet) + compileStages(c, c.checkTypes) + assertNotFailed(t, c) +} + +func TestCompilerCheckTypesWithAllOfSchema(t *testing.T) { + + tests := []struct { + note string + schema string + expectedError error + }{ + { + note: "allOf with mergeable Object types in schema", + schema: allOfObjectSchema, + expectedError: nil, + }, + { + note: "allOf with mergeable Array types in schema", + schema: allOfArraySchema, + expectedError: nil, + }, + { + note: "allOf without a parent schema", + schema: allOfSchemaParentVariation, + expectedError: nil, + }, + { + note: "allOf with empty schema", + schema: emptySchema, + expectedError: nil, + }, + { + note: "allOf with mergeable Array of Object types in schema", + schema: allOfArrayOfObjects, + expectedError: nil, + }, + { + note: "allOf with mergeable Object types in schema with type declaration missing", + schema: allOfObjectMissing, + expectedError: nil, + }, + { + note: "allOf with Array of mergeable different types in schema", + schema: allOfArrayDifTypes, + expectedError: nil, + }, + { + note: "allOf with mergeable Object containing Array types in schema", + schema: allOfArrayInsideObject, + expectedError: nil, + }, + { + note: "allOf with mergeable Array types in schema with type declaration missing", + schema: allOfArrayMissing, + expectedError: nil, + }, + { + note: "allOf with mergeable types inside of core schema", + schema: allOfInsideCoreSchema, + expectedError: nil, + }, + { + note: "allOf with mergeable String types in schema", + schema: allOfStringSchema, + expectedError: nil, + }, + { + note: "allOf with mergeable Integer types in schema", + schema: allOfIntegerSchema, + expectedError: nil, + }, + { + note: "allOf with mergeable Boolean types in schema", + schema: allOfBooleanSchema, + expectedError: nil, + }, + { + note: "allOf with mergeable Array types with uneven numbers of items", + schema: allOfSchemaWithUnevenArray, + expectedError: nil, + }, + { + note: "allOf schema with unmergeable Array of Arrays", + schema: allOfArrayOfArrays, + expectedError: fmt.Errorf("unable to merge these schemas"), + }, + { + note: "allOf schema with Array and Object types as siblings", + schema: allOfObjectAndArray, + expectedError: fmt.Errorf("unable to merge these schemas"), + }, + { + note: "allOf schema with Array type that contains different unmergeable types", + schema: allOfArrayDifTypesWithError, + expectedError: fmt.Errorf("unable to merge these schemas"), + }, + { + note: "allOf schema with different unmergeable types", + schema: allOfTypeErrorSchema, + expectedError: fmt.Errorf("unable to merge these schemas"), + }, + { + note: "allOf unmergeable schema with different parent and items types", + schema: allOfSchemaWithParentError, + expectedError: fmt.Errorf("unable to merge these schemas"), + }, + { + note: "allOf schema of Array type with uneven numbers of items to merge", + schema: allOfSchemaWithUnevenArray, + expectedError: nil, + }, + { + note: "allOf schema with unmergeable types String and Boolean", + schema: allOfStringSchemaWithError, + expectedError: fmt.Errorf("unable to merge these schemas"), + }, + } + + for _, tc := range tests { + t.Run(tc.note, func(t *testing.T) { + c := NewCompiler() + var schema interface{} + err := util.Unmarshal([]byte(tc.schema), &schema) + if err != nil { + t.Fatal("Unexpected error:", err) + } + schemaSet := NewSchemaSet() + schemaSet.Put(SchemaRootRef, schema) + c.WithSchemas(schemaSet) + compileStages(c, c.checkTypes) + if tc.expectedError != nil { + if errors.Is(c.Errors, tc.expectedError) { + t.Fatal("Unexpected error:", err) + } + } else { + assertNotFailed(t, c) + } + }) + } +} + +func TestWithSchema(t *testing.T) { + c := NewCompiler() + schemaSet := NewSchemaSet() + schemaSet.Put(SchemaRootRef, objectSchema) + c.WithSchemas(schemaSet) + if c.schemaSet == nil { + t.Fatalf("WithSchema did not set the schema correctly in the compiler") + } +} + +func TestAnyOfObjectSchema1(t *testing.T) { + c := NewCompiler() + schemaSet := NewSchemaSet() + schemaSet.Put(SchemaRootRef, anyOfExtendCoreSchema) + c.WithSchemas(schemaSet) + if c.schemaSet == nil { + t.Fatalf("Did not correctly compile an object type schema with anyOf outside core schema") + } +} + +func TestAnyOfObjectSchema2(t *testing.T) { + c := NewCompiler() + schemaSet := NewSchemaSet() + schemaSet.Put(SchemaRootRef, anyOfInsideCoreSchema) + c.WithSchemas(schemaSet) + if c.schemaSet == nil { + t.Fatalf("Did not correctly compile an object type schema with anyOf inside core schema") + } +} + +func TestAnyOfArraySchema(t *testing.T) { + c := NewCompiler() + schemaSet := NewSchemaSet() + schemaSet.Put(SchemaRootRef, anyOfArraySchema) + c.WithSchemas(schemaSet) + if c.schemaSet == nil { + t.Fatalf("Did not correctly compile an array type schema with anyOf") + } +} + +func TestAnyOfObjectMissing(t *testing.T) { + c := NewCompiler() + schemaSet := NewSchemaSet() + schemaSet.Put(SchemaRootRef, anyOfObjectMissing) + c.WithSchemas(schemaSet) + if c.schemaSet == nil { + t.Fatalf("Did not correctly compile an object type schema with anyOf where one of the props did not explicitly claim type") + } +} + +func TestAnyOfArrayMissing(t *testing.T) { + c := NewCompiler() + schemaSet := NewSchemaSet() + schemaSet.Put(SchemaRootRef, anyOfArrayMissing) + c.WithSchemas(schemaSet) + if c.schemaSet == nil { + t.Fatalf("Did not correctly compile an array type schema with anyOf where items are inside anyOf") + } +} + +func TestRecursiveSchema(t *testing.T) { + c := NewCompiler() + schemaSet := NewSchemaSet() + schemaSet.Put(SchemaRootRef, recursiveElements) + c.WithSchemas(schemaSet) + if c.schemaSet == nil { + t.Fatalf("Did not correctly compile an object schema with recursive elements") + } +} + +const objectSchema = `{ + "$schema": "http://json-schema.org/draft-07/schema", + "$id": "http://example.com/example.json", + "type": "object", + "title": "The root schema", + "description": "The root schema comprises the entire JSON document.", + "required": [ + "foo", + "b" + ], + "properties": { + "foo": { + "$id": "#/properties/foo", + "type": "string", + "title": "The foo schema", + "description": "An explanation about the purpose of this instance." + }, + "b": { + "$id": "#/properties/b", + "type": "array", + "title": "The b schema", + "description": "An explanation about the purpose of this instance.", + "additionalItems": false, + "items": { + "$id": "#/properties/b/items", + "type": "object", + "title": "The items schema", + "description": "An explanation about the purpose of this instance.", + "required": [ + "a", + "b", + "c" + ], + "properties": { + "a": { + "$id": "#/properties/b/items/properties/a", + "type": "integer", + "title": "The a schema", + "description": "An explanation about the purpose of this instance." + }, + "b": { + "$id": "#/properties/b/items/properties/b", + "type": "array", + "title": "The b schema", + "description": "An explanation about the purpose of this instance.", + "additionalItems": false, + "items": { + "$id": "#/properties/b/items/properties/b/items", + "type": "integer", + "title": "The items schema", + "description": "An explanation about the purpose of this instance." + } + }, + "c": { + "$id": "#/properties/b/items/properties/c", + "type": "null", + "title": "The c schema", + "description": "An explanation about the purpose of this instance." + } + }, + "additionalProperties": false + } + } + }, + "additionalProperties": false +}` + +const arrayNoItemsSchema = `{ + "$schema": "http://json-schema.org/draft-07/schema", + "$id": "http://example.com/example.json", + "type": "object", + "title": "The root schema", + "description": "The root schema comprises the entire JSON document.", + "required": [ + "b" + ], + "properties": { + "b": { + "$id": "#/properties/b", + "type": "array", + "title": "The b schema", + "description": "An explanation about the purpose of this instance.", + "additionalItems": true + } + }, + "additionalProperties": false +}` + +const noChildrenObjectSchema = `{ + "$schema": "http://json-schema.org/draft-07/schema", + "$id": "http://example.com/example.json", + "type": "object", + "title": "The root schema", + "description": "The root schema comprises the entire JSON document.", + "additionalProperties": true +}` + +const untypedFieldObjectSchema = `{ + "$schema": "http://json-schema.org/draft-07/schema", + "$id": "http://example.com/example.json", + "type": "object", + "title": "The root schema", + "description": "The root schema comprises the entire JSON document.", + "required": [ + "foo" + ], + "properties": { + "foo": { + "$id": "#/properties/foo" + } + }, + "additionalProperties": false +}` + +const booleanSchema = `{ + "$schema": "http://json-schema.org/draft-07/schema", + "$id": "http://example.com/example.json", + "type": "object", + "title": "The root schema", + "description": "The root schema comprises the entire JSON document.", + "required": [ + "a" + ], + "properties": { + "a": { + "$id": "#/properties/foo", + "type": "boolean", + "title": "The foo schema", + "description": "An explanation about the purpose of this instance." + } + }, + "additionalProperties": false +}` + +const refSchema = ` +{ + "description": "Pod is a collection of containers that can run on a host. This resource is created by clients and scheduled onto hosts.", + "type": "object", + "properties": { + "apiVersion": { + "description": "APIVersion defines the versioned schema of this representation of an object. Servers should convert recognized schemas to the latest internal value, and may reject unrecognized values. More info: https://git.k8s.io/community/contributors/devel/api-conventions.md#resources", + "type": [ + "string", + "null" + ] + }, + + "kind": { + "description": "Kind is a string value representing the REST resource this object represents. Servers may infer this from the endpoint the client submits requests to. Cannot be updated. In CamelCase. More info: https://git.k8s.io/community/contributors/devel/api-conventions.md#types-kinds", + "type": [ + "string", + "null" + ], + "enum": [ + "Pod" + ] + }, + "metadata": { + "$ref": "https://kubernetesjsonschema.dev/v1.14.0/_definitions.json#/definitions/io.k8s.apimachinery.pkg.apis.meta.v1.ObjectMeta", + "description": "Standard object's metadata. More info: https://git.k8s.io/community/contributors/devel/api-conventions.md#metadata" + } + } +} +` +const podSchema = ` +{ + "description": "Pod is a collection of containers that can run on a host. This resource is created by clients and scheduled onto hosts.", + "properties": { + "apiVersion": { + "description": "APIVersion defines the versioned schema of this representation of an object. Servers should convert recognized schemas to the latest internal value, and may reject unrecognized values. More info: https://git.k8s.io/community/contributors/devel/api-conventions.md#resources", + "type": [ + "string", + "null" + ] + }, + "kind": { + "description": "Kind is a string value representing the REST resource this object represents. Servers may infer this from the endpoint the client submits requests to. Cannot be updated. In CamelCase. More info: https://git.k8s.io/community/contributors/devel/api-conventions.md#types-kinds", + "type": [ + "string", + "null" + ], + "enum": [ + "Pod" + ] + }, + "metadata": { + "$ref": "https://kubernetesjsonschema.dev/v1.14.0/_definitions.json#/definitions/io.k8s.apimachinery.pkg.apis.meta.v1.ObjectMeta", + "description": "Standard object's metadata. More info: https://git.k8s.io/community/contributors/devel/api-conventions.md#metadata" + }, + "spec": { + "$ref": "https://kubernetesjsonschema.dev/v1.14.0/_definitions.json#/definitions/io.k8s.api.core.v1.PodSpec", + "description": "Specification of the desired behavior of the pod. More info: https://git.k8s.io/community/contributors/devel/api-conventions.md#spec-and-status" + }, + "status": { + "$ref": "https://kubernetesjsonschema.dev/v1.14.0/_definitions.json#/definitions/io.k8s.api.core.v1.PodStatus", + "description": "Most recently observed status of the pod. This data may not be up to date. Populated by the system. Read-only. More info: https://git.k8s.io/community/contributors/devel/api-conventions.md#spec-and-status" + } + }, + "type": "object", + "x-kubernetes-group-version-kind": [ + { + "group": "", + "kind": "Pod", + "version": "v1" + } + ], + "$schema": "http://json-schema.org/schema#" + }` + +const anyOfArraySchema = `{ + "type": "object", + "properties": { + "familyMembers": { + "type": "array", + "items": { + "anyOf": [ + { + "type": "object", + "properties": { + "age": { "type": "integer" }, + "name": {"type": "string"} + } + },{ + "type": "object", + "properties": { + "personality": { "type": "string" }, + "nickname": { "type": "string" } + } + } + ] + } + } + } +}` + +const anyOfExtendCoreSchema = `{ + "type": "object", + "properties": { + "AddressLine": { "type": "string" } + }, + "anyOf": [ + { + "type": "object", + "properties": { + "State": { "type": "string" }, + "ZipCode": { "type": "string" } + } + }, + { + "type": "object", + "properties": { + "County": { "type": "string" }, + "PostCode": { "type": "integer" } + } + } + ] +}` + +const allOfObjectSchema = `{ + "$schema": "http://json-schema.org/draft-04/schema#", + "type": "object", + "title": "My schema", + "properties": { + "AddressLine1": { "type": "string" }, + "AddressLine2": { "type": "string" }, + "City": { "type": "string" } + }, + "allOf": [ + { + "type": "object", + "properties": { + "State": { "type": "string" }, + "ZipCode": { "type": "string" } + }, + }, + { + "type": "object", + "properties": { + "County": { "type": "string" }, + "PostCode": { "type": "string" } + }, + } + ] +}` + +const allOfArraySchema = `{ + "$schema": "http://json-schema.org/draft-04/schema#", + "type": "array", + "title": "The b schema", + "description": "An explanation about the purpose of this instance.", + "items": { + "type": "integer", + "title": "The items schema", + "description": "An explanation about the purpose of this instance." + }, + "allOf": [ + { + "type": "array", + "title": "The b schema", + "description": "An explanation about the purpose of this instance.", + "items": { + "type": "integer", + "title": "The items schema", + "description": "An explanation about the purpose of this instance." + } + }, + { + "type": "array", + "title": "The b schema", + "description": "An explanation about the purpose of this instance.", + "items": { + "type": "integer", + "title": "The items schema", + "description": "An explanation about the purpose of this instance." + } + } + ] +}` + +const allOfSchemaParentVariation = `{ + "$schema": "http://json-schema.org/draft-04/schema#", + "allOf": [ + { + "type": "object", + "properties": { + "State": { "type": "string" }, + "ZipCode": { "type": "string" } + }, + }, + { + "type": "object", + "properties": { + "County": { "type": "string" }, + "PostCode": { "type": "string" } + }, + } + ] +}` + +const emptySchema = `{ + "allof" : [] + }` + +const allOfArrayOfArrays = `{ + "$schema": "http://json-schema.org/draft-04/schema#", + "type": "array", + "title": "The b schema", + "description": "An explanation about the purpose of this instance.", + "items": { + "type": "array", + "title": "The items schema", + "description": "An explanation about the purpose of this instance.", + "items": { + "type": "integer", + "title": "The items schema", + "description": "An explanation about the purpose of this instance." + } + }, + "allOf": [{ + "type": "array", + "title": "The b schema", + "description": "An explanation about the purpose of this instance.", + "items": { + "type": "array", + "title": "The items schema", + "description": "An explanation about the purpose of this instance.", + "items": { + "type": "integer", + "title": "The items schema", + "description": "An explanation about the purpose of this instance." + } + } + }, + { + "type": "array", + "title": "The b schema", + "description": "An explanation about the purpose of this instance.", + "items": { + "type": "integer", + "title": "The items schema", + "description": "An explanation about the purpose of this instance." + } + } + ] +}` + +const anyOfInsideCoreSchema = ` { + "type": "object", + "properties": { + "AddressLine": { "type": "string" }, + "RandomInfo": { + "anyOf": [ + { "type": "object", + "properties": { + "accessMe": {"type": "string"} + } + }, + { "type": "number", "minimum": 0 } + ] + } + } +}` + +const anyOfObjectMissing = `{ + "type": "object", + "properties": { + "AddressLine": { "type": "string" } + }, + "anyOf": [ + { + "type": "object", + "properties": { + "State": { "type": "string" }, + "ZipCode": { "type": "string" } + } + }, + { + "properties": { + "County": { "type": "string" }, + "PostCode": { "type": "integer" } + } + } + ] +}` + +const allOfArrayOfObjects = `{ + "$schema": "http://json-schema.org/draft-04/schema#", + "type": "array", + "title": "The b schema", + "description": "An explanation about the purpose of this instance.", + "items": { + "type": "object", + "title": "The items schema", + "description": "An explanation about the purpose of this instance.", + "properties": { + "State": { + "type": "string" + }, + "ZipCode": { + "type": "string" + } + }, + "allOf": [{ + "type": "object", + "title": "The b schema", + "description": "An explanation about the purpose of this instance.", + "properties": { + "County": { + "type": "string" + }, + "PostCode": { + "type": "string" + } + } + }, + { + "type": "object", + "title": "The b schema", + "description": "An explanation about the purpose of this instance.", + "properties": { + "Street": { + "type": "string" + }, + "House": { + "type": "string" + } + } + } + ] + } +}` + +const allOfObjectAndArray = `{ + "$schema": "http://json-schema.org/draft-04/schema#", + "type": "object", + "title": "My schema", + "properties": { + "AddressLine1": { + "type": "string" + }, + "AddressLine2": { + "type": "string" + }, + "City": { + "type": "string" + } + }, + "allOf": [{ + "type": "object", + "properties": { + "State": { + "type": "string" + }, + "ZipCode": { + "type": "string" + } + } + }, + { + "type": "array", + "title": "The b schema", + "description": "An explanation about the purpose of this instance.", + "items": { + "type": "integer", + "title": "The items schema", + "description": "An explanation about the purpose of this instance." + } + } + ] +}` + +const allOfObjectMissing = `{ + "type": "object", + "properties": { + "AddressLine": { "type": "string" } + }, + "allOf": [ + { + "type": "object", + "properties": { + "State": { "type": "string" }, + "ZipCode": { "type": "string" } + } + }, + { + "properties": { + "County": { "type": "string" }, + "PostCode": { "type": "integer" } + } + } + ] +}` + +const allOfArrayDifTypes = `{ + "$schema": "http://json-schema.org/draft-04/schema#", + "type": "array", + "title": "The b schema", + "description": "An explanation about the purpose of this instance.", + "allOf": [{ + "type": "array", + "items": [{ + "type": "string" + }, + { + "type": "integer" + } + ] + }, + { + "type": "array", + "items": [{ + "type": "string" + }, + { + "type": "integer" + } + ] + } + ] +}` + +const allOfArrayInsideObject = `{ + "type": "object", + "properties": { + "familyMembers": { + "type": "array", + "items": { + "allOf": [{ + "type": "object", + "properties": { + "age": { + "type": "integer" + }, + "name": { + "type": "string" + } + } + }, { + "type": "object", + "properties": { + "personality": { + "type": "string" + }, + "nickname": { + "type": "string" + } + } + }] + } + } + } +}` + +const anyOfArrayMissing = `{ + "type": "array", + "anyOf": [ + { + "items": [ + {"type": "number"}, + {"type": "string"}] + }, + { "items": [ + {"type": "integer"}] + } + ] +}` + +const allOfArrayMissing = `{ + "type": "array", + "allOf": [{ + "items": [{ + "type": "integer" + }, + { + "type": "integer" + } + ] + }, + { + "items": [{ + "type": "integer" + }] + } + ] +}` + +const anyOfSchemaParentVariation = `{ + "$schema": "http://json-schema.org/draft-04/schema#", + "anyOf": [ + { + "type": "object", + "properties": { + "State": { "type": "string" }, + "ZipCode": { "type": "string" } + }, + }, + { + "type": "object", + "properties": { + "County": { "type": "string" }, + "PostCode": { "type": "string" } + }, + } + ] + } +}` + +const allOfInsideCoreSchema = `{ + "type": "object", + "properties": { + "AddressLine": { "type": "string" }, + "RandomInfo": { + "allOf": [ + { "type": "object", + "properties": { + "accessMe": {"type": "string"} + } + }, + { "type": "object", + "properties": { + "accessYou": {"type": "string"} + }} + ] + } + } +}` + +const allOfArrayDifTypesWithError = `{ + "$schema": "http://json-schema.org/draft-04/schema#", + "type": "array", + "title": "The b schema", + "description": "An explanation about the purpose of this instance.", + "allOf": [{ + "type": "array", + "items": [{ + "type": "string" + }, + { + "type": "integer" + } + ] + }, + { + "type": "array", + "items": [{ + "type": "boolean" + }, + { + "type": "integer" + } + ] + } + ] +}` + +const allOfStringSchema = `{ + "$schema": "http://json-schema.org/draft-04/schema#", + "type": "string", + "title": "The b schema", + "description": "An explanation about the purpose of this instance.", + "allOf": [{ + "type": "string", + }, + { + "type": "string", + } + ] +}` + +const allOfIntegerSchema = `{ + "$schema": "http://json-schema.org/draft-04/schema#", + "type": "integer", + "title": "The b schema", + "description": "An explanation about the purpose of this instance.", + "allOf": [{ + "type": "integer", + }, + { + "type": "integer", + } + ] +}` + +const allOfBooleanSchema = `{ + "$schema": "http://json-schema.org/draft-04/schema#", + "type": "boolean", + "title": "The b schema", + "description": "An explanation about the purpose of this instance.", + "allOf": [{ + "type": "boolean", + }, + { + "type": "boolean", + } + ] +}` + +const allOfTypeErrorSchema = `{ + "$schema": "http://json-schema.org/draft-04/schema#", + "type": "string", + "title": "The b schema", + "description": "An explanation about the purpose of this instance.", + "allOf": [{ + "type": "string", + }, + { + "type": "integer", + } + ] +}` + +const allOfStringSchemaWithError = `{ + "$schema": "http://json-schema.org/draft-04/schema#", + "type": "string", + "title": "The b schema", + "description": "An explanation about the purpose of this instance.", + "allOf": [{ + "type": "string", + }, + { + "type": "string", + }, + { + "type": "boolean", + } + ] +}` + +const allOfSchemaWithParentError = `{ + "$schema": "http://json-schema.org/draft-04/schema#", + "type": "string", + "title": "The b schema", + "description": "An explanation about the purpose of this instance.", + "allOf": [{ + "type": "integer", + }, + { + "type": "integer", + } + ] +}` + +const allOfSchemaWithUnevenArray = `{ + "type": "array", + "allOf": [{ + "items": [{ + "type": "integer" + }, + { + "type": "integer" + } + ] + }, + { + "items": [{ + "type": "integer" + }, + { + "type": "integer" + }, + { + "type": "string" + }] + } + ] +}` + +const recursiveElements = `{ + "type": "object", + "properties": { + "Something": { + "$ref": "#/$defs/X" + } + }, + "$defs": { + "X": { + "type": "object", + "properties": { + "Y": { + "$ref": "#/$defs/Y" + } + } + }, + "Y": { + "type": "object", + "properties": { + "X": { + "$ref": "#/$defs/X" + } + } + } + } +} +` diff --git a/ast/strings.go b/ast/strings.go index 8f9928017a..e489f6977c 100644 --- a/ast/strings.go +++ b/ast/strings.go @@ -11,5 +11,8 @@ import ( // TypeName returns a human readable name for the AST element type. func TypeName(x interface{}) string { + if _, ok := x.(*lazyObj); ok { + return "object" + } return strings.ToLower(reflect.Indirect(reflect.ValueOf(x)).Type().Name()) } diff --git a/ast/term.go b/ast/term.go index 1dc5bd1bbf..79e21ccae6 100644 --- a/ast/term.go +++ b/ast/term.go @@ -8,6 +8,7 @@ package ast import ( "bytes" "encoding/json" + "errors" "fmt" "io" "math" @@ -17,9 +18,9 @@ import ( "sort" "strconv" "strings" + "sync" "github.com/OneOfOne/xxhash" - "github.com/pkg/errors" "github.com/open-policy-agent/opa/ast/location" "github.com/open-policy-agent/opa/util" @@ -133,12 +134,12 @@ func As(v Value, x interface{}) error { // Resolver defines the interface for resolving references to native Go values. type Resolver interface { - Resolve(ref Ref) (interface{}, error) + Resolve(Ref) (interface{}, error) } // ValueResolver defines the interface for resolving references to AST values. type ValueResolver interface { - Resolve(ref Ref) (Value, error) + Resolve(Ref) (Value, error) } // UnknownValueErr indicates a ValueResolver was unable to resolve a reference @@ -215,6 +216,11 @@ func valueToInterface(v Value, resolver Resolver, opt JSONOpt) (interface{}, err return nil, err } return buf, nil + case *lazyObj: + if opt.CopyMaps { + return valueToInterface(v.force(), resolver, opt) + } + return v.native, nil case Set: buf := []interface{}{} iter := func(x *Term) error { @@ -251,6 +257,7 @@ func JSON(v Value) (interface{}, error) { // JSONOpt defines parameters for AST to JSON conversion. type JSONOpt struct { SortSets bool // sort sets before serializing (this makes conversion more expensive) + CopyMaps bool // enforces copying of map[string]interface{} read from the store } // JSONWithOpt returns the JSON representation of v. The value must not contain any @@ -284,8 +291,10 @@ func MustInterfaceToValue(x interface{}) Value { // Term is an argument to a function. type Term struct { - Value Value `json:"value"` // the value of the Term as represented in Go - Location *Location `json:"-"` // the location of the Term in the source + Value Value `json:"value"` // the value of the Term as represented in Go + Location *Location `json:"location,omitempty"` // the location of the Term in the source + + jsonOptions JSONOptions } // NewTerm returns a new Term object. @@ -383,9 +392,13 @@ func (term *Term) Equal(other *Term) bool { // Get returns a value referred to by name from the term. func (term *Term) Get(name *Term) *Term { switch v := term.Value.(type) { + case *object: + return v.Get(name) case *Array: return v.Get(name) - case *object: + case interface { + Get(*Term) *Term + }: return v.Get(name) case Set: if v.Contains(name) { @@ -406,6 +419,10 @@ func (term *Term) IsGround() bool { return term.Value.IsGround() } +func (term *Term) setJSONOptions(opts JSONOptions) { + term.jsonOptions = opts +} + // MarshalJSON returns the JSON encoding of the term. // // Specialized marshalling logic is required to include a type hint for Value. @@ -414,6 +431,11 @@ func (term *Term) MarshalJSON() ([]byte, error) { "type": TypeName(term.Value), "value": term.Value, } + if term.jsonOptions.MarshalOptions.IncludeLocation.Term { + if term.Location != nil { + d["location"] = term.Location + } + } return json.Marshal(d) } @@ -422,7 +444,7 @@ func (term *Term) String() string { } // UnmarshalJSON parses the byte array and stores the result in term. -// Specialized unmarshalling is required to handle Value. +// Specialized unmarshalling is required to handle Value and Location. func (term *Term) UnmarshalJSON(bs []byte) error { v := map[string]interface{}{} if err := util.UnmarshalJSON(bs, &v); err != nil { @@ -433,6 +455,14 @@ func (term *Term) UnmarshalJSON(bs []byte) error { return err } term.Value = val + + if loc, ok := v["location"].(map[string]interface{}); ok { + term.Location = &Location{} + err := unmarshalLocation(term.Location, loc) + if err != nil { + return err + } + } return nil } @@ -493,6 +523,20 @@ func ContainsComprehensions(v interface{}) bool { return found } +// ContainsClosures returns true if the Value v contains closures. +func ContainsClosures(v interface{}) bool { + found := false + WalkClosures(v, func(x interface{}) bool { + switch x.(type) { + case *ArrayComprehension, *ObjectComprehension, *SetComprehension, *Every: + found = true + return found + } + return found + }) + return found +} + // IsScalar returns true if the AST value is a scalar. func IsScalar(v Value) bool { switch v.(type) { @@ -878,19 +922,16 @@ func (ref Ref) Append(term *Term) Ref { // existing elements are shifted to the right. If pos > len(ref)+1 this // function panics. func (ref Ref) Insert(x *Term, pos int) Ref { - if pos == len(ref) { + switch { + case pos == len(ref): return ref.Append(x) - } else if pos > len(ref)+1 { + case pos > len(ref)+1: panic("illegal index") } cpy := make(Ref, len(ref)+1) - for i := 0; i < pos; i++ { - cpy[i] = ref[i] - } + copy(cpy, ref[:pos]) cpy[pos] = x - for i := pos; i < len(ref); i++ { - cpy[i+1] = ref[i] - } + copy(cpy[pos+1:], ref[pos:]) return cpy } @@ -904,9 +945,8 @@ func (ref Ref) Extend(other Ref) Ref { head.Value = String(head.Value.(Var)) offset := len(ref) dst[offset] = head - for i := range other[1:] { - dst[offset+i+1] = other[i+1] - } + + copy(dst[offset+1:], other[1:]) return dst } @@ -917,10 +957,7 @@ func (ref Ref) Concat(terms []*Term) Ref { } cpy := make(Ref, len(ref)+len(terms)) copy(cpy, ref) - - for i := range terms { - cpy[len(ref)+i] = terms[i] - } + copy(cpy[len(ref):], terms) return cpy } @@ -1064,38 +1101,64 @@ func (ref Ref) String() string { } // OutputVars returns a VarSet containing variables that would be bound by evaluating -// this expression in isolation. +// this expression in isolation. func (ref Ref) OutputVars() VarSet { vis := NewVarVisitor().WithParams(VarVisitorParams{SkipRefHead: true}) vis.Walk(ref) return vis.Vars() } +func (ref Ref) toArray() *Array { + a := NewArray() + for _, term := range ref { + if _, ok := term.Value.(String); ok { + a = a.Append(term) + } else { + a = a.Append(StringTerm(term.Value.String())) + } + } + return a +} + // QueryIterator defines the interface for querying AST documents with references. type QueryIterator func(map[Var]Value, Value) error // ArrayTerm creates a new Term with an Array value. func ArrayTerm(a ...*Term) *Term { - return &Term{Value: &Array{a, 0}} + return NewTerm(NewArray(a...)) } // NewArray creates an Array with the terms provided. The array will // use the provided term slice. func NewArray(a ...*Term) *Array { - return &Array{a, 0} + hs := make([]int, len(a)) + for i, e := range a { + hs[i] = e.Value.Hash() + } + arr := &Array{elems: a, hashs: hs, ground: termSliceIsGround(a)} + arr.rehash() + return arr } // Array represents an array as defined by the language. Arrays are similar to the // same types as defined by JSON with the exception that they can contain Vars // and References. type Array struct { - elems []*Term - hash int + elems []*Term + hashs []int // element hashes + hash int + ground bool } // Copy returns a deep copy of arr. func (arr *Array) Copy() *Array { - return &Array{termSliceCopy(arr.elems), arr.hash} + cpy := make([]int, len(arr.elems)) + copy(cpy, arr.hashs) + return &Array{ + elems: termSliceCopy(arr.elems), + hashs: cpy, + hash: arr.hash, + ground: arr.IsGround()} } // Equal returns true if arr is equal to other. @@ -1155,28 +1218,24 @@ func (arr *Array) Sorted() *Array { } sort.Sort(termSlice(cpy)) a := NewArray(cpy...) - a.hash = arr.hash + a.hashs = arr.hashs return a } // Hash returns the hash code for the Value. func (arr *Array) Hash() int { - if arr.hash == 0 { - arr.hash = termSliceHash(arr.elems) - } - return arr.hash } // IsGround returns true if all of the Array elements are ground. func (arr *Array) IsGround() bool { - return termSliceIsGround(arr.elems) + return arr.ground } // MarshalJSON returns JSON encoded bytes representing arr. func (arr *Array) MarshalJSON() ([]byte, error) { if len(arr.elems) == 0 { - return json.Marshal([]interface{}{}) + return []byte(`[]`), nil } return json.Marshal(arr.elems) } @@ -1204,10 +1263,19 @@ func (arr *Array) Elem(i int) *Term { return arr.elems[i] } +// rehash updates the cached hash of arr. +func (arr *Array) rehash() { + arr.hash = 0 + for _, h := range arr.hashs { + arr.hash += h + } +} + // set sets the element i of arr. func (arr *Array) set(i int, v *Term) { + arr.ground = arr.ground && v.IsGround() arr.elems[i] = v - arr.hash = 0 + arr.hashs[i] = v.Value.Hash() } // Slice returns a slice of arr starting from i index to j. -1 @@ -1215,11 +1283,22 @@ func (arr *Array) set(i int, v *Term) { // copy and any modifications to either of arrays may be reflected to // the other. func (arr *Array) Slice(i, j int) *Array { + var elems []*Term + var hashs []int if j == -1 { - return &Array{elems: arr.elems[i:]} + elems = arr.elems[i:] + hashs = arr.hashs[i:] + } else { + elems = arr.elems[i:j] + hashs = arr.hashs[i:j] } + // If arr is ground, the slice is, too. + // If it's not, the slice could still be. + gr := arr.ground || termSliceIsGround(elems) - return &Array{elems: arr.elems[i:j]} + s := &Array{elems: elems, hashs: hashs, ground: gr} + s.rehash() + return s } // Iter calls f on each element in arr. If f returns an error, @@ -1256,7 +1335,9 @@ func (arr *Array) Foreach(f func(*Term)) { func (arr *Array) Append(v *Term) *Array { cpy := *arr cpy.elems = append(arr.elems, v) - cpy.hash = 0 + cpy.hashs = append(arr.hashs, v.Value.Hash()) + cpy.hash = arr.hash + v.Value.Hash() + cpy.ground = arr.ground && v.IsGround() return &cpy } @@ -1294,10 +1375,11 @@ func newset(n int) *set { keys = make([]*Term, 0, n) } return &set{ - elems: make(map[int]*Term, n), - keys: keys, - hash: 0, - ground: true, + elems: make(map[int]*Term, n), + keys: keys, + hash: 0, + ground: true, + sortGuard: new(sync.Once), } } @@ -1310,10 +1392,11 @@ func SetTerm(t ...*Term) *Term { } type set struct { - elems map[int]*Term - keys []*Term - hash int - ground bool + elems map[int]*Term + keys []*Term + hash int + ground bool + sortGuard *sync.Once // Prevents race condition around sorting. } // Copy returns a deep copy of s. @@ -1334,11 +1417,6 @@ func (s *set) IsGround() bool { // Hash returns a hash code for s. func (s *set) Hash() int { - if s.hash == 0 { - s.Foreach(func(x *Term) { - s.hash += x.Hash() - }) - } return s.hash } @@ -1348,7 +1426,7 @@ func (s *set) String() string { } var b strings.Builder b.WriteRune('{') - for i := range s.keys { + for i := range s.sortedKeys() { if i > 0 { b.WriteString(", ") } @@ -1358,6 +1436,13 @@ func (s *set) String() string { return b.String() } +func (s *set) sortedKeys() []*Term { + s.sortGuard.Do(func() { + sort.Sort(termSlice(s.keys)) + }) + return s.keys +} + // Compare compares s to other, return <0, 0, or >0 if it is less than, equal to, // or greater than other. func (s *set) Compare(other Value) int { @@ -1369,7 +1454,7 @@ func (s *set) Compare(other Value) int { return 1 } t := other.(*set) - return termSliceCompare(s.keys, t.keys) + return termSliceCompare(s.sortedKeys(), t.sortedKeys()) } // Find returns the set or dereferences the element itself. @@ -1435,7 +1520,7 @@ func (s *set) Add(t *Term) { // Iter calls f on each element in s. If f returns an error, iteration stops // and the return value is the error. func (s *set) Iter(f func(*Term) error) error { - for i := range s.keys { + for i := range s.sortedKeys() { if err := f(s.keys[i]); err != nil { return err } @@ -1509,28 +1594,28 @@ func (s *set) Len() int { // MarshalJSON returns JSON encoded bytes representing s. func (s *set) MarshalJSON() ([]byte, error) { if s.keys == nil { - return json.Marshal([]interface{}{}) + return []byte(`[]`), nil } - return json.Marshal(s.keys) + return json.Marshal(s.sortedKeys()) } // Sorted returns an Array that contains the sorted elements of s. func (s *set) Sorted() *Array { cpy := make([]*Term, len(s.keys)) - for i := range s.keys { - cpy[i] = s.keys[i] - } - sort.Sort(termSlice(cpy)) + copy(cpy, s.sortedKeys()) return NewArray(cpy...) } // Slice returns a slice of terms contained in the set. func (s *set) Slice() []*Term { - return s.keys + return s.sortedKeys() } +// NOTE(philipc): We assume a many-readers, single-writer model here. +// This method should NOT be used concurrently, or else we risk data races. func (s *set) insert(x *Term) { hash := x.Hash() + insertHash := hash // This `equal` utility is duplicated and manually inlined a number of // time in this file. Inlining it avoids heap allocations, so it makes // a big performance difference: some operations like lookup become twice @@ -1608,27 +1693,23 @@ func (s *set) insert(x *Term) { equal = func(y Value) bool { return Compare(x, y) == 0 } } - for curr, ok := s.elems[hash]; ok; { + for curr, ok := s.elems[insertHash]; ok; { if equal(curr.Value) { return } - hash++ - curr, ok = s.elems[hash] + insertHash++ + curr, ok = s.elems[insertHash] } - s.elems[hash] = x - i := sort.Search(len(s.keys), func(i int) bool { return Compare(x, s.keys[i]) < 0 }) - if i < len(s.keys) { - // insert at position `i`: - s.keys = append(s.keys, nil) // add some space - copy(s.keys[i+1:], s.keys[i:]) // move things over - s.keys[i] = x // drop it in position - } else { - s.keys = append(s.keys, x) - } + s.elems[insertHash] = x + // O(1) insertion, but we'll have to re-sort the keys later. + s.keys = append(s.keys, x) + // Reset the sync.Once instance. + // See https://github.com/golang/go/issues/25955 for why we do it this way. + s.sortGuard = new(sync.Once) - s.hash = 0 + s.hash += hash s.ground = s.ground && x.IsGround() } @@ -1740,7 +1821,7 @@ type Object interface { MergeWith(other Object, conflictResolver func(v1, v2 *Term) (*Term, bool)) (Object, bool) Filter(filter Object) (Object, error) Keys() []*Term - Elem(i int) (*Term, *Term) + KeysIterator() ObjectKeysIterator get(k *Term) *objectElem // To prevent external implementations } @@ -1758,12 +1839,171 @@ func ObjectTerm(o ...[2]*Term) *Term { return &Term{Value: NewObject(o...)} } +func LazyObject(blob map[string]interface{}) Object { + return &lazyObj{native: blob} +} + +type lazyObj struct { + strict Object + native map[string]interface{} +} + +func (l *lazyObj) force() Object { + if l.strict == nil { + l.strict = MustInterfaceToValue(l.native).(Object) + } + return l.strict +} + +func (l *lazyObj) Compare(other Value) int { + return l.force().Compare(other) +} + +func (l *lazyObj) Copy() Object { + return l +} + +func (l *lazyObj) Diff(other Object) Object { + return l.force().Diff(other) +} + +func (l *lazyObj) Intersect(other Object) [][3]*Term { + return l.force().Intersect(other) +} + +func (l *lazyObj) Iter(f func(*Term, *Term) error) error { + return l.force().Iter(f) +} + +func (l *lazyObj) Until(f func(*Term, *Term) bool) bool { + // NOTE(sr): there could be benefits in not forcing here -- if we abort because + // `f` returns true, we could save us from converting the rest of the object. + return l.force().Until(f) +} + +func (l *lazyObj) Foreach(f func(*Term, *Term)) { + l.force().Foreach(f) +} + +func (l *lazyObj) Filter(filter Object) (Object, error) { + return l.force().Filter(filter) +} + +func (l *lazyObj) Map(f func(*Term, *Term) (*Term, *Term, error)) (Object, error) { + return l.force().Map(f) +} + +func (l *lazyObj) MarshalJSON() ([]byte, error) { + return l.force().(*object).MarshalJSON() +} + +func (l *lazyObj) Merge(other Object) (Object, bool) { + return l.force().Merge(other) +} + +func (l *lazyObj) MergeWith(other Object, conflictResolver func(v1, v2 *Term) (*Term, bool)) (Object, bool) { + return l.force().MergeWith(other, conflictResolver) +} + +func (l *lazyObj) Len() int { + return len(l.native) +} + +func (l *lazyObj) String() string { + return l.force().String() +} + +// get is merely there to implement the Object interface -- `get` there serves the +// purpose of prohibiting external implementations. It's never called for lazyObj. +func (*lazyObj) get(*Term) *objectElem { + return nil +} + +func (l *lazyObj) Get(k *Term) *Term { + if l.strict != nil { + return l.strict.Get(k) + } + if s, ok := k.Value.(String); ok { + if val, ok := l.native[string(s)]; ok { + switch val := val.(type) { + case map[string]interface{}: + return NewTerm(&lazyObj{native: val}) + default: + return NewTerm(MustInterfaceToValue(val)) + } + } + } + return nil +} + +func (l *lazyObj) Insert(k, v *Term) { + l.force().Insert(k, v) +} + +func (*lazyObj) IsGround() bool { + return true +} + +func (l *lazyObj) Hash() int { + return l.force().Hash() +} + +func (l *lazyObj) Keys() []*Term { + if l.strict != nil { + return l.strict.Keys() + } + ret := make([]*Term, 0, len(l.native)) + for k := range l.native { + ret = append(ret, StringTerm(k)) + } + sort.Sort(termSlice(ret)) + return ret +} + +func (l *lazyObj) KeysIterator() ObjectKeysIterator { + return &lazyObjKeysIterator{keys: l.Keys()} +} + +type lazyObjKeysIterator struct { + current int + keys []*Term +} + +func (ki *lazyObjKeysIterator) Next() (*Term, bool) { + if ki.current == len(ki.keys) { + return nil, false + } + ki.current++ + return ki.keys[ki.current-1], true +} + +func (l *lazyObj) Find(path Ref) (Value, error) { + if l.strict != nil { + return l.strict.Find(path) + } + if len(path) == 0 { + return l, nil + } + if p0, ok := path[0].Value.(String); ok { + if v, ok := l.native[string(p0)]; ok { + switch v := v.(type) { + case map[string]interface{}: + return (&lazyObj{native: v}).Find(path[1:]) + default: + return MustInterfaceToValue(v).Find(path[1:]) + } + } + } + return nil, errFindNotFound +} + type object struct { elems map[int]*objectElem keys objectElemSlice ground int // number of key and value grounds. Counting is // required to support insert's key-value replace. - hash int + hash int + sortGuard *sync.Once // Prevents race condition around sorting. } func newobject(n int) *object { @@ -1772,10 +2012,11 @@ func newobject(n int) *object { keys = make(objectElemSlice, 0, n) } return &object{ - elems: make(map[int]*objectElem, n), - keys: keys, - ground: 0, - hash: 0, + elems: make(map[int]*objectElem, n), + keys: keys, + ground: 0, + hash: 0, + sortGuard: new(sync.Once), } } @@ -1797,9 +2038,19 @@ func Item(key, value *Term) [2]*Term { return [2]*Term{key, value} } +func (obj *object) sortedKeys() objectElemSlice { + obj.sortGuard.Do(func() { + sort.Sort(obj.keys) + }) + return obj.keys +} + // Compare compares obj to other, return <0, 0, or >0 if it is less than, equal to, // or greater than other. func (obj *object) Compare(other Value) int { + if x, ok := other.(*lazyObj); ok { + other = x.force() + } o1 := sortOrder(obj) o2 := sortOrder(other) if o1 < o2 { @@ -1809,29 +2060,32 @@ func (obj *object) Compare(other Value) int { } a := obj b := other.(*object) - minLen := len(a.keys) - if len(b.keys) < len(a.keys) { - minLen = len(b.keys) + // Ensure that keys are in canonical sorted order before use! + akeys := a.sortedKeys() + bkeys := b.sortedKeys() + minLen := len(akeys) + if len(b.keys) < len(akeys) { + minLen = len(bkeys) } for i := 0; i < minLen; i++ { - keysCmp := Compare(a.keys[i].key, b.keys[i].key) + keysCmp := Compare(akeys[i].key, bkeys[i].key) if keysCmp < 0 { return -1 } if keysCmp > 0 { return 1 } - valA := a.keys[i].value - valB := b.keys[i].value + valA := akeys[i].value + valB := bkeys[i].value valCmp := Compare(valA, valB) if valCmp != 0 { return valCmp } } - if len(a.keys) < len(b.keys) { + if len(akeys) < len(bkeys) { return -1 } - if len(b.keys) < len(a.keys) { + if len(bkeys) < len(akeys) { return 1 } return 0 @@ -1863,14 +2117,6 @@ func (obj *object) Get(k *Term) *Term { // Hash returns the hash code for the Value. func (obj *object) Hash() int { - if obj.hash == 0 { - for h, curr := range obj.elems { - for ; curr != nil; curr = curr.next { - obj.hash += h - obj.hash += curr.value.Hash() - } - } - } return obj.hash } @@ -1915,7 +2161,7 @@ func (obj *object) Intersect(other Object) [][3]*Term { // Iter calls the function f for each key-value pair in the object. If f // returns an error, iteration stops and the error is returned. func (obj *object) Iter(f func(*Term, *Term) error) error { - for _, node := range obj.keys { + for _, node := range obj.sortedKeys() { if err := f(node.key, node.value); err != nil { return err } @@ -1967,21 +2213,22 @@ func (obj *object) Map(f func(*Term, *Term) (*Term, *Term, error)) (Object, erro func (obj *object) Keys() []*Term { keys := make([]*Term, len(obj.keys)) - for i, elem := range obj.keys { + for i, elem := range obj.sortedKeys() { keys[i] = elem.key } return keys } -func (obj *object) Elem(i int) (*Term, *Term) { - return obj.keys[i].key, obj.keys[i].value +// Returns an iterator over the obj's keys. +func (obj *object) KeysIterator() ObjectKeysIterator { + return newobjectKeysIterator(obj) } // MarshalJSON returns JSON encoded bytes representing obj. func (obj *object) MarshalJSON() ([]byte, error) { sl := make([][2]*Term, obj.Len()) - for i, node := range obj.keys { + for i, node := range obj.sortedKeys() { sl[i] = Item(node.key, node.value) } return json.Marshal(sl) @@ -2061,7 +2308,7 @@ func (obj object) String() string { var b strings.Builder b.WriteRune('{') - for i, elem := range obj.keys { + for i, elem := range obj.sortedKeys() { if i > 0 { b.WriteString(", ") } @@ -2161,6 +2408,8 @@ func (obj *object) get(k *Term) *objectElem { return nil } +// NOTE(philipc): We assume a many-readers, single-writer model here. +// This method should NOT be used concurrently, or else we risk data races. func (obj *object) insert(k, v *Term) { hash := k.Hash() head := obj.elems[hash] @@ -2255,7 +2504,6 @@ func (obj *object) insert(k, v *Term) { } curr.value = v - obj.hash = 0 return } } @@ -2265,16 +2513,12 @@ func (obj *object) insert(k, v *Term) { next: head, } obj.elems[hash] = elem - i := sort.Search(len(obj.keys), func(i int) bool { return Compare(elem.key, obj.keys[i].key) < 0 }) - if i < len(obj.keys) { - // insert at position `i`: - obj.keys = append(obj.keys, nil) // add some space - copy(obj.keys[i+1:], obj.keys[i:]) // move things over - obj.keys[i] = elem // drop it in position - } else { - obj.keys = append(obj.keys, elem) - } - obj.hash = 0 + // O(1) insertion, but we'll have to re-sort the keys later. + obj.keys = append(obj.keys, elem) + // Reset the sync.Once instance. + // See https://github.com/golang/go/issues/25955 for why we do it this way. + obj.sortGuard = new(sync.Once) + obj.hash += hash + v.Hash() if k.IsGround() { obj.ground++ @@ -2349,6 +2593,36 @@ func filterObject(o Value, filter Value) (Value, error) { } } +// NOTE(philipc): The only way to get an ObjectKeyIterator should be +// from an Object. This ensures that the iterator can have implementation- +// specific details internally, with no contracts except to the very +// limited interface. +type ObjectKeysIterator interface { + Next() (*Term, bool) +} + +type objectKeysIterator struct { + obj *object + numKeys int + index int +} + +func newobjectKeysIterator(o *object) ObjectKeysIterator { + return &objectKeysIterator{ + obj: o, + numKeys: o.Len(), + index: 0, + } +} + +func (oki *objectKeysIterator) Next() (*Term, bool) { + if oki.index == oki.numKeys || oki.numKeys == 0 { + return nil, false + } + oki.index++ + return oki.obj.sortedKeys()[oki.index-1].key, true +} + // ArrayComprehension represents an array comprehension as defined in the language. type ArrayComprehension struct { Term *Term `json:"term"` @@ -2642,6 +2916,14 @@ func unmarshalExpr(expr *Expr, v map[string]interface{}) error { return fmt.Errorf("ast: unable to unmarshal negated field with type: %T (expected true or false)", v["negated"]) } } + if generatedRaw, ok := v["generated"]; ok { + if b, ok := generatedRaw.(bool); ok { + expr.Generated = b + } else { + return fmt.Errorf("ast: unable to unmarshal generated field with type: %T (expected true or false)", v["generated"]) + } + } + if err := unmarshalExprIndex(expr, v); err != nil { return err } @@ -2674,6 +2956,46 @@ func unmarshalExpr(expr *Expr, v map[string]interface{}) error { expr.With = ws } } + if loc, ok := v["location"].(map[string]interface{}); ok { + expr.Location = &Location{} + if err := unmarshalLocation(expr.Location, loc); err != nil { + return err + } + } + return nil +} + +func unmarshalLocation(loc *Location, v map[string]interface{}) error { + if x, ok := v["file"]; ok { + if s, ok := x.(string); ok { + loc.File = s + } else { + return fmt.Errorf("ast: unable to unmarshal file field with type: %T (expected string)", v["file"]) + } + } + if x, ok := v["row"]; ok { + if n, ok := x.(json.Number); ok { + i64, err := n.Int64() + if err != nil { + return err + } + loc.Row = int(i64) + } else { + return fmt.Errorf("ast: unable to unmarshal row field with type: %T (expected number)", v["row"]) + } + } + if x, ok := v["col"]; ok { + if n, ok := x.(json.Number); ok { + i64, err := n.Int64() + if err != nil { + return err + } + loc.Col = int(i64) + } else { + return fmt.Errorf("ast: unable to unmarshal col field with type: %T (expected number)", v["col"]) + } + } + return nil } @@ -2691,11 +3013,22 @@ func unmarshalExprIndex(expr *Expr, v map[string]interface{}) error { } func unmarshalTerm(m map[string]interface{}) (*Term, error) { + var term Term + v, err := unmarshalValue(m) if err != nil { return nil, err } - return &Term{Value: v}, nil + term.Value = v + + if loc, ok := m["location"].(map[string]interface{}); ok { + term.Location = &Location{} + if err := unmarshalLocation(term.Location, loc); err != nil { + return nil, err + } + } + + return &term, nil } func unmarshalTermSlice(s []interface{}) ([]*Term, error) { diff --git a/ast/term_bench_test.go b/ast/term_bench_test.go index 48892898a4..30ec222462 100644 --- a/ast/term_bench_test.go +++ b/ast/term_bench_test.go @@ -32,6 +32,44 @@ func BenchmarkObjectLookup(b *testing.B) { } } +func BenchmarkObjectCreationAndLookup(b *testing.B) { + sizes := []int{5, 50, 500, 5000, 50000, 500000} + for _, n := range sizes { + b.Run(fmt.Sprint(n), func(b *testing.B) { + obj := NewObject() + for i := 0; i < n; i++ { + obj.Insert(StringTerm(fmt.Sprint(i)), IntNumberTerm(i)) + } + key := StringTerm(fmt.Sprint(n - 1)) + for i := 0; i < b.N; i++ { + value := obj.Get(key) + if value == nil { + b.Fatal("expected hit") + } + } + }) + } +} + +func BenchmarkSetCreationAndLookup(b *testing.B) { + sizes := []int{5, 50, 500, 5000, 50000, 500000} + for _, n := range sizes { + b.Run(fmt.Sprint(n), func(b *testing.B) { + set := NewSet() + for i := 0; i < n; i++ { + set.Add(StringTerm(fmt.Sprint(i))) + } + key := StringTerm(fmt.Sprint(n - 1)) + for i := 0; i < b.N; i++ { + present := set.Contains(key) + if !present { + b.Fatal("expected hit") + } + } + }) + } +} + func BenchmarkSetIntersection(b *testing.B) { sizes := []int{5, 50, 500, 5000} for _, n := range sizes { @@ -109,14 +147,19 @@ func BenchmarkTermHashing(b *testing.B) { } } -var str string -var bs []byte +var ( + str string + bs []byte +) // BenchmarkObjectString generates several objects of different sizes, and // marshals them to JSON via two ways: -// map[string]int -> ast.Value -> .String() +// +// map[string]int -> ast.Value -> .String() +// // and -// map[string]int -> json.Marshal() +// +// map[string]int -> json.Marshal() // // The difference between these two is relevant for feeding input into the // wasm vm: when calling rego.New(...) with rego.Target("wasm"), it's up to @@ -128,7 +171,6 @@ func BenchmarkObjectString(b *testing.B) { for _, n := range sizes { b.Run(fmt.Sprint(n), func(b *testing.B) { - obj := map[string]int{} for i := 0; i < n; i++ { obj[fmt.Sprint(i)] = i @@ -154,8 +196,44 @@ func BenchmarkObjectString(b *testing.B) { } } -func BenchmarkObjectConstruction(b *testing.B) { +// This benchmark works similarly to BenchmarkObjectString, but with a key +// difference: it benchmarks the String and MarshalJSON interface functions +// for the Objec, instead of the underlying data structure. This ensures +// that we catch the full performance properties of Object's implementation. +func BenchmarkObjectStringInterfaces(b *testing.B) { + var err error sizes := []int{5, 50, 500, 5000, 50000} + + for _, n := range sizes { + b.Run(fmt.Sprint(n), func(b *testing.B) { + obj := map[string]int{} + for i := 0; i < n; i++ { + obj[fmt.Sprint(i)] = i + } + valString := MustInterfaceToValue(obj) + valJSON := MustInterfaceToValue(obj) + + b.Run("String()", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + str = valString.String() + } + }) + b.Run("json.Marshal", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + bs, err = json.Marshal(valJSON) + if err != nil { + b.Fatal(err) + } + } + }) + }) + } +} + +func BenchmarkObjectConstruction(b *testing.B) { + sizes := []int{5, 50, 500, 5000, 50000, 500000} seed := time.Now().UnixNano() b.Run("shuffled keys", func(b *testing.B) { @@ -165,7 +243,7 @@ func BenchmarkObjectConstruction(b *testing.B) { for i := 0; i < n; i++ { es = append(es, struct{ k, v int }{i, i}) } - rand.Seed(seed) + rand.New(rand.NewSource(seed)) // Seed the PRNG. rand.Shuffle(len(es), func(i, j int) { es[i], es[j] = es[j], es[i] }) b.ResetTimer() for i := 0; i < b.N; i++ { @@ -205,7 +283,6 @@ func BenchmarkArrayString(b *testing.B) { for _, n := range sizes { b.Run(fmt.Sprint(n), func(b *testing.B) { - obj := make([]string, n) for i := 0; i < n; i++ { obj[i] = fmt.Sprint(i) @@ -232,11 +309,10 @@ func BenchmarkArrayString(b *testing.B) { } func BenchmarkSetString(b *testing.B) { - sizes := []int{5, 50, 500, 5000} + sizes := []int{5, 50, 500, 5000, 50000} for _, n := range sizes { b.Run(fmt.Sprint(n), func(b *testing.B) { - val := NewSet() for i := 0; i < n; i++ { val.Add(IntNumberTerm(i)) @@ -251,3 +327,27 @@ func BenchmarkSetString(b *testing.B) { }) } } + +func BenchmarkSetMarshalJSON(b *testing.B) { + var err error + sizes := []int{5, 50, 500, 5000, 50000} + + for _, n := range sizes { + b.Run(fmt.Sprint(n), func(b *testing.B) { + set := NewSet() + for i := 0; i < n; i++ { + set.Add(StringTerm(fmt.Sprint(i))) + } + + b.Run("json.Marshal", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + bs, err = json.Marshal(set) + if err != nil { + b.Fatal(err) + } + } + }) + }) + } +} diff --git a/ast/term_test.go b/ast/term_test.go index ca2d482b48..0e8f50478d 100644 --- a/ast/term_test.go +++ b/ast/term_test.go @@ -7,16 +7,18 @@ package ast import ( "encoding/json" "fmt" + "math/rand" "reflect" + "runtime" "sort" "strings" + "sync" "testing" "github.com/open-policy-agent/opa/util" ) func TestInterfaceToValue(t *testing.T) { - // Test util package unmarshalled inputs input := ` { @@ -83,11 +85,9 @@ func TestInterfaceToValue(t *testing.T) { t.Fatalf("Expected %v but got: %v", expected, v) } } - } func TestInterfaceToValueStructs(t *testing.T) { - var x struct { Foo struct { Baz string `json:"baz"` @@ -158,7 +158,6 @@ func TestObjectInsertGetLen(t *testing.T) { } func TestObjectSetOperations(t *testing.T) { - a := MustParseTerm(`{"a": "b", "c": "d"}`).Value.(Object) b := MustParseTerm(`{"c": "q", "d": "e"}`).Value.(Object) @@ -263,70 +262,7 @@ func TestObjectFilter(t *testing.T) { } } -func TestObjectInsertKeepsSorting(t *testing.T) { - keysSorted := func(o *object) func(int, int) bool { - return func(i, j int) bool { - return Compare(o.keys[i].key, o.keys[j].key) < 0 - } - } - - obj := NewObject( - [2]*Term{StringTerm("d"), IntNumberTerm(4)}, - [2]*Term{StringTerm("b"), IntNumberTerm(2)}, - [2]*Term{StringTerm("a"), IntNumberTerm(1)}, - ) - o := obj.(*object) - act := sort.SliceIsSorted(o.keys, keysSorted(o)) - if exp := true; act != exp { - t.Errorf("SliceIsSorted: expected %v, got %v", exp, act) - for i := range o.keys { - t.Logf("elem[%d]: %v", i, o.keys[i].key) - } - } - - obj.Insert(StringTerm("c"), IntNumberTerm(3)) - act = sort.SliceIsSorted(o.keys, keysSorted(o)) - if exp := true; act != exp { - t.Errorf("SliceIsSorted: expected %v, got %v", exp, act) - for i := range o.keys { - t.Logf("elem[%d]: %v", i, o.keys[i].key) - } - } -} - -func TestSetInsertKeepsKeysSorting(t *testing.T) { - keysSorted := func(s *set) func(int, int) bool { - return func(i, j int) bool { - return Compare(s.keys[i], s.keys[j]) < 0 - } - } - - s0 := NewSet( - StringTerm("d"), - StringTerm("b"), - StringTerm("a"), - ) - s := s0.(*set) - act := sort.SliceIsSorted(s.keys, keysSorted(s)) - if exp := true; act != exp { - t.Errorf("SliceIsSorted: expected %v, got %v", exp, act) - for i := range s.keys { - t.Logf("elem[%d]: %v", i, s.keys[i]) - } - } - - s0.Add(StringTerm("c")) - act = sort.SliceIsSorted(s.keys, keysSorted(s)) - if exp := true; act != exp { - t.Errorf("SliceIsSorted: expected %v, got %v", exp, act) - for i := range s.keys { - t.Logf("elem[%d]: %v", i, s.keys[i]) - } - } -} - func TestTermBadJSON(t *testing.T) { - input := `{ "Value": [[ {"Value": [{"Value": "a", "Type": "var"}, {"Value": "x", "Type": "string"}], "Type": "ref"}, @@ -344,7 +280,6 @@ func TestTermBadJSON(t *testing.T) { if !reflect.DeepEqual(expected, err) { t.Errorf("Expected %v but got: %v", expected, err) } - } func TestTermEqual(t *testing.T) { @@ -382,7 +317,6 @@ func TestTermEqual(t *testing.T) { } func TestFind(t *testing.T) { - term := MustParseTerm(`{"foo": [1,{"bar": {2,3,4}}], "baz": {"qux": ["hello", "world"]}}`) tests := []struct { @@ -418,8 +352,7 @@ func TestFind(t *testing.T) { } } -func TestHash(t *testing.T) { - +func TestHashObject(t *testing.T) { doc := `{"a": [[true, {"b": [null]}, {"c": "d"}]], "e": {100: a[i].b}, "k": ["foo" | true], "o": {"foo": "bar" | true}, "sc": {"foo" | true}, "s": {1, 2, {3, 4}}, "big": 1e+1000}` stmt1 := MustParseStatement(doc) @@ -431,10 +364,78 @@ func TestHash(t *testing.T) { if obj1.Hash() != obj2.Hash() { t.Errorf("Expected hash codes to be equal") } + + // Calculate hash like we did before moving the caching to create/update: + obj := obj1.(*object) + exp := 0 + for h, curr := range obj.elems { + for ; curr != nil; curr = curr.next { + exp += h + exp += curr.value.Hash() + } + } + + if act := obj1.Hash(); exp != act { + t.Errorf("expected %v, got %v", exp, act) + } } -func TestTermIsGround(t *testing.T) { +func TestHashArray(t *testing.T) { + doc := `[{"a": [[true, {"b": [null]}, {"c": "d"}]]}, 100, true, [a[i].b], {100: a[i].b}, ["foo" | true], {"foo": "bar" | true}, {"foo" | true}, {1, 2, {3, 4}}, 1e+1000]` + + stmt1 := MustParseStatement(doc) + stmt2 := MustParseStatement(doc) + + arr1 := stmt1.(Body)[0].Terms.(*Term).Value.(*Array) + arr2 := stmt2.(Body)[0].Terms.(*Term).Value.(*Array) + + if arr1.Hash() != arr2.Hash() { + t.Errorf("Expected hash codes to be equal") + } + + // Calculate hash like we did before moving the caching to create/update: + exp := termSliceHash(arr1.elems) + if act := arr1.Hash(); exp != act { + t.Errorf("expected %v, got %v", exp, act) + } + + for j := 0; j < arr1.Len(); j++ { + for i := 0; i <= j; i++ { + slice := arr1.Slice(i, j) + exp := termSliceHash(slice.elems) + if act := slice.Hash(); exp != act { + t.Errorf("arr1[%d:%d]: expected %v, got %v", i, j, exp, act) + } + } + } +} + +func TestHashSet(t *testing.T) { + doc := `{{"a": [[true, {"b": [null]}, {"c": "d"}]]}, 100, 100, 100, true, [a[i].b], {100: a[i].b}, ["foo" | true], {"foo": "bar" | true}, {"foo" | true}, {1, 2, {3, 4}}, 1e+1000}` + + stmt1 := MustParseStatement(doc) + stmt2 := MustParseStatement(doc) + + set1 := stmt1.(Body)[0].Terms.(*Term).Value.(Set) + set2 := stmt2.(Body)[0].Terms.(*Term).Value.(Set) + + if set1.Hash() != set2.Hash() { + t.Errorf("Expected hash codes to be equal") + } + + // Calculate hash like we did before moving the caching to create/update: + exp := 0 + set1.Foreach(func(x *Term) { + exp += x.Hash() + }) + + if act := set1.Hash(); exp != act { + t.Errorf("expected %v, got %v", exp, act) + } +} + +func TestTermIsGround(t *testing.T) { tests := []struct { note string term string @@ -468,7 +469,6 @@ func TestTermIsGround(t *testing.T) { t.Errorf("Expected term %v to be %s (test case %d: %v)", term, expected, i, tc.note) } } - } func TestObjectRemainsGround(t *testing.T) { @@ -512,7 +512,6 @@ func TestIsConstant(t *testing.T) { } func TestIsScalar(t *testing.T) { - tests := []struct { term string expected bool @@ -566,7 +565,6 @@ func TestTermString(t *testing.T) { } func TestRefHasPrefix(t *testing.T) { - a := MustParseRef("foo.bar.baz") b := MustParseRef("foo.bar") c := MustParseRef("foo.bar[0][x]") @@ -706,7 +704,6 @@ func TestRefPtr(t *testing.T) { if _, err := ref.Ptr(); err == nil { t.Fatal("Expected error from x[1]") } - } func TestSetEqual(t *testing.T) { @@ -738,7 +735,6 @@ func TestSetEqual(t *testing.T) { } func TestSetMap(t *testing.T) { - set := MustParseTerm(`{"foo", "bar", "baz", "qux"}`).Value.(Set) result, err := set.Map(func(term *Term) (*Term, error) { @@ -748,7 +744,6 @@ func TestSetMap(t *testing.T) { } return term, nil }) - if err != nil { t.Fatalf("Unexpected error: %v", err) } @@ -766,7 +761,6 @@ func TestSetMap(t *testing.T) { if !reflect.DeepEqual(err, fmt.Errorf("oops")) { t.Fatalf("Expected oops to be returned but got: %v, %v", result, err) } - } func TestSetAddContainsLen(t *testing.T) { @@ -805,7 +799,6 @@ func TestSetAddContainsLen(t *testing.T) { } func TestSetOperations(t *testing.T) { - tests := []struct { a string b string @@ -861,8 +854,95 @@ func TestSetCopy(t *testing.T) { } } -func TestArrayOperations(t *testing.T) { +// Constructs a set, and then has several reader goroutines attempt to +// concurrently iterate across it. This should pretty consistently +// hit a race condition around sorting the underlying key slice if +// the sorting isn't guarded properly. +func TestSetConcurrentReads(t *testing.T) { + // Create array of numbers. + numbers := make([]*Term, 10000) + for i := 0; i < 10000; i++ { + numbers[i] = IntNumberTerm(i) + } + // Shuffle numbers array for random insertion order. + rand.New(rand.NewSource(10000)) // Seed the PRNG. + rand.Shuffle(len(numbers), func(i, j int) { + numbers[i], numbers[j] = numbers[j], numbers[i] + }) + // Build set with numbers in unsorted order. + s := NewSet() + for i := 0; i < len(numbers); i++ { + s.Add(numbers[i]) + } + // In-place sort on numbers. + sort.Sort(termSlice(numbers)) + + // Check if race condition on key sorting is present. + var wg sync.WaitGroup + num := runtime.NumCPU() + wg.Add(num) + for i := 0; i < num; i++ { + go func() { + defer wg.Done() + var retrieved []*Term + s.Foreach(func(v *Term) { + retrieved = append(retrieved, v) + }) + // Check for sortedness of retrieved results. + // This will hit a race condition around `s.sortedKeys`. + for n := 0; n < len(retrieved); n++ { + if retrieved[n] != numbers[n] { + t.Errorf("Expected: %v at iteration %d but got %v instead", numbers[n], n, retrieved[n]) + } + } + }() + } + wg.Wait() +} + +func TestObjectConcurrentReads(t *testing.T) { + // Create array of numbers. + numbers := make([]*Term, 10000) + for i := 0; i < 10000; i++ { + numbers[i] = IntNumberTerm(i) + } + // Shuffle numbers array for random insertion order. + rand.New(rand.NewSource(10000)) // Seed the PRNG. + rand.Shuffle(len(numbers), func(i, j int) { + numbers[i], numbers[j] = numbers[j], numbers[i] + }) + // Build an object with numbers in unsorted order. + o := NewObject() + for i := 0; i < len(numbers); i++ { + o.Insert(numbers[i], NullTerm()) + } + // In-place sort on numbers. + sort.Sort(termSlice(numbers)) + + // Check if race condition on key sorting is present. + var wg sync.WaitGroup + num := runtime.NumCPU() + wg.Add(num) + for i := 0; i < num; i++ { + go func() { + defer wg.Done() + var retrieved []*Term + o.Foreach(func(k, v *Term) { + retrieved = append(retrieved, k) + }) + // Check for sortedness of retrieved results. + // This will hit a race condition around `s.sortedKeys`. + for n := 0; n < len(retrieved); n++ { + if retrieved[n] != numbers[n] { + t.Errorf("Expected: %v at iteration %d but got %v instead", numbers[n], n, retrieved[n]) + } + } + }() + } + wg.Wait() +} +func TestArrayOperations(t *testing.T) { arr := MustParseTerm(`[1,2,3,4]`).Value.(*Array) getTests := []struct { @@ -987,7 +1067,6 @@ func TestArrayOperations(t *testing.T) { } func TestValueToInterface(t *testing.T) { - // Happy path term := MustParseTerm(`{ "foo": [1, "two", true, null, {3, @@ -1060,6 +1139,53 @@ func TestValueToInterface(t *testing.T) { } } +// NOTE(sr): Without the opt-out, we don't allocate another object for +// the conversion back to interface{} if it can be avoided. As a result, +// the value held by the store could be changed. +func TestJSONWithOptLazyObjDefault(t *testing.T) { + // would live in the store + m := map[string]interface{}{ + "foo": "bar", + } + o := LazyObject(m) + + n, err := JSONWithOpt(o, JSONOpt{}) + if err != nil { + t.Fatal(err) + } + n0, ok := n.(map[string]interface{}) + if !ok { + t.Fatalf("expected %T, got %T: %[2]v", n0, n) + } + n0["baz"] = true + + if v, ok := m["baz"]; !ok || !v.(bool) { + t.Errorf("expected change in m, found none: %v", m) + } +} + +func TestJSONWithOptLazyObjOptOut(t *testing.T) { + // would live in the store + m := map[string]interface{}{ + "foo": "bar", + } + o := LazyObject(m) + + n, err := JSONWithOpt(o, JSONOpt{CopyMaps: true}) + if err != nil { + t.Fatal(err) + } + n0, ok := n.(map[string]interface{}) + if !ok { + t.Fatalf("expected %T, got %T: %[2]v", n0, n) + } + n0["baz"] = true + + if _, ok := m["baz"]; ok { + t.Errorf("expected no change in m, found one: %v", m) + } +} + func assertTermEqual(t *testing.T, x *Term, y *Term) { if !x.Equal(y) { t.Errorf("Failure on equality: \n%s and \n%s\n", x, y) @@ -1078,3 +1204,167 @@ func assertToString(t *testing.T, val Value, expected string) { t.Errorf("Expected %v but got %v", expected, result) } } + +func TestLazyObjectGet(t *testing.T) { + x := LazyObject(map[string]interface{}{ + "a": map[string]interface{}{ + "b": map[string]interface{}{ + "c": true, + }, + }, + }) + y := x.Get(StringTerm("a")) + _, ok := y.Value.(*lazyObj) + if !ok { + t.Errorf("expected Get() to return another lazy object, got %v %[1]T", y.Value) + } + assertForced(t, x, false) +} + +func TestLazyObjectFind(t *testing.T) { + x := LazyObject(map[string]interface{}{ + "a": map[string]interface{}{ + "b": map[string]interface{}{ + "c": true, + }, + "d": []interface{}{true, true, true}, + }, + }) + // retrieve object via Find + y, err := x.Find(Ref{StringTerm("a"), StringTerm("b")}) + if err != nil { + t.Fatal(err) + } + _, ok := y.(*lazyObj) + if !ok { + t.Errorf("expected Find() to return another lazy object, got %v %[1]T", y) + } + assertForced(t, x, false) + + // retrieve array via Find + z, err := x.Find(Ref{StringTerm("a"), StringTerm("d")}) + if err != nil { + t.Fatal(err) + } + _, ok = z.(*Array) + if !ok { + t.Errorf("expected Find() to return array, got %v %[1]T", z) + } +} + +func TestLazyObjectCopy(t *testing.T) { + x := LazyObject(map[string]interface{}{ + "a": map[string]interface{}{ + "b": map[string]interface{}{ + "c": true, + }, + }, + }) + y := x.Copy() + _, ok := y.(*lazyObj) + if !ok { + t.Errorf("expected Get() to return another lazy object, got %v %[1]T", y) + } + assertForced(t, x, false) +} + +func TestLazyObjectLen(t *testing.T) { + x := LazyObject(map[string]interface{}{ + "a": map[string]interface{}{ + "b": map[string]interface{}{ + "c": true, + }, + }, + }) + if exp, act := 1, x.Len(); exp != act { + t.Errorf("expected Len() %v, got %v", exp, act) + } + assertForced(t, x, false) +} + +func TestLazyObjectIsGround(t *testing.T) { + x := LazyObject(map[string]interface{}{ + "a": map[string]interface{}{ + "b": map[string]interface{}{ + "c": true, + }, + }, + }) + if exp, act := true, x.IsGround(); exp != act { + t.Errorf("expected IsGround() %v, got %v", exp, act) + } + assertForced(t, x, false) +} + +func TestLazyObjectInsert(t *testing.T) { + x := LazyObject(map[string]interface{}{ + "a": "b", + }) + x.Insert(StringTerm("c"), StringTerm("d")) + assertForced(t, x, true) + + // NOTE(sr): We compare after asserting that it was forced, since comparison + // forces the lazy object, too. + if act, exp := x, NewObject(Item(StringTerm("a"), StringTerm("b")), Item(StringTerm("c"), StringTerm("d"))); exp.Compare(act) != 0 { + t.Errorf("expected %v to be equal to %v", act, exp) + } +} + +func TestLazyObjectKeys(t *testing.T) { + x := LazyObject(map[string]interface{}{ + "a": "A", + "c": "C", + "b": "B", + }) + act := x.Keys() + exp := []*Term{StringTerm("a"), StringTerm("b"), StringTerm("c")} + if !reflect.DeepEqual(exp, act) { + t.Errorf("expected Keys() %v, got %v", exp, act) + } + assertForced(t, x, false) +} + +func TestLazyObjectKeysIterator(t *testing.T) { + x := LazyObject(map[string]interface{}{ + "a": "A", + "c": "C", + "b": "B", + }) + ki := x.KeysIterator() + act := make([]*Term, 0, x.Len()) + for k, next := ki.Next(); next; k, next = ki.Next() { + act = append(act, k) + } + exp := []*Term{StringTerm("a"), StringTerm("b"), StringTerm("c")} + if !reflect.DeepEqual(exp, act) { + t.Errorf("expected Keys() %v, got %v", exp, act) + } + assertForced(t, x, false) +} + +func TestLazyObjectCompare(t *testing.T) { + x := LazyObject(map[string]interface{}{ + "a": map[string]interface{}{ + "b": map[string]interface{}{ + "c": true, + }, + }, + }) + if exp, act := 1, x.Compare(NewObject()); exp != act { + t.Errorf("expected Compare() => %v, got %v", exp, act) + } + assertForced(t, x, true) +} + +func assertForced(t *testing.T, x Object, forced bool) { + t.Helper() + l, ok := x.(*lazyObj) + switch { + case !ok: + t.Errorf("expected lazy object, got %v %[1]T", x) + case !forced && l.strict != nil: + t.Errorf("expected %v to not be forced", l) + case forced && l.strict == nil: + t.Errorf("expected %v to be forced", l) + } +} diff --git a/ast/transform.go b/ast/transform.go index c7fa4c8f1d..391a164860 100644 --- a/ast/transform.go +++ b/ast/transform.go @@ -13,7 +13,7 @@ import ( // be set to nil and no transformations will be applied to children of the // element. type Transformer interface { - Transform(v interface{}) (interface{}, error) + Transform(interface{}) (interface{}, error) } // Transform iterates the AST and calls the Transform function on the @@ -116,6 +116,9 @@ func Transform(t Transformer, x interface{}) (interface{}, error) { } return y, nil case *Head: + if y.Reference, err = transformRef(t, y.Reference); err != nil { + return nil, err + } if y.Name, err = transformVar(t, y.Name); err != nil { return nil, err } @@ -172,6 +175,26 @@ func Transform(t Transformer, x interface{}) (interface{}, error) { if y.Terms, err = transformTerm(t, ts); err != nil { return nil, err } + case *Every: + if ts.Key != nil { + ts.Key, err = transformTerm(t, ts.Key) + if err != nil { + return nil, err + } + } + ts.Value, err = transformTerm(t, ts.Value) + if err != nil { + return nil, err + } + ts.Domain, err = transformTerm(t, ts.Domain) + if err != nil { + return nil, err + } + ts.Body, err = transformBody(t, ts.Body) + if err != nil { + return nil, err + } + y.Terms = ts } for i, w := range y.With { w, err := Transform(t, w) @@ -307,7 +330,7 @@ func TransformComprehensions(x interface{}, f func(interface{}) (Value, error)) // GenericTransformer implements the Transformer interface to provide a utility // to transform AST nodes using a closure. type GenericTransformer struct { - f func(x interface{}) (interface{}, error) + f func(interface{}) (interface{}, error) } // NewGenericTransformer returns a new GenericTransformer that will transform @@ -394,3 +417,15 @@ func transformVar(t Transformer, v Var) (Var, error) { } return r, nil } + +func transformRef(t Transformer, r Ref) (Ref, error) { + r1, err := Transform(t, r) + if err != nil { + return nil, err + } + r2, ok := r1.(Ref) + if !ok { + return nil, fmt.Errorf("illegal transform: %T != %T", r, r2) + } + return r2, nil +} diff --git a/ast/transform_test.go b/ast/transform_test.go index 9ae8a87f4c..8c5a051aae 100644 --- a/ast/transform_test.go +++ b/ast/transform_test.go @@ -13,6 +13,7 @@ func TestTransform(t *testing.T) { import input.foo import data.bar.this as qux +import future.keywords.every p = true { "this" = "that" } p = "this" { false } @@ -22,6 +23,8 @@ p = true { ["this" | "this"] } p = n { count({"this", "that"}, n) with input.foo.this as {"this": true} } p { false } else = "this" { "this" } else = ["this"] { true } foo(x) = y { split(x, "this", y) } +p { every x in ["this"] { x == "this" } } +a.b.c.this["this"] = d { d := "this" } `) result, err := Transform(&GenericTransformer{ @@ -46,6 +49,7 @@ foo(x) = y { split(x, "this", y) } import input.foo import data.bar.that as qux +import future.keywords.every p = true { "that" = "that" } p = "that" { false } @@ -55,6 +59,8 @@ p = true { ["that" | "that"] } p = n { count({"that"}, n) with input.foo.that as {"that": true} } p { false } else = "that" { "that" } else = ["that"] { true } foo(x) = y { split(x, "that", y) } +p { every x in ["that"] { x == "that" } } +a.b.c.that["that"] = d { d := "that" } `) if !expected.Equal(resultMod) { @@ -109,3 +115,23 @@ p := 7`, ParserOptions{ProcessAnnotation: true}) } } + +func TestTransformRefsAndRuleHeads(t *testing.T) { + module := MustParseModule(`package test +p.q.this.fo[x] = y { x := "x"; y := "y" }`) + + result, err := TransformRefs(module, func(r Ref) (Value, error) { + if r[0].Value.Compare(Var("p")) == 0 { + r[2] = StringTerm("that") + } + return r, nil + }) + if err != nil { + t.Fatalf("Unexpected error: %s", err) + } + + resultMod := result.(*Module) + if exp, act := MustParseRef("p.q.that.fo[x]"), resultMod.Rules[0].Head.Reference; !act.Equal(exp) { + t.Errorf("expected %v, got %v", exp, act) + } +} diff --git a/ast/varset.go b/ast/varset.go index 16dc3f5849..14f531494b 100644 --- a/ast/varset.go +++ b/ast/varset.go @@ -91,7 +91,7 @@ func (s VarSet) Update(vs VarSet) { } func (s VarSet) String() string { - tmp := []string{} + tmp := make([]string, 0, len(s)) for v := range s { tmp = append(tmp, string(v)) } diff --git a/ast/visit.go b/ast/visit.go index 139c4de3f5..7b0d3b08e7 100644 --- a/ast/visit.go +++ b/ast/visit.go @@ -7,13 +7,15 @@ package ast // Visitor defines the interface for iterating AST elements. The Visit function // can return a Visitor w which will be used to visit the children of the AST // element v. If the Visit function returns nil, the children will not be -// visited. This is deprecated. +// visited. +// Deprecated: use GenericVisitor or another visitor implementation type Visitor interface { Visit(v interface{}) (w Visitor) } // BeforeAndAfterVisitor wraps Visitor to provide hooks for being called before -// and after the AST has been visited. This is deprecated. +// and after the AST has been visited. +// Deprecated: use GenericVisitor or another visitor implementation type BeforeAndAfterVisitor interface { Visitor Before(x interface{}) @@ -21,7 +23,8 @@ type BeforeAndAfterVisitor interface { } // Walk iterates the AST by calling the Visit function on the Visitor -// v for x before recursing. This is deprecated. +// v for x before recursing. +// Deprecated: use GenericVisitor.Walk func Walk(v Visitor, x interface{}) { if bav, ok := v.(BeforeAndAfterVisitor); !ok { walk(v, x) @@ -33,7 +36,8 @@ func Walk(v Visitor, x interface{}) { } // WalkBeforeAndAfter iterates the AST by calling the Visit function on the -// Visitor v for x before recursing. This is deprecated. +// Visitor v for x before recursing. +// Deprecated: use GenericVisitor.Walk func WalkBeforeAndAfter(v BeforeAndAfterVisitor, x interface{}) { Walk(v, x) } @@ -46,17 +50,17 @@ func walk(v Visitor, x interface{}) { switch x := x.(type) { case *Module: Walk(w, x.Package) - for _, i := range x.Imports { - Walk(w, i) + for i := range x.Imports { + Walk(w, x.Imports[i]) } - for _, r := range x.Rules { - Walk(w, r) + for i := range x.Rules { + Walk(w, x.Rules[i]) } - for _, a := range x.Annotations { - Walk(w, a) + for i := range x.Annotations { + Walk(w, x.Annotations[i]) } - for _, c := range x.Comments { - Walk(w, c) + for i := range x.Comments { + Walk(w, x.Comments[i]) } case *Package: Walk(w, x.Path) @@ -79,23 +83,21 @@ func walk(v Visitor, x interface{}) { Walk(w, x.Value) } case Body: - for _, e := range x { - Walk(w, e) + for i := range x { + Walk(w, x[i]) } case Args: - for _, t := range x { - Walk(w, t) + for i := range x { + Walk(w, x[i]) } case *Expr: switch ts := x.Terms.(type) { - case *SomeDecl: + case *Term, *SomeDecl, *Every: Walk(w, ts) case []*Term: - for _, t := range ts { - Walk(w, t) + for i := range ts { + Walk(w, ts[i]) } - case *Term: - Walk(w, ts) } for i := range x.With { Walk(w, x.With[i]) @@ -106,8 +108,8 @@ func walk(v Visitor, x interface{}) { case *Term: Walk(w, x.Value) case Ref: - for _, t := range x { - Walk(w, t) + for i := range x { + Walk(w, x[i]) } case *object: x.Foreach(func(k, vv *Term) { @@ -133,8 +135,19 @@ func walk(v Visitor, x interface{}) { Walk(w, x.Term) Walk(w, x.Body) case Call: - for _, t := range x { - Walk(w, t) + for i := range x { + Walk(w, x[i]) + } + case *Every: + if x.Key != nil { + Walk(w, x.Key) + } + Walk(w, x.Value) + Walk(w, x.Domain) + Walk(w, x.Body) + case *SomeDecl: + for i := range x.Symbols { + Walk(w, x.Symbols[i]) } } } @@ -155,8 +168,8 @@ func WalkVars(x interface{}, f func(Var) bool) { // returns true, AST nodes under the last node will not be visited. func WalkClosures(x interface{}, f func(interface{}) bool) { vis := &GenericVisitor{func(x interface{}) bool { - switch x.(type) { - case *ArrayComprehension, *ObjectComprehension, *SetComprehension: + switch x := x.(type) { + case *ArrayComprehension, *ObjectComprehension, *SetComprehension, *Every: return f(x) } return false @@ -277,17 +290,17 @@ func (vis *GenericVisitor) Walk(x interface{}) { switch x := x.(type) { case *Module: vis.Walk(x.Package) - for _, i := range x.Imports { - vis.Walk(i) + for i := range x.Imports { + vis.Walk(x.Imports[i]) } - for _, r := range x.Rules { - vis.Walk(r) + for i := range x.Rules { + vis.Walk(x.Rules[i]) } - for _, a := range x.Annotations { - vis.Walk(a) + for i := range x.Annotations { + vis.Walk(x.Annotations[i]) } - for _, c := range x.Comments { - vis.Walk(c) + for i := range x.Comments { + vis.Walk(x.Comments[i]) } case *Package: vis.Walk(x.Path) @@ -310,23 +323,21 @@ func (vis *GenericVisitor) Walk(x interface{}) { vis.Walk(x.Value) } case Body: - for _, e := range x { - vis.Walk(e) + for i := range x { + vis.Walk(x[i]) } case Args: - for _, t := range x { - vis.Walk(t) + for i := range x { + vis.Walk(x[i]) } case *Expr: switch ts := x.Terms.(type) { - case *SomeDecl: + case *Term, *SomeDecl, *Every: vis.Walk(ts) case []*Term: - for _, t := range ts { - vis.Walk(t) + for i := range ts { + vis.Walk(ts[i]) } - case *Term: - vis.Walk(ts) } for i := range x.With { vis.Walk(x.With[i]) @@ -337,21 +348,27 @@ func (vis *GenericVisitor) Walk(x interface{}) { case *Term: vis.Walk(x.Value) case Ref: - for _, t := range x { - vis.Walk(t) + for i := range x { + vis.Walk(x[i]) } case *object: x.Foreach(func(k, v *Term) { vis.Walk(k) vis.Walk(x.Get(k)) }) + case Object: + x.Foreach(func(k, v *Term) { + vis.Walk(k) + vis.Walk(x.Get(k)) + }) case *Array: x.Foreach(func(t *Term) { vis.Walk(t) }) case Set: - for _, t := range x.Slice() { - vis.Walk(t) + xSlice := x.Slice() + for i := range xSlice { + vis.Walk(xSlice[i]) } case *ArrayComprehension: vis.Walk(x.Term) @@ -364,8 +381,19 @@ func (vis *GenericVisitor) Walk(x interface{}) { vis.Walk(x.Term) vis.Walk(x.Body) case Call: - for _, t := range x { - vis.Walk(t) + for i := range x { + vis.Walk(x[i]) + } + case *Every: + if x.Key != nil { + vis.Walk(x.Key) + } + vis.Walk(x.Value) + vis.Walk(x.Domain) + vis.Walk(x.Body) + case *SomeDecl: + for i := range x.Symbols { + vis.Walk(x.Symbols[i]) } } } @@ -398,17 +426,17 @@ func (vis *BeforeAfterVisitor) Walk(x interface{}) { switch x := x.(type) { case *Module: vis.Walk(x.Package) - for _, i := range x.Imports { - vis.Walk(i) + for i := range x.Imports { + vis.Walk(x.Imports[i]) } - for _, r := range x.Rules { - vis.Walk(r) + for i := range x.Rules { + vis.Walk(x.Rules[i]) } - for _, a := range x.Annotations { - vis.Walk(a) + for i := range x.Annotations { + vis.Walk(x.Annotations[i]) } - for _, c := range x.Comments { - vis.Walk(c) + for i := range x.Comments { + vis.Walk(x.Comments[i]) } case *Package: vis.Walk(x.Path) @@ -422,32 +450,34 @@ func (vis *BeforeAfterVisitor) Walk(x interface{}) { vis.Walk(x.Else) } case *Head: - vis.Walk(x.Name) - vis.Walk(x.Args) - if x.Key != nil { - vis.Walk(x.Key) + if len(x.Reference) > 0 { + vis.Walk(x.Reference) + } else { + vis.Walk(x.Name) + if x.Key != nil { + vis.Walk(x.Key) + } } + vis.Walk(x.Args) if x.Value != nil { vis.Walk(x.Value) } case Body: - for _, e := range x { - vis.Walk(e) + for i := range x { + vis.Walk(x[i]) } case Args: - for _, t := range x { - vis.Walk(t) + for i := range x { + vis.Walk(x[i]) } case *Expr: switch ts := x.Terms.(type) { - case *SomeDecl: + case *Term, *SomeDecl, *Every: vis.Walk(ts) case []*Term: - for _, t := range ts { - vis.Walk(t) + for i := range ts { + vis.Walk(ts[i]) } - case *Term: - vis.Walk(ts) } for i := range x.With { vis.Walk(x.With[i]) @@ -458,21 +488,27 @@ func (vis *BeforeAfterVisitor) Walk(x interface{}) { case *Term: vis.Walk(x.Value) case Ref: - for _, t := range x { - vis.Walk(t) + for i := range x { + vis.Walk(x[i]) } case *object: x.Foreach(func(k, v *Term) { vis.Walk(k) vis.Walk(x.Get(k)) }) + case Object: + x.Foreach(func(k, v *Term) { + vis.Walk(k) + vis.Walk(x.Get(k)) + }) case *Array: x.Foreach(func(t *Term) { vis.Walk(t) }) case Set: - for _, t := range x.Slice() { - vis.Walk(t) + xSlice := x.Slice() + for i := range xSlice { + vis.Walk(xSlice[i]) } case *ArrayComprehension: vis.Walk(x.Term) @@ -485,8 +521,19 @@ func (vis *BeforeAfterVisitor) Walk(x interface{}) { vis.Walk(x.Term) vis.Walk(x.Body) case Call: - for _, t := range x { - vis.Walk(t) + for i := range x { + vis.Walk(x[i]) + } + case *Every: + if x.Key != nil { + vis.Walk(x.Key) + } + vis.Walk(x.Value) + vis.Walk(x.Domain) + vis.Walk(x.Body) + case *SomeDecl: + for i := range x.Symbols { + vis.Walk(x.Symbols[i]) } } } @@ -527,6 +574,8 @@ func (vis *VarVisitor) Vars() VarSet { return vis.vars } +// visit determines if the VarVisitor will recurse into x: if it returns `true`, +// the visitor will _skip_ that branch of the AST func (vis *VarVisitor) visit(v interface{}) bool { if vis.params.SkipObjectKeys { if o, ok := v.(Object); ok { @@ -538,16 +587,23 @@ func (vis *VarVisitor) visit(v interface{}) bool { } if vis.params.SkipRefHead { if r, ok := v.(Ref); ok { - for _, t := range r[1:] { - vis.Walk(t) + rSlice := r[1:] + for i := range rSlice { + vis.Walk(rSlice[i]) } return true } } if vis.params.SkipClosures { - switch v.(type) { + switch v := v.(type) { case *ArrayComprehension, *ObjectComprehension, *SetComprehension: return true + case *Expr: + if ev, ok := v.Terms.(*Every); ok { + vis.Walk(ev.Domain) + // We're _not_ walking ev.Body -- that's the closure here + return true + } } } if vis.params.SkipWithTarget { @@ -565,14 +621,15 @@ func (vis *VarVisitor) visit(v interface{}) bool { switch v := v.(type) { case *Expr: if terms, ok := v.Terms.([]*Term); ok { - for _, t := range terms[0].Value.(Ref)[1:] { - vis.Walk(t) + termSlice := terms[0].Value.(Ref)[1:] + for i := range termSlice { + vis.Walk(termSlice[i]) } for i := 1; i < len(terms); i++ { vis.Walk(terms[i]) } - for _, w := range v.With { - vis.Walk(w) + for i := range v.With { + vis.Walk(v.With[i]) } return true } @@ -585,6 +642,22 @@ func (vis *VarVisitor) visit(v interface{}) bool { vis.Walk(v[i]) } return true + case *With: + if ref, ok := v.Target.Value.(Ref); ok { + refSlice := ref[1:] + for i := range refSlice { + vis.Walk(refSlice[i]) + } + } + if ref, ok := v.Value.Value.(Ref); ok { + refSlice := ref[1:] + for i := range refSlice { + vis.Walk(refSlice[i]) + } + } else { + vis.Walk(v.Value) + } + return true } } if v, ok := v.(Var); ok { @@ -604,14 +677,14 @@ func (vis *VarVisitor) Walk(x interface{}) { switch x := x.(type) { case *Module: vis.Walk(x.Package) - for _, i := range x.Imports { - vis.Walk(i) + for i := range x.Imports { + vis.Walk(x.Imports[i]) } - for _, r := range x.Rules { - vis.Walk(r) + for i := range x.Rules { + vis.Walk(x.Rules[i]) } - for _, c := range x.Comments { - vis.Walk(c) + for i := range x.Comments { + vis.Walk(x.Comments[i]) } case *Package: vis.Walk(x.Path) @@ -625,32 +698,35 @@ func (vis *VarVisitor) Walk(x interface{}) { vis.Walk(x.Else) } case *Head: - vis.Walk(x.Name) - vis.Walk(x.Args) - if x.Key != nil { - vis.Walk(x.Key) + if len(x.Reference) > 0 { + vis.Walk(x.Reference) + } else { + vis.Walk(x.Name) + if x.Key != nil { + vis.Walk(x.Key) + } } + vis.Walk(x.Args) + if x.Value != nil { vis.Walk(x.Value) } case Body: - for _, e := range x { - vis.Walk(e) + for i := range x { + vis.Walk(x[i]) } case Args: - for _, t := range x { - vis.Walk(t) + for i := range x { + vis.Walk(x[i]) } case *Expr: switch ts := x.Terms.(type) { - case *SomeDecl: + case *Term, *SomeDecl, *Every: vis.Walk(ts) case []*Term: - for _, t := range ts { - vis.Walk(t) + for i := range ts { + vis.Walk(ts[i]) } - case *Term: - vis.Walk(ts) } for i := range x.With { vis.Walk(x.With[i]) @@ -661,8 +737,8 @@ func (vis *VarVisitor) Walk(x interface{}) { case *Term: vis.Walk(x.Value) case Ref: - for _, t := range x { - vis.Walk(t) + for i := range x { + vis.Walk(x[i]) } case *object: x.Foreach(func(k, v *Term) { @@ -674,8 +750,9 @@ func (vis *VarVisitor) Walk(x interface{}) { vis.Walk(t) }) case Set: - for _, t := range x.Slice() { - vis.Walk(t) + xSlice := x.Slice() + for i := range xSlice { + vis.Walk(xSlice[i]) } case *ArrayComprehension: vis.Walk(x.Term) @@ -688,8 +765,19 @@ func (vis *VarVisitor) Walk(x interface{}) { vis.Walk(x.Term) vis.Walk(x.Body) case Call: - for _, t := range x { - vis.Walk(t) + for i := range x { + vis.Walk(x[i]) + } + case *Every: + if x.Key != nil { + vis.Walk(x.Key) + } + vis.Walk(x.Value) + vis.Walk(x.Domain) + vis.Walk(x.Body) + case *SomeDecl: + for i := range x.Symbols { + vis.Walk(x.Symbols[i]) } } } diff --git a/ast/visit_test.go b/ast/visit_test.go index 1b3694508b..23d4f57e72 100644 --- a/ast/visit_test.go +++ b/ast/visit_test.go @@ -29,6 +29,7 @@ t[x] = y { y = [[x, z] | x = "x"; z = "z"] z = {"foo": [x, z] | x = "x"; z = "z"} s = {1 | a[i] = "foo"} + some x0, y0, z0 count({1, 2, 3}, n) with input.foo.bar as x } @@ -37,264 +38,10 @@ p { false } else { false } else { true } fn([x, y]) = z { json.unmarshal(x, z); z > y } `) vis := &testVis{} - NewGenericVisitor(vis.Visit).Walk(rule) - /* - mod - package - data.a.b - term - data - term - a - term - b - import - term - input.x.y - term - input - term - x - term - y - z - rule - head - t - args - term - x - term - y - body - expr1 - term - ref - term - = - term - ref1 - term - p - term - x - term - object1 - term - "foo" - term - array - term - y - term - 2 - term - object2 - term - "bar" - term - 3 - expr2 - term - ref2 - term - q - term - x - expr3 - term - ref - term - = - term - y - term - compr - term - array - term - x - term - z - body - expr4 - term - ref - term - = - term - x - term - "x" - expr5 - term - ref - term - = - term - z - term - "z" - expr4 - term - ref - term - = - term - z - term - compr - key - term - "foo" - value - array - term - x - term - z - body - expr1 - term - ref - term - = - term - x - term - "x" - expr2 - term - ref - term - = - term - z - term - "z" - expr5 - term - ref - term - = - term - s - term - compr - term - 1 - body - expr1 - term - ref - term - = - term - ref - term - a - term - i - - term - "foo" - expr6 - term - ref - term - count - term - set - term - 1 - term - 2 - term - 3 - term - n - with - term - input.foo.bar - term - input - term - foo - term - bar - term - baz - rule - head - p - args - # not counted - term - true - body - expr - term - false - rule - head - p - args - # not counted - term - true - body - expr - term - false - rule - head - p - args - # not counted - term - true - body - expr - term - true - func - head - fn - args - term - array - term - x - term - y - term - z - body - expr1 - term - ref - term - json - term - unmarshal - term - x - term - z - expr2 - term - ref - term - > - term - z - term - y - */ - if len(vis.elems) != 246 { - t.Errorf("Expected exactly 246 elements in AST but got %d: %v", len(vis.elems), vis.elems) + if exp, act := 254, len(vis.elems); exp != act { + t.Errorf("Expected exactly %d elements in AST but got %d: %v", exp, act, vis.elems) } } @@ -344,6 +91,7 @@ t[x] = y { y = [[x, z] | x = "x"; z = "z"] z = {"foo": [x, z] | x = "x"; z = "z"} s = {1 | a[i] = "foo"} + some x0, y0, z0 count({1, 2, 3}, n) with input.foo.bar as x } @@ -359,8 +107,8 @@ fn([x, y]) = z { json.unmarshal(x, z); z > y } }) vis.Walk(rule) - if len(elems) != 246 { - t.Errorf("Expected exactly 246 elements in AST but got %d: %v", len(elems), elems) + if len(elems) != 254 { + t.Errorf("Expected exactly 254 elements in AST but got %d: %v", len(elems), elems) } } @@ -375,6 +123,7 @@ t[x] = y { y = [[x, z] | x = "x"; z = "z"] z = {"foo": [x, z] | x = "x"; z = "z"} s = {1 | a[i] = "foo"} + some x0, y0, z0 count({1, 2, 3}, n) with input.foo.bar as x } @@ -393,12 +142,12 @@ fn([x, y]) = z { json.unmarshal(x, z); z > y } }) vis.Walk(rule) - if len(before) != 246 { - t.Errorf("Expected exactly 246 before elements in AST but got %d: %v", len(before), before) + if exp, act := 264, len(before); exp != act { + t.Errorf("Expected exactly %d before elements in AST but got %d: %v", exp, act, before) } - if len(after) != 246 { - t.Errorf("Expected exactly 246 after elements in AST but got %d: %v", len(after), after) + if exp, act := 264, len(before); exp != act { + t.Errorf("Expected exactly %d after elements in AST but got %d: %v", exp, act, after) } } @@ -414,21 +163,66 @@ func TestVarVisitor(t *testing.T) { {"data.foo[x] = bar.baz[y]", VarVisitorParams{SkipRefHead: true}, "[x, y]"}, {`foo = [x | data.a[i] = x]`, VarVisitorParams{SkipClosures: true}, "[foo, eq]"}, {`x = 1; y = 2; z = x + y; count([x, y, z], z)`, VarVisitorParams{}, "[x, y, z, eq, plus, count]"}, + {"some x, y", VarVisitorParams{}, "[x, y]"}, } for _, tc := range tests { - stmt := MustParseStatement(tc.stmt) + t.Run(tc.stmt, func(t *testing.T) { + stmt := MustParseStatement(tc.stmt) + + expected := NewVarSet() + MustParseTerm(tc.expected).Value.(*Array).Foreach(func(x *Term) { + expected.Add(x.Value.(Var)) + }) - expected := NewVarSet() - MustParseTerm(tc.expected).Value.(*Array).Foreach(func(x *Term) { - expected.Add(x.Value.(Var)) + vis := NewVarVisitor().WithParams(tc.params) + vis.Walk(stmt) + + if !vis.Vars().Equal(expected) { + t.Errorf("Params %#v expected %v but got: %v", tc.params, expected, vis.Vars()) + } }) + } +} + +func TestGenericVisitorLazyObject(t *testing.T) { + o := LazyObject(map[string]interface{}{"foo": 3}) + act := 0 + WalkTerms(o, func(n *Term) bool { + switch n.Value { + case String("foo"): + act++ + case Number("3"): + act++ + } - vis := NewVarVisitor().WithParams(tc.params) - vis.Walk(stmt) + return false + }) + if exp := 2; exp != act { + t.Errorf("expected %v, got %v", exp, act) + } +} - if !vis.Vars().Equal(expected) { - t.Errorf("For %v w/ %v expected %v but got: %v", stmt, tc.params, expected, vis.Vars()) +func TestGenericBeforeAfterVisitorLazyObject(t *testing.T) { + o := LazyObject(map[string]interface{}{"foo": 3}) + act := 0 + vis := NewBeforeAfterVisitor(func(x interface{}) bool { + t, ok := x.(*Term) + if !ok { + return false } + switch t.Value { + case String("foo"): + act++ + case Number("3"): + act++ + } + + return false + }, + func(interface{}) {}) + vis.Walk(o) + if exp := 2; exp != act { + t.Errorf("expected %v, got %v", exp, act) } } diff --git a/build/binary-smoke-test.sh b/build/binary-smoke-test.sh new file mode 100755 index 0000000000..f06c3b6afd --- /dev/null +++ b/build/binary-smoke-test.sh @@ -0,0 +1,72 @@ +#!/usr/bin/env bash +set -eo pipefail +OPA_EXEC="$1" +TARGET="$2" + +PATH_SEPARATOR="/" +BASE_PATH=$(pwd) +TEST_PATH="${BASE_PATH}/test/cli/smoke/namespace/data.json" +if [[ $OPA_EXEC == *".exe" ]]; then + PATH_SEPARATOR="\\" + BASE_PATH=$(pwd -W) + TEST_PATH="$(echo ${BASE_PATH}/test/cli/smoke/namespace/data.json | sed 's/^\///' | sed 's/\//\\\\/g')" + BASE_PATH=$(echo ${BASE_PATH} | sed 's/^\///' | sed 's/\//\\/g') +fi + +github_actions_group() { + local args="$*" + echo "::group::$args" + $args + echo "::endgroup::" +} + +opa() { + local args="$*" + github_actions_group $OPA_EXEC $args +} + +# assert_contains checks if the actual string contains the expected string. +assert_contains() { + local expected="$1" + local actual="$2" + if [[ "$actual" != *"$expected"* ]]; then + echo "Expected '$expected' but got '$actual'" + exit 1 + fi +} + +# assert_not_contains checks if the actual string does not contain the expected string. +assert_not_contains() { + local expected="$1" + local actual="$2" + if [[ "$actual" == *"$expected"* ]]; then + echo "Didn't expect '$expected' in '$actual'" + exit 0 + fi +} + +opa version +opa eval -t $TARGET 'time.now_ns()' +opa eval --format pretty --bundle test/cli/smoke/golden-bundle.tar.gz --input test/cli/smoke/input.json data.test.result --fail +opa exec --bundle test/cli/smoke/golden-bundle.tar.gz --decision test/result test/cli/smoke/input.json +opa build --output o0.tar.gz test/cli/smoke/data.yaml test/cli/smoke/test.rego +echo '{"yay": "bar"}' | opa eval --format pretty --bundle o0.tar.gz -I data.test.result --fail +opa build --optimize 1 --output o1.tar.gz test/cli/smoke/data.yaml test/cli/smoke/test.rego +echo '{"yay": "bar"}' | opa eval --format pretty --bundle o1.tar.gz -I data.test.result --fail +opa build --optimize 2 --output o2.tar.gz test/cli/smoke/data.yaml test/cli/smoke/test.rego +echo '{"yay": "bar"}' | opa eval --format pretty --bundle o2.tar.gz -I data.test.result --fail + +# Tar paths +opa build --output o3.tar.gz test/cli/smoke +github_actions_group assert_contains '/test/cli/smoke/test.rego' "$(tar -tf o3.tar.gz /test/cli/smoke/test.rego)" + +# Data files - correct namespaces +echo "::group:: Data files - correct namespaces" +assert_contains "data.namespace | test${PATH_SEPARATOR}cli${PATH_SEPARATOR}smoke${PATH_SEPARATOR}namespace${PATH_SEPARATOR}data.json" "$(opa inspect test/cli/smoke)" +echo "::endgroup::" + +# Data files - correct root path +echo "::group:: Data files - correct root path" +assert_contains "${TEST_PATH}" "$(opa inspect ${BASE_PATH}/test/cli/smoke -f json)" +assert_not_contains "\\\\${TEST_PATH}" "$(opa inspect ${BASE_PATH}/test/cli/smoke -f json)" +echo "::endgroup::" \ No newline at end of file diff --git a/build/changelog.py b/build/changelog.py index 08b1c9e73d..75195a0e45 100755 --- a/build/changelog.py +++ b/build/changelog.py @@ -1,4 +1,4 @@ -#!/usr/bin/env python2 +#!/usr/bin/env python3 """ changelog.py helps generate the CHANGELOG.md message for a particular release. """ @@ -8,7 +8,7 @@ import subprocess import shlex import re -import urllib2 +import urllib.request, urllib.error, urllib.parse import sys import json @@ -31,35 +31,21 @@ def get_commit_message(commit_id): def fetch(url, token): - req = urllib2.Request(url) + req = urllib.request.Request(url) if token: req.add_header('Authorization', "token {}".format(token)) try: - rsp = urllib2.urlopen(req) + rsp = urllib.request.urlopen(req) result = json.loads(rsp.read()) except Exception as e: if hasattr(e, 'reason'): - print >> sys.stderr, 'Failed to fetch URL {}: {}'.format(url, e.reason) + print('Failed to fetch URL {}: {}'.format(url, e.reason), file=sys.stderr) elif hasattr(e, 'code'): - print >> sys.stderr, 'Failed to fetch URL {}: Code {}'.format(url, e.code) + print('Failed to fetch URL {}: Code {}'.format(url, e.code), file=sys.stderr) return {} else: return result - -def get_maintainers(): - with open("MAINTAINERS.md", "r") as f: - contents = f.read() - maintainers = re.findall(r"[^\s]+@[^\s]+", contents) - return maintainers - -maintainers = get_maintainers() - - -def is_maintainer(commit_message): - author = author_email(commit_message) - return author in maintainers - org_members_usernames = [] def get_org_members(token): url = "https://api.github.com/orgs/open-policy-agent/members?per_page=100" @@ -72,7 +58,6 @@ def get_org_members(token): if login: org_members_usernames.append(str(login)) if email: - maintainers.append(str(email)) github_ids[email]=login def author_email(commit_message): @@ -101,16 +86,14 @@ def get_github_id(commit_message, commit_id, token): def mention_author(commit_message, commit_id, token): username = get_github_id(commit_message, commit_id, token) - if username not in org_members_usernames: - return "authored by @[{author}](https://github.com/{author})".format(author=username) - return "" + return "authored by @{author}".format(author=username) def get_issue_reporter(issue_id, token): url = "https://api.github.com/repos/open-policy-agent/opa/issues/{issue_id}".format(issue_id=issue_id) issue_data = fetch(url, token) username = issue_data.get("user", "").get("login", "") if username not in org_members_usernames: - return "reported by @[{reporter}](https://github.com/{reporter})".format(reporter=username) + return "reported by @{reporter}".format(reporter=username) return "" def fixes_issue_id(commit_message): @@ -163,8 +146,7 @@ def main(): mention = "" reporter = "" commit_message = get_commit_message(commit_id) - if not is_maintainer(commit_message): - mention = mention_author(commit_message, commit_id, args.token) + mention = mention_author(commit_message, commit_id, args.token) issue_id = fixes_issue_id(commit_message) if issue_id: reporter = get_issue_reporter(issue_id, args.token) diff --git a/build/check-fuzz.sh b/build/check-fuzz.sh deleted file mode 100755 index f4fa5d2a5d..0000000000 --- a/build/check-fuzz.sh +++ /dev/null @@ -1,39 +0,0 @@ -#!/usr/bin/env bash - -set -e - -OPA_DIR=$(dirname "${BASH_SOURCE}")/.. - -usage() { - echo "check-fuzz.sh