Skip to main content

json_rpc/
server.rs

1use std::{collections::HashMap, future::Future, marker::PhantomData, pin::Pin, sync::Arc};
2
3use serde::{Serialize, Serializer, de::DeserializeOwned};
4use serde_json::value::RawValue;
5
6use crate::{Error, ErrorCode, Id, Request, RequestMessage, Response};
7
8trait MethodHandler<C>: Send + Sync {
9    fn call(&self, ctx: C, params: &RawValue) -> Pin<Box<dyn Future<Output = Result<Box<RawValue>, Error>> + Send>>;
10}
11
12struct MethodHandlerImpl<C, H, P, R, E, F> {
13    handler: H,
14    _phantom: PhantomData<fn(C, P) -> (R, E, F)>,
15}
16
17impl<C, P, R, E, F, H> MethodHandler<C> for MethodHandlerImpl<C, H, P, R, E, F>
18where
19    C: Send + 'static,
20    P: DeserializeOwned + Send,
21    R: Serialize + Send,
22    E: Into<Error> + Send,
23    F: Future<Output = Result<R, E>> + Send + 'static,
24    H: Fn(C, P) -> F + Send + Sync,
25{
26    fn call(
27        &self,
28        ctx: C,
29        raw_params: &RawValue,
30    ) -> Pin<Box<dyn Future<Output = Result<Box<RawValue>, Error>> + Send>> {
31        let params: P = match serde_json::from_str(raw_params.get()) {
32            Ok(p) => p,
33            Err(e) => {
34                return Box::pin(async move { Err(Error::invalid_params(e.to_string())) });
35            }
36        };
37        let fut = (self.handler)(ctx, params);
38        Box::pin(async move {
39            match fut.await {
40                Ok(result) => serde_json::value::to_raw_value(&result)
41                    .map_err(|e| Error::new(ErrorCode::INTERNAL_ERROR, e.to_string())),
42                Err(e) => Err(e.into()),
43            }
44        })
45    }
46}
47
48/// The output of [`Server::handle`].
49///
50/// A notification or all-notification batch results in a `None` return
51/// from [`Server::handle`], meaning nothing should be sent back.
52#[derive(Clone, Debug)]
53pub enum ResponseMessage {
54    /// A single response.
55    Single(Response),
56    /// A batch of responses.
57    Batch(Vec<Response>),
58}
59
60impl Serialize for ResponseMessage {
61    fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
62        match self {
63            Self::Single(resp) => resp.serialize(serializer),
64            Self::Batch(resps) => resps.serialize(serializer),
65        }
66    }
67}
68
69/// Configuration for a [`Server`].
70///
71/// Use [`ServerConfig::default`] for the default configuration (accept batches).
72///
73/// # Example
74///
75/// ```rust
76/// use json_rpc::{Server, ServerConfig, Error};
77///
78/// let config = ServerConfig {
79///     accept_batches: false,
80/// };
81/// let mut server = Server::<()>::new(config);
82/// server.register("ping", |_: (), ()| async move { Ok::<_, Error>("pong".to_string()) });
83/// ```
84#[derive(Clone, Debug)]
85pub struct ServerConfig {
86    /// Whether the server accepts batch requests.
87    ///
88    /// When `false`, any batch request received by [`Server::handle`] will be
89    /// rejected with a single `-32600` Invalid Request error response, per the
90    /// JSON-RPC 2.0 spec rule that an unrecognized batch yields a single error.
91    ///
92    /// Default: `true` (batches are accepted).
93    pub accept_batches: bool,
94}
95
96impl Default for ServerConfig {
97    fn default() -> Self {
98        Self {
99            accept_batches: true,
100        }
101    }
102}
103
104/// A JSON-RPC 2.0 server.
105///
106/// Generic over a context type `C` that is cloned once per handler invocation.
107///
108/// # Example
109///
110/// ```rust
111/// use json_rpc::{Server, ServerConfig, Error};
112///
113/// let mut server = Server::<()>::new(ServerConfig::default());
114/// server.register("add", |_: (), (a, b): (i64, i64)| async move {
115///     Ok::<_, Error>(a + b)
116/// });
117/// ```
118pub struct Server<C> {
119    config: ServerConfig,
120    methods: HashMap<String, Arc<dyn MethodHandler<C>>>,
121    empty_params: Box<RawValue>,
122}
123
124impl<C: Send + Sync + 'static> Server<C> {
125    /// Creates a new server with the given configuration and no registered methods.
126    ///
127    /// Use [`ServerConfig::default`] to accept batches (the default behavior).
128    pub fn new(config: ServerConfig) -> Self {
129        Self {
130            config,
131            methods: HashMap::new(),
132            empty_params: RawValue::from_string("{}".to_owned()).expect("{} is valid JSON"),
133        }
134    }
135
136    /// Registers an async handler for the given method name.
137    ///
138    /// The handler receives an owned clone of the context and deserialized
139    /// method parameters, and returns a future.
140    pub fn register<P, R, E, F>(
141        &mut self,
142        method: impl Into<String>,
143        handler: impl Fn(C, P) -> F + Send + Sync + 'static,
144    ) where
145        P: DeserializeOwned + Send + 'static,
146        R: Serialize + Send + 'static,
147        E: Into<Error> + Send + 'static,
148        F: Future<Output = Result<R, E>> + Send + 'static,
149    {
150        let entry = MethodHandlerImpl::<C, _, P, R, E, F> {
151            handler,
152            _phantom: PhantomData,
153        };
154        self.methods.insert(method.into(), Arc::new(entry));
155    }
156
157    /// Handles a request message and returns the corresponding response message.
158    ///
159    /// The context `ctx` is consumed and, for batches, cloned once per handler invocation.
160    ///
161    /// Returns `None` when the message was a notification (or an all-notification batch)
162    /// — nothing should be sent back on the transport.
163    ///
164    /// When the server was configured with [`ServerConfig::accept_batches`] set to `false`,
165    /// batch requests are rejected with a single `-32600` Invalid Request error.
166    pub async fn handle(&self, ctx: C, message: RequestMessage) -> Option<ResponseMessage>
167    where
168        C: Clone,
169    {
170        match message {
171            RequestMessage::Single(req) => self.handle_single(ctx, req).await,
172            RequestMessage::Batch(entries) => {
173                if !self.config.accept_batches {
174                    return Some(ResponseMessage::Single(Response::Error {
175                        error: Error::invalid_request("batch requests are not accepted"),
176                        id: Id::Null,
177                    }));
178                }
179                self.handle_batch(ctx, entries).await
180            }
181        }
182    }
183
184    async fn handle_single(&self, ctx: C, req: Request) -> Option<ResponseMessage> {
185        let Some(id) = req.id.into_id() else {
186            let _ = self
187                .dispatch(ctx, &req.method, req.params.as_deref().unwrap_or(&self.empty_params))
188                .await;
189            return None;
190        };
191
192        let params = req.params.as_deref().unwrap_or(&self.empty_params);
193        match self.dispatch(ctx, &req.method, params).await {
194            Ok(result) => Some(ResponseMessage::Single(Response::Success {
195                result,
196                id,
197            })),
198            Err(error) => Some(ResponseMessage::Single(Response::Error {
199                error,
200                id,
201            })),
202        }
203    }
204
205    async fn handle_batch(&self, ctx: C, entries: Vec<Request>) -> Option<ResponseMessage>
206    where
207        C: Clone,
208    {
209        if entries.is_empty() {
210            return Some(ResponseMessage::Single(Response::Error {
211                error: Error::invalid_request("empty batch"),
212                id: Id::Null,
213            }));
214        }
215
216        let mut responses: Vec<Response> = Vec::with_capacity(entries.len());
217
218        for req in entries {
219            let Some(id) = req.id.into_id() else {
220                let _ = self
221                    .dispatch(ctx.clone(), &req.method, req.params.as_deref().unwrap_or(&self.empty_params))
222                    .await;
223                continue;
224            };
225
226            let params = req.params.as_deref().unwrap_or(&self.empty_params);
227            match self.dispatch(ctx.clone(), &req.method, params).await {
228                Ok(result) => responses.push(Response::Success {
229                    result,
230                    id,
231                }),
232                Err(error) => responses.push(Response::Error {
233                    error,
234                    id,
235                }),
236            }
237        }
238
239        if responses.is_empty() {
240            None
241        } else {
242            Some(ResponseMessage::Batch(responses))
243        }
244    }
245
246    async fn dispatch(&self, ctx: C, method: &str, params: &RawValue) -> Result<Box<RawValue>, Error> {
247        let callback = self
248            .methods
249            .get(method)
250            .ok_or_else(|| Error::method_not_found(method))?;
251        callback.call(ctx, params).await
252    }
253}
254
255#[cfg(test)]
256mod tests {
257    use super::*;
258    use crate::{ErrorCode, RequestId};
259
260    fn make_request(method: &str, params: Option<&str>, id: Option<i64>) -> Request {
261        Request {
262            jsonrpc: "2.0".into(),
263            method: method.into(),
264            params: params.map(|s| RawValue::from_string(s.to_owned()).unwrap()),
265            id: RequestId(id.map(Id::Number)),
266        }
267    }
268
269    #[tokio::test]
270    async fn test_simple_handler() {
271        let mut server: Server<()> = Server::new(ServerConfig::default());
272        server.register("add", |_: (), (a, b): (i64, i64)| async move { Ok::<_, Error>(a + b) });
273
274        let req = make_request("add", Some("[3, 4]"), Some(1));
275        let message = server.handle((), RequestMessage::Single(req)).await.unwrap();
276
277        match message {
278            ResponseMessage::Single(Response::Success {
279                result,
280                id,
281            }) => {
282                assert_eq!(id, Id::Number(1));
283                let v: i64 = serde_json::from_str(result.get()).unwrap();
284                assert_eq!(v, 7);
285            }
286            other => panic!("expected success response, got {other:?}"),
287        }
288    }
289
290    #[tokio::test]
291    async fn test_handler_with_error() {
292        let mut server: Server<()> = Server::new(ServerConfig::default());
293        server.register("div", |_: (), (a, b): (i64, i64)| async move {
294            if b == 0 {
295                Err(Error::new(-32000, "division by zero"))
296            } else {
297                Ok(a / b)
298            }
299        });
300
301        let req = make_request("div", Some("[4, 0]"), Some(1));
302        let message = server.handle((), RequestMessage::Single(req)).await.unwrap();
303
304        match message {
305            ResponseMessage::Single(Response::Error {
306                error,
307                id,
308            }) => {
309                assert_eq!(id, Id::Number(1));
310                assert_eq!(error.code, -32000);
311                assert_eq!(error.message, "division by zero");
312            }
313            other => panic!("expected error response, got {other:?}"),
314        }
315    }
316
317    #[tokio::test]
318    async fn test_method_not_found() {
319        let server: Server<()> = Server::new(ServerConfig::default());
320        let req = make_request("unknown", None, Some(1));
321        let message = server.handle((), RequestMessage::Single(req)).await.unwrap();
322
323        match message {
324            ResponseMessage::Single(Response::Error {
325                error,
326                id,
327            }) => {
328                assert_eq!(id, Id::Number(1));
329                assert_eq!(error.code, ErrorCode::METHOD_NOT_FOUND);
330            }
331            other => panic!("expected error response, got {other:?}"),
332        }
333    }
334
335    #[tokio::test]
336    async fn test_invalid_params() {
337        let mut server: Server<()> = Server::new(ServerConfig::default());
338        server.register("add", |_: (), (a, b): (i64, i64)| async move { Ok::<_, Error>(a + b) });
339
340        let req = make_request("add", Some(r#""not_an_array""#), Some(1));
341        let message = server.handle((), RequestMessage::Single(req)).await.unwrap();
342
343        match message {
344            ResponseMessage::Single(Response::Error {
345                error,
346                id,
347            }) => {
348                assert_eq!(id, Id::Number(1));
349                assert_eq!(error.code, ErrorCode::INVALID_PARAMS);
350            }
351            other => panic!("expected error response, got {other:?}"),
352        }
353    }
354
355    #[tokio::test]
356    async fn test_notification_is_silent() {
357        let mut server: Server<()> = Server::new(ServerConfig::default());
358        server.register("log", |_: (), _message: (String,)| async move { Ok::<_, Error>(()) });
359
360        let req = make_request("log", Some(r#"["hello"]"#), None);
361        let message = server.handle((), RequestMessage::Single(req)).await;
362
363        assert!(message.is_none());
364    }
365
366    #[tokio::test]
367    async fn test_empty_batch() {
368        let server: Server<()> = Server::new(ServerConfig::default());
369        let message = server.handle((), RequestMessage::Batch(vec![])).await.unwrap();
370
371        match message {
372            ResponseMessage::Single(Response::Error {
373                error,
374                id,
375            }) => {
376                assert_eq!(id, Id::Null);
377                assert_eq!(error.code, ErrorCode::INVALID_REQUEST);
378            }
379            other => panic!("expected single error for empty batch, got {other:?}"),
380        }
381    }
382
383    #[tokio::test]
384    async fn test_batch_mixed() {
385        let mut server: Server<()> = Server::new(ServerConfig::default());
386        server.register("add", |_: (), (a, b): (i64, i64)| async move { Ok::<_, Error>(a + b) });
387
388        let entries = vec![
389            make_request("add", Some("[1, 2]"), Some(1)),
390            make_request("add", Some("[3, 4]"), None),
391            make_request("add", Some("[5, 6]"), Some(2)),
392        ];
393
394        let message = server.handle((), RequestMessage::Batch(entries)).await.unwrap();
395
396        match message {
397            ResponseMessage::Batch(responses) => {
398                assert_eq!(responses.len(), 2);
399            }
400            other => panic!("expected batch response, got {other:?}"),
401        }
402    }
403
404    // #[tokio::test]
405    // async fn test_batch_with_invalid_entry() {
406    //     let mut server: Server<()> = Server::new(ServerConfig::default());
407    //     server.register("add", |_: (), (a, b): (i64, i64)| async move { Ok::<_, Error>(a + b) });
408
409    //     let json = r#"[
410    //         {"jsonrpc":"2.0","method":"add","params":[1,2],"id":1},
411    //         42,
412    //         {"jsonrpc":"2.0","method":"add","params":[3,4],"id":2}
413    //     ]"#;
414    //     let message: RequestMessage = serde_json::from_str(json).unwrap();
415    //     let message = server.handle((), message).await.unwrap();
416
417    //     match message {
418    //         ResponseMessage::Batch(responses) => {
419    //             assert_eq!(responses.len(), 2);
420    //             assert!(responses[0].is_success());
421    //             assert!(responses[1].is_success());
422    //         }
423    //         other => panic!("expected batch response, got {other:?}"),
424    //     }
425    // }
426
427    #[tokio::test]
428    async fn test_all_notification_batch_is_empty() {
429        let mut server: Server<()> = Server::new(ServerConfig::default());
430        server.register("notify", |_: (), _message: (String,)| async move { Ok::<_, Error>(()) });
431
432        let entries = vec![
433            make_request("notify", Some(r#"["a"]"#), None),
434            make_request("notify", Some(r#"["b"]"#), None),
435        ];
436
437        let message = server.handle((), RequestMessage::Batch(entries)).await;
438        assert!(message.is_none());
439    }
440
441    #[test]
442    fn test_response_message_to_json_single() {
443        let resp = Response::success(Id::Number(1), 42).unwrap();
444        let message = ResponseMessage::Single(resp);
445        let json = serde_json::to_string(&message).unwrap();
446        let v: serde_json::Value = serde_json::from_str(&json).unwrap();
447        assert_eq!(v["result"], serde_json::json!(42));
448    }
449
450    #[test]
451    fn test_response_message_to_json_batch() {
452        let resps = vec![
453            Response::success(Id::Number(1), 10).unwrap(),
454            Response::success(Id::Number(2), 20).unwrap(),
455        ];
456        let message = ResponseMessage::Batch(resps);
457        let json = serde_json::to_string(&message).unwrap();
458        let v: serde_json::Value = serde_json::from_str(&json).unwrap();
459        assert!(v.is_array());
460        assert_eq!(v.as_array().unwrap().len(), 2);
461    }
462
463    #[tokio::test]
464    async fn test_handler_with_context() {
465        #[derive(Clone)]
466        struct State {
467            base: i64,
468        }
469
470        let mut server: Server<State> = Server::new(ServerConfig::default());
471        server.register("add", |ctx: State, (x,): (i64,)| async move { Ok::<_, Error>(ctx.base + x) });
472
473        let state = State {
474            base: 100,
475        };
476        let req = make_request("add", Some("[5]"), Some(1));
477        let message = server.handle(state, RequestMessage::Single(req)).await.unwrap();
478
479        match message {
480            ResponseMessage::Single(Response::Success {
481                result, ..
482            }) => {
483                let v: i64 = serde_json::from_str(result.get()).unwrap();
484                assert_eq!(v, 105);
485            }
486            other => panic!("expected success, got {other:?}"),
487        }
488    }
489
490    #[tokio::test]
491    async fn test_batch_rejected_when_not_accepted() {
492        let server: Server<()> = Server::new(ServerConfig {
493            accept_batches: false,
494        });
495
496        let entries = vec![make_request("add", Some("[1, 2]"), Some(1))];
497
498        let message = server.handle((), RequestMessage::Batch(entries)).await.unwrap();
499
500        match message {
501            ResponseMessage::Single(Response::Error {
502                error,
503                id,
504            }) => {
505                assert_eq!(id, Id::Null);
506                assert_eq!(error.code, ErrorCode::INVALID_REQUEST);
507                assert!(error.message.contains("batch requests are not accepted"));
508            }
509            other => panic!("expected single error for rejected batch, got {other:?}"),
510        }
511    }
512}