turbo_tasks_bytes/
stream.rs1use std::{
2 fmt,
3 pin::Pin,
4 sync::{Arc, Mutex},
5 task::{Context as TaskContext, Poll},
6};
7
8use anyhow::Result;
9use futures::{Stream as StreamTrait, StreamExt, TryStreamExt};
10use serde::{Deserialize, Deserializer, Serialize, Serializer};
11
12#[derive(Clone, Debug)]
18pub struct Stream<T: Clone + Send> {
19 inner: Arc<Mutex<StreamState<T>>>,
20}
21
22struct StreamState<T> {
24 source: Option<Pin<Box<dyn StreamTrait<Item = T> + Send>>>,
25 pulled: Vec<T>,
26}
27
28impl<T: Clone + Send> Stream<T> {
29 pub fn new_closed(pulled: Vec<T>) -> Self {
32 Self {
33 inner: Arc::new(Mutex::new(StreamState {
34 source: None,
35 pulled,
36 })),
37 }
38 }
39
40 pub fn new_open(
42 pulled: Vec<T>,
43 source: Box<dyn StreamTrait<Item = T> + Send + 'static>,
44 ) -> Self {
45 Self {
46 inner: Arc::new(Mutex::new(StreamState {
47 source: Some(Box::into_pin(source)),
48 pulled,
49 })),
50 }
51 }
52
53 pub fn read(&self) -> StreamRead<T> {
55 StreamRead {
56 source: self.clone(),
57 index: 0,
58 }
59 }
60
61 pub async fn into_single(&self) -> SingleValue<T> {
62 let mut stream = self.read();
63 let Some(first) = stream.next().await else {
64 return SingleValue::None;
65 };
66
67 if stream.next().await.is_some() {
68 return SingleValue::Multiple;
69 }
70
71 SingleValue::Single(first)
72 }
73}
74
75impl<T: Clone + Send, E: Clone + Send> Stream<Result<T, E>> {
76 pub async fn try_into_single(&self) -> Result<SingleValue<T>, E> {
78 let mut stream = self.read();
79 let Some(first) = stream.try_next().await? else {
80 return Ok(SingleValue::None);
81 };
82
83 if stream.try_next().await?.is_some() {
84 return Ok(SingleValue::Multiple);
85 }
86
87 Ok(SingleValue::Single(first))
88 }
89}
90
91pub enum SingleValue<T> {
92 None,
94
95 Multiple,
97
98 Single(T),
100}
101
102impl<T: fmt::Debug> fmt::Debug for SingleValue<T> {
103 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
104 match self {
105 SingleValue::None => f.debug_struct("SingleValue::None").finish(),
106 SingleValue::Multiple => f.debug_struct("SingleValue::Multiple").finish(),
107 SingleValue::Single(v) => f.debug_tuple("SingleValue::Single").field(v).finish(),
108 }
109 }
110}
111
112impl<T: Clone + Send, S: StreamTrait<Item = T> + Send + Unpin + 'static> From<S> for Stream<T> {
113 fn from(source: S) -> Self {
114 Self::new_open(vec![], Box::new(source))
115 }
116}
117
118impl<T: Clone + Send> Default for Stream<T> {
119 fn default() -> Self {
120 Self::new_closed(vec![])
121 }
122}
123
124impl<T: Clone + PartialEq + Send> PartialEq for Stream<T> {
125 fn eq(&self, other: &Self) -> bool {
128 Arc::ptr_eq(&self.inner, &other.inner) || {
129 let left = self.inner.lock().unwrap();
130 let right = other.inner.lock().unwrap();
131
132 match (&*left, &*right) {
133 (
134 StreamState {
135 pulled: a,
136 source: None,
137 },
138 StreamState {
139 pulled: b,
140 source: None,
141 },
142 ) => a == b,
143 _ => false,
144 }
145 }
146 }
147}
148impl<T: Clone + Eq + Send> Eq for Stream<T> {}
149
150impl<T: Clone + Serialize + Send> Serialize for Stream<T> {
151 fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
152 use serde::ser::Error;
153 let lock = self.inner.lock().map_err(Error::custom)?;
154 match &*lock {
155 StreamState {
156 pulled,
157 source: None,
158 } => pulled.serialize(serializer),
159 _ => Err(Error::custom("cannot serialize open stream")),
160 }
161 }
162}
163
164impl<'de, T: Clone + Send + Deserialize<'de>> Deserialize<'de> for Stream<T> {
165 fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
166 let data = <Vec<T>>::deserialize(deserializer)?;
167 Ok(Stream::new_closed(data))
168 }
169}
170
171impl<T: Clone + fmt::Debug> fmt::Debug for StreamState<T> {
172 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
173 f.debug_struct("StreamState")
174 .field("pulled", &self.pulled)
175 .finish()
176 }
177}
178
179#[derive(Debug)]
181pub struct StreamRead<T: Clone + Send> {
182 index: usize,
183 source: Stream<T>,
184}
185
186impl<T: Clone + Send> StreamTrait for StreamRead<T> {
187 type Item = T;
188
189 fn poll_next(self: Pin<&mut Self>, cx: &mut TaskContext<'_>) -> Poll<Option<Self::Item>> {
190 let this = self.get_mut();
191 let index = this.index;
192 let mut inner = this.source.inner.lock().unwrap();
193
194 if let Some(v) = inner.pulled.get(index) {
195 this.index += 1;
198 return Poll::Ready(Some(v.clone()));
199 };
200
201 let Some(source) = &mut inner.source else {
202 return Poll::Ready(None);
204 };
205
206 match source.poll_next_unpin(cx) {
207 Poll::Ready(Some(v)) => {
211 this.index += 1;
212 inner.pulled.push(v.clone());
213 Poll::Ready(Some(v))
214 }
215 Poll::Ready(None) => {
218 inner.source.take();
219 Poll::Ready(None)
220 }
221 Poll::Pending => Poll::Pending,
224 }
225 }
226}