WARNING: THIS SITE IS A MIRROR OF GITHUB.COM / IT CANNOT LOGIN OR REGISTER ACCOUNTS / THE CONTENTS ARE PROVIDED AS-IS / THIS SITE ASSUMES NO RESPONSIBILITY FOR ANY DISPLAYED CONTENT OR LINKS / IF YOU FOUND SOMETHING MAY NOT GOOD FOR EVERYONE, CONTACT ADMIN AT ilovescratch@foxmail.com
Skip to content

Commit a342cf6

Browse files
dsharletgxnnpack-bot
authored andcommitted
Minor reduction cleanups
- Move neondot kernels to the correct #if guard - Minor reformatting to make things a little easier to read PiperOrigin-RevId: 841025925
1 parent 995241f commit a342cf6

File tree

5 files changed

+16
-11
lines changed

5 files changed

+16
-11
lines changed

ynnpack/kernels/reduce/arm_neon.cc

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ static f32x4x16 reduce_add(
3838
float32x4_t hi =
3939
vreinterpretq_f32_u32(vshll_n_u16(vget_high_u16(b.v[i].v), 16));
4040

41-
a.v[2 * i] += f32x4{lo};
41+
a.v[2 * i + 0] += f32x4{lo};
4242
a.v[2 * i + 1] += f32x4{hi};
4343
}
4444

@@ -54,7 +54,8 @@ static f32x4 reduce_add(
5454
vreinterpretq_f32_u32(vandq_u32(pairs, vdupq_n_u32(0xFFFF0000)));
5555

5656
a += f32x4{odd};
57-
return a += f32x4{even};
57+
a += f32x4{even};
58+
return a;
5859
}
5960

