diff --git a/pkg/ccl/utilccl/license_check.go b/pkg/ccl/utilccl/license_check.go index 35c42759b982..fb02e5189fa9 100644 --- a/pkg/ccl/utilccl/license_check.go +++ b/pkg/ccl/utilccl/license_check.go @@ -10,6 +10,7 @@ package utilccl import ( "context" + "strconv" "strings" "sync/atomic" "time" @@ -35,8 +36,37 @@ var enterpriseLicense = settings.RegisterStringSetting( "", settings.WithValidateString( func(sv *settings.Values, s string) error { - _, err := decode(s) - return err + // the validator looks for a valid diagnostics setting. + reportingSetting, ok, _ := settings.LookupForLocalAccess("diagnostics.reporting.enabled", true /* forSystemTenant */) + if !ok { + log.Warning(context.Background(), "unable to find setting for diagnostic reporting") + return nil + } + reportingStr, err := reportingSetting.DecodeToString(reportingSetting.Encoded(sv)) + if err != nil { + return err + } + + reporting, err := strconv.ParseBool(reportingStr) + if err != nil { + return err + } + + license, err := decode(s) + if err != nil { + return err + } + if license == nil { + return nil + } + + // if the cluster license is limited and the reporting value passed in is + // disabled, then do not allow diagnostics to be set. + isLimited := license.Type == licenseccl.License_Free || license.Type == licenseccl.License_Trial + if !reporting && isLimited { + return errors.New("diagnostics.reporting.enabled must be true to use this license") + } + return nil }, ), // Even though string settings are non-reportable by default, we diff --git a/pkg/ccl/utilccl/license_check_test.go b/pkg/ccl/utilccl/license_check_test.go index f29ea2d636f8..746f4f20050b 100644 --- a/pkg/ccl/utilccl/license_check_test.go +++ b/pkg/ccl/utilccl/license_check_test.go @@ -11,6 +11,7 @@ package utilccl import ( "context" "fmt" + "strconv" "testing" "time" @@ -69,6 +70,82 @@ func TestSettingAndCheckingLicense(t *testing.T) { } } +// test setting a license with a specific diagnostics setting. +func TestSetLicenseWithDiagnosticsReporting(t *testing.T) { + ctx := context.Background() + st := cluster.MakeTestingClusterSettings() + t0 := timeutil.Unix(0, 0) + + for _, tc := range []struct { + lit licenseccl.License_Type + diagnostics bool + err string + }{ + {licenseccl.License_Free, false, "unable to disable diagnostics with license type Free"}, + {licenseccl.License_Free, true, ""}, + {licenseccl.License_Trial, false, "unable to disable diagnostics with license type Trial"}, + {licenseccl.License_Trial, true, ""}, + {licenseccl.License_NonCommercial, false, ""}, + {licenseccl.License_NonCommercial, true, ""}, + {licenseccl.License_Enterprise, false, ""}, + {licenseccl.License_Enterprise, true, ""}, + {licenseccl.License_Evaluation, false, ""}, + {licenseccl.License_Evaluation, true, ""}, + } { + lic, _ := (&licenseccl.License{ + Type: tc.lit, + ValidUntilUnixSec: t0.AddDate(0, 1, 0).Unix(), + }).Encode() + updater := st.MakeUpdater() + if err := setDiagnosticsReporting(ctx, updater, tc.diagnostics); err != nil { + t.Fatal(err) + } + if err := setLicense(ctx, updater, lic); !testutils.IsError( + err, tc.err, + ) { + t.Fatalf("%s %t: expected err %q, got %v", tc.lit, tc.diagnostics, tc.err, err) + } + + } +} + +// test setting the diagnostics setting with a specific license. +func TestSetDiagnosticsReportingWithLicense(t *testing.T) { + ctx := context.Background() + st := cluster.MakeTestingClusterSettings() + t0 := timeutil.Unix(0, 0) + for _, tc := range []struct { + lit licenseccl.License_Type + diagnostics bool + err string + }{ + {licenseccl.License_Free, false, "unable to disable diagnostics with license type Free"}, + {licenseccl.License_Free, true, ""}, + {licenseccl.License_Trial, false, "unable to disable diagnostics with license type Trial"}, + {licenseccl.License_Trial, true, ""}, + {licenseccl.License_NonCommercial, false, ""}, + {licenseccl.License_NonCommercial, true, ""}, + {licenseccl.License_Enterprise, false, ""}, + {licenseccl.License_Enterprise, true, ""}, + {licenseccl.License_Evaluation, false, ""}, + {licenseccl.License_Evaluation, true, ""}, + } { + lic, _ := (&licenseccl.License{ + Type: tc.lit, + ValidUntilUnixSec: t0.AddDate(0, 1, 0).Unix(), + }).Encode() + updater := st.MakeUpdater() + if err := setLicense(ctx, updater, lic); err != nil { + t.Fatal(err) + } + if err := setDiagnosticsReporting(ctx, updater, tc.diagnostics); !testutils.IsError( + err, tc.err, + ) { + t.Fatalf("%s %t: expected err %q, got %v", tc.lit, tc.diagnostics, tc.err, err) + } + } +} + func TestGetLicenseTypePresent(t *testing.T) { defer leaktest.AfterTest(t)() @@ -220,9 +297,10 @@ func TestTimeToEnterpriseLicenseExpiry(t *testing.T) { func TestApplyTenantLicenseWithLicense(t *testing.T) { defer leaktest.AfterTest(t)() - license, _ := (&licenseccl.License{ + license, err := (&licenseccl.License{ Type: licenseccl.License_Enterprise, }).Encode() + require.NoError(t, err) defer TestingDisableEnterprise()() defer envutil.TestSetEnv(t, "COCKROACH_TENANT_LICENSE", license)() @@ -266,6 +344,13 @@ func setLicense(ctx context.Context, updater settings.Updater, val string) error }) } +func setDiagnosticsReporting(ctx context.Context, updater settings.Updater, val bool) error { + return updater.Set(ctx, "diagnostics.reporting.enabled", settings.EncodedValue{ + Value: strconv.FormatBool(val), + Type: "b", + }) +} + func TestRefreshLicenseEnforcerOnLicenseChange(t *testing.T) { defer leaktest.AfterTest(t)() diff --git a/pkg/settings/bool.go b/pkg/settings/bool.go index 23efccc02e3e..9585b7e54fd0 100644 --- a/pkg/settings/bool.go +++ b/pkg/settings/bool.go @@ -13,6 +13,8 @@ package settings import ( "context" "strconv" + + "github.com/cockroachdb/errors" ) // BoolSetting is the interface of a setting variable that will be @@ -21,6 +23,7 @@ import ( type BoolSetting struct { common defaultValue bool + validateFn func(*Values, bool) error } var _ internalSetting = &BoolSetting{} @@ -86,6 +89,16 @@ func (b *BoolSetting) Override(ctx context.Context, sv *Values, v bool) { sv.setDefaultOverride(b.slot, v) } +// Validate that a value conforms with the validation function. +func (b *BoolSetting) Validate(sv *Values, v bool) error { + if b.validateFn != nil { + if err := b.validateFn(sv, v); err != nil { + return err + } + } + return nil +} + func (b *BoolSetting) set(ctx context.Context, sv *Values, v bool) { vInt := int64(0) if v { @@ -127,7 +140,22 @@ func (b *BoolSetting) setToDefault(ctx context.Context, sv *Values) { func RegisterBoolSetting( class Class, key InternalKey, desc string, defaultValue bool, opts ...SettingOption, ) *BoolSetting { - setting := &BoolSetting{defaultValue: defaultValue} + validateFn := func(sv *Values, val bool) error { + for _, opt := range opts { + switch { + case opt.commonOpt != nil: + continue + case opt.validateBoolFn != nil: + default: + panic(errors.AssertionFailedf("wrong validator type")) + } + if err := opt.validateBoolFn(sv, val); err != nil { + return err + } + } + return nil + } + setting := &BoolSetting{defaultValue: defaultValue, validateFn: validateFn} register(class, key, desc, setting) setting.apply(opts) return setting diff --git a/pkg/settings/options.go b/pkg/settings/options.go index 5c28eab00ec5..a2f2b67497c2 100644 --- a/pkg/settings/options.go +++ b/pkg/settings/options.go @@ -19,6 +19,7 @@ import ( // SettingOption is the type of an option that can be passed to Register. type SettingOption struct { commonOpt func(*common) + validateBoolFn func(*Values, bool) error validateDurationFn func(time.Duration) error validateInt64Fn func(int64) error validateFloat64Fn func(float64) error @@ -108,6 +109,11 @@ func WithValidateFloat(fn func(float64) error) SettingOption { return SettingOption{validateFloat64Fn: fn} } +// WithValidateBool adds a validation function for a boolean setting. +func WithValidateBool(fn func(*Values, bool) error) SettingOption { + return SettingOption{validateBoolFn: fn} +} + // WithValidateString adds a validation function for a string setting. func WithValidateString(fn func(*Values, string) error) SettingOption { return SettingOption{validateStringFn: fn} diff --git a/pkg/settings/validation_test.go b/pkg/settings/validation_test.go index 994dd2b75e21..2aa6e1b33ee0 100644 --- a/pkg/settings/validation_test.go +++ b/pkg/settings/validation_test.go @@ -20,6 +20,20 @@ import ( "github.com/cockroachdb/cockroach/pkg/testutils" ) +var cantBeTrue = settings.WithValidateBool(func(sv *settings.Values, b bool) error { + if b { + return fmt.Errorf("it cant be true") + } + return nil +}) + +var cantBeFalse = settings.WithValidateBool(func(sv *settings.Values, b bool) error { + if !b { + return fmt.Errorf("it cant be false") + } + return nil +}) + func TestValidationOptions(t *testing.T) { type subTest struct { val interface{} @@ -182,6 +196,26 @@ func TestValidationOptions(t *testing.T) { {val: 11, opt: settings.ByteSizeWithMinimum(10), expectedErr: ""}, }, }, + { + testLabel: "bool", + settingFn: func(n int, bval interface{}, opt settings.SettingOption) settings.Setting { + val := bval.(bool) + b := settings.RegisterBoolSetting(settings.SystemOnly, settings.InternalKey(fmt.Sprintf("test-%d", n)), "desc", + val, opt) + // We explicitly check here to test validation which does not happen on initialization. + err := b.Validate(&settings.Values{}, val) + if err != nil { + panic(err) + } + return b + }, + subTests: []subTest{ + {val: true, opt: cantBeTrue, expectedErr: "it cant be true"}, + {val: false, opt: cantBeTrue, expectedErr: ""}, + {val: true, opt: cantBeFalse, expectedErr: ""}, + {val: false, opt: cantBeFalse, expectedErr: "it cant be false"}, + }, + }, } for _, tc := range testCases { diff --git a/pkg/sql/set_cluster_setting.go b/pkg/sql/set_cluster_setting.go index 05660ed5bdf9..251d5834a865 100644 --- a/pkg/sql/set_cluster_setting.go +++ b/pkg/sql/set_cluster_setting.go @@ -782,6 +782,9 @@ func toSettingString( return "", errors.Errorf("cannot use %s %T value for string setting", d.ResolvedType(), d) case *settings.BoolSetting: if b, ok := d.(*tree.DBool); ok { + if err := setting.Validate(&st.SV, bool(*b)); err != nil { + return "", err + } return settings.EncodeBool(bool(*b)), nil } return "", errors.Errorf("cannot use %s %T value for bool setting", d.ResolvedType(), d) diff --git a/pkg/util/log/logcrash/BUILD.bazel b/pkg/util/log/logcrash/BUILD.bazel index b8b5ec29e571..919af41aa5a0 100644 --- a/pkg/util/log/logcrash/BUILD.bazel +++ b/pkg/util/log/logcrash/BUILD.bazel @@ -10,6 +10,7 @@ go_library( }, deps = [ "//pkg/build", + "//pkg/ccl/utilccl/licenseccl", "//pkg/settings", "//pkg/util/envutil", "//pkg/util/log", diff --git a/pkg/util/log/logcrash/crash_reporting.go b/pkg/util/log/logcrash/crash_reporting.go index d5ea9cdc0f97..9ffea95de47c 100644 --- a/pkg/util/log/logcrash/crash_reporting.go +++ b/pkg/util/log/logcrash/crash_reporting.go @@ -18,6 +18,7 @@ import ( "time" "github.com/cockroachdb/cockroach/pkg/build" + "github.com/cockroachdb/cockroach/pkg/ccl/utilccl/licenseccl" "github.com/cockroachdb/cockroach/pkg/settings" "github.com/cockroachdb/cockroach/pkg/util/envutil" "github.com/cockroachdb/cockroach/pkg/util/log" @@ -56,8 +57,47 @@ var ( settings.ApplicationLevel, "diagnostics.reporting.enabled", "enable reporting diagnostic metrics to cockroach labs", - false, - settings.WithPublic) + true, + settings.WithPublic, + settings.WithValidateBool(func(sv *settings.Values, b bool) error { + // If the user wants to turn on diagnostics, no validation is needed. + if b { + return nil + } + + // The validator looks for a valid license, but fails gracefully if one is + // not found. It's possible at this point one is not set, and because + // failure will panic on startup, we allow the setting of any value. + licenseSetting, ok, _ := settings.LookupForLocalAccess("enterprise.license", true /* forSystemTenant */) + if !ok { + log.Warning(context.Background(), "unable to find license configuring diagnostic reporting") + return nil + } + lic, err := licenseSetting.DecodeToString(licenseSetting.Encoded(sv)) + if err != nil { + log.Errorf(context.Background(), "error configuring diagnostics: %s", err) + return nil + } + + license, err := licenseccl.Decode(lic) + if err != nil { + log.Errorf(context.Background(), "error configuring diagnostics: %s", err) + return nil + } + if license == nil { + log.Warning(context.Background(), "unable to read license while setting diagnostics.reporting.enabled") + return nil + } + + // If the license is limited and diagnostics are off, we prevent the user + // from disabling diagnostics reporting. + isLimited := license.Type == licenseccl.License_Free || license.Type == licenseccl.License_Trial + if isLimited { + return fmt.Errorf("unable to disable diagnostics with license type %s", license.Type) + } + return nil + }), + ) // CrashReports wraps "diagnostics.reporting.send_crash_reports.enabled". CrashReports = settings.RegisterBoolSetting(