OSDN Git Service

[filter] Add AVX version of sample_filter_LPF24_batch() and recalc_filter_LPF24_batch()
authorStarg <starg@users.osdn.me>
Sun, 2 May 2021 03:27:07 +0000 (12:27 +0900)
committerStarg <starg@users.osdn.me>
Sun, 2 May 2021 03:27:07 +0000 (12:27 +0900)
timidity/filter.c

index d88098e..b58fbfb 100644 (file)
@@ -3939,7 +3939,165 @@ static inline __mmask8 generate_mask8_for_count(int32 offset, int32 count)
 
 #endif
 
-#if (USE_X86_EXT_INTRIN >= 3) && defined(DATA_T_DOUBLE) && defined(FLOAT_T_DOUBLE)
+#if (USE_X86_EXT_INTRIN >= 8) && defined(DATA_T_DOUBLE) && defined(FLOAT_T_DOUBLE)
+
+static void sample_filter_LPF24_batch(int batch_size, FILTER_T **dcs, FILTER_T **dbs, DATA_T **sps, int32 *counts)
+{
+       for (int i = 0; i < MIX_VOICE_BATCH_SIZE; i += 4) {
+               if (i >= batch_size)
+                       break;
+
+               __m128i vcounts = _mm_set_epi32(
+                       i + 3 < batch_size ? counts[i + 3] : 0,
+                       i + 2 < batch_size ? counts[i + 2] : 0,
+                       i + 1 < batch_size ? counts[i + 1] : 0,
+                       counts[i]
+               );
+
+               __m256d vdb0123_0 = _mm256_loadu_pd(&dbs[i][0]);
+               __m256d vdb0123_1 = i + 1 < batch_size ? _mm256_loadu_pd(&dbs[i + 1][0]) : _mm256_setzero_pd();
+               __m256d vdb0123_2 = i + 2 < batch_size ? _mm256_loadu_pd(&dbs[i + 2][0]) : _mm256_setzero_pd();
+               __m256d vdb0123_3 = i + 3 < batch_size ? _mm256_loadu_pd(&dbs[i + 3][0]) : _mm256_setzero_pd();
+
+               __m256d vdb01_02 = _mm256_permute2f128_pd(vdb0123_0, vdb0123_2, (2 << 4) | 0);
+               __m256d vdb01_13 = _mm256_permute2f128_pd(vdb0123_1, vdb0123_3, (2 << 4) | 0);
+               __m256d vdb23_02 = _mm256_permute2f128_pd(vdb0123_0, vdb0123_2, (3 << 4) | 1);
+               __m256d vdb23_13 = _mm256_permute2f128_pd(vdb0123_1, vdb0123_3, (3 << 4) | 1);
+
+               __m256d vdb0 = _mm256_unpacklo_pd(vdb01_02, vdb01_13);
+               __m256d vdb1 = _mm256_unpackhi_pd(vdb01_02, vdb01_13);
+               __m256d vdb2 = _mm256_unpacklo_pd(vdb23_02, vdb23_13);
+               __m256d vdb3 = _mm256_unpackhi_pd(vdb23_02, vdb23_13);
+               __m256d vdb4 = _mm256_set_pd(
+                       i + 3 < batch_size ? dbs[i + 3][4] : 0.0,
+                       i + 2 < batch_size ? dbs[i + 2][4] : 0.0,
+                       i + 1 < batch_size ? dbs[i + 1][4] : 0.0,
+                       dbs[i][4]
+               );
+
+               __m256d vdc0123_0 = _mm256_loadu_pd(&dcs[i][0]);
+               __m256d vdc0123_1 = i + 1 < batch_size ? _mm256_loadu_pd(&dcs[i + 1][0]) : _mm256_setzero_pd();
+               __m256d vdc0123_2 = i + 2 < batch_size ? _mm256_loadu_pd(&dcs[i + 2][0]) : _mm256_setzero_pd();
+               __m256d vdc0123_3 = i + 3 < batch_size ? _mm256_loadu_pd(&dcs[i + 3][0]) : _mm256_setzero_pd();
+
+               __m256d vdc01_02 = _mm256_permute2f128_pd(vdc0123_0, vdc0123_2, (2 << 4) | 0);
+               __m256d vdc01_13 = _mm256_permute2f128_pd(vdc0123_1, vdc0123_3, (2 << 4) | 0);
+               __m256d vdc23_02 = _mm256_permute2f128_pd(vdc0123_0, vdc0123_2, (3 << 4) | 1);
+               __m256d vdc23_13 = _mm256_permute2f128_pd(vdc0123_1, vdc0123_3, (3 << 4) | 1);
+
+               __m256d vdc0 = _mm256_unpacklo_pd(vdc01_02, vdc01_13);
+               __m256d vdc1 = _mm256_unpackhi_pd(vdc01_02, vdc01_13);
+               __m256d vdc2 = _mm256_unpacklo_pd(vdc23_02, vdc23_13);
+
+               __m128i vcounts_halfmax = _mm_max_epi32(vcounts, _mm_shuffle_epi32(vcounts, (3 << 2) | 2));
+               int32 count_max = _mm_cvtsi128_si32(_mm_max_epi32(vcounts_halfmax, _mm_shuffle_epi32(vcounts_halfmax, 1)));
+
+               for (int32 j = 0; j < count_max; j += 4) {
+                       __m256d vsp0123_0 = j < counts[i] ? _mm256_loadu_pd(&sps[i][j]) : _mm256_setzero_pd();
+                       __m256d vsp0123_1 = i + 1 < batch_size && j < counts[i + 1] ? _mm256_loadu_pd(&sps[i + 1][j]) : _mm256_setzero_pd();
+                       __m256d vsp0123_2 = i + 1 < batch_size && j < counts[i + 2] ? _mm256_loadu_pd(&sps[i + 2][j]) : _mm256_setzero_pd();
+                       __m256d vsp0123_3 = i + 1 < batch_size && j < counts[i + 3] ? _mm256_loadu_pd(&sps[i + 3][j]) : _mm256_setzero_pd();
+
+                       __m256d vsp01_02 = _mm256_permute2f128_pd(vsp0123_0, vsp0123_2, (2 << 4) | 0);
+                       __m256d vsp01_13 = _mm256_permute2f128_pd(vsp0123_1, vsp0123_3, (2 << 4) | 0);
+                       __m256d vsp23_02 = _mm256_permute2f128_pd(vsp0123_0, vsp0123_2, (3 << 4) | 1);
+                       __m256d vsp23_13 = _mm256_permute2f128_pd(vsp0123_1, vsp0123_3, (3 << 4) | 1);
+
+                       __m256d vsps[4];
+                       vsps[0] = _mm256_unpacklo_pd(vsp01_02, vsp01_13);
+                       vsps[1] = _mm256_unpackhi_pd(vsp01_02, vsp01_13);
+                       vsps[2] = _mm256_unpacklo_pd(vsp23_02, vsp23_13);
+                       vsps[3] = _mm256_unpackhi_pd(vsp23_02, vsp23_13);
+
+                       for (int k = 0; k < 4; k++) {
+                               __m256d vdas[4];
+
+#if USE_X86_EXT_INTRIN >= 9
+                               __m256d vmask = _mm256_castsi256_pd(_mm256_cvtepi32_epi64(_mm_cmplt_epi32(_mm_set1_epi32(j + k), vcounts)));
+                               vdas[0] = _mm256_fnmadd_pd(vdc2, vdb4, vsps[k]);
+#else
+                               __m128i vmask32 = _mm_cmplt_epi32(_mm_set1_epi32(j + k), vcounts);
+                               __m256d vmask = _mm256_insertf128_pd(
+                                       _mm256_castpd128_pd256(_mm_castsi128_pd(_mm_unpacklo_epi32(vmask32, vmask32))),
+                                       _mm_castsi128_pd(_mm_unpackhi_epi32(vmask32, vmask32)),
+                                       1
+                               );
+                               vdas[0] = _mm256_sub_pd(vsps[k], _mm256_mul_pd(vdc2, vdb4));
+#endif
+
+                               vdas[1] = vdb1;
+                               vdas[2] = vdb2;
+                               vdas[3] = vdb3;
+
+#if USE_X86_EXT_INTRIN >= 9
+                               vdb1 = _mm256_blendv_pd(vdb1, _mm256_fmsub_pd(_mm256_add_pd(vdb0, vdas[0]), vdc0, _mm256_mul_pd(vdb1, vdc1)), vmask);
+                               vdb2 = _mm256_blendv_pd(vdb2, _mm256_fmsub_pd(_mm256_add_pd(vdb1, vdas[1]), vdc0, _mm256_mul_pd(vdb2, vdc1)), vmask);
+                               vdb3 = _mm256_blendv_pd(vdb3, _mm256_fmsub_pd(_mm256_add_pd(vdb2, vdas[2]), vdc0, _mm256_mul_pd(vdb3, vdc1)), vmask);
+                               vdb4 = _mm256_blendv_pd(vdb4, _mm256_fmsub_pd(_mm256_add_pd(vdb3, vdas[3]), vdc0, _mm256_mul_pd(vdb4, vdc1)), vmask);
+#else
+                               vdb1 = _mm256_blendv_pd(vdb1, _mm256_sub_pd(_mm256_mul_pd(_mm256_add_pd(vdb0, vdas[0]), vdc0), _mm256_mul_pd(vdb1, vdc1)), vmask);
+                               vdb2 = _mm256_blendv_pd(vdb2, _mm256_sub_pd(_mm256_mul_pd(_mm256_add_pd(vdb1, vdas[1]), vdc0), _mm256_mul_pd(vdb2, vdc1)), vmask);
+                               vdb3 = _mm256_blendv_pd(vdb3, _mm256_sub_pd(_mm256_mul_pd(_mm256_add_pd(vdb2, vdas[2]), vdc0), _mm256_mul_pd(vdb3, vdc1)), vmask);
+                               vdb4 = _mm256_blendv_pd(vdb4, _mm256_sub_pd(_mm256_mul_pd(_mm256_add_pd(vdb3, vdas[3]), vdc0), _mm256_mul_pd(vdb4, vdc1)), vmask);
+#endif
+                               vdb0 = _mm256_blendv_pd(vdb0, vdas[0], vmask);
+                               vsps[k] = vdb4;
+                       }
+
+                       vsp01_02 = _mm256_unpacklo_pd(vsps[0], vsps[1]);
+                       vsp01_13 = _mm256_unpackhi_pd(vsps[0], vsps[1]);
+                       vsp23_02 = _mm256_unpacklo_pd(vsps[2], vsps[3]);
+                       vsp23_13 = _mm256_unpackhi_pd(vsps[2], vsps[3]);
+
+                       vsp0123_0 = _mm256_permute2f128_pd(vsp01_02, vsp23_02, (2 << 4) | 0);
+                       vsp0123_1 = _mm256_permute2f128_pd(vsp01_13, vsp23_13, (2 << 4) | 0);
+                       vsp0123_2 = _mm256_permute2f128_pd(vsp01_02, vsp23_02, (3 << 4) | 1);
+                       vsp0123_3 = _mm256_permute2f128_pd(vsp01_13, vsp23_13, (3 << 4) | 1);
+
+                       if (j < counts[i])
+                               _mm256_storeu_pd(&sps[i][j], vsp0123_0);
+
+                       if (i + 1 < batch_size && j < counts[i + 1])
+                               _mm256_storeu_pd(&sps[i + 1][j], vsp0123_1);
+
+                       if (i + 2 < batch_size && j < counts[i + 2])
+                               _mm256_storeu_pd(&sps[i + 2][j], vsp0123_2);
+
+                       if (i + 3 < batch_size && j < counts[i + 3])
+                               _mm256_storeu_pd(&sps[i + 3][j], vsp0123_3);
+               }
+
+               vdb01_02 = _mm256_unpacklo_pd(vdb0, vdb1);
+               vdb01_13 = _mm256_unpackhi_pd(vdb0, vdb1);
+               vdb23_02 = _mm256_unpacklo_pd(vdb2, vdb3);
+               vdb23_13 = _mm256_unpackhi_pd(vdb2, vdb3);
+
+               vdb0123_0 = _mm256_permute2f128_pd(vdb01_02, vdb23_02, (2 << 4) | 0);
+               vdb0123_1 = _mm256_permute2f128_pd(vdb01_13, vdb23_13, (2 << 4) | 0);
+               vdb0123_2 = _mm256_permute2f128_pd(vdb01_02, vdb23_02, (3 << 4) | 1);
+               vdb0123_3 = _mm256_permute2f128_pd(vdb01_13, vdb23_13, (3 << 4) | 1);
+
+               _mm256_storeu_pd(&dbs[i][0], vdb0123_0);
+               dbs[i][4] = MM256_EXTRACT_F64(vdb4, 0);
+
+               if (i + 1 < batch_size) {
+                       _mm256_storeu_pd(&dbs[i + 1][0], vdb0123_1);
+                       dbs[i + 1][4] = MM256_EXTRACT_F64(vdb4, 1);
+               }
+
+               if (i + 2 < batch_size) {
+                       _mm256_storeu_pd(&dbs[i + 2][0], vdb0123_2);
+                       dbs[i + 2][4] = MM256_EXTRACT_F64(vdb4, 2);
+               }
+
+               if (i + 3 < batch_size) {
+                       _mm256_storeu_pd(&dbs[i + 3][0], vdb0123_3);
+                       dbs[i + 3][4] = MM256_EXTRACT_F64(vdb4, 3);
+               }
+       }
+}
+
+#elif (USE_X86_EXT_INTRIN >= 3) && defined(DATA_T_DOUBLE) && defined(FLOAT_T_DOUBLE)
 
 static void sample_filter_LPF24_batch(int batch_size, FILTER_T **dcs, FILTER_T **dbs, DATA_T **sps, int32 *counts)
 {
@@ -4051,7 +4209,145 @@ static void sample_filter_LPF24_batch(int batch_size, FILTER_T **dcs, FILTER_T *
 
 #endif
 
-#if (USE_X86_EXT_INTRIN >= 3) && defined(DATA_T_DOUBLE) && defined(FLOAT_T_DOUBLE)
+#if (USE_X86_EXT_INTRIN >= 8) && defined(DATA_T_DOUBLE) && defined(FLOAT_T_DOUBLE)
+
+static void recalc_filter_LPF24_batch(int batch_size, FilterCoefficients **fcs)
+{
+       for (int i = 0; i < MIX_VOICE_BATCH_SIZE; i += 4) {
+               if (i >= batch_size)
+                       break;
+
+               __m256d vfcrange0123_0 = _mm256_loadu_pd(fcs[i]->range);
+               __m256d vfcrange0123_1 = i + 1 < batch_size ? _mm256_loadu_pd(fcs[i + 1]->range) : _mm256_setzero_pd();
+               __m256d vfcrange0123_2 = i + 2 < batch_size ? _mm256_loadu_pd(fcs[i + 2]->range) : _mm256_setzero_pd();
+               __m256d vfcrange0123_3 = i + 3 < batch_size ? _mm256_loadu_pd(fcs[i + 3]->range) : _mm256_setzero_pd();
+
+               __m256d vfcrange01_02 = _mm256_permute2f128_pd(vfcrange0123_0, vfcrange0123_2, (2 << 4) | 0);
+               __m256d vfcrange01_13 = _mm256_permute2f128_pd(vfcrange0123_1, vfcrange0123_3, (2 << 4) | 0);
+               __m256d vfcrange23_02 = _mm256_permute2f128_pd(vfcrange0123_0, vfcrange0123_2, (3 << 4) | 1);
+               __m256d vfcrange23_13 = _mm256_permute2f128_pd(vfcrange0123_1, vfcrange0123_3, (3 << 4) | 1);
+
+               __m256d vfcrange0 = _mm256_unpacklo_pd(vfcrange01_02, vfcrange01_13);
+               __m256d vfcrange1 = _mm256_unpackhi_pd(vfcrange01_02, vfcrange01_13);
+               __m256d vfcrange2 = _mm256_unpacklo_pd(vfcrange23_02, vfcrange23_13);
+               __m256d vfcrange3 = _mm256_unpackhi_pd(vfcrange23_02, vfcrange23_13);
+
+               __m256d vfcfreq = _mm256_set_pd(
+                       i + 3 < batch_size ? fcs[i + 3]->freq : 0.0,
+                       i + 2 < batch_size ? fcs[i + 2]->freq : 0.0,
+                       i + 1 < batch_size ? fcs[i + 1]->freq : 0.0,
+                       fcs[i]->freq
+               );
+
+               __m256d vfcreso_DB = _mm256_set_pd(
+                       i + 3 < batch_size ? fcs[i + 3]->reso_dB : 0.0,
+                       i + 2 < batch_size ? fcs[i + 2]->reso_dB : 0.0,
+                       i + 1 < batch_size ? fcs[i + 1]->reso_dB : 0.0,
+                       fcs[i]->reso_dB
+               );
+
+               __m256d vmask = _mm256_or_pd(
+                       _mm256_or_pd(_mm256_cmp_pd(vfcfreq, vfcrange0, _CMP_LT_OS), _mm256_cmp_pd(vfcfreq, vfcrange1, _CMP_GT_OS)),
+                       _mm256_or_pd(_mm256_cmp_pd(vfcreso_DB, vfcrange2, _CMP_LT_OS), _mm256_cmp_pd(vfcreso_DB, vfcrange3, _CMP_GT_OS))
+               );
+
+               int imask = _mm256_movemask_pd(vmask);
+
+               if (batch_size - i < 4)
+                       imask &= (1 << (batch_size - i)) - 1;
+
+               if (imask) {
+                       __m256d v1mmargin = _mm256_set1_pd(1.0 - ext_filter_margin);
+                       __m256d v1pmargin = _mm256_set1_pd(1.0 + ext_filter_margin);
+
+                       vfcrange0 = _mm256_mul_pd(vfcfreq, v1mmargin);
+                       vfcrange1 = _mm256_mul_pd(vfcfreq, v1pmargin);
+                       vfcrange2 = _mm256_mul_pd(vfcreso_DB, v1mmargin);
+                       vfcrange3 = _mm256_mul_pd(vfcreso_DB, v1pmargin);
+
+                       vfcrange01_02 = _mm256_unpacklo_pd(vfcrange0, vfcrange1);
+                       vfcrange01_13 = _mm256_unpackhi_pd(vfcrange0, vfcrange1);
+                       vfcrange23_02 = _mm256_unpacklo_pd(vfcrange2, vfcrange3);
+                       vfcrange23_13 = _mm256_unpackhi_pd(vfcrange2, vfcrange3);
+
+                       vfcrange0123_0 = _mm256_permute2f128_pd(vfcrange01_02, vfcrange23_02, (2 << 4) | 0);
+                       vfcrange0123_1 = _mm256_permute2f128_pd(vfcrange01_13, vfcrange23_13, (2 << 4) | 0);
+                       vfcrange0123_2 = _mm256_permute2f128_pd(vfcrange01_02, vfcrange23_02, (3 << 4) | 1);
+                       vfcrange0123_3 = _mm256_permute2f128_pd(vfcrange01_13, vfcrange23_13, (3 << 4) | 1);
+
+                       if (imask & 1)
+                               _mm256_storeu_pd(fcs[i]->range, vfcrange0123_0);
+
+                       if (imask & (1 << 1))
+                               _mm256_storeu_pd(fcs[i + 1]->range, vfcrange0123_1);
+
+                       if (imask & (1 << 2))
+                               _mm256_storeu_pd(fcs[i + 2]->range, vfcrange0123_2);
+
+                       if (imask & (1 << 3))
+                               _mm256_storeu_pd(fcs[i + 3]->range, vfcrange0123_3);
+
+                       __m256d vfcdiv_flt_rate = _mm256_set_pd(
+                               i + 3 < batch_size ? fcs[i + 3]->div_flt_rate : fcs[i]->div_flt_rate,
+                               i + 2 < batch_size ? fcs[i + 2]->div_flt_rate : fcs[i]->div_flt_rate,
+                               i + 1 < batch_size ? fcs[i + 1]->div_flt_rate : fcs[i]->div_flt_rate,
+                               fcs[i]->div_flt_rate
+                       );
+
+                       __m256d v0_5 = _mm256_set1_pd(0.5);
+                       __m256d v0_8 = _mm256_set1_pd(0.8);
+                       __m256d v1 = _mm256_set1_pd(1.0);
+                       __m256d v2 = _mm256_set1_pd(2.0);
+                       __m256d v5_6 = _mm256_set1_pd(5.6);
+
+                       __m256d vf = _mm256_mul_pd(_mm256_mul_pd(v2, vfcfreq), vfcdiv_flt_rate);
+                       __m256d vp = _mm256_sub_pd(v1, vf);
+
+                       FLOAT_T reso_db_cf_m = RESO_DB_CF_M(fcs[i]->reso_dB);
+
+                       __m256d vreso_db_cf_m = _mm256_set_pd(
+                               i + 3 < batch_size ? RESO_DB_CF_M(fcs[i + 3]->reso_dB) : reso_db_cf_m,
+                               i + 2 < batch_size ? RESO_DB_CF_M(fcs[i + 2]->reso_dB) : reso_db_cf_m,
+                               i + 1 < batch_size ? RESO_DB_CF_M(fcs[i + 1]->reso_dB) : reso_db_cf_m,
+                               reso_db_cf_m
+                       );
+
+                       __m256d vq = _mm256_mul_pd(v0_8, _mm256_sub_pd(v1, vreso_db_cf_m));
+                       __m256d vdc0 = MM256_FMA_PD(_mm256_mul_pd(v0_8, vf), vp, vf);
+#if USE_X86_EXT_INTRIN >= 9
+                       __m256d vdc1 = _mm256_fmsub_pd(vdc0, v2, v1);
+#else
+                       __m256d vdc1 = _mm256_sub_pd(_mm256_add_pd(vdc0, vdc0), v1);
+#endif
+                       __m256d vdc2 = _mm256_mul_pd(vq, MM256_FMA_PD(_mm256_mul_pd(v0_5, vp), _mm256_sub_pd(MM256_FMA_PD(_mm256_mul_pd(v5_6, vp), vp, v1), vp), v1));
+                       __m256d vdc3 = _mm256_setzero_pd();
+
+                       __m256d vdc01_02 = _mm256_unpacklo_pd(vdc0, vdc1);
+                       __m256d vdc01_13 = _mm256_unpackhi_pd(vdc0, vdc1);
+                       __m256d vdc23_02 = _mm256_unpacklo_pd(vdc2, vdc3);
+                       __m256d vdc23_13 = _mm256_unpackhi_pd(vdc2, vdc3);
+
+                       __m256d vdc0123_0 = _mm256_permute2f128_pd(vdc01_02, vdc23_02, (2 << 4) | 0);
+                       __m256d vdc0123_1 = _mm256_permute2f128_pd(vdc01_13, vdc23_13, (2 << 4) | 0);
+                       __m256d vdc0123_2 = _mm256_permute2f128_pd(vdc01_02, vdc23_02, (3 << 4) | 1);
+                       __m256d vdc0123_3 = _mm256_permute2f128_pd(vdc01_13, vdc23_13, (3 << 4) | 1);
+
+                       if (imask & 1)
+                               _mm256_storeu_pd(&fcs[i]->dc[0], vdc0123_0);
+
+                       if (imask & (1 << 1))
+                               _mm256_storeu_pd(&fcs[i + 1]->dc[0], vdc0123_1);
+
+                       if (imask & (1 << 2))
+                               _mm256_storeu_pd(&fcs[i + 2]->dc[0], vdc0123_2);
+
+                       if (imask & (1 << 3))
+                               _mm256_storeu_pd(&fcs[i + 3]->dc[0], vdc0123_3);
+               }
+       }
+}
+
+#elif (USE_X86_EXT_INTRIN >= 3) && defined(DATA_T_DOUBLE) && defined(FLOAT_T_DOUBLE)
 
 static void recalc_filter_LPF24_batch(int batch_size, FilterCoefficients **fcs)
 {