turbo_tasks_bytes/
stream.rs

1use std::{
2    fmt,
3    pin::Pin,
4    sync::{Arc, Mutex},
5    task::{Context as TaskContext, Poll},
6};
7
8use anyhow::Result;
9use bincode::{
10    BorrowDecode, Decode, Encode,
11    de::{BorrowDecoder, Decoder},
12    enc::Encoder,
13    error::{DecodeError, EncodeError},
14};
15use futures::{Stream as StreamTrait, StreamExt};
16
17/// Streams allow for streaming values from source to sink.
18///
19/// A Stream implements both a reader (which implements the Stream trait), and a
20/// writer (which can be sent to another thread). As new values are written, any
21/// pending readers will be woken up to receive the new value.
22#[derive(Clone, Debug)]
23pub struct Stream<T: Clone + Send> {
24    inner: Arc<Mutex<StreamState<T>>>,
25}
26
27/// The StreamState actually holds the data of a Stream.
28struct StreamState<T> {
29    source: Option<Pin<Box<dyn StreamTrait<Item = T> + Send>>>,
30    pulled: Vec<T>,
31}
32
33impl<T: Clone + Send> Stream<T> {
34    /// Constructs a new Stream, and immediately closes it with only the passed
35    /// values.
36    pub fn new_closed(pulled: Vec<T>) -> Self {
37        Self {
38            inner: Arc::new(Mutex::new(StreamState {
39                source: None,
40                pulled,
41            })),
42        }
43    }
44
45    /// Creates a new Stream, which will lazily pull from the source stream.
46    pub fn new_open(
47        pulled: Vec<T>,
48        source: Box<dyn StreamTrait<Item = T> + Send + 'static>,
49    ) -> Self {
50        Self {
51            inner: Arc::new(Mutex::new(StreamState {
52                source: Some(Box::into_pin(source)),
53                pulled,
54            })),
55        }
56    }
57
58    /// Returns a [StreamTrait] implementation to poll values out of our Stream.
59    pub fn read(&self) -> StreamRead<T> {
60        StreamRead {
61            source: self.clone(),
62            index: 0,
63        }
64    }
65
66    pub async fn into_single(&self) -> SingleValue<T> {
67        let mut stream = self.read();
68        let Some(first) = stream.next().await else {
69            return SingleValue::None;
70        };
71
72        if stream.next().await.is_some() {
73            return SingleValue::Multiple;
74        }
75
76        SingleValue::Single(first)
77    }
78}
79
80pub enum SingleValue<T> {
81    /// The Stream did not hold a value.
82    None,
83
84    /// The Stream held multiple values.
85    Multiple,
86
87    /// The held only a single value.
88    Single(T),
89}
90
91impl<T: fmt::Debug> fmt::Debug for SingleValue<T> {
92    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
93        match self {
94            SingleValue::None => f.debug_struct("SingleValue::None").finish(),
95            SingleValue::Multiple => f.debug_struct("SingleValue::Multiple").finish(),
96            SingleValue::Single(v) => f.debug_tuple("SingleValue::Single").field(v).finish(),
97        }
98    }
99}
100
101impl<T: Clone + Send, S: StreamTrait<Item = T> + Send + Unpin + 'static> From<S> for Stream<T> {
102    fn from(source: S) -> Self {
103        Self::new_open(vec![], Box::new(source))
104    }
105}
106
107impl<T: Clone + Send> Default for Stream<T> {
108    fn default() -> Self {
109        Self::new_closed(vec![])
110    }
111}
112
113impl<T: Clone + PartialEq + Send> PartialEq for Stream<T> {
114    // A Stream is equal if it's the same internal pointer, or both streams are
115    // closed with equivalent values.
116    fn eq(&self, other: &Self) -> bool {
117        Arc::ptr_eq(&self.inner, &other.inner) || {
118            let left = self.inner.lock().unwrap();
119            let right = other.inner.lock().unwrap();
120
121            match (&*left, &*right) {
122                (
123                    StreamState {
124                        pulled: a,
125                        source: None,
126                    },
127                    StreamState {
128                        pulled: b,
129                        source: None,
130                    },
131                ) => a == b,
132                _ => false,
133            }
134        }
135    }
136}
137impl<T: Clone + Eq + Send> Eq for Stream<T> {}
138
139impl<T: Clone + Encode + Send> Encode for Stream<T> {
140    fn encode<E: Encoder>(&self, encoder: &mut E) -> Result<(), EncodeError> {
141        let lock = self
142            .inner
143            .lock()
144            .map_err(|_| EncodeError::Other("failed to lock stream, lock is poisoned?"))?;
145        match &*lock {
146            StreamState {
147                pulled,
148                source: None,
149            } => Encode::encode(pulled, encoder),
150            _ => Err(EncodeError::Other("cannot encode open stream")),
151        }
152    }
153}
154
155impl<Context, T: Clone + Send + Decode<Context>> Decode<Context> for Stream<T> {
156    fn decode<D: Decoder<Context = Context>>(decoder: &mut D) -> Result<Self, DecodeError> {
157        let data = <Vec<T>>::decode(decoder)?;
158        Ok(Stream::new_closed(data))
159    }
160}
161
162impl<'de, Context, T: Clone + Send + BorrowDecode<'de, Context>> BorrowDecode<'de, Context>
163    for Stream<T>
164{
165    fn borrow_decode<D: BorrowDecoder<'de, Context = Context>>(
166        decoder: &mut D,
167    ) -> Result<Self, DecodeError> {
168        let data = <Vec<T>>::borrow_decode(decoder)?;
169        Ok(Stream::new_closed(data))
170    }
171}
172
173impl<T: Clone + fmt::Debug> fmt::Debug for StreamState<T> {
174    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
175        f.debug_struct("StreamState")
176            .field("pulled", &self.pulled)
177            .finish()
178    }
179}
180
181/// Implements [StreamTrait] over our Stream.
182#[derive(Debug)]
183pub struct StreamRead<T: Clone + Send> {
184    index: usize,
185    source: Stream<T>,
186}
187
188impl<T: Clone + Send> StreamTrait for StreamRead<T> {
189    type Item = T;
190
191    fn poll_next(self: Pin<&mut Self>, cx: &mut TaskContext<'_>) -> Poll<Option<Self::Item>> {
192        let this = self.get_mut();
193        let index = this.index;
194        let mut inner = this.source.inner.lock().unwrap();
195
196        if let Some(v) = inner.pulled.get(index) {
197            // If the current reader can be satisfied by a value we've already pulled, then
198            // just do that.
199            this.index += 1;
200            return Poll::Ready(Some(v.clone()));
201        };
202
203        let Some(source) = &mut inner.source else {
204            // If the source has been closed, there's nothing left to pull.
205            return Poll::Ready(None);
206        };
207
208        match source.poll_next_unpin(cx) {
209            // If the source stream is ready to give us a new value, we can immediately store that
210            // and return it to the caller. Any other readers will be able to read the value from
211            // the already-pulled data.
212            Poll::Ready(Some(v)) => {
213                this.index += 1;
214                inner.pulled.push(v.clone());
215                Poll::Ready(Some(v))
216            }
217            // If the source stream is finished, then we can transition to the closed state
218            // to drop the source stream.
219            Poll::Ready(None) => {
220                inner.source.take();
221                Poll::Ready(None)
222            }
223            // Else, we need to wait for the source stream to give us a new value. The
224            // source stream will be responsible for waking the TaskContext.
225            Poll::Pending => Poll::Pending,
226        }
227    }
228}