@@ -15,6 +15,8 @@ use crate::auth::Auth;
1515use crate :: client:: Statistics ;
1616use crate :: connection:: Connection ;
1717use crate :: connection:: State ;
18+ #[ cfg( feature = "websockets" ) ]
19+ use crate :: connection:: WebSocketAdapter ;
1820use crate :: options:: CallbackArg1 ;
1921use crate :: tls;
2022use crate :: AuthError ;
@@ -168,7 +170,11 @@ impl Connector {
168170 . map_err ( |err| ConnectError :: with_source ( crate :: ConnectErrorKind :: Dns , err) ) ?;
169171 for socket_addr in socket_addrs {
170172 match self
171- . try_connect_to ( & socket_addr, server_addr. tls_required ( ) , server_addr. host ( ) )
173+ . try_connect_to (
174+ & socket_addr,
175+ server_addr. tls_required ( ) ,
176+ server_addr. clone ( ) ,
177+ )
172178 . await
173179 {
174180 Ok ( ( server_info, mut connection) ) => {
@@ -321,22 +327,76 @@ impl Connector {
321327 & self ,
322328 socket_addr : & SocketAddr ,
323329 tls_required : bool ,
324- tls_host : & str ,
330+ server_addr : ServerAddr ,
325331 ) -> Result < ( ServerInfo , Connection ) , ConnectError > {
326- let tcp_stream = tokio:: time:: timeout (
327- self . options . connection_timeout ,
328- TcpStream :: connect ( socket_addr) ,
329- )
330- . await
331- . map_err ( |_| ConnectError :: new ( crate :: ConnectErrorKind :: TimedOut ) ) ??;
332-
333- tcp_stream. set_nodelay ( true ) ?;
332+ let mut connection = match server_addr. scheme ( ) {
333+ #[ cfg( feature = "websockets" ) ]
334+ "ws" => {
335+ let ws = tokio:: time:: timeout (
336+ self . options . connection_timeout ,
337+ tokio_websockets:: client:: Builder :: new ( )
338+ . uri ( format ! ( "{}://{}" , server_addr. scheme( ) , socket_addr) . as_str ( ) )
339+ . map_err ( |err| {
340+ ConnectError :: with_source ( crate :: ConnectErrorKind :: ServerParse , err)
341+ } ) ?
342+ . connect ( ) ,
343+ )
344+ . await
345+ . map_err ( |_| ConnectError :: new ( crate :: ConnectErrorKind :: TimedOut ) ) ?
346+ . map_err ( |err| ConnectError :: with_source ( crate :: ConnectErrorKind :: Io , err) ) ?;
334347
335- let mut connection = Connection :: new (
336- Box :: new ( tcp_stream) ,
337- self . options . read_buffer_capacity . into ( ) ,
338- self . connect_stats . clone ( ) ,
339- ) ;
348+ let con = WebSocketAdapter :: new ( ws. 0 ) ;
349+ Connection :: new ( Box :: new ( con) , 0 , self . connect_stats . clone ( ) )
350+ }
351+ #[ cfg( feature = "websockets" ) ]
352+ "wss" => {
353+ let domain = webpki:: types:: ServerName :: try_from ( server_addr. host ( ) )
354+ . map_err ( |err| ConnectError :: with_source ( crate :: ConnectErrorKind :: Tls , err) ) ?;
355+ let tls_config =
356+ Arc :: new ( tls:: config_tls ( & self . options ) . await . map_err ( |err| {
357+ ConnectError :: with_source ( crate :: ConnectErrorKind :: Tls , err)
358+ } ) ?) ;
359+ let tls_connector = tokio_rustls:: TlsConnector :: from ( tls_config) ;
360+ let ws = tokio:: time:: timeout (
361+ self . options . connection_timeout ,
362+ tokio_websockets:: client:: Builder :: new ( )
363+ . connector ( & tokio_websockets:: Connector :: Rustls ( tls_connector) )
364+ . uri (
365+ format ! (
366+ "{}://{}:{}" ,
367+ server_addr. scheme( ) ,
368+ domain. to_str( ) ,
369+ server_addr. port( )
370+ )
371+ . as_str ( ) ,
372+ )
373+ . map_err ( |err| {
374+ ConnectError :: with_source ( crate :: ConnectErrorKind :: ServerParse , err)
375+ } ) ?
376+ . connect ( ) ,
377+ )
378+ . await
379+ . map_err ( |_| ConnectError :: new ( crate :: ConnectErrorKind :: TimedOut ) ) ?
380+ . map_err ( |err| ConnectError :: with_source ( crate :: ConnectErrorKind :: Io , err) ) ?;
381+ let con = WebSocketAdapter :: new ( ws. 0 ) ;
382+ Connection :: new ( Box :: new ( con) , 0 , self . connect_stats . clone ( ) )
383+ }
384+ _ => {
385+ let tcp_stream = tokio:: time:: timeout (
386+ self . options . connection_timeout ,
387+ TcpStream :: connect ( socket_addr) ,
388+ )
389+ . await
390+ . map_err ( |_| ConnectError :: new ( crate :: ConnectErrorKind :: TimedOut ) ) ??;
391+ tcp_stream. set_nodelay ( true ) ?;
392+
393+ Connection :: new (
394+ Box :: new ( tcp_stream) ,
395+ self . options . read_buffer_capacity . into ( ) ,
396+ self . connect_stats . clone ( ) ,
397+ )
398+ }
399+ } ;
340400
341401 let tls_connection = |connection : Connection | async {
342402 let tls_config = Arc :: new (
@@ -346,7 +406,7 @@ impl Connector {
346406 ) ;
347407 let tls_connector = tokio_rustls:: TlsConnector :: from ( tls_config) ;
348408
349- let domain = webpki:: types:: ServerName :: try_from ( tls_host )
409+ let domain = webpki:: types:: ServerName :: try_from ( server_addr . host ( ) )
350410 . map_err ( |err| ConnectError :: with_source ( crate :: ConnectErrorKind :: Tls , err) ) ?;
351411
352412 let tls_stream = tls_connector
@@ -363,7 +423,7 @@ impl Connector {
363423 // If `tls_first` was set, establish TLS connection before getting INFO.
364424 // There is no point in checking if tls is required, because
365425 // the connection has to be be upgraded to TLS anyway as it's different flow.
366- if self . options . tls_first {
426+ if self . options . tls_first && !server_addr . is_websocket ( ) {
367427 connection = tls_connection ( connection) . await ?;
368428 }
369429
@@ -386,6 +446,7 @@ impl Connector {
386446
387447 // If `tls_first` was not set, establish TLS connection if it is required.
388448 if !self . options . tls_first
449+ && !server_addr. is_websocket ( )
389450 && ( self . options . tls_required || info. tls_required || tls_required)
390451 {
391452 connection = tls_connection ( connection) . await ?;
0 commit comments