summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJoe Ramsay <Joe.Ramsay@arm.com>2025-11-06 15:36:03 +0000
committerWilco Dijkstra <wilco.dijkstra@arm.com>2025-11-18 16:12:52 +0000
commit360f60fb63901f1755a30030107f8e8c6f78e6e3 (patch)
tree97c1c7ad7cd3947903363119475ed1da6f8b59d5
parent215e9155ea06064342151d05446ae51da16e0f65 (diff)
AArch64: Optimise SVE scalar callbacks
Instead of using SVE instructions to marshall special results into the correct lane, just write the entire vector (and the predicate) to memory, then use cheaper scalar operations. Geomean speedup of 16% in special intervals on Neoverse with GCC 14. Reviewed-by: Wilco Dijkstra <Wilco.Dijkstra@arm.com> (cherry picked from commit 5b82fb18827e962af9f080fdf3c1a69802783f67)
-rw-r--r--sysdeps/aarch64/fpu/sv_math.h97
1 files changed, 62 insertions, 35 deletions
diff --git a/sysdeps/aarch64/fpu/sv_math.h b/sysdeps/aarch64/fpu/sv_math.h
index 3d576df4cc..65d7f0ff20 100644
--- a/sysdeps/aarch64/fpu/sv_math.h
+++ b/sysdeps/aarch64/fpu/sv_math.h
@@ -24,11 +24,29 @@
#include "vecmath_config.h"
+#if !defined(__ARM_FEATURE_SVE_BITS) || __ARM_FEATURE_SVE_BITS == 0
+/* If not specified by -msve-vector-bits, assume maximum vector length. */
+# define SVE_VECTOR_BYTES 256
+#else
+# define SVE_VECTOR_BYTES (__ARM_FEATURE_SVE_BITS / 8)
+#endif
+#define SVE_NUM_FLTS (SVE_VECTOR_BYTES / sizeof (float))
+#define SVE_NUM_DBLS (SVE_VECTOR_BYTES / sizeof (double))
+/* Predicate is stored as one bit per byte of VL so requires VL / 64 bytes. */
+#define SVE_NUM_PG_BYTES (SVE_VECTOR_BYTES / sizeof (uint64_t))
+
#define SV_NAME_F1(fun) _ZGVsMxv_##fun##f
#define SV_NAME_D1(fun) _ZGVsMxv_##fun
#define SV_NAME_F2(fun) _ZGVsMxvv_##fun##f
#define SV_NAME_D2(fun) _ZGVsMxvv_##fun
+static inline void
+svstr_p (uint8_t *dst, svbool_t p)
+{
+ /* Predicate STR does not currently have an intrinsic. */
+ __asm__("str %0, [%x1]\n" : : "Upa"(p), "r"(dst) : "memory");
+}
+
/* Double precision. */
static inline svint64_t
sv_s64 (int64_t x)
@@ -51,33 +69,35 @@ sv_f64 (double x)
static inline svfloat64_t
sv_call_f64 (double (*f) (double), svfloat64_t x, svfloat64_t y, svbool_t cmp)
{
- svbool_t p = svpfirst (cmp, svpfalse ());
- while (svptest_any (cmp, p))
+ double tmp[SVE_NUM_DBLS];
+ uint8_t pg_bits[SVE_NUM_PG_BYTES];
+ svstr_p (pg_bits, cmp);
+ svst1 (svptrue_b64 (), tmp, svsel (cmp, x, y));
+
+ for (int i = 0; i < svcntd (); i++)
{
- double elem = svclastb_n_f64 (p, 0, x);
- elem = (*f) (elem);
- svfloat64_t y2 = svdup_n_f64 (elem);
- y = svsel_f64 (p, y2, y);
- p = svpnext_b64 (cmp, p);
+ if (pg_bits[i] & 1)
+ tmp[i] = f (tmp[i]);
}
- return y;
+ return svld1 (svptrue_b64 (), tmp);
}
static inline svfloat64_t
sv_call2_f64 (double (*f) (double, double), svfloat64_t x1, svfloat64_t x2,
svfloat64_t y, svbool_t cmp)
{
- svbool_t p = svpfirst (cmp, svpfalse ());
- while (svptest_any (cmp, p))
+ double tmp1[SVE_NUM_DBLS], tmp2[SVE_NUM_DBLS];
+ uint8_t pg_bits[SVE_NUM_PG_BYTES];
+ svstr_p (pg_bits, cmp);
+ svst1 (svptrue_b64 (), tmp1, svsel (cmp, x1, y));
+ svst1 (cmp, tmp2, x2);
+
+ for (int i = 0; i < svcntd (); i++)
{
- double elem1 = svclastb_n_f64 (p, 0, x1);
- double elem2 = svclastb_n_f64 (p, 0, x2);
- double ret = (*f) (elem1, elem2);
- svfloat64_t y2 = svdup_n_f64 (ret);
- y = svsel_f64 (p, y2, y);
- p = svpnext_b64 (cmp, p);
+ if (pg_bits[i] & 1)
+ tmp1[i] = f (tmp1[i], tmp2[i]);
}
- return y;
+ return svld1 (svptrue_b64 (), tmp1);
}
static inline svuint64_t
@@ -109,33 +129,40 @@ sv_f32 (float x)
static inline svfloat32_t
sv_call_f32 (float (*f) (float), svfloat32_t x, svfloat32_t y, svbool_t cmp)
{
- svbool_t p = svpfirst (cmp, svpfalse ());
- while (svptest_any (cmp, p))
+ float tmp[SVE_NUM_FLTS];
+ uint8_t pg_bits[SVE_NUM_PG_BYTES];
+ svstr_p (pg_bits, cmp);
+ svst1 (svptrue_b32 (), tmp, svsel (cmp, x, y));
+
+ for (int i = 0; i < svcntd (); i++)
{
- float elem = svclastb_n_f32 (p, 0, x);
- elem = f (elem);
- svfloat32_t y2 = svdup_n_f32 (elem);
- y = svsel_f32 (p, y2, y);
- p = svpnext_b32 (cmp, p);
+ uint8_t p = pg_bits[i];
+ if (p & 1)
+ tmp[i * 2] = f (tmp[i * 2]);
+ if (p & (1 << 4))
+ tmp[i * 2 + 1] = f (tmp[i * 2 + 1]);
}
- return y;
+ return svld1 (svptrue_b32 (), tmp);
}
static inline svfloat32_t
sv_call2_f32 (float (*f) (float, float), svfloat32_t x1, svfloat32_t x2,
svfloat32_t y, svbool_t cmp)
{
- svbool_t p = svpfirst (cmp, svpfalse ());
- while (svptest_any (cmp, p))
+ float tmp1[SVE_NUM_FLTS], tmp2[SVE_NUM_FLTS];
+ uint8_t pg_bits[SVE_NUM_PG_BYTES];
+ svstr_p (pg_bits, cmp);
+ svst1 (svptrue_b32 (), tmp1, svsel (cmp, x1, y));
+ svst1 (cmp, tmp2, x2);
+
+ for (int i = 0; i < svcntd (); i++)
{
- float elem1 = svclastb_n_f32 (p, 0, x1);
- float elem2 = svclastb_n_f32 (p, 0, x2);
- float ret = f (elem1, elem2);
- svfloat32_t y2 = svdup_n_f32 (ret);
- y = svsel_f32 (p, y2, y);
- p = svpnext_b32 (cmp, p);
+ uint8_t p = pg_bits[i];
+ if (p & 1)
+ tmp1[i * 2] = f (tmp1[i * 2], tmp2[i * 2]);
+ if (p & (1 << 4))
+ tmp1[i * 2 + 1] = f (tmp1[i * 2 + 1], tmp2[i * 2 + 1]);
}
- return y;
+ return svld1 (svptrue_b32 (), tmp1);
}
-
#endif