diff --git a/include/net/tcp.h b/include/net/tcp.h
index 220f54c..ea81572 100644
--- a/include/net/tcp.h
+++ b/include/net/tcp.h
@@ -1037,10 +1037,15 @@ static inline void tcp_mib_init(struct net *net)
 }
 
 /* from STCP */
-static inline void tcp_clear_all_retrans_hints(struct tcp_sock *tp)
+static inline void tcp_clear_retrans_hints_partial(struct tcp_sock *tp)
 {
 	tp->lost_skb_hint = NULL;
 	tp->scoreboard_skb_hint = NULL;
+}
+
+static inline void tcp_clear_all_retrans_hints(struct tcp_sock *tp)
+{
+	tcp_clear_retrans_hints_partial(tp);
 	tp->retransmit_skb_hint = NULL;
 }
 
diff --git a/net/ipv4/tcp_input.c b/net/ipv4/tcp_input.c
index d017aed..44a4fff 100644
--- a/net/ipv4/tcp_input.c
+++ b/net/ipv4/tcp_input.c
@@ -2925,7 +2925,9 @@ static int tcp_clean_rtx_queue(struct sock *sk, int prior_fackets)
 
 		tcp_unlink_write_queue(skb, sk);
 		sk_wmem_free_skb(sk, skb);
-		tcp_clear_all_retrans_hints(tp);
+		tcp_clear_retrans_hints_partial(tp);
+		if (skb == tp->retransmit_skb_hint)
+			tp->retransmit_skb_hint = NULL;
 	}
 
 	if (skb && (TCP_SKB_CB(skb)->sacked & TCPCB_SACKED_ACKED))
diff --git a/net/ipv4/tcp_output.c b/net/ipv4/tcp_output.c
index 97873cc..63b9ce1 100644
--- a/net/ipv4/tcp_output.c
+++ b/net/ipv4/tcp_output.c
@@ -748,7 +748,7 @@ int tcp_fragment(struct sock *sk, struct sk_buff *skb, u32 len,
 
 	BUG_ON(len > skb->len);
 
-	tcp_clear_all_retrans_hints(tp);
+	tcp_clear_retrans_hints_partial(tp);
 	nsize = skb_headlen(skb) - len;
 	if (nsize < 0)
 		nsize = 0;
@@ -1821,7 +1821,9 @@ static void tcp_retrans_try_collapse(struct sock *sk, struct sk_buff *skb,
 	tp->packets_out -= tcp_skb_pcount(next_skb);
 
 	/* changed transmit queue under us so clear hints */
-	tcp_clear_all_retrans_hints(tp);
+	tcp_clear_retrans_hints_partial(tp);
+	if (next_skb == tp->retransmit_skb_hint)
+		tp->retransmit_skb_hint = skb;
 
 	sk_wmem_free_skb(sk, next_skb);
 }
@@ -1851,7 +1853,7 @@ void tcp_simple_retransmit(struct sock *sk)
 		}
 	}
 
-	tcp_clear_all_retrans_hints(tp);
+	tcp_clear_retrans_hints_partial(tp);
 
 	if (prior_lost == tp->lost_out)
 		return;
