1313#include < cstring>
1414#include < type_traits>
1515
16+ #include " ynnpack/base/base.h"
1617#include " ynnpack/base/bfloat16.h"
1718#include " ynnpack/base/half.h"
1819#include " ynnpack/base/simd/multi_vec.h"
@@ -26,8 +27,10 @@ namespace ynn {
2627namespace simd {
2728
2829using f32x8x8 = simd::multi_vec<f32x8, 8 >;
30+ using f32x8x16 = simd::multi_vec<f32x8, 16 >;
2931using s32x8x2 = multi_vec<s32x8, 2 >;
3032using s32x8x4 = multi_vec<s32x8, 4 >;
33+ using bf16x16x8 = multi_vec<bf16x16, 8 >;
3134
3235static s32x8x4& operator +=(s32x8x4& a, s8x32 b) {
3336 s8x16 b_lo = extract<0 >(b, s8x16{});
@@ -115,14 +118,72 @@ static s32x8 reduce_add(
115118 return a += s32x8 (_mm256_madd_epi16 (b_16, b_16));
116119}
117120
121+ static f32x8x16 reduce_add (
122+ f32x8x16 a, bf16x16x8 b, Identity /* map_fn*/ ,
123+ std::integral_constant<size_t , 1 > /* horizontal_factor*/ ) {
124+ YNN_UNROLL
125+ for (int i = 0 ; i < 8 ; ++i) {
126+ __m256i lo = _mm256_cvtepu16_epi32 (_mm256_castsi256_si128 (b.v [i].v ));
127+ __m256i hi = _mm256_cvtepu16_epi32 (_mm256_extracti128_si256 (b.v [i].v , 1 ));
128+
129+ a.v [2 * i + 0 ] += f32x8{_mm256_castsi256_ps (_mm256_slli_epi32 (lo, 16 ))};
130+ a.v [2 * i + 1 ] += f32x8{_mm256_castsi256_ps (_mm256_slli_epi32 (hi, 16 ))};
131+ }
132+
133+ return a;
134+ }
135+
136+ static f32x8 reduce_add (
137+ f32x8 a, bf16x16 b, Identity /* map_fn*/ ,
138+ std::integral_constant<size_t , 2 > /* horizontal_factor*/ ) {
139+ __m256 mask = _mm256_castsi256_ps (_mm256_set1_epi32 (0xFFFF0000 ));
140+ __m256 evens = _mm256_castsi256_ps (_mm256_slli_epi32 (b.v , 16 ));
141+ __m256 odds = _mm256_and_ps (_mm256_castsi256_ps (b.v ), mask);
142+
143+ a += f32x8{odds};
144+ a += f32x8{evens};
145+ return a;
146+ }
147+
148+ static f32x8x16 reduce_add (
149+ f32x8x16 a, bf16x16x8 b, Square /* map_fn*/ ,
150+ std::integral_constant<size_t , 1 > /* horizontal_factor*/ ) {
151+ YNN_UNROLL
152+ for (int i = 0 ; i < 8 ; ++i) {
153+ __m256i lo_u32 = _mm256_cvtepu16_epi32 (_mm256_castsi256_si128 (b.v [i].v ));
154+ __m256i hi_u32 =
155+ _mm256_cvtepu16_epi32 (_mm256_extracti128_si256 (b.v [i].v , 1 ));
156+ __m256 lo_f32 = _mm256_castsi256_ps (_mm256_slli_epi32 (lo_u32, 16 ));
157+ __m256 hi_f32 = _mm256_castsi256_ps (_mm256_slli_epi32 (hi_u32, 16 ));
158+
159+ a.v [2 * i + 0 ] += f32x8{_mm256_mul_ps (lo_f32, lo_f32)};
160+ a.v [2 * i + 1 ] += f32x8{_mm256_mul_ps (hi_f32, hi_f32)};
161+ }
162+
163+ return a;
164+ }
165+
166+ static f32x8 reduce_add (
167+ f32x8 a, bf16x16 b, Square /* map_fn*/ ,
168+ std::integral_constant<size_t , 2 > /* horizontal_factor*/ ) {
169+ __m256 mask = _mm256_castsi256_ps (_mm256_set1_epi32 (0xFFFF0000 ));
170+ __m256 evens = _mm256_castsi256_ps (_mm256_slli_epi32 (b.v , 16 ));
171+ __m256 odds = _mm256_and_ps (_mm256_castsi256_ps (b.v ), mask);
172+ a += f32x8{_mm256_mul_ps (odds, odds)};
173+ a += f32x8{_mm256_mul_ps (evens, evens)};
174+ return a;
175+ }
176+
118177} // namespace simd
119178
120179using simd::s32x8;
121180using simd::s32x8x2;
122181using simd::s32x8x4;
123182using simd::f32x8;
124183using simd::f32x8x8;
184+ using simd::f32x8x16;
125185using simd::bf16x16;
186+ using simd::bf16x16x8;
126187using simd::f16x16;
127188using simd::s16x16;
128189using simd::s8x16;
@@ -233,6 +294,40 @@ void sum_squared_uint8_int32_avx2(size_t n, size_t k3, size_t k2, size_t k1,
233294 }
234295}
235296
297+ void sum_bf16_fp32_avx2 (size_t n, size_t k3, size_t k2, size_t k1,
298+ size_t a_stride_n, size_t a_stride_k3,
299+ size_t a_stride_k2, const void * a, size_t , void * c) {
300+ if (k1 == 1 && a_stride_n == sizeof (bfloat16)) {
301+ tiled_reduce<sum_accumulator_k1_1<bf16x16x8, f32x8x16>, bfloat16, float >(
302+ n, k3, k2, a_stride_k3, a_stride_k2,
303+ reinterpret_cast <const bfloat16*>(a), /* C_stride_m=*/ 0 ,
304+ reinterpret_cast <float *>(c));
305+ } else {
306+ tiled_reduce<sum_accumulator_x32<f32x8, 16 >, bfloat16, float >(
307+ n, k3, k2, k1, a_stride_n, a_stride_k3, a_stride_k2,
308+ reinterpret_cast <const bfloat16*>(a), /* C_stride_m=*/ 0 ,
309+ reinterpret_cast <float *>(c));
310+ }
311+ }
312+
313+ void sum_squared_bf16_fp32_avx2 (size_t n, size_t k3, size_t k2, size_t k1,
314+ size_t a_stride_n, size_t a_stride_k3,
315+ size_t a_stride_k2, const void * a, size_t ,
316+ void * c) {
317+ if (k1 == 1 && a_stride_n == sizeof (bfloat16)) {
318+ tiled_reduce<sum_accumulator_k1_1<bf16x16x8, f32x8x16, Square>, bfloat16,
319+ float >(
320+ n, k3, k2, a_stride_k3, a_stride_k2,
321+ reinterpret_cast <const bfloat16*>(a), /* C_stride_m=*/ 0 ,
322+ reinterpret_cast <float *>(c));
323+ } else {
324+ tiled_reduce<sum_accumulator_x32<f32x8, 16 , Square>, bfloat16, float >(
325+ n, k3, k2, k1, a_stride_n, a_stride_k3, a_stride_k2,
326+ reinterpret_cast <const bfloat16*>(a), /* C_stride_m=*/ 0 ,
327+ reinterpret_cast <float *>(c));
328+ }
329+ }
330+
236331void sum_squared_fp32_avx2 (size_t n, size_t k3, size_t k2, size_t k1,
237332 size_t a_stride_n, size_t a_stride_k3,
238333 size_t a_stride_k2, const void * a, size_t , void * c) {
0 commit comments