turbo_tasks/
join_iter_ext.rs1use std::{
2 future::{Future, IntoFuture},
3 pin::Pin,
4 task::Poll,
5};
6
7use anyhow::Result;
8use futures::{
9 FutureExt,
10 future::{JoinAll, join_all},
11};
12use pin_project_lite::pin_project;
13
14pin_project! {
15 pub struct Join<F>
17 where
18 F: Future,
19 {
20 #[pin]
21 inner: JoinAll<F>,
22 }
23}
24
25impl<T, F> Future for Join<F>
26where
27 F: Future<Output = T>,
28{
29 type Output = Vec<T>;
30
31 fn poll(
32 self: std::pin::Pin<&mut Self>,
33 cx: &mut std::task::Context<'_>,
34 ) -> std::task::Poll<Self::Output> {
35 self.project().inner.poll(cx)
36 }
37}
38
39pub trait JoinIterExt<T, F>: Iterator
40where
41 F: Future<Output = T>,
42{
43 fn join(self) -> Join<F>;
46}
47
48pin_project! {
49 #[must_use]
51 pub struct TryJoin<F>
52 where
53 F: Future,
54 {
55 #[pin]
56 inner: JoinAll<F>,
57 }
58}
59
60impl<T, F> Future for TryJoin<F>
61where
62 F: Future<Output = Result<T>>,
63{
64 type Output = Result<Vec<T>>;
65
66 fn poll(
67 self: std::pin::Pin<&mut Self>,
68 cx: &mut std::task::Context<'_>,
69 ) -> std::task::Poll<Self::Output> {
70 match self.project().inner.poll_unpin(cx) {
71 std::task::Poll::Ready(res) => {
72 std::task::Poll::Ready(res.into_iter().collect::<Result<Vec<_>>>())
73 }
74 std::task::Poll::Pending => std::task::Poll::Pending,
75 }
76 }
77}
78
79pub trait TryJoinIterExt<T, F>: Iterator
80where
81 F: Future<Output = Result<T>>,
82{
83 fn try_join(self) -> TryJoin<F>;
89}
90
91impl<T, F, IF, It> JoinIterExt<T, F> for It
92where
93 F: Future<Output = T>,
94 IF: IntoFuture<Output = T, IntoFuture = F>,
95 It: Iterator<Item = IF>,
96{
97 fn join(self) -> Join<F> {
98 Join {
99 inner: join_all(self.map(|f| f.into_future())),
100 }
101 }
102}
103
104impl<T, F, IF, It> TryJoinIterExt<T, F> for It
105where
106 F: Future<Output = Result<T>>,
107 IF: IntoFuture<Output = Result<T>, IntoFuture = F>,
108 It: Iterator<Item = IF>,
109{
110 fn try_join(self) -> TryJoin<F> {
111 TryJoin {
112 inner: join_all(self.map(|f| f.into_future())),
113 }
114 }
115}
116
117pin_project! {
118 pub struct TryFlatJoin<F>
120 where
121 F: Future,
122 {
123 #[pin]
124 inner: JoinAll<F>,
125 }
126}
127
128impl<F, I, U> Future for TryFlatJoin<F>
129where
130 F: Future<Output = Result<I>>,
131 I: IntoIterator<IntoIter = U, Item = U::Item>,
132 U: Iterator,
133{
134 type Output = Result<Vec<U::Item>>;
135
136 fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
137 match self.project().inner.poll_unpin(cx) {
138 Poll::Ready(res) => {
139 let mut v = Vec::new();
140 for r in res {
141 v.extend(r?);
142 }
143
144 Poll::Ready(Ok(v))
145 }
146 Poll::Pending => Poll::Pending,
147 }
148 }
149}
150
151pub trait TryFlatJoinIterExt<F, I, U>: Iterator
152where
153 F: Future<Output = Result<I>>,
154 I: IntoIterator<IntoIter = U, Item = U::Item>,
155 U: Iterator,
156{
157 fn try_flat_join(self) -> TryFlatJoin<F>;
165}
166
167impl<F, IF, It, I, U> TryFlatJoinIterExt<F, I, U> for It
168where
169 F: Future<Output = Result<I>>,
170 IF: IntoFuture<Output = Result<I>, IntoFuture = F>,
171 It: Iterator<Item = IF>,
172 I: IntoIterator<IntoIter = U, Item = U::Item>,
173 U: Iterator,
174{
175 fn try_flat_join(self) -> TryFlatJoin<F> {
176 TryFlatJoin {
177 inner: join_all(self.map(|f| f.into_future())),
178 }
179 }
180}