Files
kitty-mirror/kitty/file_transmission.py
2025-05-30 10:06:38 +05:30

1253 lines
49 KiB
Python

#!/usr/bin/env python
# License: GPLv3 Copyright: 2021, Kovid Goyal <kovid at kovidgoyal.net>
import errno
import inspect
import io
import json
import os
import re
import stat
import tempfile
from base64 import b85decode
from collections import defaultdict, deque
from collections.abc import Callable, Iterable, Iterator
from contextlib import suppress
from dataclasses import Field, dataclass, field, fields
from enum import Enum, auto
from functools import partial
from gettext import gettext as _
from itertools import count
from time import time_ns
from typing import IO, Any, DefaultDict, Deque, Union
from kittens.transfer.utils import IdentityCompressor, ZlibCompressor, abspath, expand_home, home_path
from kitty.fast_data_types import ESC_OSC, FILE_TRANSFER_CODE, AES256GCMDecrypt, add_timer, base64_decode, base64_encode, get_boss, get_options, monotonic
from kitty.types import run_once
from kitty.typing_compat import ReadableBuffer, WriteableBuffer
from .utils import log_error
EXPIRE_TIME = 10 # minutes
MAX_ACTIVE_RECEIVES = MAX_ACTIVE_SENDS = 10
ftc_prefix = str(FILE_TRANSFER_CODE)
@run_once
def safe_string_pat() -> 're.Pattern[str]':
return re.compile(r'[^0-9a-zA-Z_:./@-]')
def safe_string(x: str) -> str:
return safe_string_pat().sub('', x)
def as_unicode(x: str | bytes) -> str:
if isinstance(x, bytes):
x = x.decode('ascii')
return x
def encode_bypass(request_id: str, bypass: str) -> str:
import hashlib
q = request_id + ';' + bypass
return 'sha256:' + hashlib.sha256(q.encode('utf-8', 'replace')).hexdigest()
def split_for_transfer(
data: bytes | bytearray | memoryview,
session_id: str = '', file_id: str = '',
mark_last: bool = False,
chunk_size: int = 4096
) -> Iterator['FileTransmissionCommand']:
if isinstance(data, (bytes, bytearray)):
data = memoryview(data)
while len(data):
ac = Action.data
if mark_last and len(data) <= chunk_size:
ac = Action.end_data
yield FileTransmissionCommand(action=ac, id=session_id, file_id=file_id, data=data[:chunk_size])
data = data[chunk_size:]
def iter_file_metadata(file_specs: Iterable[tuple[str, str]]) -> Iterator[Union['FileTransmissionCommand', 'TransmissionError']]:
file_map: DefaultDict[tuple[int, int], list[FileTransmissionCommand]] = defaultdict(list)
counter = count()
def skey(sr: os.stat_result) -> tuple[int, int]:
return sr.st_dev, sr.st_ino
def make_ftc(path: str, spec_id: str, sr: os.stat_result | None = None, parent: str = '') -> FileTransmissionCommand:
if sr is None:
sr = os.stat(path, follow_symlinks=False)
if stat.S_ISLNK(sr.st_mode):
ftype = FileType.symlink
elif stat.S_ISDIR(sr.st_mode):
ftype = FileType.directory
elif stat.S_ISREG(sr.st_mode):
ftype = FileType.regular
else:
raise ValueError('Not an appropriate file type')
ans = FileTransmissionCommand(
action=Action.file, file_id=spec_id, mtime=sr.st_mtime_ns, permissions=stat.S_IMODE(sr.st_mode),
name=path, status=str(next(counter)), size=sr.st_size, ftype=ftype, parent=parent
)
file_map[skey(sr)].append(ans)
return ans
def add_dir(ftc: FileTransmissionCommand) -> None:
try:
lr = os.listdir(ftc.name)
except OSError:
return
for entry in lr:
try:
child_ftc = make_ftc(os.path.join(ftc.name, entry), spec_id, parent=ftc.status)
except (ValueError, OSError):
continue
if child_ftc.ftype is FileType.directory:
add_dir(child_ftc)
for spec_id, spec in file_specs:
path = spec
if not os.path.isabs(path):
path = expand_home(path)
if not os.path.isabs(path):
path = abspath(path, use_home=True)
try:
sr = os.stat(path, follow_symlinks=False)
read_ok = os.access(path, os.R_OK, follow_symlinks=False)
except OSError as err:
errname = errno.errorcode.get(err.errno, 'EFAIL') if err.errno is not None else 'EFAIL'
yield TransmissionError(file_id=spec_id, code=errname, msg='Failed to read spec')
continue
if not read_ok:
yield TransmissionError(file_id=spec_id, code='EPERM', msg='No permission to read spec')
continue
try:
ftc = make_ftc(path, spec_id, sr)
except ValueError:
yield TransmissionError(file_id=spec_id, code='EINVAL', msg='Not a valid filetype')
continue
if ftc.ftype is FileType.directory:
add_dir(ftc)
def resolve_symlink(ftc: FileTransmissionCommand) -> FileTransmissionCommand:
if ftc.ftype is FileType.symlink:
try:
dest = os.path.realpath(ftc.name)
except OSError:
pass
else:
try:
s = os.stat(dest, follow_symlinks=False)
except OSError:
pass
else:
tgt = file_map.get(skey(s))
if tgt is not None:
ftc.data = tgt[0].status.encode('utf-8')
return ftc
for fkey, cmds in file_map.items():
base = cmds[0]
yield resolve_symlink(base)
if len(cmds) > 1 and base.ftype is FileType.regular:
for q in cmds:
if q is not base and q.ftype is FileType.regular:
q.ftype = FileType.link
q.data = base.status.encode('utf-8', 'replace')
yield q
class NameReprEnum(Enum):
def __repr__(self) -> str:
return f'<{self.__class__.__name__}.{self.name}>'
class Action(NameReprEnum):
send = auto()
file = auto()
data = auto()
end_data = auto()
receive = auto()
invalid = auto()
cancel = auto()
status = auto()
finish = auto()
class Compression(NameReprEnum):
zlib = auto()
none = auto()
class FileType(NameReprEnum):
regular = auto()
directory = auto()
symlink = auto()
link = auto()
@property
def short_text(self) -> str:
return {FileType.regular: 'fil', FileType.directory: 'dir', FileType.symlink: 'sym', FileType.link: 'lnk'}[self]
@property
def color(self) -> str:
return {FileType.regular: 'yellow', FileType.directory: 'magenta', FileType.symlink: 'blue', FileType.link: 'green'}[self]
class TransmissionType(NameReprEnum):
simple = auto()
rsync = auto()
ErrorCode = Enum('ErrorCode', 'OK STARTED CANCELED PROGRESS EINVAL EPERM EISDIR ENOENT')
class TransmissionError(Exception):
def __init__(
self, code: ErrorCode | str = ErrorCode.EINVAL,
msg: str = 'Generic error',
transmit: bool = True,
file_id: str = '',
name: str = '',
size: int = -1,
ttype: TransmissionType = TransmissionType.simple,
) -> None:
super().__init__(msg)
self.transmit = transmit
self.file_id = file_id
self.human_msg = msg
self.code = code
self.name = name
self.size = size
self.ttype = ttype
def as_ftc(self, request_id: str) -> 'FileTransmissionCommand':
name = self.code if isinstance(self.code, str) else self.code.name
if self.human_msg:
name += ':' + self.human_msg
return FileTransmissionCommand(
action=Action.status, id=request_id, file_id=self.file_id, status=name, name=self.name, size=self.size, ttype=self.ttype
)
@run_once
def name_to_serialized_map() -> dict[str, str]:
ans: dict[str, str] = {}
for k in fields(FileTransmissionCommand):
ans[k.name] = k.metadata.get('sname', k.name)
return ans
@run_once
def serialized_to_field_map() -> dict[bytes | memoryview, 'Field[Any]']:
ans: dict[bytes | memoryview, 'Field[Any]'] = {}
for k in fields(FileTransmissionCommand):
ans[k.metadata.get('sname', k.name).encode('ascii')] = k
return ans
@dataclass
class FileTransmissionCommand:
action: Action = field(default=Action.invalid, metadata={'sname': 'ac'})
compression: Compression = field(default=Compression.none, metadata={'sname': 'zip'})
ftype: FileType = field(default=FileType.regular, metadata={'sname': 'ft'})
ttype: TransmissionType = field(default=TransmissionType.simple, metadata={'sname': 'tt'})
id: str = ''
file_id: str = field(default='', metadata={'sname': 'fid'})
bypass: str = field(default='', metadata={'base64': True, 'sname': 'pw'})
quiet: int = field(default=0, metadata={'sname': 'q'})
mtime: int = field(default=-1, metadata={'sname': 'mod'})
permissions: int = field(default=-1, metadata={'sname': 'prm'})
size: int = field(default=-1, metadata={'sname': 'sz'})
name: str = field(default='', metadata={'base64': True, 'sname': 'n'})
status: str = field(default='', metadata={'base64': True, 'sname': 'st'})
parent: str = field(default='', metadata={'sname': 'pr'})
data: bytes | memoryview = field(default=b'', repr=False, metadata={'sname': 'd'})
def __repr__(self) -> str:
ans = []
for k in fields(self):
if not k.repr:
continue
val = getattr(self, k.name)
if val != k.default:
ans.append(f'{k.name}={val!r}')
if self.data:
ans.append(f'data={len(self.data)} bytes')
return 'FTC(' + ', '.join(ans) + ')'
def asdict(self, keep_defaults: bool = False) -> dict[str, str | int | bytes]:
ans = {}
for k in fields(self):
val = getattr(self, k.name)
if not keep_defaults and val == k.default:
continue
if inspect.isclass(k.type) and issubclass(k.type, Enum):
val = val.name
ans[k.name] = val
return ans
def get_serialized_fields(self, prefix_with_osc_code: bool = False) -> Iterator[str | bytes]:
nts = name_to_serialized_map()
found = False
if prefix_with_osc_code:
yield ftc_prefix
found = True
for k in fields(self):
name = k.name
val = getattr(self, name)
if val == k.default:
continue
if found:
yield ';'
else:
found = True
yield nts[name]
yield '='
if inspect.isclass(k.type) and issubclass(k.type, Enum):
yield val.name
elif k.type == bytes | memoryview:
yield base64_encode(val)
elif k.type is str:
if k.metadata.get('base64'):
yield base64_encode(val.encode('utf-8'))
else:
yield safe_string(val)
elif k.type is int:
yield str(val)
else:
raise KeyError(f'Field of unknown type: {k.name}')
def serialize(self, prefix_with_osc_code: bool = False) -> str:
return ''.join(map(as_unicode, self.get_serialized_fields(prefix_with_osc_code)))
@classmethod
def deserialize(cls, data: str | bytes | memoryview) -> 'FileTransmissionCommand':
ans = FileTransmissionCommand()
fmap = serialized_to_field_map()
from kittens.transfer.rsync import parse_ftc
def handle_item(key: memoryview, val: memoryview) -> None:
field = fmap.get(key)
if field is None:
return
if inspect.isclass(field.type) and issubclass(field.type, Enum):
setattr(ans, field.name, field.type[str(val, "utf-8")])
elif field.type == bytes | memoryview:
setattr(ans, field.name, base64_decode(val))
elif field.type is int:
setattr(ans, field.name, int(val))
elif field.type is str:
if field.metadata.get('base64'):
sval = base64_decode(val).decode('utf-8')
else:
sval = safe_string(str(val, "utf-8"))
setattr(ans, field.name, sval)
parse_ftc(data, handle_item)
if ans.action is Action.invalid:
raise ValueError('No valid action specified in file transmission command')
return ans
class IdentityDecompressor:
def __call__(self, data: bytes | memoryview, is_last: bool = False) -> bytes:
return bytes(data)
class ZlibDecompressor:
def __init__(self) -> None:
import zlib
self.d = zlib.decompressobj(wbits=0)
def __call__(self, data: bytes | memoryview, is_last: bool = False) -> bytes:
ans = self.d.decompress(data)
if is_last:
ans += self.d.flush()
return ans
class PatchFile:
def __init__(self, path: str, expected_size: int):
from kittens.transfer.rsync import Patcher
self.patcher = Patcher(expected_size)
self.block_buffer = memoryview(bytearray(self.patcher.block_size))
self.path = path
self.signature_done = False
self.src_file: io.BufferedReader | None = None
self._dest_file: IO[bytes] | None = None
self.closed = False
@property
def dest_file(self) -> IO[bytes]:
if self._dest_file is None:
self._dest_file = tempfile.NamedTemporaryFile(mode='wb', dir=os.path.dirname(os.path.abspath(os.path.realpath(self.path))), delete=False)
return self._dest_file
def close(self) -> None:
if self.closed:
return
self.closed = True
p = self.patcher
del self.block_buffer, self.patcher
if self._dest_file is not None and not self._dest_file.closed:
self._dest_file.close()
p.finish_delta_data()
if self.src_file is not None:
os.replace(self.dest_file.name, self.src_file.name)
if self.src_file is not None and not self.src_file.closed:
self.src_file.close()
def tell(self) -> int:
df = self.dest_file
if df.closed:
return os.path.getsize(self.path)
return df.tell()
def read_from_src(self, pos: int, b: WriteableBuffer) -> int:
assert self.src_file is not None
self.src_file.seek(pos, os.SEEK_SET)
return self.src_file.readinto(b)
def write_to_dest(self, b: ReadableBuffer) -> None:
self.dest_file.write(b)
def write(self, b: bytes) -> None:
self.patcher.apply_delta_data(b, self.read_from_src, self.write_to_dest)
def next_signature_block(self, buf: memoryview) -> int:
if self.signature_done:
return 0
if self.src_file is None:
self.src_file = open(self.path, 'rb')
return self.patcher.signature_header(buf)
n = self.src_file.readinto(self.block_buffer)
if n > 0:
n = self.patcher.sign_block(self.block_buffer[:n], buf)
else:
self.src_file.seek(0, os.SEEK_SET)
self.signature_done = True
return n
class DestFile:
def __init__(self, ftc: FileTransmissionCommand) -> None:
self.name = ftc.name
if not os.path.isabs(self.name):
self.name = expand_home(self.name)
if not os.path.isabs(self.name):
self.name = abspath(self.name, use_home=True)
try:
self.existing_stat: os.stat_result | None = os.stat(self.name, follow_symlinks=False)
except OSError:
self.existing_stat = None
self.needs_unlink = self.existing_stat is not None and (self.existing_stat.st_nlink > 1 or stat.S_ISLNK(self.existing_stat.st_mode))
self.mtime = ftc.mtime
self.file_id = ftc.file_id
self.permissions = ftc.permissions
if self.permissions != FileTransmissionCommand.permissions:
self.permissions = stat.S_IMODE(self.permissions)
self.ftype = ftc.ftype
self.ttype = ftc.ttype
self.link_target = b''
self.needs_data_sent = self.ttype is not TransmissionType.simple
self.decompressor: ZlibDecompressor | IdentityDecompressor = ZlibDecompressor() if ftc.compression is Compression.zlib else IdentityDecompressor()
self.closed = self.ftype is FileType.directory
self.actual_file: PatchFile | IO[bytes] | None = None
self.failed = False
self.bytes_written = 0
def signature_iterator(self) -> PatchFile:
self.actual_file = PatchFile(self.name, self.existing_stat.st_size if self.existing_stat is not None else 0)
return self.actual_file
def __repr__(self) -> str:
return f'DestFile(name={self.name}, file_id={self.file_id}, actual_file={self.actual_file})'
def close(self) -> None:
if not self.closed:
self.closed = True
if self.actual_file is not None:
self.actual_file.close()
self.actual_file = None
def make_parent_dirs(self) -> str:
d = os.path.dirname(self.name)
if d:
os.makedirs(d, exist_ok=True)
return d
def apply_metadata(self, is_symlink: bool = False) -> None:
if self.permissions != FileTransmissionCommand.permissions:
if is_symlink:
with suppress(NotImplementedError):
os.chmod(self.name, self.permissions, follow_symlinks=False)
else:
os.chmod(self.name, self.permissions)
if self.mtime != FileTransmissionCommand.mtime:
if is_symlink:
with suppress(NotImplementedError):
os.utime(self.name, ns=(self.mtime, self.mtime), follow_symlinks=False)
else:
os.utime(self.name, ns=(self.mtime, self.mtime))
def unlink_existing_if_needed(self, force: bool = False) -> None:
if force or self.needs_unlink:
with suppress(FileNotFoundError):
os.unlink(self.name)
self.existing_stat = None
self.needs_unlink = False
def write_data(self, all_files: dict[str, 'DestFile'], data: bytes | memoryview, is_last: bool) -> None:
if self.ftype is FileType.directory:
raise TransmissionError(code=ErrorCode.EISDIR, file_id=self.file_id, msg='Cannot write data to a directory entry')
if self.closed:
raise TransmissionError(file_id=self.file_id, msg='Cannot write to a closed file')
if self.ftype in (FileType.symlink, FileType.link):
self.link_target += data
self.bytes_written += len(data)
if is_last:
lt = self.link_target.decode('utf-8', 'replace')
base = self.make_parent_dirs()
self.unlink_existing_if_needed(force=True)
if lt.startswith('fid:'):
lt = all_files[lt[4:]].name
if self.ftype is FileType.symlink:
lt = os.path.relpath(lt, os.path.dirname(self.name))
elif lt.startswith('fid_abs:'):
lt = all_files[lt[8:]].name
elif lt.startswith('path:'):
lt = lt[5:]
if not os.path.isabs(lt) and self.ftype is FileType.link:
lt = os.path.join(base, lt)
lt = lt.replace('/', os.sep)
else:
raise TransmissionError(msg='Unknown link target type', file_id=self.file_id)
if self.ftype is FileType.symlink:
os.symlink(lt, self.name)
else:
os.link(lt, self.name)
self.close()
self.apply_metadata(is_symlink=True)
elif self.ftype is FileType.regular:
decompressed = self.decompressor(data, is_last=is_last)
if self.actual_file is None:
self.make_parent_dirs()
self.unlink_existing_if_needed()
flags = os.O_RDWR | os.O_CREAT | os.O_TRUNC | getattr(os, 'O_CLOEXEC', 0) | getattr(os, 'O_BINARY', 0)
self.actual_file = open(os.open(self.name, flags, self.permissions), mode='r+b', closefd=True)
af = self.actual_file
if decompressed or is_last:
af.write(decompressed)
self.bytes_written = af.tell()
if is_last:
self.close()
self.apply_metadata()
def check_bypass(password: str, request_id: str, bypass_data: str) -> bool:
protocol, sep, bypass_data = bypass_data.partition(':')
if protocol == 'kitty-1':
try:
pcmd = json.loads(bypass_data)
pubkey = pcmd.get('pubkey', '')
if not pubkey:
return False
ekey = get_boss().encryption_key
d = AES256GCMDecrypt(ekey.derive_secret(b85decode(pubkey)), b85decode(pcmd['iv']), b85decode(pcmd['tag']))
data = d.add_data_to_be_decrypted(b85decode(pcmd['encrypted']), True)
timestamp, sep, payload = data.decode('utf-8').partition(':')
delta = time_ns() - int(timestamp)
if abs(delta) > 5 * 60 * 1e9:
return False
return payload == f'{request_id};{password}'
except Exception as err:
log_error(f'Invalid file transmission bypass data received: {err}')
return False
elif protocol == 'sha256':
return (encode_bypass(request_id, password) == bypass_data) if password else False
else:
log_error(f'Invalid file transmission bypass data received with protocol: {protocol}')
return False
class ActiveReceive:
id: str
files: dict[str, DestFile]
accepted: bool = False
def __init__(self, request_id: str, quiet: int, bypass: str) -> None:
self.id = request_id
self.bypass_ok: bool | None = None
if bypass:
byp = get_options().file_transfer_confirmation_bypass
self.bypass_ok = check_bypass(byp, request_id, bypass)
self.files = {}
self.last_activity_at = monotonic()
self.send_acknowledgements = quiet < 1
self.send_errors = quiet < 2
self.pending_files_to_transmit_signature_of: Deque[tuple[PatchFile, str]] = deque()
self.signature_pending_chunks: Deque[FileTransmissionCommand] = deque()
@property
def is_expired(self) -> bool:
return monotonic() - self.last_activity_at > (60 * EXPIRE_TIME)
def close(self) -> None:
for x in self.files.values():
x.close()
self.files = {}
def cancel(self) -> None:
self.close()
def start_file(self, ftc: FileTransmissionCommand) -> DestFile:
self.last_activity_at = monotonic()
if ftc.file_id in self.files:
raise TransmissionError(
msg=f'The file_id {ftc.file_id} already exists',
file_id=ftc.file_id,
)
self.files[ftc.file_id] = df = DestFile(ftc)
return df
def add_data(self, ftc: FileTransmissionCommand) -> DestFile:
self.last_activity_at = monotonic()
df = self.files.get(ftc.file_id)
if df is None:
raise TransmissionError(file_id=ftc.file_id, msg='Cannot write to a file without first starting it')
if df.failed:
return df
try:
df.write_data(self.files, ftc.data, ftc.action is Action.end_data)
except Exception:
df.failed = True
with suppress(Exception):
df.close()
raise
return df
def commit(self, send_os_error: Callable[[OSError, str, 'ActiveReceive', str], None]) -> None:
directories = sorted((df for df in self.files.values() if df.ftype is FileType.directory), key=lambda x: len(x.name), reverse=True)
for df in directories:
with suppress(OSError):
# we ignore failures to apply directory metadata as we have already sent an OK for the dir
df.apply_metadata()
class SourceFile:
def __init__(self, ftc: FileTransmissionCommand):
self.file_id = ftc.file_id
self.path = ftc.name
self.ttype = ftc.ttype
self.waiting_for_signature = True if self.ttype is TransmissionType.rsync else False
self.transmitted = False
self.stat = os.stat(self.path, follow_symlinks=False)
if stat.S_ISDIR(self.stat.st_mode):
raise TransmissionError(ErrorCode.EINVAL, msg='Cannot send a directory', file_id=self.file_id)
self.compressor: ZlibCompressor | IdentityCompressor = IdentityCompressor()
self.target = b''
self.open_file: io.BufferedReader | None = None
if stat.S_ISLNK(self.stat.st_mode):
self.target = os.readlink(self.path).encode('utf-8')
else:
self.open_file = open(self.path, 'rb')
if ftc.compression is Compression.zlib:
self.compressor = ZlibCompressor()
from kittens.transfer import rsync
self.differ = rsync.Differ() if self.waiting_for_signature else None
self.buf = bytearray()
self.write_pos = 0
def write(self, b: ReadableBuffer) -> None:
self.buf[self.write_pos:self.write_pos+len(b)] = b
self.write_pos += len(b)
@property
def ready_to_transmit(self) -> bool:
return not self.transmitted and not self.waiting_for_signature
def close(self) -> None:
if self.open_file is not None:
self.open_file.close()
self.open_file = None
self.differ = None
def next_chunk(self, sz: int = 1024 * 1024) -> tuple[bytes, int]:
data: bytes | memoryview = b''
if self.target:
self.transmitted = True
data = self.target
else:
if self.open_file is None:
self.transmitted = True
data = b''
else:
if self.differ is None:
data = self.open_file.read(sz)
if not data or self.open_file.tell() >= self.stat.st_size:
self.transmitted = True
else:
self.write_pos = 0
has_more = self.differ.next_op(self.open_file.readinto, self.write)
data = memoryview(self.buf)[:self.write_pos]
if not has_more:
self.transmitted = True
uncompressed_sz = len(data)
cchunk = self.compressor.compress(data)
if self.transmitted and not isinstance(self.compressor, IdentityCompressor):
cchunk += self.compressor.flush()
if self.transmitted:
self.close()
return cchunk, uncompressed_sz
class ActiveSend:
def __init__(self, request_id: str, quiet: int, bypass: str, num_of_args: int) -> None:
self.id = request_id
self.expected_num_of_args = num_of_args
self.bypass_ok: bool | None = None
if bypass:
byp = get_options().file_transfer_confirmation_bypass
self.bypass_ok = check_bypass(byp, request_id, bypass)
self.accepted = False
self.last_activity_at = monotonic()
self.send_acknowledgements = quiet < 1
self.send_errors = quiet < 2
self.last_activity_at = monotonic()
self.file_specs: list[tuple[str, str]] = []
self.queued_files_map: dict[str, SourceFile] = {}
self.active_file: SourceFile | None = None
self.pending_chunks: Deque[FileTransmissionCommand] = deque()
self.metadata_sent = False
@property
def spec_complete(self) -> bool:
return self.expected_num_of_args <= len(self.file_specs)
def add_file_spec(self, cmd: FileTransmissionCommand) -> None:
self.last_activity_at = monotonic()
if len(self.file_specs) > 8192 or self.spec_complete:
raise TransmissionError(ErrorCode.EINVAL, 'Too many file specs')
self.file_specs.append((cmd.file_id, cmd.name))
def add_send_file(self, cmd: FileTransmissionCommand) -> None:
self.last_activity_at = monotonic()
if len(self.queued_files_map) > 32768:
raise TransmissionError(ErrorCode.EINVAL, 'Too many queued files')
self.queued_files_map[cmd.file_id] = SourceFile(cmd)
def add_signature_data(self, cmd: FileTransmissionCommand) -> None:
self.last_activity_at = monotonic()
af = self.queued_files_map.get(cmd.file_id)
if af is None:
raise TransmissionError(ErrorCode.EINVAL, f'Signature data for unknown file_id: {cmd.file_id}')
sl = af.differ
if sl is None:
raise TransmissionError(ErrorCode.EINVAL, f'Signature data for file that is not using rsync: {cmd.file_id}')
sl.add_signature_data(cmd.data)
if cmd.action is Action.end_data:
sl.finish_signature_data()
af.waiting_for_signature = False
@property
def is_expired(self) -> bool:
return monotonic() - self.last_activity_at > (60 * EXPIRE_TIME)
def close(self) -> None:
if self.active_file is not None:
self.active_file.close()
self.active_file = None
def next_chunk(self) -> FileTransmissionCommand | None:
self.last_activity_at = monotonic()
if self.pending_chunks:
return self.pending_chunks.popleft()
af = self.active_file
if af is None:
for f in self.queued_files_map.values():
if f.ready_to_transmit:
self.active_file = af = f
break
if af is None:
return None
self.queued_files_map.pop(af.file_id, None)
while True:
chunk, uncompressed_sz = af.next_chunk()
if af.transmitted:
self.active_file = None
break
if chunk:
break
if chunk:
self.pending_chunks.extend(split_for_transfer(chunk, file_id=af.file_id, mark_last=af.transmitted))
return self.pending_chunks.popleft()
elif af.transmitted:
return FileTransmissionCommand(action=Action.end_data, file_id=af.file_id)
return None
def return_chunk(self, ftc: FileTransmissionCommand) -> None:
self.pending_chunks.insert(0, ftc)
class FileTransmission:
def __init__(self, window_id: int):
self.window_id = window_id
self.active_receives: dict[str, ActiveReceive] = {}
self.active_sends: dict[str, ActiveSend] = {}
self.pending_receive_responses: Deque[FileTransmissionCommand] = deque()
self.pending_timer: int | None = None
def callback_after(self, callback: Callable[[int | None], None], timeout: float = 0) -> int | None:
return add_timer(callback, timeout, False)
def start_pending_timer(self) -> None:
if self.pending_timer is None:
self.pending_timer = self.callback_after(self.try_pending, 0.2)
def try_pending(self, timer_id: int | None) -> None:
self.pending_timer = None
while self.pending_receive_responses:
payload = self.pending_receive_responses.popleft()
ar = self.active_receives.get(payload.id)
if ar is None:
continue
if not self.write_ftc_to_child(payload, appendleft=True):
break
ar.last_activity_at = monotonic()
self.prune_expired()
def __del__(self) -> None:
for ar in self.active_receives.values():
ar.close()
self.active_receives = {}
for a in self.active_sends.values():
a.close()
self.active_receives = {}
self.active_sends = {}
def drop_receive(self, receive_id: str) -> None:
ar = self.active_receives.pop(receive_id, None)
if ar is not None:
ar.close()
def drop_send(self, send_id: str) -> None:
a = self.active_sends.pop(send_id, None)
if a is not None:
a.close()
def prune_expired(self) -> None:
for k in tuple(self.active_receives):
if self.active_receives[k].is_expired:
self.drop_receive(k)
for a in tuple(self.active_sends):
if self.active_sends[a].is_expired:
self.drop_send(a)
def handle_serialized_command(self, data: memoryview) -> None:
try:
cmd = FileTransmissionCommand.deserialize(data)
except Exception as e:
log_error(f'Failed to parse file transmission command with error: {e}')
return
# print('from kitten:', cmd)
if not cmd.id:
log_error('File transmission command without id received, ignoring')
return
if cmd.action is Action.cancel:
if cmd.id in self.active_receives:
self.handle_receive_cmd(cmd)
return
if cmd.id in self.active_sends:
self.handle_send_cmd(cmd)
return
self.prune_expired()
if cmd.id in self.active_receives or cmd.action is Action.send:
self.handle_receive_cmd(cmd)
if cmd.id in self.active_sends or cmd.action is Action.receive:
self.handle_send_cmd(cmd)
def handle_send_cmd(self, cmd: FileTransmissionCommand) -> None:
if cmd.id in self.active_sends:
asd = self.active_sends[cmd.id]
if cmd.action is Action.receive:
log_error('File transmission receive received for already active id, aborting')
self.drop_send(cmd.id)
return
if cmd.action is Action.file:
try:
asd.add_send_file(cmd) if asd.metadata_sent else asd.add_file_spec(cmd)
except OSError as err:
self.send_fail_on_os_error(err, 'Failed to add send file', asd, cmd.file_id)
self.drop_send(asd.id)
return
except TransmissionError as err:
self.drop_send(asd.id)
if asd.send_errors:
self.send_transmission_error(asd.id, err)
return
if asd.metadata_sent:
self.pump_send_chunks(asd)
else:
if asd.spec_complete and asd.accepted:
self.send_metadata_for_send_transfer(asd)
return
if cmd.action in (Action.data, Action.end_data):
try:
asd.add_signature_data(cmd)
except TransmissionError as err:
self.drop_send(asd.id)
if asd.send_errors:
self.send_transmission_error(asd.id, err)
else:
self.pump_send_chunks(asd)
elif cmd.action in (Action.status, Action.finish):
self.drop_send(asd.id)
return
if not asd.accepted:
log_error(f'File transmission command {cmd.action} received for pending id: {cmd.id}, aborting')
self.drop_send(cmd.id)
return
asd.last_activity_at = monotonic()
else:
if cmd.action is not Action.receive:
log_error(f'File transmission command {cmd.action} received for unknown or rejected id: {cmd.id}, ignoring')
return
if len(self.active_sends) >= MAX_ACTIVE_SENDS:
log_error('New File transmission send with too many active receives, ignoring')
return
asd = self.active_sends[cmd.id] = ActiveSend(cmd.id, cmd.quiet, cmd.bypass, cmd.size)
self.start_send(asd.id)
return
if cmd.action is Action.cancel:
self.drop_send(asd.id)
if asd.send_acknowledgements:
self.send_status_response(ErrorCode.CANCELED, request_id=asd.id)
def send_metadata_for_send_transfer(self, asd: ActiveSend) -> None:
sent = False
for ftc in iter_file_metadata(asd.file_specs):
if isinstance(ftc, TransmissionError):
sent = True
if asd.send_errors:
self.send_transmission_error(asd.id, ftc)
else:
ftc.id = asd.id
self.write_ftc_to_child(ftc)
sent = True
if sent:
self.send_status_response(code=ErrorCode.OK, request_id=asd.id, name=home_path())
asd.metadata_sent = True
else:
self.send_status_response(code=ErrorCode.ENOENT, request_id=asd.id, msg='No files found')
self.drop_send(asd.id)
def pump_send_chunks(self, asd: ActiveSend) -> None:
while True:
try:
ftc = asd.next_chunk()
except OSError as err:
fid = asd.active_file.file_id if asd.active_file else ''
self.send_fail_on_os_error(err, 'Failed to read data from file', asd, file_id=fid)
self.drop_send(asd.id)
break
if ftc is None:
break
ftc.id = asd.id
if not self.write_ftc_to_child(ftc, use_pending=False):
asd.return_chunk(ftc)
self.callback_after(self.pump_sends, 0.05)
break
def pump_sends(self, timer_id: int | None) -> None:
for asd in self.active_sends.values():
if asd.metadata_sent:
self.pump_send_chunks(asd)
def handle_receive_cmd(self, cmd: FileTransmissionCommand) -> None:
if cmd.id in self.active_receives:
ar = self.active_receives[cmd.id]
if cmd.action is Action.send:
log_error('File transmission send received for already active id, aborting')
self.drop_receive(cmd.id)
return
if not ar.accepted:
log_error(f'File transmission command {cmd.action} received for pending id: {cmd.id}, aborting')
self.drop_receive(cmd.id)
return
ar.last_activity_at = monotonic()
else:
if cmd.action is not Action.send:
log_error(f'File transmission command {cmd.action} received for unknown or rejected id: {cmd.id}, ignoring')
return
if len(self.active_receives) >= MAX_ACTIVE_RECEIVES:
log_error('New File transmission send with too many active receives, ignoring')
return
ar = self.active_receives[cmd.id] = ActiveReceive(cmd.id, cmd.quiet, cmd.bypass)
self.start_receive(ar.id)
return
if cmd.action is Action.cancel:
self.drop_receive(ar.id)
if ar.send_acknowledgements:
self.send_status_response(ErrorCode.CANCELED, request_id=ar.id)
elif cmd.action is Action.file:
try:
df = ar.start_file(cmd)
except TransmissionError as err:
if ar.send_errors:
self.send_transmission_error(ar.id, err)
except Exception as err:
log_error(f'Transmission protocol failed to start file with error: {err}')
if ar.send_errors:
te = TransmissionError(file_id=cmd.file_id, msg=str(err))
self.send_transmission_error(ar.id, te)
else:
if df.ftype is FileType.directory:
try:
os.makedirs(df.name, exist_ok=True)
except OSError as err:
self.send_fail_on_os_error(err, 'Failed to create directory', ar, df.file_id)
else:
self.send_status_response(ErrorCode.OK, ar.id, df.file_id, name=df.name)
else:
if ar.send_acknowledgements:
sz = df.existing_stat.st_size if df.existing_stat is not None else -1
ttype = TransmissionType.rsync \
if sz > -1 and df.ttype is TransmissionType.rsync and df.ftype is FileType.regular else TransmissionType.simple
self.send_status_response(code=ErrorCode.STARTED, request_id=ar.id, file_id=df.file_id, name=df.name, size=sz, ttype=ttype)
df.ttype = ttype
if ttype is TransmissionType.rsync:
try:
fs = df.signature_iterator()
except OSError as err:
self.send_fail_on_os_error(err, 'Failed to open file to read signature', ar, df.file_id)
else:
ar.pending_files_to_transmit_signature_of.append((fs, df.file_id))
self.callback_after(partial(self.transmit_rsync_signature, ar.id))
elif cmd.action in (Action.data, Action.end_data):
try:
before = 0
bf = ar.files.get(cmd.file_id)
if bf is not None:
before = bf.bytes_written
df = ar.add_data(cmd)
if df.failed:
return
if ar.send_acknowledgements:
if df.closed:
self.send_status_response(
code=ErrorCode.OK, request_id=ar.id, file_id=df.file_id, name=df.name, size=df.bytes_written)
elif df.bytes_written > before:
self.send_status_response(
code=ErrorCode.PROGRESS, request_id=ar.id, file_id=df.file_id, size=df.bytes_written)
except TransmissionError as err:
if ar.send_errors:
self.send_transmission_error(ar.id, err)
except Exception as err:
import traceback
st = traceback.format_exc()
log_error(f'Transmission protocol failed to write data to file with error: {st}')
if ar.send_errors:
te = TransmissionError(file_id=cmd.file_id, msg=str(err))
self.send_transmission_error(ar.id, te)
elif cmd.action is Action.finish:
try:
ar.commit(self.send_fail_on_os_error)
except TransmissionError as err:
if ar.send_errors:
self.send_transmission_error(ar.id, err)
except Exception as err:
log_error(f'Transmission protocol failed to commit receive with error: {err}')
if ar.send_errors:
te = TransmissionError(msg=str(err))
self.send_transmission_error(ar.id, te)
finally:
self.drop_receive(ar.id)
else:
log_error(f'Transmission receive command with unknown action: {cmd.action}, ignoring')
def transmit_rsync_signature(self, receive_id: str, timer_id: int | None = None) -> None:
q = self.active_receives.get(receive_id)
if q is None:
return
ar = q # for mypy
while ar.signature_pending_chunks:
if self.write_ftc_to_child(ar.signature_pending_chunks[0], use_pending=False):
ar.signature_pending_chunks.popleft()
else:
self.callback_after(partial(self.transmit_rsync_signature, receive_id), timeout=0.1)
return
if not ar.pending_files_to_transmit_signature_of:
return
fs, file_id = ar.pending_files_to_transmit_signature_of[0]
pos = 0
buf = memoryview(bytearray(4096))
is_finished = False
while len(buf) >= pos + 32:
try:
n = fs.next_signature_block(buf[pos:])
except OSError as err:
if ar.send_errors:
self.send_fail_on_os_error(err, 'Failed to read signature', ar, file_id)
return
if not n:
is_finished = True
ar.pending_files_to_transmit_signature_of.popleft()
break
pos += n
chunk = buf[:pos]
has_capacity = True
def write_ftc(data: FileTransmissionCommand) -> None:
nonlocal has_capacity
if has_capacity:
if not self.write_ftc_to_child(data, use_pending=False):
has_capacity = False
ar.signature_pending_chunks.append(data)
else:
ar.signature_pending_chunks.append(data)
if len(chunk):
for data in split_for_transfer(chunk, session_id=receive_id, file_id=file_id):
write_ftc(data)
if is_finished:
endftc = FileTransmissionCommand(id=receive_id, action=Action.end_data, file_id=file_id)
write_ftc(endftc)
self.callback_after(partial(self.transmit_rsync_signature, receive_id))
def send_status_response(
self, code: ErrorCode | str = ErrorCode.EINVAL,
request_id: str = '', file_id: str = '', msg: str = '',
name: str = '', size: int = -1,
ttype: TransmissionType = TransmissionType.simple,
) -> bool:
err = TransmissionError(code=code, msg=msg, file_id=file_id, name=name, size=size, ttype=ttype)
return self.write_ftc_to_child(err.as_ftc(request_id))
def send_transmission_error(self, request_id: str, err: TransmissionError) -> bool:
if err.transmit:
return self.write_ftc_to_child(err.as_ftc(request_id))
return True
def write_ftc_to_child(self, payload: FileTransmissionCommand, appendleft: bool = False, use_pending: bool = True) -> bool:
boss = get_boss()
window = boss.window_id_map.get(self.window_id)
if window is not None:
data = tuple(payload.get_serialized_fields(prefix_with_osc_code=True))
queued = window.screen.send_escape_code_to_child(ESC_OSC, data)
if not queued:
if use_pending:
if appendleft:
self.pending_receive_responses.appendleft(payload)
else:
self.pending_receive_responses.append(payload)
self.start_pending_timer()
return queued
return False
def start_send(self, asd_id: str) -> None:
asd = self.active_sends[asd_id]
if asd.bypass_ok is not None:
self.handle_receive_confirmation(asd.bypass_ok, asd_id)
return
boss = get_boss()
window = boss.window_id_map.get(self.window_id)
if window is not None:
boss.confirm(_(
'The remote machine wants to read some files from this computer. Do you want to allow the transfer?'),
self.handle_receive_confirmation, asd_id, window=window,
)
def handle_receive_confirmation(self, confirmed: bool, cmd_id: str) -> None:
asd = self.active_sends.get(cmd_id)
if asd is None:
return
if confirmed:
asd.accepted = True
else:
self.drop_send(asd.id)
if asd.accepted:
if asd.send_acknowledgements:
self.send_status_response(code=ErrorCode.OK, request_id=asd.id)
if asd.spec_complete:
self.send_metadata_for_send_transfer(asd)
else:
if asd.send_errors:
self.send_status_response(code=ErrorCode.EPERM, request_id=asd.id, msg='User refused the transfer')
def start_receive(self, ar_id: str) -> None:
ar = self.active_receives[ar_id]
if ar.bypass_ok is not None:
self.handle_send_confirmation(ar.bypass_ok, ar_id)
return
boss = get_boss()
window = boss.window_id_map.get(self.window_id)
if window is not None:
boss.confirm(_(
'The remote machine wants to send some files to this computer. Do you want to allow the transfer?'),
self.handle_send_confirmation, ar_id, window=window,
)
def handle_send_confirmation(self, confirmed: bool, cmd_id: str) -> None:
ar = self.active_receives.get(cmd_id)
if ar is None:
return
if confirmed:
ar.accepted = True
else:
self.drop_receive(ar.id)
if ar.accepted:
if ar.send_acknowledgements:
self.send_status_response(code=ErrorCode.OK, request_id=ar.id)
else:
if ar.send_errors:
self.send_status_response(code=ErrorCode.EPERM, request_id=ar.id, msg='User refused the transfer')
def send_fail_on_os_error(self, err: OSError, msg: str, ar: ActiveSend | ActiveReceive, file_id: str = '') -> None:
if not ar.send_errors:
return
errname = errno.errorcode.get(err.errno, 'EFAIL') if err.errno is not None else 'EFAIL'
self.send_status_response(code=errname, msg=msg, request_id=ar.id, file_id=file_id)
def active_file(self, rid: str = '', file_id: str = '') -> DestFile:
return self.active_receives[rid].files[file_id]
class TestFileTransmission(FileTransmission):
def __init__(self, allow: bool = True) -> None:
super().__init__(0)
self.test_responses: list[dict[str, str | int | bytes]] = []
self.allow = allow
def write_ftc_to_child(self, payload: FileTransmissionCommand, appendleft: bool = False, use_pending: bool = True) -> bool:
self.test_responses.append(payload.asdict())
return True
def start_receive(self, aid: str) -> None:
self.handle_send_confirmation(self.allow, aid)
def start_send(self, aid: str) -> None:
self.handle_receive_confirmation(self.allow, aid)
def callback_after(self, callback: Callable[[int | None], None], timeout: float = 0) -> int | None:
callback(None)
return None