turbo_tasks_bytes/
stream.rs1use std::{
2 fmt,
3 pin::Pin,
4 sync::{Arc, Mutex},
5 task::{Context as TaskContext, Poll},
6};
7
8use anyhow::Result;
9use bincode::{
10 BorrowDecode, Decode, Encode,
11 de::{BorrowDecoder, Decoder},
12 enc::Encoder,
13 error::{DecodeError, EncodeError},
14};
15use futures::{Stream as StreamTrait, StreamExt};
16
17#[derive(Clone, Debug)]
23pub struct Stream<T: Clone + Send> {
24 inner: Arc<Mutex<StreamState<T>>>,
25}
26
27struct StreamState<T> {
29 source: Option<Pin<Box<dyn StreamTrait<Item = T> + Send>>>,
30 pulled: Vec<T>,
31}
32
33impl<T: Clone + Send> Stream<T> {
34 pub fn new_closed(pulled: Vec<T>) -> Self {
37 Self {
38 inner: Arc::new(Mutex::new(StreamState {
39 source: None,
40 pulled,
41 })),
42 }
43 }
44
45 pub fn new_open(
47 pulled: Vec<T>,
48 source: Box<dyn StreamTrait<Item = T> + Send + 'static>,
49 ) -> Self {
50 Self {
51 inner: Arc::new(Mutex::new(StreamState {
52 source: Some(Box::into_pin(source)),
53 pulled,
54 })),
55 }
56 }
57
58 pub fn read(&self) -> StreamRead<T> {
60 StreamRead {
61 source: self.clone(),
62 index: 0,
63 }
64 }
65
66 pub async fn into_single(&self) -> SingleValue<T> {
67 let mut stream = self.read();
68 let Some(first) = stream.next().await else {
69 return SingleValue::None;
70 };
71
72 if stream.next().await.is_some() {
73 return SingleValue::Multiple;
74 }
75
76 SingleValue::Single(first)
77 }
78}
79
80pub enum SingleValue<T> {
81 None,
83
84 Multiple,
86
87 Single(T),
89}
90
91impl<T: fmt::Debug> fmt::Debug for SingleValue<T> {
92 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
93 match self {
94 SingleValue::None => f.debug_struct("SingleValue::None").finish(),
95 SingleValue::Multiple => f.debug_struct("SingleValue::Multiple").finish(),
96 SingleValue::Single(v) => f.debug_tuple("SingleValue::Single").field(v).finish(),
97 }
98 }
99}
100
101impl<T: Clone + Send, S: StreamTrait<Item = T> + Send + Unpin + 'static> From<S> for Stream<T> {
102 fn from(source: S) -> Self {
103 Self::new_open(vec![], Box::new(source))
104 }
105}
106
107impl<T: Clone + Send> Default for Stream<T> {
108 fn default() -> Self {
109 Self::new_closed(vec![])
110 }
111}
112
113impl<T: Clone + PartialEq + Send> PartialEq for Stream<T> {
114 fn eq(&self, other: &Self) -> bool {
117 Arc::ptr_eq(&self.inner, &other.inner) || {
118 let left = self.inner.lock().unwrap();
119 let right = other.inner.lock().unwrap();
120
121 match (&*left, &*right) {
122 (
123 StreamState {
124 pulled: a,
125 source: None,
126 },
127 StreamState {
128 pulled: b,
129 source: None,
130 },
131 ) => a == b,
132 _ => false,
133 }
134 }
135 }
136}
137impl<T: Clone + Eq + Send> Eq for Stream<T> {}
138
139impl<T: Clone + Encode + Send> Encode for Stream<T> {
140 fn encode<E: Encoder>(&self, encoder: &mut E) -> Result<(), EncodeError> {
141 let lock = self
142 .inner
143 .lock()
144 .map_err(|_| EncodeError::Other("failed to lock stream, lock is poisoned?"))?;
145 match &*lock {
146 StreamState {
147 pulled,
148 source: None,
149 } => Encode::encode(pulled, encoder),
150 _ => Err(EncodeError::Other("cannot encode open stream")),
151 }
152 }
153}
154
155impl<Context, T: Clone + Send + Decode<Context>> Decode<Context> for Stream<T> {
156 fn decode<D: Decoder<Context = Context>>(decoder: &mut D) -> Result<Self, DecodeError> {
157 let data = <Vec<T>>::decode(decoder)?;
158 Ok(Stream::new_closed(data))
159 }
160}
161
162impl<'de, Context, T: Clone + Send + BorrowDecode<'de, Context>> BorrowDecode<'de, Context>
163 for Stream<T>
164{
165 fn borrow_decode<D: BorrowDecoder<'de, Context = Context>>(
166 decoder: &mut D,
167 ) -> Result<Self, DecodeError> {
168 let data = <Vec<T>>::borrow_decode(decoder)?;
169 Ok(Stream::new_closed(data))
170 }
171}
172
173impl<T: Clone + fmt::Debug> fmt::Debug for StreamState<T> {
174 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
175 f.debug_struct("StreamState")
176 .field("pulled", &self.pulled)
177 .finish()
178 }
179}
180
181#[derive(Debug)]
183pub struct StreamRead<T: Clone + Send> {
184 index: usize,
185 source: Stream<T>,
186}
187
188impl<T: Clone + Send> StreamTrait for StreamRead<T> {
189 type Item = T;
190
191 fn poll_next(self: Pin<&mut Self>, cx: &mut TaskContext<'_>) -> Poll<Option<Self::Item>> {
192 let this = self.get_mut();
193 let index = this.index;
194 let mut inner = this.source.inner.lock().unwrap();
195
196 if let Some(v) = inner.pulled.get(index) {
197 this.index += 1;
200 return Poll::Ready(Some(v.clone()));
201 };
202
203 let Some(source) = &mut inner.source else {
204 return Poll::Ready(None);
206 };
207
208 match source.poll_next_unpin(cx) {
209 Poll::Ready(Some(v)) => {
213 this.index += 1;
214 inner.pulled.push(v.clone());
215 Poll::Ready(Some(v))
216 }
217 Poll::Ready(None) => {
220 inner.source.take();
221 Poll::Ready(None)
222 }
223 Poll::Pending => Poll::Pending,
226 }
227 }
228}