WARNING: THIS SITE IS A MIRROR OF GITHUB.COM / IT CANNOT LOGIN OR REGISTER ACCOUNTS / THE CONTENTS ARE PROVIDED AS-IS / THIS SITE ASSUMES NO RESPONSIBILITY FOR ANY DISPLAYED CONTENT OR LINKS / IF YOU FOUND SOMETHING MAY NOT GOOD FOR EVERYONE, CONTACT ADMIN AT ilovescratch@foxmail.com
Skip to content

Commit bbc8781

Browse files
Aelphyxnnpack-bot
authored andcommitted
Added bf16 sum and sum_squared to sse2.
PiperOrigin-RevId: 841168337
1 parent a342cf6 commit bbc8781

File tree

11 files changed

+600
-24
lines changed

11 files changed

+600
-24
lines changed

ynnpack/BUILD

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,14 @@ define_build_option(
213213
default_all = [":ynn_enable_x86_avx512"],
214214
)
215215

216+
define_build_option(
217+
name = "ynn_enable_x86_avx512bw_fma3",
218+
default_all = [
219+
":ynn_enable_x86_avx512bw",
220+
":ynn_enable_x86_fma3",
221+
],
222+
)
223+
216224
define_build_option(
217225
name = "ynn_enable_x86_avx512bf16",
218226
default_all = [

ynnpack/build_defs.bzl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,10 @@ _YNN_PARAMS_FOR_ARCH = {
174174
"cond": "//ynnpack:ynn_enable_x86_avx512bw",
175175
"copts": ["-mavx512bw"],
176176
},
177+
"x86_avx512bw_fma3": {
178+
"cond": "//ynnpack:ynn_enable_x86_avx512bw_fma3",
179+
"copts": ["-mavx512bw", "-mfma"],
180+
},
177181
"x86_avx512bf16": {
178182
"cond": "//ynnpack:ynn_enable_x86_avx512bf16",
179183
"copts": ["-mavx512bf16", "-mavx512dq"],

ynnpack/kernels/reduce/BUILD

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,10 +51,12 @@ ynn_cc_library(
5151
"x86_sse41": ["x86_sse41.cc"],
5252
"x86_avx512bf16": ["x86_avx512bf16.cc"],
5353
"x86_avx512bw": ["x86_avx512bw.cc"],
54+
"x86_avx512bw_fma3": ["x86_avx512bw_fma3.cc"],
5455
"x86_avx512f": ["x86_avx512f.cc"],
5556
"x86_f16c": ["x86_f16c.cc"],
5657
"x86_avx512fp16": ["x86_avx512fp16.cc"],
5758
"x86_avx2": ["x86_avx2.cc"],
59+
"x86_avx2_fma3": ["x86_avx2_fma3.cc"],
5860
},
5961
visibility = ["//ynnpack:__subpackages__"],
6062
deps = [

ynnpack/kernels/reduce/sum.inc

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,15 +21,23 @@ YNN_UNARY_REDUCE_KERNEL(arch_flag::avx512bf16, sum_bf16_fp32_avx512bf16, bfloat1
2121
#ifdef YNN_ARCH_X86_AVX512FP16
2222
YNN_UNARY_REDUCE_KERNEL(arch_flag::avx512fp16, sum_fp16_fp32_avx512fp16, half, float)
2323
#endif // YNN_ARCH_X86_AVX512FP16
24+
#ifdef YNN_ARCH_X86_AVX512BW_FMA3
25+
YNN_UNARY_REDUCE_KERNEL(arch_flag::avx512bw | arch_flag::fma3, sum_bf16_fp32_avx512bw_fma3, bfloat16, float)
26+
#endif // YNN_ARCH_X86_AVX512BW_FMA3
2427
#ifdef YNN_ARCH_X86_AVX512BW
28+
YNN_UNARY_REDUCE_KERNEL(arch_flag::avx512bw, sum_bf16_fp32_avx512bw, bfloat16, float)
2529
YNN_UNARY_REDUCE_KERNEL(arch_flag::avx512bw, sum_uint8_int32_avx512bw, uint8_t, int32_t)
2630
YNN_UNARY_REDUCE_KERNEL(arch_flag::avx512bw, sum_int8_int32_avx512bw, int8_t, int32_t)
2731
#endif // YNN_ARCH_X86_AVX512BW
2832
#ifdef YNN_ARCH_X86_AVX512F
2933
YNN_UNARY_REDUCE_KERNEL(arch_flag::avx512f, sum_fp32_avx512f, float, float)
3034
#endif // YNN_ARCH_X86_AVX512F
35+
#ifdef YNN_ARCH_X86_AVX2_FMA3
36+
YNN_UNARY_REDUCE_KERNEL(arch_flag::avx2 | arch_flag::fma3, sum_bf16_fp32_avx2_fma3, bfloat16, float)
37+
#endif // YNN_ARCH_X86_AVX2_FMA3
3138
#ifdef YNN_ARCH_X86_AVX2
3239
YNN_UNARY_REDUCE_KERNEL(arch_flag::avx2, sum_fp32_avx2, float, float)
40+
YNN_UNARY_REDUCE_KERNEL(arch_flag::avx2, sum_bf16_fp32_avx2, bfloat16, float)
3341
YNN_UNARY_REDUCE_KERNEL(arch_flag::avx2, sum_uint8_int32_avx2, uint8_t, int32_t)
3442
YNN_UNARY_REDUCE_KERNEL(arch_flag::avx2, sum_int8_int32_avx2, int8_t, int32_t)
3543
#endif // YNN_ARCH_X86_AVX2
@@ -42,6 +50,7 @@ YNN_UNARY_REDUCE_KERNEL(arch_flag::sse41, sum_int8_int32_sse41, int8_t, int32_t)
4250
#endif // YNN_ARCH_X86_SSE41
4351
#ifdef YNN_ARCH_X86_SSE2
4452
YNN_UNARY_REDUCE_KERNEL(arch_flag::sse2, sum_fp32_sse2, float, float)
53+
YNN_UNARY_REDUCE_KERNEL(arch_flag::sse2, sum_bf16_fp32_sse2, bfloat16, float)
4554
YNN_UNARY_REDUCE_KERNEL(arch_flag::sse2, sum_uint8_int32_sse2, uint8_t, int32_t)
4655
YNN_UNARY_REDUCE_KERNEL(arch_flag::sse2, sum_int8_int32_sse2, int8_t, int32_t)
4756
#endif // YNN_ARCH_X86_SSE2

ynnpack/kernels/reduce/sum_squared.inc

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,15 +21,23 @@ YNN_UNARY_REDUCE_KERNEL(arch_flag::avx512bf16, sum_squared_bf16_fp32_avx512bf16,
2121
#ifdef YNN_ARCH_X86_AVX512FP16
2222
YNN_UNARY_REDUCE_KERNEL(arch_flag::avx512fp16, sum_squared_fp16_fp32_avx512fp16, half, float)
2323
#endif // YNN_ARCH_X86_AVX512FP16
24+
#ifdef YNN_ARCH_X86_AVX512BW_FMA3
25+
YNN_UNARY_REDUCE_KERNEL(arch_flag::avx512bw | arch_flag::fma3, sum_squared_bf16_fp32_avx512bw_fma3, bfloat16, float)
26+
#endif // YNN_ARCH_X86_AVX512BW_FMA3
2427
#ifdef YNN_ARCH_X86_AVX512BW
28+
YNN_UNARY_REDUCE_KERNEL(arch_flag::avx512bw, sum_squared_bf16_fp32_avx512bw, bfloat16, float)
2529
YNN_UNARY_REDUCE_KERNEL(arch_flag::avx512bw, sum_squared_uint8_int32_avx512bw, uint8_t, int32_t)
2630
YNN_UNARY_REDUCE_KERNEL(arch_flag::avx512bw, sum_squared_int8_int32_avx512bw, int8_t, int32_t)
2731
#endif // YNN_ARCH_X86_AVX512BW
2832
#ifdef YNN_ARCH_X86_AVX512F
2933
YNN_UNARY_REDUCE_KERNEL(arch_flag::avx512f, sum_squared_fp32_avx512f, float, float)
3034
#endif // YNN_ARCH_X86_AVX512F
35+
#ifdef YNN_ARCH_X86_AVX2_FMA3
36+
YNN_UNARY_REDUCE_KERNEL(arch_flag::avx2 | arch_flag::fma3, sum_squared_bf16_fp32_avx2_fma3, bfloat16, float)
37+
#endif // YNN_ARCH_X86_AVX2_FMA3
3138
#ifdef YNN_ARCH_X86_AVX2
3239
YNN_UNARY_REDUCE_KERNEL(arch_flag::avx2, sum_squared_fp32_avx2, float, float)
40+
YNN_UNARY_REDUCE_KERNEL(arch_flag::avx2, sum_squared_bf16_fp32_avx2, bfloat16, float)
3341
YNN_UNARY_REDUCE_KERNEL(arch_flag::avx2, sum_squared_int8_int32_avx2, int8_t, int32_t)
3442
YNN_UNARY_REDUCE_KERNEL(arch_flag::avx2, sum_squared_uint8_int32_avx2, uint8_t, int32_t)
3543
#endif // YNN_ARCH_X86_AVX2
@@ -42,6 +50,7 @@ YNN_UNARY_REDUCE_KERNEL(arch_flag::sse41, sum_squared_uint8_int32_sse41, uint8_t
4250
#endif // YNN_ARCH_X86_SSE41
4351
#ifdef YNN_ARCH_X86_SSE2
4452
YNN_UNARY_REDUCE_KERNEL(arch_flag::sse2, sum_squared_fp32_sse2, float, float)
53+
YNN_UNARY_REDUCE_KERNEL(arch_flag::sse2, sum_squared_bf16_fp32_sse2, bfloat16, float)
4554
#endif // YNN_ARCH_X86_SSE2
4655

4756
YNN_UNARY_REDUCE_KERNEL(arch_flag::none, sum_squared_fp32, float, float)

ynnpack/kernels/reduce/x86_avx2.cc

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include <cstring>
1414
#include <type_traits>
1515

16+
#include "ynnpack/base/base.h"
1617
#include "ynnpack/base/bfloat16.h"
1718
#include "ynnpack/base/half.h"
1819
#include "ynnpack/base/simd/multi_vec.h"
@@ -26,8 +27,10 @@ namespace ynn {
2627
namespace simd {
2728

2829
using f32x8x8 = simd::multi_vec<f32x8, 8>;
30+
using f32x8x16 = simd::multi_vec<f32x8, 16>;
2931
using s32x8x2 = multi_vec<s32x8, 2>;
3032
using s32x8x4 = multi_vec<s32x8, 4>;
33+
using bf16x16x8 = multi_vec<bf16x16, 8>;
3134

3235
static s32x8x4& operator+=(s32x8x4& a, s8x32 b) {
3336
s8x16 b_lo = extract<0>(b, s8x16{});
@@ -115,14 +118,72 @@ static s32x8 reduce_add(
115118
return a += s32x8(_mm256_madd_epi16(b_16, b_16));
116119
}
117120

121+
static f32x8x16 reduce_add(
122+
f32x8x16 a, bf16x16x8 b, Identity /*map_fn*/,
123+
std::integral_constant<size_t, 1> /*horizontal_factor*/) {
124+
YNN_UNROLL
125+
for (int i = 0; i < 8; ++i) {
126+
__m256i lo = _mm256_cvtepu16_epi32(_mm256_castsi256_si128(b.v[i].v));
127+
__m256i hi = _mm256_cvtepu16_epi32(_mm256_extracti128_si256(b.v[i].v, 1));
128+
129+
a.v[2 * i + 0] += f32x8{_mm256_castsi256_ps(_mm256_slli_epi32(lo, 16))};
130+
a.v[2 * i + 1] += f32x8{_mm256_castsi256_ps(_mm256_slli_epi32(hi, 16))};
131+
}
132+
133+
return a;
134+
}
135+
136+
static f32x8 reduce_add(
137+
f32x8 a, bf16x16 b, Identity /*map_fn*/,
138+
std::integral_constant<size_t, 2> /*horizontal_factor*/) {
139+
__m256 mask = _mm256_castsi256_ps(_mm256_set1_epi32(0xFFFF0000));
140+
__m256 evens = _mm256_castsi256_ps(_mm256_slli_epi32(b.v, 16));
141+
__m256 odds = _mm256_and_ps(_mm256_castsi256_ps(b.v), mask);
142+
143+
a += f32x8{odds};
144+
a += f32x8{evens};
145+
return a;
146+
}
147+
148+
static f32x8x16 reduce_add(
149+
f32x8x16 a, bf16x16x8 b, Square /*map_fn*/,
150+
std::integral_constant<size_t, 1> /*horizontal_factor*/) {
151+
YNN_UNROLL
152+
for (int i = 0; i < 8; ++i) {
153+
__m256i lo_u32 = _mm256_cvtepu16_epi32(_mm256_castsi256_si128(b.v[i].v));
154+
__m256i hi_u32 =
155+
_mm256_cvtepu16_epi32(_mm256_extracti128_si256(b.v[i].v, 1));
156+
__m256 lo_f32 = _mm256_castsi256_ps(_mm256_slli_epi32(lo_u32, 16));
157+
__m256 hi_f32 = _mm256_castsi256_ps(_mm256_slli_epi32(hi_u32, 16));
158+
159+
a.v[2 * i + 0] += f32x8{_mm256_mul_ps(lo_f32, lo_f32)};
160+
a.v[2 * i + 1] += f32x8{_mm256_mul_ps(hi_f32, hi_f32)};
161+
}
162+
163+
return a;
164+
}
165+
166+
static f32x8 reduce_add(
167+
f32x8 a, bf16x16 b, Square /*map_fn*/,
168+
std::integral_constant<size_t, 2> /*horizontal_factor*/) {
169+
__m256 mask = _mm256_castsi256_ps(_mm256_set1_epi32(0xFFFF0000));
170+
__m256 evens = _mm256_castsi256_ps(_mm256_slli_epi32(b.v, 16));
171+
__m256 odds = _mm256_and_ps(_mm256_castsi256_ps(b.v), mask);
172+
a += f32x8{_mm256_mul_ps(odds, odds)};
173+
a += f32x8{_mm256_mul_ps(evens, evens)};
174+
return a;
175+
}
176+
118177
} // namespace simd
119178

120179
using simd::s32x8;
121180
using simd::s32x8x2;
122181
using simd::s32x8x4;
123182
using simd::f32x8;
124183
using simd::f32x8x8;
184+
using simd::f32x8x16;
125185
using simd::bf16x16;
186+
using simd::bf16x16x8;
126187
using simd::f16x16;
127188
using simd::s16x16;
128189
using simd::s8x16;
@@ -233,6 +294,40 @@ void sum_squared_uint8_int32_avx2(size_t n, size_t k3, size_t k2, size_t k1,
233294
}
234295
}
235296

297+
void sum_bf16_fp32_avx2(size_t n, size_t k3, size_t k2, size_t k1,
298+
size_t a_stride_n, size_t a_stride_k3,
299+
size_t a_stride_k2, const void* a, size_t, void* c) {
300+
if (k1 == 1 && a_stride_n == sizeof(bfloat16)) {
301+
tiled_reduce<sum_accumulator_k1_1<bf16x16x8, f32x8x16>, bfloat16, float>(
302+
n, k3, k2, a_stride_k3, a_stride_k2,
303+
reinterpret_cast<const bfloat16*>(a), /*C_stride_m=*/0,
304+
reinterpret_cast<float*>(c));
305+
} else {
306+
tiled_reduce<sum_accumulator_x32<f32x8, 16>, bfloat16, float>(
307+
n, k3, k2, k1, a_stride_n, a_stride_k3, a_stride_k2,
308+
reinterpret_cast<const bfloat16*>(a), /*C_stride_m=*/0,
309+
reinterpret_cast<float*>(c));
310+
}
311+
}
312+
313+
void sum_squared_bf16_fp32_avx2(size_t n, size_t k3, size_t k2, size_t k1,
314+
size_t a_stride_n, size_t a_stride_k3,
315+
size_t a_stride_k2, const void* a, size_t,
316+
void* c) {
317+
if (k1 == 1 && a_stride_n == sizeof(bfloat16)) {
318+
tiled_reduce<sum_accumulator_k1_1<bf16x16x8, f32x8x16, Square>, bfloat16,
319+
float>(
320+
n, k3, k2, a_stride_k3, a_stride_k2,
321+
reinterpret_cast<const bfloat16*>(a), /*C_stride_m=*/0,
322+
reinterpret_cast<float*>(c));
323+
} else {
324+
tiled_reduce<sum_accumulator_x32<f32x8, 16, Square>, bfloat16, float>(
325+
n, k3, k2, k1, a_stride_n, a_stride_k3, a_stride_k2,
326+
reinterpret_cast<const bfloat16*>(a), /*C_stride_m=*/0,
327+
reinterpret_cast<float*>(c));
328+
}
329+
}
330+
236331
void sum_squared_fp32_avx2(size_t n, size_t k3, size_t k2, size_t k1,
237332
size_t a_stride_n, size_t a_stride_k3,
238333
size_t a_stride_k2, const void* a, size_t, void* c) {
Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
// Copyright 2025 Google LLC
2+
//
3+
// This source code is licensed under the BSD-style license found in the
4+
// LICENSE file in the root directory of this source tree.
5+
6+
#include "ynnpack/base/simd/x86_avx2.h"
7+
8+
#include <immintrin.h>
9+
10+
#include <cassert>
11+
#include <cstddef>
12+
#include <cstring>
13+
#include <type_traits>
14+
15+
#include "ynnpack/base/base.h"
16+
#include "ynnpack/base/bfloat16.h"
17+
#include "ynnpack/base/simd/multi_vec.h"
18+
#include "ynnpack/kernels/reduce/generic.h"
19+
#include "ynnpack/kernels/reduce/sum_accumulator.h"
20+
21+
namespace ynn {
22+
23+
namespace simd {
24+
25+
using f32x8x16 = simd::multi_vec<f32x8, 16>;
26+
using bf16x16x8 = multi_vec<bf16x16, 8>;
27+
28+
static f32x8x16 reduce_add(
29+
f32x8x16 a, bf16x16x8 b, Identity /*map_fn*/,
30+
std::integral_constant<size_t, 1> /*horizontal_factor*/) {
31+
YNN_UNROLL
32+
for (int i = 0; i < 8; ++i) {
33+
__m256i lo = _mm256_cvtepu16_epi32(_mm256_castsi256_si128(b.v[i].v));
34+
__m256i hi = _mm256_cvtepu16_epi32(_mm256_extracti128_si256(b.v[i].v, 1));
35+
36+
a.v[2 * i + 0] += f32x8{_mm256_castsi256_ps(_mm256_slli_epi32(lo, 16))};
37+
a.v[2 * i + 1] += f32x8{_mm256_castsi256_ps(_mm256_slli_epi32(hi, 16))};
38+
}
39+
40+
return a;
41+
}
42+
43+
static f32x8 reduce_add(
44+
f32x8 a, bf16x16 b, Identity /*map_fn*/,
45+
std::integral_constant<size_t, 2> /*horizontal_factor*/) {
46+
__m256 mask = _mm256_castsi256_ps(_mm256_set1_epi32(0xFFFF0000));
47+
__m256 evens = _mm256_castsi256_ps(_mm256_slli_epi32(b.v, 16));
48+
__m256 odds = _mm256_and_ps(_mm256_castsi256_ps(b.v), mask);
49+
50+
a += f32x8{odds};
51+
a += f32x8{evens};
52+
return a;
53+
}
54+
55+
static f32x8x16 reduce_add(
56+
f32x8x16 a, bf16x16x8 b, Square /*map_fn*/,
57+
std::integral_constant<size_t, 1> /*horizontal_factor*/) {
58+
YNN_UNROLL
59+
for (int i = 0; i < 8; ++i) {
60+
__m256i lo = _mm256_cvtepu16_epi32(_mm256_castsi256_si128(b.v[i].v));
61+
__m256i hi = _mm256_cvtepu16_epi32(_mm256_extracti128_si256(b.v[i].v, 1));
62+
__m256 lo_f32 = _mm256_castsi256_ps(_mm256_slli_epi32(lo, 16));
63+
__m256 hi_f32 = _mm256_castsi256_ps(_mm256_slli_epi32(hi, 16));
64+
65+
a.v[2 * i + 0].v = _mm256_fmadd_ps(lo_f32, lo_f32, a.v[2 * i + 0].v);
66+
a.v[2 * i + 1].v = _mm256_fmadd_ps(hi_f32, hi_f32, a.v[2 * i + 1].v);
67+
}
68+
69+
return a;
70+
}
71+
72+
static f32x8 reduce_add(
73+
f32x8 a, bf16x16 b, Square /*map_fn*/,
74+
std::integral_constant<size_t, 2> /*horizontal_factor*/) {
75+
__m256 mask = _mm256_castsi256_ps(_mm256_set1_epi32(0xFFFF0000));
76+
__m256 evens = _mm256_castsi256_ps(_mm256_slli_epi32(b.v, 16));
77+
__m256 odds = _mm256_and_ps(_mm256_castsi256_ps(b.v), mask);
78+
a.v = _mm256_fmadd_ps(odds, odds, a.v);
79+
a.v = _mm256_fmadd_ps(evens, evens, a.v);
80+
return a;
81+
}
82+
83+
} // namespace simd
84+
85+
using simd::f32x8;
86+
using simd::f32x8x16;
87+
using simd::bf16x16;
88+
using simd::bf16x16x8;
89+
using simd::s16x16;
90+
91+
using bf16x16_rvar = float16_wrapper<bf16x16, s16x16>;
92+
93+
void sum_bf16_fp32_avx2_fma3(size_t n, size_t k3, size_t k2, size_t k1,
94+
size_t a_stride_n, size_t a_stride_k3,
95+
size_t a_stride_k2, const void* a, size_t,
96+
void* c) {
97+
if (k1 == 1 && a_stride_n == sizeof(bfloat16)) {
98+
tiled_reduce<sum_accumulator_k1_1<bf16x16x8, f32x8x16>, bfloat16, float>(
99+
n, k3, k2, a_stride_k3, a_stride_k2,
100+
reinterpret_cast<const bfloat16*>(a), /*C_stride_m=*/0,
101+
reinterpret_cast<float*>(c));
102+
} else {
103+
tiled_reduce<sum_accumulator_x32<f32x8, 16>, bfloat16, float>(
104+
n, k3, k2, k1, a_stride_n, a_stride_k3, a_stride_k2,
105+
reinterpret_cast<const bfloat16*>(a), /*C_stride_m=*/0,
106+
reinterpret_cast<float*>(c));
107+
}
108+
}
109+
110+
void sum_squared_bf16_fp32_avx2_fma3(size_t n, size_t k3, size_t k2, size_t k1,
111+
size_t a_stride_n, size_t a_stride_k3,
112+
size_t a_stride_k2, const void* a, size_t,
113+
void* c) {
114+
if (k1 == 1 && a_stride_n == sizeof(bfloat16)) {
115+
tiled_reduce<sum_accumulator_k1_1<bf16x16x8, f32x8x16, Square>, bfloat16,
116+
float>(
117+
n, k3, k2, a_stride_k3, a_stride_k2,
118+
reinterpret_cast<const bfloat16*>(a), /*C_stride_m=*/0,
119+
reinterpret_cast<float*>(c));
120+
} else {
121+
tiled_reduce<sum_accumulator_x32<f32x8, 16, Square>, bfloat16, float>(
122+
n, k3, k2, k1, a_stride_n, a_stride_k3, a_stride_k2,
123+
reinterpret_cast<const bfloat16*>(a), /*C_stride_m=*/0,
124+
reinterpret_cast<float*>(c));
125+
}
126+
}
127+
128+
} // namespace ynn

0 commit comments

Comments
 (0)