@@ -192,7 +192,9 @@ struct sk_buff *__udp_gso_segment(struct sk_buff *gso_skb,
unsigned int mss, __sum16 check)
{
struct udphdr *uh = udp_hdr(gso_skb);
- struct sk_buff *segs;
+ struct sock *sk = gso_skb->sk;
+ struct sk_buff *segs, *seg;
+ unsigned int sum_truesize = 0;
unsigned int hdrlen;
if (gso_skb->len <= sizeof(*uh) + mss)
@@ -203,9 +205,23 @@ struct sk_buff *__udp_gso_segment(struct sk_buff *gso_skb,
skb_pull(gso_skb, sizeof(*uh));
hdrlen = gso_skb->data - skb_mac_header(gso_skb);
+ /* clear destructor to avoid skb_segment assigning it to tail */
+ WARN_ON_ONCE(gso_skb->destructor != sock_wfree);
+ gso_skb->destructor = NULL;
+
segs = skb_segment(gso_skb, features);
- if (unlikely(IS_ERR_OR_NULL(segs)))
+ if (unlikely(IS_ERR_OR_NULL(segs))) {
+ gso_skb->destructor = sock_wfree;
return segs;
+ }
+
+ for (seg = segs; seg; seg = seg->next) {
+ seg->destructor = sock_wfree;
+ seg->sk = sk;
+ sum_truesize += seg->truesize;
+ }
+
+ refcount_add(sum_truesize - gso_skb->truesize, &sk->sk_wmem_alloc);
/* If last packet is not full, fix up its header */
if (segs->prev->len != hdrlen + mss) {