turbo_tasks_bytes/
stream.rs1use std::{
2 fmt,
3 pin::Pin,
4 sync::{Arc, Mutex},
5 task::{Context as TaskContext, Poll},
6};
7
8use anyhow::Result;
9use bincode::{
10 BorrowDecode, Decode, Encode,
11 de::{BorrowDecoder, Decoder},
12 enc::Encoder,
13 error::{DecodeError, EncodeError},
14};
15use futures::{Stream as StreamTrait, StreamExt, TryStreamExt};
16
17#[derive(Clone, Debug)]
23pub struct Stream<T: Clone + Send> {
24 inner: Arc<Mutex<StreamState<T>>>,
25}
26
27struct StreamState<T> {
29 source: Option<Pin<Box<dyn StreamTrait<Item = T> + Send>>>,
30 pulled: Vec<T>,
31}
32
33impl<T: Clone + Send> Stream<T> {
34 pub fn new_closed(pulled: Vec<T>) -> Self {
37 Self {
38 inner: Arc::new(Mutex::new(StreamState {
39 source: None,
40 pulled,
41 })),
42 }
43 }
44
45 pub fn new_open(
47 pulled: Vec<T>,
48 source: Box<dyn StreamTrait<Item = T> + Send + 'static>,
49 ) -> Self {
50 Self {
51 inner: Arc::new(Mutex::new(StreamState {
52 source: Some(Box::into_pin(source)),
53 pulled,
54 })),
55 }
56 }
57
58 pub fn read(&self) -> StreamRead<T> {
60 StreamRead {
61 source: self.clone(),
62 index: 0,
63 }
64 }
65
66 pub async fn into_single(&self) -> SingleValue<T> {
67 let mut stream = self.read();
68 let Some(first) = stream.next().await else {
69 return SingleValue::None;
70 };
71
72 if stream.next().await.is_some() {
73 return SingleValue::Multiple;
74 }
75
76 SingleValue::Single(first)
77 }
78}
79
80impl<T: Clone + Send, E: Clone + Send> Stream<Result<T, E>> {
81 pub async fn try_into_single(&self) -> Result<SingleValue<T>, E> {
83 let mut stream = self.read();
84 let Some(first) = stream.try_next().await? else {
85 return Ok(SingleValue::None);
86 };
87
88 if stream.try_next().await?.is_some() {
89 return Ok(SingleValue::Multiple);
90 }
91
92 Ok(SingleValue::Single(first))
93 }
94}
95
96pub enum SingleValue<T> {
97 None,
99
100 Multiple,
102
103 Single(T),
105}
106
107impl<T: fmt::Debug> fmt::Debug for SingleValue<T> {
108 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
109 match self {
110 SingleValue::None => f.debug_struct("SingleValue::None").finish(),
111 SingleValue::Multiple => f.debug_struct("SingleValue::Multiple").finish(),
112 SingleValue::Single(v) => f.debug_tuple("SingleValue::Single").field(v).finish(),
113 }
114 }
115}
116
117impl<T: Clone + Send, S: StreamTrait<Item = T> + Send + Unpin + 'static> From<S> for Stream<T> {
118 fn from(source: S) -> Self {
119 Self::new_open(vec![], Box::new(source))
120 }
121}
122
123impl<T: Clone + Send> Default for Stream<T> {
124 fn default() -> Self {
125 Self::new_closed(vec![])
126 }
127}
128
129impl<T: Clone + PartialEq + Send> PartialEq for Stream<T> {
130 fn eq(&self, other: &Self) -> bool {
133 Arc::ptr_eq(&self.inner, &other.inner) || {
134 let left = self.inner.lock().unwrap();
135 let right = other.inner.lock().unwrap();
136
137 match (&*left, &*right) {
138 (
139 StreamState {
140 pulled: a,
141 source: None,
142 },
143 StreamState {
144 pulled: b,
145 source: None,
146 },
147 ) => a == b,
148 _ => false,
149 }
150 }
151 }
152}
153impl<T: Clone + Eq + Send> Eq for Stream<T> {}
154
155impl<T: Clone + Encode + Send> Encode for Stream<T> {
156 fn encode<E: Encoder>(&self, encoder: &mut E) -> Result<(), EncodeError> {
157 let lock = self
158 .inner
159 .lock()
160 .map_err(|_| EncodeError::Other("failed to lock stream, lock is poisoned?"))?;
161 match &*lock {
162 StreamState {
163 pulled,
164 source: None,
165 } => Encode::encode(pulled, encoder),
166 _ => Err(EncodeError::Other("cannot encode open stream")),
167 }
168 }
169}
170
171impl<Context, T: Clone + Send + Decode<Context>> Decode<Context> for Stream<T> {
172 fn decode<D: Decoder<Context = Context>>(decoder: &mut D) -> Result<Self, DecodeError> {
173 let data = <Vec<T>>::decode(decoder)?;
174 Ok(Stream::new_closed(data))
175 }
176}
177
178impl<'de, Context, T: Clone + Send + BorrowDecode<'de, Context>> BorrowDecode<'de, Context>
179 for Stream<T>
180{
181 fn borrow_decode<D: BorrowDecoder<'de, Context = Context>>(
182 decoder: &mut D,
183 ) -> Result<Self, DecodeError> {
184 let data = <Vec<T>>::borrow_decode(decoder)?;
185 Ok(Stream::new_closed(data))
186 }
187}
188
189impl<T: Clone + fmt::Debug> fmt::Debug for StreamState<T> {
190 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
191 f.debug_struct("StreamState")
192 .field("pulled", &self.pulled)
193 .finish()
194 }
195}
196
197#[derive(Debug)]
199pub struct StreamRead<T: Clone + Send> {
200 index: usize,
201 source: Stream<T>,
202}
203
204impl<T: Clone + Send> StreamTrait for StreamRead<T> {
205 type Item = T;
206
207 fn poll_next(self: Pin<&mut Self>, cx: &mut TaskContext<'_>) -> Poll<Option<Self::Item>> {
208 let this = self.get_mut();
209 let index = this.index;
210 let mut inner = this.source.inner.lock().unwrap();
211
212 if let Some(v) = inner.pulled.get(index) {
213 this.index += 1;
216 return Poll::Ready(Some(v.clone()));
217 };
218
219 let Some(source) = &mut inner.source else {
220 return Poll::Ready(None);
222 };
223
224 match source.poll_next_unpin(cx) {
225 Poll::Ready(Some(v)) => {
229 this.index += 1;
230 inner.pulled.push(v.clone());
231 Poll::Ready(Some(v))
232 }
233 Poll::Ready(None) => {
236 inner.source.take();
237 Poll::Ready(None)
238 }
239 Poll::Pending => Poll::Pending,
242 }
243 }
244}