Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add country code exclusion in connect command #215

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion USAGE.md
Original file line number Diff line number Diff line change
Expand Up @@ -281,12 +281,16 @@ You can use the `--random` or `-r` flag to connect to a random server:

`protonvpn c -r`

There are several other variables to keep in mind when you want to connect to the “fastest” server. You can connect to the fastest server in a country, the fastest Secure Core server, the fastest P2P-enabled server, or the fastest Tor server.
There are several other variables to keep in mind when you want to connect to the “fastest” server. You can connect to the fastest server in a country, the fastest server outside a country, the fastest Secure Core server, the fastest P2P-enabled server, or the fastest Tor server.

Fastest server in a country (replace UK with the code of the desired country, e.g. `US` for USA, `JP` for Japan, `AU` for Australia, etc.):

`protonvpn c --cc UK`

Fastest server outside a country (replace UK with the code of the desired country, e.g. `US` for USA, `JP` for Japan, `AU` for Australia, etc.):

`protonvpn c --not-cc UK`

Fastest Secure Core server:

`protonvpn c --sc`
Expand Down
9 changes: 8 additions & 1 deletion protonvpn_cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
protonvpn (c | connect) [<servername>] [-p <protocol>]
protonvpn (c | connect) [-f | --fastest] [-p <protocol>]
protonvpn (c | connect) [--cc <code>] [-p <protocol>]
protonvpn (c | connect) [--not-cc <code>] [-p <protocol>]
protonvpn (c | connect) [--sc] [-p <protocol>]
protonvpn (c | connect) [--p2p] [-p <protocol>]
protonvpn (c | connect) [--tor] [-p <protocol>]
Expand All @@ -23,6 +24,7 @@
-f, --fastest Select the fastest ProtonVPN server.
-r, --random Select a random ProtonVPN server.
--cc CODE Determine the country for fastest connect.
--not-cc CODE Determine the country to exclude for fastest connect.
--sc Connect to the fastest Secure-Core server.
--p2p Connect to the fastest torrent server.
--tor Connect to the fastest Tor server.
Expand Down Expand Up @@ -123,6 +125,8 @@ def cli():
connection.direct(args.get("<servername>"), protocol)
elif args.get("--cc") is not None:
connection.country_f(args.get("--cc"), protocol)
elif args.get("--not-cc") is not None:
connection.country_f(args.get("--not-cc"), protocol, True)
# Features: 1: Secure-Core, 2: Tor, 4: P2P
elif args.get("--p2p"):
connection.feature_f(4, protocol)
Expand Down Expand Up @@ -291,6 +295,9 @@ def print_examples():
"protonvpn connect --cc AU\n"
" Connect to the fastest Australian server\n"
" with the default protocol.\n\n"
"protonvpn connect --not-cc AU\n"
" Connect to the fastest server outside Australia\n"
" with the default protocol.\n\n"
"protonvpn c --p2p -p tcp\n"
" Connect to the fastest torrent server with TCP.\n\n"
"protonvpn c --sc\n"
Expand Down Expand Up @@ -563,7 +570,7 @@ def set_killswitch():
"The Kill Switch will block all network traffic\n"
"if the VPN connection drops unexpectedly.\n"
"\n"
"Please note that the Kill Switch assumes only one network interface being active.\n" # noqa
"Please note that the Kill Switch assumes only one network interface being active.\n" # noqa
"\n"
"1) Enable Kill Switch (Block access to/from LAN)\n"
"2) Enable Kill Switch (Allow access to/from LAN)\n"
Expand Down
175 changes: 97 additions & 78 deletions protonvpn_cli/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def fastest(protocol=None):
openvpn_connect(fastest_server, protocol)


def country_f(country_code, protocol=None):
def country_f(country_code, protocol=None, excluded=False):
"""Connect to the fastest server in a specific country."""
logger.debug("Starting fastest country connect")

Expand All @@ -185,23 +185,33 @@ def country_f(country_code, protocol=None):
# Filter out excluded features and countries
server_pool = []
for server in servers:
if server["Features"] not in excluded_features and server["ExitCountry"] == country_code:
server_pool.append(server)
if server["Features"] not in excluded_features:
if (excluded and server["ExitCountry"] != country_code) or (
not excluded and server["ExitCountry"] == country_code
):
server_pool.append(server)

