diff mbox series

[bpf-next,v3,03/16] bpf: verifier support JMP32

Message ID 1548375028-8308-4-git-send-email-jiong.wang@netronome.com
State Changes Requested
Delegated to: BPF Maintainers
Headers show
Series bpf: propose new jmp32 instructions | expand

Commit Message

Jiong Wang Jan. 25, 2019, 12:10 a.m. UTC
This patch teach verifier about the new BPF_JMP32 instruction class.
Verifier need to treat it similar as the existing BPF_JMP class.
A BPF_JMP32 insn needs to go through all checks that have been done on
BPF_JMP.

Also, verifier is doing runtime optimizations based on the extra info
conditional jump instruction could offer, especially when the comparison is
between constant and register that the value range of the register could be
improved based on the comparison results. These code are updated
accordingly.

Acked-by: Jakub Kicinski <jakub.kicinski@netronome.com>
Signed-off-by: Jiong Wang <jiong.wang@netronome.com>
---
 kernel/bpf/core.c     |   3 +-
 kernel/bpf/verifier.c | 200 ++++++++++++++++++++++++++++++++++++++++++--------
 2 files changed, 170 insertions(+), 33 deletions(-)

Comments

Daniel Borkmann Jan. 26, 2019, 12:36 a.m. UTC | #1
On 01/25/2019 01:10 AM, Jiong Wang wrote:
> This patch teach verifier about the new BPF_JMP32 instruction class.
> Verifier need to treat it similar as the existing BPF_JMP class.
> A BPF_JMP32 insn needs to go through all checks that have been done on
> BPF_JMP.
> 
> Also, verifier is doing runtime optimizations based on the extra info
> conditional jump instruction could offer, especially when the comparison is
> between constant and register that the value range of the register could be
> improved based on the comparison results. These code are updated
> accordingly.
> 
> Acked-by: Jakub Kicinski <jakub.kicinski@netronome.com>
> Signed-off-by: Jiong Wang <jiong.wang@netronome.com>

Series looks good to me, but if I spot this correctly one thing that has
not been addressed here is proper rebase on top of Jakub's dead code
removal, e.g. in opt_hard_wire_dead_code_branches() where we check in
insn_is_cond_jump() for jump opcodes it still only tests for BPF_JMP
class whereas BPF_JMP32 handling needs to be taught here as well.

Thanks,
Daniel
Jiong Wang Jan. 26, 2019, 11:02 a.m. UTC | #2
Daniel Borkmann writes:

> On 01/25/2019 01:10 AM, Jiong Wang wrote:
>> This patch teach verifier about the new BPF_JMP32 instruction class.
>> Verifier need to treat it similar as the existing BPF_JMP class.
>> A BPF_JMP32 insn needs to go through all checks that have been done on
>> BPF_JMP.
>> 
>> Also, verifier is doing runtime optimizations based on the extra info
>> conditional jump instruction could offer, especially when the comparison is
>> between constant and register that the value range of the register could be
>> improved based on the comparison results. These code are updated
>> accordingly.
>> 
>> Acked-by: Jakub Kicinski <jakub.kicinski@netronome.com>
>> Signed-off-by: Jiong Wang <jiong.wang@netronome.com>
>
> Series looks good to me, but if I spot this correctly one thing that has
> not been addressed here is proper rebase on top of Jakub's dead code
> removal, e.g. in opt_hard_wire_dead_code_branches() where we check in
> insn_is_cond_jump() for jump opcodes it still only tests for BPF_JMP
> class whereas BPF_JMP32 handling needs to be taught here as well.

Thanks for catching this. Yes, insn_is_cond_jump() should be updated for
JMP32 as well. JMP32 is guaranteed to be with condition jump operation only
otherwise the earlier do_check will complain use of reserved encoding bits.

I am going to teach insn_is_cond_jump to return true for JMP32. And search
the commits, there is another similar new helper function in nfp driver
jit.

Will fix both places, and re-spin v4.

Thanks.

Regards,
Jiong
diff mbox series

Patch

