WARNING: THIS SITE IS A MIRROR OF GITHUB.COM / IT CANNOT LOGIN OR REGISTER ACCOUNTS / THE CONTENTS ARE PROVIDED AS-IS / THIS SITE ASSUMES NO RESPONSIBILITY FOR ANY DISPLAYED CONTENT OR LINKS / IF YOU FOUND SOMETHING MAY NOT GOOD FOR EVERYONE, CONTACT ADMIN AT ilovescratch@foxmail.com
Skip to content

Commit bd22b0d

Browse files
Aelphyxnnpack-bot
authored andcommitted
Added bf16 sum and sum_squared to arm neon.
PiperOrigin-RevId: 840865133
1 parent 9bab77f commit bd22b0d

File tree

3 files changed

+122
-11
lines changed

3 files changed

+122
-11
lines changed

ynnpack/kernels/reduce/arm_neon.cc

Lines changed: 110 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,16 +15,89 @@
1515
#include <type_traits>
1616

1717
#include "ynnpack/base/arithmetic.h"
18+
#include "ynnpack/base/simd/multi_vec.h"
1819
#include "ynnpack/kernels/reduce/generic.h"
1920
#include "ynnpack/kernels/reduce/min_max_accumulator.h"
2021
#include "ynnpack/kernels/reduce/reduce.h"
2122
#include "ynnpack/kernels/reduce/sum_accumulator.h"
2223

