Skip to content

Commit

Permalink
Do not modify MID of Security Policy in OSS
Browse files Browse the repository at this point in the history
Signed-off-by: Burak Sekili <[email protected]>
  • Loading branch information
buraksekili committed Jan 11, 2024
1 parent c84ece1 commit 416bf6b
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 34 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/run_tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ jobs:
fail-fast: false
matrix:
k8s-version: ${{ fromJSON(needs.matrix-generator.outputs.k8s-matrix) }}
tyk-version: ['v3.2', 'v4.0', 'v4.3','v5.0.0']
tyk-version: ['v3.2', 'v4.0', 'v4.3', 'v5.0.0', 'v5.2.3']
mode: ['ce','pro']

env:
Expand Down
68 changes: 42 additions & 26 deletions controllers/securitypolicy_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -246,11 +246,7 @@ func (r *SecurityPolicyReconciler) update(ctx context.Context,
) (*model.SecurityPolicySpec, error) {
r.Log.Info("Updating SecurityPolicy", "Policy ID", policy.Status.PolID)

if policy.Spec.MID == nil {
policy.Spec.MID = new(string)
}

*policy.Spec.MID = policy.Status.PolID
updatePolicyMID(ctx, &policy.Spec, &policy.Status.PolID)

spec, err := r.spec(ctx, &policy.Spec)
if err != nil {
Expand Down Expand Up @@ -288,12 +284,8 @@ func (r *SecurityPolicyReconciler) update(ctx context.Context,
return nil, err
}

if policy.Spec.MID == nil {
policy.Spec.MID = new(string)
}

*policy.Spec.MID = *spec.MID
policy.Status.PolID = *spec.MID
updatePolicyMID(ctx, &policy.Spec, spec.MID)
setPolicyStatusPolID(ctx, policy, spec)
} else {
r.Log.Error(err, "Failed to get Policy from Tyk", err)

Expand All @@ -315,7 +307,7 @@ func (r *SecurityPolicyReconciler) update(ctx context.Context,
return nil, err
}

polOnTyk, _ := klient.Universal.Portal().Policy().Get(ctx, *policy.Spec.MID) //nolint:errcheck
polOnTyk, _ := klient.Universal.Portal().Policy().Get(ctx, policy.Status.PolID) //nolint:errcheck

r.Log.Info("Successfully updated Policy")

Expand All @@ -325,6 +317,24 @@ func (r *SecurityPolicyReconciler) update(ctx context.Context,
})
}

func setPolicyStatusPolID(ctx context.Context, policy *tykv1.SecurityPolicy, spec *tykv1.SecurityPolicySpec) {
if policy == nil || spec == nil {
return
}

if env := opclient.GetTykMode(ctx); env.Mode == "ce" {
if spec.ID != nil {
policy.Status.PolID = *spec.ID
}

return
}

if spec.MID != nil {
policy.Status.PolID = *spec.MID
}
}

func (r *SecurityPolicyReconciler) create(ctx context.Context, policy *tykv1.SecurityPolicy) error {
r.Log.Info("Creating a policy")

Expand All @@ -351,11 +361,7 @@ func (r *SecurityPolicyReconciler) create(ctx context.Context, policy *tykv1.Sec
return err
}
} else {
if spec.MID == nil {
spec.MID = new(string)
}

*spec.MID = *existingSpec.MID
updatePolicyMID(ctx, spec, existingSpec.MID)

err = klient.Universal.Portal().Policy().Update(ctx, spec)
if err != nil {
Expand All @@ -380,11 +386,7 @@ func (r *SecurityPolicyReconciler) create(ctx context.Context, policy *tykv1.Sec

r.Log.Info("Successfully created Policy")

if policy.Spec.MID == nil {
policy.Spec.MID = new(string)
}

*policy.Spec.MID = *spec.MID
updatePolicyMID(ctx, &policy.Spec, spec.MID)

err = r.updateStatusOfLinkedAPIs(ctx, policy, false)
if err != nil {
Expand All @@ -396,14 +398,30 @@ func (r *SecurityPolicyReconciler) create(ctx context.Context, policy *tykv1.Sec
return err
}

polOnTyk, _ := klient.Universal.Portal().Policy().Get(ctx, *spec.MID) //nolint:errcheck
polOnTyk, _ := klient.Universal.Portal().Policy().Get(ctx, policy.Status.PolID) //nolint:errcheck

return r.updatePolicyStatus(ctx, policy, func(status *tykv1.SecurityPolicyStatus) {
status.LatestTykSpecHash = calculateHash(polOnTyk)
status.LatestCRDSpecHash = calculateHash(spec)
})
}

func updatePolicyMID(ctx context.Context, policy *tykv1.SecurityPolicySpec, mId *string) {
if env := opclient.GetTykMode(ctx); env.Mode == "ce" {
return
}

if mId == nil {
return
}

if policy.MID == nil {
policy.MID = new(string)
}

*policy.MID = *mId
}

// updatePolicyStatus updates the status of the policy.
func (r *SecurityPolicyReconciler) updatePolicyStatus(
ctx context.Context,
Expand All @@ -412,9 +430,7 @@ func (r *SecurityPolicyReconciler) updatePolicyStatus(
) error {
r.Log.Info("Updating policy status")

if policy.Spec.MID != nil {
policy.Status.PolID = *policy.Spec.MID
}
setPolicyStatusPolID(ctx, policy, &policy.Spec)

if policy.Spec.AccessRightsArray != nil && len(policy.Spec.AccessRightsArray) > 0 {
policy.Status.LinkedAPIs = make([]model.Target, 0)
Expand Down
10 changes: 10 additions & 0 deletions pkg/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,16 @@ func SetContext(ctx context.Context, rctx Context) context.Context {
return context.WithValue(ctx, contextKey{}, rctx)
}

func GetTykMode(ctx context.Context) environment.Env {
if c := ctx.Value(contextKey{}); c != nil {
if a, ok := c.(Context); ok {
return a.Env
}
}

return environment.Env{}
}

func GetContext(ctx context.Context) Context {
if c := ctx.Value(contextKey{}); c != nil {
return c.(Context)
Expand Down
8 changes: 1 addition & 7 deletions pkg/client/gateway/security_policy.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,20 +72,14 @@ func (a SecurityPolicy) Create(ctx context.Context, def *v1.SecurityPolicySpec)
def.MID = new(string)
}

*def.MID = msg.Key

return nil
default:
return client.Error(res)
}
}

func (a SecurityPolicy) Update(ctx context.Context, def *v1.SecurityPolicySpec) error {
if def.MID == nil || *def.MID == "" {
return client.ErrMissingPolicyID
}

res, err := client.PutJSON(ctx, client.Join(endpointPolicies, *def.MID), def)
res, err := client.PutJSON(ctx, client.Join(endpointPolicies, *def.ID), def)
if err != nil {
return err
}
Expand Down

0 comments on commit 416bf6b

Please sign in to comment.