@@ -5,14 +5,16 @@ use async_tungstenite::{tungstenite::Message, WebSocketStream};
5
5
use futures_util:: stream:: StreamExt ;
6
6
use serde:: { Deserialize , Serialize } ;
7
7
use std:: io:: Write ;
8
- use tokio:: time:: { sleep, timeout, Duration } ;
8
+ use tokio:: time:: { interval , sleep, timeout, Duration } ;
9
9
use url:: Url ;
10
10
11
11
use crate :: commands:: {
12
12
SSH_CONNECTION_TIMEOUT_SECS , SSH_MAX_EMPTY_MESSAGES , SSH_MAX_RECONNECT_ATTEMPTS ,
13
13
SSH_MESSAGE_TIMEOUT_SECS , SSH_RECONNECT_DELAY_SECS ,
14
14
} ;
15
15
16
+ const SSH_PING_INTERVAL_SECS : u64 = 10 ;
17
+
16
18
#[ derive( Clone , Debug ) ]
17
19
pub struct SSHConnectParams {
18
20
pub project_id : String ,
@@ -190,70 +192,96 @@ impl TerminalClient {
190
192
. map_err ( |e| anyhow:: anyhow!( "Failed to send signal: {}" , e) ) ?;
191
193
Ok ( ( ) )
192
194
}
195
+
196
+ async fn send_ping ( & mut self ) -> Result < ( ) > {
197
+ self . send_message ( Message :: Ping ( vec ! [ ] ) )
198
+ . await
199
+ . map_err ( |e| anyhow:: anyhow!( "Failed to send ping: {}" , e) ) ?;
200
+ Ok ( ( ) )
201
+ }
202
+
193
203
pub async fn handle_server_messages ( & mut self ) -> Result < ( ) > {
194
204
let mut consecutive_empty_messages = 0 ;
195
205
196
- while let Some ( msg) = self . ws_stream . next ( ) . await {
197
- let msg = msg. map_err ( |e| anyhow:: anyhow!( "WebSocket error: {}" , e) ) ?;
206
+ let mut ping_interval = interval ( Duration :: from_secs ( SSH_PING_INTERVAL_SECS ) ) ;
198
207
199
- match msg {
200
- Message :: Text ( text) => {
201
- let server_msg: ServerMessage = serde_json:: from_str ( & text)
202
- . map_err ( |e| anyhow:: anyhow!( "Failed to parse server message: {}" , e) ) ?;
208
+ loop {
209
+ tokio:: select! {
210
+ msg_option = self . ws_stream. next( ) => {
211
+ match msg_option {
212
+ Some ( msg_result) => {
213
+ let msg = msg_result. map_err( |e| anyhow:: anyhow!( "WebSocket error: {}" , e) ) ?;
203
214
204
- match server_msg. r#type . as_str ( ) {
205
- "session_data" => match server_msg. payload . data {
206
- DataPayload :: String ( text) => {
207
- consecutive_empty_messages = 0 ;
208
- print ! ( "{}" , text) ;
209
- std:: io:: stdout ( ) . flush ( ) ?;
210
- }
211
- DataPayload :: Buffer { data } => {
212
- consecutive_empty_messages = 0 ;
213
- std:: io:: stdout ( ) . write_all ( & data) ?;
214
- std:: io:: stdout ( ) . flush ( ) ?;
215
- }
216
- DataPayload :: Empty { } => {
217
- consecutive_empty_messages += 1 ;
218
- if consecutive_empty_messages >= SSH_MAX_EMPTY_MESSAGES {
219
- bail ! ( "Received too many empty messages in a row, connection may be stale" ) ;
215
+ match msg {
216
+ Message :: Text ( text) => {
217
+ let server_msg: ServerMessage = serde_json:: from_str( & text)
218
+ . map_err( |e| anyhow:: anyhow!( "Failed to parse server message: {}" , e) ) ?;
219
+
220
+ match server_msg. r#type. as_str( ) {
221
+ "session_data" => match server_msg. payload. data {
222
+ DataPayload :: String ( text) => {
223
+ consecutive_empty_messages = 0 ;
224
+ print!( "{}" , text) ;
225
+ std:: io:: stdout( ) . flush( ) ?;
226
+ }
227
+ DataPayload :: Buffer { data } => {
228
+ consecutive_empty_messages = 0 ;
229
+ std:: io:: stdout( ) . write_all( & data) ?;
230
+ std:: io:: stdout( ) . flush( ) ?;
231
+ }
232
+ DataPayload :: Empty { } => {
233
+ consecutive_empty_messages += 1 ;
234
+ if consecutive_empty_messages >= SSH_MAX_EMPTY_MESSAGES {
235
+ bail!( "Received too many empty messages in a row, connection may be stale" ) ;
236
+ }
237
+ }
238
+ } ,
239
+ "error" => {
240
+ bail!( server_msg. payload. message) ;
241
+ }
242
+ "pty_closed" => {
243
+ return Ok ( ( ) ) ;
244
+ }
245
+ unknown_type => {
246
+ eprintln!( "Warning: Received unknown message type: {}" , unknown_type) ;
247
+ }
248
+ }
249
+ }
250
+ Message :: Close ( frame) => {
251
+ if let Some ( frame) = frame {
252
+ bail!(
253
+ "WebSocket closed with code {}: {}" ,
254
+ frame. code,
255
+ frame. reason
256
+ ) ;
257
+ } else {
258
+ bail!( "WebSocket closed unexpectedly" ) ;
259
+ }
260
+ }
261
+ Message :: Ping ( data) => {
262
+ self . send_message( Message :: Pong ( data) ) . await ?;
263
+ }
264
+ Message :: Pong ( data) => {
265
+ // Pong recevied
266
+ }
267
+ Message :: Binary ( _) => {
268
+ eprintln!( "Warning: Unexpected binary message received" ) ;
269
+ }
270
+ Message :: Frame ( _) => {
271
+ eprintln!( "Warning: Unexpected raw frame received" ) ;
220
272
}
221
273
}
222
274
} ,
223
- "error" => {
224
- bail ! ( server_msg. payload. message) ;
225
- }
226
- "pty_closed" => {
227
- return Ok ( ( ) ) ;
228
- }
229
- unknown_type => {
230
- eprintln ! ( "Warning: Received unknown message type: {}" , unknown_type) ;
275
+ None => {
276
+ bail!( "WebSocket connection closed unexpectedly" ) ;
231
277
}
232
278
}
233
- }
234
- Message :: Close ( frame) => {
235
- if let Some ( frame) = frame {
236
- bail ! (
237
- "WebSocket closed with code {}: {}" ,
238
- frame. code,
239
- frame. reason
240
- ) ;
241
- } else {
242
- bail ! ( "WebSocket closed unexpectedly" ) ;
243
- }
244
- }
245
- Message :: Ping ( _) | Message :: Pong ( _) => {
246
- // Just acknowledge these silently...they keep the connection alive
247
- }
248
- Message :: Binary ( _) => {
249
- eprintln ! ( "Warning: Unexpected binary message received" ) ;
250
- }
251
- Message :: Frame ( _) => {
252
- eprintln ! ( "Warning: Unexpected raw frame received" ) ;
279
+ } ,
280
+
281
+ _ = ping_interval. tick( ) => {
282
+ self . send_ping( ) . await ?;
253
283
}
254
284
}
255
285
}
256
-
257
- bail ! ( "WebSocket connection closed unexpectedly" ) ;
258
286
}
259
287
}
0 commit comments