diff --git a/kernel/bpf/core.c b/kernel/bpf/core.c
index 2a81b8a..1e443ba 100644
--- a/kernel/bpf/core.c
+++ b/kernel/bpf/core.c
@@ -362,7 +362,8 @@  static int bpf_adj_branches(struct bpf_prog *prog, u32 pos, s32 end_old,
 			insn = prog->insnsi + end_old;
 		}
 		code = insn->code;
-		if (BPF_CLASS(code) != BPF_JMP ||
+		if ((BPF_CLASS(code) != BPF_JMP &&
+		     BPF_CLASS(code) != BPF_JMP32) ||
 		    BPF_OP(code) == BPF_EXIT)
 			continue;
 		/* Adjust offset of jmps if we cross patch boundaries. */
diff --git a/kernel/bpf/verifier.c b/kernel/bpf/verifier.c
index eae6cb1..2dbd757 100644
--- a/kernel/bpf/verifier.c
+++ b/kernel/bpf/verifier.c
@@ -1095,7 +1095,7 @@  static int check_subprogs(struct bpf_verifier_env *env)
 	for (i = 0; i < insn_cnt; i++) {
 		u8 code = insn[i].code;
 
-		if (BPF_CLASS(code) != BPF_JMP)
+		if (BPF_CLASS(code) != BPF_JMP && BPF_CLASS(code) != BPF_JMP32)
 			goto next;
 		if (BPF_OP(code) == BPF_EXIT || BPF_OP(code) == BPF_CALL)
 			goto next;
@@ -4031,14 +4031,49 @@  static void find_good_pkt_pointers(struct bpf_verifier_state *vstate,
  *  0 - branch will not be taken and fall-through to next insn
  * -1 - unknown. Example: "if (reg < 5)" is unknown when register value range [0,10]
  */
-static int is_branch_taken(struct bpf_reg_state *reg, u64 val, u8 opcode)
+static int is_branch_taken(struct bpf_reg_state *reg, u64 val, u8 opcode,
+			   bool is_jmp32)
 {
+	struct bpf_reg_state reg_lo;
 	s64 sval;
 
 	if (__is_pointer_value(false, reg))
 		return -1;
 
-	sval = (s64)val;
+	if (is_jmp32) {
+		reg_lo = *reg;
+		reg = &reg_lo;
+		/* For JMP32, only low 32 bits are compared, coerce_reg_to_size
+		 * could truncate high bits and update umin/umax according to
+		 * information of low bits.
+		 */
+		coerce_reg_to_size(reg, 4);
+		/* smin/smax need special handling. For example, after coerce,
+		 * if smin_value is 0x00000000ffffffffLL, the value is -1 when
+		 * used as operand to JMP32. It is a negative number from s32's
+		 * point of view, while it is a positive number when seen as
+		 * s64. The smin/smax are kept as s64, therefore, when used with
+		 * JMP32, they need to be transformed into s32, then sign
+		 * extended back to s64.
+		 *
+		 * Also, smin/smax were copied from umin/umax. If umin/umax has
+		 * different sign bit, then min/max relationship doesn't
+		 * maintain after casting into s32, for this case, set smin/smax
+		 * to safest range.
+		 */
+		if ((reg->umax_value ^ reg->umin_value) &
+		    (1ULL << 31)) {
+			reg->smin_value = S32_MIN;
+			reg->smax_value = S32_MAX;
+		}
+		reg->smin_value = (s64)(s32)reg->smin_value;
+		reg->smax_value = (s64)(s32)reg->smax_value;
+
+		val = (u32)val;
+		sval = (s64)(s32)val;
+	} else {
+		sval = (s64)val;
+	}
 
 	switch (opcode) {
 	case BPF_JEQ:
@@ -4108,6 +4143,29 @@  static int is_branch_taken(struct bpf_reg_state *reg, u64 val, u8 opcode)
 	return -1;
 }
 
+/* Generate min value of the high 32-bit from TNUM info. */
+static u64 gen_hi_min(struct tnum var)
+{
+	return var.value & ~0xffffffffULL;
+}
+
+/* Generate max value of the high 32-bit from TNUM info. */
+static u64 gen_hi_max(struct tnum var)
+{
+	return (var.value | var.mask) & ~0xffffffffULL;
+}
+
+/* Return true if VAL is compared with a s64 sign extended from s32, and they
+ * are with the same signedness.
+ */
+static bool cmp_val_with_extended_s64(s64 sval, struct bpf_reg_state *reg)
+{
+	return ((s32)sval >= 0 &&
+		reg->smin_value >= 0 && reg->smax_value <= S32_MAX) ||
+	       ((s32)sval < 0 &&
+		reg->smax_value <= 0 && reg->smin_value >= S32_MIN);
+}
+
 /* Adjusts the register min/max values in the case that the dst_reg is the
  * variable register that we are working on, and src_reg is a constant or we're
  * simply doing a BPF_K check.
@@ -4115,7 +4173,7 @@  static int is_branch_taken(struct bpf_reg_state *reg, u64 val, u8 opcode)
  */
 static void reg_set_min_max(struct bpf_reg_state *true_reg,
 			    struct bpf_reg_state *false_reg, u64 val,
-			    u8 opcode)
+			    u8 opcode, bool is_jmp32)
 {
 	s64 sval;
 
@@ -4128,7 +4186,8 @@  static void reg_set_min_max(struct bpf_reg_state *true_reg,
 	if (__is_pointer_value(false, false_reg))
 		return;
 
-	sval = (s64)val;
+	val = is_jmp32 ? (u32)val : val;
+	sval = is_jmp32 ? (s64)(s32)val : (s64)val;
 
 	switch (opcode) {
 	case BPF_JEQ:
@@ -4141,7 +4200,15 @@  static void reg_set_min_max(struct bpf_reg_state *true_reg,
 		 * if it is true we know the value for sure. Likewise for
 		 * BPF_JNE.
 		 */
-		__mark_reg_known(reg, val);
+		if (is_jmp32) {
+			u64 old_v = reg->var_off.value;
+			u64 hi_mask = ~0xffffffffULL;
+
+			reg->var_off.value = (old_v & hi_mask) | val;
+			reg->var_off.mask &= hi_mask;
+		} else {
+			__mark_reg_known(reg, val);
+		}
 		break;
 	}
 	case BPF_JSET:
@@ -4157,6 +4224,10 @@  static void reg_set_min_max(struct bpf_reg_state *true_reg,
 		u64 false_umax = opcode == BPF_JGT ? val    : val - 1;
 		u64 true_umin = opcode == BPF_JGT ? val + 1 : val;
 
+		if (is_jmp32) {
+			false_umax += gen_hi_max(false_reg->var_off);
+			true_umin += gen_hi_min(true_reg->var_off);
+		}
 		false_reg->umax_value = min(false_reg->umax_value, false_umax);
 		true_reg->umin_value = max(true_reg->umin_value, true_umin);
 		break;
@@ -4167,6 +4238,11 @@  static void reg_set_min_max(struct bpf_reg_state *true_reg,
 		s64 false_smax = opcode == BPF_JSGT ? sval    : sval - 1;
 		s64 true_smin = opcode == BPF_JSGT ? sval + 1 : sval;
 
+		/* If the full s64 was not sign-extended from s32 then don't
+		 * deduct further info.
+		 */
+		if (is_jmp32 && !cmp_val_with_extended_s64(sval, false_reg))
+			break;
 		false_reg->smax_value = min(false_reg->smax_value, false_smax);
 		true_reg->smin_value = max(true_reg->smin_value, true_smin);
 		break;
@@ -4177,6 +4253,10 @@  static void reg_set_min_max(struct bpf_reg_state *true_reg,
 		u64 false_umin = opcode == BPF_JLT ? val    : val + 1;
 		u64 true_umax = opcode == BPF_JLT ? val - 1 : val;
 
+		if (is_jmp32) {
+			false_umin += gen_hi_min(false_reg->var_off);
+			true_umax += gen_hi_max(true_reg->var_off);
+		}
 		false_reg->umin_value = max(false_reg->umin_value, false_umin);
 		true_reg->umax_value = min(true_reg->umax_value, true_umax);
 		break;
@@ -4187,6 +4267,8 @@  static void reg_set_min_max(struct bpf_reg_state *true_reg,
 		s64 false_smin = opcode == BPF_JSLT ? sval    : sval + 1;
 		s64 true_smax = opcode == BPF_JSLT ? sval - 1 : sval;
 
+		if (is_jmp32 && !cmp_val_with_extended_s64(sval, false_reg))
+			break;
 		false_reg->smin_value = max(false_reg->smin_value, false_smin);
 		true_reg->smax_value = min(true_reg->smax_value, true_smax);
 		break;
@@ -4213,14 +4295,15 @@  static void reg_set_min_max(struct bpf_reg_state *true_reg,
  */
 static void reg_set_min_max_inv(struct bpf_reg_state *true_reg,
 				struct bpf_reg_state *false_reg, u64 val,
-				u8 opcode)
+				u8 opcode, bool is_jmp32)
 {
 	s64 sval;
 
 	if (__is_pointer_value(false, false_reg))
 		return;
 
-	sval = (s64)val;
+	val = is_jmp32 ? (u32)val : val;
+	sval = is_jmp32 ? (s64)(s32)val : (s64)val;
 
 	switch (opcode) {
 	case BPF_JEQ:
@@ -4229,7 +4312,15 @@  static void reg_set_min_max_inv(struct bpf_reg_state *true_reg,
 		struct bpf_reg_state *reg =
 			opcode == BPF_JEQ ? true_reg : false_reg;
 
-		__mark_reg_known(reg, val);
+		if (is_jmp32) {
+			u64 old_v = reg->var_off.value;
+			u64 hi_mask = ~0xffffffffULL;
+
+			reg->var_off.value = (old_v & hi_mask) | val;
+			reg->var_off.mask &= hi_mask;
+		} else {
+			__mark_reg_known(reg, val);
+		}
 		break;
 	}
 	case BPF_JSET:
@@ -4245,6 +4336,10 @@  static void reg_set_min_max_inv(struct bpf_reg_state *true_reg,
 		u64 false_umin = opcode == BPF_JGT ? val    : val + 1;
 		u64 true_umax = opcode == BPF_JGT ? val - 1 : val;
 
+		if (is_jmp32) {
+			false_umin += gen_hi_min(false_reg->var_off);
+			true_umax += gen_hi_max(true_reg->var_off);
+		}
 		false_reg->umin_value = max(false_reg->umin_value, false_umin);
 		true_reg->umax_value = min(true_reg->umax_value, true_umax);
 		break;
@@ -4255,6 +4350,8 @@  static void reg_set_min_max_inv(struct bpf_reg_state *true_reg,
 		s64 false_smin = opcode == BPF_JSGT ? sval    : sval + 1;
 		s64 true_smax = opcode == BPF_JSGT ? sval - 1 : sval;
 
+		if (is_jmp32 && !cmp_val_with_extended_s64(sval, false_reg))
+			break;
 		false_reg->smin_value = max(false_reg->smin_value, false_smin);
 		true_reg->smax_value = min(true_reg->smax_value, true_smax);
 		break;
@@ -4265,6 +4362,10 @@  static void reg_set_min_max_inv(struct bpf_reg_state *true_reg,
 		u64 false_umax = opcode == BPF_JLT ? val    : val - 1;
 		u64 true_umin = opcode == BPF_JLT ? val + 1 : val;
 
+		if (is_jmp32) {
+			false_umax += gen_hi_max(false_reg->var_off);
+			true_umin += gen_hi_min(true_reg->var_off);
+		}
 		false_reg->umax_value = min(false_reg->umax_value, false_umax);
 		true_reg->umin_value = max(true_reg->umin_value, true_umin);
 		break;
@@ -4275,6 +4376,8 @@  static void reg_set_min_max_inv(struct bpf_reg_state *true_reg,
 		s64 false_smax = opcode == BPF_JSLT ? sval    : sval - 1;
 		s64 true_smin = opcode == BPF_JSLT ? sval + 1 : sval;
 
+		if (is_jmp32 && !cmp_val_with_extended_s64(sval, false_reg))
+			break;
 		false_reg->smax_value = min(false_reg->smax_value, false_smax);
 		true_reg->smin_value = max(true_reg->smin_value, true_smin);
 		break;
@@ -4416,6 +4519,10 @@  static bool try_match_pkt_pointers(const struct bpf_insn *insn,
 	if (BPF_SRC(insn->code) != BPF_X)
 		return false;
 
+	/* Pointers are always 64-bit. */
+	if (BPF_CLASS(insn->code) == BPF_JMP32)
+		return false;
+
 	switch (BPF_OP(insn->code)) {
 	case BPF_JGT:
 		if ((dst_reg->type == PTR_TO_PACKET &&
@@ -4508,16 +4615,18 @@  static int check_cond_jmp_op(struct bpf_verifier_env *env,
 	struct bpf_reg_state *regs = this_branch->frame[this_branch->curframe]->regs;
 	struct bpf_reg_state *dst_reg, *other_branch_regs;
 	u8 opcode = BPF_OP(insn->code);
+	bool is_jmp32;
 	int err;
 
-	if (opcode > BPF_JSLE) {
-		verbose(env, "invalid BPF_JMP opcode %x\n", opcode);
+	/* Only conditional jumps are expected to reach here. */
+	if (opcode == BPF_JA || opcode > BPF_JSLE) {
+		verbose(env, "invalid BPF_JMP/JMP32 opcode %x\n", opcode);
 		return -EINVAL;
 	}
 
 	if (BPF_SRC(insn->code) == BPF_X) {
 		if (insn->imm != 0) {
-			verbose(env, "BPF_JMP uses reserved fields\n");
+			verbose(env, "BPF_JMP/JMP32 uses reserved fields\n");
 			return -EINVAL;
 		}
 
@@ -4533,7 +4642,7 @@  static int check_cond_jmp_op(struct bpf_verifier_env *env,
 		}
 	} else {
 		if (insn->src_reg != BPF_REG_0) {
-			verbose(env, "BPF_JMP uses reserved fields\n");
+			verbose(env, "BPF_JMP/JMP32 uses reserved fields\n");
 			return -EINVAL;
 		}
 	}
@@ -4544,9 +4653,11 @@  static int check_cond_jmp_op(struct bpf_verifier_env *env,
 		return err;
 
 	dst_reg = &regs[insn->dst_reg];
+	is_jmp32 = BPF_CLASS(insn->code) == BPF_JMP32;
 
 	if (BPF_SRC(insn->code) == BPF_K) {
-		int pred = is_branch_taken(dst_reg, insn->imm, opcode);
+		int pred = is_branch_taken(dst_reg, insn->imm, opcode,
+					   is_jmp32);
 
 		if (pred == 1) {
 			 /* only follow the goto, ignore fall-through */
@@ -4574,30 +4685,51 @@  static int check_cond_jmp_op(struct bpf_verifier_env *env,
 	 * comparable.
 	 */
 	if (BPF_SRC(insn->code) == BPF_X) {
+		struct bpf_reg_state *src_reg = &regs[insn->src_reg];
+		struct bpf_reg_state lo_reg0 = *dst_reg;
+		struct bpf_reg_state lo_reg1 = *src_reg;
+		struct bpf_reg_state *src_lo, *dst_lo;
+
+		dst_lo = &lo_reg0;
+		src_lo = &lo_reg1;
+		coerce_reg_to_size(dst_lo, 4);
+		coerce_reg_to_size(src_lo, 4);
+
 		if (dst_reg->type == SCALAR_VALUE &&
-		    regs[insn->src_reg].type == SCALAR_VALUE) {
-			if (tnum_is_const(regs[insn->src_reg].var_off))
+		    src_reg->type == SCALAR_VALUE) {
+			if (tnum_is_const(src_reg->var_off) ||
+			    (is_jmp32 && tnum_is_const(src_lo->var_off)))
 				reg_set_min_max(&other_branch_regs[insn->dst_reg],
-						dst_reg, regs[insn->src_reg].var_off.value,
-						opcode);
-			else if (tnum_is_const(dst_reg->var_off))
+						dst_reg,
+						is_jmp32
+						? src_lo->var_off.value
+						: src_reg->var_off.value,
+						opcode, is_jmp32);
+			else if (tnum_is_const(dst_reg->var_off) ||
+				 (is_jmp32 && tnum_is_const(dst_lo->var_off)))
 				reg_set_min_max_inv(&other_branch_regs[insn->src_reg],
-						    &regs[insn->src_reg],
-						    dst_reg->var_off.value, opcode);
-			else if (opcode == BPF_JEQ || opcode == BPF_JNE)
+						    src_reg,
+						    is_jmp32
+						    ? dst_lo->var_off.value
+						    : dst_reg->var_off.value,
+						    opcode, is_jmp32);
+			else if (!is_jmp32 &&
+				 (opcode == BPF_JEQ || opcode == BPF_JNE))
 				/* Comparing for equality, we can combine knowledge */
 				reg_combine_min_max(&other_branch_regs[insn->src_reg],
 						    &other_branch_regs[insn->dst_reg],
-						    &regs[insn->src_reg],
-						    &regs[insn->dst_reg], opcode);
+						    src_reg, dst_reg, opcode);
 		}
 	} else if (dst_reg->type == SCALAR_VALUE) {
 		reg_set_min_max(&other_branch_regs[insn->dst_reg],
-					dst_reg, insn->imm, opcode);
+					dst_reg, insn->imm, opcode, is_jmp32);
 	}
 
-	/* detect if R == 0 where R is returned from bpf_map_lookup_elem() */
-	if (BPF_SRC(insn->code) == BPF_K &&
+	/* detect if R == 0 where R is returned from bpf_map_lookup_elem().
+	 * NOTE: these optimizations below are related with pointer comparison
+	 *       which will never be JMP32.
+	 */
+	if (!is_jmp32 && BPF_SRC(insn->code) == BPF_K &&
 	    insn->imm == 0 && (opcode == BPF_JEQ || opcode == BPF_JNE) &&
 	    reg_type_may_be_null(dst_reg->type)) {
 		/* Mark all identical registers in each branch as either
@@ -4926,7 +5058,8 @@  static int check_cfg(struct bpf_verifier_env *env)
 		goto check_state;
 	t = insn_stack[cur_stack - 1];
 
-	if (BPF_CLASS(insns[t].code) == BPF_JMP) {
+	if (BPF_CLASS(insns[t].code) == BPF_JMP ||
+	    BPF_CLASS(insns[t].code) == BPF_JMP32) {
 		u8 opcode = BPF_OP(insns[t].code);
 
 		if (opcode == BPF_EXIT) {
@@ -6082,7 +6215,7 @@  static int do_check(struct bpf_verifier_env *env)
 			if (err)
 				return err;
 
-		} else if (class == BPF_JMP) {
+		} else if (class == BPF_JMP || class == BPF_JMP32) {
 			u8 opcode = BPF_OP(insn->code);
 
 			if (opcode == BPF_CALL) {
@@ -6090,7 +6223,8 @@  static int do_check(struct bpf_verifier_env *env)
 				    insn->off != 0 ||
 				    (insn->src_reg != BPF_REG_0 &&
 				     insn->src_reg != BPF_PSEUDO_CALL) ||
-				    insn->dst_reg != BPF_REG_0) {
+				    insn->dst_reg != BPF_REG_0 ||
+				    class == BPF_JMP32) {
 					verbose(env, "BPF_CALL uses reserved fields\n");
 					return -EINVAL;
 				}
@@ -6106,7 +6240,8 @@  static int do_check(struct bpf_verifier_env *env)
 				if (BPF_SRC(insn->code) != BPF_K ||
 				    insn->imm != 0 ||
 				    insn->src_reg != BPF_REG_0 ||
-				    insn->dst_reg != BPF_REG_0) {
+				    insn->dst_reg != BPF_REG_0 ||
+				    class == BPF_JMP32) {
 					verbose(env, "BPF_JA uses reserved fields\n");
 					return -EINVAL;
 				}
@@ -6118,7 +6253,8 @@  static int do_check(struct bpf_verifier_env *env)
 				if (BPF_SRC(insn->code) != BPF_K ||
 				    insn->imm != 0 ||
 				    insn->src_reg != BPF_REG_0 ||
-				    insn->dst_reg != BPF_REG_0) {
+				    insn->dst_reg != BPF_REG_0 ||
+				    class == BPF_JMP32) {
 					verbose(env, "BPF_EXIT uses reserved fields\n");
 					return -EINVAL;
 				}