mirror of
https://github.com/kovidgoyal/kitty.git
synced 2025-12-13 20:36:22 +01:00
342 lines
12 KiB
Python
342 lines
12 KiB
Python
#!/usr/bin/env python
|
|
# License: GPLv3 Copyright: 2022, Kovid Goyal <kovid at kovidgoyal.net>
|
|
|
|
|
|
import os
|
|
import subprocess
|
|
import traceback
|
|
from collections.abc import Iterator, Sequence
|
|
from contextlib import suppress
|
|
from typing import Any
|
|
|
|
from kitty.types import run_once
|
|
from kitty.utils import SSHConnectionData
|
|
|
|
|
|
@run_once
|
|
def ssh_options() -> dict[str, str]:
|
|
try:
|
|
p = subprocess.run(['ssh'], stderr=subprocess.PIPE, encoding='utf-8')
|
|
raw = p.stderr or ''
|
|
except FileNotFoundError:
|
|
return {
|
|
'4': '', '6': '', 'A': '', 'a': '', 'C': '', 'f': '', 'G': '', 'g': '', 'K': '', 'k': '',
|
|
'M': '', 'N': '', 'n': '', 'q': '', 's': '', 'T': '', 't': '', 'V': '', 'v': '', 'X': '',
|
|
'x': '', 'Y': '', 'y': '', 'B': 'bind_interface', 'b': 'bind_address', 'c': 'cipher_spec',
|
|
'D': '[bind_address:]port', 'E': 'log_file', 'e': 'escape_char', 'F': 'configfile', 'I': 'pkcs11',
|
|
'i': 'identity_file', 'J': '[user@]host[:port]', 'L': 'address', 'l': 'login_name', 'm': 'mac_spec',
|
|
'O': 'ctl_cmd', 'o': 'option', 'p': 'port', 'Q': 'query_option', 'R': 'address',
|
|
'S': 'ctl_path', 'W': 'host:port', 'w': 'local_tun[:remote_tun]'
|
|
}
|
|
|
|
ans: dict[str, str] = {}
|
|
pos = 0
|
|
while True:
|
|
pos = raw.find('[', pos)
|
|
if pos < 0:
|
|
break
|
|
num = 1
|
|
epos = pos
|
|
while num > 0:
|
|
epos += 1
|
|
if raw[epos] not in '[]':
|
|
continue
|
|
num += 1 if raw[epos] == '[' else -1
|
|
q = raw[pos+1:epos]
|
|
pos = epos
|
|
if len(q) < 2 or q[0] != '-':
|
|
continue
|
|
if ' ' in q:
|
|
opt, desc = q.split(' ', 1)
|
|
ans[opt[1:]] = desc
|
|
else:
|
|
ans.update(dict.fromkeys(q[1:], ''))
|
|
return ans
|
|
|
|
|
|
def is_kitten_cmdline(q: Sequence[str]) -> bool:
|
|
if not q:
|
|
return False
|
|
exe_name = os.path.basename(q[0]).lower()
|
|
if exe_name == 'kitten' and q[1:2] == ['ssh']:
|
|
return True
|
|
if len(q) < 4:
|
|
return False
|
|
if exe_name != 'kitty':
|
|
return False
|
|
if q[1:3] == ['+kitten', 'ssh'] or q[1:4] == ['+', 'kitten', 'ssh']:
|
|
return True
|
|
return q[1:3] == ['+runpy', 'from kittens.runner import main; main()'] and len(q) >= 6 and q[5] == 'ssh'
|
|
|
|
|
|
def patch_cmdline(key: str, val: str, argv: list[str]) -> None:
|
|
for i, arg in enumerate(tuple(argv)):
|
|
if arg.startswith(f'--kitten={key}='):
|
|
argv[i] = f'--kitten={key}={val}'
|
|
return
|
|
elif i > 0 and argv[i-1] == '--kitten' and (arg.startswith(f'{key}=') or arg.startswith(f'{key} ')):
|
|
argv[i] = f'{key}={val}'
|
|
return
|
|
idx = argv.index('ssh')
|
|
argv.insert(idx + 1, f'--kitten={key}={val}')
|
|
|
|
|
|
def set_cwd_in_cmdline(cwd: str, argv: list[str]) -> None:
|
|
patch_cmdline('cwd', cwd, argv)
|
|
|
|
|
|
def create_shared_memory(data: Any, prefix: str) -> str:
|
|
import atexit
|
|
import json
|
|
|
|
from kitty.fast_data_types import get_boss
|
|
from kitty.shm import SharedMemory
|
|
db = json.dumps(data).encode('utf-8')
|
|
with SharedMemory(size=len(db) + SharedMemory.num_bytes_for_size, prefix=prefix) as shm:
|
|
shm.write_data_with_size(db)
|
|
shm.flush()
|
|
atexit.register(shm.close) # keeps shm alive till exit
|
|
get_boss().atexit.shm_unlink(shm.name)
|
|
return shm.name
|
|
|
|
|
|
def read_data_from_shared_memory(shm_name: str) -> Any:
|
|
import json
|
|
import stat
|
|
|
|
from kitty.shm import SharedMemory
|
|
with SharedMemory(shm_name, readonly=True) as shm:
|
|
shm.unlink()
|
|
if shm.stats.st_uid != os.geteuid() or shm.stats.st_gid != os.getegid():
|
|
raise ValueError(f'Incorrect owner on pwfile: uid={shm.stats.st_uid} gid={shm.stats.st_gid}')
|
|
mode = stat.S_IMODE(shm.stats.st_mode)
|
|
if mode != stat.S_IREAD | stat.S_IWRITE:
|
|
raise ValueError(f'Incorrect permissions on pwfile: 0o{mode:03o}')
|
|
return json.loads(shm.read_data_with_size())
|
|
|
|
|
|
def get_ssh_data(msgb: memoryview, request_id: str) -> Iterator[bytes|memoryview]:
|
|
from base64 import standard_b64decode
|
|
yield b'\nKITTY_DATA_START\n' # to discard leading data
|
|
try:
|
|
msg = standard_b64decode(msgb).decode('utf-8')
|
|
md = dict(x.split('=', 1) for x in msg.split(':'))
|
|
pw = md['pw']
|
|
pwfilename = md['pwfile']
|
|
rq_id = md['id']
|
|
except Exception:
|
|
traceback.print_exc()
|
|
yield b'invalid ssh data request message\n'
|
|
else:
|
|
try:
|
|
env_data = read_data_from_shared_memory(pwfilename)
|
|
if pw != env_data['pw']:
|
|
raise ValueError('Incorrect password')
|
|
if rq_id != request_id:
|
|
raise ValueError(f'Incorrect request id: {rq_id!r} expecting the KITTY_PID-KITTY_WINDOW_ID for the current kitty window')
|
|
except Exception as e:
|
|
traceback.print_exc()
|
|
yield f'{e}\n'.encode()
|
|
else:
|
|
yield b'OK\n'
|
|
encoded_data = memoryview(env_data['tarfile'].encode('ascii'))
|
|
# macOS has a 255 byte limit on its input queue as per man stty.
|
|
# Not clear if that applies to canonical mode input as well, but
|
|
# better to be safe.
|
|
line_sz = 254
|
|
while encoded_data:
|
|
yield encoded_data[:line_sz]
|
|
yield b'\n'
|
|
encoded_data = encoded_data[line_sz:]
|
|
yield b'KITTY_DATA_END\n'
|
|
|
|
|
|
def set_env_in_cmdline(env: dict[str, str], argv: list[str], clone: bool = True) -> None:
|
|
from kitty.options.utils import DELETE_ENV_VAR
|
|
if clone:
|
|
patch_cmdline('clone_env', create_shared_memory(env, 'ksse-'), argv)
|
|
return
|
|
idx = argv.index('ssh')
|
|
for i in range(idx, len(argv)):
|
|
if argv[i] == '--kitten':
|
|
idx = i + 1
|
|
elif argv[i].startswith('--kitten='):
|
|
idx = i
|
|
env_dirs = []
|
|
for k, v in env.items():
|
|
if v is DELETE_ENV_VAR:
|
|
x = f'--kitten=env={k}'
|
|
else:
|
|
x = f'--kitten=env={k}={v}'
|
|
env_dirs.append(x)
|
|
argv[idx+1:idx+1] = env_dirs
|
|
|
|
|
|
def get_ssh_cli() -> tuple[set[str], set[str]]:
|
|
other_ssh_args: set[str] = set()
|
|
boolean_ssh_args: set[str] = set()
|
|
for k, v in ssh_options().items():
|
|
k = f'-{k}'
|
|
if v:
|
|
other_ssh_args.add(k)
|
|
else:
|
|
boolean_ssh_args.add(k)
|
|
return boolean_ssh_args, other_ssh_args
|
|
|
|
|
|
def is_extra_arg(arg: str, extra_args: tuple[str, ...]) -> str:
|
|
for x in extra_args:
|
|
if arg == x or arg.startswith(f'{x}='):
|
|
return x
|
|
return ''
|
|
|
|
|
|
passthrough_args = {f'-{x}' for x in 'NnfGT'}
|
|
|
|
|
|
def set_server_args_in_cmdline(
|
|
server_args: list[str], argv: list[str],
|
|
extra_args: tuple[str, ...] = ('--kitten',),
|
|
allocate_tty: bool = False
|
|
) -> None:
|
|
boolean_ssh_args, other_ssh_args = get_ssh_cli()
|
|
ssh_args = []
|
|
expecting_option_val = False
|
|
found_extra_args: list[str] = []
|
|
expecting_extra_val = ''
|
|
ans = list(argv)
|
|
found_ssh = False
|
|
for i, argument in enumerate(argv):
|
|
if not found_ssh:
|
|
found_ssh = argument == 'ssh'
|
|
continue
|
|
if argument.startswith('-') and not expecting_option_val:
|
|
if argument == '--':
|
|
del ans[i+2:]
|
|
if allocate_tty and ans[i-1] != '-t':
|
|
ans.insert(i, '-t')
|
|
break
|
|
if extra_args:
|
|
matching_ex = is_extra_arg(argument, extra_args)
|
|
if matching_ex:
|
|
if '=' in argument:
|
|
exval = argument.partition('=')[-1]
|
|
found_extra_args.extend((matching_ex, exval))
|
|
else:
|
|
expecting_extra_val = matching_ex
|
|
expecting_option_val = True
|
|
continue
|
|
# could be a multi-character option
|
|
all_args = argument[1:]
|
|
for i, arg in enumerate(all_args):
|
|
arg = f'-{arg}'
|
|
if arg in boolean_ssh_args:
|
|
ssh_args.append(arg)
|
|
continue
|
|
if arg in other_ssh_args:
|
|
ssh_args.append(arg)
|
|
rest = all_args[i+1:]
|
|
if rest:
|
|
ssh_args.append(rest)
|
|
else:
|
|
expecting_option_val = True
|
|
break
|
|
raise KeyError(f'unknown option -- {arg[1:]}')
|
|
continue
|
|
if expecting_option_val:
|
|
if expecting_extra_val:
|
|
found_extra_args.extend((expecting_extra_val, argument))
|
|
expecting_extra_val = ''
|
|
else:
|
|
ssh_args.append(argument)
|
|
expecting_option_val = False
|
|
continue
|
|
del ans[i+1:]
|
|
if allocate_tty and ans[i] != '-t':
|
|
ans.insert(i, '-t')
|
|
break
|
|
argv[:] = ans + server_args
|
|
|
|
|
|
def get_connection_data(args: list[str], cwd: str = '', extra_args: tuple[str, ...] = ()) -> SSHConnectionData | None:
|
|
boolean_ssh_args, other_ssh_args = get_ssh_cli()
|
|
port: int | None = None
|
|
expecting_port = expecting_identity = False
|
|
expecting_option_val = False
|
|
expecting_hostname = False
|
|
expecting_extra_val = ''
|
|
host_name = identity_file = found_ssh = ''
|
|
found_extra_args: list[tuple[str, str]] = []
|
|
|
|
for i, arg in enumerate(args):
|
|
if not found_ssh:
|
|
if os.path.basename(arg).lower() in ('ssh', 'ssh.exe'):
|
|
found_ssh = arg
|
|
continue
|
|
if expecting_hostname:
|
|
host_name = arg
|
|
continue
|
|
if arg.startswith('-') and not expecting_option_val:
|
|
if arg in boolean_ssh_args:
|
|
continue
|
|
if arg == '--':
|
|
expecting_hostname = True
|
|
if arg.startswith('-p'):
|
|
if arg[2:].isdigit():
|
|
with suppress(Exception):
|
|
port = int(arg[2:])
|
|
continue
|
|
elif arg == '-p':
|
|
expecting_port = True
|
|
elif arg.startswith('-i'):
|
|
if arg == '-i':
|
|
expecting_identity = True
|
|
else:
|
|
identity_file = arg[2:]
|
|
continue
|
|
if arg.startswith('--') and extra_args:
|
|
matching_ex = is_extra_arg(arg, extra_args)
|
|
if matching_ex:
|
|
if '=' in arg:
|
|
exval = arg.partition('=')[-1]
|
|
found_extra_args.append((matching_ex, exval))
|
|
continue
|
|
expecting_extra_val = matching_ex
|
|
|
|
expecting_option_val = True
|
|
continue
|
|
|
|
if expecting_option_val:
|
|
if expecting_port:
|
|
with suppress(Exception):
|
|
port = int(arg)
|
|
expecting_port = False
|
|
elif expecting_identity:
|
|
identity_file = arg
|
|
elif expecting_extra_val:
|
|
found_extra_args.append((expecting_extra_val, arg))
|
|
expecting_extra_val = ''
|
|
expecting_option_val = False
|
|
continue
|
|
|
|
if not host_name:
|
|
host_name = arg
|
|
if not host_name:
|
|
return None
|
|
if host_name.startswith('ssh://'):
|
|
from urllib.parse import urlparse
|
|
purl = urlparse(host_name)
|
|
if purl.hostname:
|
|
host_name = purl.hostname
|
|
if purl.username:
|
|
host_name = f'{purl.username}@{host_name}'
|
|
if port is None and purl.port:
|
|
port = purl.port
|
|
if identity_file:
|
|
if not os.path.isabs(identity_file):
|
|
identity_file = os.path.expanduser(identity_file)
|
|
if not os.path.isabs(identity_file):
|
|
identity_file = os.path.normpath(os.path.join(cwd or os.getcwd(), identity_file))
|
|
|
|
return SSHConnectionData(found_ssh, host_name, port, identity_file, tuple(found_extra_args))
|