diff --git a/.github/CODE_OF_CONDUCT.md b/.github/CODE_OF_CONDUCT.md new file mode 100644 index 0000000..1900eb4 --- /dev/null +++ b/.github/CODE_OF_CONDUCT.md @@ -0,0 +1,133 @@ +# Contributor Covenant Code of Conduct + +## Our Pledge + +We as members, contributors, and leaders pledge to make participation in our +community a harassment-free experience for everyone, regardless of age, body +size, visible or invisible disability, ethnicity, sex characteristics, gender +identity and expression, level of experience, education, socio-economic status, +nationality, personal appearance, race, religion, or sexual identity +and orientation. + +We pledge to act and interact in ways that contribute to an open, welcoming, +diverse, inclusive, and healthy community. + +## Our Standards + +Examples of behavior that contributes to a positive environment for our +community include: + +* Demonstrating empathy and kindness toward other people +* Being respectful of differing opinions, viewpoints, and experiences +* Giving and gracefully accepting constructive feedback +* Accepting responsibility and apologizing to those affected by our mistakes, + and learning from the experience +* Focusing on what is best not just for us as individuals, but for the + overall community + +Examples of unacceptable behavior include: + +* The use of sexualized language or imagery, and sexual attention or + advances of any kind +* Trolling, insulting or derogatory comments, and personal or political attacks +* Public or private harassment +* Publishing others' private information, such as a physical or email + address, without their explicit permission +* Other conduct which could reasonably be considered inappropriate in a + professional setting + +## Enforcement Responsibilities + +Community leaders are responsible for clarifying and enforcing our standards of +acceptable behavior and will take appropriate and fair corrective action in +response to any behavior that they deem inappropriate, threatening, offensive, +or harmful. + +Community leaders have the right and responsibility to remove, edit, or reject +comments, commits, code, wiki edits, issues, and other contributions that are +not aligned to this Code of Conduct, and will communicate reasons for moderation +decisions when appropriate. + +## Scope + +This Code of Conduct applies within all community spaces, and also applies when +an individual is officially representing the community in public spaces. +Examples of representing our community include using an official e-mail address, +posting via an official social media account, or acting as an appointed +representative at an online or offline event. + +## Enforcement + +Instances of abusive, harassing, or otherwise unacceptable behavior may be +reported to the community leaders responsible for enforcement at +conduct@buf.build. All complaints will be reviewed and investigated promptly +and fairly. + +All community leaders are obligated to respect the privacy and security of the +reporter of any incident. + +## Enforcement Guidelines + +Community leaders will follow these Community Impact Guidelines in determining +the consequences for any action they deem in violation of this Code of Conduct: + +### 1. Correction + +**Community Impact**: Use of inappropriate language or other behavior deemed +unprofessional or unwelcome in the community. + +**Consequence**: A private, written warning from community leaders, providing +clarity around the nature of the violation and an explanation of why the +behavior was inappropriate. A public apology may be requested. + +### 2. Warning + +**Community Impact**: A violation through a single incident or series +of actions. + +**Consequence**: A warning with consequences for continued behavior. No +interaction with the people involved, including unsolicited interaction with +those enforcing the Code of Conduct, for a specified period of time. This +includes avoiding interactions in community spaces as well as external channels +like social media. Violating these terms may lead to a temporary or +permanent ban. + +### 3. Temporary Ban + +**Community Impact**: A serious violation of community standards, including +sustained inappropriate behavior. + +**Consequence**: A temporary ban from any sort of interaction or public +communication with the community for a specified period of time. No public or +private interaction with the people involved, including unsolicited interaction +with those enforcing the Code of Conduct, is allowed during this period. +Violating these terms may lead to a permanent ban. + +### 4. Permanent Ban + +**Community Impact**: Demonstrating a pattern of violation of community +standards, including sustained inappropriate behavior, harassment of an +individual, or aggression toward or disparagement of classes of individuals. + +**Consequence**: A permanent ban from any sort of public interaction within +the community. + +## Attribution + +This Code of Conduct is adapted from the [Contributor Covenant][homepage], +version 2.0, available at +[https://www.contributor-covenant.org/version/2/0/code_of_conduct.html][v2.0]. + +Community Impact Guidelines were inspired by +[Mozilla's code of conduct enforcement ladder][Mozilla CoC]. + +For answers to common questions about this code of conduct, see the FAQ at +[https://www.contributor-covenant.org/faq][FAQ]. Translations are available +at [https://www.contributor-covenant.org/translations][translations]. + +[homepage]: https://www.contributor-covenant.org +[v2.0]: https://www.contributor-covenant.org/version/2/0/code_of_conduct.html +[Mozilla CoC]: https://github.com/mozilla/diversity +[FAQ]: https://www.contributor-covenant.org/faq +[translations]: https://www.contributor-covenant.org/translations + diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 0000000..5ace460 --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,6 @@ +version: 2 +updates: + - package-ecosystem: "github-actions" + directory: "/" + schedule: + interval: "weekly" diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml new file mode 100644 index 0000000..8d0800f --- /dev/null +++ b/.github/workflows/ci.yaml @@ -0,0 +1,36 @@ +name: ci +on: + push: + branches: [main] + tags: ['v*'] + pull_request: + branches: [main] + schedule: + - cron: '15 22 * * *' + workflow_dispatch: {} # support manual runs +permissions: + contents: read +jobs: + ci: + runs-on: ubuntu-latest + strategy: + matrix: + go-version: [1.22.x, 1.23.x] + steps: + - name: Checkout Code + uses: actions/checkout@v4 + with: + fetch-depth: 1 + - name: Install Go + uses: actions/setup-go@v5 + with: + go-version: ${{ matrix.go-version }} + - name: Test + run: make test + - name: Lint + # Often, lint & gofmt guidelines depend on the Go version. To prevent + # conflicting guidance, run only on the most recent supported version. + # For the same reason, only check generated code on the most recent + # supported version. + if: matrix.go-version == '1.23.x' + run: make checkgenerate && make lint diff --git a/.github/workflows/emergency-review-bypass.yaml b/.github/workflows/emergency-review-bypass.yaml new file mode 100644 index 0000000..3d0b436 --- /dev/null +++ b/.github/workflows/emergency-review-bypass.yaml @@ -0,0 +1,12 @@ +name: Bypass review in case of emergency +on: + pull_request: + types: + - labeled +permissions: + pull-requests: write +jobs: + approve: + if: github.event.label.name == 'Emergency Bypass Review' + uses: bufbuild/base-workflows/.github/workflows/emergency-review-bypass.yaml@main + secrets: inherit diff --git a/.github/workflows/notify-approval-bypass.yaml b/.github/workflows/notify-approval-bypass.yaml new file mode 100644 index 0000000..14af550 --- /dev/null +++ b/.github/workflows/notify-approval-bypass.yaml @@ -0,0 +1,13 @@ +name: PR Approval Bypass Notifier +on: + pull_request: + types: + - closed + branches: + - main +permissions: + pull-requests: read +jobs: + approval: + uses: bufbuild/base-workflows/.github/workflows/notify-approval-bypass.yaml@main + secrets: inherit diff --git a/.github/workflows/pr-title.yaml b/.github/workflows/pr-title.yaml new file mode 100644 index 0000000..b114603 --- /dev/null +++ b/.github/workflows/pr-title.yaml @@ -0,0 +1,18 @@ +name: Lint PR Title +# Prevent writing to the repository using the CI token. +# Ref: https://docs.github.com/en/actions/reference/workflow-syntax-for-github-actions#permissions +permissions: + pull-requests: read +on: + pull_request: + # By default, a workflow only runs when a pull_request's activity type is opened, + # synchronize, or reopened. We explicity override here so that PR titles are + # re-linted when the PR text content is edited. + types: + - opened + - edited + - reopened + - synchronize +jobs: + lint: + uses: bufbuild/base-workflows/.github/workflows/pr-title.yaml@main diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..987d6a1 --- /dev/null +++ b/.gitignore @@ -0,0 +1,4 @@ +/.tmp/ +*.pprof +*.svg +cover.out diff --git a/.golangci.yml b/.golangci.yml new file mode 100644 index 0000000..290ac4c --- /dev/null +++ b/.golangci.yml @@ -0,0 +1,103 @@ +linters-settings: + errcheck: + check-type-assertions: true + forbidigo: + forbid: + - '^fmt\.Print' + - '^log\.' + - '^print$' + - '^println$' + - '^panic$' + godox: + # TODO, OPT, etc. comments are fine to commit. Use FIXME comments for + # temporary hacks, and use godox to prevent committing them. + keywords: [FIXME] + varnamelen: + ignore-decls: + - T any + - i int + - wg sync.WaitGroup + - id string +linters: + enable-all: true + disable: + - cyclop # covered by gocyclo + - depguard # unnecessary for small libraries + - err113 # way too noisy + - exhaustruct # many exceptions + - funlen # rely on code review to limit function length + - gochecknoglobals # many exceptions + - gocognit # dubious "cognitive overhead" quantification + - gofumpt # prefer standard gofmt + - goimports # rely on gci instead + - gomnd # some unnamed constants are okay + - inamedparam # not standard style + - interfacebloat # many exceptions + - ireturn # "accept interfaces, return structs" isn't ironclad + - lll # don't want hard limits for line length + - maintidx # covered by gocyclo + - nlreturn # generous whitespace violates house style + - testifylint # does not want us to use assert + - testpackage # internal tests are fine + - thelper # we want to print out the whole stack + - wrapcheck # don't _always_ need to wrap errors + - wsl # generous whitespace violates house style +issues: + exclude-dirs-use-default: false + exclude-rules: + - linters: + - revive + path: check/client.go + test: "CheckCallOption" + - linters: + - revive + path: check/check_service_handler.go + test: "CheckServiceHandlerOption" + - linters: + - exhaustive + path: check/options.go + text: "reflect.Pointer|reflect.Ptr" + - linters: + - gocritic + path: check/file.go + text: "commentFormatting" + - linters: + - gocritic + path: check/location.go + text: "commentFormatting" + - linters: + - nilnil + path: check/rule.go + - linters: + - nilnil + path: check/response_writer.go + - linters: + - unparam + path: check/category_spec.go + - linters: + - unparam + path: check/annotation.go + - linters: + - unparam + path: check/response.go + - linters: + - nilnil + path: check/checktest/checktest.go + - linters: + - varnamelen + path: check/internal/example + - linters: + - dupl + path: check/checkutil/breaking.go + - linters: + - varnamelen + path: check/checkutil/breaking.go + - linters: + - varnamelen + path: check/checkutil/lint.go + - linters: + - varnamelen + path: check/checkutil/util.go + - linters: + - varnamelen + path: internal/pkg/xslices/xslices.go diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..1040748 --- /dev/null +++ b/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright 2024 Buf Technologies, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..4161b78 --- /dev/null +++ b/Makefile @@ -0,0 +1,93 @@ +# See https://tech.davis-hansson.com/p/make/ +SHELL := bash +.DELETE_ON_ERROR: +.SHELLFLAGS := -eu -o pipefail -c +.DEFAULT_GOAL := all +MAKEFLAGS += --warn-undefined-variables +MAKEFLAGS += --no-builtin-rules +MAKEFLAGS += --no-print-directory +BIN := .tmp/bin +export PATH := $(abspath $(BIN)):$(PATH) +export GOBIN := $(abspath $(BIN)) +COPYRIGHT_YEARS := 2024 +LICENSE_IGNORE := --ignore testdata/ + +BUF_VERSION := v1.39.0 +GO_MOD_GOTOOLCHAIN := go1.23.0 +GOLANGCI_LINT_VERSION := v1.60.1 +# https://github.com/golangci/golangci-lint/issues/4837 +GOLANGCI_LINT_GOTOOLCHAIN := $(GO_MOD_GOTOOLCHAIN) +# Remove when we want to upgrade past Go 1.21 +GO_GET_PKGS := github.com/antlr4-go/antlr/v4@v4.13.0 + +.PHONY: help +help: ## Describe useful make targets + @grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "%-30s %s\n", $$1, $$2}' + +.PHONY: all +all: ## Build, test, and lint (default) + $(MAKE) test + $(MAKE) lint + +.PHONY: clean +clean: ## Delete intermediate build artifacts + @# -X only removes untracked files, -d recurses into directories, -f actually removes files/dirs + git clean -Xdf + +.PHONY: test +test: build ## Run unit tests + go test -vet=off -race -cover ./... + +.PHONY: build +build: generate ## Build all packages + go build ./... + +.PHONY: install +install: ## Install all binaries + go install ./... + +.PHONY: lint +lint: $(BIN)/golangci-lint ## Lint + go vet ./... + GOTOOLCHAIN=$(GOLANGCI_LINT_GOTOOLCHAIN) golangci-lint run --modules-download-mode=readonly --timeout=3m0s + +.PHONY: lintfix +lintfix: $(BIN)/golangci-lint ## Automatically fix some lint errors + GOTOOLCHAIN=$(GOLANGCI_LINT_GOTOOLCHAIN) golangci-lint run --fix --modules-download-mode=readonly --timeout=3m0s + +.PHONY: generate +generate: $(BIN)/buf $(BIN)/protoc-gen-pluginrpc-go $(BIN)/license-header ## Regenerate code and licenses + buf generate + cd ./check/internal/example; buf generate + license-header \ + --license-type apache \ + --copyright-holder "Buf Technologies, Inc." \ + --year-range "$(COPYRIGHT_YEARS)" $(LICENSE_IGNORE) + +.PHONY: upgrade +upgrade: ## Upgrade dependencies + go mod edit -toolchain=$(GO_MOD_GOTOOLCHAIN) + go get -u -t ./... $(GO_GET_PKGS) + go mod tidy -v + +.PHONY: checkgenerate +checkgenerate: + @# Used in CI to verify that `make generate` doesn't produce a diff. + test -z "$$(git status --porcelain | tee /dev/stderr)" + +$(BIN)/buf: Makefile + @mkdir -p $(@D) + go install github.com/bufbuild/buf/cmd/buf@$(BUF_VERSION) + +$(BIN)/license-header: Makefile + @mkdir -p $(@D) + go install github.com/bufbuild/buf/private/pkg/licenseheader/cmd/license-header@$(BUF_VERSION) + +$(BIN)/golangci-lint: Makefile + @mkdir -p $(@D) + GOTOOLCHAIN=$(GOLANGCI_LINT_GOTOOLCHAIN) go install github.com/golangci/golangci-lint/cmd/golangci-lint@$(GOLANGCI_LINT_VERSION) + +.PHONY: $(BIN)/protoc-gen-pluginrpc-go +$(BIN)/protoc-gen-pluginrpc-go: + @mkdir -p $(@D) + go install pluginrpc.com/pluginrpc/cmd/protoc-gen-pluginrpc-go diff --git a/README.md b/README.md new file mode 100644 index 0000000..17d5247 --- /dev/null +++ b/README.md @@ -0,0 +1,20 @@ +bufplugin-go +============== + +[![Build](https://github.com/bufbuild/bufplugin-go/actions/workflows/ci.yaml/badge.svg?branch=main)](https://github.com/bufbuild/bufplugin-go/actions/workflows/ci.yaml) +[![Report Card](https://goreportcard.com/badge/github.com/bufbuild/bufplugin-go)](https://goreportcard.com/report/github.com/bufbuild/bufplugin-go) +[![GoDoc](https://pkg.go.dev/badge/github.com/bufbuild/bufplugin-go.svg)](https://pkg.go.dev/github.com/bufbuild/bufplugin-go) +[![Slack](https://img.shields.io/badge/slack-buf-%23e01563)](https://buf.build/links/slack) + +This is the Golang SDK for [bufplugin](https://github.com/bufbuild/bufplugin). + +This is very early, but see the [example](check/internal/example) for how this works in practice. + +## Status: Alpha + +Bufplugin is as early as it gets - [buf](https://github.com/bufbuild/buf) doesn't actually support +plugins yet! We're publishing this publicly to get early feedback as we approach stability. + +## Legal + +Offered under the [Apache 2 license](https://github.com/bufbuild/bufplugin-go/blob/main/LICENSE). diff --git a/buf.gen.yaml b/buf.gen.yaml new file mode 100644 index 0000000..8dfc51f --- /dev/null +++ b/buf.gen.yaml @@ -0,0 +1,15 @@ +version: v2 +inputs: + - module: buf.build/bufbuild/bufplugin +managed: + enabled: true + disable: + - file_option: go_package_prefix + module: buf.build/bufbuild/bufplugin + - file_option: go_package_prefix + module: buf.build/bufbuild/protovalidate +plugins: + - local: protoc-gen-pluginrpc-go + out: internal/gen + opt: paths=source_relative +clean: true diff --git a/check/annotation.go b/check/annotation.go new file mode 100644 index 0000000..a514965 --- /dev/null +++ b/check/annotation.go @@ -0,0 +1,122 @@ +// Copyright 2024 Buf Technologies, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package check + +import ( + "errors" + "sort" + + checkv1 "buf.build/gen/go/bufbuild/bufplugin/protocolbuffers/go/buf/plugin/check/v1" +) + +// Annotation represents a rule Failure. +// +// An annotation always contains the ID of the Rule that failed. It also optionally +// contains a user-readable message, a location of the failure, and a location of the +// failure in the against Files. +// +// Annotations are created on the server-side via ResponseWriters, and returned +// from Clients on Responses. +type Annotation interface { + // RuleID is the ID of the Rule that failed. + // + // This will always be present. + RuleID() string + // Message is a user-readable message describing the failure. + Message() string + // Location is the location of the failure. + Location() Location + // AgainstLocation is the Location of the failure in the against Files. + // + // Will only potentially be produced for breaking change rules. + AgainstLocation() Location + + toProto() *checkv1.Annotation + + isAnnotation() +} + +// *** PRIVATE *** + +type annotation struct { + ruleID string + message string + location Location + againstLocation Location +} + +func newAnnotation( + ruleID string, + message string, + location Location, + againstLocation Location, +) (*annotation, error) { + if ruleID == "" { + return nil, errors.New("check.Annotation: RuleID is empty") + } + return &annotation{ + ruleID: ruleID, + message: message, + location: location, + againstLocation: againstLocation, + }, nil +} + +func (a *annotation) RuleID() string { + return a.ruleID +} + +func (a *annotation) Message() string { + return a.message +} + +func (a *annotation) Location() Location { + return a.location +} + +func (a *annotation) AgainstLocation() Location { + return a.againstLocation +} + +func (a *annotation) toProto() *checkv1.Annotation { + if a == nil { + return nil + } + var protoLocation *checkv1.Location + if a.location != nil { + protoLocation = a.location.toProto() + } + var protoAgainstLocation *checkv1.Location + if a.againstLocation != nil { + protoAgainstLocation = a.againstLocation.toProto() + } + return &checkv1.Annotation{ + RuleId: a.RuleID(), + Message: a.Message(), + Location: protoLocation, + AgainstLocation: protoAgainstLocation, + } +} + +func (*annotation) isAnnotation() {} + +func sortAnnotations(annotations []Annotation) { + sort.Slice( + annotations, + func(i int, j int) bool { + return CompareAnnotations(annotations[i], annotations[j]) < 0 + }, + ) +} diff --git a/check/category.go b/check/category.go new file mode 100644 index 0000000..0583176 --- /dev/null +++ b/check/category.go @@ -0,0 +1,157 @@ +// Copyright 2024 Buf Technologies, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package check + +import ( + "errors" + "fmt" + "slices" + "sort" + + checkv1 "buf.build/gen/go/bufbuild/bufplugin/protocolbuffers/go/buf/plugin/check/v1" + "github.com/bufbuild/bufplugin-go/internal/pkg/xslices" +) + +// Category is rule category. +// +// Categories have unique IDs. On the server-side (i.e. the plugin), Categories are created +// by CategorySpecs. Clients can list all available plugin Categories by calling ListCategories. +type Category interface { + // ID is the ID of the Category. + // + // Always present. + // + // This uniquely identifies the Category. + ID() string + // A user-displayable purpose of the category. + // + // Always present. + Purpose() string + // Deprecated returns whether or not this Category is deprecated. + // + // If the Category is deprecated, it may be replaced by zero or more Categories. These will + // be denoted by ReplacementIDs. + Deprecated() bool + // ReplacementIDs returns the IDs of the Categories that replace this Category, if this Category is deprecated. + // + // This means that the combination of the Categories specified by ReplacementIDs replace this Category entirely, + // and this Category is considered equivalent to the AND of the categories specified by ReplacementIDs. + // + // This will only be non-empty if Deprecated is true. + // + // It is not valid for a deprecated Category to specfiy another deprecated Category as a replacement. + ReplacementIDs() []string + + toProto() *checkv1.Category + + isCategory() +} + +// *** PRIVATE *** + +type category struct { + id string + purpose string + deprecated bool + replacementIDs []string +} + +func newCategory( + id string, + purpose string, + deprecated bool, + replacementIDs []string, +) (*category, error) { + if id == "" { + return nil, errors.New("check.Category: ID is empty") + } + if purpose == "" { + return nil, errors.New("check.Category: Purpose is empty") + } + if !deprecated && len(replacementIDs) > 0 { + return nil, fmt.Errorf("check.Category: Deprecated is false but ReplacementIDs %v specified", replacementIDs) + } + return &category{ + id: id, + purpose: purpose, + deprecated: deprecated, + replacementIDs: replacementIDs, + }, nil +} + +func (r *category) ID() string { + return r.id +} + +func (r *category) Purpose() string { + return r.purpose +} + +func (r *category) Deprecated() bool { + return r.deprecated +} + +func (r *category) ReplacementIDs() []string { + return slices.Clone(r.replacementIDs) +} + +func (r *category) toProto() *checkv1.Category { + if r == nil { + return nil + } + return &checkv1.Category{ + Id: r.id, + Purpose: r.purpose, + Deprecated: r.deprecated, + ReplacementIds: r.replacementIDs, + } +} + +func (*category) isCategory() {} + +func categoryForProtoCategory(protoCategory *checkv1.Category) (Category, error) { + return newCategory( + protoCategory.GetId(), + protoCategory.GetPurpose(), + protoCategory.GetDeprecated(), + protoCategory.GetReplacementIds(), + ) +} + +func sortCategories(categories []Category) { + sort.Slice(categories, func(i int, j int) bool { return CompareCategories(categories[i], categories[j]) < 0 }) +} + +func validateCategories(categories []Category) error { + return validateNoDuplicateCategoryIDs(xslices.Map(categories, Category.ID)) +} + +func validateNoDuplicateCategoryIDs(ids []string) error { + idToCount := make(map[string]int, len(ids)) + for _, id := range ids { + idToCount[id]++ + } + var duplicateIDs []string + for id, count := range idToCount { + if count > 1 { + duplicateIDs = append(duplicateIDs, id) + } + } + if len(duplicateIDs) > 0 { + sort.Strings(duplicateIDs) + return newDuplicateCategoryIDError(duplicateIDs) + } + return nil +} diff --git a/check/category_spec.go b/check/category_spec.go new file mode 100644 index 0000000..be9f57a --- /dev/null +++ b/check/category_spec.go @@ -0,0 +1,102 @@ +// Copyright 2024 Buf Technologies, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package check + +import ( + "sort" + + "github.com/bufbuild/bufplugin-go/internal/pkg/xslices" +) + +// CategorySpec is the spec for a Category. +// +// It is used to construct a Category on the server-side (i.e. within the plugin). It specifies the +// ID, purpose, and a CategoryHandler to actually run the Category logic. +// +// Generally, these are provided to Main. This library will handle Check and ListCategories calls +// based on the provided CategorySpecs. +type CategorySpec struct { + // Required. + ID string + // Required. + Purpose string + Deprecated bool + ReplacementIDs []string +} + +// *** PRIVATE *** + +// Assumes that the CategorySpec is validated. +func categorySpecToCategory(categorySpec *CategorySpec) (Category, error) { + return newCategory( + categorySpec.ID, + categorySpec.Purpose, + categorySpec.Deprecated, + categorySpec.ReplacementIDs, + ) +} + +func validateCategorySpecs( + categorySpecs []*CategorySpec, + ruleSpecs []*RuleSpec, +) error { + categoryIDs := xslices.Map(categorySpecs, func(categorySpec *CategorySpec) string { return categorySpec.ID }) + if err := validateNoDuplicateCategoryIDs(categoryIDs); err != nil { + return err + } + categoryIDForRulesMap := make(map[string]struct{}) + for _, ruleSpec := range ruleSpecs { + for _, categoryID := range ruleSpec.CategoryIDs { + categoryIDForRulesMap[categoryID] = struct{}{} + } + } + categoryIDToCategorySpec := make(map[string]*CategorySpec) + for _, categorySpec := range categorySpecs { + if err := validateID(categorySpec.ID); err != nil { + return wrapValidateCategorySpecError(err) + } + categoryIDToCategorySpec[categorySpec.ID] = categorySpec + } + for _, categorySpec := range categorySpecs { + if err := validatePurpose(categorySpec.ID, categorySpec.Purpose); err != nil { + return wrapValidateCategorySpecError(err) + } + if len(categorySpec.ReplacementIDs) > 0 && !categorySpec.Deprecated { + return newValidateCategorySpecErrorf("ID %q had ReplacementIDs but Deprecated was false", categorySpec.ID) + } + for _, replacementID := range categorySpec.ReplacementIDs { + replacementCategorySpec, ok := categoryIDToCategorySpec[replacementID] + if !ok { + return newValidateCategorySpecErrorf("ID %q specified replacement ID %q which was not found", categorySpec.ID, replacementID) + } + if replacementCategorySpec.Deprecated { + return newValidateCategorySpecErrorf("Deprecated ID %q specified replacement ID %q which also deprecated", categorySpec.ID, replacementID) + } + } + if _, ok := categoryIDForRulesMap[categorySpec.ID]; !ok { + return newValidateCategorySpecErrorf("no Rule has a Category ID of %q", categorySpec.ID) + } + } + return nil +} + +func sortCategorySpecs(categorySpecs []*CategorySpec) { + sort.Slice( + categorySpecs, + func(i int, j int) bool { + return compareCategorySpecs(categorySpecs[i], categorySpecs[j]) < 0 + }, + ) +} diff --git a/check/check.go b/check/check.go new file mode 100644 index 0000000..3fc9de7 --- /dev/null +++ b/check/check.go @@ -0,0 +1,16 @@ +// Copyright 2024 Buf Technologies, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package check implements the SDK for custom lint and breaking change plugins. +package check diff --git a/check/check_server.go b/check/check_server.go new file mode 100644 index 0000000..48c1d0c --- /dev/null +++ b/check/check_server.go @@ -0,0 +1,42 @@ +// Copyright 2024 Buf Technologies, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package check + +import ( + "github.com/bufbuild/bufplugin-go/internal/gen/buf/plugin/check/v1/v1pluginrpc" + "pluginrpc.com/pluginrpc" +) + +// NewCheckServiceServer is a convenience function that creates a new pluginrpc.Server for +// the given v1pluginrpc.CheckServiceHandler. +// +// This registers the Check RPC on the command "check", the ListRules RPC on the command +// "list-rules", and the ListCategories RPC on the command "list-categories". No options +// are passed to any of the types necessary to create this Server. If further customization +// is necessary, this can be done manually. +func NewCheckServiceServer(checkServiceHandler v1pluginrpc.CheckServiceHandler) (pluginrpc.Server, error) { + spec, err := v1pluginrpc.CheckServiceSpecBuilder{ + Check: []pluginrpc.ProcedureOption{pluginrpc.ProcedureWithArgs("check")}, + ListRules: []pluginrpc.ProcedureOption{pluginrpc.ProcedureWithArgs("list-rules")}, + ListCategories: []pluginrpc.ProcedureOption{pluginrpc.ProcedureWithArgs("list-categories")}, + }.Build() + if err != nil { + return nil, err + } + serverRegistrar := pluginrpc.NewServerRegistrar() + checkServiceServer := v1pluginrpc.NewCheckServiceServer(pluginrpc.NewHandler(spec), checkServiceHandler) + v1pluginrpc.RegisterCheckServiceServer(serverRegistrar, checkServiceServer) + return pluginrpc.NewServer(spec, serverRegistrar) +} diff --git a/check/check_service_handler.go b/check/check_service_handler.go new file mode 100644 index 0000000..0a43d06 --- /dev/null +++ b/check/check_service_handler.go @@ -0,0 +1,307 @@ +// Copyright 2024 Buf Technologies, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package check + +import ( + "context" + "fmt" + "slices" + + checkv1 "buf.build/gen/go/bufbuild/bufplugin/protocolbuffers/go/buf/plugin/check/v1" + "github.com/bufbuild/bufplugin-go/internal/gen/buf/plugin/check/v1/v1pluginrpc" + "github.com/bufbuild/bufplugin-go/internal/pkg/thread" + "github.com/bufbuild/bufplugin-go/internal/pkg/xslices" + "github.com/bufbuild/protovalidate-go" + "pluginrpc.com/pluginrpc" +) + +const defaultPageSize = 250 + +// NewCheckServiceHandler returns a new v1pluginrpc.CheckServiceHandler for the given Spec. +// +// The Spec will be validated. +func NewCheckServiceHandler(spec *Spec, options ...CheckServiceHandlerOption) (v1pluginrpc.CheckServiceHandler, error) { + return newCheckServiceHandler(spec, options...) +} + +// CheckServiceHandlerOption is an option for CheckServiceHandler. +type CheckServiceHandlerOption func(*checkServiceHandlerOptions) + +// CheckServiceHandlerWithParallelism returns a new CheckServiceHandlerOption that sets the parallelism +// by which Rules will be run. +// +// If this is set to a value >= 1, this many concurrent Rules can be run at the same time. +// A value of 0 indicates the default behavior, which is to use runtime.GOMAXPROCS(0). +// +// A value if < 0 has no effect. +func CheckServiceHandlerWithParallelism(parallelism int) CheckServiceHandlerOption { + return func(checkServiceHandlerOptions *checkServiceHandlerOptions) { + if parallelism < 0 { + parallelism = 0 + } + checkServiceHandlerOptions.parallelism = parallelism + } +} + +// *** PRIVATE *** + +type checkServiceHandler struct { + spec *Spec + parallelism int + validator *protovalidate.Validator + rules []Rule + ruleIDToRule map[string]Rule + ruleIDToRuleHandler map[string]RuleHandler + ruleIDToIndex map[string]int + categories []Category + categoryIDToCategory map[string]Category + categoryIDToIndex map[string]int +} + +func newCheckServiceHandler(spec *Spec, options ...CheckServiceHandlerOption) (*checkServiceHandler, error) { + checkServiceHandlerOptions := newCheckServiceHandlerOptions() + for _, option := range options { + option(checkServiceHandlerOptions) + } + if err := ValidateSpec(spec); err != nil { + return nil, err + } + categorySpecs := slices.Clone(spec.Categories) + sortCategorySpecs(categorySpecs) + categories := make([]Category, len(categorySpecs)) + categoryIDToCategory := make(map[string]Category, len(categorySpecs)) + categoryIDToIndex := make(map[string]int, len(categorySpecs)) + for i, categorySpec := range categorySpecs { + category, err := categorySpecToCategory(categorySpec) + if err != nil { + return nil, err + } + id := category.ID() + // Should never happen after validating the Spec. + if _, ok := categoryIDToCategory[id]; ok { + return nil, fmt.Errorf("duplicate Category ID: %q", id) + } + categories[i] = category + categoryIDToCategory[id] = category + categoryIDToIndex[id] = i + } + ruleSpecs := slices.Clone(spec.Rules) + sortRuleSpecs(ruleSpecs) + rules := make([]Rule, len(ruleSpecs)) + ruleIDToRuleHandler := make(map[string]RuleHandler, len(ruleSpecs)) + ruleIDToRule := make(map[string]Rule, len(ruleSpecs)) + ruleIDToIndex := make(map[string]int, len(ruleSpecs)) + for i, ruleSpec := range ruleSpecs { + rule, err := ruleSpecToRule(ruleSpec, categoryIDToCategory) + if err != nil { + return nil, err + } + id := rule.ID() + // Should never happen after validating the Spec. + if _, ok := ruleIDToRule[id]; ok { + return nil, fmt.Errorf("duplicate Rule ID: %q", id) + } + rules[i] = rule + ruleIDToRuleHandler[id] = ruleSpec.Handler + ruleIDToRule[id] = rule + ruleIDToIndex[id] = i + } + validator, err := protovalidate.New() + if err != nil { + return nil, err + } + return &checkServiceHandler{ + spec: spec, + parallelism: checkServiceHandlerOptions.parallelism, + validator: validator, + rules: rules, + ruleIDToRuleHandler: ruleIDToRuleHandler, + ruleIDToRule: ruleIDToRule, + ruleIDToIndex: ruleIDToIndex, + categories: categories, + categoryIDToCategory: categoryIDToCategory, + categoryIDToIndex: categoryIDToIndex, + }, nil +} + +func (c *checkServiceHandler) Check( + ctx context.Context, + checkRequest *checkv1.CheckRequest, +) (*checkv1.CheckResponse, error) { + if err := c.validator.Validate(checkRequest); err != nil { + return nil, pluginrpc.NewError(pluginrpc.CodeInvalidArgument, err) + } + request, err := RequestForProtoRequest(checkRequest) + if err != nil { + return nil, err + } + if c.spec.Before != nil { + ctx, request, err = c.spec.Before(ctx, request) + if err != nil { + return nil, err + } + } + rules := xslices.Filter(c.rules, func(rule Rule) bool { return rule.Default() }) + if ruleIDs := request.RuleIDs(); len(ruleIDs) > 0 { + rules = make([]Rule, 0) + for _, ruleID := range ruleIDs { + rule, ok := c.ruleIDToRule[ruleID] + if !ok { + return nil, pluginrpc.NewErrorf(pluginrpc.CodeInvalidArgument, "unknown rule ID: %q", ruleID) + } + rules = append(rules, rule) + } + } + multiResponseWriter, err := newMultiResponseWriter(request) + if err != nil { + return nil, err + } + if err := thread.Parallelize( + ctx, + xslices.Map( + rules, + func(rule Rule) func(context.Context) error { + return func(ctx context.Context) error { + ruleHandler, ok := c.ruleIDToRuleHandler[rule.ID()] + if !ok { + // This should never happen. + return fmt.Errorf("no RuleHandler for id %q", rule.ID()) + } + return ruleHandler.Handle( + ctx, + multiResponseWriter.newResponseWriter(rule.ID()), + request, + ) + } + }, + ), + thread.WithParallelism(c.parallelism), + ); err != nil { + return nil, err + } + response, err := multiResponseWriter.toResponse() + if err != nil { + return nil, err + } + checkResponse := response.toProto() + if err := c.validator.Validate(checkResponse); err != nil { + return nil, err + } + return checkResponse, nil +} + +func (c *checkServiceHandler) ListRules(_ context.Context, listRulesRequest *checkv1.ListRulesRequest) (*checkv1.ListRulesResponse, error) { + if err := c.validator.Validate(listRulesRequest); err != nil { + return nil, pluginrpc.NewError(pluginrpc.CodeInvalidArgument, err) + } + rules, nextPageToken, err := c.getRulesAndNextPageToken( + int(listRulesRequest.GetPageSize()), + listRulesRequest.GetPageToken(), + ) + if err != nil { + return nil, err + } + listRulesResponse := &checkv1.ListRulesResponse{ + NextPageToken: nextPageToken, + Rules: xslices.Map(rules, Rule.toProto), + } + if err := c.validator.Validate(listRulesResponse); err != nil { + return nil, err + } + return listRulesResponse, nil +} + +func (c *checkServiceHandler) ListCategories(_ context.Context, listCategoriesRequest *checkv1.ListCategoriesRequest) (*checkv1.ListCategoriesResponse, error) { + if err := c.validator.Validate(listCategoriesRequest); err != nil { + return nil, pluginrpc.NewError(pluginrpc.CodeInvalidArgument, err) + } + categories, nextPageToken, err := c.getCategoriesAndNextPageToken( + int(listCategoriesRequest.GetPageSize()), + listCategoriesRequest.GetPageToken(), + ) + if err != nil { + return nil, err + } + listCategoriesResponse := &checkv1.ListCategoriesResponse{ + NextPageToken: nextPageToken, + Categories: xslices.Map(categories, Category.toProto), + } + if err := c.validator.Validate(listCategoriesResponse); err != nil { + return nil, err + } + return listCategoriesResponse, nil +} + +func (c *checkServiceHandler) getRulesAndNextPageToken(pageSize int, pageToken string) ([]Rule, string, error) { + index := 0 + if pageToken != "" { + var ok bool + index, ok = c.ruleIDToIndex[pageToken] + if !ok { + return nil, "", pluginrpc.NewErrorf(pluginrpc.CodeInvalidArgument, "unknown page token: %q", pageToken) + } + } + if pageSize == 0 { + pageSize = defaultPageSize + } + resultRules := make([]Rule, 0, len(c.rules)-index) + for i := 0; i < pageSize; i++ { + if index >= len(c.rules) { + break + } + resultRules = append(resultRules, c.rules[index]) + index++ + } + var nextPageToken string + if index < len(c.rules) { + nextPageToken = c.rules[index].ID() + } + return resultRules, nextPageToken, nil +} + +func (c *checkServiceHandler) getCategoriesAndNextPageToken(pageSize int, pageToken string) ([]Category, string, error) { + index := 0 + if pageToken != "" { + var ok bool + index, ok = c.categoryIDToIndex[pageToken] + if !ok { + return nil, "", pluginrpc.NewErrorf(pluginrpc.CodeInvalidArgument, "unknown page token: %q", pageToken) + } + } + if pageSize == 0 { + pageSize = defaultPageSize + } + resultCategories := make([]Category, 0, len(c.categories)-index) + for i := 0; i < pageSize; i++ { + if index >= len(c.categories) { + break + } + resultCategories = append(resultCategories, c.categories[index]) + index++ + } + var nextPageToken string + if index < len(c.categories) { + nextPageToken = c.categories[index].ID() + } + return resultCategories, nextPageToken, nil +} + +type checkServiceHandlerOptions struct { + parallelism int +} + +func newCheckServiceHandlerOptions() *checkServiceHandlerOptions { + return &checkServiceHandlerOptions{} +} diff --git a/check/check_service_handler_test.go b/check/check_service_handler_test.go new file mode 100644 index 0000000..1d9683e --- /dev/null +++ b/check/check_service_handler_test.go @@ -0,0 +1,116 @@ +// Copyright 2024 Buf Technologies, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package check + +import ( + "context" + "testing" + + checkv1 "buf.build/gen/go/bufbuild/bufplugin/protocolbuffers/go/buf/plugin/check/v1" + "github.com/stretchr/testify/require" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/types/descriptorpb" + "pluginrpc.com/pluginrpc" +) + +func TestCheckServiceHandlerUniqueFiles(t *testing.T) { + t.Parallel() + + checkServiceHandler, err := NewCheckServiceHandler( + &Spec{ + Rules: []*RuleSpec{ + testNewSimpleLintRuleSpec("RULE1", nil, true, false, nil), + }, + }, + ) + require.NoError(t, err) + + _, err = checkServiceHandler.Check( + context.Background(), + &checkv1.CheckRequest{ + Files: []*checkv1.File{ + { + FileDescriptorProto: &descriptorpb.FileDescriptorProto{ + Name: proto.String("foo.proto"), + SourceCodeInfo: &descriptorpb.SourceCodeInfo{}, + }, + }, + }, + AgainstFiles: []*checkv1.File{ + { + FileDescriptorProto: &descriptorpb.FileDescriptorProto{ + Name: proto.String("foo.proto"), + SourceCodeInfo: &descriptorpb.SourceCodeInfo{}, + }, + }, + }, + }, + ) + require.NoError(t, err) + + _, err = checkServiceHandler.Check( + context.Background(), + &checkv1.CheckRequest{ + Files: []*checkv1.File{ + { + FileDescriptorProto: &descriptorpb.FileDescriptorProto{ + Name: proto.String("foo.proto"), + SourceCodeInfo: &descriptorpb.SourceCodeInfo{}, + }, + }, + { + FileDescriptorProto: &descriptorpb.FileDescriptorProto{ + Name: proto.String("foo.proto"), + SourceCodeInfo: &descriptorpb.SourceCodeInfo{}, + }, + }, + }, + }, + ) + pluginrpcError := &pluginrpc.Error{} + require.ErrorAs(t, err, &pluginrpcError) + require.Equal(t, pluginrpc.CodeInvalidArgument, pluginrpcError.Code()) + + _, err = checkServiceHandler.Check( + context.Background(), + &checkv1.CheckRequest{ + Files: []*checkv1.File{ + { + FileDescriptorProto: &descriptorpb.FileDescriptorProto{ + Name: proto.String("foo.proto"), + SourceCodeInfo: &descriptorpb.SourceCodeInfo{}, + }, + }, + }, + AgainstFiles: []*checkv1.File{ + { + FileDescriptorProto: &descriptorpb.FileDescriptorProto{ + Name: proto.String("bar.proto"), + SourceCodeInfo: &descriptorpb.SourceCodeInfo{}, + }, + }, + { + FileDescriptorProto: &descriptorpb.FileDescriptorProto{ + Name: proto.String("bar.proto"), + SourceCodeInfo: &descriptorpb.SourceCodeInfo{}, + }, + }, + }, + }, + ) + pluginrpcError = &pluginrpc.Error{} + require.ErrorAs(t, err, &pluginrpcError) + require.Equal(t, pluginrpc.CodeInvalidArgument, pluginrpcError.Code()) +} diff --git a/check/checktest/checktest.go b/check/checktest/checktest.go new file mode 100644 index 0000000..d3715b3 --- /dev/null +++ b/check/checktest/checktest.go @@ -0,0 +1,444 @@ +// Copyright 2024 Buf Technologies, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package checktest provides testing helpers when writing lint and breaking change plugins. +// +// The easiest entry point is TestCase. This allows you to set up a test and run it extremely +// easily. Other functions provide lower-level primitives if TestCase doesn't meet your needs. +package checktest + +import ( + "context" + "errors" + "path/filepath" + "strconv" + "testing" + + checkv1 "buf.build/gen/go/bufbuild/bufplugin/protocolbuffers/go/buf/plugin/check/v1" + "github.com/bufbuild/bufplugin-go/check" + "github.com/bufbuild/bufplugin-go/internal/pkg/xslices" + "github.com/bufbuild/protocompile" + "github.com/bufbuild/protocompile/linker" + "github.com/bufbuild/protocompile/parser" + "github.com/bufbuild/protocompile/protoutil" + "github.com/bufbuild/protocompile/reporter" + "github.com/bufbuild/protocompile/wellknownimports" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/protobuf/reflect/protoreflect" + "google.golang.org/protobuf/types/descriptorpb" +) + +// SpecTest tests your spec with check.ValidateSpec. +// +// Almost every plugin should run a test with SpecTest. +// +// func TestSpec(t *testing.T) { +// t.Parallel() +// checktest.SpecTest(t, yourSpec) +// } +func SpecTest(t *testing.T, spec *check.Spec) { + require.NoError(t, check.ValidateSpec(spec)) +} + +// CheckTest is a single Check test to run against a Spec. +type CheckTest struct { + // Request is the request spec to test. + Request *RequestSpec + // Spec is the Spec to test. + // + // Required. + Spec *check.Spec + // ExpectedAnnotations are the expected Annotations that should be returned. + ExpectedAnnotations []ExpectedAnnotation +} + +// Run runs the test. +// +// This will: +// +// - Build the Files and AgainstFiles. +// - Create a new Request. +// - Create a new Client based on the Spec. +// - Call Check on the Client. +// - Compare the resulting Annotations with the ExpectedAnnotations, failing if there is a mismatch. +func (c CheckTest) Run(t *testing.T) { + ctx := context.Background() + + require.NotNil(t, c.Request) + require.NotNil(t, c.Spec) + + request, err := c.Request.ToRequest(ctx) + require.NoError(t, err) + client, err := check.NewClientForSpec(c.Spec) + require.NoError(t, err) + response, err := client.Check(ctx, request) + require.NoError(t, err) + AssertAnnotationsEqual(t, c.ExpectedAnnotations, response.Annotations()) +} + +// RequestSpec specifies request parameters to be compiled for testing. +// +// This allows a Request to be built from a directory of .proto files. +type RequestSpec struct { + // Files specifies the input files to test against. + // + // Required. + Files *ProtoFileSpec + // AgainstFiles specifies the input against files to test against, if anoy. + AgainstFiles *ProtoFileSpec + // RuleIDs are the specific RuleIDs to run. + RuleIDs []string + // Options are any options to pass to the plugin. + Options map[string]any +} + +// ToRequest converts the spec into a check.Request. +// +// If r is nil, this returns nil. +func (r *RequestSpec) ToRequest(ctx context.Context) (check.Request, error) { + if r == nil { + return nil, nil + } + + if r.Files == nil { + return nil, errors.New("RequestSpec.Files not set") + } + + againstFiles, err := r.AgainstFiles.ToFiles(ctx) + if err != nil { + return nil, err + } + options, err := check.NewOptions(r.Options) + if err != nil { + return nil, err + } + requestOptions := []check.RequestOption{ + check.WithAgainstFiles(againstFiles), + check.WithOptions(options), + check.WithRuleIDs(r.RuleIDs...), + } + + files, err := r.Files.ToFiles(ctx) + if err != nil { + return nil, err + } + return check.NewRequest(files, requestOptions...) +} + +// ProtoFileSpec specifies files to be compiled for testing. +// +// This allows tests to effectively point at a directory, and get back a +// *descriptorpb.FileDesriptorSet, or more to the point, check.Files +// that can be passed on a Request. +type ProtoFileSpec struct { + // DirPaths are the paths where .proto files are contained. + // + // Imports within .proto files should derive from one of these directories. + // This must contain at least one element. + // + // This corresponds to the -I flag in protoc. + DirPaths []string + // FilePaths are the specific paths to build within the DirPaths. + // + // Any imports of the FilePaths will be built as well, and marked as imports. + // This must contain at least one element. + // Paths should be relative to DirPaths. + // + // This corresponds to arguments passed to protoc. + FilePaths []string +} + +// ToFiles compiles the files into check.Files. +// +// If p is nil, this returns an empty slice. +func (p *ProtoFileSpec) ToFiles(ctx context.Context) ([]check.File, error) { + if p == nil { + return nil, nil + } + if err := validateProtoFileSpec(p); err != nil { + return nil, err + } + return compile(ctx, p.DirPaths, p.FilePaths) +} + +// ExpectedAnnotation contains the values expected from an Annotation. +type ExpectedAnnotation struct { + // RuleID is the ID of the Rule. + // + // Required. + RuleID string + // Message is the message returned from the annoation. + // + // If Message is not set on ExpectedAnnotation, this field will *not* be compared + // against the value in Annotation. That is, it is valid to have an Annotation return + // a message but to not set it on ExpectedAnnotation. + Message string + // Location is the location of the failure. + Location *ExpectedLocation + // AgainstLocation is the against location of the failure. + AgainstLocation *ExpectedLocation +} + +// String implements fmt.Stringer. +func (ea ExpectedAnnotation) String() string { + return "ruleID=\"" + ea.RuleID + "\"" + + " message=\"" + ea.Message + "\"" + + " location=\"" + ea.Location.String() + "\"" + + " againstLocation=\"" + ea.AgainstLocation.String() + "\"" +} + +// ExpectedLocation contains the values expected from a Location. +type ExpectedLocation struct { + // FileName is the name of the file. + FileName string + // StartLine is the zero-indexed start line. + StartLine int + // StartColumn is the zero-indexed start column. + StartColumn int + // EndLine is the zero-indexed end line. + EndLine int + // EndColumn is the zero-indexed end column. + EndColumn int +} + +// String implements fmt.Stringer. +func (el *ExpectedLocation) String() string { + if el == nil { + return "nil" + } + return el.FileName + + " startLine=" + strconv.Itoa(el.StartLine) + + " startColumn=" + strconv.Itoa(el.StartColumn) + + " endLine=" + strconv.Itoa(el.EndLine) + + " endColumn=" + strconv.Itoa(el.EndColumn) +} + +// AssertAnnotationsEqual asserts that the Annotations equal the expected Annotations. +func AssertAnnotationsEqual(t *testing.T, expectedAnnotations []ExpectedAnnotation, actualAnnotations []check.Annotation) { + if len(expectedAnnotations) == 0 { + expectedAnnotations = nil + } + if len(actualAnnotations) == 0 { + actualAnnotations = nil + } + actualExpectedAnnotations := expectedAnnotationsForAnnotations(actualAnnotations) + msgAndArgs := []any{"expected:\n%v\nactual:\n%v", expectedAnnotations, actualExpectedAnnotations} + require.Equal(t, len(expectedAnnotations), len(actualExpectedAnnotations), msgAndArgs...) + for i, expectedAnnotation := range expectedAnnotations { + if expectedAnnotation.Message == "" { + actualExpectedAnnotations[i].Message = "" + } + } + assert.Equal(t, expectedAnnotations, actualExpectedAnnotations, msgAndArgs...) +} + +// RequireAnnotationsEqual requires that the Annotations equal the expected Annotations. +func RequireAnnotationsEqual(t *testing.T, expectedAnnotations []ExpectedAnnotation, actualAnnotations []check.Annotation) { + if len(expectedAnnotations) == 0 { + expectedAnnotations = nil + } + if len(actualAnnotations) == 0 { + actualAnnotations = nil + } + actualExpectedAnnotations := expectedAnnotationsForAnnotations(actualAnnotations) + msgAndArgs := []any{"expected:\n%v\nactual:\n%v", expectedAnnotations, actualExpectedAnnotations} + require.Equal(t, len(expectedAnnotations), len(actualExpectedAnnotations), msgAndArgs...) + for i, expectedAnnotation := range expectedAnnotations { + if expectedAnnotation.Message == "" { + actualExpectedAnnotations[i].Message = "" + } + } + require.Equal(t, expectedAnnotations, actualExpectedAnnotations, msgAndArgs...) +} + +// *** PRIVATE *** + +func validateProtoFileSpec(protoFileSpec *ProtoFileSpec) error { + if len(protoFileSpec.DirPaths) == 0 { + return errors.New("no DirPaths specified on ProtoFileSpec") + } + if len(protoFileSpec.FilePaths) == 0 { + return errors.New("no FilePaths specified on ProtoFileSpec") + } + return nil +} + +// expectedAnnotationsForAnnotations returns ExpectedAnnotations for the given Annotations. +// +// Callers will need to filter out the Messages from the returned ExpectedAnnotations to conform +// to the ExpectedAnnotations that are being compared against. See the note on ExpectedAnnotation.Message. +func expectedAnnotationsForAnnotations(annotations []check.Annotation) []ExpectedAnnotation { + return xslices.Map(annotations, expectedAnnotationForAnnotation) +} + +// expectedAnnotationForAnnotation returns an ExpectedAnnotation for the given Annotation. +// +// Callers will need to filter out the Messages from the returned ExpectedAnnotations to conform +// to the ExpectedAnnotations that are being compared against. See the note on ExpectedAnnotation.Message. +func expectedAnnotationForAnnotation(annotation check.Annotation) ExpectedAnnotation { + expectedAnnotation := ExpectedAnnotation{ + RuleID: annotation.RuleID(), + Message: annotation.Message(), + } + if location := annotation.Location(); location != nil { + expectedAnnotation.Location = &ExpectedLocation{ + FileName: location.File().FileDescriptor().Path(), + StartLine: location.StartLine(), + StartColumn: location.StartColumn(), + EndLine: location.EndLine(), + EndColumn: location.EndColumn(), + } + } + if againstLocation := annotation.AgainstLocation(); againstLocation != nil { + expectedAnnotation.AgainstLocation = &ExpectedLocation{ + FileName: againstLocation.File().FileDescriptor().Path(), + StartLine: againstLocation.StartLine(), + StartColumn: againstLocation.StartColumn(), + EndLine: againstLocation.EndLine(), + EndColumn: againstLocation.EndColumn(), + } + } + return expectedAnnotation +} + +func compile(ctx context.Context, dirPaths []string, filePaths []string) ([]check.File, error) { + dirPaths = fromSlashPaths(dirPaths) + filePaths = fromSlashPaths(filePaths) + toSlashFilePathMap := make(map[string]struct{}, len(filePaths)) + for _, filePath := range filePaths { + toSlashFilePathMap[filepath.ToSlash(filePath)] = struct{}{} + } + + var warningErrorsWithPos []reporter.ErrorWithPos + compiler := protocompile.Compiler{ + Resolver: wellknownimports.WithStandardImports( + &protocompile.SourceResolver{ + ImportPaths: dirPaths, + }, + ), + Reporter: reporter.NewReporter( + func(reporter.ErrorWithPos) error { + return nil + }, + func(errorWithPos reporter.ErrorWithPos) { + warningErrorsWithPos = append(warningErrorsWithPos, errorWithPos) + }, + ), + // This is what buf uses. + SourceInfoMode: protocompile.SourceInfoExtraOptionLocations, + } + files, err := compiler.Compile(ctx, filePaths...) + if err != nil { + return nil, err + } + syntaxUnspecifiedFilePaths := make(map[string]struct{}) + filePathToUnusedDependencyFilePaths := make(map[string]map[string]struct{}) + for _, warningErrorWithPos := range warningErrorsWithPos { + maybeAddSyntaxUnspecified(syntaxUnspecifiedFilePaths, warningErrorWithPos) + maybeAddUnusedDependency(filePathToUnusedDependencyFilePaths, warningErrorWithPos) + } + fileDescriptorSet := fileDescriptorSetForFileDescriptors(files) + + protoFiles := make([]*checkv1.File, len(fileDescriptorSet.GetFile())) + for i, fileDescriptorProto := range fileDescriptorSet.GetFile() { + _, isNotImport := toSlashFilePathMap[fileDescriptorProto.GetName()] + _, isSyntaxUnspecified := syntaxUnspecifiedFilePaths[fileDescriptorProto.GetName()] + unusedDependencyIndexes := unusedDependencyIndexesForFilePathToUnusedDependencyFilePaths( + fileDescriptorProto, + filePathToUnusedDependencyFilePaths[fileDescriptorProto.GetName()], + ) + protoFiles[i] = &checkv1.File{ + FileDescriptorProto: fileDescriptorProto, + IsImport: !isNotImport, + IsSyntaxUnspecified: isSyntaxUnspecified, + UnusedDependency: unusedDependencyIndexes, + } + } + return check.FilesForProtoFiles(protoFiles) +} + +func unusedDependencyIndexesForFilePathToUnusedDependencyFilePaths( + fileDescriptorProto *descriptorpb.FileDescriptorProto, + unusedDependencyFilePaths map[string]struct{}, +) []int32 { + unusedDependencyIndexes := make([]int32, 0, len(unusedDependencyFilePaths)) + if len(unusedDependencyFilePaths) == 0 { + return unusedDependencyIndexes + } + dependencyFilePaths := fileDescriptorProto.GetDependency() + for i := 0; i < len(dependencyFilePaths); i++ { + if _, ok := unusedDependencyFilePaths[dependencyFilePaths[i]]; ok { + unusedDependencyIndexes = append(unusedDependencyIndexes, int32(i)) + } + } + return unusedDependencyIndexes +} + +func maybeAddSyntaxUnspecified( + syntaxUnspecifiedFilePaths map[string]struct{}, + errorWithPos reporter.ErrorWithPos, +) { + if !errors.Is(errorWithPos, parser.ErrNoSyntax) { + return + } + syntaxUnspecifiedFilePaths[errorWithPos.GetPosition().Filename] = struct{}{} +} + +func maybeAddUnusedDependency( + filePathToUnusedDependencyFilePaths map[string]map[string]struct{}, + errorWithPos reporter.ErrorWithPos, +) { + var errorUnusedImport linker.ErrorUnusedImport + if !errors.As(errorWithPos, &errorUnusedImport) { + return + } + pos := errorWithPos.GetPosition() + unusedDependencyFilePaths, ok := filePathToUnusedDependencyFilePaths[pos.Filename] + if !ok { + unusedDependencyFilePaths = make(map[string]struct{}) + filePathToUnusedDependencyFilePaths[pos.Filename] = unusedDependencyFilePaths + } + unusedDependencyFilePaths[errorUnusedImport.UnusedImport()] = struct{}{} +} + +func fileDescriptorSetForFileDescriptors[D protoreflect.FileDescriptor](files []D) *descriptorpb.FileDescriptorSet { + soFar := make(map[string]struct{}, len(files)) + slice := make([]*descriptorpb.FileDescriptorProto, 0, len(files)) + for _, file := range files { + toFileDescriptorProtoSlice(file, &slice, soFar) + } + return &descriptorpb.FileDescriptorSet{File: slice} +} + +func toFileDescriptorProtoSlice(file protoreflect.FileDescriptor, results *[]*descriptorpb.FileDescriptorProto, soFar map[string]struct{}) { + if _, exists := soFar[file.Path()]; exists { + return + } + soFar[file.Path()] = struct{}{} + // Add dependencies first so the resulting slice is in topological order + imports := file.Imports() + for i, length := 0, imports.Len(); i < length; i++ { + toFileDescriptorProtoSlice(imports.Get(i).FileDescriptor, results, soFar) + } + *results = append(*results, protoutil.ProtoFromFileDescriptor(file)) +} + +func fromSlashPaths(paths []string) []string { + fromSlashPaths := make([]string, len(paths)) + for i, path := range paths { + fromSlashPaths[i] = filepath.Clean(filepath.FromSlash(path)) + } + return fromSlashPaths +} diff --git a/check/checkutil/breaking.go b/check/checkutil/breaking.go new file mode 100644 index 0000000..0d77d36 --- /dev/null +++ b/check/checkutil/breaking.go @@ -0,0 +1,315 @@ +// Copyright 2024 Buf Technologies, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package checkutil + +import ( + "context" + + "github.com/bufbuild/bufplugin-go/check" + "google.golang.org/protobuf/reflect/protoreflect" +) + +// NewFilePairRuleHandler returns a new RuleHandler that will call f for every file pair +// within the check.Request's Files() and AgainstFiles(). +// +// The Files will be paired up by name. Files that cannot be paired up are skipped. +// +// This is typically used for breaking change Rules. +func NewFilePairRuleHandler( + f func( + ctx context.Context, + responseWriter check.ResponseWriter, + request check.Request, + file check.File, + againstFile check.File, + ) error, + options ...IteratorOption, +) check.RuleHandler { + iteratorOptions := newIteratorOptions() + for _, option := range options { + option(iteratorOptions) + } + return check.RuleHandlerFunc( + func( + ctx context.Context, + responseWriter check.ResponseWriter, + request check.Request, + ) error { + files := filterFiles(request.Files(), iteratorOptions.withoutImports) + againstFiles := filterFiles(request.AgainstFiles(), iteratorOptions.withoutImports) + pathToFile, err := getPathToFile(files) + if err != nil { + return err + } + againstPathToFile, err := getPathToFile(againstFiles) + if err != nil { + return err + } + for againstPath, againstFile := range againstPathToFile { + if file, ok := pathToFile[againstPath]; ok { + if err = f(ctx, responseWriter, request, file, againstFile); err != nil { + return err + } + } + } + return nil + }, + ) +} + +// NewEnumPairRuleHandler returns a new RuleHandler that will call f for every enum pair +// within the check.Request's Files() and AgainstFiles(). +// +// The enums will be paired up by fully-qualified name. Enums that cannot be paired up are skipped. +// +// This is typically used for breaking change Rules. +func NewEnumPairRuleHandler( + f func( + ctx context.Context, + responseWriter check.ResponseWriter, + request check.Request, + enumDescriptor protoreflect.EnumDescriptor, + againstEnumDescriptor protoreflect.EnumDescriptor, + ) error, + options ...IteratorOption, +) check.RuleHandler { + iteratorOptions := newIteratorOptions() + for _, option := range options { + option(iteratorOptions) + } + return check.RuleHandlerFunc( + func( + ctx context.Context, + responseWriter check.ResponseWriter, + request check.Request, + ) error { + files := filterFiles(request.Files(), iteratorOptions.withoutImports) + againstFiles := filterFiles(request.AgainstFiles(), iteratorOptions.withoutImports) + fullNameToEnumDescriptor, err := getFullNameToEnumDescriptor(files) + if err != nil { + return err + } + againstFullNameToEnumDescriptor, err := getFullNameToEnumDescriptor(againstFiles) + if err != nil { + return err + } + for againstFullName, againstEnumDescriptor := range againstFullNameToEnumDescriptor { + if enumDescriptor, ok := fullNameToEnumDescriptor[againstFullName]; ok { + if err = f(ctx, responseWriter, request, enumDescriptor, againstEnumDescriptor); err != nil { + return err + } + } + } + return nil + }, + ) +} + +// NewMessagePairRuleHandler returns a new RuleHandler that will call f for every message pair +// within the check.Request's Files() and AgainstFiles(). +// +// The messages will be paired up by fully-qualified name. Messages that cannot be paired up are skipped. +// +// This is typically used for breaking change Rules. +func NewMessagePairRuleHandler( + f func( + ctx context.Context, + responseWriter check.ResponseWriter, + request check.Request, + messageDescriptor protoreflect.MessageDescriptor, + againstMessageDescriptor protoreflect.MessageDescriptor, + ) error, + options ...IteratorOption, +) check.RuleHandler { + iteratorOptions := newIteratorOptions() + for _, option := range options { + option(iteratorOptions) + } + return check.RuleHandlerFunc( + func( + ctx context.Context, + responseWriter check.ResponseWriter, + request check.Request, + ) error { + files := filterFiles(request.Files(), iteratorOptions.withoutImports) + againstFiles := filterFiles(request.AgainstFiles(), iteratorOptions.withoutImports) + fullNameToMessageDescriptor, err := getFullNameToMessageDescriptor(files) + if err != nil { + return err + } + againstFullNameToMessageDescriptor, err := getFullNameToMessageDescriptor(againstFiles) + if err != nil { + return err + } + for againstFullName, againstMessageDescriptor := range againstFullNameToMessageDescriptor { + if messageDescriptor, ok := fullNameToMessageDescriptor[againstFullName]; ok { + if err = f(ctx, responseWriter, request, messageDescriptor, againstMessageDescriptor); err != nil { + return err + } + } + } + return nil + }, + ) +} + +// NewFieldPairRuleHandler returns a new RuleHandler that will call f for every field pair +// within the check.Request's Files() and AgainstFiles(). +// +// The fields will be paired up by the fully-qualified name of the message, and the field number. +// Fields that cannot be paired up are skipped. +// +// This includes extensions. +// +// This is typically used for breaking change Rules. +func NewFieldPairRuleHandler( + f func( + ctx context.Context, + responseWriter check.ResponseWriter, + request check.Request, + fieldDescriptor protoreflect.FieldDescriptor, + againstFieldDescriptor protoreflect.FieldDescriptor, + ) error, + options ...IteratorOption, +) check.RuleHandler { + iteratorOptions := newIteratorOptions() + for _, option := range options { + option(iteratorOptions) + } + return check.RuleHandlerFunc( + func( + ctx context.Context, + responseWriter check.ResponseWriter, + request check.Request, + ) error { + files := filterFiles(request.Files(), iteratorOptions.withoutImports) + againstFiles := filterFiles(request.AgainstFiles(), iteratorOptions.withoutImports) + containingMessageFullNameToNumberToFieldDescriptor, err := getContainingMessageFullNameToNumberToFieldDescriptor(files) + if err != nil { + return err + } + againstContainingMessageFullNameToNumberToFieldDescriptor, err := getContainingMessageFullNameToNumberToFieldDescriptor(againstFiles) + if err != nil { + return err + } + for againstContainingMessageFullName, againstNumberToFieldDescriptor := range againstContainingMessageFullNameToNumberToFieldDescriptor { + if numberToFieldDescriptor, ok := containingMessageFullNameToNumberToFieldDescriptor[againstContainingMessageFullName]; ok { + for againstNumber, againstFieldDescriptor := range againstNumberToFieldDescriptor { + if fieldDescriptor, ok := numberToFieldDescriptor[againstNumber]; ok { + if err = f(ctx, responseWriter, request, fieldDescriptor, againstFieldDescriptor); err != nil { + return err + } + } + } + } + } + return nil + }, + ) +} + +// NewServicePairRuleHandler returns a new RuleHandler that will call f for every service pair +// within the check.Request's Files() and AgainstFiles(). +// +// The services will be paired up by fully-qualified name. Services that cannot be paired up are skipped. +// +// This is typically used for breaking change Rules. +func NewServicePairRuleHandler( + f func( + ctx context.Context, + responseWriter check.ResponseWriter, + request check.Request, + serviceDescriptor protoreflect.ServiceDescriptor, + againstServiceDescriptor protoreflect.ServiceDescriptor, + ) error, + options ...IteratorOption, +) check.RuleHandler { + iteratorOptions := newIteratorOptions() + for _, option := range options { + option(iteratorOptions) + } + return check.RuleHandlerFunc( + func( + ctx context.Context, + responseWriter check.ResponseWriter, + request check.Request, + ) error { + files := filterFiles(request.Files(), iteratorOptions.withoutImports) + againstFiles := filterFiles(request.AgainstFiles(), iteratorOptions.withoutImports) + fullNameToServiceDescriptor, err := getFullNameToServiceDescriptor(files) + if err != nil { + return err + } + againstFullNameToServiceDescriptor, err := getFullNameToServiceDescriptor(againstFiles) + if err != nil { + return err + } + for againstFullName, againstServiceDescriptor := range againstFullNameToServiceDescriptor { + if serviceDescriptor, ok := fullNameToServiceDescriptor[againstFullName]; ok { + if err = f(ctx, responseWriter, request, serviceDescriptor, againstServiceDescriptor); err != nil { + return err + } + } + } + return nil + }, + ) +} + +// NewMethodPairRuleHandler returns a new RuleHandler that will call f for every method pair +// within the check.Request's Files() and AgainstFiles(). +// +// The services will be paired up by fully-qualified name of the service, and name of the method. +// Methods that cannot be paired up are skipped. +// +// This is typically used for breaking change Rules. +func NewMethodPairRuleHandler( + f func( + ctx context.Context, + responseWriter check.ResponseWriter, + request check.Request, + methodDescriptor protoreflect.MethodDescriptor, + againstMethodDescriptor protoreflect.MethodDescriptor, + ) error, + options ...IteratorOption, +) check.RuleHandler { + return NewServicePairRuleHandler( + func( + ctx context.Context, + responseWriter check.ResponseWriter, + request check.Request, + serviceDescriptor protoreflect.ServiceDescriptor, + againstServiceDescriptor protoreflect.ServiceDescriptor, + ) error { + nameToMethodDescriptor, err := getNameToMethodDescriptor(serviceDescriptor) + if err != nil { + return err + } + againstNameToMethodDescriptor, err := getNameToMethodDescriptor(againstServiceDescriptor) + if err != nil { + return err + } + for againstName, againstMethodDescriptor := range againstNameToMethodDescriptor { + if methodDescriptor, ok := nameToMethodDescriptor[againstName]; ok { + if err = f(ctx, responseWriter, request, methodDescriptor, againstMethodDescriptor); err != nil { + return err + } + } + } + return nil + }, + options..., + ) +} diff --git a/check/checkutil/checkutil.go b/check/checkutil/checkutil.go new file mode 100644 index 0000000..fc879a0 --- /dev/null +++ b/check/checkutil/checkutil.go @@ -0,0 +1,42 @@ +// Copyright 2024 Buf Technologies, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package checkutil implements helpers for the check package. +package checkutil + +// IteratorOption is an option for any of the New.*RuleHandler functions in this package. +type IteratorOption func(*iteratorOptions) + +// WithoutImports returns a new IteratorOption that will not call the provided function +// for any imports. +// +// For lint RuleHandlers, this is generally an option you will want to pass. For breaking +// RuleHandlers, you generally want to consider imports as part of breaking changes. +// +// The default is to call the provided function for all imports. +func WithoutImports() IteratorOption { + return func(iteratorOptions *iteratorOptions) { + iteratorOptions.withoutImports = true + } +} + +// *** PRIVATE *** + +type iteratorOptions struct { + withoutImports bool +} + +func newIteratorOptions() *iteratorOptions { + return &iteratorOptions{} +} diff --git a/check/checkutil/lint.go b/check/checkutil/lint.go new file mode 100644 index 0000000..461fa9a --- /dev/null +++ b/check/checkutil/lint.go @@ -0,0 +1,268 @@ +// Copyright 2024 Buf Technologies, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package checkutil + +import ( + "context" + + "github.com/bufbuild/bufplugin-go/check" + "google.golang.org/protobuf/reflect/protoreflect" +) + +// NewFileRuleHandler returns a new RuleHandler that will call f for every file +// within the check.Request's Files(). +// +// This is typically used for lint Rules. Most callers will use the WithoutImports() options. +func NewFileRuleHandler( + f func(context.Context, check.ResponseWriter, check.Request, check.File) error, + options ...IteratorOption, +) check.RuleHandler { + iteratorOptions := newIteratorOptions() + for _, option := range options { + option(iteratorOptions) + } + return check.RuleHandlerFunc( + func( + ctx context.Context, + responseWriter check.ResponseWriter, + request check.Request, + ) error { + for _, file := range request.Files() { + if iteratorOptions.withoutImports && file.IsImport() { + continue + } + if err := f(ctx, responseWriter, request, file); err != nil { + return err + } + } + return nil + }, + ) +} + +// NewFileImportRuleHandler returns a new RuleHandler that will call f for every "import" statement +// within the check.Request's Files(). +// +// Note that terms are overloaded here: check.File.IsImport denotes whether the File is an import +// itself, while this iterates over the protoreflect.FileImports within each File. The option +// WithoutImports() is a separate concern - NewFileImportRuleHandler(f, WithoutImports()) will +// iterate over all the FileImports for the non-import Files. +// +// This is typically used for lint Rules. Most callers will use the WithoutImports() options. +func NewFileImportRuleHandler( + f func(context.Context, check.ResponseWriter, check.Request, protoreflect.FileImport) error, + options ...IteratorOption, +) check.RuleHandler { + return NewFileRuleHandler( + func( + ctx context.Context, + responseWriter check.ResponseWriter, + request check.Request, + file check.File, + ) error { + return forEachFileImport( + file.FileDescriptor(), + func(fileImport protoreflect.FileImport) error { + return f(ctx, responseWriter, request, fileImport) + }, + ) + }, + options..., + ) +} + +// NewEnumRuleHandler returns a new RuleHandler that will call f for every enum +// within the check.Request's Files(). +// +// This is typically used for lint Rules. Most callers will use the WithoutImports() options. +func NewEnumRuleHandler( + f func(context.Context, check.ResponseWriter, check.Request, protoreflect.EnumDescriptor) error, + options ...IteratorOption, +) check.RuleHandler { + return NewFileRuleHandler( + func( + ctx context.Context, + responseWriter check.ResponseWriter, + request check.Request, + file check.File, + ) error { + return forEachEnum( + file.FileDescriptor(), + func(enumDescriptor protoreflect.EnumDescriptor) error { + return f(ctx, responseWriter, request, enumDescriptor) + }, + ) + }, + options..., + ) +} + +// NewEnumValueRuleHandler returns a new RuleHandler that will call f for every value in every enum +// within the check.Request's Files(). +// +// This is typically used for lint Rules. Most callers will use the WithoutImports() options. +func NewEnumValueRuleHandler( + f func(context.Context, check.ResponseWriter, check.Request, protoreflect.EnumValueDescriptor) error, + options ...IteratorOption, +) check.RuleHandler { + return NewEnumRuleHandler( + func( + ctx context.Context, + responseWriter check.ResponseWriter, + request check.Request, + enumDescriptor protoreflect.EnumDescriptor, + ) error { + return forEachEnumValue( + enumDescriptor, + func(enumValueDescriptor protoreflect.EnumValueDescriptor) error { + return f(ctx, responseWriter, request, enumValueDescriptor) + }, + ) + }, + options..., + ) +} + +// NewMessageRuleHandler returns a new RuleHandler that will call f for every message +// within the check.Request's Files(). +// +// This is typically used for lint Rules. Most callers will use the WithoutImports() options. +func NewMessageRuleHandler( + f func(context.Context, check.ResponseWriter, check.Request, protoreflect.MessageDescriptor) error, + options ...IteratorOption, +) check.RuleHandler { + return NewFileRuleHandler( + func( + ctx context.Context, + responseWriter check.ResponseWriter, + request check.Request, + file check.File, + ) error { + return forEachMessage( + file.FileDescriptor(), + func(messageDescriptor protoreflect.MessageDescriptor) error { + return f(ctx, responseWriter, request, messageDescriptor) + }, + ) + }, + options..., + ) +} + +// NewFieldRuleHandler returns a new RuleHandler that will call f for every field in every message +// within the check.Request's Files(). +// +// This includes extensions. +// +// This is typically used for lint Rules. Most callers will use the WithoutImports() options. +func NewFieldRuleHandler( + f func(context.Context, check.ResponseWriter, check.Request, protoreflect.FieldDescriptor) error, + options ...IteratorOption, +) check.RuleHandler { + return NewFileRuleHandler( + func( + ctx context.Context, + responseWriter check.ResponseWriter, + request check.Request, + file check.File, + ) error { + return forEachField( + file.FileDescriptor(), + func(fieldDescriptor protoreflect.FieldDescriptor) error { + return f(ctx, responseWriter, request, fieldDescriptor) + }, + ) + }, + options..., + ) +} + +// NewOneofRuleHandler returns a new RuleHandler that will call f for every oneof in every message +// within the check.Request's Files(). +// +// This is typically used for lint Rules. Most callers will use the WithoutImports() options. +func NewOneofRuleHandler( + f func(context.Context, check.ResponseWriter, check.Request, protoreflect.OneofDescriptor) error, + options ...IteratorOption, +) check.RuleHandler { + return NewMessageRuleHandler( + func( + ctx context.Context, + responseWriter check.ResponseWriter, + request check.Request, + messageDescriptor protoreflect.MessageDescriptor, + ) error { + return forEachOneof( + messageDescriptor, + func(oneofDescriptor protoreflect.OneofDescriptor) error { + return f(ctx, responseWriter, request, oneofDescriptor) + }, + ) + }, + options..., + ) +} + +// NewServiceRuleHandler returns a new RuleHandler that will call f for every service +// within the check.Request's Files(). +// +// This is typically used for lint Rules. Most callers will use the WithoutImports() options. +func NewServiceRuleHandler( + f func(context.Context, check.ResponseWriter, check.Request, protoreflect.ServiceDescriptor) error, + options ...IteratorOption, +) check.RuleHandler { + return NewFileRuleHandler( + func( + ctx context.Context, + responseWriter check.ResponseWriter, + request check.Request, + file check.File, + ) error { + return forEachService( + file.FileDescriptor(), + func(serviceDescriptor protoreflect.ServiceDescriptor) error { + return f(ctx, responseWriter, request, serviceDescriptor) + }, + ) + }, + options..., + ) +} + +// NewMethodRuleHandler returns a new RuleHandler that will call f for every method in every service +// within the check.Request's Files(). +// +// This is typically used for lint Rules. Most callers will use the WithoutImports() options. +func NewMethodRuleHandler( + f func(context.Context, check.ResponseWriter, check.Request, protoreflect.MethodDescriptor) error, + options ...IteratorOption, +) check.RuleHandler { + return NewServiceRuleHandler( + func( + ctx context.Context, + responseWriter check.ResponseWriter, + request check.Request, + serviceDescriptor protoreflect.ServiceDescriptor, + ) error { + return forEachMethod( + serviceDescriptor, + func(methodDescriptor protoreflect.MethodDescriptor) error { + return f(ctx, responseWriter, request, methodDescriptor) + }, + ) + }, + options..., + ) +} diff --git a/check/checkutil/util.go b/check/checkutil/util.go new file mode 100644 index 0000000..bcb56d1 --- /dev/null +++ b/check/checkutil/util.go @@ -0,0 +1,325 @@ +// Copyright 2024 Buf Technologies, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package checkutil + +import ( + "fmt" + "sort" + + "github.com/bufbuild/bufplugin-go/check" + "github.com/bufbuild/bufplugin-go/internal/pkg/xslices" + "google.golang.org/protobuf/reflect/protoreflect" +) + +type container interface { + Enums() protoreflect.EnumDescriptors + Messages() protoreflect.MessageDescriptors + Extensions() protoreflect.ExtensionDescriptors +} + +func getPathToFile(files []check.File) (map[string]check.File, error) { + pathToFileMap := make(map[string]check.File, len(files)) + for _, file := range files { + path := file.FileDescriptor().Path() + if _, ok := pathToFileMap[path]; ok { + return nil, fmt.Errorf("duplicate file: %q", path) + } + pathToFileMap[path] = file + } + return pathToFileMap, nil +} + +func getFullNameToEnumDescriptor(files []check.File) (map[protoreflect.FullName]protoreflect.EnumDescriptor, error) { + fullNameToEnumDescriptorMap := make(map[protoreflect.FullName]protoreflect.EnumDescriptor) + for _, file := range files { + if err := forEachEnum( + file.FileDescriptor(), + func(enumDescriptor protoreflect.EnumDescriptor) error { + fullName := enumDescriptor.FullName() + if _, ok := fullNameToEnumDescriptorMap[fullName]; ok { + return fmt.Errorf("duplicate enum: %q", fullName) + } + fullNameToEnumDescriptorMap[fullName] = enumDescriptor + return nil + }, + ); err != nil { + return nil, err + } + } + return fullNameToEnumDescriptorMap, nil +} + +// Keeping this function around for now, this is to suppress lint unused. +var _ = getNumberToEnumValueDescriptors + +func getNumberToEnumValueDescriptors(enumDescriptor protoreflect.EnumDescriptor) (map[protoreflect.EnumNumber][]protoreflect.EnumValueDescriptor, error) { + numberToEnumValueDescriptorsMap := make(map[protoreflect.EnumNumber][]protoreflect.EnumValueDescriptor) + if err := forEachEnumValue( + enumDescriptor, + func(enumValueDescriptor protoreflect.EnumValueDescriptor) error { + numberToEnumValueDescriptorsMap[enumValueDescriptor.Number()] = append( + numberToEnumValueDescriptorsMap[enumValueDescriptor.Number()], + enumValueDescriptor, + ) + return nil + }, + ); err != nil { + return nil, err + } + for _, enumValueDescriptors := range numberToEnumValueDescriptorsMap { + sort.Slice( + enumValueDescriptors, + func(i int, j int) bool { + return enumValueDescriptors[i].Name() < enumValueDescriptors[j].Name() + }, + ) + } + return numberToEnumValueDescriptorsMap, nil +} + +func getFullNameToMessageDescriptor(files []check.File) (map[protoreflect.FullName]protoreflect.MessageDescriptor, error) { + fullNameToMessageDescriptorMap := make(map[protoreflect.FullName]protoreflect.MessageDescriptor) + for _, file := range files { + if err := forEachMessage( + file.FileDescriptor(), + func(messageDescriptor protoreflect.MessageDescriptor) error { + fullName := messageDescriptor.FullName() + if _, ok := fullNameToMessageDescriptorMap[fullName]; ok { + return fmt.Errorf("duplicate message: %q", fullName) + } + fullNameToMessageDescriptorMap[fullName] = messageDescriptor + return nil + }, + ); err != nil { + return nil, err + } + } + return fullNameToMessageDescriptorMap, nil +} + +func getContainingMessageFullNameToNumberToFieldDescriptor( + files []check.File, +) (map[protoreflect.FullName]map[protoreflect.FieldNumber]protoreflect.FieldDescriptor, error) { + containingMessageFullNameToNumberToFieldDescriptorMap := make( + map[protoreflect.FullName]map[protoreflect.FieldNumber]protoreflect.FieldDescriptor, + ) + for _, file := range files { + if err := forEachField( + file.FileDescriptor(), + func(fieldDescriptor protoreflect.FieldDescriptor) error { + number := fieldDescriptor.Number() + containingMessage := fieldDescriptor.ContainingMessage() + if containingMessage == nil { + return fmt.Errorf("containing message was nil for field %d", number) + } + fullName := containingMessage.FullName() + numberToFieldDescriptor, ok := containingMessageFullNameToNumberToFieldDescriptorMap[fullName] + if !ok { + numberToFieldDescriptor = make(map[protoreflect.FieldNumber]protoreflect.FieldDescriptor) + containingMessageFullNameToNumberToFieldDescriptorMap[fullName] = numberToFieldDescriptor + } + if _, ok := numberToFieldDescriptor[number]; ok { + return fmt.Errorf("duplicate field on message %q: %d", fullName, number) + } + numberToFieldDescriptor[number] = fieldDescriptor + return nil + }, + ); err != nil { + return nil, err + } + } + return containingMessageFullNameToNumberToFieldDescriptorMap, nil +} + +func getFullNameToServiceDescriptor(files []check.File) (map[protoreflect.FullName]protoreflect.ServiceDescriptor, error) { + fullNameToServiceDescriptorMap := make(map[protoreflect.FullName]protoreflect.ServiceDescriptor) + for _, file := range files { + if err := forEachService( + file.FileDescriptor(), + func(serviceDescriptor protoreflect.ServiceDescriptor) error { + fullName := serviceDescriptor.FullName() + if _, ok := fullNameToServiceDescriptorMap[fullName]; ok { + return fmt.Errorf("duplicate service: %q", fullName) + } + fullNameToServiceDescriptorMap[fullName] = serviceDescriptor + return nil + }, + ); err != nil { + return nil, err + } + } + return fullNameToServiceDescriptorMap, nil +} + +func getNameToMethodDescriptor(serviceDescriptor protoreflect.ServiceDescriptor) (map[protoreflect.Name]protoreflect.MethodDescriptor, error) { + nameToMethodDescriptorMap := make(map[protoreflect.Name]protoreflect.MethodDescriptor) + if err := forEachMethod( + serviceDescriptor, + func(methodDescriptor protoreflect.MethodDescriptor) error { + name := methodDescriptor.Name() + if _, ok := nameToMethodDescriptorMap[name]; ok { + return fmt.Errorf("duplicate method on service %q: %q", serviceDescriptor.FullName(), name) + } + nameToMethodDescriptorMap[name] = methodDescriptor + return nil + }, + ); err != nil { + return nil, err + } + return nameToMethodDescriptorMap, nil +} + +func forEachFileImport( + fileDescriptor protoreflect.FileDescriptor, + f func(protoreflect.FileImport) error, +) error { + fileImports := fileDescriptor.Imports() + for i := 0; i < fileImports.Len(); i++ { + if err := f(fileImports.Get(i)); err != nil { + return err + } + } + return nil +} + +func forEachEnum( + container container, + f func(protoreflect.EnumDescriptor) error, +) error { + enums := container.Enums() + for i := 0; i < enums.Len(); i++ { + if err := f(enums.Get(i)); err != nil { + return err + } + } + messages := container.Messages() + for i := 0; i < messages.Len(); i++ { + // Nested enums. + if err := forEachEnum(messages.Get(i), f); err != nil { + return err + } + } + return nil +} + +func forEachEnumValue( + enumDescriptor protoreflect.EnumDescriptor, + f func(protoreflect.EnumValueDescriptor) error, +) error { + enumValues := enumDescriptor.Values() + for i := 0; i < enumValues.Len(); i++ { + if err := f(enumValues.Get(i)); err != nil { + return err + } + } + return nil +} + +func forEachMessage( + container container, + f func(protoreflect.MessageDescriptor) error, +) error { + messages := container.Messages() + for i := 0; i < messages.Len(); i++ { + messageDescriptor := messages.Get(i) + if err := f(messageDescriptor); err != nil { + return err + } + // Nested messages. + if err := forEachMessage(messageDescriptor, f); err != nil { + return err + } + } + return nil +} + +func forEachField( + container container, + f func(protoreflect.FieldDescriptor) error, +) error { + if err := forEachMessage( + container, + func(messageDescriptor protoreflect.MessageDescriptor) error { + fields := messageDescriptor.Fields() + for i := 0; i < fields.Len(); i++ { + if err := f(fields.Get(i)); err != nil { + return err + } + } + extensions := messageDescriptor.Extensions() + for i := 0; i < extensions.Len(); i++ { + if err := f(extensions.Get(i)); err != nil { + return err + } + } + return nil + }, + ); err != nil { + return err + } + extensions := container.Extensions() + for i := 0; i < extensions.Len(); i++ { + if err := f(extensions.Get(i)); err != nil { + return err + } + } + return nil +} + +func forEachOneof( + messageDescriptor protoreflect.MessageDescriptor, + f func(protoreflect.OneofDescriptor) error, +) error { + oneofs := messageDescriptor.Oneofs() + for i := 0; i < oneofs.Len(); i++ { + if err := f(oneofs.Get(i)); err != nil { + return err + } + } + return nil +} + +func forEachService( + fileDescriptor protoreflect.FileDescriptor, + f func(protoreflect.ServiceDescriptor) error, +) error { + services := fileDescriptor.Services() + for i := 0; i < services.Len(); i++ { + if err := f(services.Get(i)); err != nil { + return err + } + } + return nil +} + +func forEachMethod( + serviceDescriptor protoreflect.ServiceDescriptor, + f func(protoreflect.MethodDescriptor) error, +) error { + methods := serviceDescriptor.Methods() + for i := 0; i < methods.Len(); i++ { + if err := f(methods.Get(i)); err != nil { + return err + } + } + return nil +} + +func filterFiles(files []check.File, withoutImports bool) []check.File { + if !withoutImports { + return files + } + return xslices.Filter(files, func(file check.File) bool { return !file.IsImport() }) +} diff --git a/check/client.go b/check/client.go new file mode 100644 index 0000000..a4d2777 --- /dev/null +++ b/check/client.go @@ -0,0 +1,320 @@ +// Copyright 2024 Buf Technologies, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package check + +import ( + "context" + "fmt" + + checkv1 "buf.build/gen/go/bufbuild/bufplugin/protocolbuffers/go/buf/plugin/check/v1" + "github.com/bufbuild/bufplugin-go/internal/gen/buf/plugin/check/v1/v1pluginrpc" + "github.com/bufbuild/bufplugin-go/internal/pkg/cache" + "github.com/bufbuild/bufplugin-go/internal/pkg/xslices" + "pluginrpc.com/pluginrpc" +) + +const ( + listRulesPageSize = 250 + listCategoriesPageSize = 250 +) + +// Client is a client for a custom lint or breaking change plugin. +type Client interface { + // Check invokes a check using the plugin.. + Check(ctx context.Context, request Request, options ...CheckCallOption) (Response, error) + // ListRules lists all available Rules from the plugin. + // + // The Rules will be sorted by Rule ID. + // Returns error if duplicate Rule IDs were detected from the underlying source. + ListRules(ctx context.Context, options ...ListRulesCallOption) ([]Rule, error) + // ListCategories lists all available Categories from the plugin. + // + // The Categories will be sorted by Category ID. + // Returns error if duplicate Category IDs were detected from the underlying source. + ListCategories(ctx context.Context, options ...ListCategoriesCallOption) ([]Category, error) + + isClient() +} + +// NewClient returns a new Client for the given pluginrpc.Client. +func NewClient(pluginrpcClient pluginrpc.Client, options ...ClientOption) Client { + clientOptions := newClientOptions() + for _, option := range options { + option.applyToClient(clientOptions) + } + return newClient(pluginrpcClient, clientOptions.cacheRulesAndCategories) +} + +// ClientOption is an option for a new Client. +type ClientOption interface { + ClientForSpecOption + + applyToClient(opts *clientOptions) +} + +// ClientWithCacheRulesAndCategories returns a new ClientOption that will result in the Rules from +// ListRules and the Categories from ListCategories being cached. +// +// The default is to not cache Rules or Categories. +func ClientWithCacheRulesAndCategories() ClientOption { + return clientWithCacheRulesAndCategoriesOption{} +} + +// NewClientForSpec return a new Client that directly uses the given Spec. +// +// This should primarily be used for testing. +func NewClientForSpec(spec *Spec, options ...ClientForSpecOption) (Client, error) { + clientForSpecOptions := newClientForSpecOptions() + for _, option := range options { + option.applyToClientForSpec(clientForSpecOptions) + } + checkServiceHandler, err := NewCheckServiceHandler(spec) + if err != nil { + return nil, err + } + checkServiceServer, err := NewCheckServiceServer(checkServiceHandler) + if err != nil { + return nil, err + } + return newClient( + pluginrpc.NewClient( + pluginrpc.NewServerRunner(checkServiceServer), + ), + clientForSpecOptions.cacheRulesAndCategories, + ), nil +} + +// ClientForSpecOption is an option for a new Client constructed with NewClientForSpec. +type ClientForSpecOption interface { + applyToClientForSpec(opts *clientForSpecOptions) +} + +// CheckCallOption is an option for a Client.Check call. +type CheckCallOption func(*checkCallOptions) + +// ListRulesCallOption is an option for a Client.ListRules call. +type ListRulesCallOption func(*listRulesCallOptions) + +// ListCategoriesCallOption is an option for a Client.ListCategories call. +type ListCategoriesCallOption func(*listCategoriesCallOptions) + +// *** PRIVATE *** + +type client struct { + pluginrpcClient pluginrpc.Client + + cacheRulesAndCategories bool + + // Singleton ordering: rules -> categories -> checkServiceClient + rules *cache.Singleton[[]Rule] + categories *cache.Singleton[[]Category] + checkServiceClient *cache.Singleton[v1pluginrpc.CheckServiceClient] +} + +func newClient( + pluginrpcClient pluginrpc.Client, + cacheRulesAndCategories bool, +) *client { + client := &client{ + pluginrpcClient: pluginrpcClient, + cacheRulesAndCategories: cacheRulesAndCategories, + } + client.rules = cache.NewSingleton(client.listRulesUncached) + client.categories = cache.NewSingleton(client.listCategoriesUncached) + client.checkServiceClient = cache.NewSingleton(client.getCheckServiceClientUncached) + return client +} + +func (c *client) Check(ctx context.Context, request Request, _ ...CheckCallOption) (Response, error) { + checkServiceClient, err := c.checkServiceClient.Get(ctx) + if err != nil { + return nil, err + } + multiResponseWriter, err := newMultiResponseWriter(request) + if err != nil { + return nil, err + } + protoRequests, err := request.toProtos() + if err != nil { + return nil, err + } + for _, protoRequest := range protoRequests { + protoResponse, err := checkServiceClient.Check(ctx, protoRequest) + if err != nil { + return nil, err + } + for _, protoAnnotation := range protoResponse.GetAnnotations() { + multiResponseWriter.addAnnotation( + protoAnnotation.GetRuleId(), + WithMessage(protoAnnotation.GetMessage()), + WithFileNameAndSourcePath( + protoAnnotation.GetLocation().GetFileName(), + protoAnnotation.GetLocation().GetSourcePath(), + ), + WithAgainstFileNameAndSourcePath( + protoAnnotation.GetAgainstLocation().GetFileName(), + protoAnnotation.GetAgainstLocation().GetSourcePath(), + ), + ) + } + } + return multiResponseWriter.toResponse() +} + +func (c *client) ListRules(ctx context.Context, _ ...ListRulesCallOption) ([]Rule, error) { + return c.rules.Get(ctx) +} + +func (c *client) ListCategories(ctx context.Context, _ ...ListCategoriesCallOption) ([]Category, error) { + return c.categories.Get(ctx) +} + +func (c *client) listRulesUncached(ctx context.Context) ([]Rule, error) { + checkServiceClient, err := c.checkServiceClient.Get(ctx) + if err != nil { + return nil, err + } + var protoRules []*checkv1.Rule + var pageToken string + for { + response, err := checkServiceClient.ListRules( + ctx, + &checkv1.ListRulesRequest{ + PageSize: listRulesPageSize, + PageToken: pageToken, + }, + ) + if err != nil { + return nil, err + } + protoRules = append(protoRules, response.GetRules()...) + pageToken = response.GetNextPageToken() + if pageToken == "" { + break + } + } + + // We acquire rules before categories. + categories, err := c.ListCategories(ctx) + if err != nil { + return nil, err + } + categoryIDToCategory := make(map[string]Category) + for _, category := range categories { + // We know there are no duplicate IDs from validation. + categoryIDToCategory[category.ID()] = category + } + rules, err := xslices.MapError( + protoRules, + func(protoRule *checkv1.Rule) (Rule, error) { + return ruleForProtoRule(protoRule, categoryIDToCategory) + }, + ) + if err != nil { + return nil, err + } + if err := validateRules(rules); err != nil { + return nil, err + } + sortRules(rules) + return rules, nil +} + +func (c *client) listCategoriesUncached(ctx context.Context) ([]Category, error) { + checkServiceClient, err := c.checkServiceClient.Get(ctx) + if err != nil { + return nil, err + } + var protoCategories []*checkv1.Category + var pageToken string + for { + response, err := checkServiceClient.ListCategories( + ctx, + &checkv1.ListCategoriesRequest{ + PageSize: listCategoriesPageSize, + PageToken: pageToken, + }, + ) + if err != nil { + return nil, err + } + protoCategories = append(protoCategories, response.GetCategories()...) + pageToken = response.GetNextPageToken() + if pageToken == "" { + break + } + } + categories, err := xslices.MapError(protoCategories, categoryForProtoCategory) + if err != nil { + return nil, err + } + if err := validateCategories(categories); err != nil { + return nil, err + } + sortCategories(categories) + return categories, nil +} + +func (c *client) getCheckServiceClientUncached(ctx context.Context) (v1pluginrpc.CheckServiceClient, error) { + spec, err := c.pluginrpcClient.Spec(ctx) + if err != nil { + return nil, err + } + // All of these procedures are required for a plugin to be considered a buf plugin. + for _, procedurePath := range []string{ + v1pluginrpc.CheckServiceCheckPath, + v1pluginrpc.CheckServiceListRulesPath, + v1pluginrpc.CheckServiceListCategoriesPath, + } { + if spec.ProcedureForPath(procedurePath) == nil { + return nil, fmt.Errorf("plugin spec not implemented: RPC %q not found", procedurePath) + } + } + return v1pluginrpc.NewCheckServiceClient(c.pluginrpcClient) +} + +func (*client) isClient() {} + +type clientOptions struct { + cacheRulesAndCategories bool +} + +func newClientOptions() *clientOptions { + return &clientOptions{} +} + +type clientForSpecOptions struct { + cacheRulesAndCategories bool +} + +func newClientForSpecOptions() *clientForSpecOptions { + return &clientForSpecOptions{} +} + +type clientWithCacheRulesAndCategoriesOption struct{} + +func (clientWithCacheRulesAndCategoriesOption) applyToClient(clientOptions *clientOptions) { + clientOptions.cacheRulesAndCategories = true +} + +func (clientWithCacheRulesAndCategoriesOption) applyToClientForSpec(clientForSpecOptions *clientForSpecOptions) { + clientForSpecOptions.cacheRulesAndCategories = true +} + +type checkCallOptions struct{} + +type listRulesCallOptions struct{} + +type listCategoriesCallOptions struct{} diff --git a/check/client_test.go b/check/client_test.go new file mode 100644 index 0000000..ff5b6e0 --- /dev/null +++ b/check/client_test.go @@ -0,0 +1,154 @@ +// Copyright 2024 Buf Technologies, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package check + +import ( + "context" + "fmt" + "slices" + "testing" + + "github.com/bufbuild/bufplugin-go/internal/pkg/xslices" + "github.com/stretchr/testify/require" +) + +func TestClientListRulesCategoriesSimple(t *testing.T) { + t.Parallel() + + testClientListRulesCategoriesSimple(t) + testClientListRulesCategoriesSimple(t, ClientWithCacheRulesAndCategories()) +} + +func testClientListRulesCategoriesSimple(t *testing.T, options ...ClientForSpecOption) { + ctx := context.Background() + client, err := NewClientForSpec( + &Spec{ + Rules: []*RuleSpec{ + { + ID: "RULE1", + Purpose: "Test RULE1.", + Type: RuleTypeLint, + Handler: nopRuleHandler, + }, + { + ID: "RULE2", + CategoryIDs: []string{ + "CATEGORY1", + }, + Purpose: "Test RULE2.", + Type: RuleTypeLint, + Handler: nopRuleHandler, + }, + { + ID: "RULE3", + CategoryIDs: []string{ + "CATEGORY1", + "CATEGORY2", + }, + Purpose: "Test RULE3.", + Type: RuleTypeLint, + Handler: nopRuleHandler, + }, + }, + Categories: []*CategorySpec{ + { + ID: "CATEGORY1", + Purpose: "Test CATEGORY1.", + }, + { + ID: "CATEGORY2", + Purpose: "Test CATEGORY2.", + }, + }, + }, + options..., + ) + require.NoError(t, err) + rules, err := client.ListRules(ctx) + require.NoError(t, err) + require.Equal( + t, + []string{ + "RULE1", + "RULE2", + "RULE3", + }, + xslices.Map(rules, Rule.ID), + ) + categories, err := client.ListCategories(ctx) + require.NoError(t, err) + require.Equal( + t, + []string{ + "CATEGORY1", + "CATEGORY2", + }, + xslices.Map(categories, Category.ID), + ) + categories = rules[0].Categories() + require.Empty(t, categories) + categories = rules[1].Categories() + require.Equal( + t, + []string{ + "CATEGORY1", + }, + xslices.Map(categories, Category.ID), + ) + categories = rules[2].Categories() + require.Equal( + t, + []string{ + "CATEGORY1", + "CATEGORY2", + }, + xslices.Map(categories, Category.ID), + ) +} + +func TestClientListRulesCount(t *testing.T) { + t.Parallel() + + testClientListRulesCount(t, listRulesPageSize-1) + testClientListRulesCount(t, listRulesPageSize) + testClientListRulesCount(t, listRulesPageSize+1) + testClientListRulesCount(t, listRulesPageSize*2) + testClientListRulesCount(t, (listRulesPageSize*2)+1) + testClientListRulesCount(t, (listRulesPageSize*4)+1) +} + +func testClientListRulesCount(t *testing.T, count int) { + require.True(t, count < 10000, "count must be less than 10000 for sorting to work properly in this test") + ruleSpecs := make([]*RuleSpec, count) + for i := 0; i < count; i++ { + ruleSpecs[i] = &RuleSpec{ + ID: fmt.Sprintf("RULE%05d", i), + Purpose: fmt.Sprintf("Test RULE%05d.", i), + Type: RuleTypeLint, + Handler: nopRuleHandler, + } + } + // Make the ruleSpecs not in sorted order. + ruleSpecsOutOfOrder := slices.Clone(ruleSpecs) + slices.Reverse(ruleSpecsOutOfOrder) + client, err := NewClientForSpec(&Spec{Rules: ruleSpecsOutOfOrder}) + require.NoError(t, err) + rules, err := client.ListRules(context.Background()) + require.NoError(t, err) + require.Equal(t, count, len(rules)) + for i := 0; i < count; i++ { + require.Equal(t, ruleSpecs[i].ID, rules[i].ID()) + } +} diff --git a/check/compare.go b/check/compare.go new file mode 100644 index 0000000..d72f63a --- /dev/null +++ b/check/compare.go @@ -0,0 +1,158 @@ +// Copyright 2024 Buf Technologies, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package check + +import ( + "slices" + "strings" +) + +// CompareAnnotations returns -1 if one < two, 1 if one > two, 0 otherwise. +func CompareAnnotations(one Annotation, two Annotation) int { + if one == nil && two == nil { + return 0 + } + if one == nil && two != nil { + return -1 + } + if one != nil && two == nil { + return 1 + } + if compare := strings.Compare(one.RuleID(), two.RuleID()); compare != 0 { + return compare + } + + if compare := CompareLocations(one.Location(), two.Location()); compare != 0 { + return compare + } + + if compare := CompareLocations(one.AgainstLocation(), two.AgainstLocation()); compare != 0 { + return compare + } + return strings.Compare(one.Message(), two.Message()) +} + +// CompareLocations returns -1 if one < two, 1 if one > two, 0 otherwise. +func CompareLocations(one Location, two Location) int { + if one == nil && two == nil { + return 0 + } + if one == nil && two != nil { + return -1 + } + if one != nil && two == nil { + return 1 + } + if compare := strings.Compare(one.File().FileDescriptor().Path(), two.File().FileDescriptor().Path()); compare != 0 { + return compare + } + + if compare := intCompare(one.StartLine(), two.StartLine()); compare != 0 { + return compare + } + + if compare := intCompare(one.StartColumn(), two.StartColumn()); compare != 0 { + return compare + } + + if compare := intCompare(one.EndLine(), two.EndLine()); compare != 0 { + return compare + } + + if compare := intCompare(one.EndColumn(), two.EndColumn()); compare != 0 { + return compare + } + + if compare := slices.Compare(one.unclonedSourcePath(), two.unclonedSourcePath()); compare != 0 { + return compare + } + + if compare := strings.Compare(one.LeadingComments(), two.LeadingComments()); compare != 0 { + return compare + } + + if compare := strings.Compare(one.TrailingComments(), two.TrailingComments()); compare != 0 { + return compare + } + return slices.Compare(one.unclonedLeadingDetachedComments(), two.unclonedLeadingDetachedComments()) +} + +// CompareRules returns -1 if one < two, 1 if one > two, 0 otherwise. +func CompareRules(one Rule, two Rule) int { + if one == nil && two == nil { + return 0 + } + if one == nil && two != nil { + return -1 + } + if one != nil && two == nil { + return 1 + } + return strings.Compare(one.ID(), two.ID()) +} + +// CompareCategories returns -1 if one < two, 1 if one > two, 0 otherwise. +func CompareCategories(one Category, two Category) int { + if one == nil && two == nil { + return 0 + } + if one == nil && two != nil { + return -1 + } + if one != nil && two == nil { + return 1 + } + return strings.Compare(one.ID(), two.ID()) +} + +// *** PRIVATE *** + +// compareRuleSpecs returns -1 if one < two, 1 if one > two, 0 otherwise. +func compareRuleSpecs(one *RuleSpec, two *RuleSpec) int { + if one == nil && two == nil { + return 0 + } + if one == nil && two != nil { + return -1 + } + if one != nil && two == nil { + return 1 + } + return strings.Compare(one.ID, two.ID) +} + +// compareCategorySpecs returns -1 if one < two, 1 if one > two, 0 otherwise. +func compareCategorySpecs(one *CategorySpec, two *CategorySpec) int { + if one == nil && two == nil { + return 0 + } + if one == nil && two != nil { + return -1 + } + if one != nil && two == nil { + return 1 + } + return strings.Compare(one.ID, two.ID) +} + +func intCompare(one int, two int) int { + if one < two { + return -1 + } + if one > two { + return 1 + } + return 0 +} diff --git a/check/errors.go b/check/errors.go new file mode 100644 index 0000000..8e9bd01 --- /dev/null +++ b/check/errors.go @@ -0,0 +1,223 @@ +// Copyright 2024 Buf Technologies, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package check + +import ( + "errors" + "fmt" + "strings" +) + +type duplicateRuleIDError struct { + duplicateIDs []string +} + +func newDuplicateRuleIDError(duplicateIDs []string) *duplicateRuleIDError { + return &duplicateRuleIDError{ + duplicateIDs: duplicateIDs, + } +} + +func (r *duplicateRuleIDError) Error() string { + if r == nil { + return "" + } + if len(r.duplicateIDs) == 0 { + return "" + } + var sb strings.Builder + _, _ = sb.WriteString("duplicate rule IDs: ") + _, _ = sb.WriteString(strings.Join(r.duplicateIDs, ", ")) + return sb.String() +} + +type duplicateCategoryIDError struct { + duplicateIDs []string +} + +func newDuplicateCategoryIDError(duplicateIDs []string) *duplicateCategoryIDError { + return &duplicateCategoryIDError{ + duplicateIDs: duplicateIDs, + } +} + +func (c *duplicateCategoryIDError) Error() string { + if c == nil { + return "" + } + if len(c.duplicateIDs) == 0 { + return "" + } + var sb strings.Builder + _, _ = sb.WriteString("duplicate category IDs: ") + _, _ = sb.WriteString(strings.Join(c.duplicateIDs, ", ")) + return sb.String() +} + +type duplicateRuleOrCategoryIDError struct { + duplicateIDs []string +} + +func newDuplicateRuleOrCategoryIDError(duplicateIDs []string) *duplicateRuleOrCategoryIDError { + return &duplicateRuleOrCategoryIDError{ + duplicateIDs: duplicateIDs, + } +} + +func (o *duplicateRuleOrCategoryIDError) Error() string { + if o == nil { + return "" + } + if len(o.duplicateIDs) == 0 { + return "" + } + var sb strings.Builder + _, _ = sb.WriteString("duplicate rule or category IDs: ") + _, _ = sb.WriteString(strings.Join(o.duplicateIDs, ", ")) + return sb.String() +} + +type unexpectedOptionValueTypeError struct { + key string + expected any + actual any +} + +func newUnexpectedOptionValueTypeError(key string, expected any, actual any) *unexpectedOptionValueTypeError { + return &unexpectedOptionValueTypeError{ + key: key, + expected: expected, + actual: actual, + } +} + +func (u *unexpectedOptionValueTypeError) Error() string { + if u == nil { + return "" + } + var sb strings.Builder + _, _ = sb.WriteString(`unexpected type for option value "`) + _, _ = sb.WriteString(u.key) + _, _ = sb.WriteString(fmt.Sprintf(`": expected %T, got %T`, u.expected, u.actual)) + return sb.String() +} + +type validateRuleSpecError struct { + delegate error +} + +func newValidateRuleSpecErrorf(format string, args ...any) *validateRuleSpecError { + return &validateRuleSpecError{ + delegate: fmt.Errorf(format, args...), + } +} + +func wrapValidateRuleSpecError(delegate error) *validateRuleSpecError { + return &validateRuleSpecError{ + delegate: delegate, + } +} + +func (vr *validateRuleSpecError) Error() string { + if vr == nil { + return "" + } + if vr.delegate == nil { + return "" + } + var sb strings.Builder + _, _ = sb.WriteString(`invalid check.RuleSpec: `) + _, _ = sb.WriteString(vr.delegate.Error()) + return sb.String() +} + +func (vr *validateRuleSpecError) Unwrap() error { + if vr == nil { + return nil + } + return vr.delegate +} + +type validateCategorySpecError struct { + delegate error +} + +func newValidateCategorySpecErrorf(format string, args ...any) *validateCategorySpecError { + return &validateCategorySpecError{ + delegate: fmt.Errorf(format, args...), + } +} + +func wrapValidateCategorySpecError(delegate error) *validateCategorySpecError { + return &validateCategorySpecError{ + delegate: delegate, + } +} + +func (vr *validateCategorySpecError) Error() string { + if vr == nil { + return "" + } + if vr.delegate == nil { + return "" + } + var sb strings.Builder + _, _ = sb.WriteString(`invalid check.CategorySpec: `) + _, _ = sb.WriteString(vr.delegate.Error()) + return sb.String() +} + +func (vr *validateCategorySpecError) Unwrap() error { + if vr == nil { + return nil + } + return vr.delegate +} + +type validateSpecError struct { + delegate error +} + +func newValidateSpecError(message string) *validateSpecError { + return &validateSpecError{ + delegate: errors.New(message), + } +} + +func wrapValidateSpecError(delegate error) *validateSpecError { + return &validateSpecError{ + delegate: delegate, + } +} + +func (vr *validateSpecError) Error() string { + if vr == nil { + return "" + } + if vr.delegate == nil { + return "" + } + var sb strings.Builder + _, _ = sb.WriteString(`invalid check.Spec: `) + _, _ = sb.WriteString(vr.delegate.Error()) + return sb.String() +} + +func (vr *validateSpecError) Unwrap() error { + if vr == nil { + return nil + } + return vr.delegate +} diff --git a/check/file.go b/check/file.go new file mode 100644 index 0000000..13f164e --- /dev/null +++ b/check/file.go @@ -0,0 +1,206 @@ +// Copyright 2024 Buf Technologies, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package check + +import ( + "fmt" + "slices" + + checkv1 "buf.build/gen/go/bufbuild/bufplugin/protocolbuffers/go/buf/plugin/check/v1" + "google.golang.org/protobuf/reflect/protodesc" + "google.golang.org/protobuf/reflect/protoreflect" + "google.golang.org/protobuf/types/descriptorpb" +) + +// File is an invidual file that should be checked. +// +// Both the protoreflect FileDescriptor and the raw FileDescriptorProto interacves +// are provided. +// +// Files also have the property of being imports or non-imports. +type File interface { + // FileDescriptor returns the protoreflect FileDescriptor representing this File. + // + // This will always contain SourceCodeInfo. + FileDescriptor() protoreflect.FileDescriptor + // FileDescriptorProto returns the FileDescriptorProto representing this File. + // + // This is not a copy - do not modify! + FileDescriptorProto() *descriptorpb.FileDescriptorProto + // IsImport returns true if the File is an import. + // + // An import is a file that is either: + // + // - A Well-Known Type included from the compiler and imported by a targeted file. + // - A file that was included from a Buf module dependency and imported by a targeted file. + // - A file that was not targeted, but was imported by a targeted file. + // + // We use "import" as this matches with the protoc concept of --include_imports, however + // import is a bit of an overloaded term. + IsImport() bool + + // IsSyntaxUnspecified denotes whether the file did not have a syntax explicitly specified. + // + // Per the FileDescriptorProto spec, it would be fine in this case to just leave the syntax field + // unset to denote this and to set the syntax field to "proto2" if it is specified. However, + // protoc does not set the syntax field if it was "proto2". Plugins may want to differentiate + // between "proto2" and unset, and this field allows them to. + IsSyntaxUnspecified() bool + + // UnusedDependencyIndexes are the indexes within the Dependency field on FileDescriptorProto for + // those dependencies that are not used. + // + // This matches the shape of the PublicDependency and WeakDependency fields. + UnusedDependencyIndexes() []int32 + + toProto() *checkv1.File + + isFile() +} + +// FilesForProtoFiles returns a new slice of Files for the given checkv1.Files. +func FilesForProtoFiles(protoFiles []*checkv1.File) ([]File, error) { + if len(protoFiles) == 0 { + return nil, nil + } + fileNameToProtoFile := make(map[string]*checkv1.File, len(protoFiles)) + fileDescriptorProtos := make([]*descriptorpb.FileDescriptorProto, len(protoFiles)) + for i, protoFile := range protoFiles { + fileDescriptorProto := protoFile.GetFileDescriptorProto() + fileName := fileDescriptorProto.GetName() + if _, ok := fileNameToProtoFile[fileName]; ok { + // This should have been validated via protovalidate. + return nil, fmt.Errorf("duplicate file name: %q", fileName) + } + fileDescriptorProtos[i] = fileDescriptorProto + fileNameToProtoFile[fileName] = protoFile + } + + protoregistryFiles, err := protodesc.NewFiles( + &descriptorpb.FileDescriptorSet{ + File: fileDescriptorProtos, + }, + ) + if err != nil { + return nil, err + } + + files := make([]File, 0, len(protoFiles)) + protoregistryFiles.RangeFiles( + func(fileDescriptor protoreflect.FileDescriptor) bool { + protoFile, ok := fileNameToProtoFile[fileDescriptor.Path()] + if !ok { + // If the protoreflect API is sane, this should never happen. + // However, the protoreflect API is not sane. + err = fmt.Errorf("unknown file: %q", fileDescriptor.Path()) + return false + } + files = append( + files, + newFile( + fileDescriptor, + protoFile.GetFileDescriptorProto(), + protoFile.GetIsImport(), + protoFile.GetIsSyntaxUnspecified(), + protoFile.GetUnusedDependency(), + ), + ) + return true + }, + ) + if err != nil { + return nil, err + } + if len(files) != len(protoFiles) { + // If the protoreflect API is sane, this should never happen. + // However, the protoreflect API is not sane. + return nil, fmt.Errorf("expected %d files from protoregistry, got %d", len(protoFiles), len(files)) + } + return files, nil +} + +// *** PRIVATE *** + +type file struct { + fileDescriptor protoreflect.FileDescriptor + fileDescriptorProto *descriptorpb.FileDescriptorProto + isImport bool + isSyntaxUnspecified bool + unusedDependencyIndexes []int32 +} + +func newFile( + fileDescriptor protoreflect.FileDescriptor, + fileDescriptorProto *descriptorpb.FileDescriptorProto, + isImport bool, + isSyntaxUnspecified bool, + unusedDependencyIndexes []int32, +) *file { + return &file{ + fileDescriptor: fileDescriptor, + fileDescriptorProto: fileDescriptorProto, + isImport: isImport, + isSyntaxUnspecified: isSyntaxUnspecified, + unusedDependencyIndexes: unusedDependencyIndexes, + } +} + +func (f *file) FileDescriptor() protoreflect.FileDescriptor { + return f.fileDescriptor +} + +func (f *file) FileDescriptorProto() *descriptorpb.FileDescriptorProto { + return f.fileDescriptorProto +} + +func (f *file) IsImport() bool { + return f.isImport +} + +func (f *file) IsSyntaxUnspecified() bool { + return f.isSyntaxUnspecified +} + +func (f *file) UnusedDependencyIndexes() []int32 { + return slices.Clone(f.unusedDependencyIndexes) +} + +func (f *file) toProto() *checkv1.File { + return &checkv1.File{ + FileDescriptorProto: f.fileDescriptorProto, + IsImport: f.isImport, + IsSyntaxUnspecified: f.isSyntaxUnspecified, + UnusedDependency: f.unusedDependencyIndexes, + } +} + +func (*file) isFile() {} + +func validateFiles(files []File) error { + _, err := fileNameToFileForFiles(files) + return err +} + +func fileNameToFileForFiles(files []File) (map[string]File, error) { + fileNameToFile := make(map[string]File, len(files)) + for _, file := range files { + fileName := file.FileDescriptor().Path() + if _, ok := fileNameToFile[fileName]; ok { + return nil, fmt.Errorf("duplicate file name: %q", fileName) + } + fileNameToFile[fileName] = file + } + return fileNameToFile, nil +} diff --git a/check/internal/example/buf.gen.yaml b/check/internal/example/buf.gen.yaml new file mode 100644 index 0000000..9d9e02d --- /dev/null +++ b/check/internal/example/buf.gen.yaml @@ -0,0 +1,13 @@ +version: v2 +inputs: + - directory: proto +managed: + enabled: true + override: + - file_option: go_package_prefix + value: github.com/bufbuild/bufplugin-go/check/internal/example/gen +plugins: + - remote: buf.build/protocolbuffers/go + out: gen + opt: paths=source_relative +clean: true diff --git a/check/internal/example/cmd/buf-plugin-field-lower-snake-case/main.go b/check/internal/example/cmd/buf-plugin-field-lower-snake-case/main.go new file mode 100644 index 0000000..462a65d --- /dev/null +++ b/check/internal/example/cmd/buf-plugin-field-lower-snake-case/main.go @@ -0,0 +1,121 @@ +// Copyright 2024 Buf Technologies, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package main implements a simple plugin that checks that all field names are lower_snake_case. +// +// To use this plugin: +// +// # buf.yaml +// version: v2 +// lint: +// use: +// - STANDARD # omit if you do not want to use the rules builtin to buf +// - PLUGIN_FIELD_LOWER_SNAKE_CASE +// plugins: +// - plugin: buf-plugin-field-lower-snake-case +// +// Note that the buf CLI implements this check as a builtin Rule, but this is just for example. +package main + +import ( + "context" + "strings" + "unicode" + + "github.com/bufbuild/bufplugin-go/check" + "github.com/bufbuild/bufplugin-go/check/checkutil" + "google.golang.org/protobuf/reflect/protoreflect" +) + +// fieldLowerSnakeCaseRuleID is the Rule ID of the timestamp suffix Rule. +// +// This has a "PLUGIN_" prefix as the buf CLI has a rule "FIELD_LOWER_SNAKE_CASE" builtin, +// and plugins/the buf CLI must have unique Rule IDs. +const fieldLowerSnakeCaseRuleID = "PLUGIN_FIELD_LOWER_SNAKE_CASE" + +var ( + // fieldLowerSnakeCaseRuleSpec is the RuleSpec for the timestamp suffix Rule. + fieldLowerSnakeCaseRuleSpec = &check.RuleSpec{ + ID: fieldLowerSnakeCaseRuleID, + Default: true, + Purpose: "Checks that all field names are lower_snake_case.", + Type: check.RuleTypeLint, + Handler: checkutil.NewFieldRuleHandler(checkFieldLowerSnakeCase, checkutil.WithoutImports()), + } + + // spec is the Spec for the timestamp suffix plugin. + spec = &check.Spec{ + Rules: []*check.RuleSpec{ + fieldLowerSnakeCaseRuleSpec, + }, + } +) + +func main() { + check.Main(spec) +} + +func checkFieldLowerSnakeCase( + _ context.Context, + responseWriter check.ResponseWriter, + _ check.Request, + fieldDescriptor protoreflect.FieldDescriptor, +) error { + fieldName := string(fieldDescriptor.Name()) + fieldNameToLowerSnakeCase := toLowerSnakeCase(fieldName) + if fieldName != fieldNameToLowerSnakeCase { + responseWriter.AddAnnotation( + check.WithMessagef("Field name %q should be lower_snake_case, such as %q.", fieldName, fieldNameToLowerSnakeCase), + check.WithDescriptor(fieldDescriptor), + ) + } + return nil +} + +func toLowerSnakeCase(s string) string { + return strings.ToLower(toSnakeCase(s)) +} + +func toSnakeCase(s string) string { + output := "" + s = strings.TrimFunc(s, isDelimiter) + for i, c := range s { + if isDelimiter(c) { + c = '_' + } + switch { + case i == 0: + output += string(c) + case isSnakeCaseNewWord(c, false) && + output[len(output)-1] != '_' && + ((i < len(s)-1 && !isSnakeCaseNewWord(rune(s[i+1]), true) && !isDelimiter(rune(s[i+1]))) || + (unicode.IsLower(rune(s[i-1])))): + output += "_" + string(c) + case !(isDelimiter(c) && output[len(output)-1] == '_'): + output += string(c) + } + } + return output +} + +func isSnakeCaseNewWord(r rune, newWordOnDigits bool) bool { + if newWordOnDigits { + return unicode.IsUpper(r) || unicode.IsDigit(r) + } + return unicode.IsUpper(r) +} + +func isDelimiter(r rune) bool { + return r == '.' || r == '-' || r == '_' || r == ' ' || r == '\t' || r == '\n' || r == '\r' +} diff --git a/check/internal/example/cmd/buf-plugin-field-lower-snake-case/main_test.go b/check/internal/example/cmd/buf-plugin-field-lower-snake-case/main_test.go new file mode 100644 index 0000000..3c0409f --- /dev/null +++ b/check/internal/example/cmd/buf-plugin-field-lower-snake-case/main_test.go @@ -0,0 +1,52 @@ +// Copyright 2024 Buf Technologies, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + "testing" + + "github.com/bufbuild/bufplugin-go/check/checktest" +) + +func TestSpec(t *testing.T) { + t.Parallel() + checktest.SpecTest(t, spec) +} + +func TestSimple(t *testing.T) { + t.Parallel() + + checktest.CheckTest{ + Request: &checktest.RequestSpec{ + Files: &checktest.ProtoFileSpec{ + DirPaths: []string{"testdata/simple"}, + FilePaths: []string{"simple.proto"}, + }, + }, + Spec: spec, + ExpectedAnnotations: []checktest.ExpectedAnnotation{ + { + RuleID: fieldLowerSnakeCaseRuleID, + Location: &checktest.ExpectedLocation{ + FileName: "simple.proto", + StartLine: 6, + StartColumn: 2, + EndLine: 6, + EndColumn: 23, + }, + }, + }, + }.Run(t) +} diff --git a/check/internal/example/cmd/buf-plugin-field-lower-snake-case/testdata/simple/simple.proto b/check/internal/example/cmd/buf-plugin-field-lower-snake-case/testdata/simple/simple.proto new file mode 100644 index 0000000..e3249a6 --- /dev/null +++ b/check/internal/example/cmd/buf-plugin-field-lower-snake-case/testdata/simple/simple.proto @@ -0,0 +1,8 @@ +syntax = "proto3"; + +package simple; + +message Foo { + int32 lower_snake_case = 1; + int32 PascalCase = 2; +} diff --git a/check/internal/example/cmd/buf-plugin-field-option-safe-for-ml/main.go b/check/internal/example/cmd/buf-plugin-field-option-safe-for-ml/main.go new file mode 100644 index 0000000..a9981be --- /dev/null +++ b/check/internal/example/cmd/buf-plugin-field-option-safe-for-ml/main.go @@ -0,0 +1,209 @@ +// Copyright 2024 Buf Technologies, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package main implements a plugin that implements two Rules: +// +// - A lint Rule that checks that every field has the option (acme.option.v1.safe_for_ml) explicitly set. +// - A breaking Rule that verifes that no field goes from having option (acme.option.v1.safe_for_ml) going +// from true to false. That is, if a field is marked as safe, it can not then be moved to unsafe. +// +// This is an example of a plugin that will check a custom option, which is a very typical +// case for a custom lint or breaking change plugin. In this case, we're saying that an organization +// wants to explicitly mark every field in its schemas as either safe to train ML models on, or +// unsafe to train models on. This plugin enforces that all fields have such markings, and that +// those fields do not transition from safe to unsafe. +// +// This plugin also demonstrates the usage of categories. The Rules have IDs: +// +// - FIELD_OPTION_SAFE_FOR_ML_SET +// - FIELD_OPTION_SAFE_FOR_ML_STAYS_TRUE +// +// However, the Rules both belong to category FIELD_OPTION_SAFE_FOR_ML. This means that you +// do not need to specify the individual rules in your configuration. You can just specify +// the Category, and all Rules in this Category will be included. +// +// To use this plugin: +// +// # buf.yaml +// version: v2 +// lint: +// use: +// - STANDARD # omit if you do not want to use the rules builtin to buf +// - FIELD_OPTION_SAFE_FOR_ML +// breaking: +// use: +// - WIRE_JSON # omit if you do not want to use the rules builtin to buf +// - FIELD_OPTION_SAFE_FOR_ML +// plugins: +// - plugin: buf-plugin-field-option-safe-for-ml +package main + +import ( + "context" + "fmt" + + "github.com/bufbuild/bufplugin-go/check" + "github.com/bufbuild/bufplugin-go/check/checkutil" + optionv1 "github.com/bufbuild/bufplugin-go/check/internal/example/gen/acme/option/v1" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/reflect/protoreflect" + "google.golang.org/protobuf/types/descriptorpb" +) + +const ( + // fieldOptionSafeForMLSetRuleID is the Rule ID of the field option safe for ML set Rule. + fieldOptionSafeForMLSetRuleID = "FIELD_OPTION_SAFE_FOR_ML_SET" + // fieldOptionSafeForMLStaysTrueRuleID is the Rule ID of the field option safe for ML stays true Rule. + fieldOptionSafeForMLStaysTrueRuleID = "FIELD_OPTION_SAFE_FOR_ML_STAYS_TRUE" + // fieldOptionSafeForMLCategoryID is the Category ID for the rules concerning (acme.option.v1.safe_for_ml). + fieldOptionSafeForMLCategoryID = "FIELD_OPTION_SAFE_FOR_ML" +) + +var ( + // fieldOptionSafeForMLRuleSpec is the RuleSpec for the field option safe for ML set Rule. + fieldOptionSafeForMLSetRuleSpec = &check.RuleSpec{ + ID: fieldOptionSafeForMLSetRuleID, + Default: true, + Purpose: "Checks that every field has option (acme.option.v1.safe_for_ml) explicitly set.", + CategoryIDs: []string{ + fieldOptionSafeForMLCategoryID, + }, + Type: check.RuleTypeLint, + Handler: checkutil.NewFieldRuleHandler(checkFieldOptionSafeForMLSet, checkutil.WithoutImports()), + } + // fieldOptionSafeForMLStaysTrueRuleSpec is the RuleSpec for the field option safe for ML stays true Rule. + fieldOptionSafeForMLStaysTrueRuleSpec = &check.RuleSpec{ + ID: fieldOptionSafeForMLStaysTrueRuleID, + Default: true, + Purpose: "Checks that every field marked with (acme.option.v1.safe_for_ml) = true does not change to false.", + CategoryIDs: []string{ + fieldOptionSafeForMLCategoryID, + }, + Type: check.RuleTypeBreaking, + Handler: checkutil.NewFieldPairRuleHandler(checkFieldOptionSafeForMLStaysTrue, checkutil.WithoutImports()), + } + fieldOptionSafeForMLCategorySpec = &check.CategorySpec{ + ID: fieldOptionSafeForMLCategoryID, + Purpose: "Checks properties around the (acme.option.v1.safe_for_ml) option.", + } + + // spec is the Spec for the syntax specified plugin. + spec = &check.Spec{ + Rules: []*check.RuleSpec{ + fieldOptionSafeForMLSetRuleSpec, + fieldOptionSafeForMLStaysTrueRuleSpec, + }, + Categories: []*check.CategorySpec{ + fieldOptionSafeForMLCategorySpec, + }, + } +) + +func main() { + check.Main(spec) +} + +func checkFieldOptionSafeForMLSet( + _ context.Context, + responseWriter check.ResponseWriter, + _ check.Request, + fieldDescriptor protoreflect.FieldDescriptor, +) error { + // Ignore the actual field options - we don't need to mark safe_for_ml as safe_for_ml. + if fieldDescriptor.ContainingMessage().FullName() == "google.protobuf.FieldOptions" { + return nil + } + fieldOptions, err := getFieldOptions(fieldDescriptor) + if err != nil { + return err + } + if !proto.HasExtension(fieldOptions, optionv1.E_SafeForMl) { + responseWriter.AddAnnotation( + check.WithMessagef( + "Field %q on message %q should have option (acme.option.v1.safe_for_ml) explicitly set.", + fieldDescriptor.Name(), + fieldDescriptor.ContainingMessage().FullName(), + ), + check.WithDescriptor(fieldDescriptor), + ) + } + return nil +} + +func checkFieldOptionSafeForMLStaysTrue( + _ context.Context, + responseWriter check.ResponseWriter, + _ check.Request, + fieldDescriptor protoreflect.FieldDescriptor, + againstFieldDescriptor protoreflect.FieldDescriptor, +) error { + // Ignore the actual field options - we don't need to mark safe_for_ml as safe_for_ml. + if fieldDescriptor.ContainingMessage().FullName() == "google.protobuf.FieldOptions" { + return nil + } + againstSafeForML, err := getSafeForML(againstFieldDescriptor) + if err != nil { + return err + } + if !againstSafeForML { + // If the field does not have safe_for_ml or safe_for_ml is false, we are done. It is up to the + // lint Rule to enforce whether or not every field has this option explicitly set. + return nil + } + safeForML, err := getSafeForML(fieldDescriptor) + if err != nil { + return err + } + if !safeForML { + responseWriter.AddAnnotation( + check.WithMessagef( + "Field %q on message %q should had option (acme.option.v1.safe_for_ml) change from true to false.", + fieldDescriptor.Name(), + fieldDescriptor.ContainingMessage().FullName(), + ), + check.WithDescriptor(fieldDescriptor), + check.WithAgainstDescriptor(againstFieldDescriptor), + ) + } + return nil +} + +func getFieldOptions(fieldDescriptor protoreflect.FieldDescriptor) (*descriptorpb.FieldOptions, error) { + fieldOptions, ok := fieldDescriptor.Options().(*descriptorpb.FieldOptions) + if !ok { + // This should never happen. + return nil, fmt.Errorf("expected *descriptorpb.FieldOptions for FieldDescriptor %q Options but got %T", fieldDescriptor.FullName(), fieldOptions) + } + return fieldOptions, nil +} + +func getSafeForML(fieldDescriptor protoreflect.FieldDescriptor) (bool, error) { + fieldOptions, err := getFieldOptions(fieldDescriptor) + if err != nil { + return false, err + } + if !proto.HasExtension(fieldOptions, optionv1.E_SafeForMl) { + return false, nil + } + safeForMLIface := proto.GetExtension(fieldOptions, optionv1.E_SafeForMl) + if safeForMLIface == nil { + return false, fmt.Errorf("expected non-nil value for FieldDescriptor %q option value", fieldDescriptor.FullName()) + } + safeForML, ok := safeForMLIface.(bool) + if !ok { + // This should never happen. + return false, fmt.Errorf("expected bool for FieldDescriptor %q option value but got %T", fieldDescriptor.FullName(), safeForMLIface) + } + return safeForML, nil +} diff --git a/check/internal/example/cmd/buf-plugin-field-option-safe-for-ml/main_test.go b/check/internal/example/cmd/buf-plugin-field-option-safe-for-ml/main_test.go new file mode 100644 index 0000000..18b0f81 --- /dev/null +++ b/check/internal/example/cmd/buf-plugin-field-option-safe-for-ml/main_test.go @@ -0,0 +1,157 @@ +// Copyright 2024 Buf Technologies, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + "testing" + + "github.com/bufbuild/bufplugin-go/check/checktest" +) + +func TestSpec(t *testing.T) { + t.Parallel() + checktest.SpecTest(t, spec) +} + +func TestSimpleSuccess(t *testing.T) { + t.Parallel() + + checktest.CheckTest{ + Request: &checktest.RequestSpec{ + Files: &checktest.ProtoFileSpec{ + DirPaths: []string{ + "../../proto", + "testdata/simple_success", + }, + FilePaths: []string{ + "acme/option/v1/option.proto", + "simple.proto", + }, + }, + }, + Spec: spec, + }.Run(t) +} + +func TestSimpleFailure(t *testing.T) { + t.Parallel() + + checktest.CheckTest{ + Request: &checktest.RequestSpec{ + Files: &checktest.ProtoFileSpec{ + DirPaths: []string{ + "../../proto", + "testdata/simple_failure", + }, + FilePaths: []string{ + "acme/option/v1/option.proto", + "simple.proto", + }, + }, + }, + Spec: spec, + ExpectedAnnotations: []checktest.ExpectedAnnotation{ + { + RuleID: fieldOptionSafeForMLSetRuleID, + Location: &checktest.ExpectedLocation{ + FileName: "simple.proto", + StartLine: 8, + StartColumn: 2, + EndLine: 8, + EndColumn: 17, + }, + }, + }, + }.Run(t) +} + +func TestChangeSuccess(t *testing.T) { + t.Parallel() + + checktest.CheckTest{ + Request: &checktest.RequestSpec{ + Files: &checktest.ProtoFileSpec{ + DirPaths: []string{ + "../../proto", + "testdata/change_success/current", + }, + FilePaths: []string{ + "acme/option/v1/option.proto", + "simple.proto", + }, + }, + AgainstFiles: &checktest.ProtoFileSpec{ + DirPaths: []string{ + "../../proto", + "testdata/change_success/previous", + }, + FilePaths: []string{ + "acme/option/v1/option.proto", + "simple.proto", + }, + }, + }, + Spec: spec, + }.Run(t) +} + +func TestChangeFailure(t *testing.T) { + t.Parallel() + + checktest.CheckTest{ + Request: &checktest.RequestSpec{ + Files: &checktest.ProtoFileSpec{ + DirPaths: []string{ + "../../proto", + "testdata/change_failure/current", + }, + FilePaths: []string{ + "acme/option/v1/option.proto", + "simple.proto", + }, + }, + AgainstFiles: &checktest.ProtoFileSpec{ + DirPaths: []string{ + "../../proto", + "testdata/change_failure/previous", + }, + FilePaths: []string{ + "acme/option/v1/option.proto", + "simple.proto", + }, + }, + }, + Spec: spec, + ExpectedAnnotations: []checktest.ExpectedAnnotation{ + { + RuleID: fieldOptionSafeForMLStaysTrueRuleID, + Location: &checktest.ExpectedLocation{ + FileName: "simple.proto", + StartLine: 8, + StartColumn: 2, + EndLine: 8, + EndColumn: 56, + }, + AgainstLocation: &checktest.ExpectedLocation{ + FileName: "simple.proto", + StartLine: 8, + StartColumn: 2, + EndLine: 8, + EndColumn: 55, + }, + }, + }, + }.Run(t) +} diff --git a/check/internal/example/cmd/buf-plugin-field-option-safe-for-ml/testdata/change_failure/current/simple.proto b/check/internal/example/cmd/buf-plugin-field-option-safe-for-ml/testdata/change_failure/current/simple.proto new file mode 100644 index 0000000..d460ea1 --- /dev/null +++ b/check/internal/example/cmd/buf-plugin-field-option-safe-for-ml/testdata/change_failure/current/simple.proto @@ -0,0 +1,10 @@ +syntax = "proto3"; + +package simple; + +import "acme/option/v1/option.proto"; + +message User { + string name = 1 [(acme.option.v1.safe_for_ml) = false]; + string age = 2 [(acme.option.v1.safe_for_ml) = false]; +} diff --git a/check/internal/example/cmd/buf-plugin-field-option-safe-for-ml/testdata/change_failure/previous/simple.proto b/check/internal/example/cmd/buf-plugin-field-option-safe-for-ml/testdata/change_failure/previous/simple.proto new file mode 100644 index 0000000..2e3f29b --- /dev/null +++ b/check/internal/example/cmd/buf-plugin-field-option-safe-for-ml/testdata/change_failure/previous/simple.proto @@ -0,0 +1,10 @@ +syntax = "proto3"; + +package simple; + +import "acme/option/v1/option.proto"; + +message User { + string name = 1 [(acme.option.v1.safe_for_ml) = false]; + string age = 2 [(acme.option.v1.safe_for_ml) = true]; +} diff --git a/check/internal/example/cmd/buf-plugin-field-option-safe-for-ml/testdata/change_success/current/simple.proto b/check/internal/example/cmd/buf-plugin-field-option-safe-for-ml/testdata/change_success/current/simple.proto new file mode 100644 index 0000000..ae9eb99 --- /dev/null +++ b/check/internal/example/cmd/buf-plugin-field-option-safe-for-ml/testdata/change_success/current/simple.proto @@ -0,0 +1,10 @@ +syntax = "proto3"; + +package simple; + +import "acme/option/v1/option.proto"; + +message User { + string name = 1 [(acme.option.v1.safe_for_ml) = true]; + string age = 2 [(acme.option.v1.safe_for_ml) = true]; +} diff --git a/check/internal/example/cmd/buf-plugin-field-option-safe-for-ml/testdata/change_success/previous/simple.proto b/check/internal/example/cmd/buf-plugin-field-option-safe-for-ml/testdata/change_success/previous/simple.proto new file mode 100644 index 0000000..2e3f29b --- /dev/null +++ b/check/internal/example/cmd/buf-plugin-field-option-safe-for-ml/testdata/change_success/previous/simple.proto @@ -0,0 +1,10 @@ +syntax = "proto3"; + +package simple; + +import "acme/option/v1/option.proto"; + +message User { + string name = 1 [(acme.option.v1.safe_for_ml) = false]; + string age = 2 [(acme.option.v1.safe_for_ml) = true]; +} diff --git a/check/internal/example/cmd/buf-plugin-field-option-safe-for-ml/testdata/simple_failure/simple.proto b/check/internal/example/cmd/buf-plugin-field-option-safe-for-ml/testdata/simple_failure/simple.proto new file mode 100644 index 0000000..d2279bf --- /dev/null +++ b/check/internal/example/cmd/buf-plugin-field-option-safe-for-ml/testdata/simple_failure/simple.proto @@ -0,0 +1,10 @@ +syntax = "proto3"; + +package simple; + +import "acme/option/v1/option.proto"; + +message User { + string name = 1 [(acme.option.v1.safe_for_ml) = false]; + string age = 2; +} diff --git a/check/internal/example/cmd/buf-plugin-field-option-safe-for-ml/testdata/simple_success/simple.proto b/check/internal/example/cmd/buf-plugin-field-option-safe-for-ml/testdata/simple_success/simple.proto new file mode 100644 index 0000000..2e3f29b --- /dev/null +++ b/check/internal/example/cmd/buf-plugin-field-option-safe-for-ml/testdata/simple_success/simple.proto @@ -0,0 +1,10 @@ +syntax = "proto3"; + +package simple; + +import "acme/option/v1/option.proto"; + +message User { + string name = 1 [(acme.option.v1.safe_for_ml) = false]; + string age = 2 [(acme.option.v1.safe_for_ml) = true]; +} diff --git a/check/internal/example/cmd/buf-plugin-syntax-specified/main.go b/check/internal/example/cmd/buf-plugin-syntax-specified/main.go new file mode 100644 index 0000000..375c6c9 --- /dev/null +++ b/check/internal/example/cmd/buf-plugin-syntax-specified/main.go @@ -0,0 +1,83 @@ +// Copyright 2024 Buf Technologies, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package main implements a simple plugin that checks that syntax is specified in every file. +// +// This is just demonstrating the additional functionality that check.Files have +// over FileDescriptors and FileDescriptorProtos. +// +// To use this plugin: +// +// # buf.yaml +// version: v2 +// lint: +// use: +// - STANDARD # omit if you do not want to use the rules builtin to buf +// - PLUGIN_SYNTAX_SPECIFIED +// plugins: +// - plugin: buf-plugin-syntax-specified +// +// Note that the buf CLI implements this check by as a builtin rule, but this is just for example. +package main + +import ( + "context" + + "github.com/bufbuild/bufplugin-go/check" + "github.com/bufbuild/bufplugin-go/check/checkutil" +) + +// syntaxSpecifiedRuleID is the Rule ID of the syntax specified Rule. +// +// This has a "PLUGIN_" prefix as the buf CLI has a rule "SYNTAX_SPECIFIED" builtin, +// and plugins/the buf CLI must have unique Rule IDs. +const syntaxSpecifiedRuleID = "PLUGIN_SYNTAX_SPECIFIED" + +var ( + // syntaxSpecifiedRuleSpec is the RuleSpec for the syntax specified Rule. + syntaxSpecifiedRuleSpec = &check.RuleSpec{ + ID: syntaxSpecifiedRuleID, + Default: true, + Purpose: "Checks that syntax is specified.", + Type: check.RuleTypeLint, + Handler: checkutil.NewFileRuleHandler(checkSyntaxSpecified, checkutil.WithoutImports()), + } + + // spec is the Spec for the syntax specified plugin. + spec = &check.Spec{ + Rules: []*check.RuleSpec{ + syntaxSpecifiedRuleSpec, + }, + } +) + +func main() { + check.Main(spec) +} + +func checkSyntaxSpecified( + _ context.Context, + responseWriter check.ResponseWriter, + _ check.Request, + file check.File, +) error { + if file.IsSyntaxUnspecified() { + syntax := file.FileDescriptorProto().GetSyntax() + responseWriter.AddAnnotation( + check.WithMessagef("Syntax should be specified but was %q.", syntax), + check.WithDescriptor(file.FileDescriptor()), + ) + } + return nil +} diff --git a/check/internal/example/cmd/buf-plugin-syntax-specified/main_test.go b/check/internal/example/cmd/buf-plugin-syntax-specified/main_test.go new file mode 100644 index 0000000..58a41a2 --- /dev/null +++ b/check/internal/example/cmd/buf-plugin-syntax-specified/main_test.go @@ -0,0 +1,62 @@ +// Copyright 2024 Buf Technologies, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + "testing" + + "github.com/bufbuild/bufplugin-go/check/checktest" +) + +func TestSpec(t *testing.T) { + t.Parallel() + checktest.SpecTest(t, spec) +} + +func TestSimpleSuccess(t *testing.T) { + t.Parallel() + + checktest.CheckTest{ + Request: &checktest.RequestSpec{ + Files: &checktest.ProtoFileSpec{ + DirPaths: []string{"testdata/simple_success"}, + FilePaths: []string{"simple.proto"}, + }, + }, + Spec: spec, + }.Run(t) +} + +func TestSimpleFailure(t *testing.T) { + t.Parallel() + + checktest.CheckTest{ + Request: &checktest.RequestSpec{ + Files: &checktest.ProtoFileSpec{ + DirPaths: []string{"testdata/simple_failure"}, + FilePaths: []string{"simple.proto"}, + }, + }, + Spec: spec, + ExpectedAnnotations: []checktest.ExpectedAnnotation{ + { + RuleID: syntaxSpecifiedRuleID, + Location: &checktest.ExpectedLocation{ + FileName: "simple.proto", + }, + }, + }, + }.Run(t) +} diff --git a/check/internal/example/cmd/buf-plugin-syntax-specified/testdata/simple_failure/simple.proto b/check/internal/example/cmd/buf-plugin-syntax-specified/testdata/simple_failure/simple.proto new file mode 100644 index 0000000..322a627 --- /dev/null +++ b/check/internal/example/cmd/buf-plugin-syntax-specified/testdata/simple_failure/simple.proto @@ -0,0 +1 @@ +package simple; diff --git a/check/internal/example/cmd/buf-plugin-syntax-specified/testdata/simple_success/simple.proto b/check/internal/example/cmd/buf-plugin-syntax-specified/testdata/simple_success/simple.proto new file mode 100644 index 0000000..3a872c9 --- /dev/null +++ b/check/internal/example/cmd/buf-plugin-syntax-specified/testdata/simple_success/simple.proto @@ -0,0 +1,3 @@ +syntax = "proto3"; + +package simple; diff --git a/check/internal/example/cmd/buf-plugin-timestamp-suffix/main.go b/check/internal/example/cmd/buf-plugin-timestamp-suffix/main.go new file mode 100644 index 0000000..eb43274 --- /dev/null +++ b/check/internal/example/cmd/buf-plugin-timestamp-suffix/main.go @@ -0,0 +1,114 @@ +// Copyright 2024 Buf Technologies, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package main implements a simple plugin that checks that all +// google.protobuf.Timestamp fields end in a specific suffix. +// +// To use this plugin: +// +// # buf.yaml +// version: v2 +// lint: +// use: +// - STANDARD # omit if you do not want to use the rules builtin to buf +// - TIMESTAMP_SUFFIX +// plugins: +// - plugin: buf-plugin-timestamp-suffix +// +// The default suffix is "_time", but this can be overridden with the +// "timestamp_suffix" option key in your buf.yaml: +// +// # buf.yaml +// version: v2 +// lint: +// use: +// - STANDARD # omit if you do not want to use the rules builtin to buf +// - TIMESTAMP_SUFFIX +// plugins: +// - plugin: buf-plugin-timestamp-suffix +// options: +// timestamp_suffix: _timestamp +package main + +import ( + "context" + "strings" + + "github.com/bufbuild/bufplugin-go/check" + "github.com/bufbuild/bufplugin-go/check/checkutil" + "google.golang.org/protobuf/reflect/protoreflect" +) + +const ( + // timestampSuffixRuleID is the Rule ID of the timestamp suffix Rule. + timestampSuffixRuleID = "TIMESTAMP_SUFFIX" + + // timestampSuffixOptionKey is the option key to override the default timestamp suffix. + timestampSuffixOptionKey = "timestamp_suffix" + + defaultTimestampSuffix = "_time" +) + +var ( + // timestampSuffixRuleSpec is the RuleSpec for the timestamp suffix Rule. + timestampSuffixRuleSpec = &check.RuleSpec{ + ID: timestampSuffixRuleID, + Default: true, + Purpose: `Checks that all google.protobuf.Timestamps end in a specific suffix (default is "_time").`, + Type: check.RuleTypeLint, + Handler: checkutil.NewFieldRuleHandler(checkTimestampSuffix, checkutil.WithoutImports()), + } + + // spec is the Spec for the timestamp suffix plugin. + spec = &check.Spec{ + Rules: []*check.RuleSpec{ + timestampSuffixRuleSpec, + }, + } +) + +func main() { + check.Main(spec) +} + +func checkTimestampSuffix( + _ context.Context, + responseWriter check.ResponseWriter, + request check.Request, + fieldDescriptor protoreflect.FieldDescriptor, +) error { + timestampSuffix := defaultTimestampSuffix + timestampSuffixOptionValue, err := check.GetStringValue(request.Options(), timestampSuffixOptionKey) + if err != nil { + return err + } + if timestampSuffixOptionValue != "" { + timestampSuffix = timestampSuffixOptionValue + } + + fieldDescriptorType := fieldDescriptor.Message() + if fieldDescriptorType == nil { + return nil + } + if string(fieldDescriptorType.FullName()) != "google.protobuf.Timestamp" { + return nil + } + if !strings.HasSuffix(string(fieldDescriptor.Name()), timestampSuffix) { + responseWriter.AddAnnotation( + check.WithMessagef("Fields of type google.protobuf.Timestamp must end in %q but field name was %q.", timestampSuffix, string(fieldDescriptor.Name())), + check.WithDescriptor(fieldDescriptor), + ) + } + return nil +} diff --git a/check/internal/example/cmd/buf-plugin-timestamp-suffix/main_test.go b/check/internal/example/cmd/buf-plugin-timestamp-suffix/main_test.go new file mode 100644 index 0000000..3077a95 --- /dev/null +++ b/check/internal/example/cmd/buf-plugin-timestamp-suffix/main_test.go @@ -0,0 +1,84 @@ +// Copyright 2024 Buf Technologies, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + "testing" + + "github.com/bufbuild/bufplugin-go/check/checktest" +) + +func TestSpec(t *testing.T) { + t.Parallel() + checktest.SpecTest(t, spec) +} + +func TestSimple(t *testing.T) { + t.Parallel() + + checktest.CheckTest{ + Request: &checktest.RequestSpec{ + Files: &checktest.ProtoFileSpec{ + DirPaths: []string{"testdata/simple"}, + FilePaths: []string{"simple.proto"}, + }, + // This linter only has a single Rule, so this has no effect in this + // test, however this is how you scope a test to a single Rule. + RuleIDs: []string{timestampSuffixRuleID}, + }, + Spec: spec, + ExpectedAnnotations: []checktest.ExpectedAnnotation{ + { + RuleID: timestampSuffixRuleID, + Location: &checktest.ExpectedLocation{ + FileName: "simple.proto", + StartLine: 8, + StartColumn: 2, + EndLine: 8, + EndColumn: 50, + }, + }, + }, + }.Run(t) +} + +func TestOption(t *testing.T) { + t.Parallel() + + checktest.CheckTest{ + Request: &checktest.RequestSpec{ + Files: &checktest.ProtoFileSpec{ + DirPaths: []string{"testdata/option"}, + FilePaths: []string{"option.proto"}, + }, + Options: map[string]any{ + timestampSuffixOptionKey: "_timestamp", + }, + }, + Spec: spec, + ExpectedAnnotations: []checktest.ExpectedAnnotation{ + { + RuleID: timestampSuffixRuleID, + Location: &checktest.ExpectedLocation{ + FileName: "option.proto", + StartLine: 8, + StartColumn: 2, + EndLine: 8, + EndColumn: 45, + }, + }, + }, + }.Run(t) +} diff --git a/check/internal/example/cmd/buf-plugin-timestamp-suffix/testdata/option/option.proto b/check/internal/example/cmd/buf-plugin-timestamp-suffix/testdata/option/option.proto new file mode 100644 index 0000000..94902fb --- /dev/null +++ b/check/internal/example/cmd/buf-plugin-timestamp-suffix/testdata/option/option.proto @@ -0,0 +1,10 @@ +syntax = "proto3"; + +package option; + +import "google/protobuf/timestamp.proto"; + +message Foo { + google.protobuf.Timestamp valid_timestamp = 1; + google.protobuf.Timestamp invalid_time = 2; +} diff --git a/check/internal/example/cmd/buf-plugin-timestamp-suffix/testdata/simple/simple.proto b/check/internal/example/cmd/buf-plugin-timestamp-suffix/testdata/simple/simple.proto new file mode 100644 index 0000000..de2de83 --- /dev/null +++ b/check/internal/example/cmd/buf-plugin-timestamp-suffix/testdata/simple/simple.proto @@ -0,0 +1,10 @@ +syntax = "proto3"; + +package simple; + +import "google/protobuf/timestamp.proto"; + +message Foo { + google.protobuf.Timestamp valid_time = 1; + google.protobuf.Timestamp invalid_timestamp = 2; +} diff --git a/check/internal/example/gen/acme/option/v1/option.pb.go b/check/internal/example/gen/acme/option/v1/option.pb.go new file mode 100644 index 0000000..43ddf70 --- /dev/null +++ b/check/internal/example/gen/acme/option/v1/option.pb.go @@ -0,0 +1,119 @@ +// Copyright 2024 Buf Technologies, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.34.2 +// protoc (unknown) +// source: acme/option/v1/option.proto + +package optionv1 + +import ( + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + descriptorpb "google.golang.org/protobuf/types/descriptorpb" + reflect "reflect" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +var file_acme_option_v1_option_proto_extTypes = []protoimpl.ExtensionInfo{ + { + ExtendedType: (*descriptorpb.FieldOptions)(nil), + ExtensionType: (*bool)(nil), + Field: 50001, + Name: "acme.option.v1.safe_for_ml", + Tag: "varint,50001,opt,name=safe_for_ml", + Filename: "acme/option/v1/option.proto", + }, +} + +// Extension fields to descriptorpb.FieldOptions. +var ( + // If true, the field is safe to be used for training ML models. + // + // optional bool safe_for_ml = 50001; + E_SafeForMl = &file_acme_option_v1_option_proto_extTypes[0] +) + +var File_acme_option_v1_option_proto protoreflect.FileDescriptor + +var file_acme_option_v1_option_proto_rawDesc = []byte{ + 0x0a, 0x1b, 0x61, 0x63, 0x6d, 0x65, 0x2f, 0x6f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x2f, 0x76, 0x31, + 0x2f, 0x6f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x0e, 0x61, + 0x63, 0x6d, 0x65, 0x2e, 0x6f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x2e, 0x76, 0x31, 0x1a, 0x20, 0x67, + 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2f, 0x64, + 0x65, 0x73, 0x63, 0x72, 0x69, 0x70, 0x74, 0x6f, 0x72, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x3a, + 0x42, 0x0a, 0x0b, 0x73, 0x61, 0x66, 0x65, 0x5f, 0x66, 0x6f, 0x72, 0x5f, 0x6d, 0x6c, 0x12, 0x1d, + 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, + 0x2e, 0x46, 0x69, 0x65, 0x6c, 0x64, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x18, 0xd1, 0x86, + 0x03, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x73, 0x61, 0x66, 0x65, 0x46, 0x6f, 0x72, 0x4d, 0x6c, + 0x88, 0x01, 0x01, 0x42, 0xd0, 0x01, 0x0a, 0x12, 0x63, 0x6f, 0x6d, 0x2e, 0x61, 0x63, 0x6d, 0x65, + 0x2e, 0x6f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x2e, 0x76, 0x31, 0x42, 0x0b, 0x4f, 0x70, 0x74, 0x69, + 0x6f, 0x6e, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x50, 0x01, 0x5a, 0x53, 0x67, 0x69, 0x74, 0x68, 0x75, + 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x62, 0x75, 0x66, 0x62, 0x75, 0x69, 0x6c, 0x64, 0x2f, 0x62, + 0x75, 0x66, 0x70, 0x6c, 0x75, 0x67, 0x69, 0x6e, 0x2d, 0x67, 0x6f, 0x2f, 0x63, 0x68, 0x65, 0x63, + 0x6b, 0x2f, 0x69, 0x6e, 0x74, 0x65, 0x72, 0x6e, 0x61, 0x6c, 0x2f, 0x65, 0x78, 0x61, 0x6d, 0x70, + 0x6c, 0x65, 0x2f, 0x67, 0x65, 0x6e, 0x2f, 0x61, 0x63, 0x6d, 0x65, 0x2f, 0x6f, 0x70, 0x74, 0x69, + 0x6f, 0x6e, 0x2f, 0x76, 0x31, 0x3b, 0x6f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x76, 0x31, 0xa2, 0x02, + 0x03, 0x41, 0x4f, 0x58, 0xaa, 0x02, 0x0e, 0x41, 0x63, 0x6d, 0x65, 0x2e, 0x4f, 0x70, 0x74, 0x69, + 0x6f, 0x6e, 0x2e, 0x56, 0x31, 0xca, 0x02, 0x0e, 0x41, 0x63, 0x6d, 0x65, 0x5c, 0x4f, 0x70, 0x74, + 0x69, 0x6f, 0x6e, 0x5c, 0x56, 0x31, 0xe2, 0x02, 0x1a, 0x41, 0x63, 0x6d, 0x65, 0x5c, 0x4f, 0x70, + 0x74, 0x69, 0x6f, 0x6e, 0x5c, 0x56, 0x31, 0x5c, 0x47, 0x50, 0x42, 0x4d, 0x65, 0x74, 0x61, 0x64, + 0x61, 0x74, 0x61, 0xea, 0x02, 0x10, 0x41, 0x63, 0x6d, 0x65, 0x3a, 0x3a, 0x4f, 0x70, 0x74, 0x69, + 0x6f, 0x6e, 0x3a, 0x3a, 0x56, 0x31, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, +} + +var file_acme_option_v1_option_proto_goTypes = []any{ + (*descriptorpb.FieldOptions)(nil), // 0: google.protobuf.FieldOptions +} +var file_acme_option_v1_option_proto_depIdxs = []int32{ + 0, // 0: acme.option.v1.safe_for_ml:extendee -> google.protobuf.FieldOptions + 1, // [1:1] is the sub-list for method output_type + 1, // [1:1] is the sub-list for method input_type + 1, // [1:1] is the sub-list for extension type_name + 0, // [0:1] is the sub-list for extension extendee + 0, // [0:0] is the sub-list for field type_name +} + +func init() { file_acme_option_v1_option_proto_init() } +func file_acme_option_v1_option_proto_init() { + if File_acme_option_v1_option_proto != nil { + return + } + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: file_acme_option_v1_option_proto_rawDesc, + NumEnums: 0, + NumMessages: 0, + NumExtensions: 1, + NumServices: 0, + }, + GoTypes: file_acme_option_v1_option_proto_goTypes, + DependencyIndexes: file_acme_option_v1_option_proto_depIdxs, + ExtensionInfos: file_acme_option_v1_option_proto_extTypes, + }.Build() + File_acme_option_v1_option_proto = out.File + file_acme_option_v1_option_proto_rawDesc = nil + file_acme_option_v1_option_proto_goTypes = nil + file_acme_option_v1_option_proto_depIdxs = nil +} diff --git a/check/internal/example/proto/acme/option/v1/option.proto b/check/internal/example/proto/acme/option/v1/option.proto new file mode 100644 index 0000000..8569786 --- /dev/null +++ b/check/internal/example/proto/acme/option/v1/option.proto @@ -0,0 +1,24 @@ +// Copyright 2024 Buf Technologies, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto3"; + +package acme.option.v1; + +import "google/protobuf/descriptor.proto"; + +extend google.protobuf.FieldOptions { + // If true, the field is safe to be used for training ML models. + optional bool safe_for_ml = 50001; +} diff --git a/check/location.go b/check/location.go new file mode 100644 index 0000000..952ac62 --- /dev/null +++ b/check/location.go @@ -0,0 +1,128 @@ +// Copyright 2024 Buf Technologies, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package check + +import ( + "slices" + + checkv1 "buf.build/gen/go/bufbuild/bufplugin/protocolbuffers/go/buf/plugin/check/v1" + "google.golang.org/protobuf/reflect/protoreflect" +) + +// Location is a reference to a File or to a location within a File. +// +// A Location always has a file name. +type Location interface { + // File is the File associated with the Location. + // + // Always present. + File() File + // SourcePath returns the path within the FileDescriptorProto of the Location. + SourcePath() protoreflect.SourcePath + + // StartLine returns the zero-indexed start line, if known. + StartLine() int + // StartColumn returns the zero-indexed start column, if known. + StartColumn() int + // EndLine returns the zero-indexed end line, if known. + EndLine() int + // EndColumn returns the zero-indexed end column, if known. + EndColumn() int + // LeadingComments returns any leading comments, if known. + LeadingComments() string + // TrailingComments returns any trailing comments, if known. + TrailingComments() string + // LeadingDetachedComments returns any leading detached comments, if known. + LeadingDetachedComments() []string + + unclonedSourcePath() protoreflect.SourcePath + unclonedLeadingDetachedComments() []string + toProto() *checkv1.Location + + isLocation() +} + +// *** PRIVATE *** + +type location struct { + file File + sourceLocation protoreflect.SourceLocation +} + +func newLocation( + file File, + sourceLocation protoreflect.SourceLocation, +) *location { + return &location{ + file: file, + sourceLocation: sourceLocation, + } +} + +func (l *location) File() File { + return l.file +} + +func (l *location) SourcePath() protoreflect.SourcePath { + return slices.Clone(l.sourceLocation.Path) +} + +func (l *location) StartLine() int { + return l.sourceLocation.StartLine +} + +func (l *location) StartColumn() int { + return l.sourceLocation.StartColumn +} + +func (l *location) EndLine() int { + return l.sourceLocation.EndLine +} + +func (l *location) EndColumn() int { + return l.sourceLocation.EndColumn +} + +func (l *location) LeadingComments() string { + return l.sourceLocation.LeadingComments +} + +func (l *location) TrailingComments() string { + return l.sourceLocation.TrailingComments +} + +func (l *location) LeadingDetachedComments() []string { + return slices.Clone(l.sourceLocation.LeadingDetachedComments) +} + +func (l *location) unclonedSourcePath() protoreflect.SourcePath { + return l.sourceLocation.Path +} + +func (l *location) unclonedLeadingDetachedComments() []string { + return l.sourceLocation.LeadingDetachedComments +} + +func (l *location) toProto() *checkv1.Location { + if l == nil { + return nil + } + return &checkv1.Location{ + FileName: l.file.FileDescriptor().Path(), + SourcePath: l.sourceLocation.Path, + } +} + +func (*location) isLocation() {} diff --git a/check/main.go b/check/main.go new file mode 100644 index 0000000..c2cd4e2 --- /dev/null +++ b/check/main.go @@ -0,0 +1,86 @@ +// Copyright 2024 Buf Technologies, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package check + +import ( + "pluginrpc.com/pluginrpc" +) + +// Main is the main entrypoint for a plugin that implements the given Spec. +// +// A plugin just needs to provide a Spec, and then call this function within main. +// +// func main() { +// check.Main( +// &check.Spec { +// Rules: []*check.RuleSpec{ +// { +// ID: "TIMESTAMP_SUFFIX", +// Default: true, +// Purpose: "Checks that all google.protobuf.Timestamps end in _time.", +// Type: check.RuleTypeLint, +// Handler: check.RuleHandlerFunc(handleTimestampSuffix), +// }, +// }, +// }, +// ) +// } +func Main(spec *Spec, options ...MainOption) { + mainOptions := newMainOptions() + for _, option := range options { + option(mainOptions) + } + pluginrpc.Main( + func() (pluginrpc.Server, error) { + checkServiceHandler, err := NewCheckServiceHandler( + spec, + CheckServiceHandlerWithParallelism(mainOptions.parallelism), + ) + if err != nil { + return nil, err + } + return NewCheckServiceServer(checkServiceHandler) + }, + ) +} + +// MainOption is an option for Main. +type MainOption func(*mainOptions) + +// MainWithParallelism returns a new MainOption that sets the parallelism by which Rules +// will be run. +// +// If this is set to a value >= 1, this many concurrent Rules can be run at the same time. +// A value of 0 indicates the default behavior, which is to use runtime.GOMAXPROCS(0). +// +// A value if < 0 has no effect. +func MainWithParallelism(parallelism int) MainOption { + return func(mainOptions *mainOptions) { + if parallelism < 0 { + parallelism = 0 + } + mainOptions.parallelism = parallelism + } +} + +// *** PRIVATE *** + +type mainOptions struct { + parallelism int +} + +func newMainOptions() *mainOptions { + return &mainOptions{} +} diff --git a/check/options.go b/check/options.go new file mode 100644 index 0000000..bf03b8b --- /dev/null +++ b/check/options.go @@ -0,0 +1,434 @@ +// Copyright 2024 Buf Technologies, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package check + +import ( + "errors" + "fmt" + "reflect" + + checkv1 "buf.build/gen/go/bufbuild/bufplugin/protocolbuffers/go/buf/plugin/check/v1" +) + +var emptyOptions = newOptionsNoValidate(nil) + +// Options are key/values that can control the behavior of a RuleHandler, +// and can control the value of the Purpose string of the Rule. +// +// For example, if you had a Rule that checked that the suffix of all Services was "API", +// you may want an option with key "service_suffix" that can override the suffix "API" to +// another suffix such as "Service". This would result in the behavior of the check changing, +// as well as result in the Purpose string potentially changing to specify that the +// expected suffix is "Service" instead of "API". +// +// It is not possible to set a key with a not-present value. Do not add an Option with +// a given key to denote that the key is not set. +type Options interface { + // Get gets the option value for the given key. + // + // Values will be one of: + // + // - int64 + // - float64 + // - string + // - []byte + // - bool + // - A slice of any of the above, recursively (i.e. []string, [][]int64, ...) + // + // A caller should not modify a returned value. + // + // The key must have at least four characters. + // The key must start and end with a lowercase letter from a-z, and only consist + // of lowercase letters from a-z and underscores. + Get(key string) (any, bool) + // Range ranges over all key/value pairs. + // + // The range order is not deterministic. + Range(f func(key string, value any)) + + toProto() ([]*checkv1.Option, error) + + isOption() +} + +// NewOptions returns a new validated Options for the given key/value map. +func NewOptions(keyToValue map[string]any) (Options, error) { + if err := validateKeyToValue(keyToValue); err != nil { + return nil, err + } + return newOptionsNoValidate(keyToValue), nil +} + +// OptionsForProtoOptions returns a new Options for the given checkv1.Options. +func OptionsForProtoOptions(protoOptions []*checkv1.Option) (Options, error) { + keyToValue := make(map[string]any, len(protoOptions)) + for _, protoOption := range protoOptions { + value, err := protoValueToValue(protoOption.GetValue()) + if err != nil { + return nil, err + } + keyToValue[protoOption.GetKey()] = value + } + return NewOptions(keyToValue) +} + +// GetBoolValue gets a bool value from the Options. +// +// If the value is present and is not of type bool, an error is returned. +func GetBoolValue(options Options, key string) (bool, error) { + anyValue, ok := options.Get(key) + if !ok { + return false, nil + } + value, ok := anyValue.(bool) + if !ok { + return false, newUnexpectedOptionValueTypeError(key, false, anyValue) + } + return value, nil +} + +// GetInt64Value gets a int64 value from the Options. +// +// If the value is present and is not of type int64, an error is returned. +func GetInt64Value(options Options, key string) (int64, error) { + anyValue, ok := options.Get(key) + if !ok { + return 0, nil + } + value, ok := anyValue.(int64) + if !ok { + return 0, newUnexpectedOptionValueTypeError(key, int64(0), anyValue) + } + return value, nil +} + +// GetFloat64Value gets a float64 value from the Options. +// +// If the value is present and is not of type float64, an error is returned. +func GetFloat64Value(options Options, key string) (float64, error) { + anyValue, ok := options.Get(key) + if !ok { + return 0.0, nil + } + value, ok := anyValue.(float64) + if !ok { + return 0.0, newUnexpectedOptionValueTypeError(key, float64(0.0), anyValue) + } + return value, nil +} + +// GetStringValue gets a string value from the Options. +// +// If the value is present and is not of type string, an error is returned. +func GetStringValue(options Options, key string) (string, error) { + anyValue, ok := options.Get(key) + if !ok { + return "", nil + } + value, ok := anyValue.(string) + if !ok { + return "", newUnexpectedOptionValueTypeError(key, "", anyValue) + } + return value, nil +} + +// GetBytesValue gets a bytes value from the Options. +// +// If the value is present and is not of type bytes, an error is returned. +func GetBytesValue(options Options, key string) ([]byte, error) { + anyValue, ok := options.Get(key) + if !ok { + return nil, nil + } + value, ok := anyValue.([]byte) + if !ok { + return nil, newUnexpectedOptionValueTypeError(key, []byte{}, anyValue) + } + return value, nil +} + +// GetInt64SliceValue gets a []int64 value from the Options. +// +// If the value is present and is not of type []int64, an error is returned. +func GetInt64SliceValue(options Options, key string) ([]int64, error) { + anyValue, ok := options.Get(key) + if !ok { + return nil, nil + } + value, ok := anyValue.([]int64) + if !ok { + return nil, newUnexpectedOptionValueTypeError(key, []int64{}, anyValue) + } + return value, nil +} + +// GetFloat64SliceValue gets a []float64 value from the Options. +// +// If the value is present and is not of type []float64, an error is returned. +func GetFloat64SliceValue(options Options, key string) ([]float64, error) { + anyValue, ok := options.Get(key) + if !ok { + return nil, nil + } + value, ok := anyValue.([]float64) + if !ok { + return nil, newUnexpectedOptionValueTypeError(key, []float64{}, anyValue) + } + return value, nil +} + +// GetStringSliceValue gets a []string value from the Options. +// +// If the value is present and is not of type []string, an error is returned. +func GetStringSliceValue(options Options, key string) ([]string, error) { + anyValue, ok := options.Get(key) + if !ok { + return nil, nil + } + value, ok := anyValue.([]string) + if !ok { + return nil, newUnexpectedOptionValueTypeError(key, []string{}, anyValue) + } + return value, nil +} + +// *** PRIVATE *** + +type options struct { + keyToValue map[string]any +} + +func newOptionsNoValidate(keyToValue map[string]any) *options { + if keyToValue == nil { + keyToValue = make(map[string]any) + } + return &options{ + keyToValue: keyToValue, + } +} + +func (o *options) Get(key string) (any, bool) { + value, ok := o.keyToValue[key] + return value, ok +} + +func (o *options) Range(f func(key string, value any)) { + for key, value := range o.keyToValue { + f(key, value) + } +} + +func (o *options) toProto() ([]*checkv1.Option, error) { + if o == nil { + return nil, nil + } + protoOptions := make([]*checkv1.Option, 0, len(o.keyToValue)) + for key, value := range o.keyToValue { + protoValue, err := valueToProtoValue(value) + if err != nil { + return nil, err + } + // Assuming that we've validated that no values are empty. + protoOptions = append( + protoOptions, + &checkv1.Option{ + Key: key, + Value: protoValue, + }, + ) + } + return protoOptions, nil +} + +func (*options) isOption() {} + +// You can assume that value is a valid value. +func valueToProtoValue(value any) (*checkv1.Value, error) { + switch reflectValue := reflect.ValueOf(value); reflectValue.Kind() { + case reflect.Bool: + return &checkv1.Value{ + Type: &checkv1.Value_BoolValue{ + BoolValue: reflectValue.Bool(), + }, + }, nil + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return &checkv1.Value{ + Type: &checkv1.Value_Int64Value{ + Int64Value: reflectValue.Int(), + }, + }, nil + case reflect.Float32, reflect.Float64: + return &checkv1.Value{ + Type: &checkv1.Value_DoubleValue{ + DoubleValue: reflectValue.Float(), + }, + }, nil + case reflect.String: + return &checkv1.Value{ + Type: &checkv1.Value_StringValue{ + StringValue: reflectValue.String(), + }, + }, nil + case reflect.Slice: + if t, ok := value.([]byte); ok { + return &checkv1.Value{ + Type: &checkv1.Value_BytesValue{ + BytesValue: t, + }, + }, nil + } + values := make([]*checkv1.Value, reflectValue.Len()) + for i := 0; i < reflectValue.Len(); i++ { + subValue, err := valueToProtoValue(reflectValue.Index(i).Interface()) + if err != nil { + return nil, err + } + values[i] = subValue + } + return &checkv1.Value{ + Type: &checkv1.Value_ListValue{ + ListValue: &checkv1.ListValue{ + Values: values, + }, + }, + }, nil + case reflect.Invalid, reflect.Uintptr, reflect.Complex64, reflect.Complex128, reflect.Array, reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, reflect.Pointer | reflect.Ptr, reflect.Struct, reflect.UnsafePointer: + return nil, fmt.Errorf("invalid type for Options value %T", value) + default: + return nil, fmt.Errorf("invalid type for Options value %T", value) + } +} + +func protoValueToValue(protoValue *checkv1.Value) (any, error) { + if protoValue == nil { + return nil, errors.New("invalid checkv1.Value: value cannot be nil") + } + switch { + case protoValue.GetBoolValue(): + return protoValue.GetBoolValue(), nil + case protoValue.GetInt64Value() != 0: + return protoValue.GetInt64Value(), nil + case protoValue.GetDoubleValue() != 0: + return protoValue.GetDoubleValue(), nil + case len(protoValue.GetStringValue()) > 0: + return protoValue.GetStringValue(), nil + case len(protoValue.GetBytesValue()) > 0: + return protoValue.GetBytesValue(), nil + case protoValue.GetListValue() != nil: + protoListValue := protoValue.GetListValue() + protoListValues := protoListValue.GetValues() + if len(protoListValues) == 0 { + return nil, errors.New("invalid checkv1.Value: list_values had no values") + } + anySlice := make([]any, len(protoListValue.GetValues())) + for i, protoSubValue := range protoListValues { + subValue, err := protoValueToValue(protoSubValue) + if err != nil { + return nil, err + } + anySlice[i] = subValue + } + // We know this is of at least length 1 + anySliceFirstType := reflect.TypeOf(anySlice[0]) + for i := 1; i < len(anySlice); i++ { + anySliceSubType := reflect.TypeOf(anySlice[i]) + if anySliceFirstType != anySliceSubType { + return nil, fmt.Errorf("invalid checkv1.Value: list_values must have values of the same type but detected types %v and %v", anySliceFirstType, anySliceSubType) + } + } + reflectSlice := reflect.MakeSlice(reflect.SliceOf(anySliceFirstType), 0, len(anySlice)) + for _, anySliceSubValue := range anySlice { + reflectSlice = reflect.Append(reflectSlice, reflect.ValueOf(anySliceSubValue)) + } + return reflectSlice.Interface(), nil + default: + return nil, errors.New("invalid checkv1.Value: no value of oneof is set") + } +} + +func validateKeyToValue(keyToValue map[string]any) error { + for key, value := range keyToValue { + // This should all be validated via protovalidate, and the below doesn't + // even encapsulate all the validation. + if len(key) == 0 { + return errors.New("invalid option key: key cannot be empty") + } + if err := validateValue(value); err != nil { + return err + } + } + return nil +} + +func validateValue(value any) error { + if value == nil { + return errors.New("invalid option value: value cannot be nil") + } + switch reflectValue := reflect.ValueOf(value); reflectValue.Kind() { + case reflect.Bool: + t := reflectValue.Bool() + if !t { + return errors.New("invalid option value: bool must be true") + } + return nil + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + t := reflectValue.Int() + if t == 0 { + return errors.New("invalid option value: int must be non-zero") + } + return nil + case reflect.Float32, reflect.Float64: + t := reflectValue.Float() + if t == 0 { + return errors.New("invalid option value: float must be non-zero") + } + return nil + case reflect.String: + t := reflectValue.String() + if t == "" { + return errors.New("invalid option value: string must be non-empty") + } + return nil + case reflect.Slice: + vLen := reflectValue.Len() + if vLen == 0 { + return errors.New("invalid option value: slice must be non-empty") + } + firstValue := reflectValue.Index(0).Interface() + firstValueType := reflect.TypeOf(firstValue) + for i := 1; i < vLen; i++ { + subValue := reflectValue.Index(i).Interface() + subValueType := reflect.TypeOf(subValue) + // reflect.Types are comparable with == per documentation. + if firstValueType != subValueType { + return fmt.Errorf("invalid option value: slice must have values of the same type but detected types %v and %v", firstValueType, subValueType) + } + } + return nil + case reflect.Invalid, reflect.Uintptr, reflect.Complex64, reflect.Complex128, reflect.Array, reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, reflect.Pointer | reflect.Ptr, reflect.Struct, reflect.UnsafePointer: + return fmt.Errorf("invalid option value: unhandled type %T", value) + default: + return fmt.Errorf("invalid option value: unhandled type %T", value) + } +} diff --git a/check/options_test.go b/check/options_test.go new file mode 100644 index 0000000..ca99dc3 --- /dev/null +++ b/check/options_test.go @@ -0,0 +1,85 @@ +// Copyright 2024 Buf Technologies, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package check + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestOptionsRoundTrip(t *testing.T) { + t.Parallel() + + testOptionsRoundTrip(t, true) + testOptionsRoundTrip(t, int64(1)) + testOptionsRoundTrip(t, float64(1.0)) + testOptionsRoundTrip(t, "foo") + testOptionsRoundTrip(t, []byte("foo")) + testOptionsRoundTrip(t, []bool{true, true}) + testOptionsRoundTrip(t, []int64{1, 2}) + testOptionsRoundTrip(t, []float64{1.0, 2.0}) + testOptionsRoundTrip(t, []string{"foo", "bar"}) + testOptionsRoundTrip(t, [][]string{{"foo", "bar"}, {"baz, bat"}}) + testOptionsRoundTripDifferentInputOutput( + t, + []any{"foo", "bar"}, + []string{"foo", "bar"}, + ) + testOptionsRoundTripDifferentInputOutput( + t, + []any{[]string{"foo"}, []string{"bar"}}, + [][]string{{"foo"}, {"bar"}}, + ) +} + +func TestOptionsValidateValueError(t *testing.T) { + t.Parallel() + + err := validateValue(false) + assert.Error(t, err) + err = validateValue(0) + assert.Error(t, err) + err = validateValue([]any{1, "foo"}) + assert.Error(t, err) + err = validateValue([]any{[]string{"foo"}, "foo"}) + assert.Error(t, err) +} + +func testOptionsRoundTrip(t *testing.T, value any) { + protoValue, err := valueToProtoValue(value) + require.NoError(t, err) + actualValue, err := protoValueToValue(protoValue) + require.NoError(t, err) + assert.Equal(t, value, actualValue) +} + +func testOptionsRoundTripDifferentInputOutput(t *testing.T, input any, expectedOutput any) { + protoValue, err := valueToProtoValue(input) + require.NoError(t, err) + actualValue, err := protoValueToValue(protoValue) + require.NoError(t, err) + assert.Equal(t, expectedOutput, actualValue) +} diff --git a/check/request.go b/check/request.go new file mode 100644 index 0000000..f1c7bc9 --- /dev/null +++ b/check/request.go @@ -0,0 +1,229 @@ +// Copyright 2024 Buf Technologies, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package check + +import ( + "slices" + "sort" + + checkv1 "buf.build/gen/go/bufbuild/bufplugin/protocolbuffers/go/buf/plugin/check/v1" + "github.com/bufbuild/bufplugin-go/internal/pkg/xslices" +) + +const checkRuleIDPageSize = 250 + +// Request is a request to a plugin to run checks. +type Request interface { + // Files contains the files to check. + // + // Will never be nil or empty. + // + // Files are guaranteed to be unique with respect to their file name + Files() []File + // AgainstFiles contains the files to check against, in the case of breaking change plugins. + // + // May be empty, including in the case where we did actually specify against files. + // + // Files are guaranteed to be unique with respect to their file name + AgainstFiles() []File + // Options contains any options passed to the plugin. + // + // Will never be nil, but may have no values. + Options() Options + // RuleIDs returns the specific IDs the of Rules to use. + // + // If empty, all default Rules will be used. + // The returned RuleIDs will be sorted. + // + // This may return more than 250 IDs; the underlying Client implemention is required to do + // any necessary chunking. + // + // RuleHandlers can safely ignore this - the handling of RuleIDs will have already + // been performed prior to the Request reaching the RuleHandler. + RuleIDs() []string + + // toProtos converts the Request into one or more CheckRequests. + // + // If there are more than 250 Rule IDs, multiple CheckRequests will be produced by chunking up + // the Rule IDs. + toProtos() ([]*checkv1.CheckRequest, error) + + isRequest() +} + +// NewRequest returns a new Request for the given Files. +// +// Files are always required. To set against Files or options, use +// WithAgainstFiles and WithOption. +func NewRequest( + files []File, + options ...RequestOption, +) (Request, error) { + return newRequest(files, options...) +} + +// RequestOption is an option for a new Request. +type RequestOption func(*requestOptions) + +// WithAgainstFiles adds the given against Files to the Request. +func WithAgainstFiles(againstFiles []File) RequestOption { + return func(requestOptions *requestOptions) { + requestOptions.againstFiles = againstFiles + } +} + +// WithOption adds the given Options to the Request. +func WithOptions(options Options) RequestOption { + return func(requestOptions *requestOptions) { + requestOptions.options = options + } +} + +// WithRuleIDs specifies that the given rule IDs should be used on the Request. +// +// Multiple calls to WithRuleIDs will result in the new rule IDs being appended. +// If duplicate rule IDs are specified, this will result in an error. +func WithRuleIDs(ruleIDs ...string) RequestOption { + return func(requestOptions *requestOptions) { + requestOptions.ruleIDs = append(requestOptions.ruleIDs, ruleIDs...) + } +} + +// RequestForProtoRequest returns a new Request for the given checkv1.Request. +func RequestForProtoRequest(protoRequest *checkv1.CheckRequest) (Request, error) { + files, err := FilesForProtoFiles(protoRequest.GetFiles()) + if err != nil { + return nil, err + } + againstFiles, err := FilesForProtoFiles(protoRequest.GetAgainstFiles()) + if err != nil { + return nil, err + } + options, err := OptionsForProtoOptions(protoRequest.GetOptions()) + if err != nil { + return nil, err + } + return NewRequest( + files, + WithAgainstFiles(againstFiles), + WithOptions(options), + WithRuleIDs(protoRequest.GetRuleIds()...), + ) +} + +// *** PRIVATE *** + +type request struct { + files []File + againstFiles []File + options Options + ruleIDs []string +} + +func newRequest( + files []File, + options ...RequestOption, +) (*request, error) { + requestOptions := newRequestOptions() + for _, option := range options { + option(requestOptions) + } + if requestOptions.options == nil { + requestOptions.options = emptyOptions + } + if err := validateNoDuplicateRuleOrCategoryIDs(requestOptions.ruleIDs); err != nil { + return nil, err + } + sort.Strings(requestOptions.ruleIDs) + if err := validateFiles(files); err != nil { + return nil, err + } + if err := validateFiles(requestOptions.againstFiles); err != nil { + return nil, err + } + return &request{ + files: files, + againstFiles: requestOptions.againstFiles, + options: requestOptions.options, + ruleIDs: requestOptions.ruleIDs, + }, nil +} + +func (r *request) Files() []File { + return slices.Clone(r.files) +} + +func (r *request) AgainstFiles() []File { + return slices.Clone(r.againstFiles) +} + +func (r *request) Options() Options { + return r.options +} + +func (r *request) RuleIDs() []string { + return slices.Clone(r.ruleIDs) +} + +func (r *request) toProtos() ([]*checkv1.CheckRequest, error) { + if r == nil { + return nil, nil + } + protoFiles := xslices.Map(r.files, File.toProto) + protoAgainstFiles := xslices.Map(r.againstFiles, File.toProto) + protoOptions, err := r.options.toProto() + if err != nil { + return nil, err + } + if len(r.ruleIDs) == 0 { + return []*checkv1.CheckRequest{ + { + Files: protoFiles, + AgainstFiles: protoAgainstFiles, + Options: protoOptions, + }, + }, nil + } + var checkRequests []*checkv1.CheckRequest + for i := 0; i < len(r.ruleIDs); i += checkRuleIDPageSize { + start := i + end := start + checkRuleIDPageSize + if end > len(r.ruleIDs) { + end = len(r.ruleIDs) + } + checkRequests = append( + checkRequests, + &checkv1.CheckRequest{ + Files: protoFiles, + AgainstFiles: protoAgainstFiles, + Options: protoOptions, + RuleIds: r.ruleIDs[start:end], + }, + ) + } + return checkRequests, nil +} + +func (*request) isRequest() {} + +type requestOptions struct { + againstFiles []File + options Options + ruleIDs []string +} + +func newRequestOptions() *requestOptions { + return &requestOptions{} +} diff --git a/check/response.go b/check/response.go new file mode 100644 index 0000000..cbe66a9 --- /dev/null +++ b/check/response.go @@ -0,0 +1,59 @@ +// Copyright 2024 Buf Technologies, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package check + +import ( + "slices" + + checkv1 "buf.build/gen/go/bufbuild/bufplugin/protocolbuffers/go/buf/plugin/check/v1" + "github.com/bufbuild/bufplugin-go/internal/pkg/xslices" +) + +// Response is a response from a plugin for a check call. +type Response interface { + // Annotations returns all of the Annotations. + // + // The returned annotations will be sorted. + Annotations() []Annotation + + toProto() *checkv1.CheckResponse + + isResponse() +} + +// *** PRIVATE *** + +type response struct { + annotations []Annotation +} + +func newResponse(annotations []Annotation) (*response, error) { + sortAnnotations(annotations) + return &response{ + annotations: annotations, + }, nil +} + +func (r *response) Annotations() []Annotation { + return slices.Clone(r.annotations) +} + +func (r *response) toProto() *checkv1.CheckResponse { + return &checkv1.CheckResponse{ + Annotations: xslices.Map(r.annotations, Annotation.toProto), + } +} + +func (*response) isResponse() {} diff --git a/check/response_writer.go b/check/response_writer.go new file mode 100644 index 0000000..6ad75a2 --- /dev/null +++ b/check/response_writer.go @@ -0,0 +1,347 @@ +// Copyright 2024 Buf Technologies, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package check + +import ( + "errors" + "fmt" + "sync" + + "google.golang.org/protobuf/reflect/protoreflect" +) + +var errCannotReuseResponseWriter = errors.New("cannot reuse ResponseWriter") + +// ResponseWriter is used by plugin implmentations to add Annotations to responses. +// +// A ResponseWriter is tied to a specific rule, and is passed to a RuleHandler. +// The ID of the Rule will be automatically populated for any added Annotations. +type ResponseWriter interface { + // AddAnnotation adds an Annotation with the rule ID that is tied to this ResponseWriter. + // + // Fields of the Annotation are controlled with AddAnnotationOptions, of which there are several: + // + // - WithMessage/WithMessagef: Add a message to the Annotation. + // - WithDescriptor/WithAgainstDescriptor: Use the protoreflect.Descriptor to determine Location information. + // - WithFileName/WithAgainstFileName: Use the given file name on the Location. + // - WithFileNameAndSourcePath/WithAgainstFileNameAndSourcePath: Use the given explicit file name and source path on the Location. + // + // There are some rules to note when using AddAnnotationOptions: + // + // - Multiple calls of WithMessage/WithMessagef will overwrite previous calls. + // - You must either use WithDescriptor, or use WithFileName/WithSourcePath, but you cannot + // use these together. Location information is determined either from the Descriptor, or + // from explicit setting via WithFileName/WithFileNameAndSourcePath. Same applies to the Against equivalents. + // + // Don't worry, these rules are verified when building a Response. + // + // Most users will use WithDescriptor/WithAgainstDescriptor as opposed to their lower-level variants. + AddAnnotation(options ...AddAnnotationOption) + + isResponseWriter() +} + +// AddAnnotationOption is an option with adding an Annotation to a ResponseWriter. +type AddAnnotationOption func(*addAnnotationOptions) + +// WithMessage sets the message on the Annotation. +// +// If there are multiple calls to WithMessage or WithMessagef, the last one wins. +func WithMessage(message string) AddAnnotationOption { + return func(addAnnotationOptions *addAnnotationOptions) { + addAnnotationOptions.message = message + } +} + +// WithMessagef sets the message on the Annotation. +// +// If there are multiple calls to WithMessage or WithMessagef, the last one wins. +func WithMessagef(format string, args ...any) AddAnnotationOption { + return func(addAnnotationOptions *addAnnotationOptions) { + addAnnotationOptions.message = fmt.Sprintf(format, args...) + } +} + +// WithDescriptor will set the Location on the Annotation by extracting file and source path +// information from the descriptor itself. +// +// It is not valid to use WithDescriptor if also using either WithFileName or WithSourcePath. +func WithDescriptor(descriptor protoreflect.Descriptor) AddAnnotationOption { + return func(addAnnotationOptions *addAnnotationOptions) { + addAnnotationOptions.descriptor = descriptor + } +} + +// WithFileName will set the FileName on the Annotation's Location directly. +// +// Typically, most users will use WithDescriptor to accomplish this task. +// +// This will not set any line/column information. To do so, use WithFileNameAndSourcePath. +// +// It is not valid to use WithDescriptor if also using either WithFileName +// or WithFileNameAndSourcePath. +func WithFileName(fileName string) AddAnnotationOption { + return func(addAnnotationOptions *addAnnotationOptions) { + addAnnotationOptions.fileName = fileName + } +} + +// WithFileNameAndSourcePath will set the SourcePath on the Annotation's Location directly. +// +// Typically, most users will use WithDescriptor to accomplish this task. +// +// It is not valid to use WithDescriptor if also using either WithFileName +// or WithFileNameAndSourcePath. +func WithFileNameAndSourcePath(fileName string, sourcePath protoreflect.SourcePath) AddAnnotationOption { + return func(addAnnotationOptions *addAnnotationOptions) { + addAnnotationOptions.fileName = fileName + addAnnotationOptions.sourcePath = sourcePath + } +} + +// WithAgainstDescriptor will set the AgainstLocation on the Annotation by extracting file and +// source path information from the descriptor itself. +// +// It is not valid to use WithAgainstDescriptor if also using either WithAgainstFileName or +// WithAgainstSourcePath. +func WithAgainstDescriptor(againstDescriptor protoreflect.Descriptor) AddAnnotationOption { + return func(addAnnotationOptions *addAnnotationOptions) { + addAnnotationOptions.againstDescriptor = againstDescriptor + } +} + +// WithAgainstFileName will set the FileName on the Annotation's AgainstLocation directly. +// +// Typically, most users will use WithAgainstDescriptor to accomplish this task. +// +// This will not set any line/column information. To do so, use WithAgainstFileNameAndSourcePath. +// +// It is not valid to use WithAgainstDescriptor if also using either WithAgainstFileName or +// WithAgainstFileNameAndSourcePath. +func WithAgainstFileName(againstFileName string) AddAnnotationOption { + return func(addAnnotationOptions *addAnnotationOptions) { + addAnnotationOptions.againstFileName = againstFileName + } +} + +// WithAgainstFileNameAndSourcePath will set the Filename and SourcePath on the +// Annotation's AgainstLocation directly. +// +// Typically, most users will use WithAgainstDescriptor to accomplish this task. +// +// It is not valid to use WithAgainstDescriptor if also using either WithAgainstFileName or +// WithAgainstFileNameAndSourcePath. +func WithAgainstFileNameAndSourcePath(againstFileName string, againstSourcePath protoreflect.SourcePath) AddAnnotationOption { + return func(addAnnotationOptions *addAnnotationOptions) { + addAnnotationOptions.againstFileName = againstFileName + addAnnotationOptions.againstSourcePath = againstSourcePath + } +} + +// *** PRIVATE *** + +// multiResponseWriter is a ResponseWriter that can be used for multiple IDs. It differs +// from a ResponseWriter in that an ID must be provided to addAnnotation. A multiResponseWriter +// itself creates ResponseWriters. +// +// multiResponseWriter is used by checkClients and checkServiceHandlers. +type multiResponseWriter struct { + fileNameToFile map[string]File + againstFileNameToFile map[string]File + + annotations []Annotation + written bool + errs []error + lock sync.RWMutex +} + +func newMultiResponseWriter(request Request) (*multiResponseWriter, error) { + fileNameToFile, err := fileNameToFileForFiles(request.Files()) + if err != nil { + return nil, err + } + againstFileNameToFile, err := fileNameToFileForFiles(request.AgainstFiles()) + if err != nil { + return nil, err + } + return &multiResponseWriter{ + fileNameToFile: fileNameToFile, + againstFileNameToFile: againstFileNameToFile, + }, nil +} + +func (m *multiResponseWriter) newResponseWriter(id string) *responseWriter { + return newResponseWriter(m, id) +} + +func (m *multiResponseWriter) addAnnotation( + ruleID string, + options ...AddAnnotationOption, +) { + addAnnotationOptions := newAddAnnotationOptions() + for _, option := range options { + option(addAnnotationOptions) + } + + m.lock.Lock() + defer m.lock.Unlock() + + if err := validateAddAnnotationOptions(addAnnotationOptions); err != nil { + m.errs = append(m.errs, err) + return + } + + if m.written { + m.errs = append(m.errs, errCannotReuseResponseWriter) + return + } + + location, err := getLocationForAddAnnotationOptions( + m.fileNameToFile, + addAnnotationOptions.descriptor, + addAnnotationOptions.fileName, + addAnnotationOptions.sourcePath, + ) + if err != nil { + m.errs = append(m.errs, err) + return + } + againstLocation, err := getLocationForAddAnnotationOptions( + m.againstFileNameToFile, + addAnnotationOptions.againstDescriptor, + addAnnotationOptions.againstFileName, + addAnnotationOptions.againstSourcePath, + ) + if err != nil { + m.errs = append(m.errs, err) + return + } + annotation, err := newAnnotation( + ruleID, + addAnnotationOptions.message, + location, + againstLocation, + ) + if err != nil { + m.errs = append(m.errs, err) + return + } + + m.annotations = append(m.annotations, annotation) +} + +func (m *multiResponseWriter) toResponse() (Response, error) { + m.lock.RLock() + defer m.lock.RUnlock() + + if len(m.errs) > 0 { + return nil, errors.Join(m.errs...) + } + if m.written { + return nil, errCannotReuseResponseWriter + } + m.written = true + + return newResponse(m.annotations) +} + +type responseWriter struct { + multiResponseWriter *multiResponseWriter + id string +} + +func newResponseWriter( + multiResponseWriter *multiResponseWriter, + id string, +) *responseWriter { + return &responseWriter{ + multiResponseWriter: multiResponseWriter, + id: id, + } +} + +func (r *responseWriter) AddAnnotation( + options ...AddAnnotationOption, +) { + r.multiResponseWriter.addAnnotation(r.id, options...) +} + +func (*responseWriter) isResponseWriter() {} + +type addAnnotationOptions struct { + message string + descriptor protoreflect.Descriptor + againstDescriptor protoreflect.Descriptor + fileName string + sourcePath protoreflect.SourcePath + againstFileName string + againstSourcePath protoreflect.SourcePath +} + +func newAddAnnotationOptions() *addAnnotationOptions { + return &addAnnotationOptions{} +} + +func validateAddAnnotationOptions(addAnnotationOptions *addAnnotationOptions) error { + if addAnnotationOptions.descriptor != nil && + (addAnnotationOptions.fileName != "" || len(addAnnotationOptions.sourcePath) > 0) { + return errors.New("cannot call both WithDescriptor and WithFileName or WithFileNameAndSourcePath") + } + if addAnnotationOptions.againstDescriptor != nil && + (addAnnotationOptions.againstFileName != "" || len(addAnnotationOptions.againstSourcePath) > 0) { + return errors.New("cannot call both WithAgainstDescriptor and WithAgainstFileName or WithAgainstFileNameAndSourcePath") + } + if addAnnotationOptions.fileName == "" && len(addAnnotationOptions.sourcePath) > 0 { + return errors.New("must set a non-empty FileName when calling WithFileNameAndSourcePath") + } + if addAnnotationOptions.againstFileName == "" && len(addAnnotationOptions.againstSourcePath) > 0 { + return errors.New("must set a non-empty FileName when calling WithAgainstFileNameAndSourcePath") + } + return nil +} + +func getLocationForAddAnnotationOptions( + fileNameToFile map[string]File, + descriptor protoreflect.Descriptor, + fileName string, + path protoreflect.SourcePath, +) (Location, error) { + if descriptor != nil { + // Technically, ParentFile() can be nil. + if fileDescriptor := descriptor.ParentFile(); fileDescriptor != nil { + file, ok := fileNameToFile[fileDescriptor.Path()] + if !ok { + return nil, fmt.Errorf("cannot add annotation for unknown file: %q", fileDescriptor.Path()) + } + return newLocation( + file, + fileDescriptor.SourceLocations().ByDescriptor(descriptor), + ), nil + } + return nil, nil + } + if fileName != "" { + var sourceLocation protoreflect.SourceLocation + file, ok := fileNameToFile[fileName] + if !ok { + return nil, fmt.Errorf("cannot add annotation for unknown file: %q", fileName) + } + if len(path) > 0 { + sourceLocation = file.FileDescriptor().SourceLocations().ByPath(path) + } + return newLocation(file, sourceLocation), nil + } + return nil, nil +} diff --git a/check/rule.go b/check/rule.go new file mode 100644 index 0000000..fc4722d --- /dev/null +++ b/check/rule.go @@ -0,0 +1,236 @@ +// Copyright 2024 Buf Technologies, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package check + +import ( + "errors" + "fmt" + "slices" + "sort" + + checkv1 "buf.build/gen/go/bufbuild/bufplugin/protocolbuffers/go/buf/plugin/check/v1" + "github.com/bufbuild/bufplugin-go/internal/pkg/xslices" +) + +// Rule is a single lint or breaking change rule. +// +// Rules have unique IDs. On the server-side (i.e. the plugin), Rules are created +// by RuleSpecs. Clients can list all available plugin Rules by calling ListRules. +type Rule interface { + // ID is the ID of the Rule. + // + // Always present. + // + // This uniquely identifies the Rule. + ID() string + // The categories that the Rule is a part of. + // + // Optional. + // + // Buf uses categories to include or exclude sets of rules via configuration. + Categories() []Category + // Whether or not the Rule is a default Rule. + // + // If a Rule is a default Rule, it will be called if a Request specifies no specific Rule IDs. + // + // A deprecated rule cannot be a default rule. + Default() bool + // A user-displayable purpose of the rule. + // + // Always present. + // + // This should be a proper sentence that starts with a capital letter and ends in a period. + Purpose() string + // Type is the type of the Rule. + Type() RuleType + // Deprecated returns whether or not this Rule is deprecated. + // + // If the Rule is deprecated, it may be replaced by 0 or more Rules. These will be denoted + // by ReplacementIDs. + Deprecated() bool + // ReplacementIDs returns the IDs of the Rules that replace this Rule, if this Rule is deprecated. + // + // This means that the combination of the Rules specified by ReplacementIDs replace this Rule entirely, + // and this Rule is considered equivalent to the AND of the rules specified by ReplacementIDs. + // + // This will only be non-empty if Deprecated is true. + // + // It is not valid for a deprecated Rule to specfiy another deprecated Rule as a replacement. + ReplacementIDs() []string + + toProto() *checkv1.Rule + + isRule() +} + +// *** PRIVATE *** + +type rule struct { + id string + categories []Category + isDefault bool + purpose string + ruleType RuleType + deprecated bool + replacementIDs []string +} + +func newRule( + id string, + categories []Category, + isDefault bool, + purpose string, + ruleType RuleType, + deprecated bool, + replacementIDs []string, +) (*rule, error) { + if id == "" { + return nil, errors.New("check.Rule: ID is empty") + } + if purpose == "" { + return nil, errors.New("check.Rule: ID is empty") + } + if isDefault && deprecated { + return nil, errors.New("check.Rule: Default and Deprecated are true") + } + if !deprecated && len(replacementIDs) > 0 { + return nil, fmt.Errorf("check.Rule: Deprecated is false but ReplacementIDs %v specified", replacementIDs) + } + return &rule{ + id: id, + categories: categories, + isDefault: isDefault, + purpose: purpose, + ruleType: ruleType, + deprecated: deprecated, + replacementIDs: replacementIDs, + }, nil +} + +func (r *rule) ID() string { + return r.id +} + +func (r *rule) Categories() []Category { + return slices.Clone(r.categories) +} + +func (r *rule) Default() bool { + return r.isDefault +} + +func (r *rule) Purpose() string { + return r.purpose +} + +func (r *rule) Type() RuleType { + return r.ruleType +} + +func (r *rule) Deprecated() bool { + return r.deprecated +} + +func (r *rule) ReplacementIDs() []string { + return slices.Clone(r.replacementIDs) +} + +func (r *rule) toProto() *checkv1.Rule { + if r == nil { + return nil + } + protoRuleType := ruleTypeToProtoRuleType[r.ruleType] + return &checkv1.Rule{ + Id: r.id, + CategoryIds: xslices.Map(r.categories, Category.ID), + Default: r.isDefault, + Purpose: r.purpose, + Type: protoRuleType, + Deprecated: r.deprecated, + ReplacementIds: r.replacementIDs, + } +} + +func (*rule) isRule() {} + +func ruleForProtoRule(protoRule *checkv1.Rule, idToCategory map[string]Category) (Rule, error) { + categories, err := xslices.MapError( + protoRule.GetCategoryIds(), + func(id string) (Category, error) { + category, ok := idToCategory[id] + if !ok { + return nil, fmt.Errorf("no category for ID %q", id) + } + return category, nil + }, + ) + if err != nil { + return nil, err + } + ruleType := protoRuleTypeToRuleType[protoRule.GetType()] + return newRule( + protoRule.GetId(), + categories, + protoRule.GetDefault(), + protoRule.GetPurpose(), + ruleType, + protoRule.GetDeprecated(), + protoRule.GetReplacementIds(), + ) +} + +func sortRules(rules []Rule) { + sort.Slice(rules, func(i int, j int) bool { return CompareRules(rules[i], rules[j]) < 0 }) +} + +func validateRules(rules []Rule) error { + return validateNoDuplicateRuleIDs(xslices.Map(rules, Rule.ID)) +} + +func validateNoDuplicateRuleIDs(ids []string) error { + idToCount := make(map[string]int, len(ids)) + for _, id := range ids { + idToCount[id]++ + } + var duplicateIDs []string + for id, count := range idToCount { + if count > 1 { + duplicateIDs = append(duplicateIDs, id) + } + } + if len(duplicateIDs) > 0 { + sort.Strings(duplicateIDs) + return newDuplicateRuleIDError(duplicateIDs) + } + return nil +} + +func validateNoDuplicateRuleOrCategoryIDs(ids []string) error { + idToCount := make(map[string]int, len(ids)) + for _, id := range ids { + idToCount[id]++ + } + var duplicateIDs []string + for id, count := range idToCount { + if count > 1 { + duplicateIDs = append(duplicateIDs, id) + } + } + if len(duplicateIDs) > 0 { + sort.Strings(duplicateIDs) + return newDuplicateRuleOrCategoryIDError(duplicateIDs) + } + return nil +} diff --git a/check/rule_handler.go b/check/rule_handler.go new file mode 100644 index 0000000..39b6856 --- /dev/null +++ b/check/rule_handler.go @@ -0,0 +1,36 @@ +// Copyright 2024 Buf Technologies, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package check + +import ( + "context" +) + +var nopRuleHandler = RuleHandlerFunc(func(context.Context, ResponseWriter, Request) error { return nil }) + +// RuleHandler implements the check logic for a single Rule. +// +// A RuleHandler takes in a Request, and writes Annotations to the ResponseWriter. +type RuleHandler interface { + Handle(ctx context.Context, responseWriter ResponseWriter, request Request) error +} + +// RuleHandlerFunc is a function that implements RuleHandler. +type RuleHandlerFunc func(context.Context, ResponseWriter, Request) error + +// Handle implements RuleHandler. +func (r RuleHandlerFunc) Handle(ctx context.Context, responseWriter ResponseWriter, request Request) error { + return r(ctx, responseWriter, request) +} diff --git a/check/rule_spec.go b/check/rule_spec.go new file mode 100644 index 0000000..3f8e22f --- /dev/null +++ b/check/rule_spec.go @@ -0,0 +1,166 @@ +// Copyright 2024 Buf Technologies, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package check + +import ( + "errors" + "fmt" + "regexp" + "sort" + + "github.com/bufbuild/bufplugin-go/internal/pkg/xslices" +) + +const ( + idMinLen = 3 + idMaxLen = 64 +) + +var ( + idRegexp = regexp.MustCompile("^[A-Z0-9][A-Z0-9_]*[A-Z0-9]$") + purposeRegexp = regexp.MustCompile("^[A-Z].*[.]$") +) + +// RuleSpec is the spec for a Rule. +// +// It is used to construct a Rule on the server-side (i.e. within the plugin). It specifies the +// ID, categories, purpose, type, and a RuleHandler to actually run the Rule logic. +// +// Generally, these are provided to Main. This library will handle Check and ListRules calls +// based on the provided RuleSpecs. +type RuleSpec struct { + // Required. + ID string + CategoryIDs []string + Default bool + // Required. + Purpose string + // Required. + Type RuleType + Deprecated bool + ReplacementIDs []string + // Required. + Handler RuleHandler +} + +// *** PRIVATE *** + +// Assumes that the RuleSpec is validated. +func ruleSpecToRule(ruleSpec *RuleSpec, idToCategory map[string]Category) (Rule, error) { + categories, err := xslices.MapError( + ruleSpec.CategoryIDs, + func(id string) (Category, error) { + category, ok := idToCategory[id] + if !ok { + return nil, fmt.Errorf("no category for id %q", id) + } + return category, nil + }, + ) + if err != nil { + return nil, err + } + return newRule( + ruleSpec.ID, + categories, + ruleSpec.Default, + ruleSpec.Purpose, + ruleSpec.Type, + ruleSpec.Deprecated, + ruleSpec.ReplacementIDs, + ) +} + +func validateRuleSpecs( + ruleSpecs []*RuleSpec, + categoryIDMap map[string]struct{}, +) error { + ruleIDs := xslices.Map(ruleSpecs, func(ruleSpec *RuleSpec) string { return ruleSpec.ID }) + if err := validateNoDuplicateRuleIDs(ruleIDs); err != nil { + return err + } + ruleIDToRuleSpec := make(map[string]*RuleSpec) + for _, ruleSpec := range ruleSpecs { + if err := validateID(ruleSpec.ID); err != nil { + return wrapValidateRuleSpecError(err) + } + ruleIDToRuleSpec[ruleSpec.ID] = ruleSpec + } + for _, ruleSpec := range ruleSpecs { + for _, categoryID := range ruleSpec.CategoryIDs { + if _, ok := categoryIDMap[categoryID]; !ok { + return newValidateRuleSpecErrorf("no category has ID %q", categoryID) + } + } + if err := validatePurpose(ruleSpec.ID, ruleSpec.Purpose); err != nil { + return wrapValidateRuleSpecError(err) + } + if ruleSpec.Type == 0 { + return newValidateRuleSpecErrorf("Type is not set for ID %q", ruleSpec.ID) + } + if _, ok := ruleTypeToProtoRuleType[ruleSpec.Type]; !ok { + return newValidateRuleSpecErrorf("Type is unknown: %q", ruleSpec.Type) + } + if ruleSpec.Handler == nil { + return newValidateRuleSpecErrorf("Handler is not set for ID %q", ruleSpec.ID) + } + if ruleSpec.Default && ruleSpec.Deprecated { + return newValidateRuleSpecErrorf("ID %q was a default Rule but Deprecated was false", ruleSpec.ID) + } + if len(ruleSpec.ReplacementIDs) > 0 && !ruleSpec.Deprecated { + return newValidateRuleSpecErrorf("ID %q had ReplacementIDs but Deprecated was false", ruleSpec.ID) + } + for _, replacementID := range ruleSpec.ReplacementIDs { + replacementRuleSpec, ok := ruleIDToRuleSpec[replacementID] + if !ok { + return newValidateRuleSpecErrorf("ID %q specified replacement ID %q which was not found", ruleSpec.ID, replacementID) + } + if replacementRuleSpec.Deprecated { + return newValidateRuleSpecErrorf("Deprecated ID %q specified replacement ID %q which also deprecated", ruleSpec.ID, replacementID) + } + } + } + return nil +} + +func sortRuleSpecs(ruleSpecs []*RuleSpec) { + sort.Slice(ruleSpecs, func(i int, j int) bool { return compareRuleSpecs(ruleSpecs[i], ruleSpecs[j]) < 0 }) +} + +func validateID(id string) error { + if id == "" { + return errors.New("ID is empty") + } + if len(id) < idMinLen { + return fmt.Errorf("ID %q must be at least length %d", id, idMinLen) + } + if len(id) > idMaxLen { + return fmt.Errorf("ID %q must be at most length %d", id, idMaxLen) + } + if !idRegexp.MatchString(id) { + return fmt.Errorf("ID %q does not match %q", id, idRegexp.String()) + } + return nil +} + +func validatePurpose(id string, purpose string) error { + if purpose == "" { + return fmt.Errorf("Purpose is empty for ID %q", id) + } + if !purposeRegexp.MatchString(purpose) { + return fmt.Errorf("Purpose %q for ID %q does not match %q", purpose, id, purposeRegexp.String()) + } + return nil +} diff --git a/check/rule_type.go b/check/rule_type.go new file mode 100644 index 0000000..b4109cc --- /dev/null +++ b/check/rule_type.go @@ -0,0 +1,54 @@ +// Copyright 2024 Buf Technologies, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package check + +import ( + "strconv" + + checkv1 "buf.build/gen/go/bufbuild/bufplugin/protocolbuffers/go/buf/plugin/check/v1" +) + +const ( + // RuleTypeLint is a lint Rule. + RuleTypeLint RuleType = 1 + // RuleTypeBreaking is a breaking change Rule. + RuleTypeBreaking RuleType = 2 +) + +var ( + ruleTypeToString = map[RuleType]string{ + RuleTypeLint: "lint", + RuleTypeBreaking: "breaking", + } + ruleTypeToProtoRuleType = map[RuleType]checkv1.RuleType{ + RuleTypeLint: checkv1.RuleType_RULE_TYPE_LINT, + RuleTypeBreaking: checkv1.RuleType_RULE_TYPE_BREAKING, + } + protoRuleTypeToRuleType = map[checkv1.RuleType]RuleType{ + checkv1.RuleType_RULE_TYPE_LINT: RuleTypeLint, + checkv1.RuleType_RULE_TYPE_BREAKING: RuleTypeBreaking, + } +) + +// RuleType is the type of Rule. +type RuleType int + +// String implements fmt.Stringer. +func (t RuleType) String() string { + if s, ok := ruleTypeToString[t]; ok { + return s + } + return strconv.Itoa(int(t)) +} diff --git a/check/spec.go b/check/spec.go new file mode 100644 index 0000000..0df66a1 --- /dev/null +++ b/check/spec.go @@ -0,0 +1,73 @@ +// Copyright 2024 Buf Technologies, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package check + +import ( + "context" + + "github.com/bufbuild/bufplugin-go/internal/pkg/xslices" +) + +// Spec is the spec for a plugin. +// +// It is used to construct a plugin on the server-side (i.e. within the plugin). +// +// Generally, this is provided to Main. This library will handle Check and ListRules calls +// based on the provided RuleSpecs. +type Spec struct { + // Required. + // + // All RuleSpecs must have Category IDs that match a CategorySpec within Categories. + // + // No IDs can overlap with Category IDs in Categories. + Rules []*RuleSpec + // Required if any RuleSpec specifies a category. + // + // All CategorySpecs must have an ID that matches at least one Category ID on a + // RuleSpec within Rules. + // + // No IDs can overlap with Rule IDs in Rules. + Categories []*CategorySpec + + // Before is a function that will be executed before any RuleHandlers are + // invoked that returns a new Context and Request. This new Context and + // Request will be passed to the RuleHandlers. This allows for any + // pre-processing that needs to occur. + Before func(ctx context.Context, request Request) (context.Context, Request, error) +} + +// ValidateSpec validates all values on a Spec. +// +// This is exposed publicly so it can be run as part of plugin tests. This will verify +// that your Spec will result in a valid plugin. +func ValidateSpec(spec *Spec) error { + if len(spec.Rules) == 0 { + return newValidateSpecError("Rules is empty") + } + categoryIDs := xslices.Map(spec.Categories, func(categorySpec *CategorySpec) string { return categorySpec.ID }) + if err := validateNoDuplicateRuleOrCategoryIDs( + append( + xslices.Map(spec.Rules, func(ruleSpec *RuleSpec) string { return ruleSpec.ID }), + categoryIDs..., + ), + ); err != nil { + return wrapValidateSpecError(err) + } + categoryIDMap := xslices.ToStructMap(categoryIDs) + if err := validateRuleSpecs(spec.Rules, categoryIDMap); err != nil { + return err + } + return validateCategorySpecs(spec.Categories, spec.Rules) +} diff --git a/check/spec_test.go b/check/spec_test.go new file mode 100644 index 0000000..d3f85d9 --- /dev/null +++ b/check/spec_test.go @@ -0,0 +1,170 @@ +// Copyright 2024 Buf Technologies, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package check + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestValidateSpec(t *testing.T) { + t.Parallel() + + validateRuleSpecError := &validateRuleSpecError{} + validateCategorySpecError := &validateCategorySpecError{} + validateSpecError := &validateSpecError{} + + // Simple spec that passes validation. + spec := &Spec{ + Rules: []*RuleSpec{ + testNewSimpleLintRuleSpec("RULE1", nil, true, false, nil), + testNewSimpleLintRuleSpec("RULE2", []string{"CATEGORY1"}, true, false, nil), + testNewSimpleLintRuleSpec("RULE3", []string{"CATEGORY1", "CATEGORY2"}, true, false, nil), + }, + Categories: []*CategorySpec{ + testNewSimpleCategorySpec("CATEGORY1", false, nil), + testNewSimpleCategorySpec("CATEGORY2", false, nil), + }, + } + require.NoError(t, ValidateSpec(spec)) + + // More complicated spec with deprecated rules and categories that passes validation. + spec = &Spec{ + Rules: []*RuleSpec{ + testNewSimpleLintRuleSpec("RULE1", nil, true, false, nil), + testNewSimpleLintRuleSpec("RULE2", []string{"CATEGORY1"}, true, false, nil), + testNewSimpleLintRuleSpec("RULE3", []string{"CATEGORY1", "CATEGORY2"}, true, false, nil), + testNewSimpleLintRuleSpec("RULE4", []string{"CATEGORY1"}, false, true, []string{"RULE1"}), + testNewSimpleLintRuleSpec("RULE5", []string{"CATEGORY3", "CATEGORY4"}, false, true, []string{"RULE2", "RULE3"}), + }, + Categories: []*CategorySpec{ + testNewSimpleCategorySpec("CATEGORY1", false, nil), + testNewSimpleCategorySpec("CATEGORY2", false, nil), + testNewSimpleCategorySpec("CATEGORY3", true, []string{"CATEGORY1"}), + testNewSimpleCategorySpec("CATEGORY4", true, []string{"CATEGORY1", "CATEGORY2"}), + }, + } + require.NoError(t, ValidateSpec(spec)) + + // Spec that has rules with categories with no resulting category spec. + spec = &Spec{ + Rules: []*RuleSpec{ + testNewSimpleLintRuleSpec("RULE1", nil, true, false, nil), + testNewSimpleLintRuleSpec("RULE2", []string{"CATEGORY1"}, true, false, nil), + testNewSimpleLintRuleSpec("RULE3", []string{"CATEGORY1", "CATEGORY2"}, true, false, nil), + }, + Categories: []*CategorySpec{ + testNewSimpleCategorySpec("CATEGORY1", false, nil), + }, + } + require.ErrorAs(t, ValidateSpec(spec), &validateRuleSpecError) + + // Spec that has categories with no rules with those categories. + spec = &Spec{ + Rules: []*RuleSpec{ + testNewSimpleLintRuleSpec("RULE1", nil, true, false, nil), + testNewSimpleLintRuleSpec("RULE2", []string{"CATEGORY1"}, true, false, nil), + testNewSimpleLintRuleSpec("RULE3", []string{"CATEGORY1", "CATEGORY2"}, true, false, nil), + }, + Categories: []*CategorySpec{ + testNewSimpleCategorySpec("CATEGORY1", false, nil), + testNewSimpleCategorySpec("CATEGORY2", false, nil), + testNewSimpleCategorySpec("CATEGORY3", false, nil), + testNewSimpleCategorySpec("CATEGORY4", false, nil), + }, + } + require.ErrorAs(t, ValidateSpec(spec), &validateCategorySpecError) + + // Spec that has overlapping rules and categories. + spec = &Spec{ + Rules: []*RuleSpec{ + testNewSimpleLintRuleSpec("RULE1", nil, true, false, nil), + testNewSimpleLintRuleSpec("RULE2", []string{"CATEGORY1"}, true, false, nil), + testNewSimpleLintRuleSpec("RULE3", []string{"CATEGORY1", "CATEGORY2"}, true, false, nil), + }, + Categories: []*CategorySpec{ + testNewSimpleCategorySpec("CATEGORY1", false, nil), + testNewSimpleCategorySpec("CATEGORY2", false, nil), + testNewSimpleCategorySpec("RULE3", false, nil), + }, + } + require.ErrorAs(t, ValidateSpec(spec), &validateSpecError) + + // Spec that has deprecated rules that point to deprecated rules. + spec = &Spec{ + Rules: []*RuleSpec{ + testNewSimpleLintRuleSpec("RULE1", nil, true, false, nil), + testNewSimpleLintRuleSpec("RULE2", nil, false, true, []string{"RULE1"}), + testNewSimpleLintRuleSpec("RULE3", nil, false, true, []string{"RULE2"}), + }, + } + require.ErrorAs(t, ValidateSpec(spec), &validateRuleSpecError) + + // Spec that has deprecated rules that are defaults. + spec = &Spec{ + Rules: []*RuleSpec{ + testNewSimpleLintRuleSpec("RULE1", nil, true, true, nil), + }, + } + require.ErrorAs(t, ValidateSpec(spec), &validateRuleSpecError) + + // Spec that has deprecated categories that point to deprecated categories. + spec = &Spec{ + Rules: []*RuleSpec{ + testNewSimpleLintRuleSpec("RULE1", nil, true, false, nil), + testNewSimpleLintRuleSpec("RULE2", []string{"CATEGORY1"}, true, false, nil), + testNewSimpleLintRuleSpec("RULE3", []string{"CATEGORY1", "CATEGORY2", "CATEGORY3"}, true, false, nil), + }, + Categories: []*CategorySpec{ + testNewSimpleCategorySpec("CATEGORY1", false, nil), + testNewSimpleCategorySpec("CATEGORY2", true, []string{"CATEGORY1"}), + testNewSimpleCategorySpec("CATEGORY3", true, []string{"CATEGORY2"}), + }, + } + require.ErrorAs(t, ValidateSpec(spec), &validateCategorySpecError) +} + +func testNewSimpleLintRuleSpec( + id string, + categoryIDs []string, + isDefault bool, + deprecated bool, + replacementIDs []string, +) *RuleSpec { + return &RuleSpec{ + ID: id, + CategoryIDs: categoryIDs, + Default: isDefault, + Purpose: "Checks " + id + ".", + Type: RuleTypeLint, + Deprecated: deprecated, + ReplacementIDs: replacementIDs, + Handler: nopRuleHandler, + } +} + +func testNewSimpleCategorySpec( + id string, + deprecated bool, + replacementIDs []string, +) *CategorySpec { + return &CategorySpec{ + ID: id, + Purpose: "Checks " + id + ".", + Deprecated: deprecated, + ReplacementIDs: replacementIDs, + } +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..d6add0b --- /dev/null +++ b/go.mod @@ -0,0 +1,33 @@ +module github.com/bufbuild/bufplugin-go + +go 1.21 + +toolchain go1.23.0 + +require ( + buf.build/gen/go/bufbuild/bufplugin/protocolbuffers/go v1.34.2-20240904175752-5f2ce91228e8.2 + github.com/bufbuild/protocompile v0.14.1 + github.com/bufbuild/protovalidate-go v0.6.5 + github.com/stretchr/testify v1.9.0 + google.golang.org/protobuf v1.34.2 + pluginrpc.com/pluginrpc v0.2.0 +) + +require ( + buf.build/gen/go/bufbuild/protovalidate/protocolbuffers/go v1.34.2-20240717164558-a6c49f84cc0f.2 // indirect + buf.build/gen/go/pluginrpc/pluginrpc/protocolbuffers/go v1.34.2-20240828222655-5345c0a56177.2 // indirect + github.com/antlr4-go/antlr/v4 v4.13.0 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/google/cel-go v0.21.0 // indirect + github.com/mattn/go-isatty v0.0.20 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/spf13/pflag v1.0.5 // indirect + github.com/stoewer/go-strcase v1.3.0 // indirect + golang.org/x/exp v0.0.0-20240823005443-9b4947da3948 // indirect + golang.org/x/sync v0.8.0 // indirect + golang.org/x/sys v0.25.0 // indirect + golang.org/x/text v0.18.0 // indirect + google.golang.org/genproto/googleapis/api v0.0.0-20240903143218-8af14fe29dc1 // indirect + google.golang.org/genproto/googleapis/rpc v0.0.0-20240903143218-8af14fe29dc1 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..f341670 --- /dev/null +++ b/go.sum @@ -0,0 +1,64 @@ +buf.build/gen/go/bufbuild/bufplugin/protocolbuffers/go v1.34.2-20240904175752-5f2ce91228e8.2 h1:EuphpPzJKitRQFq4KtGm0ie55KQhMYeG6QLQeNWMfWk= +buf.build/gen/go/bufbuild/bufplugin/protocolbuffers/go v1.34.2-20240904175752-5f2ce91228e8.2/go.mod h1:B+9TKHRYqoAUW57pLjhkLOnBCu0DQYMV+f7imQ9nXwI= +buf.build/gen/go/bufbuild/protovalidate/protocolbuffers/go v1.34.2-20240717164558-a6c49f84cc0f.2 h1:SZRVx928rbYZ6hEKUIN+vtGDkl7uotABRWGY4OAg5gM= +buf.build/gen/go/bufbuild/protovalidate/protocolbuffers/go v1.34.2-20240717164558-a6c49f84cc0f.2/go.mod h1:ylS4c28ACSI59oJrOdW4pHS4n0Hw4TgSPHn8rpHl4Yw= +buf.build/gen/go/pluginrpc/pluginrpc/protocolbuffers/go v1.34.2-20240828222655-5345c0a56177.2 h1:oSi+Adw4xvIjXrW8eY8QGR3sBdfWeY5HN/RefnRt52M= +buf.build/gen/go/pluginrpc/pluginrpc/protocolbuffers/go v1.34.2-20240828222655-5345c0a56177.2/go.mod h1:GjH0gjlY/ns16X8d6eaXV2W+6IFwsO5Ly9WVnzyd1E0= +github.com/antlr4-go/antlr/v4 v4.13.0 h1:lxCg3LAv+EUK6t1i0y1V6/SLeUi0eKEKdhQAlS8TVTI= +github.com/antlr4-go/antlr/v4 v4.13.0/go.mod h1:pfChB/xh/Unjila75QW7+VU4TSnWnnk9UTnmpPaOR2g= +github.com/bufbuild/protocompile v0.14.1 h1:iA73zAf/fyljNjQKwYzUHD6AD4R8KMasmwa/FBatYVw= +github.com/bufbuild/protocompile v0.14.1/go.mod h1:ppVdAIhbr2H8asPk6k4pY7t9zB1OU5DoEw9xY/FUi1c= +github.com/bufbuild/protovalidate-go v0.6.5 h1:WucDKXIbK22WjkO8A8J6Yyxxy0jl91Oe9LSMduq3YEE= +github.com/bufbuild/protovalidate-go v0.6.5/go.mod h1:LHDiGCWSM3GagZEnyEZ1sPtFwi6Ja4tVTi/DCc+iDFI= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/envoyproxy/protoc-gen-validate v1.1.0 h1:tntQDh69XqOCOZsDz0lVJQez/2L6Uu2PdjCQwWCJ3bM= +github.com/envoyproxy/protoc-gen-validate v1.1.0/go.mod h1:sXRDRVmzEbkM7CVcM06s9shE/m23dg3wzjl0UWqJ2q4= +github.com/google/cel-go v0.21.0 h1:cl6uW/gxN+Hy50tNYvI691+sXxioCnstFzLp2WO4GCI= +github.com/google/cel-go v0.21.0/go.mod h1:rHUlWCcBKgyEk+eV03RPdZUekPp6YcJwV0FxuUksYxc= +github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= +github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= +github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= +github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= +github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= +github.com/stoewer/go-strcase v1.3.0 h1:g0eASXYtp+yvN9fK8sH94oCIk0fau9uV1/ZdJ0AVEzs= +github.com/stoewer/go-strcase v1.3.0/go.mod h1:fAH5hQ5pehh+j3nZfvwdk2RgEgQjAoM8wodgtPmh1xo= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= +github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +golang.org/x/exp v0.0.0-20240823005443-9b4947da3948 h1:kx6Ds3MlpiUHKj7syVnbp57++8WpuKPcR5yjLBjvLEA= +golang.org/x/exp v0.0.0-20240823005443-9b4947da3948/go.mod h1:akd2r19cwCdwSwWeIdzYQGa/EZZyqcOdwWiwj5L5eKQ= +golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ= +golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.25.0 h1:r+8e+loiHxRqhXVl6ML1nO3l1+oFoWbnlu2Ehimmi34= +golang.org/x/sys v0.25.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/text v0.18.0 h1:XvMDiNzPAl0jr17s6W9lcaIhGUfUORdGCNsuLmPG224= +golang.org/x/text v0.18.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY= +google.golang.org/genproto/googleapis/api v0.0.0-20240903143218-8af14fe29dc1 h1:hjSy6tcFQZ171igDaN5QHOw2n6vx40juYbC/x67CEhc= +google.golang.org/genproto/googleapis/api v0.0.0-20240903143218-8af14fe29dc1/go.mod h1:qpvKtACPCQhAdu3PyQgV4l3LMXZEtft7y8QcarRsp9I= +google.golang.org/genproto/googleapis/rpc v0.0.0-20240903143218-8af14fe29dc1 h1:pPJltXNxVzT4pK9yD8vR9X75DaWYYmLGMsEvBfFQZzQ= +google.golang.org/genproto/googleapis/rpc v0.0.0-20240903143218-8af14fe29dc1/go.mod h1:UqMtugtsSgubUsoxbuAoiCXvqvErP7Gf0so0mK9tHxU= +google.golang.org/protobuf v1.34.2 h1:6xV6lTsCfpGD21XK49h7MhtcApnLqkfYgPcdHftf6hg= +google.golang.org/protobuf v1.34.2/go.mod h1:qYOHts0dSfpeUzUFpOMr/WGzszTmLH+DiWniOlNbLDw= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo= +gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +pluginrpc.com/pluginrpc v0.2.0 h1:mUuxA2Vtt1/buDsnR1HscuAu56Y/3ax5oPPy+9q/Zr4= +pluginrpc.com/pluginrpc v0.2.0/go.mod h1:rX3qwV56YEwfayfyfEovbQ+KMVDjgJ8icHy0WTaUXRY= diff --git a/internal/gen/buf/plugin/check/v1/v1pluginrpc/check_service.pluginrpc.go b/internal/gen/buf/plugin/check/v1/v1pluginrpc/check_service.pluginrpc.go new file mode 100644 index 0000000..d9c901e --- /dev/null +++ b/internal/gen/buf/plugin/check/v1/v1pluginrpc/check_service.pluginrpc.go @@ -0,0 +1,219 @@ +// Copyright 2024 Buf Technologies, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by protoc-gen-pluginrpc-go. DO NOT EDIT. +// +// Source: buf/plugin/check/v1/check_service.proto + +package v1pluginrpc + +import ( + v1 "buf.build/gen/go/bufbuild/bufplugin/protocolbuffers/go/buf/plugin/check/v1" + context "context" + fmt "fmt" + pluginrpc "pluginrpc.com/pluginrpc" +) + +// This is a compile-time assertion to ensure that this generated file and the pluginrpc package are +// compatible. If you get a compiler error that this constant is not defined, this code was +// generated with a version of pluginrpc newer than the one compiled into your binary. You can fix +// the problem by either regenerating this code with an older version of pluginrpc or updating the +// pluginrpc version compiled into your binary. +const _ = pluginrpc.IsAtLeastVersion0_1_0 + +const ( + // CheckServiceCheckPath is the path of the CheckService's Check RPC. + CheckServiceCheckPath = "/buf.plugin.check.v1.CheckService/Check" + // CheckServiceListRulesPath is the path of the CheckService's ListRules RPC. + CheckServiceListRulesPath = "/buf.plugin.check.v1.CheckService/ListRules" + // CheckServiceListCategoriesPath is the path of the CheckService's ListCategories RPC. + CheckServiceListCategoriesPath = "/buf.plugin.check.v1.CheckService/ListCategories" +) + +// CheckServiceSpecBuilder builds a Spec for the buf.plugin.check.v1.CheckService service. +type CheckServiceSpecBuilder struct { + Check []pluginrpc.ProcedureOption + ListRules []pluginrpc.ProcedureOption + ListCategories []pluginrpc.ProcedureOption +} + +// Build builds a Spec for the buf.plugin.check.v1.CheckService service. +func (s CheckServiceSpecBuilder) Build() (pluginrpc.Spec, error) { + procedures := make([]pluginrpc.Procedure, 0, 3) + procedure, err := pluginrpc.NewProcedure(CheckServiceCheckPath, s.Check...) + if err != nil { + return nil, err + } + procedures = append(procedures, procedure) + procedure, err = pluginrpc.NewProcedure(CheckServiceListRulesPath, s.ListRules...) + if err != nil { + return nil, err + } + procedures = append(procedures, procedure) + procedure, err = pluginrpc.NewProcedure(CheckServiceListCategoriesPath, s.ListCategories...) + if err != nil { + return nil, err + } + procedures = append(procedures, procedure) + return pluginrpc.NewSpec(procedures) +} + +// CheckServiceClient is a client for the buf.plugin.check.v1.CheckService service. +type CheckServiceClient interface { + // Check a set of Files for failures. + // + // All Annotations returned will have an ID that is contained within a Rule listed by ListRules. + Check(context.Context, *v1.CheckRequest, ...pluginrpc.CallOption) (*v1.CheckResponse, error) + // List all rules that this service implements. + ListRules(context.Context, *v1.ListRulesRequest, ...pluginrpc.CallOption) (*v1.ListRulesResponse, error) + // List all categories that this service implements. + ListCategories(context.Context, *v1.ListCategoriesRequest, ...pluginrpc.CallOption) (*v1.ListCategoriesResponse, error) +} + +// NewCheckServiceClient constructs a client for the buf.plugin.check.v1.CheckService service. +func NewCheckServiceClient(client pluginrpc.Client) (CheckServiceClient, error) { + return &checkServiceClient{ + client: client, + }, nil +} + +// CheckServiceHandler is an implementation of the buf.plugin.check.v1.CheckService service. +type CheckServiceHandler interface { + // Check a set of Files for failures. + // + // All Annotations returned will have an ID that is contained within a Rule listed by ListRules. + Check(context.Context, *v1.CheckRequest) (*v1.CheckResponse, error) + // List all rules that this service implements. + ListRules(context.Context, *v1.ListRulesRequest) (*v1.ListRulesResponse, error) + // List all categories that this service implements. + ListCategories(context.Context, *v1.ListCategoriesRequest) (*v1.ListCategoriesResponse, error) +} + +// CheckServiceServer serves the buf.plugin.check.v1.CheckService service. +type CheckServiceServer interface { + // Check a set of Files for failures. + // + // All Annotations returned will have an ID that is contained within a Rule listed by ListRules. + Check(context.Context, pluginrpc.HandleEnv, ...pluginrpc.HandleOption) error + // List all rules that this service implements. + ListRules(context.Context, pluginrpc.HandleEnv, ...pluginrpc.HandleOption) error + // List all categories that this service implements. + ListCategories(context.Context, pluginrpc.HandleEnv, ...pluginrpc.HandleOption) error +} + +// NewCheckServiceServer constructs a server for the buf.plugin.check.v1.CheckService service. +func NewCheckServiceServer(handler pluginrpc.Handler, checkServiceHandler CheckServiceHandler) CheckServiceServer { + return &checkServiceServer{ + handler: handler, + checkServiceHandler: checkServiceHandler, + } +} + +// RegisterCheckServiceServer registers the server for the buf.plugin.check.v1.CheckService service. +func RegisterCheckServiceServer(serverRegistrar pluginrpc.ServerRegistrar, checkServiceServer CheckServiceServer) { + serverRegistrar.Register(CheckServiceCheckPath, checkServiceServer.Check) + serverRegistrar.Register(CheckServiceListRulesPath, checkServiceServer.ListRules) + serverRegistrar.Register(CheckServiceListCategoriesPath, checkServiceServer.ListCategories) +} + +// *** PRIVATE *** + +// checkServiceClient implements CheckServiceClient. +type checkServiceClient struct { + client pluginrpc.Client +} + +// Check calls buf.plugin.check.v1.CheckService.Check. +func (c *checkServiceClient) Check(ctx context.Context, req *v1.CheckRequest, opts ...pluginrpc.CallOption) (*v1.CheckResponse, error) { + res := &v1.CheckResponse{} + if err := c.client.Call(ctx, CheckServiceCheckPath, req, res, opts...); err != nil { + return nil, err + } + return res, nil +} + +// ListRules calls buf.plugin.check.v1.CheckService.ListRules. +func (c *checkServiceClient) ListRules(ctx context.Context, req *v1.ListRulesRequest, opts ...pluginrpc.CallOption) (*v1.ListRulesResponse, error) { + res := &v1.ListRulesResponse{} + if err := c.client.Call(ctx, CheckServiceListRulesPath, req, res, opts...); err != nil { + return nil, err + } + return res, nil +} + +// ListCategories calls buf.plugin.check.v1.CheckService.ListCategories. +func (c *checkServiceClient) ListCategories(ctx context.Context, req *v1.ListCategoriesRequest, opts ...pluginrpc.CallOption) (*v1.ListCategoriesResponse, error) { + res := &v1.ListCategoriesResponse{} + if err := c.client.Call(ctx, CheckServiceListCategoriesPath, req, res, opts...); err != nil { + return nil, err + } + return res, nil +} + +// checkServiceServer implements CheckServiceServer. +type checkServiceServer struct { + handler pluginrpc.Handler + checkServiceHandler CheckServiceHandler +} + +// Check calls buf.plugin.check.v1.CheckService.Check. +func (c *checkServiceServer) Check(ctx context.Context, handleEnv pluginrpc.HandleEnv, options ...pluginrpc.HandleOption) error { + return c.handler.Handle( + ctx, + handleEnv, + &v1.CheckRequest{}, + func(ctx context.Context, anyReq any) (any, error) { + req, ok := anyReq.(*v1.CheckRequest) + if !ok { + return nil, fmt.Errorf("could not cast %T to a *v1.CheckRequest", anyReq) + } + return c.checkServiceHandler.Check(ctx, req) + }, + options..., + ) +} + +// ListRules calls buf.plugin.check.v1.CheckService.ListRules. +func (c *checkServiceServer) ListRules(ctx context.Context, handleEnv pluginrpc.HandleEnv, options ...pluginrpc.HandleOption) error { + return c.handler.Handle( + ctx, + handleEnv, + &v1.ListRulesRequest{}, + func(ctx context.Context, anyReq any) (any, error) { + req, ok := anyReq.(*v1.ListRulesRequest) + if !ok { + return nil, fmt.Errorf("could not cast %T to a *v1.ListRulesRequest", anyReq) + } + return c.checkServiceHandler.ListRules(ctx, req) + }, + options..., + ) +} + +// ListCategories calls buf.plugin.check.v1.CheckService.ListCategories. +func (c *checkServiceServer) ListCategories(ctx context.Context, handleEnv pluginrpc.HandleEnv, options ...pluginrpc.HandleOption) error { + return c.handler.Handle( + ctx, + handleEnv, + &v1.ListCategoriesRequest{}, + func(ctx context.Context, anyReq any) (any, error) { + req, ok := anyReq.(*v1.ListCategoriesRequest) + if !ok { + return nil, fmt.Errorf("could not cast %T to a *v1.ListCategoriesRequest", anyReq) + } + return c.checkServiceHandler.ListCategories(ctx, req) + }, + options..., + ) +} diff --git a/internal/pkg/cache/singleton.go b/internal/pkg/cache/singleton.go new file mode 100644 index 0000000..6b58cf2 --- /dev/null +++ b/internal/pkg/cache/singleton.go @@ -0,0 +1,67 @@ +// Copyright 2024 Buf Technologies, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cache + +import ( + "context" + "errors" + "sync" +) + +// Singleton is a singleton. +// +// It must be constructed with NewSingleton. +type Singleton[V any] struct { + get func(context.Context) (V, error) + value V + err error + // Storing a bool to not deal with generic zero/nil comparisons. + called bool + lock sync.RWMutex +} + +// NewSingleton returns a new Singleton. +// +// The get function must only return the zero value of V on error. +func NewSingleton[V any](get func(context.Context) (V, error)) *Singleton[V] { + return &Singleton[V]{ + get: get, + } +} + +// Get gets the value, or returns the error in loading the value. +// +// The given context will be used to load the value if not already loaded. +// +// If Singletons call Singletons, lock ordering must be respected. +func (s *Singleton[V]) Get(ctx context.Context) (V, error) { + if s.get == nil { + var zero V + return zero, errors.New("must create singleton with NewSingleton and a non-nil get function") + } + s.lock.RLock() + if s.called { + s.lock.RUnlock() + return s.value, s.err + } + s.lock.RUnlock() + s.lock.Lock() + defer s.lock.Unlock() + if !s.called { + s.value, s.err = s.get(ctx) + s.called = true + } + return s.value, s.err +} diff --git a/internal/pkg/cache/singleton_test.go b/internal/pkg/cache/singleton_test.go new file mode 100644 index 0000000..b625d9c --- /dev/null +++ b/internal/pkg/cache/singleton_test.go @@ -0,0 +1,57 @@ +// Copyright 2024 Buf Technologies, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cache + +import ( + "context" + "fmt" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestBasic(t *testing.T) { + t.Parallel() + + ctx := context.Background() + + var count int + singleton := NewSingleton( + func(context.Context) (int, error) { + count++ + return count, nil + }, + ) + value, err := singleton.Get(ctx) + require.NoError(t, err) + require.Equal(t, 1, value) + value, err = singleton.Get(ctx) + require.NoError(t, err) + require.Equal(t, 1, value) + + count = 0 + singleton = NewSingleton( + func(context.Context) (int, error) { + count++ + return 0, fmt.Errorf("%d", count) + }, + ) + _, err = singleton.Get(ctx) + require.Error(t, err) + require.Equal(t, "1", err.Error()) + _, err = singleton.Get(ctx) + require.Error(t, err) + require.Equal(t, "1", err.Error()) +} diff --git a/internal/pkg/thread/thread.go b/internal/pkg/thread/thread.go new file mode 100644 index 0000000..ccd858b --- /dev/null +++ b/internal/pkg/thread/thread.go @@ -0,0 +1,126 @@ +// Copyright 2024 Buf Technologies, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package thread + +import ( + "context" + "errors" + "runtime" + "sync" +) + +var defaultParallelism = runtime.GOMAXPROCS(0) + +// Parallelize runs the jobs in parallel. +// +// Returns the combined error from the jobs. +func Parallelize(ctx context.Context, jobs []func(context.Context) error, options ...ParallelizeOption) error { + parallelizeOptions := newParallelizeOptions() + for _, option := range options { + option(parallelizeOptions) + } + switch len(jobs) { + case 0: + return nil + case 1: + return jobs[0](ctx) + } + parallelism := parallelizeOptions.parallelism + if parallelism < 1 { + parallelism = defaultParallelism + } + var cancel context.CancelFunc + if parallelizeOptions.cancelOnFailure { + ctx, cancel = context.WithCancel(ctx) + defer cancel() + } + semaphoreC := make(chan struct{}, parallelism) + var retErr error + var wg sync.WaitGroup + var lock sync.Mutex + var stop bool + for _, job := range jobs { + if stop { + break + } + // We always want context cancellation/deadline expiration to take + // precedence over the semaphore unblocking, but select statements choose + // among the unblocked non-default cases pseudorandomly. To correctly + // enforce precedence, use a similar pattern to the check-lock-check + // pattern common with sync.RWMutex: check the context twice, and only do + // the semaphore-protected work in the innermost default case. + select { + case <-ctx.Done(): + stop = true + retErr = errors.Join(retErr, ctx.Err()) + case semaphoreC <- struct{}{}: + select { + case <-ctx.Done(): + stop = true + retErr = errors.Join(retErr, ctx.Err()) + default: + job := job + wg.Add(1) + go func() { + if err := job(ctx); err != nil { + lock.Lock() + retErr = errors.Join(retErr, err) + lock.Unlock() + if cancel != nil { + cancel() + } + } + // This will never block. + <-semaphoreC + wg.Done() + }() + } + } + } + wg.Wait() + return retErr +} + +// ParallelizeOption is an option to Parallelize. +type ParallelizeOption func(*parallelizeOptions) + +// WithParallelism returns a new ParallelizeOption that will run up to the given +// number of goroutines simultaneously. +// +// Values less than 1 are ignored. +// +// The default is runtime.GOMAXPROCS(0). +func WithParallelism(parallelism int) ParallelizeOption { + return func(parallelizeOptions *parallelizeOptions) { + parallelizeOptions.parallelism = parallelism + } +} + +// ParallelizeWithCancelOnFailure returns a new ParallelizeOption that will attempt +// to cancel all other jobs via context cancellation if any job fails. +func ParallelizeWithCancelOnFailure() ParallelizeOption { + return func(parallelizeOptions *parallelizeOptions) { + parallelizeOptions.cancelOnFailure = true + } +} + +type parallelizeOptions struct { + parallelism int + cancelOnFailure bool +} + +func newParallelizeOptions() *parallelizeOptions { + return ¶llelizeOptions{} +} diff --git a/internal/pkg/thread/thread_test.go b/internal/pkg/thread/thread_test.go new file mode 100644 index 0000000..4360634 --- /dev/null +++ b/internal/pkg/thread/thread_test.go @@ -0,0 +1,67 @@ +// Copyright 2024 Buf Technologies, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package thread + +import ( + "context" + "sync/atomic" + "testing" + + "github.com/stretchr/testify/assert" +) + +// The bulk of the code relies on subtle timing that's difficult to +// reproduce, but we can test the most basic use cases. + +func TestParallelizeSimple(t *testing.T) { + t.Parallel() + + numJobs := 10 + var executed atomic.Int64 + jobs := make([]func(context.Context) error, 0, numJobs) + for i := 0; i < numJobs; i++ { + jobs = append( + jobs, + func(context.Context) error { + executed.Add(1) + return nil + }, + ) + } + ctx := context.Background() + assert.NoError(t, Parallelize(ctx, jobs)) + assert.Equal(t, int64(numJobs), executed.Load()) +} + +func TestParallelizeImmediateCancellation(t *testing.T) { + t.Parallel() + + numJobs := 10 + var executed atomic.Int64 + jobs := make([]func(context.Context) error, 0, numJobs) + for i := 0; i < numJobs; i++ { + jobs = append( + jobs, + func(context.Context) error { + executed.Add(1) + return nil + }, + ) + } + ctx, cancel := context.WithCancel(context.Background()) + cancel() + assert.Error(t, Parallelize(ctx, jobs)) + assert.Equal(t, int64(0), executed.Load()) +} diff --git a/internal/pkg/xslices/xslices.go b/internal/pkg/xslices/xslices.go new file mode 100644 index 0000000..f9e8b29 --- /dev/null +++ b/internal/pkg/xslices/xslices.go @@ -0,0 +1,103 @@ +// Copyright 2024 Buf Technologies, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package xslices + +import ( + "cmp" + "slices" +) + +// Filter filters the slice to only the values where f returns true. +func Filter[T any](s []T, f func(T) bool) []T { + sf := make([]T, 0, len(s)) + for _, e := range s { + if f(e) { + sf = append(sf, e) + } + } + return sf +} + +// FilterError filters the slice to only the values where f returns true. +// +// Returns error the first time f returns error. +func FilterError[T any](s []T, f func(T) (bool, error)) ([]T, error) { + sf := make([]T, 0, len(s)) + for _, e := range s { + ok, err := f(e) + if err != nil { + return nil, err + } + if ok { + sf = append(sf, e) + } + } + return sf, nil +} + +// Map maps the slice. +func Map[T1, T2 any](s []T1, f func(T1) T2) []T2 { + if s == nil { + return nil + } + sm := make([]T2, len(s)) + for i, e := range s { + sm[i] = f(e) + } + return sm +} + +// MapError maps the slice. +// +// Returns error the first time f returns error. +func MapError[T1, T2 any](s []T1, f func(T1) (T2, error)) ([]T2, error) { + if s == nil { + return nil, nil + } + sm := make([]T2, len(s)) + for i, e := range s { + em, err := f(e) + if err != nil { + return nil, err + } + sm[i] = em + } + return sm, nil +} + +// MapKeysToSortedSlice converts the map's keys to a sorted slice. +func MapKeysToSortedSlice[M ~map[K]V, K cmp.Ordered, V any](m M) []K { + s := MapKeysToSlice(m) + slices.Sort(s) + return s +} + +// MapKeysToSlice converts the map's keys to a slice. +func MapKeysToSlice[K comparable, V any](m map[K]V) []K { + s := make([]K, 0, len(m)) + for k := range m { + s = append(s, k) + } + return s +} + +// ToStructMap converts the slice to a map with struct{} values. +func ToStructMap[T comparable](s []T) map[T]struct{} { + m := make(map[T]struct{}, len(s)) + for _, e := range s { + m[e] = struct{}{} + } + return m +}