1515#include < type_traits>
1616
1717#include " ynnpack/base/arithmetic.h"
18+ #include " ynnpack/base/simd/multi_vec.h"
1819#include " ynnpack/kernels/reduce/generic.h"
1920#include " ynnpack/kernels/reduce/min_max_accumulator.h"
2021#include " ynnpack/kernels/reduce/reduce.h"
2122#include " ynnpack/kernels/reduce/sum_accumulator.h"
2223
2324namespace ynn {
2425
26+ namespace simd {
27+
28+ using bf16x8x8 = multi_vec<bf16x8, 8 >;
29+ using f32x4x16 = multi_vec<f32x4, 16 >;
30+
31+ static f32x4x16 reduce_add (
32+ f32x4x16 a, bf16x8x8 b, Identity /* map_fn*/ ,
33+ std::integral_constant<size_t , 1 > /* horizontal_factor*/ ) {
34+ YNN_UNROLL
35+ for (int i = 0 ; i < 8 ; ++i) {
36+ float32x4_t lo =
37+ vreinterpretq_f32_u32 (vshll_n_u16 (vget_low_u16 (b.v [i].v ), 16 ));
38+ float32x4_t hi =
39+ vreinterpretq_f32_u32 (vshll_n_u16 (vget_high_u16 (b.v [i].v ), 16 ));
40+
41+ a.v [2 * i] += f32x4{lo};
42+ a.v [2 * i + 1 ] += f32x4{hi};
43+ }
44+
45+ return a;
46+ }
47+
48+ static f32x4 reduce_add (
49+ f32x4 a, bf16x8 b, Identity /* map_fn*/ ,
50+ std::integral_constant<size_t , 2 > /* horizontal_factor*/ ) {
51+ uint32x4_t pairs = vreinterpretq_u32_u16 (b.v );
52+ float32x4_t even = vreinterpretq_f32_u32 (vshlq_n_u32 (pairs, 16 ));
53+ float32x4_t odd =
54+ vreinterpretq_f32_u32 (vandq_u32 (pairs, vdupq_n_u32 (0xFFFF0000 )));
55+
56+ a += f32x4{odd};
57+ return a += f32x4{even};
58+ }
59+
60+ static f32x4x16 reduce_add (
61+ f32x4x16 a, bf16x8x8 b, Square /* map_fn*/ ,
62+ std::integral_constant<size_t , 1 > /* horizontal_factor*/ ) {
63+ YNN_UNROLL
64+ for (int i = 0 ; i < 8 ; ++i) {
65+ float32x4_t lo =
66+ vreinterpretq_f32_u32 (vshll_n_u16 (vget_low_u16 (b.v [i].v ), 16 ));
67+ float32x4_t hi =
68+ vreinterpretq_f32_u32 (vshll_n_u16 (vget_high_u16 (b.v [i].v ), 16 ));
69+ a.v [2 * i].v = vmlaq_f32 (a.v [2 * i].v , lo, lo);
70+ a.v [2 * i + 1 ].v = vmlaq_f32 (a.v [2 * i + 1 ].v , hi, hi);
71+ }
72+
73+ return a;
74+ }
75+
76+ static f32x4 reduce_add (
77+ f32x4 a, bf16x8 b, Square /* map_fn*/ ,
78+ std::integral_constant<size_t , 2 > /* horizontal_factor*/ ) {
79+ uint32x4_t pairs = vreinterpretq_u32_u16 (b.v );
80+ float32x4_t even = vreinterpretq_f32_u32 (vshlq_n_u32 (pairs, 16 ));
81+ float32x4_t odd =
82+ vreinterpretq_f32_u32 (vandq_u32 (pairs, vdupq_n_u32 (0xFFFF0000 )));
83+
84+ a.v = vmlaq_f32 (a.v , odd, odd);
85+ a.v = vmlaq_f32 (a.v , even, even);
86+ return a;
87+ }
88+
89+ } // namespace simd
90+
91+ using simd::f32x4;
92+ using simd::f32x4x16;
2593using simd::bf16x8;
94+ using simd::bf16x8x8;
2695using simd::f16x8;
27- using simd::f32x4;
2896using simd::s16x8;
2997using simd::s8x16;
3098using simd::u8x16;
@@ -50,6 +118,41 @@ MIN_MAX_KERNEL(max_fp16_4x8_neon, dummy_t, f16x8_rvar, half, 8);
50118MIN_MAX_KERNEL (max_uint8_4x16_neon, dummy_t , u8x16, uint8_t , 16 );
51119MIN_MAX_KERNEL (max_int8_4x16_neon, dummy_t , s8x16, int8_t , 16 );
52120
121+ void sum_bf16_fp32_neon (size_t n, size_t k3, size_t k2, size_t k1,
122+ size_t a_stride_n, size_t a_stride_k3,
123+ size_t a_stride_k2, const void * a, size_t ,
124+ void * c) {
125+ if (k1 == 1 && a_stride_n == sizeof (bfloat16)) {
126+ tiled_reduce<sum_accumulator_k1_1<bf16x8x8, f32x4x16>, bfloat16, float >(
127+ n, k3, k2, a_stride_k3, a_stride_k2,
128+ reinterpret_cast <const bfloat16*>(a), /* C_stride_m=*/ 0 ,
129+ reinterpret_cast <float *>(c));
130+ } else {
131+ tiled_reduce<sum_accumulator_x32<f32x4, 8 >, bfloat16, float >(
132+ n, k3, k2, k1, a_stride_n, a_stride_k3, a_stride_k2,
133+ reinterpret_cast <const bfloat16*>(a), /* C_stride_m=*/ 0 ,
134+ reinterpret_cast <float *>(c));
135+ }
136+ }
137+
138+ void sum_squared_bf16_fp32_neon (size_t n, size_t k3, size_t k2, size_t k1,
139+ size_t a_stride_n, size_t a_stride_k3,
140+ size_t a_stride_k2, const void * a, size_t ,
141+ void * c) {
142+ if (k1 == 1 && a_stride_n == sizeof (bfloat16)) {
143+ tiled_reduce<sum_accumulator_k1_1<bf16x8x8, f32x4x16, Square>, bfloat16,
144+ float >(
145+ n, k3, k2, a_stride_k3, a_stride_k2,
146+ reinterpret_cast <const bfloat16*>(a), /* C_stride_m=*/ 0 ,
147+ reinterpret_cast <float *>(c));
148+ } else {
149+ tiled_reduce<sum_accumulator_x32<f32x4, 8 , Square>, bfloat16, float >(
150+ n, k3, k2, k1, a_stride_n, a_stride_k3, a_stride_k2,
151+ reinterpret_cast <const bfloat16*>(a), /* C_stride_m=*/ 0 ,
152+ reinterpret_cast <float *>(c));
153+ }
154+ }
155+
53156void sum_fp32_neon (size_t n, size_t k3, size_t k2, size_t k1,
54157 size_t a_stride_n, size_t a_stride_k3, size_t a_stride_k2,
55158 const void * a, size_t , void * c) {
0 commit comments