Implement > and < intrinsics for vector registers

This commit is contained in:
Kovid Goyal
2024-02-01 08:19:34 +05:30
parent 82b7b4fcce
commit d60dacbd09
3 changed files with 86 additions and 23 deletions

View File

@@ -360,14 +360,12 @@ func (f *Function) Comment(x ...any) {
f.Instructions = append(f.Instructions, space_join("\t//", x...))
}
func shrn8b(r Register) int {
n, err := strconv.Atoi(r.Name[1:])
if err != nil {
panic(err)
}
n = 0x8400 + (n * 0x21)
return 0x0f0c0000 + n
func shrn8b_immediate4(a, b Register) uint32 {
return (0x0f0c84 << 8) | (a.ARMId()<<5 | b.ARMId())
}
func encode_cmgt16b(a, b, dest Register) (ans uint32) {
return 0x271<<21 | a.ARMId()<<16 | 0xd<<10 | b.ARMId()<<5 | dest.ARMId()
}
func (f *Function) CountBytesToFirstMatchDestructive(vec, ans Register) {
@@ -377,7 +375,7 @@ func (f *Function) CountBytesToFirstMatchDestructive(vec, ans Register) {
// See https://community.arm.com/arm-community-blogs/b/infrastructure-solutions-blog/posts/porting-x86-vector-bitmask-optimizations-to-arm-neon
f.Comment("Go assembler doesn't support the shrn instruction, below we have: shrn.8b", vec, vec, "#4")
f.Comment("It is shifting right by four bits in every 16 bit word and truncating to 8 bits storing the result in the lower 64 bits of", vec)
f.instr("WORD", fmt.Sprintf("$0x%x", shrn8b(vec)))
f.instr("WORD", fmt.Sprintf("$0x%x", shrn8b_immediate4(vec, vec)))
f.instr("FMOVD", "F"+vec.Name[1:], ans)
f.AddTrailingComment("Extract the lower 64 bits from", vec, "and put them into", ans)
f.instr("RBIT", ans, ans)
@@ -647,28 +645,53 @@ func (f *Function) SetRegsiterTo(self Register, val any) {
}
}
func (f *Function) CmpEqEpi8(a, b, ans Register) {
func (r Register) ARMId() uint32 {
num, err := strconv.Atoi(r.Name[1:])
if err != nil {
panic(err)
}
return uint32(num)
}
func (f *Function) cmp(a, b, ans Register, op, c_rep string) {
if a.Size != b.Size || a.Size != ans.Size {
panic("Can only compare registers of equal sizes")
}
if f.ISA.Goarch == ARM64 {
f.instr("VCMEQ", a.ARMFullWidth(), b.ARMFullWidth(), ans.ARMFullWidth())
if op == "EQ" {
f.instr("VCMEQ", a.ARMFullWidth(), b.ARMFullWidth(), ans.ARMFullWidth())
} else {
f.instr("WORD", fmt.Sprintf("$0x%x", encode_cmgt16b(a, b, ans)))
}
} else {
op := `PCMP` + op + "B"
if f.ISA.Bits == 128 {
switch ans.Name {
case a.Name:
f.instr("PCMPEQB", b, ans)
f.instr(op, b, ans)
case b.Name:
f.instr("PCMPEQB", a, ans)
f.instr(op, a, ans)
default:
f.CopyRegister(a, ans)
f.instr("PCMPEQB", b, ans)
f.instr(op, b, ans)
}
} else {
f.instr("VPCMPEQB", a, b, ans)
f.instr("V"+op, a, b, ans)
}
}
f.AddTrailingComment(ans, "= 0xff on every byte where", a.Name+"[n] == "+b.Name+"[n] and zero elsewhere")
f.AddTrailingComment(ans, "= 0xff on every byte where", a.Name+"[n]", c_rep, b.Name+"[n] and zero elsewhere")
}
func (f *Function) CmpGtEpi8(a, b, ans Register) {
f.cmp(a, b, ans, "GT", ">")
}
func (f *Function) CmpLtEpi8(a, b, ans Register) {
f.cmp(b, a, ans, "GT", "<")
}
func (f *Function) CmpEqEpi8(a, b, ans Register) {
f.cmp(a, b, ans, "EQ", "==")
}
func (f *Function) Set1Epi8FromParam(function_parameter string, vec Register) {
@@ -1030,6 +1053,17 @@ func (s *State) test_cmpeq_epi8() {
f.store_vec_in_param(a, "ans")
}
func (s *State) test_cmplt_epi8() {
f := s.NewFunction("test_cmplt_epi8_asm", "Test byte comparison of two vectors", []FunctionParam{{"a", ByteSlice}, {"b", ByteSlice}, {"ans", ByteSlice}}, nil)
if !s.ISA.HasSIMD {
return
}
a := f.load_vec_from_param("a")
b := f.load_vec_from_param("b")
f.CmpLtEpi8(a, b, a)
f.store_vec_in_param(a, "ans")
}
func (s *State) test_or() {
f := s.NewFunction("test_or_asm", "Test OR of two vectors", []FunctionParam{{"a", ByteSlice}, {"b", ByteSlice}, {"ans", ByteSlice}}, nil)
if !s.ISA.HasSIMD {
@@ -1175,6 +1209,7 @@ func (s *State) Generate() {
s.test_load()
s.test_set1_epi8()
s.test_cmpeq_epi8()
s.test_cmplt_epi8()
s.test_or()
s.test_jump_if_zero()
s.test_count_to_match()

View File

@@ -92,7 +92,7 @@ func init() {
}
case "arm64":
Have128bit = HasSIMD128Code
Have128bit = HasSIMD256Code
Have256bit = HasSIMD256Code
}
if Have256bit {
UnsafeIndexByte2 = index_byte2_asm_256

View File

@@ -44,6 +44,16 @@ func test_cmpeq_epi8(a, b []byte) []byte {
return ans
}
func test_cmplt_epi8(a, b []byte) []byte {
ans := make([]byte, len(a))
if len(ans) == 16 {
test_cmplt_epi8_asm_128(a, b, ans)
} else {
test_cmplt_epi8_asm_256(a, b, ans)
}
return ans
}
func test_or(a, b []byte) []byte {
ans := make([]byte, len(a))
if len(ans) == 16 {
@@ -81,14 +91,18 @@ func broadcast_byte(b byte, size int) []byte {
}
func TestSIMDStringOps(t *testing.T) {
sizes := []int{}
if Have128bit {
sizes = append(sizes, 16)
}
if Have256bit {
sizes = append(sizes, 32)
}
if len(sizes) == 0 {
t.Skip("skipping as no SIMD available at runtime")
}
test := func(haystack []byte, a, b byte) {
sizes := []int{}
if Have128bit {
sizes = append(sizes, 16)
}
if Have256bit {
sizes = append(sizes, 32)
}
var actual int
safe_haystack := append(bytes.Repeat([]byte{'<'}, 64), haystack...)
safe_haystack = append(safe_haystack, bytes.Repeat([]byte{'>'}, 64)...)
@@ -170,6 +184,9 @@ func TestIntrinsics(t *testing.T) {
if !HasSIMD128Code {
t.Fatal("SIMD 128bit code not built")
}
if !Have128bit {
t.Fatal("SIMD 128bit support not available at runtime")
}
}
ae := func(sz int, func_name string, a, b any) {
if s := cmp.Diff(a, b); s != "" {
@@ -192,6 +209,14 @@ func TestIntrinsics(t *testing.T) {
b := ordered_bytes(sz)
ans := test_cmpeq_epi8(a, b)
ae(sz, `cmpeq_epi8_test`, broadcast_byte(0xff, sz), ans)
threshold := -1
a[1] = byte(threshold)
a[2] = byte(threshold - 1)
ans = test_cmplt_epi8(a, broadcast_byte(byte(threshold), sz))
expected := broadcast_byte(0xff, sz)
expected[1] = 0
expected[2] = 0
ae(sz, `cmplt_epi8_test`, expected, ans)
})
tests = append(tests, func(sz int) {
a := make([]byte, sz)
@@ -247,6 +272,9 @@ func TestIntrinsics(t *testing.T) {
if Have256bit {
sizes = append(sizes, 32)
}
if len(sizes) == 0 {
t.Skip("skipping as no SIMD available at runtime")
}
for _, sz := range sizes {
for _, test := range tests {
test(sz)