if len(server_pool) == 0:
print(
"[!] No Server in country {0} found\n".format(country_code)
+ "[!] Please choose a valid country"
)
logger.debug("No server in country {0}".format(country_code))
if not excluded:
print(
"[!] No Server in country {0} found\n".format(country_code)
+ "[!] Please choose a valid country"
)
logger.debug("No server in country {0}".format(country_code))
else:
print(
"[!] No Server found outside country {0}\n".format(country_code)
+ "[!] Please choose a valid country"
)
logger.debug("No server outside country {0}".format(country_code))
sys.exit(1)

fastest_server = get_fastest_server(server_pool)
openvpn_connect(fastest_server, protocol)


def feature_f(feature, protocol=None):
"""Connect to the fastest server in a specific country."""
"""Connect to the fastest server with have a specific feature"""
logger.debug(
"Starting fastest feature connect with feature {0}".format(feature)
)
Expand Down Expand Up @@ -554,75 +564,84 @@ def manage_dns(mode, dns_server=False):

if mode == "leak_protection":
logger.debug("Leak Protection initiated")
# Restore original resolv.conf if it exists
if os.path.isfile(backupfile):
logger.debug("resolv.conf.backup exists")
manage_dns("restore")
# Check for custom DNS Server
if not int(get_config_value("USER", "dns_leak_protection")):
if get_config_value("USER", "custom_dns") == "None":
logger.debug("DNS Leak Protection is disabled")
return
else:
dns_server = get_config_value("USER", "custom_dns")
logger.debug("Using custom DNS")
if shutil.which("resolvectl") and os.path.islink("/etc/resolv.conf"):
logger.debug("Running resolvectl command for leak_protection")
cmd_args = ["resolvectl", "dns", "proton0", dns_server]
pipes = subprocess.Popen(cmd_args, stderr=subprocess.PIPE)
_, std_err = pipes.communicate()
if pipes.returncode != 0:
raise Exception(f"{' '.join(cmd_args)} failed with code {pipes.returncode} -> {std_err.strip()}")
else:
logger.debug("DNS Leak Protection is enabled")
# Make sure DNS Server has been provided
if not dns_server:
raise Exception("No DNS Server has been provided.")

shutil.copy2(resolvconf_path, backupfile)
logger.debug("{0} (resolv.conf) backed up".format(resolvconf_path))

# Remove previous nameservers
dns_regex = re.compile(r"^nameserver .*$")

with open(backupfile, 'r') as backup_handle:
with open(resolvconf_path, 'w') as resolvconf_handle:
for line in backup_handle:
if not dns_regex.search(line):
resolvconf_handle.write(line)

logger.debug("Removed existing DNS Servers")

# Add ProtonVPN managed DNS Server to resolv.conf
dns_server = dns_server.split()
with open(resolvconf_path, "a") as f:
f.write("# ProtonVPN DNS Servers. Managed by ProtonVPN-CLI.\n")
for dns in dns_server[:3]:
f.write("nameserver {0}\n".format(dns))
logger.debug("Added ProtonVPN or custom DNS")

# Write the hash of the edited file in the configuration
#
# This is so it doesn't restore an old DNS configuration
# if the configuration changes during a VPN session
# (e.g. by switching networks)

with open(resolvconf_path, "rb") as f:
filehash = zlib.crc32(f.read())
set_config_value("metadata", "resolvconf_hash", filehash)
# Restore original resolv.conf if it exists
if os.path.isfile(backupfile):
logger.debug("resolv.conf.backup exists")
manage_dns("restore")
# Check for custom DNS Server
if not int(get_config_value("USER", "dns_leak_protection")):
if get_config_value("USER", "custom_dns") == "None":
logger.debug("DNS Leak Protection is disabled")
return
else:
dns_server = get_config_value("USER", "custom_dns")
logger.debug("Using custom DNS")
else:
logger.debug("DNS Leak Protection is enabled")
# Make sure DNS Server has been provided
if not dns_server:
raise Exception("No DNS Server has been provided.")

shutil.copy2(resolvconf_path, backupfile)
logger.debug("{0} (resolv.conf) backed up".format(resolvconf_path))

# Remove previous nameservers
dns_regex = re.compile(r"^nameserver .*$")

with open(backupfile, 'r') as backup_handle:
with open(resolvconf_path, 'w') as resolvconf_handle:
for line in backup_handle:
if not dns_regex.search(line):
resolvconf_handle.write(line)

logger.debug("Removed existing DNS Servers")

# Add ProtonVPN managed DNS Server to resolv.conf
dns_server = dns_server.split()
with open(resolvconf_path, "a") as f:
f.write("# ProtonVPN DNS Servers. Managed by ProtonVPN-CLI.\n")
for dns in dns_server[:3]:
f.write("nameserver {0}\n".format(dns))
logger.debug("Added ProtonVPN or custom DNS")

