diff mbox series

i386: Auto vectorize usdot_prod, udot_prod with AVXVNNIINT16 instruction.

Message ID 20230714062413.2277485-1-haochen.jiang@intel.com
State New
Headers show
Series i386: Auto vectorize usdot_prod, udot_prod with AVXVNNIINT16 instruction. | expand

Commit Message

Jiang, Haochen July 14, 2023, 6:24 a.m. UTC
Hi all,

This patch aims to auto vectorize usdot_prod and udot_prod with newly
introduced AVX-VNNI-INT16.

Also I refined the redundant mode iterator in the patch.

Regtested on x86_64-pc-linux-gnu. Ok for trunk after AVX-VNNI-INT16 patch
checked in?

BRs,
Haochen

gcc/ChangeLog:

	* config/i386/sse.md (VI2_AVX2): Delete V32HI since we actually
	have the same iterator. Also renaming all the occurence to
	VI2_AVX2_AVX512BW.
	(usdot_prod<mode>): New define_expand.
	(udot_prod<mode>): Ditto.

gcc/testsuite/ChangeLog:

	* gcc.target/i386/vnniint16-auto-vectorize-1.c: New test.
	* gcc.target/i386/vnniint16-auto-vectorize-2.c: Ditto.
---
 gcc/config/i386/sse.md                        | 98 +++++++++++++------
 .../i386/vnniint16-auto-vectorize-1.c         | 28 ++++++
 .../i386/vnniint16-auto-vectorize-2.c         | 76 ++++++++++++++
 3 files changed, 172 insertions(+), 30 deletions(-)
 create mode 100644 gcc/testsuite/gcc.target/i386/vnniint16-auto-vectorize-1.c
 create mode 100644 gcc/testsuite/gcc.target/i386/vnniint16-auto-vectorize-2.c

Comments

Uros Bizjak July 14, 2023, 7:14 a.m. UTC | #1
On Fri, Jul 14, 2023 at 8:24 AM Haochen Jiang <haochen.jiang@intel.com> wrote:
>
> Hi all,
>
> This patch aims to auto vectorize usdot_prod and udot_prod with newly
> introduced AVX-VNNI-INT16.
>
> Also I refined the redundant mode iterator in the patch.
>
> Regtested on x86_64-pc-linux-gnu. Ok for trunk after AVX-VNNI-INT16 patch
> checked in?
>
> BRs,
> Haochen
>
> gcc/ChangeLog:
>
>         * config/i386/sse.md (VI2_AVX2): Delete V32HI since we actually
>         have the same iterator. Also renaming all the occurence to
>         VI2_AVX2_AVX512BW.
>         (usdot_prod<mode>): New define_expand.
>         (udot_prod<mode>): Ditto.
>
> gcc/testsuite/ChangeLog:
>
>         * gcc.target/i386/vnniint16-auto-vectorize-1.c: New test.
>         * gcc.target/i386/vnniint16-auto-vectorize-2.c: Ditto.

OK with two changes below.

Thanks,
Uros.

