vsps[3] = _mm256_unpackhi_pd(vsp23_02, vsp23_13);
for (int k = 0; k < 4; k++) {
+#if USE_X86_EXT_INTRIN >= 9
__m256d vmask = _mm256_castsi256_pd(_mm256_cvtepi32_epi64(_mm_cmplt_epi32(_mm_set1_epi32(j + k), vcounts)));
+#else
+ __m128i vmask32 = _mm_cmplt_epi32(_mm_set1_epi32(j + k), vcounts);
+ __m256d vmask = _mm256_insertf128_pd(
+ _mm256_castpd128_pd256(_mm_castsi128_pd(_mm_unpacklo_epi32(vmask32, vmask32))),
+ _mm_castsi128_pd(_mm_unpackhi_epi32(vmask32, vmask32)),
+ 1
+ );
+#endif
vdb0 = _mm256_blendv_pd(vdb0, vsps[k], vmask);
vdb2 = _mm256_blendv_pd(vdb2, MM256_FMA_PD(vdc0, vdb0, MM256_FMA4_PD(vdc1, vdb1, vdc2, vdb2, vdc3, vdb3, vdc4, vdb4)), vmask);
if (i >= batch_size)
break;
+ int32 count0 = counts[i];
+ int32 count1 = i + 1 < batch_size ? counts[i + 1] : 0;
+
__m128i vcounts = _mm_set_epi32(
0,
0,
- i + 1 < batch_size ? counts[i + 1] : 0,
- counts[i]
+ count1,
+ count0
);
__m128d vdb01_0 = _mm_loadu_pd(&dbs[i][0]);
dcs[i][4]
);
- int32 count_max = _mm_cvtsi128_si32(_mm_max_epi32(vcounts, _mm_shuffle_epi32(vcounts, 1)));
+ int32 count_max = count0 < count1 ? count1 : count0;
for (int32 j = 0; j < count_max; j += 2) {
__m128d vsp01_0 = j < counts[i] ? _mm_loadu_pd(&sps[i][j]) : _mm_setzero_pd();
vsps[1] = _mm_unpackhi_pd(vsp01_0, vsp01_1);
for (int k = 0; k < 2; k++) {
- __m128d vmask = _mm_castsi128_pd(_mm_cvtepi32_epi64(_mm_cmplt_epi32(_mm_set1_epi32(j + k), vcounts)));
+ __m128i vmask32 = _mm_cmplt_epi32(_mm_set1_epi32(j + k), vcounts);
+ __m128d vmask = _mm_castsi128_pd(_mm_unpacklo_epi32(vmask32, vmask32));
vdb0 = MM_BLENDV_PD(vdb0, vsps[k], vmask);
vdb2 = MM_BLENDV_PD(vdb2, MM_FMA5_PD(vdc0, vdb0, vdc1, vdb1, vdc2, vdb2, vdc3, vdb3, vdc4, vdb4), vmask);
vsps[3] = _mm256_unpackhi_pd(vsp23_02, vsp23_13);
for (int k = 0; k < 4; k++) {
+#if USE_X86_EXT_INTRIN >= 9
__m256d vmask = _mm256_castsi256_pd(_mm256_cvtepi32_epi64(_mm_cmplt_epi32(_mm_set1_epi32(j + k), vcounts)));
+#else
+ __m128i vmask32 = _mm_cmplt_epi32(_mm_set1_epi32(j + k), vcounts);
+ __m256d vmask = _mm256_insertf128_pd(
+ _mm256_castpd128_pd256(_mm_castsi128_pd(_mm_unpacklo_epi32(vmask32, vmask32))),
+ _mm_castsi128_pd(_mm_unpackhi_epi32(vmask32, vmask32)),
+ 1
+ );
+#endif
vdb1 = _mm256_blendv_pd(vdb1, MM256_FMA_PD(_mm256_sub_pd(vsps[k], vdb0), vdc1, vdb1), vmask);
vdb0 = _mm256_blendv_pd(vdb0, _mm256_add_pd(vdb0, vdb1), vmask);
if (i >= batch_size)
break;
+ int32 count0 = counts[i];
+ int32 count1 = i + 1 < batch_size ? counts[i + 1] : 0;
+
__m128i vcounts = _mm_set_epi32(
0,
0,
- i + 1 < batch_size ? counts[i + 1] : 0,
- counts[i]
+ count1,
+ count0
);
__m128d vdb01_0 = _mm_loadu_pd(dbs[i]);
__m128d vdc0 = _mm_unpacklo_pd(vdc01_0, vdc01_1);
__m128d vdc1 = _mm_unpackhi_pd(vdc01_0, vdc01_1);
- int32 count_max = _mm_cvtsi128_si32(_mm_max_epi32(vcounts, _mm_shuffle_epi32(vcounts, 1)));
+ int32 count_max = count0 < count1 ? count1 : count0;
for (int32 j = 0; j < count_max; j += 2) {
__m128d vsp01_0 = j < counts[i] ? _mm_loadu_pd(&sps[i][j]) : _mm_setzero_pd();
vsps[1] = _mm_unpackhi_pd(vsp01_0, vsp01_1);
for (int k = 0; k < 2; k++) {
- __m128d vmask = _mm_castsi128_pd(_mm_cvtepi32_epi64(_mm_cmplt_epi32(_mm_set1_epi32(j + k), vcounts)));
+ __m128i vmask32 = _mm_cmplt_epi32(_mm_set1_epi32(j + k), vcounts);
+ __m128d vmask = _mm_castsi128_pd(_mm_unpacklo_epi32(vmask32, vmask32));
vdb1 = MM_BLENDV_PD(vdb1, MM_FMA_PD(_mm_sub_pd(vsps[k], vdb0), vdc1, vdb1), vmask);
vdb0 = MM_BLENDV_PD(vdb0, _mm_add_pd(vdb0, vdb1), vmask);
vsps[3] = _mm256_unpackhi_pd(vsp23_02, vsp23_13);
for (int k = 0; k < 4; k++) {
+#if USE_X86_EXT_INTRIN >= 9
__m256d vmask = _mm256_castsi256_pd(_mm256_cvtepi32_epi64(_mm_cmplt_epi32(_mm_set1_epi32(j + k), vcounts)));
+#else
+ __m128i vmask32 = _mm_cmplt_epi32(_mm_set1_epi32(j + k), vcounts);
+ __m256d vmask = _mm256_insertf128_pd(
+ _mm256_castpd128_pd256(_mm_castsi128_pd(_mm_unpacklo_epi32(vmask32, vmask32))),
+ _mm_castsi128_pd(_mm_unpackhi_epi32(vmask32, vmask32)),
+ 1
+ );
+#endif
vdb1 = _mm256_blendv_pd(vdb1, MM256_FMA_PD(_mm256_sub_pd(vsps[k], vdb0), vdc1, vdb1), vmask);
vdb0 = _mm256_blendv_pd(vdb0, _mm256_add_pd(vdb0, vdb1), vmask);
if (i >= batch_size)
break;
+ int32 count0 = counts[i];
+ int32 count1 = i + 1 < batch_size ? counts[i + 1] : 0;
+
__m128i vcounts = _mm_set_epi32(
0,
0,
- i + 1 < batch_size ? counts[i + 1] : 0,
- counts[i]
+ count1,
+ count0
);
__m128d vdb01_0 = _mm_loadu_pd(dbs[i]);
__m128d vdc0 = _mm_unpacklo_pd(vdc01_0, vdc01_1);
__m128d vdc1 = _mm_unpackhi_pd(vdc01_0, vdc01_1);
- int32 count_max = _mm_cvtsi128_si32(_mm_max_epi32(vcounts, _mm_shuffle_epi32(vcounts, 1)));
+ int32 count_max = count0 < count1 ? count1 : count0;
for (int32 j = 0; j < count_max; j += 2) {
__m128d vsp01_0 = j < counts[i] ? _mm_loadu_pd(&sps[i][j]) : _mm_setzero_pd();
vsps[1] = _mm_unpackhi_pd(vsp01_0, vsp01_1);
for (int k = 0; k < 2; k++) {
- __m128d vmask = _mm_castsi128_pd(_mm_cvtepi32_epi64(_mm_cmplt_epi32(_mm_set1_epi32(j + k), vcounts)));
+ __m128i vmask32 = _mm_cmplt_epi32(_mm_set1_epi32(j + k), vcounts);
+ __m128d vmask = _mm_castsi128_pd(_mm_unpacklo_epi32(vmask32, vmask32));
vdb1 = MM_BLENDV_PD(vdb1, MM_FMA_PD(_mm_sub_pd(vsps[k], vdb0), vdc1, vdb1), vmask);
vdb0 = MM_BLENDV_PD(vdb0, _mm_add_pd(vdb0, vdb1), vmask);