turbo_tasks_bytes/
stream.rs

1use std::{
2    fmt,
3    pin::Pin,
4    sync::{Arc, Mutex},
5    task::{Context as TaskContext, Poll},
6};
7
8use anyhow::Result;
9use futures::{Stream as StreamTrait, StreamExt, TryStreamExt};
10use serde::{Deserialize, Deserializer, Serialize, Serializer};
11
12/// Streams allow for streaming values from source to sink.
13///
14/// A Stream implements both a reader (which implements the Stream trait), and a
15/// writer (which can be sent to another thread). As new values are written, any
16/// pending readers will be woken up to receive the new value.
17#[derive(Clone, Debug)]
18pub struct Stream<T: Clone + Send> {
19    inner: Arc<Mutex<StreamState<T>>>,
20}
21
22/// The StreamState actually holds the data of a Stream.
23struct StreamState<T> {
24    source: Option<Pin<Box<dyn StreamTrait<Item = T> + Send>>>,
25    pulled: Vec<T>,
26}
27
28impl<T: Clone + Send> Stream<T> {
29    /// Constructs a new Stream, and immediately closes it with only the passed
30    /// values.
31    pub fn new_closed(pulled: Vec<T>) -> Self {
32        Self {
33            inner: Arc::new(Mutex::new(StreamState {
34                source: None,
35                pulled,
36            })),
37        }
38    }
39
40    /// Creates a new Stream, which will lazily pull from the source stream.
41    pub fn new_open(
42        pulled: Vec<T>,
43        source: Box<dyn StreamTrait<Item = T> + Send + 'static>,
44    ) -> Self {
45        Self {
46            inner: Arc::new(Mutex::new(StreamState {
47                source: Some(Box::into_pin(source)),
48                pulled,
49            })),
50        }
51    }
52
53    /// Returns a [StreamTrait] implementation to poll values out of our Stream.
54    pub fn read(&self) -> StreamRead<T> {
55        StreamRead {
56            source: self.clone(),
57            index: 0,
58        }
59    }
60
61    pub async fn into_single(&self) -> SingleValue<T> {
62        let mut stream = self.read();
63        let Some(first) = stream.next().await else {
64            return SingleValue::None;
65        };
66
67        if stream.next().await.is_some() {
68            return SingleValue::Multiple;
69        }
70
71        SingleValue::Single(first)
72    }
73}
74
75impl<T: Clone + Send, E: Clone + Send> Stream<Result<T, E>> {
76    /// Converts a TryStream into a single value when possible.
77    pub async fn try_into_single(&self) -> Result<SingleValue<T>, E> {
78        let mut stream = self.read();
79        let Some(first) = stream.try_next().await? else {
80            return Ok(SingleValue::None);
81        };
82
83        if stream.try_next().await?.is_some() {
84            return Ok(SingleValue::Multiple);
85        }
86
87        Ok(SingleValue::Single(first))
88    }
89}
90
91pub enum SingleValue<T> {
92    /// The Stream did not hold a value.
93    None,
94
95    /// The Stream held multiple values.
96    Multiple,
97
98    /// The held only a single value.
99    Single(T),
100}
101
102impl<T: fmt::Debug> fmt::Debug for SingleValue<T> {
103    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
104        match self {
105            SingleValue::None => f.debug_struct("SingleValue::None").finish(),
106            SingleValue::Multiple => f.debug_struct("SingleValue::Multiple").finish(),
107            SingleValue::Single(v) => f.debug_tuple("SingleValue::Single").field(v).finish(),
108        }
109    }
110}
111
112impl<T: Clone + Send, S: StreamTrait<Item = T> + Send + Unpin + 'static> From<S> for Stream<T> {
113    fn from(source: S) -> Self {
114        Self::new_open(vec![], Box::new(source))
115    }
116}
117
118impl<T: Clone + Send> Default for Stream<T> {
119    fn default() -> Self {
120        Self::new_closed(vec![])
121    }
122}
123
124impl<T: Clone + PartialEq + Send> PartialEq for Stream<T> {
125    // A Stream is equal if it's the same internal pointer, or both streams are
126    // closed with equivalent values.
127    fn eq(&self, other: &Self) -> bool {
128        Arc::ptr_eq(&self.inner, &other.inner) || {
129            let left = self.inner.lock().unwrap();
130            let right = other.inner.lock().unwrap();
131
132            match (&*left, &*right) {
133                (
134                    StreamState {
135                        pulled: a,
136                        source: None,
137                    },
138                    StreamState {
139                        pulled: b,
140                        source: None,
141                    },
142                ) => a == b,
143                _ => false,
144            }
145        }
146    }
147}
148impl<T: Clone + Eq + Send> Eq for Stream<T> {}
149
150impl<T: Clone + Serialize + Send> Serialize for Stream<T> {
151    fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
152        use serde::ser::Error;
153        let lock = self.inner.lock().map_err(Error::custom)?;
154        match &*lock {
155            StreamState {
156                pulled,
157                source: None,
158            } => pulled.serialize(serializer),
159            _ => Err(Error::custom("cannot serialize open stream")),
160        }
161    }
162}
163
164impl<'de, T: Clone + Send + Deserialize<'de>> Deserialize<'de> for Stream<T> {
165    fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
166        let data = <Vec<T>>::deserialize(deserializer)?;
167        Ok(Stream::new_closed(data))
168    }
169}
170
171impl<T: Clone + fmt::Debug> fmt::Debug for StreamState<T> {
172    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
173        f.debug_struct("StreamState")
174            .field("pulled", &self.pulled)
175            .finish()
176    }
177}
178
179/// Implements [StreamTrait] over our Stream.
180#[derive(Debug)]
181pub struct StreamRead<T: Clone + Send> {
182    index: usize,
183    source: Stream<T>,
184}
185
186impl<T: Clone + Send> StreamTrait for StreamRead<T> {
187    type Item = T;
188
189    fn poll_next(self: Pin<&mut Self>, cx: &mut TaskContext<'_>) -> Poll<Option<Self::Item>> {
190        let this = self.get_mut();
191        let index = this.index;
192        let mut inner = this.source.inner.lock().unwrap();
193
194        if let Some(v) = inner.pulled.get(index) {
195            // If the current reader can be satisfied by a value we've already pulled, then
196            // just do that.
197            this.index += 1;
198            return Poll::Ready(Some(v.clone()));
199        };
200
201        let Some(source) = &mut inner.source else {
202            // If the source has been closed, there's nothing left to pull.
203            return Poll::Ready(None);
204        };
205
206        match source.poll_next_unpin(cx) {
207            // If the source stream is ready to give us a new value, we can immediately store that
208            // and return it to the caller. Any other readers will be able to read the value from
209            // the already-pulled data.
210            Poll::Ready(Some(v)) => {
211                this.index += 1;
212                inner.pulled.push(v.clone());
213                Poll::Ready(Some(v))
214            }
215            // If the source stream is finished, then we can transition to the closed state
216            // to drop the source stream.
217            Poll::Ready(None) => {
218                inner.source.take();
219                Poll::Ready(None)
220            }
221            // Else, we need to wait for the source stream to give us a new value. The
222            // source stream will be responsible for waking the TaskContext.
223            Poll::Pending => Poll::Pending,
224        }
225    }
226}