WARNING: THIS SITE IS A MIRROR OF GITHUB.COM / IT CANNOT LOGIN OR REGISTER ACCOUNTS / THE CONTENTS ARE PROVIDED AS-IS / THIS SITE ASSUMES NO RESPONSIBILITY FOR ANY DISPLAYED CONTENT OR LINKS / IF YOU FOUND SOMETHING MAY NOT GOOD FOR EVERYONE, CONTACT ADMIN AT ilovescratch@foxmail.com
Skip to content

Commit 70a97fc

Browse files
Aelphyxnnpack-bot
authored andcommitted
Added bf16 sum and sum_squared to avx2.
Fixed avx512bf16 version to be the same as for neonbf16. PiperOrigin-RevId: 840891405
1 parent 9bab77f commit 70a97fc

File tree

6 files changed

+351
-11
lines changed

6 files changed

+351
-11
lines changed

ynnpack/kernels/reduce/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ ynn_cc_library(
5555
"x86_f16c": ["x86_f16c.cc"],
5656
"x86_avx512fp16": ["x86_avx512fp16.cc"],
5757
"x86_avx2": ["x86_avx2.cc"],
58+
"x86_avx2_fma3": ["x86_avx2_fma3.cc"],
5859
},
5960
visibility = ["//ynnpack:__subpackages__"],
6061
deps = [

ynnpack/kernels/reduce/arm_neon.cc

Lines changed: 110 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,16 +15,89 @@
1515
#include <type_traits>
1616

1717
#include "ynnpack/base/arithmetic.h"
18+
#include "ynnpack/base/simd/multi_vec.h"
1819
#include "ynnpack/kernels/reduce/generic.h"
1920
#include "ynnpack/kernels/reduce/min_max_accumulator.h"
2021
#include "ynnpack/kernels/reduce/reduce.h"
2122
#include "ynnpack/kernels/reduce/sum_accumulator.h"
2223

2324
namespace ynn {
2425

26+
namespace simd {
27+
28+
using bf16x8x8 = multi_vec<bf16x8, 8>;
29+
using f32x4x16 = multi_vec<f32x4, 16>;
30+
31+
static f32x4x16 reduce_add(
32+
f32x4x16 a, bf16x8x8 b, Identity /*map_fn*/,
33+
std::integral_constant<size_t, 1> /*horizontal_factor*/) {
34+
YNN_UNROLL
35+
for (int i = 0; i < 8; ++i) {
36+
float32x4_t lo =
37+
vreinterpretq_f32_u32(vshll_n_u16(vget_low_u16(b.v[i].v), 16));
38+
float32x4_t hi =
39+
vreinterpretq_f32_u32(vshll_n_u16(vget_high_u16(b.v[i].v), 16));
40+
41+
a.v[2 * i] += f32x4{lo};
42+
a.v[2 * i + 1] += f32x4{hi};
43+
}
44+
45+
return a;
46+
}
47+
48+
static f32x4 reduce_add(
49+
f32x4 a, bf16x8 b, Identity /*map_fn*/,
50+
std::integral_constant<size_t, 2> /*horizontal_factor*/) {
51+
float32x4_t lo = vreinterpretq_f32_u32(vshll_n_u16(vget_low_u16(b.v), 16));
52+
float32x4_t hi = vreinterpretq_f32_u32(vshll_n_u16(vget_high_u16(b.v), 16));
53+
#ifndef __aarch64__
54+
float32x2_t pair_lo = vpadd_f32(vget_low_f32(lo), vget_high_f32(lo));
55+
float32x2_t pair_hi = vpadd_f32(vget_low_f32(hi), vget_high_f32(hi));
56+
return a += f32x4{vcombine_f32(pair_lo, pair_hi)};
57+
#else
58+
return a += f32x4{vpaddq_f32(lo, hi)};
59+
#endif
60+
}
61+
62+
static f32x4x16 reduce_add(
63+
f32x4x16 a, bf16x8x8 b, Square /*map_fn*/,
64+
std::integral_constant<size_t, 1> /*horizontal_factor*/) {
65+
YNN_UNROLL
66+
for (int i = 0; i < 8; ++i) {
67+
float32x4_t lo =
68+
vreinterpretq_f32_u32(vshll_n_u16(vget_low_u16(b.v[i].v), 16));
69+
float32x4_t hi =
70+
vreinterpretq_f32_u32(vshll_n_u16(vget_high_u16(b.v[i].v), 16));
71+
a.v[2 * i].v = vmlaq_f32(a.v[2 * i].v, lo, lo);
72+
a.v[2 * i + 1].v = vmlaq_f32(a.v[2 * i + 1].v, hi, hi);
73+
}
74+
75+
return a;
76+
}
77+
78+
static f32x4 reduce_add(
79+
f32x4 a, bf16x8 b, Square /*map_fn*/,
80+
std::integral_constant<size_t, 2> /*horizontal_factor*/) {
81+
float32x4_t lo = vreinterpretq_f32_u32(vshll_n_u16(vget_low_u16(b.v), 16));
82+
float32x4_t hi = vreinterpretq_f32_u32(vshll_n_u16(vget_high_u16(b.v), 16));
83+
#ifndef __aarch64__
84+
float32x4_t sq_lo = vmulq_f32(lo, lo);
85+
float32x4_t sq_hi = vmulq_f32(hi, hi);
86+
float32x2_t pair_lo = vpadd_f32(vget_low_f32(sq_lo), vget_high_f32(sq_lo));
87+
float32x2_t pair_hi = vpadd_f32(vget_low_f32(sq_hi), vget_high_f32(sq_hi));
88+
return a += f32x4{vcombine_f32(pair_lo, pair_hi)};
89+
#else
90+
return a += f32x4{vpaddq_f32(vmulq_f32(lo, lo), vmulq_f32(hi, hi))};
91+
#endif
92+
}
93+
94+
} // namespace simd
95+
96+
using simd::f32x4;
97+
using simd::f32x4x16;
2598
using simd::bf16x8;
99+
using simd::bf16x8x8;
26100
using simd::f16x8;
27-
using simd::f32x4;
28101
using simd::s16x8;
29102
using simd::s8x16;
30103
using simd::u8x16;
@@ -50,6 +123,42 @@ MIN_MAX_KERNEL(max_fp16_4x8_neon, dummy_t, f16x8_rvar, half, 8);
50123
MIN_MAX_KERNEL(max_uint8_4x16_neon, dummy_t, u8x16, uint8_t, 16);
51124
MIN_MAX_KERNEL(max_int8_4x16_neon, dummy_t, s8x16, int8_t, 16);
52125

126+
void sum_bf16_fp32_neon(size_t n, size_t k3, size_t k2, size_t k1,
127+
size_t a_stride_n, size_t a_stride_k3,
128+
size_t a_stride_k2, const void* a, size_t,
129+
void* c) {
130+
if (k1 == 1 && a_stride_n == sizeof(bfloat16)) {
131+
tiled_reduce<sum_accumulator_k1_1<bf16x8x8, f32x4x16, Square>, bfloat16,
132+
float>(
133+
n, k3, k2, a_stride_k3, a_stride_k2,
134+
reinterpret_cast<const bfloat16*>(a), /*C_stride_m=*/0,
135+
reinterpret_cast<float*>(c));
136+
} else {
137+
tiled_reduce<sum_accumulator_x32<f32x4, 8>, bfloat16, float>(
138+
n, k3, k2, k1, a_stride_n, a_stride_k3, a_stride_k2,
139+
reinterpret_cast<const bfloat16*>(a), /*C_stride_m=*/0,
140+
reinterpret_cast<float*>(c));
141+
}
142+
}
143+
144+
void sum_squared_bf16_fp32_neon(size_t n, size_t k3, size_t k2, size_t k1,
145+
size_t a_stride_n, size_t a_stride_k3,
146+
size_t a_stride_k2, const void* a, size_t,
147+
void* c) {
148+
if (k1 == 1 && a_stride_n == sizeof(bfloat16)) {
149+
tiled_reduce<sum_accumulator_k1_1<bf16x8x8, f32x4x16, Square>, bfloat16,
150+
float>(
151+
n, k3, k2, a_stride_k3, a_stride_k2,
152+
reinterpret_cast<const bfloat16*>(a), /*C_stride_m=*/0,
153+
reinterpret_cast<float*>(c));
154+
} else {
155+
tiled_reduce<sum_accumulator_x32<f32x4, 8, Square>, bfloat16, float>(
156+
n, k3, k2, k1, a_stride_n, a_stride_k3, a_stride_k2,
157+
reinterpret_cast<const bfloat16*>(a), /*C_stride_m=*/0,
158+
reinterpret_cast<float*>(c));
159+
}
160+
}
161+
53162
void sum_fp32_neon(size_t n, size_t k3, size_t k2, size_t k1,
54163
size_t a_stride_n, size_t a_stride_k3, size_t a_stride_k2,
55164
const void* a, size_t, void* c) {

ynnpack/kernels/reduce/sum.inc

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,17 @@
11
// clang-format off
22

3-
#ifdef YNN_ARCH_ARM_NEON
4-
YNN_UNARY_REDUCE_KERNEL(arch_flag::neon, sum_fp32_neon, float, float)
5-
YNN_UNARY_REDUCE_KERNEL(arch_flag::neondot, sum_int8_int32_neondot, int8_t, int32_t)
6-
YNN_UNARY_REDUCE_KERNEL(arch_flag::neondot, sum_uint8_int32_neondot, uint8_t, int32_t)
7-
#endif // YNN_ARCH_ARM_NEON
83
#ifdef YNN_ARCH_ARM_NEONBF16
94
YNN_UNARY_REDUCE_KERNEL(arch_flag::neonbf16, sum_bf16_fp32_neonbf16, bfloat16, float)
105
#endif // YNN_ARCH_ARM_NEONBF16
116
#ifdef YNN_ARCH_ARM_NEONFP16ARITH
127
YNN_UNARY_REDUCE_KERNEL(arch_flag::neonfp16arith, sum_fp16_fp32_neonfp16arith, half, float)
138
#endif // YNN_ARCH_ARM_NEONFP16ARITH
9+
#ifdef YNN_ARCH_ARM_NEON
10+
YNN_UNARY_REDUCE_KERNEL(arch_flag::neon, sum_fp32_neon, float, float)
11+
YNN_UNARY_REDUCE_KERNEL(arch_flag::neon, sum_bf16_fp32_neon, bfloat16, float)
12+
YNN_UNARY_REDUCE_KERNEL(arch_flag::neondot, sum_int8_int32_neondot, int8_t, int32_t)
13+
YNN_UNARY_REDUCE_KERNEL(arch_flag::neondot, sum_uint8_int32_neondot, uint8_t, int32_t)
14+
#endif // YNN_ARCH_ARM_NEON
1415

1516
#ifdef YNN_ARCH_X86_AVX512BF16
1617
YNN_UNARY_REDUCE_KERNEL(arch_flag::avx512bf16, sum_bf16_fp32_avx512bf16, bfloat16, float)
@@ -25,8 +26,12 @@ YNN_UNARY_REDUCE_KERNEL(arch_flag::avx512bw, sum_int8_int32_avx512bw, int8_t, in
2526
#ifdef YNN_ARCH_X86_AVX512F
2627
YNN_UNARY_REDUCE_KERNEL(arch_flag::avx512f, sum_fp32_avx512f, float, float)
2728
#endif // YNN_ARCH_X86_AVX512F
29+
#ifdef YNN_ARCH_X86_AVX2_FMA3
30+
YNN_UNARY_REDUCE_KERNEL(arch_flag::avx2 | arch_flag::fma3, sum_bf16_fp32_avx2_fma3, bfloat16, float)
31+
#endif // YNN_ARCH_X86_AVX2_FMA3
2832
#ifdef YNN_ARCH_X86_AVX2
2933
YNN_UNARY_REDUCE_KERNEL(arch_flag::avx2, sum_fp32_avx2, float, float)
34+
YNN_UNARY_REDUCE_KERNEL(arch_flag::avx2, sum_bf16_fp32_avx2, bfloat16, float)
3035
YNN_UNARY_REDUCE_KERNEL(arch_flag::avx2, sum_uint8_int32_avx2, uint8_t, int32_t)
3136
YNN_UNARY_REDUCE_KERNEL(arch_flag::avx2, sum_int8_int32_avx2, int8_t, int32_t)
3237
#endif // YNN_ARCH_X86_AVX2

ynnpack/kernels/reduce/sum_squared.inc

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,17 @@
11
// clang-format off
22

3-
#ifdef YNN_ARCH_ARM_NEON
4-
YNN_UNARY_REDUCE_KERNEL(arch_flag::neon, sum_squared_fp32_neon, float, float)
5-
YNN_UNARY_REDUCE_KERNEL(arch_flag::neondot, sum_squared_int8_int32_neondot, int8_t, int32_t)
6-
YNN_UNARY_REDUCE_KERNEL(arch_flag::neondot, sum_squared_uint8_int32_neondot, uint8_t, int32_t)
7-
#endif // YNN_ARCH_ARM_NEON
83
#ifdef YNN_ARCH_ARM_NEONBF16
94
YNN_UNARY_REDUCE_KERNEL(arch_flag::neonbf16, sum_squared_bf16_fp32_neonbf16, bfloat16, float)
105
#endif // YNN_ARCH_ARM_NEONBF16
116
#ifdef YNN_ARCH_ARM_NEONFP16ARITH
127
YNN_UNARY_REDUCE_KERNEL(arch_flag::neonfp16arith, sum_squared_fp16_fp32_neonfp16arith, half, float)
138
#endif // YNN_ARCH_ARM_NEONFP16ARITH
9+
#ifdef YNN_ARCH_ARM_NEON
10+
YNN_UNARY_REDUCE_KERNEL(arch_flag::neon, sum_squared_fp32_neon, float, float)
11+
YNN_UNARY_REDUCE_KERNEL(arch_flag::neon, sum_squared_bf16_fp32_neon, bfloat16, float)
12+
YNN_UNARY_REDUCE_KERNEL(arch_flag::neondot, sum_squared_int8_int32_neondot, int8_t, int32_t)
13+
YNN_UNARY_REDUCE_KERNEL(arch_flag::neondot, sum_squared_uint8_int32_neondot, uint8_t, int32_t)
14+
#endif // YNN_ARCH_ARM_NEON
1415

1516
#ifdef YNN_ARCH_X86_AVX512BF16
1617
YNN_UNARY_REDUCE_KERNEL(arch_flag::avx512bf16, sum_squared_bf16_fp32_avx512bf16, bfloat16, float)
@@ -25,8 +26,12 @@ YNN_UNARY_REDUCE_KERNEL(arch_flag::avx512bw, sum_squared_int8_int32_avx512bw, in
2526
#ifdef YNN_ARCH_X86_AVX512F
2627
YNN_UNARY_REDUCE_KERNEL(arch_flag::avx512f, sum_squared_fp32_avx512f, float, float)
2728
#endif // YNN_ARCH_X86_AVX512F
29+
#ifdef YNN_ARCH_X86_AVX2_FMA3
30+
YNN_UNARY_REDUCE_KERNEL(arch_flag::avx2 | arch_flag::fma3, sum_squared_bf16_fp32_avx2_fma3, bfloat16, float)
31+
#endif // YNN_ARCH_X86_AVX2_FMA3
2832
#ifdef YNN_ARCH_X86_AVX2
2933
YNN_UNARY_REDUCE_KERNEL(arch_flag::avx2, sum_squared_fp32_avx2, float, float)
34+
YNN_UNARY_REDUCE_KERNEL(arch_flag::avx2, sum_squared_bf16_fp32_avx2, bfloat16, float)
3035
YNN_UNARY_REDUCE_KERNEL(arch_flag::avx2, sum_squared_int8_int32_avx2, int8_t, int32_t)
3136
YNN_UNARY_REDUCE_KERNEL(arch_flag::avx2, sum_squared_uint8_int32_avx2, uint8_t, int32_t)
3237
#endif // YNN_ARCH_X86_AVX2

ynnpack/kernels/reduce/x86_avx2.cc

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include <cstring>
1414
#include <type_traits>
1515

16+
#include "ynnpack/base/base.h"
1617
#include "ynnpack/base/bfloat16.h"
1718
#include "ynnpack/base/half.h"
1819
#include "ynnpack/base/simd/multi_vec.h"
@@ -26,8 +27,10 @@ namespace ynn {
2627
namespace simd {
2728

2829
using f32x8x8 = simd::multi_vec<f32x8, 8>;
30+
using f32x8x16 = simd::multi_vec<f32x8, 16>;
2931
using s32x8x2 = multi_vec<s32x8, 2>;
3032
using s32x8x4 = multi_vec<s32x8, 4>;
33+
using bf16x16x8 = multi_vec<bf16x16, 4>;
3134

3235
static s32x8x4& operator+=(s32x8x4& a, s8x32 b) {
3336
s8x16 b_lo = extract<0>(b, s8x16{});
@@ -115,14 +118,71 @@ static s32x8 reduce_add(
115118
return a += s32x8(_mm256_madd_epi16(b_16, b_16));
116119
}
117120

121+
static f32x8x16 reduce_add(
122+
f32x8x16 a, bf16x16x8 b, Identity /*map_fn*/,
123+
std::integral_constant<size_t, 1> /*horizontal_factor*/) {
124+
YNN_UNROLL
125+
for (int i = 0; i < 8; ++i) {
126+
__m256i lo_u32 = _mm256_cvtepu16_epi32(_mm256_castsi256_si128(b.v[i].v));
127+
__m256i hi_u32 =
128+
_mm256_cvtepu16_epi32(_mm256_extracti128_si256(b.v[i].v, 1));
129+
130+
a.v[2 * i] += f32x8{_mm256_castsi256_ps(_mm256_slli_epi32(lo_u32, 16))};
131+
a.v[2 * i + 1] += f32x8{_mm256_castsi256_ps(_mm256_slli_epi32(hi_u32, 16))};
132+
}
133+
134+
return a;
135+
}
136+
137+
static f32x8 reduce_add(
138+
f32x8 a, bf16x16 b, Identity /*map_fn*/,
139+
std::integral_constant<size_t, 2> /*horizontal_factor*/) {
140+
__m256 mask = _mm256_castsi256_ps(_mm256_set1_epi32(0xFFFF0000));
141+
__m256 evens = _mm256_castsi256_ps(_mm256_slli_epi32(b.v, 16));
142+
__m256 odds = _mm256_and_ps(_mm256_castsi256_ps(b.v), mask);
143+
144+
return a += f32x8{_mm256_add_ps(evens, odds)};
145+
}
146+
147+
static f32x8x16 reduce_add(
148+
f32x8x16 a, bf16x16x8 b, Square /*map_fn*/,
149+
std::integral_constant<size_t, 1> /*horizontal_factor*/) {
150+
YNN_UNROLL
151+
for (int i = 0; i < 8; ++i) {
152+
__m256i lo_u32 = _mm256_cvtepu16_epi32(_mm256_castsi256_si128(b.v[i].v));
153+
__m256i hi_u32 =
154+
_mm256_cvtepu16_epi32(_mm256_extracti128_si256(b.v[i].v, 1));
155+
__m256 lo_f32 = _mm256_castsi256_ps(_mm256_slli_epi32(lo_u32, 16));
156+
__m256 hi_f32 = _mm256_castsi256_ps(_mm256_slli_epi32(hi_u32, 16));
157+
158+
a.v[2 * i] += f32x8{_mm256_mul_ps(lo_f32, lo_f32)};
159+
a.v[2 * i + 1] += f32x8{_mm256_mul_ps(hi_f32, hi_f32)};
160+
}
161+
162+
return a;
163+
}
164+
165+
static f32x8 reduce_add(
166+
f32x8 a, bf16x16 b, Square /*map_fn*/,
167+
std::integral_constant<size_t, 2> /*horizontal_factor*/) {
168+
__m256 mask = _mm256_castsi256_ps(_mm256_set1_epi32(0xFFFF0000));
169+
__m256 evens = _mm256_castsi256_ps(_mm256_slli_epi32(b.v, 16));
170+
__m256 odds = _mm256_and_ps(_mm256_castsi256_ps(b.v), mask);
171+
__m256 sq_evens = _mm256_mul_ps(evens, evens);
172+
__m256 sq_odds = _mm256_mul_ps(odds, odds);
173+
return a += f32x8{_mm256_add_ps(sq_evens, sq_odds)};
174+
}
175+
118176
} // namespace simd
119177

120178
using simd::s32x8;
121179
using simd::s32x8x2;
122180
using simd::s32x8x4;
123181
using simd::f32x8;
124182
using simd::f32x8x8;
183+
using simd::f32x8x16;
125184
using simd::bf16x16;
185+
using simd::bf16x16x8;
126186
using simd::f16x16;
127187
using simd::s16x16;
128188
using simd::s8x16;
@@ -233,6 +293,40 @@ void sum_squared_uint8_int32_avx2(size_t n, size_t k3, size_t k2, size_t k1,
233293
}
234294
}
235295

296+
void sum_bf16_fp32_avx2(size_t n, size_t k3, size_t k2, size_t k1,
297+
size_t a_stride_n, size_t a_stride_k3,
298+
size_t a_stride_k2, const void* a, size_t, void* c) {
299+
if (k1 == 1 && a_stride_n == sizeof(bfloat16)) {
300+
tiled_reduce<sum_accumulator_k1_1<bf16x16x8, f32x8x16>, bfloat16, float>(
301+
n, k3, k2, a_stride_k3, a_stride_k2,
302+
reinterpret_cast<const bfloat16*>(a), /*C_stride_m=*/0,
303+
reinterpret_cast<float*>(c));
304+
} else {
305+
tiled_reduce<sum_accumulator_x32<f32x8, 16>, bfloat16, float>(
306+
n, k3, k2, k1, a_stride_n, a_stride_k3, a_stride_k2,
307+
reinterpret_cast<const bfloat16*>(a), /*C_stride_m=*/0,
308+
reinterpret_cast<float*>(c));
309+
}
310+
}
311+
312+
void sum_squared_bf16_fp32_avx2(size_t n, size_t k3, size_t k2, size_t k1,
313+
size_t a_stride_n, size_t a_stride_k3,
314+
size_t a_stride_k2, const void* a, size_t,
315+
void* c) {
316+
if (k1 == 1 && a_stride_n == sizeof(bfloat16)) {
317+
tiled_reduce<sum_accumulator_k1_1<bf16x16x8, f32x8x16, Square>, bfloat16,
318+
float>(
319+
n, k3, k2, a_stride_k3, a_stride_k2,
320+
reinterpret_cast<const bfloat16*>(a), /*C_stride_m=*/0,
321+
reinterpret_cast<float*>(c));
322+
} else {
323+
tiled_reduce<sum_accumulator_x32<f32x8, 16, Square>, bfloat16, float>(
324+
n, k3, k2, k1, a_stride_n, a_stride_k3, a_stride_k2,
325+
reinterpret_cast<const bfloat16*>(a), /*C_stride_m=*/0,
326+
reinterpret_cast<float*>(c));
327+
}
328+
}
329+
236330
void sum_squared_fp32_avx2(size_t n, size_t k3, size_t k2, size_t k1,
237331
size_t a_stride_n, size_t a_stride_k3,
238332
size_t a_stride_k2, const void* a, size_t, void* c) {

0 commit comments

Comments
 (0)