diff --git a/controllers/cert.go b/controllers/cert.go index 8bb1c04..60ef62b 100644 --- a/controllers/cert.go +++ b/controllers/cert.go @@ -17,7 +17,9 @@ package controllers import ( "encoding/json" + "github.com/beego/beego/utils/pagination" "github.com/casbin/caswaf/object" + "github.com/casbin/caswaf/util" ) func (c *ApiController) GetGlobalCerts() { @@ -44,13 +46,38 @@ func (c *ApiController) GetCerts() { owner = "" } - certs, err := object.GetCerts(owner) - if err != nil { - c.ResponseError(err.Error()) - return + limit := c.Input().Get("pageSize") + page := c.Input().Get("p") + field := c.Input().Get("field") + value := c.Input().Get("value") + sortField := c.Input().Get("sortField") + sortOrder := c.Input().Get("sortOrder") + + if limit == "" || page == "" { + certs, err := object.GetCerts(owner) + if err != nil { + c.ResponseError(err.Error()) + return + } + + c.ResponseOk(object.GetMaskedCerts(certs)) + } else { + limit := util.ParseInt(limit) + count, err := object.GetCertCount(owner, field, value) + if err != nil { + c.ResponseError(err.Error()) + return + } + + paginator := pagination.SetPaginator(c.Ctx, limit, count) + certs, err := object.GetPaginationCerts(owner, paginator.Offset(), limit, field, value, sortField, sortOrder) + if err != nil { + c.ResponseError(err.Error()) + return + } + + c.ResponseOk(object.GetMaskedCerts(certs), paginator.Nums()) } - - c.ResponseOk(object.GetMaskedCerts(certs)) } func (c *ApiController) GetCert() { @@ -142,11 +169,6 @@ func (c *ApiController) UpdateCertDomainExpire() { } cert.DomainExpireTime = domainExpireTime - _, err = object.UpdateCert(id, cert) - if err != nil { - c.ResponseError(err.Error()) - return - } - - c.ResponseOk(object.GetMaskedCert(cert)) + c.Data["json"] = wrapActionResponse(object.UpdateCert(id, cert)) + c.ServeJSON() } diff --git a/controllers/site.go b/controllers/site.go index f2861f6..8730d10 100644 --- a/controllers/site.go +++ b/controllers/site.go @@ -17,6 +17,7 @@ package controllers import ( "encoding/json" + "github.com/beego/beego/utils/pagination" "github.com/casbin/caswaf/object" "github.com/casbin/caswaf/util" ) @@ -45,13 +46,38 @@ func (c *ApiController) GetSites() { owner = "" } - sites, err := object.GetSites(owner) + limit := c.Input().Get("pageSize") + page := c.Input().Get("p") + field := c.Input().Get("field") + value := c.Input().Get("value") + sortField := c.Input().Get("sortField") + sortOrder := c.Input().Get("sortOrder") + + if limit == "" || page == "" { + sites, err := object.GetSites(owner) + if err != nil { + c.ResponseError(err.Error()) + return + } + c.ResponseOk(object.GetMaskedSites(sites, util.GetHostname())) + return + } + + limitInt := util.ParseInt(limit) + count, err := object.GetSiteCount(owner, field, value) if err != nil { c.ResponseError(err.Error()) return } - c.ResponseOk(object.GetMaskedSites(sites, util.GetHostname())) + paginator := pagination.SetPaginator(c.Ctx, limitInt, count) + sites, err := object.GetPaginationSites(owner, paginator.Offset(), limitInt, field, value, sortField, sortOrder) + if err != nil { + c.ResponseError(err.Error()) + return + } + + c.ResponseOk(object.GetMaskedSites(sites, util.GetHostname()), paginator.Nums()) } func (c *ApiController) GetSite() { diff --git a/object/cert.go b/object/cert.go index 7f7675b..3ec9b14 100644 --- a/object/cert.go +++ b/object/cert.go @@ -206,3 +206,19 @@ func (cert *Cert) isCertNearExpire() (bool, error) { return res, nil } + +func GetCertCount(owner, field, value string) (int64, error) { + session := GetSession(owner, -1, -1, field, value, "", "") + return session.Count(&Cert{}) +} + +func GetPaginationCerts(owner string, offset, limit int, field, value, sortField, sortOrder string) ([]*Cert, error) { + certs := []*Cert{} + session := GetSession(owner, offset, limit, field, value, sortField, sortOrder) + err := session.Where("owner = ? or owner = ?", "admin", owner).Find(&certs) + if err != nil { + return certs, err + } + + return certs, nil +} diff --git a/object/ormer_session.go b/object/ormer_session.go new file mode 100644 index 0000000..9c1f65b --- /dev/null +++ b/object/ormer_session.go @@ -0,0 +1,46 @@ +// Copyright 2023 The casbin Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package object + +import ( + "fmt" + + "github.com/casbin/caswaf/util" + "github.com/xorm-io/xorm" +) + +func GetSession(owner string, offset, limit int, field, value, sortField, sortOrder string) *xorm.Session { + session := ormer.Engine.Prepare() + if offset != -1 && limit != -1 { + session.Limit(limit, offset) + } + if owner != "" { + session = session.And("owner=?", owner) + } + if field != "" && value != "" { + if util.FilterField(field) { + session = session.And(fmt.Sprintf("%s like ?", util.SnakeString(field)), fmt.Sprintf("%%%s%%", value)) + } + } + if sortField == "" || sortOrder == "" { + sortField = "created_time" + } + if sortOrder == "ascend" { + session = session.Asc(util.SnakeString(sortField)) + } else { + session = session.Desc(util.SnakeString(sortField)) + } + return session +} diff --git a/object/site.go b/object/site.go index 5806082..9ee9e33 100644 --- a/object/site.go +++ b/object/site.go @@ -333,3 +333,19 @@ func (site *Site) checkNodes() error { return nil } + +func GetSiteCount(owner, field, value string) (int64, error) { + session := GetSession(owner, -1, -1, field, value, "", "") + return session.Count(&Site{}) +} + +func GetPaginationSites(owner string, offset, limit int, field, value, sortField, sortOrder string) ([]*Site, error) { + sites := []*Site{} + session := GetSession(owner, offset, limit, field, value, sortField, sortOrder) + err := session.Where("owner = ? or owner = ?", "admin", owner).Find(&sites) + if err != nil { + return sites, err + } + + return sites, nil +} diff --git a/util/string.go b/util/string.go index 14d3458..1a59488 100644 --- a/util/string.go +++ b/util/string.go @@ -163,3 +163,21 @@ func GenerateTwoUniqueRandomStrings() (string, string, error) { } return str1, str2, nil } + +func SnakeString(s string) string { + data := make([]byte, 0, len(s)*2) + j := false + num := len(s) + for i := 0; i < num; i++ { + d := s[i] + if i > 0 && d >= 'A' && d <= 'Z' && j { + data = append(data, '_') + } + if d != '_' { + j = true + } + data = append(data, d) + } + result := strings.ToLower(string(data[:])) + return strings.ReplaceAll(result, " ", "") +} diff --git a/util/validation.go b/util/validation.go new file mode 100644 index 0000000..ec9d32b --- /dev/null +++ b/util/validation.go @@ -0,0 +1,17 @@ +package util + +import "regexp" + +var ( + ReWhiteSpace *regexp.Regexp + ReFieldWhiteList *regexp.Regexp +) + +func init() { + ReWhiteSpace, _ = regexp.Compile(`\s`) + ReFieldWhiteList, _ = regexp.Compile(`^[A-Za-z0-9]+$`) +} + +func FilterField(field string) bool { + return ReFieldWhiteList.MatchString(field) +} diff --git a/web/src/BaseListPage.js b/web/src/BaseListPage.js index d6bb90b..2af20f2 100644 --- a/web/src/BaseListPage.js +++ b/web/src/BaseListPage.js @@ -32,11 +32,19 @@ class BaseListPage extends React.Component { } UNSAFE_componentWillMount() { - this.fetch(); + const {pagination} = this.state; + this.fetch({pagination}); } - handleTableChange = () => { - this.fetch(); + handleTableChange = (pagination, filters, sorter) => { + this.fetch({ + sortField: sorter.field, + sortOrder: sorter.order, + pagination, + ...filters, + searchText: this.state.searchText, + searchedColumn: this.state.searchedColumn, + }); }; render() { diff --git a/web/src/CertListPage.js b/web/src/CertListPage.js index 828f6ed..4bb9f4f 100644 --- a/web/src/CertListPage.js +++ b/web/src/CertListPage.js @@ -23,28 +23,14 @@ import copy from "copy-to-clipboard"; import BaseListPage from "./BaseListPage"; class CertListPage extends BaseListPage { + constructor(props) { + super(props); + } UNSAFE_componentWillMount() { this.fetch(); } - fetch = (params = {}) => { - this.setState({loading: true}); - CertBackend.getCerts(this.props.account.name) - .then((res) => { - this.setState({ - loading: false, - }); - if (res.status === "ok") { - this.setState({ - data: res.data, - }); - } else { - Setting.showMessage("error", `Failed to get certs: ${res.msg}`); - } - }); - }; - newCert() { const randomName = Setting.getRandomName(); return { @@ -105,11 +91,7 @@ class CertListPage extends BaseListPage { Setting.showMessage("error", `Failed to refresh domain expire: ${res.msg}`); } else { Setting.showMessage("success", "Domain expire refresh successfully"); - const newData = [...this.state.data]; - newData[i] = res.data; - this.setState({ - data: newData, - }); + this.fetch(); } } ) @@ -263,9 +245,8 @@ class CertListPage extends BaseListPage { width: "260px", render: (text, record, index) => { return ( -