Skip to content

Commit

Permalink
增加访问限制、优化代码结构
Browse files Browse the repository at this point in the history
  • Loading branch information
baiqll committed Oct 26, 2024
1 parent 3438875 commit bf9311f
Show file tree
Hide file tree
Showing 5 changed files with 186 additions and 90 deletions.
134 changes: 44 additions & 90 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@ import (
"io/ioutil"
"net/http"
"path/filepath"
"regexp"
"strings"

"github.com/baiqll/src-http/src/cert"
"github.com/baiqll/src-http/src/httpx"
"github.com/baiqll/src-http/src/lib"
)

Expand All @@ -25,7 +25,7 @@ func main() {
/____/_/ |_|\____/_/ /_/\__/\__/ .___/
/_/ v1.0
Enabling https service dedicated to SRC testing
Enabling https service dedicated to SRC testing
`
fmt.Println(string(banner))

Expand All @@ -35,14 +35,8 @@ func main() {
var payload string
var enable_tls bool
var default_file string
var tls_path = filepath.Join(lib.HomeDir(), ".config/src-http")
var internet_ip = lib.GetInternetIP()
var domain string
var port string
var method string
var web_server string
var is_new_domain = false
var show_internet_server = true
var config_path = filepath.Join(lib.HomeDir(), ".config/src-http")


flag.StringVar(&server, "server", "", "https 服务")
flag.BoolVar(&enable_tls, "tls", false, "是否开启tls,默认关闭")
Expand All @@ -52,96 +46,42 @@ func main() {
// 解析命令行参数写入注册的flag里
flag.Parse()

// 判断域名是否合规
if server != "" {


server_split := strings.Split(server, ":")
domain = server_split[0]
if len(server_split) > 1 {
port = server_split[1]
}

if is_host, _ := regexp.MatchString(`[a-zA-Z0-9][-a-zA-Z0-9]{0,62}(\.[a-zA-Z0-9][-a-zA-Z0-9]{0,62})+\.?`, domain); !is_host {

return
}

if domain!= "0.0.0.0"{
/*
设置本地域名解析
*/
is_new_domain = lib.NewDNS(domain)
show_internet_server = false

}

}

lib.NewDNS(internet_ip)

if port == ""{
if enable_tls{
port = "443"
}else{
port = "80"
}
}

server = "0.0.0.0:"+ port
uri := httpx.Parse(server, enable_tls)

fmt.Println(uri.ServerBanner)

if enable_tls{
method = "https"
}else{
method = "http"
}

if domain !=""{
web_server = method + "://" + domain + ":" + port
}else{
web_server = method + "://127.0.0.1:" + port
}


// 开始启动服务
fmt.Println("[*] Starting server ",web_server, "...")
if show_internet_server{
fmt.Println("[*] Internet server ", method + "://" + internet_ip + ":" + port )
}

fmt.Println("[*] Listening ", server)

err := cert.CreateTlsCert(tls_path,[]string{domain},internet_ip, is_new_domain)
err := cert.CreateTlsCert(config_path, []string{uri.Host}, uri.InternetIp, uri.IsNewDomain)
if err != nil {
fmt.Println("TLS Cert Error")
}

http_server(server, filepath.Join(tls_path, "server.pem"), filepath.Join(tls_path, "server.key"), payload, default_file, enable_tls)
HttpServer(httpx.Config{
Server: fmt.Sprintf("0.0.0.0:%s",uri.Port),
Host: uri.Host,
ConfigPath: config_path,
Payload: payload,
DefaultFile: default_file,
EnableTLS: enable_tls,
})

}

func set_response_header(w http.ResponseWriter){
func SetResponseHeader(w http.ResponseWriter){
w.Header().Set("Access-Control-Allow-Origin", "*")
w.Header().Set("Access-Control-Allow-Methods", "POST, GET, OPTIONS, PUT, DELETE, UPDATE")
w.Header().Set("Access-Control-Allow-Headers", "Origin, X-Requested-With, Content-Type, Accept, Authorization")
w.Header().Set("Access-Control-Expose-Headers", "Content-Length, Access-Control-Allow-Origin, Access-Control-Allow-Headers, Cache-Control, Content-Language, Content-Type")
w.Header().Set("Access-Control-Allow-Credentials", "true")
}

func http_write(w http.ResponseWriter, res_data []byte){

set_response_header(w)
func HttpWrite(w http.ResponseWriter, res_data []byte){
SetResponseHeader(w)
w.Write(res_data)
}

// 开启文件类型模式
func http_server(server string, tls_crt string, tls_key string, payload string, default_file string, enable_tls bool) {

mux := http.NewServeMux()

mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {

// 定义一个可以接受额外参数的HTTP处理函数
func SRCHandler(config httpx.Config) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
fmt.Printf("%s %s %s\n",r.Method, r.URL,r.Proto)
fmt.Printf("Host: %s\n",r.Host)
fmt.Printf("From: %s\n",lib.GetRemoteIp(r))
Expand Down Expand Up @@ -172,13 +112,13 @@ func http_server(server string, tls_crt string, tls_key string, payload string,
}else if(strings.HasPrefix(r.URL.String(), "/default")){
// 设置默认信息

data, err := ioutil.ReadFile(default_file)
data, err := ioutil.ReadFile(config.DefaultFile)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}

http_write(w,data)
HttpWrite(w,data)

}else if(strings.HasPrefix(r.URL.String(), "/payload")){
// 自定义返固定内容
Expand Down Expand Up @@ -212,34 +152,48 @@ func http_server(server string, tls_crt string, tls_key string, payload string,
}

default:
http_write(w,[]byte(payload))
HttpWrite(w,[]byte(config.Payload))
}


}else if(strings.HasPrefix(r.URL.String(), "/message")){
// 返回全内容(接收消息)

http_write(w,[]byte(`{"message": "OK"}`))
HttpWrite(w,[]byte(`{"message": "OK"}`))

}else{
// 文件系统
set_response_header(w)
SetResponseHeader(w)
http.FileServer(http.Dir("./")).ServeHTTP(w, r)

}
}
}

})
// 开启文件类型模式
func HttpServer(config httpx.Config) {

if enable_tls {
mux := http.NewServeMux()

// 创建一个中间件链
handlerWithMiddleware := httpx.HostMiddleware(config.Host, http.HandlerFunc(SRCHandler(config)))

// 设置路由并启动服务器
mux.Handle("/", handlerWithMiddleware)

if config.EnableTLS {
// 使用https
err := http.ListenAndServeTLS(server, tls_crt, tls_key, mux)
tls_crt := filepath.Join(config.ConfigPath, "server.pem")
tls_key := filepath.Join(config.ConfigPath, "server.key")

err := http.ListenAndServeTLS(config.Server, tls_crt, tls_key, mux)
if err != nil {
fmt.Println("TLS Cert Error:", err.Error())
}

} else {
// 使用http
http.ListenAndServe(server, mux)
http.ListenAndServe(config.Server, mux)
}
}

11 changes: 11 additions & 0 deletions src/httpx/config.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package httpx

type Config struct{
Server string
Host string
ConfigPath string
Payload string
DefaultFile string
EnableTLS bool

}
20 changes: 20 additions & 0 deletions src/httpx/middleware.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
package httpx

import (
"net/http"
"strings"

"github.com/baiqll/src-http/src/lib"
)

func HostMiddleware(host string, next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
req_host := strings.Split(r.Host, ":")[0]
if !lib.IsIP(host) && req_host != host{
w.WriteHeader(http.StatusNotFound)
w.Write([]byte("Not Found!"))
return
}
next.ServeHTTP(w, r)
})
}
105 changes: 105 additions & 0 deletions src/httpx/uri.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
package httpx

import (
"fmt"
"regexp"
"strings"

"github.com/baiqll/src-http/src/lib"
)

type Uri struct{
Host string
Port string
ServerBanner string
IsNewDomain bool
InternetIp string
EnableTls bool
ShowInternetServer bool
}

func Parse(str string, enable_tls bool)Uri{

httpx := Uri{
Host: "127.0.0.1",
InternetIp: lib.GetInternetIP(),
ShowInternetServer: true,
EnableTls: enable_tls,
}

httpx.init(str)

return httpx

}

func(u *Uri) init_port(){

if u.Port == ""{
if u.EnableTls{
u.Port = "443"
}else{
u.Port = "80"
}
}
}

func(u *Uri) init_server_banner(){
method := "http"
banner := ""
Listening_port := ":"+ u.Port

if u.EnableTls{
method = "https"
}
if u.Port == "443" || u.Port == "80"{
Listening_port = ""

}

banner += fmt.Sprintf("[*] Starting server %s://%s%s ...\n", method, u.Host, Listening_port)
if u.ShowInternetServer{
banner += fmt.Sprintf("[*] Internet server %s://%s%s ...\n", method, u.InternetIp, Listening_port)
}

banner += fmt.Sprintf("[*] Listening 0.0.0.0:%s",Listening_port)

u.ServerBanner = banner

}

func (u *Uri)init(str string){

// 判断域名是否合规
if str != "" {

server_split := strings.Split(str, ":")
u.Host = server_split[0]
if len(server_split) > 1 {
u.Port = server_split[1]
}

if is_host, _ := regexp.MatchString(`[a-zA-Z0-9][-a-zA-Z0-9]{0,62}(\.[a-zA-Z0-9][-a-zA-Z0-9]{0,62})+\.?`, u.Host); !is_host {

return
}

if u.Host!= "0.0.0.0"{
/*
设置本地域名解析
*/
u.IsNewDomain = lib.NewDNS(u.Host)
// show_internet_server = false
u.ShowInternetServer = false

}

}

lib.NewDNS(u.InternetIp)

u.init_port()

u.init_server_banner()

}
6 changes: 6 additions & 0 deletions src/lib/lib.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,12 @@ func SetHosts(host string) {

}

// isIP 尝试将字符串解析为IP地址
func IsIP(str string) bool {
ip := net.ParseIP(str)
return ip != nil
}

// 取消hosts域名绑定
func UnloadHosts(host string) {

Expand Down

0 comments on commit bf9311f

Please sign in to comment.