diff --git a/tower/src/buffer/worker.rs b/tower/src/buffer/worker.rs index b0a797891..cd22f8d73 100644 --- a/tower/src/buffer/worker.rs +++ b/tower/src/buffer/worker.rs @@ -24,7 +24,6 @@ pin_project_lite::pin_project! { where T: Service, { - current_message: Option>, rx: mpsc::Receiver>, service: T, finish: bool, @@ -53,7 +52,6 @@ where }; let worker = Worker { - current_message: None, finish: false, failed: None, rx, @@ -68,33 +66,18 @@ where /// /// If a `Message` is returned, the `bool` is true if this is the first time we received this /// message, and false otherwise (i.e., we tried to forward it to the backing service before). - fn poll_next_msg( - &mut self, - cx: &mut Context<'_>, - ) -> Poll, bool)>> { + fn poll_next_msg(&mut self, cx: &mut Context<'_>) -> Poll>> { if self.finish { // We've already received None and are shutting down return Poll::Ready(None); } tracing::trace!("worker polling for next message"); - if let Some(msg) = self.current_message.take() { - // If the oneshot sender is closed, then the receiver is dropped, - // and nobody cares about the response. If this is the case, we - // should continue to the next request. - if !msg.tx.is_closed() { - tracing::trace!("resuming buffered request"); - return Poll::Ready(Some((msg, false))); - } - - tracing::trace!("dropping cancelled buffered request"); - } - // Get the next request while let Some(msg) = ready!(Pin::new(&mut self.rx).poll_recv(cx)) { if !msg.tx.is_closed() { tracing::trace!("processing new request"); - return Poll::Ready(Some((msg, true))); + return Poll::Ready(Some(msg)); } // Otherwise, request is canceled, so pop the next one. tracing::trace!("dropping cancelled request"); @@ -150,8 +133,24 @@ where } loop { + if self.failed.is_none() { + match self.service.poll_ready(cx) { + Poll::Pending => { + tracing::trace!(service.ready = false); + return Poll::Pending; + } + Poll::Ready(Err(e)) => { + let error = e.into(); + tracing::debug!({ %error }, "service failed"); + self.failed(error); + } + Poll::Ready(Ok(())) => { + tracing::debug!(service.ready = true); + } + } + } match ready!(self.poll_next_msg(cx)) { - Some((msg, first)) => { + Some(msg) => { let _guard = msg.span.enter(); if let Some(ref failed) = self.failed { tracing::trace!("notifying caller about worker failure"); @@ -159,42 +158,15 @@ where continue; } - // Wait for the service to be ready - tracing::trace!( - resumed = !first, - message = "worker received request; waiting for service readiness" - ); - match self.service.poll_ready(cx) { - Poll::Ready(Ok(())) => { - tracing::debug!(service.ready = true, message = "processing request"); - let response = self.service.call(msg.request); - - // Send the response future back to the sender. - // - // An error means the request had been canceled in-between - // our calls, the response future will just be dropped. - tracing::trace!("returning response future"); - let _ = msg.tx.send(Ok(response)); - } - Poll::Pending => { - tracing::trace!(service.ready = false, message = "delay"); - // Put out current message back in its slot. - drop(_guard); - self.current_message = Some(msg); - return Poll::Pending; - } - Poll::Ready(Err(e)) => { - let error = e.into(); - tracing::debug!({ %error }, "service failed"); - drop(_guard); - self.failed(error); - let _ = msg.tx.send(Err(self - .failed - .as_ref() - .expect("Worker::failed did not set self.failed?") - .clone())); - } - } + tracing::debug!(service.ready = true, message = "processing request"); + let response = self.service.call(msg.request); + + // Send the response future back to the sender. + // + // An error means the request had been canceled in-between + // our calls, the response future will just be dropped. + tracing::trace!("returning response future"); + let _ = msg.tx.send(Ok(response)); } None => { // No more more requests _ever_. diff --git a/tower/tests/buffer/main.rs b/tower/tests/buffer/main.rs index ee238f11c..5f31d03d7 100644 --- a/tower/tests/buffer/main.rs +++ b/tower/tests/buffer/main.rs @@ -172,14 +172,13 @@ async fn waits_for_channel_capacity() { assert_ready_ok!(service.poll_ready()); let mut response2 = task::spawn(service.call("hello")); - assert_pending!(worker.poll()); - - assert_ready_ok!(service.poll_ready()); - let mut response3 = task::spawn(service.call("hello")); assert_pending!(service.poll_ready()); assert_pending!(worker.poll()); + // wake up worker's service (i.e. Mock), now it's ready to make progress handle.allow(1); + // process the request(i.e. send to handle), return the response + // and then poll worker's service::poll_ready in next loop. assert_pending!(worker.poll()); handle @@ -192,10 +191,10 @@ async fn waits_for_channel_capacity() { assert_ready_ok!(response1.poll()); assert_ready_ok!(service.poll_ready()); - let mut response4 = task::spawn(service.call("hello")); + let mut response3 = task::spawn(service.call("hello")); assert_pending!(worker.poll()); - handle.allow(3); + handle.allow(2); assert_pending!(worker.poll()); handle @@ -216,16 +215,6 @@ async fn waits_for_channel_capacity() { .send_response("world"); assert_pending!(worker.poll()); assert_ready_ok!(response3.poll()); - - assert_pending!(worker.poll()); - handle - .next_request() - .await - .unwrap() - .1 - .send_response("world"); - assert_pending!(worker.poll()); - assert_ready_ok!(response4.poll()); } #[tokio::test(flavor = "current_thread")] @@ -243,14 +232,13 @@ async fn wakes_pending_waiters_on_close() { assert_pending!(worker.poll()); let mut response = task::spawn(service1.call("hello")); - assert!(worker.is_woken(), "worker task should be woken by request"); + assert!( + !worker.is_woken(), + "worker task would NOT be woken by request until worker's service is ready" + ); assert_pending!(worker.poll()); // fill the channel so all subsequent requests will wait for capacity - let service1 = assert_ready_ok!(task::spawn(service.ready()).poll()); - assert_pending!(worker.poll()); - let mut response2 = task::spawn(service1.call("world")); - let mut service1 = service.clone(); let mut ready1 = task::spawn(service1.ready()); assert_pending!(worker.poll()); @@ -271,13 +259,6 @@ async fn wakes_pending_waiters_on_close() { err ); - let err = assert_ready_err!(response2.poll()); - assert!( - err.is::(), - "response should fail with a Closed, got: {:?}", - err - ); - assert!( ready1.is_woken(), "dropping worker should wake ready task 1" @@ -316,14 +297,13 @@ async fn wakes_pending_waiters_on_failure() { assert_pending!(worker.poll()); let mut response = task::spawn(service1.call("hello")); - assert!(worker.is_woken(), "worker task should be woken by request"); + assert!( + !worker.is_woken(), + "worker task would NOT be woken by request until worker's service is ready" + ); assert_pending!(worker.poll()); // fill the channel so all subsequent requests will wait for capacity - let service1 = assert_ready_ok!(task::spawn(service.ready()).poll()); - assert_pending!(worker.poll()); - let mut response2 = task::spawn(service1.call("world")); - let mut service1 = service.clone(); let mut ready1 = task::spawn(service1.ready()); assert_pending!(worker.poll()); @@ -336,6 +316,8 @@ async fn wakes_pending_waiters_on_failure() { // fail the inner service handle.send_error("foobar"); + // consume the in-flight request and send an Err response, then run + // next loop until read None. // worker task terminates assert_ready!(worker.poll()); @@ -345,12 +327,6 @@ async fn wakes_pending_waiters_on_failure() { "response should fail with a ServiceError, got: {:?}", err ); - let err = assert_ready_err!(response2.poll()); - assert!( - err.is::(), - "response should fail with a ServiceError, got: {:?}", - err - ); assert!( ready1.is_woken(), @@ -375,25 +351,6 @@ async fn wakes_pending_waiters_on_failure() { ); } -#[tokio::test(flavor = "current_thread")] -async fn propagates_trace_spans() { - use tower::util::ServiceExt; - use tracing::Instrument; - - let _t = support::trace_init(); - - let span = tracing::info_span!("my_span"); - - let service = support::AssertSpanSvc::new(span.clone()); - let (service, worker) = Buffer::pair(service, 5); - let worker = tokio::spawn(worker); - - let result = tokio::spawn(service.oneshot(()).instrument(span)); - - result.await.expect("service panicked").expect("failed"); - worker.await.expect("worker panicked"); -} - #[tokio::test(flavor = "current_thread")] async fn doesnt_leak_permits() { let _t = support::trace_init();