WARNING: THIS SITE IS A MIRROR OF GITHUB.COM / IT CANNOT LOGIN OR REGISTER ACCOUNTS / THE CONTENTS ARE PROVIDED AS-IS / THIS SITE ASSUMES NO RESPONSIBILITY FOR ANY DISPLAYED CONTENT OR LINKS / IF YOU FOUND SOMETHING MAY NOT GOOD FOR EVERYONE, CONTACT ADMIN AT ilovescratch@foxmail.com
Skip to content

Commit 995241f

Browse files
Aelphyxnnpack-bot
authored andcommitted
Added bf16 sum and sum_squared to arm neon.
PiperOrigin-RevId: 840924381
1 parent 9bab77f commit 995241f

File tree

3 files changed

+116
-11
lines changed

3 files changed

+116
-11
lines changed

ynnpack/kernels/reduce/arm_neon.cc

Lines changed: 104 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,16 +15,84 @@
1515
#include <type_traits>
1616

1717
#include "ynnpack/base/arithmetic.h"
18+
#include "ynnpack/base/simd/multi_vec.h"
1819
#include "ynnpack/kernels/reduce/generic.h"
1920
#include "ynnpack/kernels/reduce/min_max_accumulator.h"
2021
#include "ynnpack/kernels/reduce/reduce.h"
2122
#include "ynnpack/kernels/reduce/sum_accumulator.h"
2223

