WARNING: THIS SITE IS A MIRROR OF GITHUB.COM / IT CANNOT LOGIN OR REGISTER ACCOUNTS / THE CONTENTS ARE PROVIDED AS-IS / THIS SITE ASSUMES NO RESPONSIBILITY FOR ANY DISPLAYED CONTENT OR LINKS / IF YOU FOUND SOMETHING MAY NOT GOOD FOR EVERYONE, CONTACT ADMIN AT ilovescratch@foxmail.com
Skip to content

Commit 53cc394

Browse files
committed
PROTON-2879: [Python] Convenience iterators for sessions and links
1 parent 96e5b8d commit 53cc394

File tree

2 files changed

+150
-4
lines changed

2 files changed

+150
-4
lines changed

python/proton/_endpoints.py

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -434,7 +434,7 @@ def session(self) -> 'Session':
434434
else:
435435
return Session(ssn)
436436

437-
def session_head(self, mask: int) -> Optional['Session']:
437+
def session_head(self, mask: EndpointState) -> Optional['Session']:
438438
"""
439439
Retrieve the first session from a given connection that matches the
440440
specified state mask.
@@ -452,7 +452,19 @@ def session_head(self, mask: int) -> Optional['Session']:
452452
"""
453453
return Session.wrap(pn_session_head(self._impl, mask))
454454

455-
def link_head(self, mask: int) -> Optional[Union['Sender', 'Receiver']]:
455+
def sessions(self, mask: EndpointState) -> Iterator['Session']:
456+
"""
457+
Returns a generator of sessions owned by the connection with the
458+
given state mask.
459+
460+
:return: Generator of sessions.
461+
"""
462+
session = self.session_head(mask)
463+
while session:
464+
yield session
465+
session = session.next(mask)
466+
467+
def link_head(self, mask: EndpointState) -> Optional['Link']:
456468
"""
457469
Retrieve the first link that matches the given state mask.
458470
@@ -469,6 +481,18 @@ def link_head(self, mask: int) -> Optional[Union['Sender', 'Receiver']]:
469481
"""
470482
return Link.wrap(pn_link_head(self._impl, mask))
471483

484+
def links(self, mask: EndpointState) -> Iterator['Link']:
485+
"""
486+
Returns a generator of links owned by this connection with the
487+
given state mask.
488+
489+
:return: Generator of links.
490+
"""
491+
link = self.link_head(mask)
492+
while link:
493+
yield link
494+
link = link.next(mask)
495+
472496
@property
473497
def error(self):
474498
"""
@@ -619,7 +643,7 @@ def close(self) -> None:
619643
self._update_cond()
620644
pn_session_close(self._impl)
621645

622-
def next(self, mask):
646+
def next(self, mask: EndpointState) -> Optional['Session']:
623647
"""
624648
Retrieve the next session for this connection that matches the
625649
specified state mask.
@@ -935,7 +959,7 @@ def queued(self) -> int:
935959
"""
936960
return pn_link_queued(self._impl)
937961

938-
def next(self, mask: int) -> Optional[Union['Sender', 'Receiver']]:
962+
def next(self, mask: EndpointState) -> Optional['Link']:
939963
"""
940964
Retrieve the next link that matches the given state mask.
941965

python/tests/proton_tests/engine.py

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -510,6 +510,62 @@ def test_set_get_outgoing_window(self):
510510
self.ssn.outgoing_window = 1024
511511
assert self.ssn.outgoing_window == 1024
512512

