1515#include < type_traits>
1616
1717#include " ynnpack/base/arithmetic.h"
18+ #include " ynnpack/base/simd/multi_vec.h"
1819#include " ynnpack/kernels/reduce/generic.h"
1920#include " ynnpack/kernels/reduce/min_max_accumulator.h"
2021#include " ynnpack/kernels/reduce/reduce.h"
2122#include " ynnpack/kernels/reduce/sum_accumulator.h"
2223
2324namespace ynn {
2425
26+ namespace simd {
27+
28+ using bf16x8x8 = multi_vec<bf16x8, 8 >;
29+ using f32x4x16 = multi_vec<f32x4, 16 >;
30+
31+ static f32x4x16 reduce_add (
32+ f32x4x16 a, bf16x8x8 b, Identity /* map_fn*/ ,
33+ std::integral_constant<size_t , 1 > /* horizontal_factor*/ ) {
34+ YNN_UNROLL
35+ for (int i = 0 ; i < 8 ; ++i) {
36+ float32x4_t lo =
37+ vreinterpretq_f32_u32 (vshll_n_u16 (vget_low_u16 (b.v [i].v ), 16 ));
38+ float32x4_t hi =
39+ vreinterpretq_f32_u32 (vshll_n_u16 (vget_high_u16 (b.v [i].v ), 16 ));
40+
41+ a.v [2 * i] += f32x4{lo};
42+ a.v [2 * i + 1 ] += f32x4{hi};
43+ }
44+
45+ return a;
46+ }
47+
48+ static f32x4 reduce_add (
49+ f32x4 a, bf16x8 b, Identity /* map_fn*/ ,
50+ std::integral_constant<size_t , 2 > /* horizontal_factor*/ ) {
51+ float32x4_t lo = vreinterpretq_f32_u32 (vshll_n_u16 (vget_low_u16 (b.v ), 16 ));
52+ float32x4_t hi = vreinterpretq_f32_u32 (vshll_n_u16 (vget_high_u16 (b.v ), 16 ));
53+ #ifndef __aarch64__
54+ float32x2_t pair_lo = vpadd_f32 (vget_low_f32 (lo), vget_high_f32 (lo));
55+ float32x2_t pair_hi = vpadd_f32 (vget_low_f32 (hi), vget_high_f32 (hi));
56+ return a += f32x4{vcombine_f32 (pair_lo, pair_hi)};
57+ #else
58+ return a += f32x4{vpaddq_f32 (lo, hi)};
59+ #endif
60+ }
61+
62+ static f32x4x16 reduce_add (
63+ f32x4x16 a, bf16x8x8 b, Square /* map_fn*/ ,
64+ std::integral_constant<size_t , 1 > /* horizontal_factor*/ ) {
65+ YNN_UNROLL
66+ for (int i = 0 ; i < 8 ; ++i) {
67+ float32x4_t lo =
68+ vreinterpretq_f32_u32 (vshll_n_u16 (vget_low_u16 (b.v [i].v ), 16 ));
69+ float32x4_t hi =
70+ vreinterpretq_f32_u32 (vshll_n_u16 (vget_high_u16 (b.v [i].v ), 16 ));
71+ a.v [2 * i].v = vmlaq_f32 (a.v [2 * i].v , lo, lo);
72+ a.v [2 * i + 1 ].v = vmlaq_f32 (a.v [2 * i + 1 ].v , hi, hi);
73+ }
74+
75+ return a;
76+ }
77+
78+ static f32x4 reduce_add (
79+ f32x4 a, bf16x8 b, Square /* map_fn*/ ,
80+ std::integral_constant<size_t , 2 > /* horizontal_factor*/ ) {
81+ float32x4_t lo = vreinterpretq_f32_u32 (vshll_n_u16 (vget_low_u16 (b.v ), 16 ));
82+ float32x4_t hi = vreinterpretq_f32_u32 (vshll_n_u16 (vget_high_u16 (b.v ), 16 ));
83+ #ifndef __aarch64__
84+ float32x4_t sq_lo = vmulq_f32 (lo, lo);
85+ float32x4_t sq_hi = vmulq_f32 (hi, hi);
86+ float32x2_t pair_lo = vpadd_f32 (vget_low_f32 (sq_lo), vget_high_f32 (sq_lo));
87+ float32x2_t pair_hi = vpadd_f32 (vget_low_f32 (sq_hi), vget_high_f32 (sq_hi));
88+ return a += f32x4{vcombine_f32 (pair_lo, pair_hi)};
89+ #else
90+ return a += f32x4{vpaddq_f32 (vmulq_f32 (lo, lo), vmulq_f32 (hi, hi))};
91+ #endif
92+ }
93+
94+ } // namespace simd
95+
96+ using simd::f32x4;
97+ using simd::f32x4x16;
2598using simd::bf16x8;
99+ using simd::bf16x8x8;
26100using simd::f16x8;
27- using simd::f32x4;
28101using simd::s16x8;
29102using simd::s8x16;
30103using simd::u8x16;
@@ -50,6 +123,42 @@ MIN_MAX_KERNEL(max_fp16_4x8_neon, dummy_t, f16x8_rvar, half, 8);
50123MIN_MAX_KERNEL (max_uint8_4x16_neon, dummy_t , u8x16, uint8_t , 16 );
51124MIN_MAX_KERNEL (max_int8_4x16_neon, dummy_t , s8x16, int8_t , 16 );
52125
126+ void sum_bf16_fp32_neon (size_t n, size_t k3, size_t k2, size_t k1,
127+ size_t a_stride_n, size_t a_stride_k3,
128+ size_t a_stride_k2, const void * a, size_t ,
129+ void * c) {
130+ if (k1 == 1 && a_stride_n == sizeof (bfloat16)) {
131+ tiled_reduce<sum_accumulator_k1_1<bf16x8x8, f32x4x16, Square>, bfloat16,
132+ float >(
133+ n, k3, k2, a_stride_k3, a_stride_k2,
134+ reinterpret_cast <const bfloat16*>(a), /* C_stride_m=*/ 0 ,
135+ reinterpret_cast <float *>(c));
136+ } else {
137+ tiled_reduce<sum_accumulator_x32<f32x4, 8 >, bfloat16, float >(
138+ n, k3, k2, k1, a_stride_n, a_stride_k3, a_stride_k2,
139+ reinterpret_cast <const bfloat16*>(a), /* C_stride_m=*/ 0 ,
140+ reinterpret_cast <float *>(c));
141+ }
142+ }
143+
144+ void sum_squared_bf16_fp32_neon (size_t n, size_t k3, size_t k2, size_t k1,
145+ size_t a_stride_n, size_t a_stride_k3,
146+ size_t a_stride_k2, const void * a, size_t ,
147+ void * c) {
148+ if (k1 == 1 && a_stride_n == sizeof (bfloat16)) {
149+ tiled_reduce<sum_accumulator_k1_1<bf16x8x8, f32x4x16, Square>, bfloat16,
150+ float >(
151+ n, k3, k2, a_stride_k3, a_stride_k2,
152+ reinterpret_cast <const bfloat16*>(a), /* C_stride_m=*/ 0 ,
153+ reinterpret_cast <float *>(c));
154+ } else {
155+ tiled_reduce<sum_accumulator_x32<f32x4, 8 , Square>, bfloat16, float >(
156+ n, k3, k2, k1, a_stride_n, a_stride_k3, a_stride_k2,
157+ reinterpret_cast <const bfloat16*>(a), /* C_stride_m=*/ 0 ,
158+ reinterpret_cast <float *>(c));
159+ }
160+ }
161+
53162void sum_fp32_neon (size_t n, size_t k3, size_t k2, size_t k1,
54163 size_t a_stride_n, size_t a_stride_k3, size_t a_stride_k2,
55164 const void * a, size_t , void * c) {
0 commit comments