Skip to content

Commit

Permalink
feat: add UA rule (#45)
Browse files Browse the repository at this point in the history
  • Loading branch information
love98ooo authored Jul 26, 2024
1 parent 4b5d42d commit 49d0dba
Show file tree
Hide file tree
Showing 13 changed files with 453 additions and 122 deletions.
29 changes: 17 additions & 12 deletions controllers/rule.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,12 @@ func (c *ApiController) GetRules() {
if c.RequireSignedIn() {
return
}
owner := c.Input().Get("owner")
if owner == "admin" {
owner = ""
}

rules, err := object.GetRules()
rules, err := object.GetRules(owner)
if err != nil {
c.ResponseError(err.Error())
return
Expand Down Expand Up @@ -76,7 +80,7 @@ func (c *ApiController) AddRule() {
return
}
c.Data["json"] = wrapActionResponse(object.AddRule(&rule))
go service.UpdateWAF()
go service.UpdateWafs()
c.ServeJSON()
}

Expand All @@ -100,7 +104,7 @@ func (c *ApiController) UpdateRule() {

id := c.Input().Get("id")
c.Data["json"] = wrapActionResponse(object.UpdateRule(id, &rule))
go service.UpdateWAF()
go service.UpdateWafs()
c.ServeJSON()
}

Expand All @@ -117,25 +121,25 @@ func (c *ApiController) DeleteRule() {
}

c.Data["json"] = wrapActionResponse(object.DeleteRule(&rule))
go service.UpdateWAF()
go service.UpdateWafs()
c.ServeJSON()
}

func checkExpressions(expressions []object.Expression, ruleType string) error {
func checkExpressions(expressions []*object.Expression, ruleType string) error {
values := make([]string, len(expressions))
for i, expression := range expressions {
values[i] = expression.Value
}
switch ruleType {
case "waf":
return checkWAFRule(values)
case "ip":
return checkIPRule(values)
case "WAF":
return checkWafRule(values)
case "IP":
return checkIpRule(values)
}
return nil
}

func checkWAFRule(rules []string) error {
func checkWafRule(rules []string) error {
for _, rule := range rules {
scanner := parser.NewSecLangScannerFromString(rule)
_, err := scanner.AllDirective()
Expand All @@ -146,10 +150,11 @@ func checkWAFRule(rules []string) error {
return nil
}

func checkIPRule(ipLists []string) error {
func checkIpRule(ipLists []string) error {
for _, ipList := range ipLists {
for _, ip := range strings.Split(ipList, " ") {
if net.ParseIP(ip) == nil {
_, _, err := net.ParseCIDR(ip)
if net.ParseIP(ip) == nil && err != nil {
return errors.New("Invalid IP address: " + ip)
}
}
Expand Down
1 change: 1 addition & 0 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ func main() {
casdoor.InitCasdoorConfig()
proxy.InitHttpClient()
object.InitSiteMap()
object.InitRuleMap()
run.InitAppMap()
run.InitSelfStart()
object.StartMonitorSitesLoop()
Expand Down
56 changes: 31 additions & 25 deletions object/rule.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,20 +26,28 @@ type Expression struct {
}

type Rule struct {
Owner string `xorm:"varchar(100) notnull pk" json:"owner"`
Name string `xorm:"varchar(100) notnull pk" json:"name"`
Type string `xorm:"varchar(100) notnull" json:"type"`
Expressions []Expression `xorm:"mediumtext" json:"expressions"`
CreatedTime string `xorm:"varchar(100) notnull" json:"createdTime"`
UpdatedTime string `xorm:"varchar(100) notnull" json:"updatedTime"`
Owner string `xorm:"varchar(100) notnull pk" json:"owner"`
Name string `xorm:"varchar(100) notnull pk" json:"name"`
Type string `xorm:"varchar(100) notnull" json:"type"`
Expressions []*Expression `xorm:"mediumtext" json:"expressions"`
Action string `xorm:"varchar(100) notnull" json:"action"`
Reason string `xorm:"varchar(100) notnull" json:"reason"`
CreatedTime string `xorm:"varchar(100) notnull" json:"createdTime"`
UpdatedTime string `xorm:"varchar(100) notnull" json:"updatedTime"`
}

func GetRules() ([]*Rule, error) {
func GetGlobalRules() ([]*Rule, error) {
rules := []*Rule{}
err := ormer.Engine.Asc("owner").Desc("created_time").Find(&rules)
return rules, err
}

func GetRules(owner string) ([]*Rule, error) {
rules := []*Rule{}
err := ormer.Engine.Desc("updated_time").Find(&rules, &Rule{Owner: owner})
return rules, err
}

func getRule(owner string, name string) (*Rule, error) {
rule := Rule{Owner: owner, Name: name}
existed, err := ormer.Engine.Get(&rule)
Expand Down Expand Up @@ -70,15 +78,25 @@ func UpdateRule(id string, rule *Rule) (bool, error) {
if err != nil {
return false, err
}
err = refreshRuleMap()
if err != nil {
return false, err
}
return true, nil
}

func AddRule(rule *Rule) (bool, error) {
if _, err := ormer.Engine.Insert(rule); err != nil {
affected, err := ormer.Engine.Insert(rule)
if err != nil {
return false, err
} else {
return true, nil
}
if affected != 0 {
err = refreshRuleMap()
if err != nil {
return false, err
}
}
return affected != 0, nil
}

func DeleteRule(rule *Rule) (bool, error) {
Expand All @@ -90,20 +108,8 @@ func DeleteRule(rule *Rule) (bool, error) {
return affected != 0, nil
}

func GetWAFRules() string {
// Get all rules of type "waf".
func getWafRules() ([]*Rule, error) {
rules := []*Rule{}
err := ormer.Engine.Where("type = ?", "waf").Find(&rules)
if err != nil {
return ""
}

res := ""
// get all expressions from rules
for _, rule := range rules {
for _, expression := range rule.Expressions {
res += expression.Value + "\n"
}
}
return res
err := ormer.Engine.Where("type = ?", "WAF").Find(&rules)
return rules, err
}
53 changes: 53 additions & 0 deletions object/rule_cache.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
// Copyright 2023 The casbin Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package object

import (
"github.com/casbin/caswaf/util"
)

var ruleMap = map[string]*Rule{}

func InitRuleMap() {
err := refreshRuleMap()
if err != nil {
panic(err)
}
}

func refreshRuleMap() error {
newRuleMap := map[string]*Rule{}
rules, err := GetGlobalRules()
if err != nil {
return err
}

for _, rule := range rules {
newRuleMap[util.GetIdFromOwnerAndName(rule.Owner, rule.Name)] = rule
}

ruleMap = newRuleMap
return nil
}

func GetRulesByRuleIds(ids []string) []*Rule {
var res []*Rule
for _, id := range ids {
if rule, ok := ruleMap[id]; ok {
res = append(res, rule)
}
}
return res
}
1 change: 1 addition & 0 deletions object/site.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ type Site struct {
OtherDomains []string `xorm:"varchar(500)" json:"otherDomains"`
NeedRedirect bool `json:"needRedirect"`
EnableWaf bool `json:"enableWaf"`
Rules []string `xorm:"varchar(500)" json:"wafRuleIds"`
Waf coraza.WAF `xorm:"-" json:"-"`
Challenges []string `xorm:"mediumtext" json:"challenges"`
Host string `xorm:"varchar(100)" json:"host"`
Expand Down
12 changes: 12 additions & 0 deletions object/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -195,3 +195,15 @@ func GetCertByDomain(domain string) (*Cert, error) {

return nil, nil
}

func GetWafRulesByIds(ids []string) string {
var res string
for _, id := range ids {
if rule, ok := ruleMap[id]; ok {
for _, expression := range rule.Expressions {
res += expression.Value + "\n"
}
}
}
return res
}
34 changes: 33 additions & 1 deletion service/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,27 @@ func redirectToHost(w http.ResponseWriter, r *http.Request, host string) {
http.Redirect(w, r, targetUrl, http.StatusMovedPermanently)
}

func checkRules(wafRuleIds []string, r *http.Request) (bool, string, error) {
rules := object.GetRulesByRuleIds(wafRuleIds)
for _, rule := range rules {
switch rule.Type {
case "User-Agent":
uaRule := &UaRule{Rule: *rule}
action, reason, err := uaRule.checkRule(rule.Expressions, r)
if err != nil {
return false, "Internal Server Error", err
}
if action == "Block" {
return false, reason, nil
}
if action == "Allow" {
return true, "", nil
}
}
}
return true, "", nil
}

func handleRequest(w http.ResponseWriter, r *http.Request) {
clientIp := getClientIp(r)
logRequest(clientIp, r)
Expand Down Expand Up @@ -212,7 +233,18 @@ func handleRequest(w http.ResponseWriter, r *http.Request) {
return
}
if site.EnableWaf {
site.Waf = getWAF()
isAllowed, reason, err := checkRules(site.Rules, r)
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
responseError(w, "Internal Server Error: %v", err)
return
}
if !isAllowed {
w.WriteHeader(http.StatusForbidden)
responseError(w, "Blocked by CasWAF: %s", reason)
return
}
getWaf(site)
httptx.WrapHandler(site.Waf, http.HandlerFunc(nextHandle)).ServeHTTP(w, r)
} else {
nextHandle(w, r)
Expand Down
53 changes: 53 additions & 0 deletions service/rule.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
package service

import (
"net/http"
"regexp"
"strings"

"github.com/casbin/caswaf/object"
)

type UaRule struct {
Rule object.Rule
check CheckRule
}

type CheckRule interface {
checkRule(expressions []*object.Expression, req *http.Request) (string, string, error)
}

func (r *UaRule) checkRule(expressions []*object.Expression, req *http.Request) (string, string, error) {
userAgent := req.UserAgent()
for _, expression := range expressions {
ua := expression.Value
switch expression.Operator {
case "contains":
if strings.Contains(userAgent, ua) {
return r.Rule.Action, r.Rule.Reason, nil
}
case "does not contain":
if !strings.Contains(userAgent, ua) {
return r.Rule.Action, r.Rule.Reason, nil
}
case "equals":
if userAgent == ua {
return r.Rule.Action, r.Rule.Reason, nil
}
case "does not equal":
if strings.Compare(userAgent, ua) != 0 {
return r.Rule.Action, r.Rule.Reason, nil
}
case "match":
// regex match
isMatched, err := regexp.MatchString(ua, userAgent)
if err != nil {
return "", "", err
}
if isMatched {
return r.Rule.Action, r.Rule.Reason, nil
}
}
}
return "", "", nil
}
Loading

0 comments on commit 49d0dba

Please sign in to comment.