> ---
>  gcc/config/i386/sse.md                        | 98 +++++++++++++------
>  .../i386/vnniint16-auto-vectorize-1.c         | 28 ++++++
>  .../i386/vnniint16-auto-vectorize-2.c         | 76 ++++++++++++++
>  3 files changed, 172 insertions(+), 30 deletions(-)
>  create mode 100644 gcc/testsuite/gcc.target/i386/vnniint16-auto-vectorize-1.c
>  create mode 100644 gcc/testsuite/gcc.target/i386/vnniint16-auto-vectorize-2.c
>
> diff --git a/gcc/config/i386/sse.md b/gcc/config/i386/sse.md
> index 7471932b27e..98e7f9334bc 100644
> --- a/gcc/config/i386/sse.md
> +++ b/gcc/config/i386/sse.md
> @@ -545,6 +545,9 @@
>     V32HI (V16HI "TARGET_AVX512VL")])
>
>  (define_mode_iterator VI2_AVX2
> +  [(V16HI "TARGET_AVX2") V8HI])
> +
> +(define_mode_iterator VI2_AVX2_AVX512BW
>    [(V32HI "TARGET_AVX512BW") (V16HI "TARGET_AVX2") V8HI])
>
>  (define_mode_iterator VI2_AVX512F
> @@ -637,9 +640,6 @@
>     (V16HI "TARGET_AVX2") V8HI
>     (V8SI "TARGET_AVX2") V4SI])
>
> -(define_mode_iterator VI2_AVX2_AVX512BW
> -  [(V32HI "TARGET_AVX512BW") (V16HI "TARGET_AVX2") V8HI])
> -
>  (define_mode_iterator VI248_AVX512VL
>    [V32HI V16SI V8DI
>     (V16HI "TARGET_AVX512VL") (V8SI "TARGET_AVX512VL")
> @@ -15298,16 +15298,16 @@
>  })
>
>  (define_expand "mul<mode>3<mask_name>"
> -  [(set (match_operand:VI2_AVX2 0 "register_operand")
> -       (mult:VI2_AVX2 (match_operand:VI2_AVX2 1 "vector_operand")
> -                      (match_operand:VI2_AVX2 2 "vector_operand")))]
> +  [(set (match_operand:VI2_AVX2_AVX512BW 0 "register_operand")
> +       (mult:VI2_AVX2_AVX512BW (match_operand:VI2_AVX2_AVX512BW 1 "vector_operand")
> +                      (match_operand:VI2_AVX2_AVX512BW 2 "vector_operand")))]
>    "TARGET_SSE2 && <mask_mode512bit_condition> && <mask_avx512bw_condition>"
>    "ix86_fixup_binary_operands_no_copy (MULT, <MODE>mode, operands);")
>
>  (define_insn "*mul<mode>3<mask_name>"
> -  [(set (match_operand:VI2_AVX2 0 "register_operand" "=x,<v_Yw>")
> -       (mult:VI2_AVX2 (match_operand:VI2_AVX2 1 "vector_operand" "%0,<v_Yw>")
> -                      (match_operand:VI2_AVX2 2 "vector_operand" "xBm,<v_Yw>m")))]
> +  [(set (match_operand:VI2_AVX2_AVX512BW 0 "register_operand" "=x,<v_Yw>")
> +       (mult:VI2_AVX2_AVX512BW (match_operand:VI2_AVX2_AVX512BW 1 "vector_operand" "%0,<v_Yw>")
> +                      (match_operand:VI2_AVX2_AVX512BW 2 "vector_operand" "xBm,<v_Yw>m")))]
>    "TARGET_SSE2 && !(MEM_P (operands[1]) && MEM_P (operands[2]))
>     && <mask_mode512bit_condition> && <mask_avx512bw_condition>"
>    "@
> @@ -15320,28 +15320,28 @@
>     (set_attr "mode" "<sseinsnmode>")])
>
>  (define_expand "<s>mul<mode>3_highpart<mask_name>"
> -  [(set (match_operand:VI2_AVX2 0 "register_operand")
> -       (truncate:VI2_AVX2
> +  [(set (match_operand:VI2_AVX2_AVX512BW 0 "register_operand")
> +       (truncate:VI2_AVX2_AVX512BW
>           (lshiftrt:<ssedoublemode>
>             (mult:<ssedoublemode>
>               (any_extend:<ssedoublemode>
> -               (match_operand:VI2_AVX2 1 "vector_operand"))
> +               (match_operand:VI2_AVX2_AVX512BW 1 "vector_operand"))
>               (any_extend:<ssedoublemode>
> -               (match_operand:VI2_AVX2 2 "vector_operand")))
> +               (match_operand:VI2_AVX2_AVX512BW 2 "vector_operand")))
>             (const_int 16))))]
>    "TARGET_SSE2
>     && <mask_mode512bit_condition> && <mask_avx512bw_condition>"
>    "ix86_fixup_binary_operands_no_copy (MULT, <MODE>mode, operands);")
>
>  (define_insn "*<s>mul<mode>3_highpart<mask_name>"
> -  [(set (match_operand:VI2_AVX2 0 "register_operand" "=x,<v_Yw>")
> -       (truncate:VI2_AVX2
> +  [(set (match_operand:VI2_AVX2_AVX512BW 0 "register_operand" "=x,<v_Yw>")
> +       (truncate:VI2_AVX2_AVX512BW
>           (lshiftrt:<ssedoublemode>
>             (mult:<ssedoublemode>
>               (any_extend:<ssedoublemode>
> -               (match_operand:VI2_AVX2 1 "vector_operand" "%0,<v_Yw>"))
> +               (match_operand:VI2_AVX2_AVX512BW 1 "vector_operand" "%0,<v_Yw>"))
>               (any_extend:<ssedoublemode>
> -               (match_operand:VI2_AVX2 2 "vector_operand" "xBm,<v_Yw>m")))
> +               (match_operand:VI2_AVX2_AVX512BW 2 "vector_operand" "xBm,<v_Yw>m")))
>             (const_int 16))))]
>    "TARGET_SSE2 && !(MEM_P (operands[1]) && MEM_P (operands[2]))
>     && <mask_mode512bit_condition> && <mask_avx512bw_condition>"
> @@ -15591,8 +15591,8 @@
>  (define_insn "avx512bw_pmaddwd512<mode><mask_name>"
>    [(set (match_operand:<sseunpackmode> 0 "register_operand" "=v")
>            (unspec:<sseunpackmode>
> -            [(match_operand:VI2_AVX2 1 "register_operand" "v")
> -             (match_operand:VI2_AVX2 2 "nonimmediate_operand" "vm")]
> +            [(match_operand:VI2_AVX2_AVX512BW 1 "register_operand" "v")
> +             (match_operand:VI2_AVX2_AVX512BW 2 "nonimmediate_operand" "vm")]
>               UNSPEC_PMADDWD512))]
>     "TARGET_AVX512BW && <mask_mode512bit_condition>"
>     "vpmaddwd\t{%2, %1, %0<mask_operand3>|%0<mask_operand3>, %1, %2}";
> @@ -21569,16 +21569,16 @@
>  })
>
>  (define_expand "smulhrs<mode>3"
> -  [(set (match_operand:VI2_AVX2 0 "register_operand")
> -       (truncate:VI2_AVX2
> +  [(set (match_operand:VI2_AVX2_AVX512BW 0 "register_operand")
> +       (truncate:VI2_AVX2_AVX512BW
>           (lshiftrt:<ssedoublemode>
>             (plus:<ssedoublemode>
>               (lshiftrt:<ssedoublemode>
>                 (mult:<ssedoublemode>
>                   (sign_extend:<ssedoublemode>
> -                   (match_operand:VI2_AVX2 1 "nonimmediate_operand"))
> +                   (match_operand:VI2_AVX2_AVX512BW 1 "nonimmediate_operand"))
>                   (sign_extend:<ssedoublemode>
> -                   (match_operand:VI2_AVX2 2 "nonimmediate_operand")))
> +                   (match_operand:VI2_AVX2_AVX512BW 2 "nonimmediate_operand")))
>                 (const_int 14))
>               (match_dup 3))
>             (const_int 1))))]
> @@ -21589,18 +21589,18 @@
>  })
>
>  (define_insn "*<ssse3_avx2>_pmulhrsw<mode>3<mask_name>"
> -  [(set (match_operand:VI2_AVX2 0 "register_operand" "=x,<v_Yw>")
> -       (truncate:VI2_AVX2
> +  [(set (match_operand:VI2_AVX2_AVX512BW 0 "register_operand" "=x,<v_Yw>")
> +       (truncate:VI2_AVX2_AVX512BW
>           (lshiftrt:<ssedoublemode>
>             (plus:<ssedoublemode>
>               (lshiftrt:<ssedoublemode>
>                 (mult:<ssedoublemode>
>                   (sign_extend:<ssedoublemode>
> -                   (match_operand:VI2_AVX2 1 "vector_operand" "%0,<v_Yw>"))
> +                   (match_operand:VI2_AVX2_AVX512BW 1 "vector_operand" "%0,<v_Yw>"))
>                   (sign_extend:<ssedoublemode>
> -                   (match_operand:VI2_AVX2 2 "vector_operand" "xBm,<v_Yw>m")))
> +                   (match_operand:VI2_AVX2_AVX512BW 2 "vector_operand" "xBm,<v_Yw>m")))
>                 (const_int 14))
> -             (match_operand:VI2_AVX2 3 "const1_operand"))
> +             (match_operand:VI2_AVX2_AVX512BW 3 "const1_operand"))
>             (const_int 1))))]
>    "TARGET_SSSE3 && <mask_mode512bit_condition> && <mask_avx512bw_condition>
>     && !(MEM_P (operands[1]) && MEM_P (operands[2]))"
> @@ -22327,8 +22327,8 @@
>     (set_attr "mode" "<sseinsnmode>")])
>
>  (define_insn "<sse4_1_avx2>_packusdw<mask_name>"
> -  [(set (match_operand:VI2_AVX2 0 "register_operand" "=Yr,*x,<v_Yw>")
> -       (unspec:VI2_AVX2
> +  [(set (match_operand:VI2_AVX2_AVX512BW 0 "register_operand" "=Yr,*x,<v_Yw>")
> +       (unspec:VI2_AVX2_AVX512BW
>           [(match_operand:<sseunpackmode> 1 "register_operand" "0,0,<v_Yw>")
>            (match_operand:<sseunpackmode> 2 "vector_operand" "YrBm,*xBm,<v_Yw>m")]
>            UNSPEC_US_TRUNCATE))]
> @@ -30340,6 +30340,44 @@
>     (UNSPEC_VPDPWSUD "wsud") (UNSPEC_VPDPWSUDS "wsuds")
>     (UNSPEC_VPDPWUUD "wuud") (UNSPEC_VPDPWUUDS "wuuds")])
>
> +(define_expand "usdot_prod<mode>"
> +  [(match_operand:<sseunpackmode> 0 "register_operand")
> +   (match_operand:VI2_AVX2 1 "register_operand")
> +   (match_operand:VI2_AVX2 2 "register_operand")
> +   (match_operand:<sseunpackmode> 3 "register_operand")]
> +  "TARGET_AVXVNNIINT16"
> +{
> +  operands[1] = lowpart_subreg (<sseunpackmode>mode,
> +                                force_reg (<MODE>mode, operands[1]),
> +                                <MODE>mode);
> +  operands[2] = lowpart_subreg (<sseunpackmode>mode,
> +                                force_reg (<MODE>mode, operands[2]),
> +                                <MODE>mode);
> +  emit_insn (gen_rtx_SET (operands[0], operands[3]));

You don't have to emit a move, the register allocator will do that for you.

> +  emit_insn (gen_vpdpwusd_<SDOT_VPDP_SUF> (operands[0], operands[3],
> +                                          operands[1], operands[2]));
> +  DONE;
> +})
> +
> +(define_expand "udot_prod<mode>"
> +  [(match_operand:<sseunpackmode> 0 "register_operand")
> +   (match_operand:VI2_AVX2 1 "register_operand")
> +   (match_operand:VI2_AVX2 2 "register_operand")
> +   (match_operand:<sseunpackmode> 3 "register_operand")]
> +  "TARGET_AVXVNNIINT16"
> +{
> +  operands[1] = lowpart_subreg (<sseunpackmode>mode,
> +                                force_reg (<MODE>mode, operands[1]),
> +                                <MODE>mode);
> +  operands[2] = lowpart_subreg (<sseunpackmode>mode,
> +                                force_reg (<MODE>mode, operands[2]),
> +                                <MODE>mode);
> +  emit_insn (gen_rtx_SET (operands[0], operands[3]));

Also here, the above is not needed.

> +  emit_insn (gen_vpdpwuud_<SDOT_VPDP_SUF> (operands[0], operands[3],
> +                                          operands[1], operands[2]));
> +  DONE;
> +})
> +
>  (define_insn "vpdp<vpdpwprodtype>_<mode>"
>    [(set (match_operand:VI4_AVX 0 "register_operand" "=x")
>         (unspec:VI4_AVX
> diff --git a/gcc/testsuite/gcc.target/i386/vnniint16-auto-vectorize-1.c b/gcc/testsuite/gcc.target/i386/vnniint16-auto-vectorize-1.c
> new file mode 100644
> index 00000000000..73f0d3296aa
> --- /dev/null
> +++ b/gcc/testsuite/gcc.target/i386/vnniint16-auto-vectorize-1.c
> @@ -0,0 +1,28 @@
> +/* { dg-do compile } */
> +/* { dg-options "-mavxvnniint16 -O2" } */
> +/* { dg-final { scan-assembler "vpdpwusd\t" } } */
> +/* { dg-final { scan-assembler "vpdpwuud\t" } } */
> +
> +int __attribute__((noinline, noclone, optimize("tree-vectorize")))
> +usdot_prod_hi (unsigned short * restrict a, short * restrict b,
> +              int c, int n)
> +{
> +  int i;
> +  for (i = 0; i < n; i++)
> +    {
> +      c += ((int) a[i] * (int) b[i]);
> +    }
> +  return c;
> +}
> +
> +int __attribute__((noinline, noclone, optimize("tree-vectorize")))
> +udot_prod_hi (unsigned short * restrict a, unsigned short *restrict b,
> +             int c, int n)
> +{
> +  int i;
> +  for (i = 0; i < n; i++)
> +    {
> +      c += ((int) a[i] * (int) b[i]);
> +    }
> +  return c;
> +}
> diff --git a/gcc/testsuite/gcc.target/i386/vnniint16-auto-vectorize-2.c b/gcc/testsuite/gcc.target/i386/vnniint16-auto-vectorize-2.c
> new file mode 100644
> index 00000000000..90dc0eade7e
> --- /dev/null
> +++ b/gcc/testsuite/gcc.target/i386/vnniint16-auto-vectorize-2.c
> @@ -0,0 +1,76 @@
> +/* { dg-do run } */
> +/* { dg-options "-O2 -mavxvnniint16" } */
> +/* { dg-require-effective-target avxvnniint16 } */
> +
> +#define AVXVNNIINT16
> +#ifndef CHECK
> +#define CHECK "avx-check.h"
> +#endif
> +
> +#ifndef TEST
> +#define TEST avx_test
> +#endif
> +
> +#include CHECK
> +#include "vnniint16-auto-vectorize-1.c"
> +
> +#define N 256
> +
> +short a_i16[N];
> +unsigned short b_u16[N], c_u16[N], d_u16[N];
> +int i16_exp, i16_ref;
> +
> +int __attribute__((noinline, noclone, optimize("no-tree-vectorize")))
> +udot_prod_hi_scalar (unsigned short * restrict a, unsigned short * restrict b,
> +                    int c, int n)
> +{
> +  int i;
> +  for (i = 0; i < n; i++)
> +    {
> +      c += ((int) a[i] * (int) b[i]);
> +    }
> +  return c;
> +}
> +
> +int __attribute__((noinline, noclone, optimize("no-tree-vectorize")))
> +usdot_prod_hi_scalar (unsigned short * restrict a, short *restrict b,
> +                     int c, int n)
> +{
> +  int i;
> +  for (i = 0; i < n; i++)
> +    {
> +      c += ((int) a[i] * (int) b[i]);
> +    }
> +  return c;
> +}
> +
> +void init ()
> +{
> +  int i;
> +
> +  i16_exp = i16_ref = 65535;
> +
> +  for (i = 0; i < N; i++)
> +    {
> +      a_i16[i] = -i + 2;
> +      b_u16[i] = i * 2;
> +      c_u16[i] = i * 3;
> +      d_u16[i] = i * 4;
> +    }
> +}
> +
> +void
> +TEST (void)
> +{
> +  init ();
> +  i16_exp = usdot_prod_hi (a_i16, b_u16, i16_exp, N);
> +  i16_ref = usdot_prod_hi_scalar (a_i16, b_u16, i16_ref, N);
> +  if (i16_exp != i16_ref)
> +    abort ();
> +
> +  init ();
> +  i16_exp = udot_prod_hi (c_u16, d_u16, i16_exp, N);
> +  i16_ref = udot_prod_hi_scalar (c_u16, d_u16, i16_ref, N);
> +  if (i16_exp != i16_ref)
> +    abort ();
> +}
> --
> 2.31.1
>
diff mbox series

Patch

diff --git a/gcc/config/i386/sse.md b/gcc/config/i386/sse.md
index 7471932b27e..98e7f9334bc 100644
--- a/gcc/config/i386/sse.md
+++ b/gcc/config/i386/sse.md
@@ -545,6 +545,9 @@ 
    V32HI (V16HI "TARGET_AVX512VL")])
 
 (define_mode_iterator VI2_AVX2
+  [(V16HI "TARGET_AVX2") V8HI])
+
+(define_mode_iterator VI2_AVX2_AVX512BW
   [(V32HI "TARGET_AVX512BW") (V16HI "TARGET_AVX2") V8HI])
 
 (define_mode_iterator VI2_AVX512F
@@ -637,9 +640,6 @@ 
    (V16HI "TARGET_AVX2") V8HI
    (V8SI "TARGET_AVX2") V4SI])
 
-(define_mode_iterator VI2_AVX2_AVX512BW
-  [(V32HI "TARGET_AVX512BW") (V16HI "TARGET_AVX2") V8HI])
-
 (define_mode_iterator VI248_AVX512VL
   [V32HI V16SI V8DI
    (V16HI "TARGET_AVX512VL") (V8SI "TARGET_AVX512VL")
@@ -15298,16 +15298,16 @@ 
 })
 
 (define_expand "mul<mode>3<mask_name>"
-  [(set (match_operand:VI2_AVX2 0 "register_operand")
-	(mult:VI2_AVX2 (match_operand:VI2_AVX2 1 "vector_operand")
-		       (match_operand:VI2_AVX2 2 "vector_operand")))]
+  [(set (match_operand:VI2_AVX2_AVX512BW 0 "register_operand")
+	(mult:VI2_AVX2_AVX512BW (match_operand:VI2_AVX2_AVX512BW 1 "vector_operand")
+		       (match_operand:VI2_AVX2_AVX512BW 2 "vector_operand")))]
   "TARGET_SSE2 && <mask_mode512bit_condition> && <mask_avx512bw_condition>"
   "ix86_fixup_binary_operands_no_copy (MULT, <MODE>mode, operands);")
 
 (define_insn "*mul<mode>3<mask_name>"
-  [(set (match_operand:VI2_AVX2 0 "register_operand" "=x,<v_Yw>")
-	(mult:VI2_AVX2 (match_operand:VI2_AVX2 1 "vector_operand" "%0,<v_Yw>")
-		       (match_operand:VI2_AVX2 2 "vector_operand" "xBm,<v_Yw>m")))]
+  [(set (match_operand:VI2_AVX2_AVX512BW 0 "register_operand" "=x,<v_Yw>")
+	(mult:VI2_AVX2_AVX512BW (match_operand:VI2_AVX2_AVX512BW 1 "vector_operand" "%0,<v_Yw>")
+		       (match_operand:VI2_AVX2_AVX512BW 2 "vector_operand" "xBm,<v_Yw>m")))]
   "TARGET_SSE2 && !(MEM_P (operands[1]) && MEM_P (operands[2]))
    && <mask_mode512bit_condition> && <mask_avx512bw_condition>"
   "@
@@ -15320,28 +15320,28 @@ 
    (set_attr "mode" "<sseinsnmode>")])
 
 (define_expand "<s>mul<mode>3_highpart<mask_name>"
-  [(set (match_operand:VI2_AVX2 0 "register_operand")
-	(truncate:VI2_AVX2
+  [(set (match_operand:VI2_AVX2_AVX512BW 0 "register_operand")
+	(truncate:VI2_AVX2_AVX512BW
 	  (lshiftrt:<ssedoublemode>
 	    (mult:<ssedoublemode>
 	      (any_extend:<ssedoublemode>
-		(match_operand:VI2_AVX2 1 "vector_operand"))
+		(match_operand:VI2_AVX2_AVX512BW 1 "vector_operand"))
 	      (any_extend:<ssedoublemode>
-		(match_operand:VI2_AVX2 2 "vector_operand")))
+		(match_operand:VI2_AVX2_AVX512BW 2 "vector_operand")))
 	    (const_int 16))))]
   "TARGET_SSE2
    && <mask_mode512bit_condition> && <mask_avx512bw_condition>"
   "ix86_fixup_binary_operands_no_copy (MULT, <MODE>mode, operands);")
 
 (define_insn "*<s>mul<mode>3_highpart<mask_name>"
-  [(set (match_operand:VI2_AVX2 0 "register_operand" "=x,<v_Yw>")
-	(truncate:VI2_AVX2
+  [(set (match_operand:VI2_AVX2_AVX512BW 0 "register_operand" "=x,<v_Yw>")
+	(truncate:VI2_AVX2_AVX512BW
 	  (lshiftrt:<ssedoublemode>
 	    (mult:<ssedoublemode>
 	      (any_extend:<ssedoublemode>
-		(match_operand:VI2_AVX2 1 "vector_operand" "%0,<v_Yw>"))
+		(match_operand:VI2_AVX2_AVX512BW 1 "vector_operand" "%0,<v_Yw>"))
 	      (any_extend:<ssedoublemode>
-		(match_operand:VI2_AVX2 2 "vector_operand" "xBm,<v_Yw>m")))
+		(match_operand:VI2_AVX2_AVX512BW 2 "vector_operand" "xBm,<v_Yw>m")))
 	    (const_int 16))))]
   "TARGET_SSE2 && !(MEM_P (operands[1]) && MEM_P (operands[2]))
    && <mask_mode512bit_condition> && <mask_avx512bw_condition>"
@@ -15591,8 +15591,8 @@ 
 (define_insn "avx512bw_pmaddwd512<mode><mask_name>"
   [(set (match_operand:<sseunpackmode> 0 "register_operand" "=v")
           (unspec:<sseunpackmode>
-            [(match_operand:VI2_AVX2 1 "register_operand" "v")
-             (match_operand:VI2_AVX2 2 "nonimmediate_operand" "vm")]
+            [(match_operand:VI2_AVX2_AVX512BW 1 "register_operand" "v")
+             (match_operand:VI2_AVX2_AVX512BW 2 "nonimmediate_operand" "vm")]
              UNSPEC_PMADDWD512))]
    "TARGET_AVX512BW && <mask_mode512bit_condition>"
    "vpmaddwd\t{%2, %1, %0<mask_operand3>|%0<mask_operand3>, %1, %2}";
@@ -21569,16 +21569,16 @@ 
 })
 
 (define_expand "smulhrs<mode>3"
-  [(set (match_operand:VI2_AVX2 0 "register_operand")
-	(truncate:VI2_AVX2
+  [(set (match_operand:VI2_AVX2_AVX512BW 0 "register_operand")
+	(truncate:VI2_AVX2_AVX512BW
 	  (lshiftrt:<ssedoublemode>
 	    (plus:<ssedoublemode>
 	      (lshiftrt:<ssedoublemode>
 		(mult:<ssedoublemode>
 		  (sign_extend:<ssedoublemode>
-		    (match_operand:VI2_AVX2 1 "nonimmediate_operand"))
+		    (match_operand:VI2_AVX2_AVX512BW 1 "nonimmediate_operand"))
 		  (sign_extend:<ssedoublemode>
-		    (match_operand:VI2_AVX2 2 "nonimmediate_operand")))
+		    (match_operand:VI2_AVX2_AVX512BW 2 "nonimmediate_operand")))
 		(const_int 14))
 	      (match_dup 3))
 	    (const_int 1))))]
@@ -21589,18 +21589,18 @@ 
 })
 
 (define_insn "*<ssse3_avx2>_pmulhrsw<mode>3<mask_name>"
-  [(set (match_operand:VI2_AVX2 0 "register_operand" "=x,<v_Yw>")
-	(truncate:VI2_AVX2
+  [(set (match_operand:VI2_AVX2_AVX512BW 0 "register_operand" "=x,<v_Yw>")
+	(truncate:VI2_AVX2_AVX512BW
 	  (lshiftrt:<ssedoublemode>
 	    (plus:<ssedoublemode>
 	      (lshiftrt:<ssedoublemode>
 		(mult:<ssedoublemode>
 		  (sign_extend:<ssedoublemode>
-		    (match_operand:VI2_AVX2 1 "vector_operand" "%0,<v_Yw>"))
+		    (match_operand:VI2_AVX2_AVX512BW 1 "vector_operand" "%0,<v_Yw>"))
 		  (sign_extend:<ssedoublemode>
-		    (match_operand:VI2_AVX2 2 "vector_operand" "xBm,<v_Yw>m")))
+		    (match_operand:VI2_AVX2_AVX512BW 2 "vector_operand" "xBm,<v_Yw>m")))
 		(const_int 14))
-	      (match_operand:VI2_AVX2 3 "const1_operand"))
+	      (match_operand:VI2_AVX2_AVX512BW 3 "const1_operand"))
 	    (const_int 1))))]
   "TARGET_SSSE3 && <mask_mode512bit_condition> && <mask_avx512bw_condition>
    && !(MEM_P (operands[1]) && MEM_P (operands[2]))"
@@ -22327,8 +22327,8 @@ 
    (set_attr "mode" "<sseinsnmode>")])
 
 (define_insn "<sse4_1_avx2>_packusdw<mask_name>"
-  [(set (match_operand:VI2_AVX2 0 "register_operand" "=Yr,*x,<v_Yw>")
-	(unspec:VI2_AVX2
+  [(set (match_operand:VI2_AVX2_AVX512BW 0 "register_operand" "=Yr,*x,<v_Yw>")
+	(unspec:VI2_AVX2_AVX512BW
 	  [(match_operand:<sseunpackmode> 1 "register_operand" "0,0,<v_Yw>")
 	   (match_operand:<sseunpackmode> 2 "vector_operand" "YrBm,*xBm,<v_Yw>m")]
 	   UNSPEC_US_TRUNCATE))]
@@ -30340,6 +30340,44 @@ 
    (UNSPEC_VPDPWSUD "wsud") (UNSPEC_VPDPWSUDS "wsuds")
    (UNSPEC_VPDPWUUD "wuud") (UNSPEC_VPDPWUUDS "wuuds")])
 
+(define_expand "usdot_prod<mode>"
+  [(match_operand:<sseunpackmode> 0 "register_operand")
+   (match_operand:VI2_AVX2 1 "register_operand")
+   (match_operand:VI2_AVX2 2 "register_operand")
+   (match_operand:<sseunpackmode> 3 "register_operand")]
+  "TARGET_AVXVNNIINT16"
+{
+  operands[1] = lowpart_subreg (<sseunpackmode>mode,
+                                force_reg (<MODE>mode, operands[1]),
+                                <MODE>mode);
+  operands[2] = lowpart_subreg (<sseunpackmode>mode,
+                                force_reg (<MODE>mode, operands[2]),
+                                <MODE>mode);
+  emit_insn (gen_rtx_SET (operands[0], operands[3]));
+  emit_insn (gen_vpdpwusd_<SDOT_VPDP_SUF> (operands[0], operands[3],
+					   operands[1], operands[2]));
+  DONE;
+})
+
+(define_expand "udot_prod<mode>"
+  [(match_operand:<sseunpackmode> 0 "register_operand")
+   (match_operand:VI2_AVX2 1 "register_operand")
+   (match_operand:VI2_AVX2 2 "register_operand")
+   (match_operand:<sseunpackmode> 3 "register_operand")]
+  "TARGET_AVXVNNIINT16"
+{
+  operands[1] = lowpart_subreg (<sseunpackmode>mode,
+                                force_reg (<MODE>mode, operands[1]),
+                                <MODE>mode);
+  operands[2] = lowpart_subreg (<sseunpackmode>mode,
+                                force_reg (<MODE>mode, operands[2]),
+                                <MODE>mode);
+  emit_insn (gen_rtx_SET (operands[0], operands[3]));
+  emit_insn (gen_vpdpwuud_<SDOT_VPDP_SUF> (operands[0], operands[3],
+					   operands[1], operands[2]));
+  DONE;
+})
+
 (define_insn "vpdp<vpdpwprodtype>_<mode>"
   [(set (match_operand:VI4_AVX 0 "register_operand" "=x")
 	(unspec:VI4_AVX
diff --git a/gcc/testsuite/gcc.target/i386/vnniint16-auto-vectorize-1.c b/gcc/testsuite/gcc.target/i386/vnniint16-auto-vectorize-1.c
new file mode 100644
index 00000000000..73f0d3296aa
--- /dev/null
+++ b/gcc/testsuite/gcc.target/i386/vnniint16-auto-vectorize-1.c
@@ -0,0 +1,28 @@ 
+/* { dg-do compile } */                                     
+/* { dg-options "-mavxvnniint16 -O2" } */
+/* { dg-final { scan-assembler "vpdpwusd\t" } } */
+/* { dg-final { scan-assembler "vpdpwuud\t" } } */
+
+int __attribute__((noinline, noclone, optimize("tree-vectorize")))
+usdot_prod_hi (unsigned short * restrict a, short * restrict b,
+	       int c, int n)
+{
+  int i;
+  for (i = 0; i < n; i++)
+    {
+      c += ((int) a[i] * (int) b[i]);
+    }
+  return c;
+}
+
+int __attribute__((noinline, noclone, optimize("tree-vectorize")))
+udot_prod_hi (unsigned short * restrict a, unsigned short *restrict b,
+	      int c, int n)
+{
+  int i;
+  for (i = 0; i < n; i++)
+    {
+      c += ((int) a[i] * (int) b[i]);
+    }
+  return c;
+}
diff --git a/gcc/testsuite/gcc.target/i386/vnniint16-auto-vectorize-2.c b/gcc/testsuite/gcc.target/i386/vnniint16-auto-vectorize-2.c
new file mode 100644
index 00000000000..90dc0eade7e
--- /dev/null
+++ b/gcc/testsuite/gcc.target/i386/vnniint16-auto-vectorize-2.c
@@ -0,0 +1,76 @@ 
+/* { dg-do run } */
+/* { dg-options "-O2 -mavxvnniint16" } */
+/* { dg-require-effective-target avxvnniint16 } */
+
+#define AVXVNNIINT16
+#ifndef CHECK
+#define CHECK "avx-check.h"
+#endif
+
+#ifndef TEST
+#define TEST avx_test
+#endif
+
+#include CHECK
+#include "vnniint16-auto-vectorize-1.c"
+
+#define N 256
+
+short a_i16[N];
+unsigned short b_u16[N], c_u16[N], d_u16[N];
+int i16_exp, i16_ref;
+
+int __attribute__((noinline, noclone, optimize("no-tree-vectorize")))
+udot_prod_hi_scalar (unsigned short * restrict a, unsigned short * restrict b,
+		     int c, int n)
+{
+  int i;
+  for (i = 0; i < n; i++)
+    {
+      c += ((int) a[i] * (int) b[i]);
+    }
+  return c;
+}
+
+int __attribute__((noinline, noclone, optimize("no-tree-vectorize")))
+usdot_prod_hi_scalar (unsigned short * restrict a, short *restrict b,
+		      int c, int n)
+{
+  int i;
+  for (i = 0; i < n; i++)
+    {
+      c += ((int) a[i] * (int) b[i]);
+    }
+  return c;
+}
+
+void init ()
+{
+  int i;
+
+  i16_exp = i16_ref = 65535;
+
+  for (i = 0; i < N; i++)
+    {
+      a_i16[i] = -i + 2;
+      b_u16[i] = i * 2;
+      c_u16[i] = i * 3;
+      d_u16[i] = i * 4;
+    }
+}
+
+void
+TEST (void)
+{
+  init ();
+  i16_exp = usdot_prod_hi (a_i16, b_u16, i16_exp, N);
+  i16_ref = usdot_prod_hi_scalar (a_i16, b_u16, i16_ref, N);
+  if (i16_exp != i16_ref)
+    abort ();
+
+  init ();
+  i16_exp = udot_prod_hi (c_u16, d_u16, i16_exp, N);
+  i16_ref = udot_prod_hi_scalar (c_u16, d_u16, i16_ref, N);
+  if (i16_exp != i16_ref)
+    abort ();
+}