use std::{ collections::VecDeque, io::{self, BufRead, Read, Write}, sync::{Arc, Condvar, Mutex}, thread, }; use paths::Utf8PathBuf; use proc_macro_api::{ ServerError, bidirectional_protocol::msg::{ BidirectionalMessage, Request as BiRequest, Response as BiResponse, SubRequest, SubResponse, }, legacy_protocol::msg::{FlatTree, Message, Request, Response, SpanDataIndexMap}, }; use span::{Edition, EditionedFileId, FileId, Span, SpanAnchor, SyntaxContext, TextRange}; use tt::{Delimiter, DelimiterKind, TopSubtreeBuilder}; /// Shared state for an in-memory byte channel. #[derive(Default)] struct ChannelState { buffer: VecDeque, closed: bool, } type InMemoryChannel = Arc<(Mutex, Condvar)>; /// Writer end of an in-memory channel. pub(crate) struct ChannelWriter { state: InMemoryChannel, } impl Write for ChannelWriter { fn write(&mut self, buf: &[u8]) -> io::Result { let (lock, cvar) = &*self.state; let mut state = lock.lock().unwrap(); if state.closed { return Err(io::Error::new(io::ErrorKind::BrokenPipe, "channel closed")); } state.buffer.extend(buf); cvar.notify_all(); Ok(buf.len()) } fn flush(&mut self) -> io::Result<()> { Ok(()) } } impl Drop for ChannelWriter { fn drop(&mut self) { let (lock, cvar) = &*self.state; let mut state = lock.lock().unwrap(); state.closed = true; cvar.notify_all(); } } /// Reader end of an in-memory channel. pub(crate) struct ChannelReader { state: InMemoryChannel, internal_buf: Vec, } impl Read for ChannelReader { fn read(&mut self, buf: &mut [u8]) -> io::Result { let (lock, cvar) = &*self.state; let mut state = lock.lock().unwrap(); while state.buffer.is_empty() && !state.closed { state = cvar.wait(state).unwrap(); } if state.buffer.is_empty() && state.closed { return Ok(0); } let to_read = buf.len().min(state.buffer.len()); for (dst, src) in buf.iter_mut().zip(state.buffer.drain(..to_read)) { *dst = src; } Ok(to_read) } } impl BufRead for ChannelReader { fn fill_buf(&mut self) -> io::Result<&[u8]> { let (lock, cvar) = &*self.state; let mut state = lock.lock().unwrap(); while state.buffer.is_empty() && !state.closed { state = cvar.wait(state).unwrap(); } self.internal_buf.clear(); self.internal_buf.extend(&state.buffer); Ok(&self.internal_buf) } fn consume(&mut self, amt: usize) { let (lock, _) = &*self.state; let mut state = lock.lock().unwrap(); let to_drain = amt.min(state.buffer.len()); drop(state.buffer.drain(..to_drain)); } } /// Creates a connected pair of channels for bidirectional communication. fn create_channel_pair() -> (ChannelWriter, ChannelReader, ChannelWriter, ChannelReader) { // Channel for client -> server communication let client_to_server = Arc::new(( Mutex::new(ChannelState { buffer: VecDeque::new(), closed: false }), Condvar::new(), )); let client_writer = ChannelWriter { state: client_to_server.clone() }; let server_reader = ChannelReader { state: client_to_server, internal_buf: Vec::new() }; // Channel for server -> client communication let server_to_client = Arc::new(( Mutex::new(ChannelState { buffer: VecDeque::new(), closed: false }), Condvar::new(), )); let server_writer = ChannelWriter { state: server_to_client.clone() }; let client_reader = ChannelReader { state: server_to_client, internal_buf: Vec::new() }; (client_writer, client_reader, server_writer, server_reader) } pub(crate) fn proc_macro_test_dylib_path() -> Utf8PathBuf { let path = proc_macro_test::PROC_MACRO_TEST_LOCATION; if path.is_empty() { panic!("proc-macro-test dylib not available (requires nightly toolchain)"); } path.into() } /// Creates a simple empty token tree suitable for testing. pub(crate) fn create_empty_token_tree( version: u32, span_data_table: &mut SpanDataIndexMap, ) -> FlatTree { let anchor = SpanAnchor { file_id: EditionedFileId::new(FileId::from_raw(0), Edition::CURRENT), ast_id: span::ROOT_ERASED_FILE_AST_ID, }; let span = Span { range: TextRange::empty(0.into()), anchor, ctx: SyntaxContext::root(Edition::CURRENT), }; let builder = TopSubtreeBuilder::new(Delimiter { open: span, close: span, kind: DelimiterKind::Invisible, }); let tt = builder.build(); FlatTree::from_subtree(tt.view(), version, span_data_table) } pub(crate) fn with_server(format: proc_macro_api::ProtocolFormat, test_fn: F) -> R where F: FnOnce(&mut dyn Write, &mut dyn BufRead) -> R, { let (mut client_writer, mut client_reader, mut server_writer, mut server_reader) = create_channel_pair(); let server_handle = thread::spawn(move || { proc_macro_srv_cli::main_loop::run(&mut server_reader, &mut server_writer, format) }); let result = test_fn(&mut client_writer, &mut client_reader); drop(client_writer); match server_handle.join() { Ok(Ok(())) => {} Ok(Err(e)) => { if !matches!( e.kind(), io::ErrorKind::BrokenPipe | io::ErrorKind::UnexpectedEof | io::ErrorKind::InvalidData ) { panic!("Server error: {e}"); } } Err(e) => std::panic::resume_unwind(e), } result } trait TestProtocol { type Request; type Response; fn request(&self, writer: &mut dyn Write, req: Self::Request); fn receive(&self, reader: &mut dyn BufRead, writer: &mut dyn Write) -> Self::Response; } #[allow(dead_code)] struct JsonLegacy; impl TestProtocol for JsonLegacy { type Request = Request; type Response = Response; fn request(&self, writer: &mut dyn Write, req: Request) { req.write(writer).expect("failed to write request"); } fn receive(&self, reader: &mut dyn BufRead, _writer: &mut dyn Write) -> Response { let mut buf = String::new(); Response::read(reader, &mut buf) .expect("failed to read response") .expect("no response received") } } #[allow(dead_code)] struct PostcardBidirectional where F: Fn(SubRequest) -> Result, { callback: F, } impl TestProtocol for PostcardBidirectional where F: Fn(SubRequest) -> Result, { type Request = BiRequest; type Response = BiResponse; fn request(&self, writer: &mut dyn Write, req: BiRequest) { let msg = BidirectionalMessage::Request(req); msg.write(writer).expect("failed to write request"); } fn receive(&self, reader: &mut dyn BufRead, writer: &mut dyn Write) -> BiResponse { let mut buf = Vec::new(); loop { let msg = BidirectionalMessage::read(reader, &mut buf) .expect("failed to read message") .expect("no message received"); match msg { BidirectionalMessage::Response(resp) => return resp, BidirectionalMessage::SubRequest(sr) => { let reply = (self.callback)(sr).expect("subrequest callback failed"); let msg = BidirectionalMessage::SubResponse(reply); msg.write(writer).expect("failed to write subresponse"); } other => panic!("unexpected message: {other:?}"), } } } } #[allow(dead_code)] pub(crate) fn request_legacy( writer: &mut dyn Write, reader: &mut dyn BufRead, request: Request, ) -> Response { let protocol = JsonLegacy; protocol.request(writer, request); protocol.receive(reader, writer) } #[allow(dead_code)] pub(crate) fn request_bidirectional( writer: &mut dyn Write, reader: &mut dyn BufRead, request: BiRequest, callback: F, ) -> BiResponse where F: Fn(SubRequest) -> Result, { let protocol = PostcardBidirectional { callback }; protocol.request(writer, request); protocol.receive(reader, writer) }