2324
namespace ynn {
2425

26+
namespace simd {
27+
28+
using bf16x8x8 = multi_vec<bf16x8, 8>;
29+
using f32x4x16 = multi_vec<f32x4, 16>;
30+
31+
static f32x4x16 reduce_add(
32+
f32x4x16 a, bf16x8x8 b, Identity /*map_fn*/,
33+
std::integral_constant<size_t, 1> /*horizontal_factor*/) {
34+
YNN_UNROLL
35+
for (int i = 0; i < 8; ++i) {
36+
float32x4_t lo =
37+
vreinterpretq_f32_u32(vshll_n_u16(vget_low_u16(b.v[i].v), 16));
38+
float32x4_t hi =
39+
vreinterpretq_f32_u32(vshll_n_u16(vget_high_u16(b.v[i].v), 16));
40+
41+
a.v[2 * i] += f32x4{lo};
42+
a.v[2 * i + 1] += f32x4{hi};
43+
}
44+
45+
return a;
46+
}
47+
48+
static f32x4 reduce_add(
49+
f32x4 a, bf16x8 b, Identity /*map_fn*/,
50+
std::integral_constant<size_t, 2> /*horizontal_factor*/) {
51+
uint32x4_t pairs = vreinterpretq_u32_u16(b.v);
52+
float32x4_t even = vreinterpretq_f32_u32(vshlq_n_u32(pairs, 16));
53+
float32x4_t odd =
54+
vreinterpretq_f32_u32(vandq_u32(pairs, vdupq_n_u32(0xFFFF0000)));
55+
56+
a += f32x4{odd};
57+
return a += f32x4{even};
58+
}
59+
60+
static f32x4x16 reduce_add(
61+
f32x4x16 a, bf16x8x8 b, Square /*map_fn*/,
62+
std::integral_constant<size_t, 1> /*horizontal_factor*/) {
63+
YNN_UNROLL
64+
for (int i = 0; i < 8; ++i) {
65+
float32x4_t lo =
66+
vreinterpretq_f32_u32(vshll_n_u16(vget_low_u16(b.v[i].v), 16));
67+
float32x4_t hi =
68+
vreinterpretq_f32_u32(vshll_n_u16(vget_high_u16(b.v[i].v), 16));
69+
a.v[2 * i].v = vmlaq_f32(a.v[2 * i].v, lo, lo);
70+
a.v[2 * i + 1].v = vmlaq_f32(a.v[2 * i + 1].v, hi, hi);
71+
}
72+
73+
return a;
74+
}
75+
76+
static f32x4 reduce_add(
77+
f32x4 a, bf16x8 b, Square /*map_fn*/,
78+
std::integral_constant<size_t, 2> /*horizontal_factor*/) {
79+
uint32x4_t pairs = vreinterpretq_u32_u16(b.v);
80+
float32x4_t even = vreinterpretq_f32_u32(vshlq_n_u32(pairs, 16));
81+
float32x4_t odd =
82+
vreinterpretq_f32_u32(vandq_u32(pairs, vdupq_n_u32(0xFFFF0000)));
83+
84+
a.v = vmlaq_f32(a.v, odd, odd);
85+
a.v = vmlaq_f32(a.v, even, even);
86+
return a;
87+
}
88+
89+
} // namespace simd
90+
91+
using simd::f32x4;
92+
using simd::f32x4x16;
2593
using simd::bf16x8;
94+
using simd::bf16x8x8;
2695
using simd::f16x8;
27-
using simd::f32x4;
2896
using simd::s16x8;
2997
using simd::s8x16;
3098
using simd::u8x16;
@@ -50,6 +118,41 @@ MIN_MAX_KERNEL(max_fp16_4x8_neon, dummy_t, f16x8_rvar, half, 8);
50118
MIN_MAX_KERNEL(max_uint8_4x16_neon, dummy_t, u8x16, uint8_t, 16);
51119
MIN_MAX_KERNEL(max_int8_4x16_neon, dummy_t, s8x16, int8_t, 16);
52120

121+
void sum_bf16_fp32_neon(size_t n, size_t k3, size_t k2, size_t k1,
122+
size_t a_stride_n, size_t a_stride_k3,
123+
size_t a_stride_k2, const void* a, size_t,
124+
void* c) {
125+
if (k1 == 1 && a_stride_n == sizeof(bfloat16)) {
126+
tiled_reduce<sum_accumulator_k1_1<bf16x8x8, f32x4x16>, bfloat16, float>(
127+
n, k3, k2, a_stride_k3, a_stride_k2,
128+
reinterpret_cast<const bfloat16*>(a), /*C_stride_m=*/0,
129+
reinterpret_cast<float*>(c));
130+
} else {
131+
tiled_reduce<sum_accumulator_x32<f32x4, 8>, bfloat16, float>(
132+
n, k3, k2, k1, a_stride_n, a_stride_k3, a_stride_k2,
133+
reinterpret_cast<const bfloat16*>(a), /*C_stride_m=*/0,
134+
reinterpret_cast<float*>(c));
135+
}
136+
}
137+
138+
void sum_squared_bf16_fp32_neon(size_t n, size_t k3, size_t k2, size_t k1,
139+
size_t a_stride_n, size_t a_stride_k3,
140+
size_t a_stride_k2, const void* a, size_t,
141+
void* c) {
142+
if (k1 == 1 && a_stride_n == sizeof(bfloat16)) {
143+
tiled_reduce<sum_accumulator_k1_1<bf16x8x8, f32x4x16, Square>, bfloat16,
144+
float>(
145+
n, k3, k2, a_stride_k3, a_stride_k2,
146+
reinterpret_cast<const bfloat16*>(a), /*C_stride_m=*/0,
147+
reinterpret_cast<float*>(c));
148+
} else {
149+
tiled_reduce<sum_accumulator_x32<f32x4, 8, Square>, bfloat16, float>(
150+
n, k3, k2, k1, a_stride_n, a_stride_k3, a_stride_k2,
151+
reinterpret_cast<const bfloat16*>(a), /*C_stride_m=*/0,
152+
reinterpret_cast<float*>(c));
153+
}
154+
}
155+
53156
void sum_fp32_neon(size_t n, size_t k3, size_t k2, size_t k1,
54157
size_t a_stride_n, size_t a_stride_k3, size_t a_stride_k2,
55158
const void* a, size_t, void* c) {

ynnpack/kernels/reduce/sum.inc

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,17 @@
11
// clang-format off
22

3-
#ifdef YNN_ARCH_ARM_NEON
4-
YNN_UNARY_REDUCE_KERNEL(arch_flag::neon, sum_fp32_neon, float, float)
5-
YNN_UNARY_REDUCE_KERNEL(arch_flag::neondot, sum_int8_int32_neondot, int8_t, int32_t)
6-
YNN_UNARY_REDUCE_KERNEL(arch_flag::neondot, sum_uint8_int32_neondot, uint8_t, int32_t)
7-
#endif // YNN_ARCH_ARM_NEON
83
#ifdef YNN_ARCH_ARM_NEONBF16
94
YNN_UNARY_REDUCE_KERNEL(arch_flag::neonbf16, sum_bf16_fp32_neonbf16, bfloat16, float)
105
#endif // YNN_ARCH_ARM_NEONBF16
116
#ifdef YNN_ARCH_ARM_NEONFP16ARITH
127
YNN_UNARY_REDUCE_KERNEL(arch_flag::neonfp16arith, sum_fp16_fp32_neonfp16arith, half, float)
138
#endif // YNN_ARCH_ARM_NEONFP16ARITH
9+
#ifdef YNN_ARCH_ARM_NEON
10+
YNN_UNARY_REDUCE_KERNEL(arch_flag::neon, sum_fp32_neon, float, float)
11+
YNN_UNARY_REDUCE_KERNEL(arch_flag::neon, sum_bf16_fp32_neon, bfloat16, float)
12+
YNN_UNARY_REDUCE_KERNEL(arch_flag::neondot, sum_int8_int32_neondot, int8_t, int32_t)
13+
YNN_UNARY_REDUCE_KERNEL(arch_flag::neondot, sum_uint8_int32_neondot, uint8_t, int32_t)
14+
#endif // YNN_ARCH_ARM_NEON
1415

1516
#ifdef YNN_ARCH_X86_AVX512BF16
1617
YNN_UNARY_REDUCE_KERNEL(arch_flag::avx512bf16, sum_bf16_fp32_avx512bf16, bfloat16, float)

ynnpack/kernels/reduce/sum_squared.inc

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,17 @@
11
// clang-format off
22

3-
#ifdef YNN_ARCH_ARM_NEON
4-
YNN_UNARY_REDUCE_KERNEL(arch_flag::neon, sum_squared_fp32_neon, float, float)
5-
YNN_UNARY_REDUCE_KERNEL(arch_flag::neondot, sum_squared_int8_int32_neondot, int8_t, int32_t)
6-
YNN_UNARY_REDUCE_KERNEL(arch_flag::neondot, sum_squared_uint8_int32_neondot, uint8_t, int32_t)
7-
#endif // YNN_ARCH_ARM_NEON
83
#ifdef YNN_ARCH_ARM_NEONBF16
94
YNN_UNARY_REDUCE_KERNEL(arch_flag::neonbf16, sum_squared_bf16_fp32_neonbf16, bfloat16, float)
105
#endif // YNN_ARCH_ARM_NEONBF16
116
#ifdef YNN_ARCH_ARM_NEONFP16ARITH
127
YNN_UNARY_REDUCE_KERNEL(arch_flag::neonfp16arith, sum_squared_fp16_fp32_neonfp16arith, half, float)
138
#endif // YNN_ARCH_ARM_NEONFP16ARITH
9+
#ifdef YNN_ARCH_ARM_NEON
10+
YNN_UNARY_REDUCE_KERNEL(arch_flag::neon, sum_squared_fp32_neon, float, float)
11+
YNN_UNARY_REDUCE_KERNEL(arch_flag::neon, sum_squared_bf16_fp32_neon, bfloat16, float)
12+
YNN_UNARY_REDUCE_KERNEL(arch_flag::neondot, sum_squared_int8_int32_neondot, int8_t, int32_t)
13+
YNN_UNARY_REDUCE_KERNEL(arch_flag::neondot, sum_squared_uint8_int32_neondot, uint8_t, int32_t)
14+
#endif // YNN_ARCH_ARM_NEON
1415

1516
#ifdef YNN_ARCH_X86_AVX512BF16
1617
YNN_UNARY_REDUCE_KERNEL(arch_flag::avx512bf16, sum_squared_bf16_fp32_avx512bf16, bfloat16, float)

0 commit comments

Comments
 (0)