diff --git a/pkg/ccl/utilccl/license_check.go b/pkg/ccl/utilccl/license_check.go index 35c42759b982..df57558094d5 100644 --- a/pkg/ccl/utilccl/license_check.go +++ b/pkg/ccl/utilccl/license_check.go @@ -10,6 +10,8 @@ package utilccl import ( "context" + "fmt" + "strconv" "strings" "sync/atomic" "time" @@ -35,8 +37,36 @@ var enterpriseLicense = settings.RegisterStringSetting( "", settings.WithValidateString( func(sv *settings.Values, s string) error { - _, err := decode(s) - return err + // lookup diagnostic reporting setting + reportingSetting, ok, _ := settings.LookupForLocalAccess("diagnostics.reporting.enabled", true /* forSystemTenant */) + if !ok { + return fmt.Errorf("unable to find setting for diagnostic reporting") + } + reportingStr, err := reportingSetting.DecodeToString(reportingSetting.Encoded(sv)) + if err != nil { + return err + } + + reporting, err := strconv.ParseBool(reportingStr) + if err != nil { + return err + } + + // decode license + license, err := decode(s) + if err != nil { + return err + } + if license == nil { + return nil + } + + // if the license is limited and reporting is disabled, do not allow it to be set + isLimited := license.Type == licenseccl.License_Free || license.Type == licenseccl.License_Trial + if !reporting && isLimited { + return errors.New("diagnostics.reporting.enabled must be true to use this license") + } + return nil }, ), // Even though string settings are non-reportable by default, we diff --git a/pkg/ccl/utilccl/license_check_test.go b/pkg/ccl/utilccl/license_check_test.go index f29ea2d636f8..7c6d7691f30a 100644 --- a/pkg/ccl/utilccl/license_check_test.go +++ b/pkg/ccl/utilccl/license_check_test.go @@ -11,6 +11,7 @@ package utilccl import ( "context" "fmt" + "strconv" "testing" "time" @@ -69,6 +70,82 @@ func TestSettingAndCheckingLicense(t *testing.T) { } } +// test setting a license with a specific diagnostics setting +func TestSetLicenseWithDiagnosticsReporting(t *testing.T) { + ctx := context.Background() + st := cluster.MakeTestingClusterSettings() + t0 := timeutil.Unix(0, 0) + + for _, tc := range []struct { + lit licenseccl.License_Type + diagnostics bool + err string + }{ + {licenseccl.License_Free, false, "unable to disable diagnostics with license type Free"}, + {licenseccl.License_Free, true, ""}, + {licenseccl.License_Trial, false, "unable to disable diagnostics with license type Trial"}, + {licenseccl.License_Trial, true, ""}, + {licenseccl.License_NonCommercial, false, ""}, + {licenseccl.License_NonCommercial, true, ""}, + {licenseccl.License_Enterprise, false, ""}, + {licenseccl.License_Enterprise, true, ""}, + {licenseccl.License_Evaluation, false, ""}, + {licenseccl.License_Evaluation, true, ""}, + } { + lic, _ := (&licenseccl.License{ + Type: tc.lit, + ValidUntilUnixSec: t0.AddDate(0, 1, 0).Unix(), + }).Encode() + updater := st.MakeUpdater() + if err := setDiagnosticsReporting(ctx, updater, tc.diagnostics); err != nil { + t.Fatal(err) + } + if err := setLicense(ctx, updater, lic); !testutils.IsError( + err, tc.err, + ) { + t.Fatalf("%s %t: expected err %q, got %v", tc.lit, tc.diagnostics, tc.err, err) + } + + } +} + +// test setting the diagnostics setting with a specific license (reverse of the above) +func TestSetDiagnosticsReportingWithLicense(t *testing.T) { + ctx := context.Background() + st := cluster.MakeTestingClusterSettings() + t0 := timeutil.Unix(0, 0) + for _, tc := range []struct { + lit licenseccl.License_Type + diagnostics bool + err string + }{ + {licenseccl.License_Free, false, "unable to disable diagnostics with license type Free"}, + {licenseccl.License_Free, true, ""}, + {licenseccl.License_Trial, false, "unable to disable diagnostics with license type Trial"}, + {licenseccl.License_Trial, true, ""}, + {licenseccl.License_NonCommercial, false, ""}, + {licenseccl.License_NonCommercial, true, ""}, + {licenseccl.License_Enterprise, false, ""}, + {licenseccl.License_Enterprise, true, ""}, + {licenseccl.License_Evaluation, false, ""}, + {licenseccl.License_Evaluation, true, ""}, + } { + lic, _ := (&licenseccl.License{ + Type: tc.lit, + ValidUntilUnixSec: t0.AddDate(0, 1, 0).Unix(), + }).Encode() + updater := st.MakeUpdater() + if err := setLicense(ctx, updater, lic); err != nil { + t.Fatal(err) + } + if err := setDiagnosticsReporting(ctx, updater, tc.diagnostics); !testutils.IsError( + err, tc.err, + ) { + t.Fatalf("%s %t: expected err %q, got %v", tc.lit, tc.diagnostics, tc.err, err) + } + } +} + func TestGetLicenseTypePresent(t *testing.T) { defer leaktest.AfterTest(t)() @@ -266,6 +343,13 @@ func setLicense(ctx context.Context, updater settings.Updater, val string) error }) } +func setDiagnosticsReporting(ctx context.Context, updater settings.Updater, val bool) error { + return updater.Set(ctx, "diagnostics.reporting.enabled", settings.EncodedValue{ + Value: strconv.FormatBool(val), + Type: "b", + }) +} + func TestRefreshLicenseEnforcerOnLicenseChange(t *testing.T) { defer leaktest.AfterTest(t)() diff --git a/pkg/cmd/cockroach-oss/BUILD.bazel b/pkg/cmd/cockroach-oss/BUILD.bazel index 731a7b5a6f75..5c1c3c33b571 100644 --- a/pkg/cmd/cockroach-oss/BUILD.bazel +++ b/pkg/cmd/cockroach-oss/BUILD.bazel @@ -23,7 +23,6 @@ disallowed_imports_test( "cockroach-oss", disallowed_list = [], disallowed_prefixes = [ - "pkg/ccl", "pkg/ui/distccl", ], ) diff --git a/pkg/settings/bool.go b/pkg/settings/bool.go index 23efccc02e3e..02f8ffcca9fe 100644 --- a/pkg/settings/bool.go +++ b/pkg/settings/bool.go @@ -13,6 +13,8 @@ package settings import ( "context" "strconv" + + "github.com/cockroachdb/errors" ) // BoolSetting is the interface of a setting variable that will be @@ -21,6 +23,7 @@ import ( type BoolSetting struct { common defaultValue bool + validateFn func(*Values, bool) error } var _ internalSetting = &BoolSetting{} @@ -82,11 +85,29 @@ var _ = (*BoolSetting).Default // For testing usage only. func (b *BoolSetting) Override(ctx context.Context, sv *Values, v bool) { sv.setValueOrigin(ctx, b.slot, OriginOverride) - b.set(ctx, sv, v) + b.setOnValues(ctx, sv, v) sv.setDefaultOverride(b.slot, v) } -func (b *BoolSetting) set(ctx context.Context, sv *Values, v bool) { +// Validate that a value conforms with the validation function. +func (b *BoolSetting) Validate(sv *Values, v bool) error { + if b.validateFn != nil { + if err := b.validateFn(sv, v); err != nil { + return err + } + } + return nil +} + +func (b *BoolSetting) set(ctx context.Context, sv *Values, v bool) error { + if err := b.Validate(sv, v); err != nil { + return err + } + b.setOnValues(ctx, sv, v) + return nil +} + +func (b *BoolSetting) setOnValues(ctx context.Context, sv *Values, v bool) { vInt := int64(0) if v { vInt = 1 @@ -99,8 +120,7 @@ func (b *BoolSetting) decodeAndSet(ctx context.Context, sv *Values, encoded stri if err != nil { return err } - b.set(ctx, sv, v) - return nil + return b.set(ctx, sv, v) } func (b *BoolSetting) decodeAndSetDefaultOverride( @@ -117,17 +137,40 @@ func (b *BoolSetting) decodeAndSetDefaultOverride( func (b *BoolSetting) setToDefault(ctx context.Context, sv *Values) { // See if the default value was overridden. if val := sv.getDefaultOverride(b.slot); val != nil { - b.set(ctx, sv, val.(bool)) + // As per the semantics of override, these values don't go through + // validation. + _ = b.set(ctx, sv, val.(bool)) return } - b.set(ctx, sv, b.defaultValue) + if err := b.set(ctx, sv, b.defaultValue); err != nil { + panic(err) + } } // RegisterBoolSetting defines a new setting with type bool. func RegisterBoolSetting( class Class, key InternalKey, desc string, defaultValue bool, opts ...SettingOption, ) *BoolSetting { - setting := &BoolSetting{defaultValue: defaultValue} + validateFn := func(sv *Values, val bool) error { + for _, opt := range opts { + switch { + case opt.commonOpt != nil: + continue + case opt.validateBoolFn != nil: + default: + panic(errors.AssertionFailedf("wrong validator type")) + } + if err := opt.validateBoolFn(sv, val); err != nil { + return err + } + } + return nil + } + // what to put here? + if err := validateFn(&Values{}, defaultValue); err != nil { + panic(errors.Wrap(err, "invalid default")) + } + setting := &BoolSetting{defaultValue: defaultValue, validateFn: validateFn} register(class, key, desc, setting) setting.apply(opts) return setting diff --git a/pkg/settings/options.go b/pkg/settings/options.go index 5c28eab00ec5..a2f2b67497c2 100644 --- a/pkg/settings/options.go +++ b/pkg/settings/options.go @@ -19,6 +19,7 @@ import ( // SettingOption is the type of an option that can be passed to Register. type SettingOption struct { commonOpt func(*common) + validateBoolFn func(*Values, bool) error validateDurationFn func(time.Duration) error validateInt64Fn func(int64) error validateFloat64Fn func(float64) error @@ -108,6 +109,11 @@ func WithValidateFloat(fn func(float64) error) SettingOption { return SettingOption{validateFloat64Fn: fn} } +// WithValidateBool adds a validation function for a boolean setting. +func WithValidateBool(fn func(*Values, bool) error) SettingOption { + return SettingOption{validateBoolFn: fn} +} + // WithValidateString adds a validation function for a string setting. func WithValidateString(fn func(*Values, string) error) SettingOption { return SettingOption{validateStringFn: fn} diff --git a/pkg/settings/validation_test.go b/pkg/settings/validation_test.go index 994dd2b75e21..264f841da437 100644 --- a/pkg/settings/validation_test.go +++ b/pkg/settings/validation_test.go @@ -20,6 +20,22 @@ import ( "github.com/cockroachdb/cockroach/pkg/testutils" ) +var cantBeTrue = settings.WithValidateBool(func(sv *settings.Values, b bool) error { + fmt.Println("testing it cant be true") + if b { + return fmt.Errorf("it cant be true") + } + return nil +}) + +var cantBeFalse = settings.WithValidateBool(func(sv *settings.Values, b bool) error { + fmt.Println("testing it cant be false") + if !b { + return fmt.Errorf("it cant be false") + } + return nil +}) + func TestValidationOptions(t *testing.T) { type subTest struct { val interface{} @@ -182,6 +198,20 @@ func TestValidationOptions(t *testing.T) { {val: 11, opt: settings.ByteSizeWithMinimum(10), expectedErr: ""}, }, }, + { + testLabel: "bool", + settingFn: func(n int, bval interface{}, opt settings.SettingOption) settings.Setting { + val := bval.(bool) + return settings.RegisterBoolSetting(settings.SystemOnly, settings.InternalKey(fmt.Sprintf("test-%d", n)), "desc", + val, opt) + }, + subTests: []subTest{ + {val: true, opt: cantBeTrue, expectedErr: "it cant be true"}, + {val: false, opt: cantBeTrue, expectedErr: ""}, + {val: true, opt: cantBeFalse, expectedErr: ""}, + {val: false, opt: cantBeFalse, expectedErr: "it cant be false"}, + }, + }, } for _, tc := range testCases { diff --git a/pkg/sql/set_cluster_setting.go b/pkg/sql/set_cluster_setting.go index 05660ed5bdf9..251d5834a865 100644 --- a/pkg/sql/set_cluster_setting.go +++ b/pkg/sql/set_cluster_setting.go @@ -782,6 +782,9 @@ func toSettingString( return "", errors.Errorf("cannot use %s %T value for string setting", d.ResolvedType(), d) case *settings.BoolSetting: if b, ok := d.(*tree.DBool); ok { + if err := setting.Validate(&st.SV, bool(*b)); err != nil { + return "", err + } return settings.EncodeBool(bool(*b)), nil } return "", errors.Errorf("cannot use %s %T value for bool setting", d.ResolvedType(), d) diff --git a/pkg/util/log/logcrash/BUILD.bazel b/pkg/util/log/logcrash/BUILD.bazel index b8b5ec29e571..919af41aa5a0 100644 --- a/pkg/util/log/logcrash/BUILD.bazel +++ b/pkg/util/log/logcrash/BUILD.bazel @@ -10,6 +10,7 @@ go_library( }, deps = [ "//pkg/build", + "//pkg/ccl/utilccl/licenseccl", "//pkg/settings", "//pkg/util/envutil", "//pkg/util/log", diff --git a/pkg/util/log/logcrash/crash_reporting.go b/pkg/util/log/logcrash/crash_reporting.go index d5ea9cdc0f97..162f4c603e87 100644 --- a/pkg/util/log/logcrash/crash_reporting.go +++ b/pkg/util/log/logcrash/crash_reporting.go @@ -18,6 +18,7 @@ import ( "time" "github.com/cockroachdb/cockroach/pkg/build" + "github.com/cockroachdb/cockroach/pkg/ccl/utilccl/licenseccl" "github.com/cockroachdb/cockroach/pkg/settings" "github.com/cockroachdb/cockroach/pkg/util/envutil" "github.com/cockroachdb/cockroach/pkg/util/log" @@ -56,8 +57,44 @@ var ( settings.ApplicationLevel, "diagnostics.reporting.enabled", "enable reporting diagnostic metrics to cockroach labs", - false, - settings.WithPublic) + true, + settings.WithPublic, + settings.WithValidateBool(func(sv *settings.Values, b bool) error { + // if the user wants to turn on diagnostics, short circuit to no error + if b { + return nil + } + + // attempt to get the license to verify ability to disable reporting + licenseSetting, ok, _ := settings.LookupForLocalAccess("enterprise.license", true /* forSystemTenant */) + if !ok { + log.Warning(context.Background(), "unable to find license configuring diagnostic reporting") + return nil + } + lic, err := licenseSetting.DecodeToString(licenseSetting.Encoded(sv)) + if err != nil { + log.Errorf(context.Background(), "error configuring diagnostics: %s", err) + return nil + } + + license, err := licenseccl.Decode(lic) + if err != nil { + log.Errorf(context.Background(), "error configuring diagnostics: %s", err) + return nil + } + if license == nil { + log.Warning(context.Background(), "unable to read license while setting diagnostics.reporting.enabled") + return nil + } + + // prevent user from disabling diagnostics if license is limited + isLimited := license.Type == licenseccl.License_Free || license.Type == licenseccl.License_Trial + if isLimited { + return fmt.Errorf("unable to disable diagnostics with license type %s", license.Type) + } + return nil + }), + ) // CrashReports wraps "diagnostics.reporting.send_crash_reports.enabled". CrashReports = settings.RegisterBoolSetting(