WARNING: THIS SITE IS A MIRROR OF GITHUB.COM / IT CANNOT LOGIN OR REGISTER ACCOUNTS / THE CONTENTS ARE PROVIDED AS-IS / THIS SITE ASSUMES NO RESPONSIBILITY FOR ANY DISPLAYED CONTENT OR LINKS / IF YOU FOUND SOMETHING MAY NOT GOOD FOR EVERYONE, CONTACT ADMIN AT ilovescratch@foxmail.com
Skip to content

Commit f251456

Browse files
committed
added shaperQueue for transmission queue management
1 parent 489b45e commit f251456

File tree

3 files changed

+293
-8
lines changed

3 files changed

+293
-8
lines changed

session.go

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
package smux
2424

2525
import (
26-
"container/heap"
2726
"encoding/binary"
2827
"errors"
2928
"io"
@@ -480,22 +479,22 @@ func (s *Session) keepalive() {
480479
// shaperLoop implements a priority queue for write requests,
481480
// some control messages are prioritized over data messages
482481
func (s *Session) shaperLoop() {
483-
var reqs shaperHeap
482+
reqs := NewShaperQueue()
484483
var next writeRequest
485484
var chWrite chan writeRequest
486485
var chShaper chan writeRequest
487486

488487
for {
489488
// chWrite is not available until it has packet to send
490-
if len(reqs) > 0 {
489+
if !reqs.IsEmpty() {
491490
chWrite = s.writes
492-
next = heap.Pop(&reqs).(writeRequest)
491+
next, _ = reqs.Pop()
493492
} else {
494493
chWrite = nil
495494
}
496495

497496
// control heap size, chShaper is not available until packets are less than maximum allowed
498-
if len(reqs) >= maxShaperSize {
497+
if reqs.Len() >= maxShaperSize {
499498
chShaper = nil
500499
} else {
501500
chShaper = s.shaper
@@ -510,10 +509,10 @@ func (s *Session) shaperLoop() {
510509
case <-s.die:
511510
return
512511
case r := <-chShaper:
513-
if chWrite != nil { // next is valid, reshape
514-
heap.Push(&reqs, next)
512+
if chWrite != nil { // re-enqueue the request if there is a pending write
513+
reqs.Push(next)
515514
}
516-
heap.Push(&reqs, r)
515+
reqs.Push(r)
517516
case chWrite <- next:
518517
}
519518
}

shaper.go

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,12 @@
2222

2323
package smux
2424

25+
import (
26+
"container/heap"
27+
"sync"
28+
"time"
29+
)
30+
2531
// _itimediff returns the time difference between two uint32 values.
2632
// The result is a signed 32-bit integer representing the difference between 'later' and 'earlier'.
2733
func _itimediff(later, earlier uint32) int32 {
@@ -54,3 +60,93 @@ func (h *shaperHeap) Pop() interface{} {
5460
*h = old[0 : n-1]
5561
return x
5662
}
63+
64+
const (
65+
streamExpireDuration = 1 * time.Minute
66+
)
67+
68+
type shaperQueue struct {
69+
streams map[uint32]*shaperHeap
70+
lastVisits map[uint32]time.Time
71+
allSids []uint32
72+
nextIdx uint32
73+
count uint32
74+
mu sync.Mutex
75+
}
76+
77+
func NewShaperQueue() *shaperQueue {
78+
return &shaperQueue{
79+
streams: make(map[uint32]*shaperHeap),
80+
lastVisits: make(map[uint32]time.Time),
81+
}
82+
}
83+
84+
func (sq *shaperQueue) Push(req writeRequest) {
85+
sq.mu.Lock()
86+
defer sq.mu.Unlock()
87+
sid := req.frame.sid
88+
if _, ok := sq.streams[sid]; !ok {
89+
sq.streams[sid] = new(shaperHeap)
90+
sq.allSids = append(sq.allSids, sid)
91+
}
92+
h := sq.streams[sid]
93+
heap.Push(h, req)
94+
sq.lastVisits[sid] = time.Now()
95+
sq.count++
96+
}
97+
98+
// Pop uses Round Robin to pop writeRequests from the shaperQueue.
99+
func (sq *shaperQueue) Pop() (req writeRequest, ok bool) {
100+
sq.mu.Lock()
101+
defer sq.mu.Unlock()
102+
103+
if len(sq.allSids) == 0 {
104+
return writeRequest{}, false
105+
}
106+
107+
start := sq.nextIdx % uint32(len(sq.allSids))
108+
109+
// loop through all streams in a round-robin manner
110+
for i := 0; i < len(sq.allSids); i++ {
111+
idx := (int(start) + i) % len(sq.allSids)
112+
sid := sq.allSids[idx]
113+
h := sq.streams[sid]
114+
if h == nil || h.Len() == 0 {
115+
continue
116+
}
117+
118+
// pop from the heap
119+
req := heap.Pop(h).(writeRequest)
120+
sq.count--
121+
122+
// If the heap is empty after popping, remove it from the map
123+
if h.Len() == 0 && sq.lastVisits[sid].Add(streamExpireDuration).Before(time.Now()) {
124+
delete(sq.streams, sid)
125+
delete(sq.lastVisits, sid)
126+
// copy the rest of allSids to overwrite the removed sid
127+
sq.allSids = append(sq.allSids[:idx], sq.allSids[idx+1:]...)
128+
}
129+
130+
// update nextSid for round-robin
131+
if len(sq.allSids) == 0 {
132+
sq.nextIdx = 0
133+
} else {
134+
sq.nextIdx = uint32((idx + 1) % len(sq.allSids))
135+
}
136+
return req, true
137+
}
138+
139+
return writeRequest{}, false
140+
}
141+
142+
func (sq *shaperQueue) IsEmpty() bool {
143+
sq.mu.Lock()
144+
defer sq.mu.Unlock()
145+
return sq.count == 0
146+
}
147+
148+
func (sq *shaperQueue) Len() int {
149+
sq.mu.Lock()
150+
defer sq.mu.Unlock()
151+
return int(sq.count)
152+
}

shaper_test.go

Lines changed: 190 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,11 @@ package smux
2424

2525
import (
2626
"container/heap"
27+
"fmt"
28+
"math/rand"
29+
"sync"
2730
"testing"
31+
"time"
2832
)
2933

3034
func TestShaper(t *testing.T) {
@@ -70,3 +74,189 @@ func TestShaper2(t *testing.T) {
7074
t.Log("sid:", w.frame.sid, "seq:", w.seq)
7175
}
7276
}
77+
78+
func TestShaperQueueFairness(t *testing.T) {
79+
rand.Seed(time.Now().UnixNano())
80+
81+
sq := NewShaperQueue()
82+
83+
const streams = 10
84+
const testDuration = 10 * time.Second
85+
86+
var wg sync.WaitGroup
87+
sendCount := make([]uint64, streams)
88+
89+
stop := make(chan struct{})
90+
91+
// Producers: each stream pushes packets
92+
for sid := 0; sid < streams; sid++ {
93+
sid := sid
94+
wg.Add(1)
95+
go func() {
96+
defer wg.Done()
97+
seq := uint32(0)
98+
for {
99+
select {
100+
case <-stop:
101+
return
102+
default:
103+
}
104+
sq.Push(writeRequest{
105+
frame: Frame{sid: uint32(sid)},
106+
seq: seq,
107+
})
108+
seq++
109+
time.Sleep(time.Duration(rand.Intn(300)) * time.Microsecond)
110+
}
111+
}()
112+
}
113+
114+
// Consumer: slow network, 1 pop every 10ms
115+
wg.Add(1)
116+
go func() {
117+
defer wg.Done()
118+
ticker := time.NewTicker(10 * time.Millisecond)
119+
for {
120+
select {
121+
case <-stop:
122+
return
123+
case <-ticker.C:
124+
req, ok := sq.Pop()
125+
if ok {
126+
sendCount[req.frame.sid]++
127+
}
128+
}
129+
}
130+
}()
131+
132+
// ---- NEW: periodic live report ----
133+
go func() {
134+
ticker := time.NewTicker(500 * time.Millisecond)
135+
for {
136+
select {
137+
case <-stop:
138+
return
139+
case <-ticker.C:
140+
fmt.Printf("[DEBUG] Current counts: %v\n", sendCount)
141+
}
142+
}
143+
}()
144+
145+
// run test
146+
time.Sleep(testDuration)
147+
close(stop)
148+
wg.Wait()
149+
150+
// ---- final report ----
151+
fmt.Println("=== FINAL COUNTS ===")
152+
fmt.Println(sendCount)
153+
154+
// ---- fairness check ----
155+
total := uint64(0)
156+
for _, c := range sendCount {
157+
total += c
158+
}
159+
avg := total / streams
160+
tolerance := avg / 4 // 25%
161+
162+
for sid, c := range sendCount {
163+
if c < avg-tolerance || c > avg+tolerance {
164+
t.Errorf("stream %d unfair: got %d, avg %d", sid, c, avg)
165+
}
166+
}
167+
}
168+
169+
func TestShaperQueue_FastWriteSlowRead(t *testing.T) {
170+
rand.Seed(time.Now().UnixNano())
171+
172+
const (
173+
streams = 10
174+
duration = 10 * time.Second
175+
producerWait = 1 * time.Microsecond // super fast writing
176+
consumerWait = 15 * time.Millisecond // super slow reading
177+
)
178+
179+
sq := NewShaperQueue()
180+
181+
sendCount := make([]uint64, streams)
182+
stop := make(chan struct{})
183+
var wg sync.WaitGroup
184+
185+
// Producers: extremely fast writers
186+
for sid := 0; sid < streams; sid++ {
187+
sid := sid
188+
wg.Add(1)
189+
go func() {
190+
defer wg.Done()
191+
seq := uint32(0)
192+
for {
193+
select {
194+
case <-stop:
195+
return
196+
default:
197+
}
198+
199+
sq.Push(writeRequest{
200+
frame: Frame{sid: uint32(sid)},
201+
seq: seq,
202+
})
203+
seq++
204+
time.Sleep(producerWait)
205+
}
206+
}()
207+
}
208+
209+
// Consumer: very slow reader
210+
wg.Add(1)
211+
go func() {
212+
defer wg.Done()
213+
for {
214+
select {
215+
case <-stop:
216+
return
217+
default:
218+
}
219+
220+
req, ok := sq.Pop()
221+
if ok {
222+
sendCount[req.frame.sid]++
223+
}
224+
225+
time.Sleep(consumerWait)
226+
}
227+
}()
228+
229+
// Periodic monitor
230+
go func() {
231+
ticker := time.NewTicker(500 * time.Millisecond)
232+
for {
233+
select {
234+
case <-stop:
235+
return
236+
case <-ticker.C:
237+
fmt.Printf("[DEBUG] Queue size=%d, counts=%v\n", sq.count, sendCount)
238+
}
239+
}
240+
}()
241+
242+
// Run test
243+
time.Sleep(duration)
244+
close(stop)
245+
wg.Wait()
246+
247+
fmt.Printf("=== FINAL ===\ncounts=%v\nqueue remaining=%d\n", sendCount, sq.count)
248+
249+
// Check fairness
250+
total := uint64(0)
251+
for _, v := range sendCount {
252+
total += v
253+
}
254+
avg := total / streams
255+
tolerance := avg / 3 // allow 33%
256+
257+
for sid, c := range sendCount {
258+
if c < avg-tolerance || c > avg+tolerance {
259+
t.Errorf("stream %d unfair: got %d, avg %d", sid, c, avg)
260+
}
261+
}
262+
}

0 commit comments

Comments
 (0)