# Write the hash of the edited file in the configuration
#
# This is so it doesn't restore an old DNS configuration
# if the configuration changes during a VPN session
# (e.g. by switching networks)

with open(resolvconf_path, "rb") as f:
filehash = zlib.crc32(f.read())
set_config_value("metadata", "resolvconf_hash", filehash)

elif mode == "restore":
logger.debug("Restoring DNS")
if os.path.isfile(backupfile):
if not (shutil.which("resolvectl") and os.path.islink("/etc/resolv.conf")):
if os.path.isfile(backupfile):

# Check if the file changed since connection
oldhash = get_config_value("metadata", "resolvconf_hash")
with open(resolvconf_path, "rb") as f:
filehash = zlib.crc32(f.read())
# Check if the file changed since connection
oldhash = get_config_value("metadata", "resolvconf_hash")
with open(resolvconf_path, "rb") as f:
filehash = zlib.crc32(f.read())

if filehash == int(oldhash):
shutil.copy2(backupfile, resolvconf_path)
logger.debug("resolv.conf restored from backup")
else:
logger.debug("resolv.conf changed. Not restoring.")
if filehash == int(oldhash):
shutil.copy2(backupfile, resolvconf_path)
logger.debug("resolv.conf restored from backup")
else:
logger.debug("resolv.conf changed. Not restoring.")

os.remove(backupfile)
logger.debug("resolv.conf.backup removed")
else:
logger.debug("No Backupfile found")
os.remove(backupfile)
logger.debug("resolv.conf.backup removed")
else:
logger.debug("No Backupfile found")
else:
raise Exception("Invalid argument provided. "
"Mode must be 'restore' or 'leak_protection'")
Expand Down Expand Up @@ -715,7 +734,7 @@ def manage_ipv6(mode):
ipv6_addr = lines[1].strip()

ipv6_info = subprocess.run(
"ip addr show dev {0} | grep '\<inet6.*global\>'".format(default_nic), # noqa
"ip addr show dev {0} | grep '\<inet6.*global\>'".format(default_nic), # noqa
shell=True, stderr=subprocess.PIPE, stdout=subprocess.PIPE
)

Expand Down Expand Up @@ -833,10 +852,10 @@ def manage_killswitch(mode, proto=None, port=None):
"iptables -A INPUT -i lo -j ACCEPT",
"iptables -A OUTPUT -o {0} -j ACCEPT".format(device),
"iptables -A INPUT -i {0} -j ACCEPT".format(device),
"iptables -A OUTPUT -o {0} -m state --state ESTABLISHED,RELATED -j ACCEPT".format(device), # noqa
"iptables -A INPUT -i {0} -m state --state ESTABLISHED,RELATED -j ACCEPT".format(device), # noqa
"iptables -A OUTPUT -p {0} -m {1} --dport {2} -j ACCEPT".format(proto.lower(), proto.lower(), port), # noqa
"iptables -A INPUT -p {0} -m {1} --sport {2} -j ACCEPT".format(proto.lower(), proto.lower(), port), # noqa
"iptables -A OUTPUT -o {0} -m state --state ESTABLISHED,RELATED -j ACCEPT".format(device), # noqa
"iptables -A INPUT -i {0} -m state --state ESTABLISHED,RELATED -j ACCEPT".format(device), # noqa
"iptables -A OUTPUT -p {0} -m {1} --dport {2} -j ACCEPT".format(proto.lower(), proto.lower(), port), # noqa
"iptables -A INPUT -p {0} -m {1} --sport {2} -j ACCEPT".format(proto.lower(), proto.lower(), port), # noqa
]

if int(get_config_value("USER", "killswitch")) == 2:
Expand All @@ -849,8 +868,8 @@ def manage_killswitch(mode, proto=None, port=None):
local_network = local_network.stdout.decode().strip().split()[1]

exclude_lan_commands = [
"iptables -A OUTPUT -o {0} -d {1} -j ACCEPT".format(default_nic, local_network), # noqa
"iptables -A INPUT -i {0} -s {1} -j ACCEPT".format(default_nic, local_network), # noqa
"iptables -A OUTPUT -o {0} -d {1} -j ACCEPT".format(default_nic, local_network), # noqa
"iptables -A INPUT -i {0} -s {1} -j ACCEPT".format(default_nic, local_network), # noqa
]

for lan_command in exclude_lan_commands:
Expand Down