6061
static f32x4x16 reduce_add(
@@ -66,7 +67,7 @@ static f32x4x16 reduce_add(
6667
vreinterpretq_f32_u32(vshll_n_u16(vget_low_u16(b.v[i].v), 16));
6768
float32x4_t hi =
6869
vreinterpretq_f32_u32(vshll_n_u16(vget_high_u16(b.v[i].v), 16));
69-
a.v[2 * i].v = vmlaq_f32(a.v[2 * i].v, lo, lo);
70+
a.v[2 * i + 0].v = vmlaq_f32(a.v[2 * i + 0].v, lo, lo);
7071
a.v[2 * i + 1].v = vmlaq_f32(a.v[2 * i + 1].v, hi, hi);
7172
}
7273

ynnpack/kernels/reduce/arm_neonbf16.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ static f32x4x16& operator+=(f32x4x16& a, bf16x8x8 b) {
3636
YNN_UNROLL
3737
for (size_t i = 0; i < 8; ++i) {
3838
uint16x8x2_t zipped = vzipq_u16(b.v[i].v, zero);
39-
a.v[2 * i].v = vbfdotq_f32(a.v[2 * i].v,
39+
a.v[2 * i + 0].v = vbfdotq_f32(a.v[2 * i + 0].v,
4040
vreinterpretq_bf16_u16(zipped.val[0]), one);
4141
a.v[2 * i + 1].v = vbfdotq_f32(a.v[2 * i + 1].v,
4242
vreinterpretq_bf16_u16(zipped.val[1]), one);
@@ -62,7 +62,7 @@ static f32x4x16 reduce_add(
6262
YNN_UNROLL
6363
for (size_t i = 0; i < 8; ++i) {
6464
uint16x8x2_t zipped = vzipq_u16(b.v[i].v, zero);
65-
a.v[2 * i].v = vbfdotq_f32(a.v[2 * i].v,
65+
a.v[2 * i + 0].v = vbfdotq_f32(a.v[2 * i + 0].v,
6666
vreinterpretq_bf16_u16(zipped.val[0]),
6767
vreinterpretq_bf16_u16(zipped.val[0]));
6868
a.v[2 * i + 1].v = vbfdotq_f32(a.v[2 * i + 1].v,

ynnpack/kernels/reduce/arm_neonfp16arith.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ static f32x4x16& operator+=(f32x4x16& a, f16x8x8 b) {
4646
f32x4 b_1(vcvt_f32_f16(vget_high_f16(
4747
reinterpret_cast<float16x8_t>(b.v[i].v))));
4848

49-
a.v[2 * i] += b_0;
49+
a.v[2 * i + 0] += b_0;
5050
a.v[2 * i + 1] += b_1;
5151
}
5252

@@ -75,7 +75,7 @@ static f32x4x16 reduce_add(
7575
f32x4 b_1(vcvt_f32_f16(vget_high_f16(
7676
reinterpret_cast<float16x8_t>(b.v[i].v))));
7777

78-
a.v[2 * i] += b_0 * b_0;
78+
a.v[2 * i + 0] += b_0 * b_0;
7979
a.v[2 * i + 1] += b_1 * b_1;
8080
}
8181

ynnpack/kernels/reduce/sum.inc

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,13 @@ YNN_UNARY_REDUCE_KERNEL(arch_flag::neonbf16, sum_bf16_fp32_neonbf16, bfloat16, f
66
#ifdef YNN_ARCH_ARM_NEONFP16ARITH
77
YNN_UNARY_REDUCE_KERNEL(arch_flag::neonfp16arith, sum_fp16_fp32_neonfp16arith, half, float)
88
#endif // YNN_ARCH_ARM_NEONFP16ARITH
9+
#ifdef YNN_ARCH_ARM_NEONDOT
10+
YNN_UNARY_REDUCE_KERNEL(arch_flag::neondot, sum_int8_int32_neondot, int8_t, int32_t)
11+
YNN_UNARY_REDUCE_KERNEL(arch_flag::neondot, sum_uint8_int32_neondot, uint8_t, int32_t)
12+
#endif // YNN_ARCH_ARM_NEONDOT
913
#ifdef YNN_ARCH_ARM_NEON
1014
YNN_UNARY_REDUCE_KERNEL(arch_flag::neon, sum_fp32_neon, float, float)
1115
YNN_UNARY_REDUCE_KERNEL(arch_flag::neon, sum_bf16_fp32_neon, bfloat16, float)
12-
YNN_UNARY_REDUCE_KERNEL(arch_flag::neondot, sum_int8_int32_neondot, int8_t, int32_t)
13-
YNN_UNARY_REDUCE_KERNEL(arch_flag::neondot, sum_uint8_int32_neondot, uint8_t, int32_t)
1416
#endif // YNN_ARCH_ARM_NEON
1517

1618
#ifdef YNN_ARCH_X86_AVX512BF16

ynnpack/kernels/reduce/sum_squared.inc

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,13 @@ YNN_UNARY_REDUCE_KERNEL(arch_flag::neonbf16, sum_squared_bf16_fp32_neonbf16, bfl
66
#ifdef YNN_ARCH_ARM_NEONFP16ARITH
77
YNN_UNARY_REDUCE_KERNEL(arch_flag::neonfp16arith, sum_squared_fp16_fp32_neonfp16arith, half, float)
88
#endif // YNN_ARCH_ARM_NEONFP16ARITH
9+
#ifdef YNN_ARCH_ARM_NEONDOT
10+
YNN_UNARY_REDUCE_KERNEL(arch_flag::neondot, sum_squared_int8_int32_neondot, int8_t, int32_t)
11+
YNN_UNARY_REDUCE_KERNEL(arch_flag::neondot, sum_squared_uint8_int32_neondot, uint8_t, int32_t)
12+
#endif // YNN_ARCH_ARM_NEONDOT
913
#ifdef YNN_ARCH_ARM_NEON
1014
YNN_UNARY_REDUCE_KERNEL(arch_flag::neon, sum_squared_fp32_neon, float, float)
1115
YNN_UNARY_REDUCE_KERNEL(arch_flag::neon, sum_squared_bf16_fp32_neon, bfloat16, float)
12-
YNN_UNARY_REDUCE_KERNEL(arch_flag::neondot, sum_squared_int8_int32_neondot, int8_t, int32_t)
13-
YNN_UNARY_REDUCE_KERNEL(arch_flag::neondot, sum_squared_uint8_int32_neondot, uint8_t, int32_t)
1416
#endif // YNN_ARCH_ARM_NEON
1517

1618
#ifdef YNN_ARCH_X86_AVX512BF16

0 commit comments

Comments
 (0)