diff --git a/talpid-core/src/firewall/windows.rs b/talpid-core/src/firewall/windows.rs index 91bdc7d206eb..e16d7f17e0d7 100644 --- a/talpid-core/src/firewall/windows.rs +++ b/talpid-core/src/firewall/windows.rs @@ -153,16 +153,17 @@ impl Firewall { protocol: WinFwProt::from(endpoint.endpoint.protocol), }; - let relay_client_wstr = endpoint + let relay_client_wstrs: Vec<_> = endpoint .clients - .last() - .as_ref() - .map(|client| WideCString::from_os_str_truncate(client)); - let relay_client_wstr_ptr: *const u16 = if let Some(ref wstr) = relay_client_wstr { - wstr.as_ptr() - } else { - ptr::null() - }; + .iter() + .map(|client| WideCString::from_os_str_truncate(client)).collect(); + let relay_client_wstr_ptrs: Vec<*const u16> = relay_client_wstrs.iter().map(|wstr| wstr.as_ptr()).collect(); + let relay_client_wstr_ptrs_len = relay_client_wstr_ptrs.len(); + //let relay_client_wstr_ptr: *const u16 = if let Some(ref wstr) = relay_client_wstr { + // wstr.as_ptr() + //} else { + // ptr::null() + //}; let interface_wstr = tunnel_metadata .as_ref() @@ -223,7 +224,8 @@ impl Firewall { WinFw_ApplyPolicyConnecting( winfw_settings, &winfw_relay, - relay_client_wstr_ptr, + relay_client_wstr_ptrs.as_ptr(), + relay_client_wstr_ptrs_len, interface_wstr_ptr, allowed_endpoint, &allowed_tunnel_traffic, @@ -271,15 +273,16 @@ impl Firewall { None => ptr::null(), }; - let relay_client_wstr = endpoint.clients - .last() - .as_ref() - .map(|client| WideCString::from_os_str_truncate(client)); - let relay_client_wstr_ptr: *const u16 = if let Some(ref wstr) = relay_client_wstr { - wstr.as_ptr() - } else { - ptr::null() - }; + let relay_client_wstrs: Vec<_> = endpoint.clients + .iter() + .map(|client| WideCString::from_os_str_truncate(client)).collect(); + let relay_client_wstr_ptrs: Vec<*const u16> = relay_client_wstrs.iter().map(|wstr| wstr.as_ptr()).collect(); + let relay_client_wstr_ptrs_len = relay_client_wstr_ptrs.len(); + //let relay_client_wstr_ptrs: *const u16 = if let Some(ref wstr) = relay_client_wstr { + // wstr.as_ptr() + //} else { + // ptr::null() + //}; let dns_servers: Vec = dns_servers.iter().cloned().map(widestring_ip).collect(); @@ -289,7 +292,8 @@ impl Firewall { WinFw_ApplyPolicyConnected( winfw_settings, &winfw_relay, - relay_client_wstr_ptr, + relay_client_wstr_ptrs.as_ptr(), + relay_client_wstr_ptrs_len, tunnel_alias.as_ptr(), v4_gateway.as_ptr(), v6_gateway_ptr, @@ -610,7 +614,8 @@ mod winfw { pub fn WinFw_ApplyPolicyConnecting( settings: &WinFwSettings, relay: &WinFwEndpoint, - relayClient: *const libc::wchar_t, + relayClient: *const *const libc::wchar_t, + relayClientLen: usize, tunnelIfaceAlias: *const libc::wchar_t, allowedEndpoint: *const WinFwAllowedEndpoint<'_>, allowedTunnelTraffic: &WinFwAllowedTunnelTraffic, @@ -620,7 +625,8 @@ mod winfw { pub fn WinFw_ApplyPolicyConnected( settings: &WinFwSettings, relay: &WinFwEndpoint, - relayClient: *const libc::wchar_t, + relayClient: *const *const libc::wchar_t, + relayClientLen: usize, tunnelIfaceAlias: *const libc::wchar_t, v4Gateway: *const libc::wchar_t, v6Gateway: *const libc::wchar_t, diff --git a/windows/winfw/src/winfw/fwcontext.cpp b/windows/winfw/src/winfw/fwcontext.cpp index 1a69ffc2cd16..4ed22737fcd4 100644 --- a/windows/winfw/src/winfw/fwcontext.cpp +++ b/windows/winfw/src/winfw/fwcontext.cpp @@ -81,7 +81,7 @@ void AppendRelayRules ( FwContext::Ruleset &ruleset, const WinFwEndpoint &relay, - const std::optional &relayClient + const std::vector &relayClients ) { auto sublayer = @@ -95,7 +95,7 @@ void AppendRelayRules wfp::IpAddress(relay.ip), relay.port, relay.protocol, - relayClient, + relayClients, sublayer )); } @@ -185,7 +185,7 @@ bool FwContext::applyPolicyConnecting ( const WinFwSettings &settings, const WinFwEndpoint &relay, - const std::optional &relayClient, + const std::vector &relayClients, const std::optional &tunnelInterfaceAlias, const std::optional &allowedEndpoint, const WinFwAllowedTunnelTraffic &allowedTunnelTraffic @@ -195,7 +195,7 @@ bool FwContext::applyPolicyConnecting AppendNetBlockedRules(ruleset); AppendSettingsRules(ruleset, settings); - AppendRelayRules(ruleset, relay, relayClient); + AppendRelayRules(ruleset, relay, relayClients); if (allowedEndpoint.has_value()) { @@ -280,7 +280,7 @@ bool FwContext::applyPolicyConnected ( const WinFwSettings &settings, const WinFwEndpoint &relay, - const std::optional &relayClient, + const std::vector &relayClient, const std::wstring &tunnelInterfaceAlias, const std::vector &tunnelDnsServers, const std::vector &nonTunnelDnsServers diff --git a/windows/winfw/src/winfw/fwcontext.h b/windows/winfw/src/winfw/fwcontext.h index af9320961cc5..92ecce4f4fe0 100644 --- a/windows/winfw/src/winfw/fwcontext.h +++ b/windows/winfw/src/winfw/fwcontext.h @@ -28,7 +28,7 @@ class FwContext ( const WinFwSettings &settings, const WinFwEndpoint &relay, - const std::optional &relayClient, + const std::vector &relayClients, const std::optional &tunnelInterfaceAlias, const std::optional &allowedEndpoint, const WinFwAllowedTunnelTraffic &allowedTunnelTraffic @@ -38,7 +38,7 @@ class FwContext ( const WinFwSettings &settings, const WinFwEndpoint &relay, - const std::optional &relayClient, + const std::vector &relayClients, const std::wstring &tunnelInterfaceAlias, const std::vector &tunnelDnsServers, const std::vector &nonTunnelDnsServers diff --git a/windows/winfw/src/winfw/rules/multi/permitvpnrelay.cpp b/windows/winfw/src/winfw/rules/multi/permitvpnrelay.cpp index 1e25f1f5b7ed..19ce09571b29 100644 --- a/windows/winfw/src/winfw/rules/multi/permitvpnrelay.cpp +++ b/windows/winfw/src/winfw/rules/multi/permitvpnrelay.cpp @@ -52,13 +52,13 @@ PermitVpnRelay::PermitVpnRelay const wfp::IpAddress &relay, uint16_t relayPort, WinFwProtocol protocol, - const std::optional &relayClient, + const std::vector &relayClients, Sublayer sublayer ) : m_relay(relay) , m_relayPort(relayPort) , m_protocol(protocol) - , m_relayClient(relayClient) + , m_relayClients(relayClients) , m_sublayer(sublayer) { } @@ -87,9 +87,8 @@ bool PermitVpnRelay::apply(IObjectInstaller &objectInstaller) conditionBuilder.add_condition(ConditionPort::Remote(m_relayPort)); conditionBuilder.add_condition(CreateProtocolCondition(m_protocol)); - if (m_relayClient.has_value()) - { - conditionBuilder.add_condition(std::make_unique(m_relayClient.value())); + for(auto relayClient : m_relayClients) { + conditionBuilder.add_condition(std::make_unique(relayClient)); } return objectInstaller.addFilter(filterBuilder, conditionBuilder); diff --git a/windows/winfw/src/winfw/rules/multi/permitvpnrelay.h b/windows/winfw/src/winfw/rules/multi/permitvpnrelay.h index 580bb71a2df6..cf1fb0241c36 100644 --- a/windows/winfw/src/winfw/rules/multi/permitvpnrelay.h +++ b/windows/winfw/src/winfw/rules/multi/permitvpnrelay.h @@ -24,7 +24,7 @@ class PermitVpnRelay : public IFirewallRule const wfp::IpAddress &relay, uint16_t relayPort, WinFwProtocol protocol, - const std::optional &relayClient, + const std::vector &relayClients, Sublayer sublayer ); @@ -35,7 +35,7 @@ class PermitVpnRelay : public IFirewallRule const wfp::IpAddress m_relay; const uint16_t m_relayPort; const WinFwProtocol m_protocol; - const std::optional m_relayClient; + const std::vector m_relayClients; const Sublayer m_sublayer; }; diff --git a/windows/winfw/src/winfw/winfw.cpp b/windows/winfw/src/winfw/winfw.cpp index b0ab5af0e00f..34d64971cc73 100644 --- a/windows/winfw/src/winfw/winfw.cpp +++ b/windows/winfw/src/winfw/winfw.cpp @@ -231,7 +231,8 @@ WINFW_API WinFw_ApplyPolicyConnecting( const WinFwSettings *settings, const WinFwEndpoint *relay, - const wchar_t *relayClient, + const wchar_t **relayClients, + size_t relayClientsLen, const wchar_t *tunnelInterfaceAlias, const WinFwAllowedEndpoint *allowedEndpoint, const WinFwAllowedTunnelTraffic *allowedTunnelTraffic @@ -259,10 +260,16 @@ WinFw_ApplyPolicyConnecting( THROW_ERROR("Invalid argument: allowedTunnelTraffic"); } + auto relayClientWstrings = std::vector(); + relayClientWstrings.reserve(relayClientsLen); + for(int i = 0; i < relayClientsLen; i++) { + relayClientWstrings.push_back(relayClients[i]); + } + return g_fwContext->applyPolicyConnecting( *settings, *relay, - relayClient != nullptr ? std::make_optional(relayClient) : std::nullopt, + relayClientWstrings, tunnelInterfaceAlias != nullptr ? std::make_optional(tunnelInterfaceAlias) : std::nullopt, MakeOptional(allowedEndpoint), *allowedTunnelTraffic @@ -293,7 +300,8 @@ WINFW_API WinFw_ApplyPolicyConnected( const WinFwSettings *settings, const WinFwEndpoint *relay, - const wchar_t *relayClient, + const wchar_t **relayClients, + size_t relayClientsLen, const wchar_t *tunnelInterfaceAlias, const wchar_t *v4Gateway, const wchar_t *v6Gateway, @@ -397,10 +405,16 @@ WinFw_ApplyPolicyConnected( g_logSink(MULLVAD_LOG_LEVEL_DEBUG, ss.str().c_str(), g_logSinkContext); } + auto relayClientWstrings = std::vector(); + relayClientWstrings.reserve(relayClientsLen); + for(int i = 0; i < relayClientsLen; i++) { + relayClientWstrings.push_back(relayClients[i]); + } + return g_fwContext->applyPolicyConnected( *settings, *relay, - relayClient != nullptr ? std::make_optional(relayClient) : std::nullopt, + relayClientWstrings, tunnelInterfaceAlias, tunnelDnsServers, nonTunnelDnsServers diff --git a/windows/winfw/src/winfw/winfw.h b/windows/winfw/src/winfw/winfw.h index 5d61f1029d4f..b786d943d399 100644 --- a/windows/winfw/src/winfw/winfw.h +++ b/windows/winfw/src/winfw/winfw.h @@ -164,7 +164,8 @@ WINFW_API WinFw_ApplyPolicyConnecting( const WinFwSettings *settings, const WinFwEndpoint *relay, - const wchar_t *relayClient, + const wchar_t **relayClient, + size_t relayClientLen, const wchar_t *tunnelInterfaceAlias, const WinFwAllowedEndpoint *allowedEndpoint, const WinFwAllowedTunnelTraffic *allowedTunnelTraffic @@ -194,7 +195,8 @@ WINFW_API WinFw_ApplyPolicyConnected( const WinFwSettings *settings, const WinFwEndpoint *relay, - const wchar_t *relayClient, + const wchar_t **relayClient, + size_t relayClientLen, const wchar_t *tunnelInterfaceAlias, const wchar_t *v4Gateway, const wchar_t *v6Gateway,