WARNING: THIS SITE IS A MIRROR OF GITHUB.COM / IT CANNOT LOGIN OR REGISTER ACCOUNTS / THE CONTENTS ARE PROVIDED AS-IS / THIS SITE ASSUMES NO RESPONSIBILITY FOR ANY DISPLAYED CONTENT OR LINKS / IF YOU FOUND SOMETHING MAY NOT GOOD FOR EVERYONE, CONTACT ADMIN AT ilovescratch@foxmail.com
Skip to content

Commit e40abcf

Browse files
authored
Add Websockets support
Signed-off-by: Tomasz Pietrek <[email protected]>
1 parent 1fbcfdd commit e40abcf

File tree

8 files changed

+306
-23
lines changed

8 files changed

+306
-23
lines changed

.github/workflows/test.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ jobs:
164164

165165
- name: Install msrv Rust on ubuntu-latest
166166
id: install-rust
167-
uses: dtolnay/rust-toolchain@1.70.0
167+
uses: dtolnay/rust-toolchain@1.79.0
168168
- name: Cache the build artifacts
169169
uses: Swatinem/rust-cache@v2
170170
with:

async-nats/Cargo.toml

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ name = "async-nats"
33
authors = ["Tomasz Pietrek <[email protected]>", "Casper Beyer <[email protected]>"]
44
version = "0.37.0"
55
edition = "2021"
6-
rust = "1.74.0"
6+
rust = "1.79.0"
77
description = "A async Rust NATS client"
88
license = "Apache-2.0"
99
documentation = "https://docs.rs/async-nats"
@@ -41,6 +41,8 @@ ring = { version = "0.17", optional = true }
4141
rand = "0.8"
4242
webpki = { package = "rustls-webpki", version = "0.102" }
4343
portable-atomic = "1"
44+
tokio-websockets = { version = "0.10", features = ["client", "rand", "rustls-native-roots"], optional = true }
45+
pin-project = "1.0"
4446

4547
[dev-dependencies]
4648
ring = "0.17"
@@ -57,13 +59,13 @@ jsonschema = "0.17.1"
5759
# for -Z minimal-versions
5860
num = "0.4.1"
5961

60-
6162
[features]
6263
default = ["server_2_10", "ring"]
6364
# Enables Service API for the client.
6465
service = []
65-
aws-lc-rs = ["dep:aws-lc-rs", "tokio-rustls/aws-lc-rs"]
66-
ring = ["dep:ring", "tokio-rustls/ring"]
66+
websockets = ["dep:tokio-websockets"]
67+
aws-lc-rs = ["dep:aws-lc-rs", "tokio-rustls/aws-lc-rs", "tokio-websockets/aws-lc-rs"]
68+
ring = ["dep:ring", "tokio-rustls/ring", "tokio-websockets/ring"]
6769
fips = ["aws-lc-rs", "tokio-rustls/fips"]
6870
# All experimental features are part of this feature flag.
6971
experimental = ["service"]

async-nats/src/connection.rs

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,14 @@ use std::sync::atomic::Ordering;
2323
use std::sync::Arc;
2424
use std::task::{Context, Poll};
2525

26+
#[cfg(feature = "websockets")]
27+
use {
28+
futures::{SinkExt, StreamExt},
29+
pin_project::pin_project,
30+
tokio::io::ReadBuf,
31+
tokio_websockets::WebSocketStream,
32+
};
33+
2634
use bytes::{Buf, Bytes, BytesMut};
2735
use tokio::io::{self, AsyncRead, AsyncReadExt, AsyncWrite};
2836

@@ -683,6 +691,108 @@ impl Connection {
683691
}
684692
}
685693