2324
namespace ynn {
2425

26+
namespace simd {
27+
28+
using bf16x8x8 = multi_vec<bf16x8, 8>;
29+
using f32x4x16 = multi_vec<f32x4, 16>;
30+
31+
static f32x4x16 reduce_add(
32+
f32x4x16 a, bf16x8x8 b, Identity /*map_fn*/,
33+
std::integral_constant<size_t, 1> /*horizontal_factor*/) {
34+
YNN_UNROLL
35+
for (int i = 0; i < 8; ++i) {
36+
float32x4_t lo =
37+
vreinterpretq_f32_u32(vshll_n_u16(vget_low_u16(b.v[i].v), 16));
38+
float32x4_t hi =
39+
vreinterpretq_f32_u32(vshll_n_u16(vget_high_u16(b.v[i].v), 16));
40+
41+
a.v[2 * i] += f32x4{lo};
42+
a.v[2 * i + 1] += f32x4{hi};
43+
}
44+
45+
return a;
46+
}
47+
48+
static f32x4 reduce_add(
49+
f32x4 a, bf16x8 b, Identity /*map_fn*/,
50+
std::integral_constant<size_t, 2> /*horizontal_factor*/) {
51+
float32x4_t lo = vreinterpretq_f32_u32(vshll_n_u16(vget_low_u16(b.v), 16));
52+
float32x4_t hi = vreinterpretq_f32_u32(vshll_n_u16(vget_high_u16(b.v), 16));
53+
#ifndef __aarch64__
54+
float32x2_t pair_lo = vpadd_f32(vget_low_f32(lo), vget_high_f32(lo));
55+
float32x2_t pair_hi = vpadd_f32(vget_low_f32(hi), vget_high_f32(hi));
56+
return a += f32x4{vcombine_f32(pair_lo, pair_hi)};
57+
#else
58+
return a += f32x4{vpaddq_f32(lo, hi)};
59+
#endif
60+
}
61+
62+
static f32x4x16 reduce_add(
63+
f32x4x16 a, bf16x8x8 b, Square /*map_fn*/,
64+
std::integral_constant<size_t, 1> /*horizontal_factor*/) {
65+
YNN_UNROLL
66+
for (int i = 0; i < 8; ++i) {
67+
float32x4_t lo =
68+
vreinterpretq_f32_u32(vshll_n_u16(vget_low_u16(b.v[i].v), 16));
69+
float32x4_t hi =
70+
vreinterpretq_f32_u32(vshll_n_u16(vget_high_u16(b.v[i].v), 16));
71+
a.v[2 * i].v = vmlaq_f32(a.v[2 * i].v, lo, lo);
72+
a.v[2 * i + 1].v = vmlaq_f32(a.v[2 * i + 1].v, hi, hi);
73+
}
74+
75+
return a;
76+
}
77+
78+
static f32x4 reduce_add(
79+
f32x4 a, bf16x8 b, Square /*map_fn*/,
80+
std::integral_constant<size_t, 2> /*horizontal_factor*/) {
81+
float32x4_t lo = vreinterpretq_f32_u32(vshll_n_u16(vget_low_u16(b.v), 16));
82+
float32x4_t hi = vreinterpretq_f32_u32(vshll_n_u16(vget_high_u16(b.v), 16));
83+
#ifndef __aarch64__
84+
float32x4_t sq_lo = vmulq_f32(lo, lo);
85+
float32x4_t sq_hi = vmulq_f32(hi, hi);
86+
float32x2_t pair_lo = vpadd_f32(vget_low_f32(sq_lo), vget_high_f32(sq_lo));
87+
float32x2_t pair_hi = vpadd_f32(vget_low_f32(sq_hi), vget_high_f32(sq_hi));
88+
return a += f32x4{vcombine_f32(pair_lo, pair_hi)};
89+
#else
90+
return a += f32x4{vpaddq_f32(vmulq_f32(lo, lo), vmulq_f32(hi, hi))};
91+
#endif
92+
}
93+
94+
} // namespace simd
95+
96+
using simd::f32x4;
97+
using simd::f32x4x16;
2598
using simd::bf16x8;
99+
using simd::bf16x8x8;
26100
using simd::f16x8;
27-
using simd::f32x4;
28101
using simd::s16x8;
29102
using simd::s8x16;
30103
using simd::u8x16;
@@ -50,6 +123,42 @@ MIN_MAX_KERNEL(max_fp16_4x8_neon, dummy_t, f16x8_rvar, half, 8);
50123
MIN_MAX_KERNEL(max_uint8_4x16_neon, dummy_t, u8x16, uint8_t, 16);
51124
MIN_MAX_KERNEL(max_int8_4x16_neon, dummy_t, s8x16, int8_t, 16);
52125

126+
void sum_bf16_fp32_neon(size_t n, size_t k3, size_t k2, size_t k1,
127+
size_t a_stride_n, size_t a_stride_k3,
128+
size_t a_stride_k2, const void* a, size_t,
129+
void* c) {
130+
if (k1 == 1 && a_stride_n == sizeof(bfloat16)) {
131+
tiled_reduce<sum_accumulator_k1_1<bf16x8x8, f32x4x16, Square>, bfloat16,
132+
float>(
133+
n, k3, k2, a_stride_k3, a_stride_k2,
134+
reinterpret_cast<const bfloat16*>(a), /*C_stride_m=*/0,
135+
reinterpret_cast<float*>(c));
136+
} else {
137+
tiled_reduce<sum_accumulator_x32<f32x4, 8>, bfloat16, float>(
138+
n, k3, k2, k1, a_stride_n, a_stride_k3, a_stride_k2,
139+
reinterpret_cast<const bfloat16*>(a), /*C_stride_m=*/0,
140+
reinterpret_cast<float*>(c));
141+
}
142+
}
143+
144+
void sum_squared_bf16_fp32_neon(size_t n, size_t k3, size_t k2, size_t k1,
145+
size_t a_stride_n, size_t a_stride_k3,
146+
size_t a_stride_k2, const void* a, size_t,
147+
void* c) {
148+
if (k1 == 1 && a_stride_n == sizeof(bfloat16)) {
149+
tiled_reduce<sum_accumulator_k1_1<bf16x8x8, f32x4x16, Square>, bfloat16,
150+
float>(
151+
n, k3, k2, a_stride_k3, a_stride_k2,
152+
reinterpret_cast<const bfloat16*>(a), /*C_stride_m=*/0,
153+
reinterpret_cast<float*>(c));
154+
} else {
155+
tiled_reduce<sum_accumulator_x32<f32x4, 8, Square>, bfloat16, float>(
156+
n, k3, k2, k1, a_stride_n, a_stride_k3, a_stride_k2,
157+
reinterpret_cast<const bfloat16*>(a), /*C_stride_m=*/0,
158+
reinterpret_cast<float*>(c));
159+
}
160+
}
161+
53162
void sum_fp32_neon(size_t n, size_t k3, size_t k2, size_t k1,
54163
size_t a_stride_n, size_t a_stride_k3, size_t a_stride_k2,
55164
const void* a, size_t, void* c) {

ynnpack/kernels/reduce/sum.inc

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,17 @@
11
// clang-format off
22

3-
#ifdef YNN_ARCH_ARM_NEON
4-
YNN_UNARY_REDUCE_KERNEL(arch_flag::neon, sum_fp32_neon, float, float)
5-
YNN_UNARY_REDUCE_KERNEL(arch_flag::neondot, sum_int8_int32_neondot, int8_t, int32_t)
6-
YNN_UNARY_REDUCE_KERNEL(arch_flag::neondot, sum_uint8_int32_neondot, uint8_t, int32_t)
7-
#endif // YNN_ARCH_ARM_NEON
83
#ifdef YNN_ARCH_ARM_NEONBF16
94
YNN_UNARY_REDUCE_KERNEL(arch_flag::neonbf16, sum_bf16_fp32_neonbf16, bfloat16, float)
105
#endif // YNN_ARCH_ARM_NEONBF16
116
#ifdef YNN_ARCH_ARM_NEONFP16ARITH
127
YNN_UNARY_REDUCE_KERNEL(arch_flag::neonfp16arith, sum_fp16_fp32_neonfp16arith, half, float)
138
#endif // YNN_ARCH_ARM_NEONFP16ARITH
9+
#ifdef YNN_ARCH_ARM_NEON
10+
YNN_UNARY_REDUCE_KERNEL(arch_flag::neon, sum_fp32_neon, float, float)
11+
YNN_UNARY_REDUCE_KERNEL(arch_flag::neon, sum_bf16_fp32_neon, bfloat16, float)
12+
YNN_UNARY_REDUCE_KERNEL(arch_flag::neondot, sum_int8_int32_neondot, int8_t, int32_t)
13+
YNN_UNARY_REDUCE_KERNEL(arch_flag::neondot, sum_uint8_int32_neondot, uint8_t, int32_t)
14+
#endif // YNN_ARCH_ARM_NEON
1415

1516
#ifdef YNN_ARCH_X86_AVX512BF16
1617
YNN_UNARY_REDUCE_KERNEL(arch_flag::avx512bf16, sum_bf16_fp32_avx512bf16, bfloat16, float)

ynnpack/kernels/reduce/sum_squared.inc

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,17 @@
11
// clang-format off
22

3-
#ifdef YNN_ARCH_ARM_NEON
4-
YNN_UNARY_REDUCE_KERNEL(arch_flag::neon, sum_squared_fp32_neon, float, float)
5-
YNN_UNARY_REDUCE_KERNEL(arch_flag::neondot, sum_squared_int8_int32_neondot, int8_t, int32_t)
6-
YNN_UNARY_REDUCE_KERNEL(arch_flag::neondot, sum_squared_uint8_int32_neondot, uint8_t, int32_t)
7-
#endif // YNN_ARCH_ARM_NEON
83
#ifdef YNN_ARCH_ARM_NEONBF16
94
YNN_UNARY_REDUCE_KERNEL(arch_flag::neonbf16, sum_squared_bf16_fp32_neonbf16, bfloat16, float)
105
#endif // YNN_ARCH_ARM_NEONBF16
116
#ifdef YNN_ARCH_ARM_NEONFP16ARITH
127
YNN_UNARY_REDUCE_KERNEL(arch_flag::neonfp16arith, sum_squared_fp16_fp32_neonfp16arith, half, float)
138
#endif // YNN_ARCH_ARM_NEONFP16ARITH
9+
#ifdef YNN_ARCH_ARM_NEON
10+
YNN_UNARY_REDUCE_KERNEL(arch_flag::neon, sum_squared_fp32_neon, float, float)
11+
YNN_UNARY_REDUCE_KERNEL(arch_flag::neon, sum_squared_bf16_fp32_neon, bfloat16, float)
12+
YNN_UNARY_REDUCE_KERNEL(arch_flag::neondot, sum_squared_int8_int32_neondot, int8_t, int32_t)
13+
YNN_UNARY_REDUCE_KERNEL(arch_flag::neondot, sum_squared_uint8_int32_neondot, uint8_t, int32_t)
14+
#endif // YNN_ARCH_ARM_NEON
1415

1516
#ifdef YNN_ARCH_X86_AVX512BF16
1617
YNN_UNARY_REDUCE_KERNEL(arch_flag::avx512bf16, sum_squared_bf16_fp32_avx512bf16, bfloat16, float)

0 commit comments

Comments
 (0)