Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 41 additions & 8 deletions src/proto/h1/dispatch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ use crate::upgrade::OnUpgrade;
pub(crate) struct Dispatcher<D, Bs: Body, I, T> {
conn: Conn<I, Bs::Data, T>,
dispatch: D,
body_tx: Option<crate::body::Sender>,
body_tx: SenderDropGuard,
body_rx: Pin<Box<Option<Bs>>>,
is_closing: bool,
}
Expand Down Expand Up @@ -81,7 +81,7 @@ where
Dispatcher {
conn,
dispatch,
body_tx: None,
body_tx: SenderDropGuard::none(),
body_rx: Box::pin(None),
is_closing: false,
}
Expand Down Expand Up @@ -126,7 +126,8 @@ where
should_shutdown: bool,
) -> Poll<crate::Result<Dispatched>> {
Poll::Ready(ready!(self.poll_inner(cx, should_shutdown)).or_else(|e| {
// Be sure to alert a streaming body of the failure.
// Be sure to alert a streaming body of the failure with a
// more specific error than the drop guard would provide.
if let Some(mut body) = self.body_tx.take() {
body.send_error(crate::Error::new_body("connection error"));
}
Expand Down Expand Up @@ -226,7 +227,7 @@ where
match body.poll_ready(cx) {
Poll::Ready(Ok(())) => (),
Poll::Pending => {
self.body_tx = Some(body);
self.body_tx.set(body);
return Poll::Pending;
}
Poll::Ready(Err(_canceled)) => {
Expand All @@ -243,7 +244,7 @@ where
let chunk = frame.into_data().unwrap_or_else(|_| unreachable!());
match body.try_send_data(chunk) {
Ok(()) => {
self.body_tx = Some(body);
self.body_tx.set(body);
}
Err(_canceled) => {
if self.conn.can_read_body() {
Expand All @@ -257,7 +258,7 @@ where
frame.into_trailers().unwrap_or_else(|_| unreachable!());
match body.try_send_trailers(trailers) {
Ok(()) => {
self.body_tx = Some(body);
self.body_tx.set(body);
}
Err(_canceled) => {
if self.conn.can_read_body() {
Expand All @@ -275,7 +276,7 @@ where
// just drop, the body will close automatically
}
Poll::Pending => {
self.body_tx = Some(body);
self.body_tx.set(body);
return Poll::Pending;
}
Poll::Ready(Some(Err(e))) => {
Expand Down Expand Up @@ -310,7 +311,7 @@ where
other => {
let (tx, rx) =
IncomingBody::new_channel(other, wants.contains(Wants::EXPECT));
self.body_tx = Some(tx);
self.body_tx.set(tx);
rx
}
};
Expand Down Expand Up @@ -524,6 +525,38 @@ impl<T> Drop for OptGuard<'_, T> {
}
}

// ===== impl SenderDropGuard =====

/// A drop guard for the body `Sender`.
///
/// If the `Dispatcher` future is dropped (e.g. the runtime driving the
/// connection is shut down) while it still owns a body `Sender`, the guard
/// sends an incomplete-message error so the receiver sees an error instead
/// of a silent, clean end-of-stream.
struct SenderDropGuard(Option<crate::body::Sender>);

impl SenderDropGuard {
fn none() -> Self {
SenderDropGuard(None)
}

fn set(&mut self, sender: crate::body::Sender) {
self.0 = Some(sender);
}

fn take(&mut self) -> Option<crate::body::Sender> {
self.0.take()
}
}

impl Drop for SenderDropGuard {
fn drop(&mut self) {
if let Some(mut sender) = self.0.take() {
sender.send_error(crate::Error::new_incomplete());
}
}
}

// ===== impl Server =====

cfg_server! {
Expand Down
55 changes: 55 additions & 0 deletions tests/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1758,6 +1758,61 @@ mod conn {
assert_eq!(chunk.data_ref().unwrap().len(), 5);
}

#[tokio::test]
async fn dropped_conn_sends_incomplete_body_error() {
let (listener, addr) = setup_tk_test_server().await;
let (release_tx, release_rx) = oneshot::channel();

let server = async move {
let mut sock = listener.accept().await.unwrap().0;
let mut buf = [0; 4096];
let n = sock.read(&mut buf).await.expect("read 1");

let expected = "GET / HTTP/1.1\r\n\r\n";
assert_eq!(s(&buf[..n]), expected);

sock.write_all(b"HTTP/1.1 200 OK\r\nContent-Length: 5\r\n\r\n")
.await
.unwrap();

release_rx.await.expect("release server");
};

let client = async move {
let tcp = tcp_connect(&addr).await.expect("connect");
let (mut client, conn) = conn::http1::handshake(tcp).await.expect("handshake");

let conn = tokio::task::spawn(async move {
conn.await.expect("http conn");
});

let req = Request::builder()
.uri("/")
.body(Empty::<Bytes>::new())
.unwrap();
let mut res = client.send_request(req).await.expect("send_request");
assert_eq!(res.status(), hyper::StatusCode::OK);
assert_eq!(res.body().size_hint().exact(), Some(5));
assert!(!res.body().is_end_stream());

conn.abort();
let err = conn.await.expect_err("conn task should be aborted");
assert!(err.is_cancelled(), "{err:?}");

let err = res
.body_mut()
.frame()
.await
.expect("body frame")
.unwrap_err();
assert!(err.is_incomplete_message(), "{err:?}");

release_tx.send(()).expect("release server");
};

future::join(server, client).await;
}

#[test]
fn aborted_body_isnt_completed() {
let _ = ::pretty_env_logger::try_init();
Expand Down
Loading