diff --git a/etc/smartdns/smartdns.conf b/etc/smartdns/smartdns.conf index 31969eadbc..524dd5d057 100644 --- a/etc/smartdns/smartdns.conf +++ b/etc/smartdns/smartdns.conf @@ -265,6 +265,11 @@ log-level info # specific cname to domain # cname /domain/target +# add srv record, support multiple srv record. +# srv-record /domain/[target][,port][,priority][,weight] +# srv-record /_ldap._tcp.example.com/ldapserver.example.com,389 +# srv-record /_ldap._tcp.example.com/ + # enalbe DNS64 feature # dns64 [ip/subnet] # dns64 64:ff9b::/96 diff --git a/src/dns_conf.c b/src/dns_conf.c index 9da3adfe82..fae389400f 100644 --- a/src/dns_conf.c +++ b/src/dns_conf.c @@ -66,6 +66,9 @@ int dns_hosts_record_num; /* DNS64 */ struct dns_dns64 dns_conf_dns_dns64; +/* SRV-HOST */ +struct dns_srv_record_table dns_conf_srv_record_table; + /* server ip/port */ struct dns_bind_ip dns_conf_bind_ip[DNS_MAX_BIND_IP]; int dns_conf_bind_ip_num = 0; @@ -505,6 +508,26 @@ static void _config_proxy_table_destroy(void) } } +static void _config_srv_record_table_destroy(void) +{ + struct dns_srv_records *srv_records = NULL; + struct dns_srv_record *srv_record, *tmp1 = NULL; + struct hlist_node *tmp = NULL; + unsigned int i; + + hash_for_each_safe(dns_conf_srv_record_table.srv, i, tmp, srv_records, node) + { + list_for_each_entry_safe(srv_record, tmp1, &srv_records->list, list) + { + list_del(&srv_record->list); + free(srv_record); + } + + hlist_del_init(&srv_records->node); + free(srv_records); + } +} + static int _config_server(int argc, char *argv[], dns_server_type_t type, int default_port) { int index = dns_conf_server_num; @@ -1860,6 +1883,122 @@ static int _config_cname(void *data, int argc, char *argv[]) return 0; } +struct dns_srv_records *dns_server_get_srv_record(const char *domain) +{ + uint32_t key = 0; + + key = hash_string(domain); + struct dns_srv_records *srv_records = NULL; + hash_for_each_possible(dns_conf_srv_record_table.srv, srv_records, node, key) + { + if (strncmp(srv_records->domain, domain, DNS_MAX_CONF_CNAME_LEN) == 0) { + return srv_records; + } + } + + return NULL; +} + +static int _confg_srv_record_add(const char *domain, const char *host, unsigned short priority, unsigned short weight, + unsigned short port) +{ + struct dns_srv_records *srv_records = NULL; + struct dns_srv_record *srv_record = NULL; + uint32_t key = 0; + + srv_records = dns_server_get_srv_record(domain); + if (srv_records == NULL) { + srv_records = malloc(sizeof(*srv_records)); + if (srv_records == NULL) { + goto errout; + } + memset(srv_records, 0, sizeof(*srv_records)); + safe_strncpy(srv_records->domain, domain, DNS_MAX_CONF_CNAME_LEN); + INIT_LIST_HEAD(&srv_records->list); + key = hash_string(domain); + hash_add(dns_conf_srv_record_table.srv, &srv_records->node, key); + } + + srv_record = malloc(sizeof(*srv_record)); + if (srv_record == NULL) { + goto errout; + } + memset(srv_record, 0, sizeof(*srv_record)); + safe_strncpy(srv_record->host, host, DNS_MAX_CONF_CNAME_LEN); + srv_record->priority = priority; + srv_record->weight = weight; + srv_record->port = port; + list_add_tail(&srv_record->list, &srv_records->list); + + return 0; +errout: + if (srv_record != NULL) { + free(srv_record); + } + return -1; +} + +static int _config_srv_record(void *data, int argc, char *argv[]) +{ + char *value = NULL; + char domain[DNS_MAX_CONF_CNAME_LEN]; + char buff[DNS_MAX_CONF_CNAME_LEN]; + char *ptr = NULL; + int ret = -1; + + char *host_s; + char *priority_s; + char *weight_s; + char *port_s; + + unsigned short priority = 0; + unsigned short weight = 0; + unsigned short port = 1; + + if (argc < 2) { + goto errout; + } + + value = argv[1]; + if (_get_domain(value, domain, DNS_MAX_CONF_CNAME_LEN, &value) != 0) { + goto errout; + } + + safe_strncpy(buff, value, sizeof(buff)); + + host_s = strtok_r(buff, ",", &ptr); + if (host_s == NULL) { + host_s = ""; + goto out; + } + + port_s = strtok_r(NULL, ",", &ptr); + if (port_s != NULL) { + port = atoi(port_s); + } + + priority_s = strtok_r(NULL, ",", &ptr); + if (priority_s != NULL) { + priority = atoi(priority_s); + } + + weight_s = strtok_r(NULL, ",", &ptr); + if (weight_s != NULL) { + weight = atoi(weight_s); + } +out: + ret = _confg_srv_record_add(domain, host_s, priority, weight, port); + if (ret != 0) { + goto errout; + } + + return 0; + +errout: + tlog(TLOG_ERROR, "add srv-record %s:%s failed", domain, value); + return -1; +} + static void _config_speed_check_mode_clear(struct dns_domain_check_orders *check_orders) { memset(check_orders->orders, 0, sizeof(check_orders->orders)); @@ -4154,6 +4293,7 @@ static struct config_item _config_item[] = { CONF_YESNO("expand-ptr-from-address", &dns_conf_expand_ptr_from_address), CONF_CUSTOM("address", _config_address, NULL), CONF_CUSTOM("cname", _config_cname, NULL), + CONF_CUSTOM("srv-record", _config_srv_record, NULL), CONF_CUSTOM("proxy-server", _config_proxy_server, NULL), CONF_YESNO("ipset-timeout", &dns_conf_ipset_timeout_enable), CONF_CUSTOM("ipset", _config_ipset, NULL), @@ -4406,6 +4546,7 @@ static int _dns_server_load_conf_init(void) hash_init(dns_ptr_table.ptr); hash_init(dns_domain_set_name_table.names); hash_init(dns_ip_set_name_table.names); + hash_init(dns_conf_srv_record_table.srv); return 0; } @@ -4456,6 +4597,7 @@ void dns_server_load_exit(void) _config_host_table_destroy(0); _config_qtype_soa_table_destroy(); _config_proxy_table_destroy(); + _config_srv_record_table_destroy(); dns_conf_server_num = 0; dns_server_bind_destroy(); diff --git a/src/dns_conf.h b/src/dns_conf.h index d9cbe44686..127c2b510c 100644 --- a/src/dns_conf.h +++ b/src/dns_conf.h @@ -488,6 +488,25 @@ struct dns_dns64 { uint32_t prefix_len; }; +struct dns_srv_record { + struct list_head list; + char host[DNS_MAX_CNAME_LEN]; + unsigned short priority; + unsigned short weight; + unsigned short port; +}; + +struct dns_srv_records { + char domain[DNS_MAX_CNAME_LEN]; + struct hlist_node node; + struct list_head list; +}; + +struct dns_srv_record_table { + DECLARE_HASHTABLE(srv, 4); +}; +extern struct dns_srv_record_table dns_conf_srv_record_table; + extern struct dns_dns64 dns_conf_dns_dns64; extern struct dns_bind_ip dns_conf_bind_ip[DNS_MAX_BIND_IP]; @@ -584,6 +603,8 @@ int dns_server_check_update_hosts(void); struct dns_proxy_names *dns_server_get_proxy_nams(const char *proxyname); +struct dns_srv_records *dns_server_get_srv_record(const char *domain); + extern int config_additional_file(void *data, int argc, char *argv[]); const char *dns_conf_get_cache_dir(void); diff --git a/src/dns_server.c b/src/dns_server.c index b0ee9e5d90..5ba497a934 100644 --- a/src/dns_server.c +++ b/src/dns_server.c @@ -276,6 +276,8 @@ struct dns_request { int has_soa; int force_soa; + struct dns_srv_records *srv_records; + atomic_t notified; atomic_t do_callback; atomic_t adblock; @@ -949,6 +951,29 @@ static void _dns_server_setup_soa(struct dns_request *request) soa->minimum = 86400; } +static int _dns_server_add_srv(struct dns_server_post_context *context) +{ + struct dns_request *request = context->request; + struct dns_srv_records *srv_records = request->srv_records; + struct dns_srv_record *srv_record = NULL; + int ret = 0; + + if (srv_records == NULL) { + return 0; + } + + list_for_each_entry(srv_record, &srv_records->list, list) + { + ret = dns_add_SRV(context->packet, DNS_RRS_AN, request->domain, request->ip_ttl, srv_record->priority, + srv_record->weight, srv_record->port, srv_record->host); + if (ret != 0) { + return -1; + } + } + + return 0; +} + static int _dns_add_rrs(struct dns_server_post_context *context) { struct dns_request *request = context->request; @@ -1011,6 +1036,10 @@ static int _dns_add_rrs(struct dns_server_post_context *context) ret |= dns_add_OPT_ECS(context->packet, &request->ecs); } + if (request->srv_records != NULL) { + ret |= _dns_server_add_srv(context); + } + if (request->rcode != DNS_RC_NOERROR) { tlog(TLOG_INFO, "result: %s, qtype: %d, rtcode: %d, id: %d", domain, context->qtype, request->rcode, request->id); @@ -4159,6 +4188,28 @@ static int _dns_server_process_DDR(struct dns_request *request) } static int _dns_server_process_srv(struct dns_request *request) +{ + struct dns_srv_records *srv_records = dns_server_get_srv_record(request->domain); + if (srv_records == NULL) { + return -1; + } + + request->rcode = DNS_RC_NOERROR; + request->ip_ttl = _dns_server_get_local_ttl(request); + request->srv_records = srv_records; + + struct dns_server_post_context context; + _dns_server_post_context_init(&context, request); + context.do_audit = 1; + context.do_reply = 1; + context.do_cache = 0; + context.do_force_soa = 0; + _dns_request_post(&context); + + return 0; +} + +static int _dns_server_process_svcb(struct dns_request *request) { if (strncmp("_dns.resolver.arpa", request->domain, DNS_MAX_CNAME_LEN) == 0) { return _dns_server_process_DDR(request); @@ -5268,7 +5319,7 @@ static int _dns_server_process_special_query(struct dns_request *request) switch (request->qtype) { case DNS_T_PTR: break; - case DNS_T_SVCB: + case DNS_T_SRV: ret = _dns_server_process_srv(request); if (ret == 0) { goto clean_exit; @@ -5277,6 +5328,15 @@ static int _dns_server_process_special_query(struct dns_request *request) request->passthrough = 1; } break; + case DNS_T_SVCB: + ret = _dns_server_process_svcb(request); + if (ret == 0) { + goto clean_exit; + } else { + /* pass to upstream server */ + request->passthrough = 1; + } + break; case DNS_T_A: break; case DNS_T_AAAA: diff --git a/src/smartdns.c b/src/smartdns.c index 6334058a1a..8e8a6807d7 100644 --- a/src/smartdns.c +++ b/src/smartdns.c @@ -1112,7 +1112,7 @@ int main(int argc, char *argv[]) errout: if (is_run_as_daemon) { daemon_kickoff(ret, dns_conf_log_console | verbose_screen); - } else { + } else if (dns_conf_log_console == 0 && verbose_screen == 0) { _smartdns_print_error_tip(); } smartdns_test_notify(2); diff --git a/test/cases/test-srv.cc b/test/cases/test-srv.cc new file mode 100644 index 0000000000..b249684346 --- /dev/null +++ b/test/cases/test-srv.cc @@ -0,0 +1,105 @@ +/************************************************************************* + * + * Copyright (C) 2018-2023 Ruilin Peng (Nick) . + * + * smartdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * smartdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +#include "client.h" +#include "dns.h" +#include "include/utils.h" +#include "server.h" +#include "util.h" +#include "gtest/gtest.h" +#include + +class SRV : public ::testing::Test +{ + protected: + virtual void SetUp() {} + virtual void TearDown() {} +}; + +TEST_F(SRV, query) +{ + smartdns::MockServer server_upstream; + smartdns::Server server; + + server_upstream.Start("udp://0.0.0.0:61053", [&](struct smartdns::ServerRequestContext *request) { + if (request->qtype != DNS_T_SRV) { + return smartdns::SERVER_REQUEST_SOA; + } + + struct dns_packet *packet = request->response_packet; + dns_add_SRV(packet, DNS_RRS_AN, request->domain.c_str(), 603, 1, 1, 443, "www.example.com"); + dns_add_SRV(packet, DNS_RRS_AN, request->domain.c_str(), 603, 1, 1, 443, "www1.example.com"); + + return smartdns::SERVER_REQUEST_OK; + }); + + server.Start(R"""(bind [::]:60053 +server 127.0.0.1:61053 +log-num 0 +log-console yes +log-level debug +speed-check-mode none +cache-persist no)"""); + smartdns::Client client; + ASSERT_TRUE(client.Query("_ldap._tcp.local.com SRV", 60053)); + std::cout << client.GetResult() << std::endl; + ASSERT_EQ(client.GetAnswerNum(), 2); + EXPECT_EQ(client.GetStatus(), "NOERROR"); + EXPECT_EQ(client.GetAnswer()[0].GetName(), "_ldap._tcp.local.com"); + EXPECT_EQ(client.GetAnswer()[0].GetTTL(), 603); + EXPECT_EQ(client.GetAnswer()[0].GetType(), "SRV"); + EXPECT_EQ(client.GetAnswer()[0].GetData(), "1 1 443 www.example.com."); +} + +TEST_F(SRV, match) +{ + smartdns::MockServer server_upstream; + smartdns::Server server; + + server_upstream.Start("udp://0.0.0.0:61053", [&](struct smartdns::ServerRequestContext *request) { + if (request->qtype != DNS_T_SRV) { + return smartdns::SERVER_REQUEST_SOA; + } + + struct dns_packet *packet = request->response_packet; + dns_add_SRV(packet, DNS_RRS_AN, request->domain.c_str(), 603, 1, 1, 443, "www.example.com"); + dns_add_SRV(packet, DNS_RRS_AN, request->domain.c_str(), 603, 1, 1, 443, "www1.example.com"); + + return smartdns::SERVER_REQUEST_OK; + }); + + server.Start(R"""(bind [::]:60053 +server 127.0.0.1:61053 +log-num 0 +log-console yes +log-level debug +srv-record /_ldap._tcp.local.com/www.a.com,443,1,1 +srv-record /_ldap._tcp.local.com/www1.a.com,443,1,1 +srv-record /_ldap._tcp.local.com/www2.a.com,443,1,1 +speed-check-mode none +cache-persist no)"""); + smartdns::Client client; + ASSERT_TRUE(client.Query("_ldap._tcp.local.com SRV", 60053)); + std::cout << client.GetResult() << std::endl; + ASSERT_EQ(client.GetAnswerNum(), 3); + EXPECT_EQ(client.GetStatus(), "NOERROR"); + EXPECT_EQ(client.GetAnswer()[0].GetName(), "_ldap._tcp.local.com"); + EXPECT_EQ(client.GetAnswer()[0].GetTTL(), 600); + EXPECT_EQ(client.GetAnswer()[0].GetType(), "SRV"); + EXPECT_EQ(client.GetAnswer()[0].GetData(), "1 1 443 www.a.com."); +}