Refactored IPv4/6 preference logic to fix pylint warnings.

This commit is contained in:
Joe Testa
2021-02-23 16:05:01 -05:00
parent 1bbc3feb57
commit b300ad1252
5 changed files with 47 additions and 86 deletions

View File

@@ -1,7 +1,7 @@
"""
The MIT License (MIT)
Copyright (C) 2017-2020 Joe Testa (jtesta@positronsecurity.com)
Copyright (C) 2017-2021 Joe Testa (jtesta@positronsecurity.com)
Copyright (C) 2017 Andris Raugulis (moo@arthepsy.eu)
Permission is hereby granted, free of charge, to any person obtaining a copy
@@ -52,7 +52,7 @@ class SSH_Socket(ReadBuf, WriteBuf):
SM_BANNER_SENT = 1
def __init__(self, host: Optional[str], port: int, ipvo: Optional[Sequence[int]] = None, timeout: Union[int, float] = 5, timeout_set: bool = False) -> None:
def __init__(self, host: Optional[str], port: int, ip_version_preference: List[int] = [], timeout: Union[int, float] = 5, timeout_set: bool = False) -> None: # pylint: disable=dangerous-default-value
super(SSH_Socket, self).__init__()
self.__sock: Optional[socket.socket] = None
self.__sock_map: Dict[int, socket.socket] = {}
@@ -67,32 +67,27 @@ class SSH_Socket(ReadBuf, WriteBuf):
raise ValueError('invalid port: {}'.format(port))
self.__host = host
self.__port = nport
if ipvo is not None:
self.__ipvo = ipvo
else:
self.__ipvo = ()
self.__ip_version_preference = ip_version_preference # Holds only 5 possible values: [] (no preference), [4] (use IPv4 only), [6] (use IPv6 only), [46] (use both IPv4 and IPv6, but prioritize v4), and [64] (use both IPv4 and IPv6, but prioritize v6).
self.__timeout = timeout
self.__timeout_set = timeout_set
self.client_host: Optional[str] = None
self.client_port = None
def _resolve(self, ipvo: Sequence[int]) -> Iterable[Tuple[int, Tuple[Any, ...]]]:
ipvo = tuple([x for x in Utils.unique_seq(ipvo) if x in (4, 6)])
ipvo_len = len(ipvo)
prefer_ipvo = ipvo_len > 0
prefer_ipv4 = prefer_ipvo and ipvo[0] == 4
if ipvo_len == 1:
family = socket.AF_INET if ipvo[0] == 4 else socket.AF_INET6
def _resolve(self) -> Iterable[Tuple[int, Tuple[Any, ...]]]:
# If __ip_version_preference has only one entry, then it means that ONLY that IP version should be used.
if len(self.__ip_version_preference) == 1:
family = socket.AF_INET if self.__ip_version_preference[0] == 4 else socket.AF_INET6
else:
family = socket.AF_UNSPEC
try:
stype = socket.SOCK_STREAM
r = socket.getaddrinfo(self.__host, self.__port, family, stype)
if prefer_ipvo:
r = sorted(r, key=lambda x: x[0], reverse=not prefer_ipv4)
check = any(stype == rline[2] for rline in r)
# If the user has a preference for using IPv4 over IPv6 (or vice-versa), then sort the list returned by getaddrinfo() so that the preferred address type comes first.
if len(self.__ip_version_preference) == 2:
r = sorted(r, key=lambda x: x[0], reverse=(self.__ip_version_preference[0] == 6))
for af, socktype, _proto, _canonname, addr in r:
if not check or socktype == socket.SOCK_STREAM:
if socktype == socket.SOCK_STREAM:
yield af, addr
except socket.error as e:
OutputBuffer().fail('[exception] {}'.format(e)).write()
@@ -156,7 +151,7 @@ class SSH_Socket(ReadBuf, WriteBuf):
def connect(self) -> Optional[str]:
'''Returns None on success, or an error string.'''
err = None
for af, addr in self._resolve(self.__ipvo):
for af, addr in self._resolve():
s = None
try:
s = socket.socket(af, socket.SOCK_STREAM)