513+
def test_multiple_iterator(self):
514+
ssn1 = self.ssn
515+
ssn2 = self.c1.session()
516+
ssn3 = self.c1.session()
517+
518+
# Check that the iterator gets all sessions for no mask
519+
ssns = [ssn1, ssn2, ssn3]
520+
for ssn in self.c1.sessions(0):
521+
assert ssn in ssns, ssn
522+
ssns.remove(ssn)
523+
assert not ssns, ssns
524+
525+
# Check that every session starts uninitialized local and remote
526+
ssns = [ssn1, ssn2, ssn3]
527+
for ssn in self.c1.sessions(Endpoint.LOCAL_UNINIT | Endpoint.REMOTE_UNINIT):
528+
assert ssn in ssns, ssn
529+
ssns.remove(ssn)
530+
assert not ssns, ssns
531+
532+
for ssn in self.c1.sessions(0):
533+
ssn.open()
534+
535+
self.pump()
536+
537+
ssns = [ssn1, ssn2, ssn3]
538+
for ssn in self.c1.sessions(Endpoint.LOCAL_ACTIVE | Endpoint.REMOTE_UNINIT):
539+
assert ssn in ssns, ssn
540+
ssns.remove(ssn)
541+
assert not ssns, ssns
542+
543+
ssns = [ssn for ssn in self.c2.sessions(Endpoint.LOCAL_UNINIT | Endpoint.REMOTE_ACTIVE)]
544+
assert len(ssns) == 3, ssns
545+
546+
for ssn in self.c2.sessions(0):
547+
ssn.open()
548+
549+
self.pump()
550+
551+
# Check that every session is now active local and remote
552+
ssns = [ssn1, ssn2, ssn3]
553+
for ssn in self.c1.sessions(Endpoint.LOCAL_ACTIVE | Endpoint.REMOTE_ACTIVE):
554+
assert ssn in ssns, ssn
555+
ssns.remove(ssn)
556+
assert not ssns, ssns
557+
558+
for ssn in self.c2.sessions(0):
559+
ssn.close()
560+
561+
self.pump()
562+
563+
# Check that every session is now closed local and remote
564+
ssns = [ssn1, ssn2, ssn3]
565+
for ssn in self.c1.sessions(Endpoint.LOCAL_CLOSED | Endpoint.REMOTE_CLOSED):
566+
assert ssn in ssns, ssn
567+
ssns.remove(ssn)
568+
513569

514570
class LinkTest(Test):
515571

@@ -621,6 +677,72 @@ def test_multiple(self):
621677
conn.close()
622678
self.pump()
623679

680+
def test_multiple_iterator(self):
681+
snd1 = self.snd
682+
sess1 = self.snd.session
683+
snd2 = sess1.sender('sender2')
684+
snd3 = sess1.sender('sender3')
685+
686+
# Check that the iterator gets all senders for no mask, and all senders
687+
# are uninitialized local and remote
688+
snds = [snd1, snd2, snd3]
689+
for snd in sess1.connection.links(0):
690+
assert snd.state == Endpoint.LOCAL_UNINIT | Endpoint.REMOTE_UNINIT, snd.state
691+
assert snd in snds, snd
692+
snds.remove(snd)
693+
assert not snds, snds
694+
695+
for snd in sess1.connection.links(0):
696+
snd.open()
697+
698+
self.pump()
699+
700+
# Check that every sender starts uninitialized local and remote
701+
snds = [snd1, snd2, snd3]
702+
for snd in sess1.connection.links(Endpoint.LOCAL_ACTIVE | Endpoint.REMOTE_UNINIT):
703+
assert snd in snds, f"{snd}, {snd.state} not in {snds}"
704+
snds.remove(snd)
705+
assert not snds, snds
706+
707+
rcvs = [rcv for rcv in self.rcv.connection.links(Endpoint.LOCAL_UNINIT | Endpoint.REMOTE_ACTIVE)]
708+
assert len(rcvs) == 3, rcvs
709+
710+
for rcv in self.rcv.connection.links(0):
711+
rcv.open()
712+
713+
self.pump()
714+
715+
# Check that every session is now active local and remote
716+
snds = [snd1, snd2, snd3]
717+
for snd in sess1.connection.links(Endpoint.LOCAL_ACTIVE | Endpoint.REMOTE_ACTIVE):
718+
assert snd in snds, f"{snd}, {snd.state} not in {snds}"
719+
snds.remove(snd)
720+
assert not snds, snds
721+
722+
for snd in sess1.connection.links(0):
723+
snd.close()
724+
725+
self.pump()
726+
727+
# Check that every session is now closed local and active remote
728+
snds = [snd1, snd2, snd3]
729+
for snd in sess1.connection.links(Endpoint.LOCAL_CLOSED | Endpoint.REMOTE_ACTIVE):
730+
assert snd in snds, f"{snd}, {snd.state} not in {snds}"
731+
snds.remove(snd)
732+
assert not snds, snds
733+
734+
for rcv in self.rcv.connection.links(0):
735+
rcv.close()
736+
737+
self.pump()
738+
739+
# Check that every session is now closed local and remote
740+
snds = [snd1, snd2, snd3]
741+
for snd in sess1.connection.links(Endpoint.LOCAL_CLOSED | Endpoint.REMOTE_CLOSED):
742+
assert snd in snds, f"{snd}, {snd.state} not in {snds}"
743+
snds.remove(snd)
744+
assert not snds, snds
745+
624746
def test_closing_session(self):
625747
self.snd.open()
626748
self.rcv.open()

0 commit comments

Comments
 (0)