694+
#[cfg(feature = "websockets")]
695+
#[pin_project]
696+
pub(crate) struct WebSocketAdapter<T> {
697+
#[pin]
698+
pub(crate) inner: WebSocketStream<T>,
699+
pub(crate) read_buf: BytesMut,
700+
}
701+
702+
#[cfg(feature = "websockets")]
703+
impl<T> WebSocketAdapter<T> {
704+
pub(crate) fn new(inner: WebSocketStream<T>) -> Self {
705+
Self {
706+
inner,
707+
read_buf: BytesMut::new(),
708+
}
709+
}
710+
}
711+
712+
#[cfg(feature = "websockets")]
713+
impl<T> AsyncRead for WebSocketAdapter<T>
714+
where
715+
T: AsyncRead + AsyncWrite + Unpin,
716+
{
717+
fn poll_read(
718+
self: Pin<&mut Self>,
719+
cx: &mut Context<'_>,
720+
buf: &mut ReadBuf<'_>,
721+
) -> Poll<std::io::Result<()>> {
722+
let mut this = self.project();
723+
724+
loop {
725+
// If we have data in the read buffer, let's move it to the output buffer.
726+
if !this.read_buf.is_empty() {
727+
let len = std::cmp::min(buf.remaining(), this.read_buf.len());
728+
buf.put_slice(&this.read_buf.split_to(len));
729+
return Poll::Ready(Ok(()));
730+
}
731+
732+
match this.inner.poll_next_unpin(cx) {
733+
Poll::Ready(Some(Ok(message))) => {
734+
this.read_buf.extend_from_slice(message.as_payload());
735+
}
736+
Poll::Ready(Some(Err(e))) => {
737+
return Poll::Ready(Err(std::io::Error::new(std::io::ErrorKind::Other, e)));
738+
}
739+
Poll::Ready(None) => {
740+
return Poll::Ready(Err(std::io::Error::new(
741+
std::io::ErrorKind::UnexpectedEof,
742+
"WebSocket closed",
743+
)));
744+
}
745+
Poll::Pending => {
746+
return Poll::Pending;
747+
}
748+
}
749+
}
750+
}
751+
}
752+
753+
#[cfg(feature = "websockets")]
754+
impl<T> AsyncWrite for WebSocketAdapter<T>
755+
where
756+
T: AsyncRead + AsyncWrite + Unpin,
757+
{
758+
fn poll_write(
759+
self: Pin<&mut Self>,
760+
cx: &mut Context<'_>,
761+
buf: &[u8],
762+
) -> Poll<std::io::Result<usize>> {
763+
let mut this = self.project();
764+
765+
let data = buf.to_vec();
766+
match this.inner.poll_ready_unpin(cx) {
767+
Poll::Ready(Ok(())) => match this
768+
.inner
769+
.start_send_unpin(tokio_websockets::Message::binary(data))
770+
{
771+
Ok(()) => Poll::Ready(Ok(buf.len())),
772+
Err(e) => Poll::Ready(Err(std::io::Error::new(std::io::ErrorKind::Other, e))),
773+
},
774+
Poll::Ready(Err(e)) => {
775+
Poll::Ready(Err(std::io::Error::new(std::io::ErrorKind::Other, e)))
776+
}
777+
Poll::Pending => Poll::Pending,
778+
}
779+
}
780+
781+
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
782+
self.project()
783+
.inner
784+
.poll_flush_unpin(cx)
785+
.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))
786+
}
787+
788+
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
789+
self.project()
790+
.inner
791+
.poll_close_unpin(cx)
792+
.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))
793+
}
794+
}
795+
686796
#[cfg(test)]
687797
mod read_op {
688798
use std::sync::Arc;

async-nats/src/connector.rs

Lines changed: 78 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ use crate::auth::Auth;
1515
use crate::client::Statistics;
1616
use crate::connection::Connection;
1717
use crate::connection::State;
18+
#[cfg(feature = "websockets")]
19+
use crate::connection::WebSocketAdapter;
1820
use crate::options::CallbackArg1;
1921
use crate::tls;
2022
use crate::AuthError;
@@ -168,7 +170,11 @@ impl Connector {
168170
.map_err(|err| ConnectError::with_source(crate::ConnectErrorKind::Dns, err))?;
169171
for socket_addr in socket_addrs {
170172
match self
171-
.try_connect_to(&socket_addr, server_addr.tls_required(), server_addr.host())
173+
.try_connect_to(
174+
&socket_addr,
175+
server_addr.tls_required(),
176+
server_addr.clone(),
177+
)
172178
.await
173179
{
174180
Ok((server_info, mut connection)) => {
@@ -321,22 +327,76 @@ impl Connector {
321327
&self,
322328
socket_addr: &SocketAddr,
323329
tls_required: bool,
324-
tls_host: &str,
330+
server_addr: ServerAddr,
325331
) -> Result<(ServerInfo, Connection), ConnectError> {
326-
let tcp_stream = tokio::time::timeout(
327-
self.options.connection_timeout,
328-
TcpStream::connect(socket_addr),
329-
)
330-
.await
331-
.map_err(|_| ConnectError::new(crate::ConnectErrorKind::TimedOut))??;
332-
333-
tcp_stream.set_nodelay(true)?;
332+
let mut connection = match server_addr.scheme() {
333+
#[cfg(feature = "websockets")]
334+
"ws" => {
335+
let ws = tokio::time::timeout(
336+
self.options.connection_timeout,
337+
tokio_websockets::client::Builder::new()
338+
.uri(format!("{}://{}", server_addr.scheme(), socket_addr).as_str())
339+
.map_err(|err| {
340+
ConnectError::with_source(crate::ConnectErrorKind::ServerParse, err)
341+
})?
342+
.connect(),
343+
)
344+
.await
345+
.map_err(|_| ConnectError::new(crate::ConnectErrorKind::TimedOut))?
346+
.map_err(|err| ConnectError::with_source(crate::ConnectErrorKind::Io, err))?;
334347

335-
let mut connection = Connection::new(
336-
Box::new(tcp_stream),
337-
self.options.read_buffer_capacity.into(),
338-
self.connect_stats.clone(),
339-
);
348+
let con = WebSocketAdapter::new(ws.0);
349+
Connection::new(Box::new(con), 0, self.connect_stats.clone())
350+
}
351+
#[cfg(feature = "websockets")]
352+
"wss" => {
353+
let domain = webpki::types::ServerName::try_from(server_addr.host())
354+
.map_err(|err| ConnectError::with_source(crate::ConnectErrorKind::Tls, err))?;
355+
let tls_config =
356+
Arc::new(tls::config_tls(&self.options).await.map_err(|err| {
357+
ConnectError::with_source(crate::ConnectErrorKind::Tls, err)
358+
})?);
359+
let tls_connector = tokio_rustls::TlsConnector::from(tls_config);
360+
let ws = tokio::time::timeout(
361+
self.options.connection_timeout,
362+
tokio_websockets::client::Builder::new()
363+
.connector(&tokio_websockets::Connector::Rustls(tls_connector))
364+
.uri(
365+
format!(
366+
"{}://{}:{}",
367+
server_addr.scheme(),
368+
domain.to_str(),
369+
server_addr.port()
370+
)
371+
.as_str(),
372+
)
373+
.map_err(|err| {
374+
ConnectError::with_source(crate::ConnectErrorKind::ServerParse, err)
375+
})?
376+
.connect(),
377+
)
378+
.await
379+
.map_err(|_| ConnectError::new(crate::ConnectErrorKind::TimedOut))?
380+
.map_err(|err| ConnectError::with_source(crate::ConnectErrorKind::Io, err))?;
381+
let con = WebSocketAdapter::new(ws.0);
382+
Connection::new(Box::new(con), 0, self.connect_stats.clone())
383+
}
384+
_ => {
385+
let tcp_stream = tokio::time::timeout(
386+
self.options.connection_timeout,
387+
TcpStream::connect(socket_addr),
388+
)
389+
.await
390+
.map_err(|_| ConnectError::new(crate::ConnectErrorKind::TimedOut))??;
391+
tcp_stream.set_nodelay(true)?;
392+
393+
Connection::new(
394+
Box::new(tcp_stream),
395+
self.options.read_buffer_capacity.into(),
396+
self.connect_stats.clone(),
397+
)
398+
}
399+
};
340400

341401
let tls_connection = |connection: Connection| async {
342402
let tls_config = Arc::new(
@@ -346,7 +406,7 @@ impl Connector {
346406
);
347407
let tls_connector = tokio_rustls::TlsConnector::from(tls_config);
348408

349-
let domain = webpki::types::ServerName::try_from(tls_host)
409+
let domain = webpki::types::ServerName::try_from(server_addr.host())
350410
.map_err(|err| ConnectError::with_source(crate::ConnectErrorKind::Tls, err))?;
351411

352412
let tls_stream = tls_connector
@@ -363,7 +423,7 @@ impl Connector {
363423
// If `tls_first` was set, establish TLS connection before getting INFO.
364424
// There is no point in checking if tls is required, because
365425
// the connection has to be be upgraded to TLS anyway as it's different flow.
366-
if self.options.tls_first {
426+
if self.options.tls_first && !server_addr.is_websocket() {
367427
connection = tls_connection(connection).await?;
368428
}
369429

@@ -386,6 +446,7 @@ impl Connector {
386446

387447
// If `tls_first` was not set, establish TLS connection if it is required.
388448
if !self.options.tls_first
449+
&& !server_addr.is_websocket()
389450
&& (self.options.tls_required || info.tls_required || tls_required)
390451
{
391452
connection = tls_connection(connection).await?;

async-nats/src/lib.rs

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1455,7 +1455,11 @@ impl FromStr for ServerAddr {
14551455
impl ServerAddr {
14561456
/// Check if the URL is a valid NATS server address.
14571457
pub fn from_url(url: Url) -> io::Result<Self> {
1458-
if url.scheme() != "nats" && url.scheme() != "tls" {
1458+
if url.scheme() != "nats"
1459+
&& url.scheme() != "tls"
1460+
&& url.scheme() != "ws"
1461+
&& url.scheme() != "wss"
1462+
{
14591463
return Err(std::io::Error::new(
14601464
ErrorKind::InvalidInput,
14611465
format!("invalid scheme for NATS server URL: {}", url.scheme()),
@@ -1480,6 +1484,10 @@ impl ServerAddr {
14801484
self.0.username() != ""
14811485
}
14821486

1487+
pub fn scheme(&self) -> &str {
1488+
self.0.scheme()
1489+
}
1490+
14831491
/// Returns the host.
14841492
pub fn host(&self) -> &str {
14851493
match self.0.host() {
@@ -1493,6 +1501,10 @@ impl ServerAddr {
14931501
}
14941502
}
14951503

1504+
pub fn is_websocket(&self) -> bool {
1505+
self.0.scheme() == "ws" || self.0.scheme() == "wss"
1506+
}
1507+
14961508
/// Returns the port.
14971509
pub fn port(&self) -> u16 {
14981510
self.0.port().unwrap_or(4222)

async-nats/tests/configs/ws.conf

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
jetstream {}
2+
websocket {
3+
port: 8444
4+
no_tls: true
5+
}
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
authorization {
2+
user: derek
3+
password: porkchop
4+
timeout: 1
5+
}
6+
7+
websocket {
8+
tls {
9+
10+
cert_file: "./tests/configs/certs/server-cert.pem"
11+
key_file: "./tests/configs/certs/server-key.pem"
12+
ca_file: "./tests/configs/certs/rootCA.pem"
13+
}
14+
port: 8445
15+
}

0 commit comments

Comments
 (0)