mirror of
https://github.com/kovidgoyal/kitty.git
synced 2025-12-13 20:36:22 +01:00
Go SHM API to read simple data with size from SHM name
This commit is contained in:
30
kitty_tests/shm.py
Normal file
30
kitty_tests/shm.py
Normal file
@@ -0,0 +1,30 @@
|
||||
#!/usr/bin/env python
|
||||
# License: GPLv3 Copyright: 2023, Kovid Goyal <kovid at kovidgoyal.net>
|
||||
|
||||
|
||||
import os
|
||||
import subprocess
|
||||
|
||||
from kitty.constants import kitten_exe
|
||||
from kitty.fast_data_types import shm_unlink
|
||||
from kitty.shm import SharedMemory
|
||||
|
||||
from . import BaseTest
|
||||
|
||||
|
||||
class SHMTest(BaseTest):
|
||||
|
||||
def test_shm_with_kitten(self):
|
||||
data = os.urandom(333)
|
||||
with SharedMemory(size=363) as shm:
|
||||
shm.write_data_with_size(data)
|
||||
cp = subprocess.run([kitten_exe(), '__pytest__', 'shm', 'read', shm.name], stdout=subprocess.PIPE)
|
||||
self.assertEqual(cp.returncode, 0)
|
||||
self.assertEqual(cp.stdout, data)
|
||||
self.assertRaises(FileNotFoundError, shm_unlink, shm.name)
|
||||
cp = subprocess.run([kitten_exe(), '__pytest__', 'shm', 'write'], input=data, stdout=subprocess.PIPE)
|
||||
self.assertEqual(cp.returncode, 0)
|
||||
name = cp.stdout.decode().strip()
|
||||
with SharedMemory(name=name, unlink_on_exit=True) as shm:
|
||||
q = shm.read_data_with_size()
|
||||
self.assertEqual(data, q)
|
||||
20
tools/cmd/pytest/main.go
Normal file
20
tools/cmd/pytest/main.go
Normal file
@@ -0,0 +1,20 @@
|
||||
// License: GPLv3 Copyright: 2023, Kovid Goyal, <kovid at kovidgoyal.net>
|
||||
|
||||
package pytest
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"kitty/tools/cli"
|
||||
"kitty/tools/utils/shm"
|
||||
)
|
||||
|
||||
var _ = fmt.Print
|
||||
|
||||
func EntryPoint(root *cli.Command) {
|
||||
root = root.AddSubCommand(&cli.Command{
|
||||
Name: "__pytest__",
|
||||
Hidden: true,
|
||||
})
|
||||
shm.TestEntryPoint(root)
|
||||
}
|
||||
@@ -5,10 +5,13 @@ package ssh
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"os"
|
||||
"os/user"
|
||||
"strings"
|
||||
|
||||
"kitty/tools/cli"
|
||||
"kitty/tools/tty"
|
||||
|
||||
"golang.org/x/exp/maps"
|
||||
"golang.org/x/sys/unix"
|
||||
@@ -16,6 +19,74 @@ import (
|
||||
|
||||
var _ = fmt.Print
|
||||
|
||||
func get_destination(hostname string) (username, hostname_for_match string) {
|
||||
u, err := user.Current()
|
||||
if err == nil {
|
||||
username = u.Username
|
||||
}
|
||||
hostname_for_match = hostname
|
||||
if strings.HasPrefix(hostname, "ssh://") {
|
||||
p, err := url.Parse(hostname)
|
||||
if err == nil {
|
||||
hostname_for_match = p.Hostname()
|
||||
if p.User.Username() != "" {
|
||||
username = p.User.Username()
|
||||
}
|
||||
}
|
||||
} else if strings.Contains(hostname, "@") && hostname[0] != '@' {
|
||||
username, hostname_for_match, _ = strings.Cut(hostname, "@")
|
||||
}
|
||||
if strings.Contains(hostname, "@") && hostname[0] != '@' {
|
||||
_, hostname_for_match, _ = strings.Cut(hostname_for_match, "@")
|
||||
}
|
||||
hostname_for_match, _, _ = strings.Cut(hostname_for_match, ":")
|
||||
return
|
||||
}
|
||||
|
||||
func add_cloned_env(val string) map[string]string {
|
||||
return nil // TODO: Implement me
|
||||
}
|
||||
|
||||
func parse_kitten_args(found_extra_args []string, username, hostname_for_match string) (overrides []string, literal_env map[string]string) {
|
||||
literal_env = make(map[string]string)
|
||||
overrides = make([]string, 0, 4)
|
||||
for i, a := range found_extra_args {
|
||||
if i%2 == 0 {
|
||||
continue
|
||||
}
|
||||
if key, val, found := strings.Cut(a, "="); found {
|
||||
if key == "clone_env" {
|
||||
le := add_cloned_env(val)
|
||||
if le != nil {
|
||||
literal_env = le
|
||||
}
|
||||
} else if key != "hostname" {
|
||||
overrides = append(overrides, key+" "+val)
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(overrides) > 0 {
|
||||
overrides = append([]string{"hostname " + username + "@" + hostname_for_match}, overrides...)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func run_ssh(ssh_args, server_args, found_extra_args []string) (rc int, err error) {
|
||||
cmd := append([]string{ssh_exe()}, ssh_args...)
|
||||
hostname, remote_args := server_args[0], server_args[1:]
|
||||
if len(remote_args) == 0 {
|
||||
cmd = append(cmd, "-t")
|
||||
}
|
||||
insertion_point := len(cmd)
|
||||
cmd = append(cmd, "--", hostname)
|
||||
uname, hostname_for_match := get_destination(hostname)
|
||||
overrides, literal_env := parse_kitten_args(found_extra_args, uname, hostname_for_match)
|
||||
if insertion_point > 0 && overrides != nil && literal_env != nil {
|
||||
}
|
||||
// TODO: Implement me
|
||||
return
|
||||
}
|
||||
|
||||
func main(cmd *cli.Command, o *Options, args []string) (rc int, err error) {
|
||||
if len(args) > 0 {
|
||||
switch args[0] {
|
||||
@@ -44,10 +115,13 @@ func main(cmd *cli.Command, o *Options, args []string) (rc int, err error) {
|
||||
}
|
||||
return 1, unix.Exec(ssh_exe(), append([]string{"ssh"}, args...), os.Environ())
|
||||
}
|
||||
if false {
|
||||
return len(ssh_args) + len(server_args), nil
|
||||
if os.Getenv("KITTY_WINDOW_ID") == "" || os.Getenv("KITTY_PID") == "" {
|
||||
return 1, fmt.Errorf("The SSH kitten is meant to run inside a kitty window")
|
||||
}
|
||||
return
|
||||
if !tty.IsTerminal(os.Stdin.Fd()) {
|
||||
return 1, fmt.Errorf("The SSH kitten is meant for interactive use only, STDIN must be a terminal")
|
||||
}
|
||||
return run_ssh(ssh_args, server_args, found_extra_args)
|
||||
}
|
||||
|
||||
func EntryPoint(parent *cli.Command) {
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"kitty/tools/cmd/clipboard"
|
||||
"kitty/tools/cmd/edit_in_kitty"
|
||||
"kitty/tools/cmd/icat"
|
||||
"kitty/tools/cmd/pytest"
|
||||
"kitty/tools/cmd/ssh"
|
||||
"kitty/tools/cmd/unicode_input"
|
||||
"kitty/tools/cmd/update_self"
|
||||
@@ -35,6 +36,8 @@ func KittyToolEntryPoints(root *cli.Command) {
|
||||
ssh.EntryPoint(root)
|
||||
// unicode_input
|
||||
unicode_input.EntryPoint(root)
|
||||
// __pytest__
|
||||
pytest.EntryPoint(root)
|
||||
// __hold_till_enter__
|
||||
root.AddSubCommand(&cli.Command{
|
||||
Name: "__hold_till_enter__",
|
||||
|
||||
@@ -5,13 +5,17 @@ package shm
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/base32"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
not_rand "math/rand"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"kitty/tools/cli"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
@@ -109,3 +113,69 @@ func truncate_or_unlink(ans *os.File, size uint64) (err error) {
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func read_till_buf_full(f *os.File, buf []byte) ([]byte, error) {
|
||||
p := buf
|
||||
for len(p) > 0 {
|
||||
n, err := f.Read(p)
|
||||
p = p[n:]
|
||||
if err != nil {
|
||||
if len(p) == 0 && errors.Is(err, io.EOF) {
|
||||
err = nil
|
||||
}
|
||||
return buf[:len(buf)-len(p)], err
|
||||
}
|
||||
}
|
||||
return buf, nil
|
||||
}
|
||||
|
||||
func read_with_size(f *os.File) ([]byte, error) {
|
||||
szbuf := []byte{0, 0, 0, 0}
|
||||
szbuf, err := read_till_buf_full(f, szbuf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
size := int(binary.BigEndian.Uint32(szbuf))
|
||||
return read_till_buf_full(f, make([]byte, size))
|
||||
}
|
||||
|
||||
func test_integration_with_python(args []string) (rc int, err error) {
|
||||
switch args[0] {
|
||||
default:
|
||||
return 1, fmt.Errorf("Unknown test type: %s", args[0])
|
||||
case "read":
|
||||
data, err := ReadWithSizeAndUnlink(args[1])
|
||||
if err != nil {
|
||||
return 1, err
|
||||
}
|
||||
_, err = os.Stdout.Write(data)
|
||||
if err != nil {
|
||||
return 1, err
|
||||
}
|
||||
case "write":
|
||||
data, err := io.ReadAll(os.Stdin)
|
||||
if err != nil {
|
||||
return 1, err
|
||||
}
|
||||
mmap, err := CreateTemp("shmtest-", uint64(len(data)+4))
|
||||
if err != nil {
|
||||
return 1, err
|
||||
}
|
||||
defer mmap.Close()
|
||||
binary.BigEndian.PutUint32(mmap.Slice(), uint32(len(data)))
|
||||
copy(mmap.Slice()[4:], data)
|
||||
fmt.Println(mmap.Name())
|
||||
}
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func TestEntryPoint(root *cli.Command) {
|
||||
root.AddSubCommand(&cli.Command{
|
||||
Name: "shm",
|
||||
OnlyArgsAllowed: true,
|
||||
Run: func(cmd *cli.Command, args []string) (rc int, err error) {
|
||||
return test_integration_with_python(args)
|
||||
},
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
@@ -113,7 +113,7 @@ func create_temp(pattern string, size uint64) (ans MMap, err error) {
|
||||
return file_mmap(f, size, WRITE, true, special_name)
|
||||
}
|
||||
|
||||
func Open(name string, size uint64) (MMap, error) {
|
||||
func open(name string) (*os.File, error) {
|
||||
ans, err := os.OpenFile(file_path_from_name(name), os.O_RDONLY, 0)
|
||||
if err != nil {
|
||||
if errors.Is(err, fs.ErrNotExist) {
|
||||
@@ -123,5 +123,23 @@ func Open(name string, size uint64) (MMap, error) {
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return ans, nil
|
||||
}
|
||||
|
||||
func Open(name string, size uint64) (MMap, error) {
|
||||
ans, err := open(name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return file_mmap(ans, size, READ, false, name)
|
||||
}
|
||||
|
||||
func ReadWithSizeAndUnlink(name string) ([]byte, error) {
|
||||
f, err := open(name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer f.Close()
|
||||
defer os.Remove(f.Name())
|
||||
return read_with_size(f)
|
||||
}
|
||||
|
||||
@@ -151,3 +151,13 @@ func Open(name string, size uint64) (MMap, error) {
|
||||
}
|
||||
return syscall_mmap(ans, size, READ, false)
|
||||
}
|
||||
|
||||
func ReadWithSizeAndUnlink(name string) ([]byte, error) {
|
||||
f, err := shm_open(name, os.O_RDONLY, 0)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer f.Close()
|
||||
defer shm_unlink(f.Name())
|
||||
return read_with_size(f)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user