turbo_tasks/
join_iter_ext.rs

1use std::{
2    future::{Future, IntoFuture},
3    pin::Pin,
4    task::Poll,
5};
6
7use anyhow::Result;
8use futures::{
9    FutureExt,
10    future::{JoinAll, join_all},
11};
12use pin_project_lite::pin_project;
13
14pin_project! {
15    /// Future for the [JoinIterExt::join] method.
16    pub struct Join<F>
17    where
18        F: Future,
19    {
20        #[pin]
21        inner: JoinAll<F>,
22    }
23}
24
25impl<T, F> Future for Join<F>
26where
27    F: Future<Output = T>,
28{
29    type Output = Vec<T>;
30
31    fn poll(
32        self: std::pin::Pin<&mut Self>,
33        cx: &mut std::task::Context<'_>,
34    ) -> std::task::Poll<Self::Output> {
35        self.project().inner.poll(cx)
36    }
37}
38
39pub trait JoinIterExt<T, F>: Iterator
40where
41    F: Future<Output = T>,
42{
43    /// Returns a future that resolves to a vector of the outputs of the futures
44    /// in the iterator.
45    fn join(self) -> Join<F>;
46}
47
48pin_project! {
49    /// Future for the [TryJoinIterExt::try_join] method.
50    #[must_use]
51    pub struct TryJoin<F>
52    where
53        F: Future,
54    {
55        #[pin]
56        inner: JoinAll<F>,
57    }
58}
59
60impl<T, F> Future for TryJoin<F>
61where
62    F: Future<Output = Result<T>>,
63{
64    type Output = Result<Vec<T>>;
65
66    fn poll(
67        self: std::pin::Pin<&mut Self>,
68        cx: &mut std::task::Context<'_>,
69    ) -> std::task::Poll<Self::Output> {
70        match self.project().inner.poll_unpin(cx) {
71            std::task::Poll::Ready(res) => {
72                std::task::Poll::Ready(res.into_iter().collect::<Result<Vec<_>>>())
73            }
74            std::task::Poll::Pending => std::task::Poll::Pending,
75        }
76    }
77}
78
79pub trait TryJoinIterExt<T, F>: Iterator
80where
81    F: Future<Output = Result<T>>,
82{
83    /// Returns a future that resolves to a vector of the outputs of the futures
84    /// in the iterator, or to an error if one of the futures fail.
85    ///
86    /// Unlike `Futures::future::try_join_all`, this returns the Error that
87    /// occurs first in the list of futures, not the first to fail in time.
88    fn try_join(self) -> TryJoin<F>;
89}
90
91impl<T, F, IF, It> JoinIterExt<T, F> for It
92where
93    F: Future<Output = T>,
94    IF: IntoFuture<Output = T, IntoFuture = F>,
95    It: Iterator<Item = IF>,
96{
97    fn join(self) -> Join<F> {
98        Join {
99            inner: join_all(self.map(|f| f.into_future())),
100        }
101    }
102}
103
104impl<T, F, IF, It> TryJoinIterExt<T, F> for It
105where
106    F: Future<Output = Result<T>>,
107    IF: IntoFuture<Output = Result<T>, IntoFuture = F>,
108    It: Iterator<Item = IF>,
109{
110    fn try_join(self) -> TryJoin<F> {
111        TryJoin {
112            inner: join_all(self.map(|f| f.into_future())),
113        }
114    }
115}
116
117pin_project! {
118    /// Future for the [TryFlatJoinIterExt::try_flat_join] method.
119    pub struct TryFlatJoin<F>
120    where
121        F: Future,
122    {
123        #[pin]
124        inner: JoinAll<F>,
125    }
126}
127
128impl<F, I, U> Future for TryFlatJoin<F>
129where
130    F: Future<Output = Result<I>>,
131    I: IntoIterator<IntoIter = U, Item = U::Item>,
132    U: Iterator,
133{
134    type Output = Result<Vec<U::Item>>;
135
136    fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
137        match self.project().inner.poll_unpin(cx) {
138            Poll::Ready(res) => {
139                let mut v = Vec::new();
140                for r in res {
141                    v.extend(r?);
142                }
143
144                Poll::Ready(Ok(v))
145            }
146            Poll::Pending => Poll::Pending,
147        }
148    }
149}
150
151pub trait TryFlatJoinIterExt<F, I, U>: Iterator
152where
153    F: Future<Output = Result<I>>,
154    I: IntoIterator<IntoIter = U, Item = U::Item>,
155    U: Iterator,
156{
157    /// Returns a future that resolves to a vector of the outputs of the futures
158    /// in the iterator, or to an error if one of the futures fail.
159    ///
160    /// It also flattens the result.
161    ///
162    /// Unlike `Futures::future::try_join_all`, this returns the Error that
163    /// occurs first in the list of futures, not the first to fail in time.
164    fn try_flat_join(self) -> TryFlatJoin<F>;
165}
166
167impl<F, IF, It, I, U> TryFlatJoinIterExt<F, I, U> for It
168where
169    F: Future<Output = Result<I>>,
170    IF: IntoFuture<Output = Result<I>, IntoFuture = F>,
171    It: Iterator<Item = IF>,
172    I: IntoIterator<IntoIter = U, Item = U::Item>,
173    U: Iterator,
174{
175    fn try_flat_join(self) -> TryFlatJoin<F> {
176        TryFlatJoin {
177            inner: join_all(self.map(|f| f.into_future())),
178        }
179    }
180}