diff --git a/internal/bzlmod/go_mod.bzl b/internal/bzlmod/go_mod.bzl index 9d37fcf5a..283f128da 100644 --- a/internal/bzlmod/go_mod.bzl +++ b/internal/bzlmod/go_mod.bzl @@ -57,7 +57,7 @@ def parse_go_mod(content, path): continue if not current_directive: - if tokens[0] not in ["module", "go", "require", "replace", "exclude", "retract"]: + if tokens[0] not in ["module", "go", "require", "replace", "exclude", "retract", "toolchain"]: fail("{}:{}: unexpected token '{}' at start of line".format(path, line_no, tokens[0])) if len(tokens) == 1: fail("{}:{}: expected another token after '{}'".format(path, line_no, tokens[0])) @@ -98,7 +98,9 @@ def parse_go_mod(content, path): if not go: # "As of the Go 1.17 release, if the go directive is missing, go 1.16 is assumed." go = "1.16" - major, minor = go.split(".") + + # The go directive can contain patch and pre-release versions, but we omit them. + major, minor = go.split(".")[:2] return struct( module = module, diff --git a/tests/bzlmod/go_mod_test.bzl b/tests/bzlmod/go_mod_test.bzl index a1fcedeb1..ba7519ee2 100644 --- a/tests/bzlmod/go_mod_test.bzl +++ b/tests/bzlmod/go_mod_test.bzl @@ -46,6 +46,27 @@ def _go_mod_test_impl(ctx): go_mod_test = unittest.make(_go_mod_test_impl) +_GO_MOD_21_CONTENT = """go 1.21.0rc1 + +module example.com + +toolchain go1.22.2 +""" + +_EXPECTED_GO_MOD_21_PARSE_RESULT = struct( + go = (1, 21), + module = "example.com", + replace_map = {}, + require = (), +) + +def _go_mod_21_test_impl(ctx): + env = unittest.begin(ctx) + asserts.equals(env, _EXPECTED_GO_MOD_21_PARSE_RESULT, parse_go_mod(_GO_MOD_21_CONTENT, "/go.mod")) + return unittest.end(env) + +go_mod_21_test = unittest.make(_go_mod_21_test_impl) + _GO_SUM_CONTENT = """cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/bazelbuild/buildtools v0.0.0-20220531122519-a43aed7014c8 h1:fmdo+fvvWlhldUcqkhAMpKndSxMN3vH5l7yow5cEaiQ= @@ -70,5 +91,6 @@ def go_mod_test_suite(name): unittest.suite( name, go_mod_test, + go_mod_21_test, go_sum_test, )