1use std::{collections::HashMap, future::Future, marker::PhantomData, pin::Pin, sync::Arc};
2
3use serde::{Serialize, Serializer, de::DeserializeOwned};
4use serde_json::value::RawValue;
5
6use crate::{Error, ErrorCode, Id, Request, RequestMessage, Response};
7
8trait MethodHandler<C>: Send + Sync {
9 fn call(&self, ctx: C, params: &RawValue) -> Pin<Box<dyn Future<Output = Result<Box<RawValue>, Error>> + Send>>;
10}
11
12struct MethodHandlerImpl<C, H, P, R, E, F> {
13 handler: H,
14 _phantom: PhantomData<fn(C, P) -> (R, E, F)>,
15}
16
17impl<C, P, R, E, F, H> MethodHandler<C> for MethodHandlerImpl<C, H, P, R, E, F>
18where
19 C: Send + 'static,
20 P: DeserializeOwned + Send,
21 R: Serialize + Send,
22 E: Into<Error> + Send,
23 F: Future<Output = Result<R, E>> + Send + 'static,
24 H: Fn(C, P) -> F + Send + Sync,
25{
26 fn call(
27 &self,
28 ctx: C,
29 raw_params: &RawValue,
30 ) -> Pin<Box<dyn Future<Output = Result<Box<RawValue>, Error>> + Send>> {
31 let params: P = match serde_json::from_str(raw_params.get()) {
32 Ok(p) => p,
33 Err(e) => {
34 return Box::pin(async move { Err(Error::invalid_params(e.to_string())) });
35 }
36 };
37 let fut = (self.handler)(ctx, params);
38 Box::pin(async move {
39 match fut.await {
40 Ok(result) => serde_json::value::to_raw_value(&result)
41 .map_err(|e| Error::new(ErrorCode::INTERNAL_ERROR, e.to_string())),
42 Err(e) => Err(e.into()),
43 }
44 })
45 }
46}
47
48#[derive(Clone, Debug)]
53pub enum ResponseMessage {
54 Single(Response),
56 Batch(Vec<Response>),
58}
59
60impl Serialize for ResponseMessage {
61 fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
62 match self {
63 Self::Single(resp) => resp.serialize(serializer),
64 Self::Batch(resps) => resps.serialize(serializer),
65 }
66 }
67}
68
69#[derive(Clone, Debug)]
85pub struct ServerConfig {
86 pub accept_batches: bool,
94}
95
96impl Default for ServerConfig {
97 fn default() -> Self {
98 Self {
99 accept_batches: true,
100 }
101 }
102}
103
104pub struct Server<C> {
119 config: ServerConfig,
120 methods: HashMap<String, Arc<dyn MethodHandler<C>>>,
121 empty_params: Box<RawValue>,
122}
123
124impl<C: Send + Sync + 'static> Server<C> {
125 pub fn new(config: ServerConfig) -> Self {
129 Self {
130 config,
131 methods: HashMap::new(),
132 empty_params: RawValue::from_string("{}".to_owned()).expect("{} is valid JSON"),
133 }
134 }
135
136 pub fn register<P, R, E, F>(
141 &mut self,
142 method: impl Into<String>,
143 handler: impl Fn(C, P) -> F + Send + Sync + 'static,
144 ) where
145 P: DeserializeOwned + Send + 'static,
146 R: Serialize + Send + 'static,
147 E: Into<Error> + Send + 'static,
148 F: Future<Output = Result<R, E>> + Send + 'static,
149 {
150 let entry = MethodHandlerImpl::<C, _, P, R, E, F> {
151 handler,
152 _phantom: PhantomData,
153 };
154 self.methods.insert(method.into(), Arc::new(entry));
155 }
156
157 pub async fn handle(&self, ctx: C, message: RequestMessage) -> Option<ResponseMessage>
167 where
168 C: Clone,
169 {
170 match message {
171 RequestMessage::Single(req) => self.handle_single(ctx, req).await,
172 RequestMessage::Batch(entries) => {
173 if !self.config.accept_batches {
174 return Some(ResponseMessage::Single(Response::Error {
175 error: Error::invalid_request("batch requests are not accepted"),
176 id: Id::Null,
177 }));
178 }
179 self.handle_batch(ctx, entries).await
180 }
181 }
182 }
183
184 async fn handle_single(&self, ctx: C, req: Request) -> Option<ResponseMessage> {
185 let Some(id) = req.id.into_id() else {
186 let _ = self
187 .dispatch(ctx, &req.method, req.params.as_deref().unwrap_or(&self.empty_params))
188 .await;
189 return None;
190 };
191
192 let params = req.params.as_deref().unwrap_or(&self.empty_params);
193 match self.dispatch(ctx, &req.method, params).await {
194 Ok(result) => Some(ResponseMessage::Single(Response::Success {
195 result,
196 id,
197 })),
198 Err(error) => Some(ResponseMessage::Single(Response::Error {
199 error,
200 id,
201 })),
202 }
203 }
204
205 async fn handle_batch(&self, ctx: C, entries: Vec<Request>) -> Option<ResponseMessage>
206 where
207 C: Clone,
208 {
209 if entries.is_empty() {
210 return Some(ResponseMessage::Single(Response::Error {
211 error: Error::invalid_request("empty batch"),
212 id: Id::Null,
213 }));
214 }
215
216 let mut responses: Vec<Response> = Vec::with_capacity(entries.len());
217
218 for req in entries {
219 let Some(id) = req.id.into_id() else {
220 let _ = self
221 .dispatch(ctx.clone(), &req.method, req.params.as_deref().unwrap_or(&self.empty_params))
222 .await;
223 continue;
224 };
225
226 let params = req.params.as_deref().unwrap_or(&self.empty_params);
227 match self.dispatch(ctx.clone(), &req.method, params).await {
228 Ok(result) => responses.push(Response::Success {
229 result,
230 id,
231 }),
232 Err(error) => responses.push(Response::Error {
233 error,
234 id,
235 }),
236 }
237 }
238
239 if responses.is_empty() {
240 None
241 } else {
242 Some(ResponseMessage::Batch(responses))
243 }
244 }
245
246 async fn dispatch(&self, ctx: C, method: &str, params: &RawValue) -> Result<Box<RawValue>, Error> {
247 let callback = self
248 .methods
249 .get(method)
250 .ok_or_else(|| Error::method_not_found(method))?;
251 callback.call(ctx, params).await
252 }
253}
254
255#[cfg(test)]
256mod tests {
257 use super::*;
258 use crate::{ErrorCode, RequestId};
259
260 fn make_request(method: &str, params: Option<&str>, id: Option<i64>) -> Request {
261 Request {
262 jsonrpc: "2.0".into(),
263 method: method.into(),
264 params: params.map(|s| RawValue::from_string(s.to_owned()).unwrap()),
265 id: RequestId(id.map(Id::Number)),
266 }
267 }
268
269 #[tokio::test]
270 async fn test_simple_handler() {
271 let mut server: Server<()> = Server::new(ServerConfig::default());
272 server.register("add", |_: (), (a, b): (i64, i64)| async move { Ok::<_, Error>(a + b) });
273
274 let req = make_request("add", Some("[3, 4]"), Some(1));
275 let message = server.handle((), RequestMessage::Single(req)).await.unwrap();
276
277 match message {
278 ResponseMessage::Single(Response::Success {
279 result,
280 id,
281 }) => {
282 assert_eq!(id, Id::Number(1));
283 let v: i64 = serde_json::from_str(result.get()).unwrap();
284 assert_eq!(v, 7);
285 }
286 other => panic!("expected success response, got {other:?}"),
287 }
288 }
289
290 #[tokio::test]
291 async fn test_handler_with_error() {
292 let mut server: Server<()> = Server::new(ServerConfig::default());
293 server.register("div", |_: (), (a, b): (i64, i64)| async move {
294 if b == 0 {
295 Err(Error::new(-32000, "division by zero"))
296 } else {
297 Ok(a / b)
298 }
299 });
300
301 let req = make_request("div", Some("[4, 0]"), Some(1));
302 let message = server.handle((), RequestMessage::Single(req)).await.unwrap();
303
304 match message {
305 ResponseMessage::Single(Response::Error {
306 error,
307 id,
308 }) => {
309 assert_eq!(id, Id::Number(1));
310 assert_eq!(error.code, -32000);
311 assert_eq!(error.message, "division by zero");
312 }
313 other => panic!("expected error response, got {other:?}"),
314 }
315 }
316
317 #[tokio::test]
318 async fn test_method_not_found() {
319 let server: Server<()> = Server::new(ServerConfig::default());
320 let req = make_request("unknown", None, Some(1));
321 let message = server.handle((), RequestMessage::Single(req)).await.unwrap();
322
323 match message {
324 ResponseMessage::Single(Response::Error {
325 error,
326 id,
327 }) => {
328 assert_eq!(id, Id::Number(1));
329 assert_eq!(error.code, ErrorCode::METHOD_NOT_FOUND);
330 }
331 other => panic!("expected error response, got {other:?}"),
332 }
333 }
334
335 #[tokio::test]
336 async fn test_invalid_params() {
337 let mut server: Server<()> = Server::new(ServerConfig::default());
338 server.register("add", |_: (), (a, b): (i64, i64)| async move { Ok::<_, Error>(a + b) });
339
340 let req = make_request("add", Some(r#""not_an_array""#), Some(1));
341 let message = server.handle((), RequestMessage::Single(req)).await.unwrap();
342
343 match message {
344 ResponseMessage::Single(Response::Error {
345 error,
346 id,
347 }) => {
348 assert_eq!(id, Id::Number(1));
349 assert_eq!(error.code, ErrorCode::INVALID_PARAMS);
350 }
351 other => panic!("expected error response, got {other:?}"),
352 }
353 }
354
355 #[tokio::test]
356 async fn test_notification_is_silent() {
357 let mut server: Server<()> = Server::new(ServerConfig::default());
358 server.register("log", |_: (), _message: (String,)| async move { Ok::<_, Error>(()) });
359
360 let req = make_request("log", Some(r#"["hello"]"#), None);
361 let message = server.handle((), RequestMessage::Single(req)).await;
362
363 assert!(message.is_none());
364 }
365
366 #[tokio::test]
367 async fn test_empty_batch() {
368 let server: Server<()> = Server::new(ServerConfig::default());
369 let message = server.handle((), RequestMessage::Batch(vec![])).await.unwrap();
370
371 match message {
372 ResponseMessage::Single(Response::Error {
373 error,
374 id,
375 }) => {
376 assert_eq!(id, Id::Null);
377 assert_eq!(error.code, ErrorCode::INVALID_REQUEST);
378 }
379 other => panic!("expected single error for empty batch, got {other:?}"),
380 }
381 }
382
383 #[tokio::test]
384 async fn test_batch_mixed() {
385 let mut server: Server<()> = Server::new(ServerConfig::default());
386 server.register("add", |_: (), (a, b): (i64, i64)| async move { Ok::<_, Error>(a + b) });
387
388 let entries = vec![
389 make_request("add", Some("[1, 2]"), Some(1)),
390 make_request("add", Some("[3, 4]"), None),
391 make_request("add", Some("[5, 6]"), Some(2)),
392 ];
393
394 let message = server.handle((), RequestMessage::Batch(entries)).await.unwrap();
395
396 match message {
397 ResponseMessage::Batch(responses) => {
398 assert_eq!(responses.len(), 2);
399 }
400 other => panic!("expected batch response, got {other:?}"),
401 }
402 }
403
404 #[tokio::test]
428 async fn test_all_notification_batch_is_empty() {
429 let mut server: Server<()> = Server::new(ServerConfig::default());
430 server.register("notify", |_: (), _message: (String,)| async move { Ok::<_, Error>(()) });
431
432 let entries = vec![
433 make_request("notify", Some(r#"["a"]"#), None),
434 make_request("notify", Some(r#"["b"]"#), None),
435 ];
436
437 let message = server.handle((), RequestMessage::Batch(entries)).await;
438 assert!(message.is_none());
439 }
440
441 #[test]
442 fn test_response_message_to_json_single() {
443 let resp = Response::success(Id::Number(1), 42).unwrap();
444 let message = ResponseMessage::Single(resp);
445 let json = serde_json::to_string(&message).unwrap();
446 let v: serde_json::Value = serde_json::from_str(&json).unwrap();
447 assert_eq!(v["result"], serde_json::json!(42));
448 }
449
450 #[test]
451 fn test_response_message_to_json_batch() {
452 let resps = vec![
453 Response::success(Id::Number(1), 10).unwrap(),
454 Response::success(Id::Number(2), 20).unwrap(),
455 ];
456 let message = ResponseMessage::Batch(resps);
457 let json = serde_json::to_string(&message).unwrap();
458 let v: serde_json::Value = serde_json::from_str(&json).unwrap();
459 assert!(v.is_array());
460 assert_eq!(v.as_array().unwrap().len(), 2);
461 }
462
463 #[tokio::test]
464 async fn test_handler_with_context() {
465 #[derive(Clone)]
466 struct State {
467 base: i64,
468 }
469
470 let mut server: Server<State> = Server::new(ServerConfig::default());
471 server.register("add", |ctx: State, (x,): (i64,)| async move { Ok::<_, Error>(ctx.base + x) });
472
473 let state = State {
474 base: 100,
475 };
476 let req = make_request("add", Some("[5]"), Some(1));
477 let message = server.handle(state, RequestMessage::Single(req)).await.unwrap();
478
479 match message {
480 ResponseMessage::Single(Response::Success {
481 result, ..
482 }) => {
483 let v: i64 = serde_json::from_str(result.get()).unwrap();
484 assert_eq!(v, 105);
485 }
486 other => panic!("expected success, got {other:?}"),
487 }
488 }
489
490 #[tokio::test]
491 async fn test_batch_rejected_when_not_accepted() {
492 let server: Server<()> = Server::new(ServerConfig {
493 accept_batches: false,
494 });
495
496 let entries = vec![make_request("add", Some("[1, 2]"), Some(1))];
497
498 let message = server.handle((), RequestMessage::Batch(entries)).await.unwrap();
499
500 match message {
501 ResponseMessage::Single(Response::Error {
502 error,
503 id,
504 }) => {
505 assert_eq!(id, Id::Null);
506 assert_eq!(error.code, ErrorCode::INVALID_REQUEST);
507 assert!(error.message.contains("batch requests are not accepted"));
508 }
509 other => panic!("expected single error for rejected batch, got {other:?}"),
510 }
511 }
512}