turbo_tasks_bytes/
stream.rs

1use std::{
2    fmt,
3    pin::Pin,
4    sync::{Arc, Mutex},
5    task::{Context as TaskContext, Poll},
6};
7
8use anyhow::Result;
9use bincode::{
10    BorrowDecode, Decode, Encode,
11    de::{BorrowDecoder, Decoder},
12    enc::Encoder,
13    error::{DecodeError, EncodeError},
14};
15use futures::{Stream as StreamTrait, StreamExt, TryStreamExt};
16
17/// Streams allow for streaming values from source to sink.
18///
19/// A Stream implements both a reader (which implements the Stream trait), and a
20/// writer (which can be sent to another thread). As new values are written, any
21/// pending readers will be woken up to receive the new value.
22#[derive(Clone, Debug)]
23pub struct Stream<T: Clone + Send> {
24    inner: Arc<Mutex<StreamState<T>>>,
25}
26
27/// The StreamState actually holds the data of a Stream.
28struct StreamState<T> {
29    source: Option<Pin<Box<dyn StreamTrait<Item = T> + Send>>>,
30    pulled: Vec<T>,
31}
32
33impl<T: Clone + Send> Stream<T> {
34    /// Constructs a new Stream, and immediately closes it with only the passed
35    /// values.
36    pub fn new_closed(pulled: Vec<T>) -> Self {
37        Self {
38            inner: Arc::new(Mutex::new(StreamState {
39                source: None,
40                pulled,
41            })),
42        }
43    }
44
45    /// Creates a new Stream, which will lazily pull from the source stream.
46    pub fn new_open(
47        pulled: Vec<T>,
48        source: Box<dyn StreamTrait<Item = T> + Send + 'static>,
49    ) -> Self {
50        Self {
51            inner: Arc::new(Mutex::new(StreamState {
52                source: Some(Box::into_pin(source)),
53                pulled,
54            })),
55        }
56    }
57
58    /// Returns a [StreamTrait] implementation to poll values out of our Stream.
59    pub fn read(&self) -> StreamRead<T> {
60        StreamRead {
61            source: self.clone(),
62            index: 0,
63        }
64    }
65
66    pub async fn into_single(&self) -> SingleValue<T> {
67        let mut stream = self.read();
68        let Some(first) = stream.next().await else {
69            return SingleValue::None;
70        };
71
72        if stream.next().await.is_some() {
73            return SingleValue::Multiple;
74        }
75
76        SingleValue::Single(first)
77    }
78}
79
80impl<T: Clone + Send, E: Clone + Send> Stream<Result<T, E>> {
81    /// Converts a TryStream into a single value when possible.
82    pub async fn try_into_single(&self) -> Result<SingleValue<T>, E> {
83        let mut stream = self.read();
84        let Some(first) = stream.try_next().await? else {
85            return Ok(SingleValue::None);
86        };
87
88        if stream.try_next().await?.is_some() {
89            return Ok(SingleValue::Multiple);
90        }
91
92        Ok(SingleValue::Single(first))
93    }
94}
95
96pub enum SingleValue<T> {
97    /// The Stream did not hold a value.
98    None,
99
100    /// The Stream held multiple values.
101    Multiple,
102
103    /// The held only a single value.
104    Single(T),
105}
106
107impl<T: fmt::Debug> fmt::Debug for SingleValue<T> {
108    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
109        match self {
110            SingleValue::None => f.debug_struct("SingleValue::None").finish(),
111            SingleValue::Multiple => f.debug_struct("SingleValue::Multiple").finish(),
112            SingleValue::Single(v) => f.debug_tuple("SingleValue::Single").field(v).finish(),
113        }
114    }
115}
116
117impl<T: Clone + Send, S: StreamTrait<Item = T> + Send + Unpin + 'static> From<S> for Stream<T> {
118    fn from(source: S) -> Self {
119        Self::new_open(vec![], Box::new(source))
120    }
121}
122
123impl<T: Clone + Send> Default for Stream<T> {
124    fn default() -> Self {
125        Self::new_closed(vec![])
126    }
127}
128
129impl<T: Clone + PartialEq + Send> PartialEq for Stream<T> {
130    // A Stream is equal if it's the same internal pointer, or both streams are
131    // closed with equivalent values.
132    fn eq(&self, other: &Self) -> bool {
133        Arc::ptr_eq(&self.inner, &other.inner) || {
134            let left = self.inner.lock().unwrap();
135            let right = other.inner.lock().unwrap();
136
137            match (&*left, &*right) {
138                (
139                    StreamState {
140                        pulled: a,
141                        source: None,
142                    },
143                    StreamState {
144                        pulled: b,
145                        source: None,
146                    },
147                ) => a == b,
148                _ => false,
149            }
150        }
151    }
152}
153impl<T: Clone + Eq + Send> Eq for Stream<T> {}
154
155impl<T: Clone + Encode + Send> Encode for Stream<T> {
156    fn encode<E: Encoder>(&self, encoder: &mut E) -> Result<(), EncodeError> {
157        let lock = self
158            .inner
159            .lock()
160            .map_err(|_| EncodeError::Other("failed to lock stream, lock is poisoned?"))?;
161        match &*lock {
162            StreamState {
163                pulled,
164                source: None,
165            } => Encode::encode(pulled, encoder),
166            _ => Err(EncodeError::Other("cannot encode open stream")),
167        }
168    }
169}
170
171impl<Context, T: Clone + Send + Decode<Context>> Decode<Context> for Stream<T> {
172    fn decode<D: Decoder<Context = Context>>(decoder: &mut D) -> Result<Self, DecodeError> {
173        let data = <Vec<T>>::decode(decoder)?;
174        Ok(Stream::new_closed(data))
175    }
176}
177
178impl<'de, Context, T: Clone + Send + BorrowDecode<'de, Context>> BorrowDecode<'de, Context>
179    for Stream<T>
180{
181    fn borrow_decode<D: BorrowDecoder<'de, Context = Context>>(
182        decoder: &mut D,
183    ) -> Result<Self, DecodeError> {
184        let data = <Vec<T>>::borrow_decode(decoder)?;
185        Ok(Stream::new_closed(data))
186    }
187}
188
189impl<T: Clone + fmt::Debug> fmt::Debug for StreamState<T> {
190    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
191        f.debug_struct("StreamState")
192            .field("pulled", &self.pulled)
193            .finish()
194    }
195}
196
197/// Implements [StreamTrait] over our Stream.
198#[derive(Debug)]
199pub struct StreamRead<T: Clone + Send> {
200    index: usize,
201    source: Stream<T>,
202}
203
204impl<T: Clone + Send> StreamTrait for StreamRead<T> {
205    type Item = T;
206
207    fn poll_next(self: Pin<&mut Self>, cx: &mut TaskContext<'_>) -> Poll<Option<Self::Item>> {
208        let this = self.get_mut();
209        let index = this.index;
210        let mut inner = this.source.inner.lock().unwrap();
211
212        if let Some(v) = inner.pulled.get(index) {
213            // If the current reader can be satisfied by a value we've already pulled, then
214            // just do that.
215            this.index += 1;
216            return Poll::Ready(Some(v.clone()));
217        };
218
219        let Some(source) = &mut inner.source else {
220            // If the source has been closed, there's nothing left to pull.
221            return Poll::Ready(None);
222        };
223
224        match source.poll_next_unpin(cx) {
225            // If the source stream is ready to give us a new value, we can immediately store that
226            // and return it to the caller. Any other readers will be able to read the value from
227            // the already-pulled data.
228            Poll::Ready(Some(v)) => {
229                this.index += 1;
230                inner.pulled.push(v.clone());
231                Poll::Ready(Some(v))
232            }
233            // If the source stream is finished, then we can transition to the closed state
234            // to drop the source stream.
235            Poll::Ready(None) => {
236                inner.source.take();
237                Poll::Ready(None)
238            }
239            // Else, we need to wait for the source stream to give us a new value. The
240            // source stream will be responsible for waking the TaskContext.
241            Poll::Pending => Poll::Pending,
242        }
243    }
244}