Unnamed repository; edit this file 'description' to name the repository.
Diffstat (limited to 'crates/proc-macro-srv-cli/tests/common/utils.rs')
-rw-r--r--crates/proc-macro-srv-cli/tests/common/utils.rs288
1 files changed, 288 insertions, 0 deletions
diff --git a/crates/proc-macro-srv-cli/tests/common/utils.rs b/crates/proc-macro-srv-cli/tests/common/utils.rs
new file mode 100644
index 0000000000..3049e98004
--- /dev/null
+++ b/crates/proc-macro-srv-cli/tests/common/utils.rs
@@ -0,0 +1,288 @@
+use std::{
+ collections::VecDeque,
+ io::{self, BufRead, Read, Write},
+ sync::{Arc, Condvar, Mutex},
+ thread,
+};
+
+use paths::Utf8PathBuf;
+use proc_macro_api::{
+ ServerError,
+ bidirectional_protocol::msg::{
+ BidirectionalMessage, Request as BiRequest, Response as BiResponse, SubRequest, SubResponse,
+ },
+ legacy_protocol::msg::{FlatTree, Message, Request, Response, SpanDataIndexMap},
+};
+use span::{Edition, EditionedFileId, FileId, Span, SpanAnchor, SyntaxContext, TextRange};
+use tt::{Delimiter, DelimiterKind, TopSubtreeBuilder};
+
+/// Shared state for an in-memory byte channel.
+#[derive(Default)]
+struct ChannelState {
+ buffer: VecDeque<u8>,
+ closed: bool,
+}
+
+type InMemoryChannel = Arc<(Mutex<ChannelState>, Condvar)>;
+
+/// Writer end of an in-memory channel.
+pub(crate) struct ChannelWriter {
+ state: InMemoryChannel,
+}
+
+impl Write for ChannelWriter {
+ fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
+ let (lock, cvar) = &*self.state;
+ let mut state = lock.lock().unwrap();
+ if state.closed {
+ return Err(io::Error::new(io::ErrorKind::BrokenPipe, "channel closed"));
+ }
+ state.buffer.extend(buf);
+ cvar.notify_all();
+ Ok(buf.len())
+ }
+
+ fn flush(&mut self) -> io::Result<()> {
+ Ok(())
+ }
+}
+
+impl Drop for ChannelWriter {
+ fn drop(&mut self) {
+ let (lock, cvar) = &*self.state;
+ let mut state = lock.lock().unwrap();
+ state.closed = true;
+ cvar.notify_all();
+ }
+}
+
+/// Reader end of an in-memory channel.
+pub(crate) struct ChannelReader {
+ state: InMemoryChannel,
+ internal_buf: Vec<u8>,
+}
+
+impl Read for ChannelReader {
+ fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
+ let (lock, cvar) = &*self.state;
+ let mut state = lock.lock().unwrap();
+
+ while state.buffer.is_empty() && !state.closed {
+ state = cvar.wait(state).unwrap();
+ }
+
+ if state.buffer.is_empty() && state.closed {
+ return Ok(0);
+ }
+
+ let to_read = buf.len().min(state.buffer.len());
+ for (dst, src) in buf.iter_mut().zip(state.buffer.drain(..to_read)) {
+ *dst = src;
+ }
+ Ok(to_read)
+ }
+}
+
+impl BufRead for ChannelReader {
+ fn fill_buf(&mut self) -> io::Result<&[u8]> {
+ let (lock, cvar) = &*self.state;
+ let mut state = lock.lock().unwrap();
+
+ while state.buffer.is_empty() && !state.closed {
+ state = cvar.wait(state).unwrap();
+ }
+
+ self.internal_buf.clear();
+ self.internal_buf.extend(&state.buffer);
+ Ok(&self.internal_buf)
+ }
+
+ fn consume(&mut self, amt: usize) {
+ let (lock, _) = &*self.state;
+ let mut state = lock.lock().unwrap();
+ let to_drain = amt.min(state.buffer.len());
+ drop(state.buffer.drain(..to_drain));
+ }
+}
+
+/// Creates a connected pair of channels for bidirectional communication.
+fn create_channel_pair() -> (ChannelWriter, ChannelReader, ChannelWriter, ChannelReader) {
+ // Channel for client -> server communication
+ let client_to_server = Arc::new((
+ Mutex::new(ChannelState { buffer: VecDeque::new(), closed: false }),
+ Condvar::new(),
+ ));
+ let client_writer = ChannelWriter { state: client_to_server.clone() };
+ let server_reader = ChannelReader { state: client_to_server, internal_buf: Vec::new() };
+
+ // Channel for server -> client communication
+ let server_to_client = Arc::new((
+ Mutex::new(ChannelState { buffer: VecDeque::new(), closed: false }),
+ Condvar::new(),
+ ));
+
+ let server_writer = ChannelWriter { state: server_to_client.clone() };
+ let client_reader = ChannelReader { state: server_to_client, internal_buf: Vec::new() };
+
+ (client_writer, client_reader, server_writer, server_reader)
+}
+
+pub(crate) fn proc_macro_test_dylib_path() -> Utf8PathBuf {
+ let path = proc_macro_test::PROC_MACRO_TEST_LOCATION;
+ if path.is_empty() {
+ panic!("proc-macro-test dylib not available (requires nightly toolchain)");
+ }
+ path.into()
+}
+
+/// Creates a simple empty token tree suitable for testing.
+pub(crate) fn create_empty_token_tree(
+ version: u32,
+ span_data_table: &mut SpanDataIndexMap,
+) -> FlatTree {
+ let anchor = SpanAnchor {
+ file_id: EditionedFileId::new(FileId::from_raw(0), Edition::CURRENT),
+ ast_id: span::ROOT_ERASED_FILE_AST_ID,
+ };
+ let span = Span {
+ range: TextRange::empty(0.into()),
+ anchor,
+ ctx: SyntaxContext::root(Edition::CURRENT),
+ };
+
+ let builder = TopSubtreeBuilder::new(Delimiter {
+ open: span,
+ close: span,
+ kind: DelimiterKind::Invisible,
+ });
+ let tt = builder.build();
+
+ FlatTree::from_subtree(tt.view(), version, span_data_table)
+}
+
+pub(crate) fn with_server<F, R>(format: proc_macro_api::ProtocolFormat, test_fn: F) -> R
+where
+ F: FnOnce(&mut dyn Write, &mut dyn BufRead) -> R,
+{
+ let (mut client_writer, mut client_reader, mut server_writer, mut server_reader) =
+ create_channel_pair();
+
+ let server_handle = thread::spawn(move || {
+ proc_macro_srv_cli::main_loop::run(&mut server_reader, &mut server_writer, format)
+ });
+
+ let result = test_fn(&mut client_writer, &mut client_reader);
+
+ drop(client_writer);
+
+ match server_handle.join() {
+ Ok(Ok(())) => {}
+ Ok(Err(e)) => {
+ if !matches!(
+ e.kind(),
+ io::ErrorKind::BrokenPipe
+ | io::ErrorKind::UnexpectedEof
+ | io::ErrorKind::InvalidData
+ ) {
+ panic!("Server error: {e}");
+ }
+ }
+ Err(e) => std::panic::resume_unwind(e),
+ }
+
+ result
+}
+
+trait TestProtocol {
+ type Request;
+ type Response;
+
+ fn request(&self, writer: &mut dyn Write, req: Self::Request);
+ fn receive(&self, reader: &mut dyn BufRead, writer: &mut dyn Write) -> Self::Response;
+}
+
+#[allow(dead_code)]
+struct JsonLegacy;
+
+impl TestProtocol for JsonLegacy {
+ type Request = Request;
+ type Response = Response;
+
+ fn request(&self, writer: &mut dyn Write, req: Request) {
+ req.write(writer).expect("failed to write request");
+ }
+
+ fn receive(&self, reader: &mut dyn BufRead, _writer: &mut dyn Write) -> Response {
+ let mut buf = String::new();
+ Response::read(reader, &mut buf)
+ .expect("failed to read response")
+ .expect("no response received")
+ }
+}
+
+#[allow(dead_code)]
+struct PostcardBidirectional<F>
+where
+ F: Fn(SubRequest) -> Result<SubResponse, ServerError>,
+{
+ callback: F,
+}
+
+impl<F> TestProtocol for PostcardBidirectional<F>
+where
+ F: Fn(SubRequest) -> Result<SubResponse, ServerError>,
+{
+ type Request = BiRequest;
+ type Response = BiResponse;
+
+ fn request(&self, writer: &mut dyn Write, req: BiRequest) {
+ let msg = BidirectionalMessage::Request(req);
+ msg.write(writer).expect("failed to write request");
+ }
+
+ fn receive(&self, reader: &mut dyn BufRead, writer: &mut dyn Write) -> BiResponse {
+ let mut buf = Vec::new();
+
+ loop {
+ let msg = BidirectionalMessage::read(reader, &mut buf)
+ .expect("failed to read message")
+ .expect("no message received");
+
+ match msg {
+ BidirectionalMessage::Response(resp) => return resp,
+ BidirectionalMessage::SubRequest(sr) => {
+ let reply = (self.callback)(sr).expect("subrequest callback failed");
+ let msg = BidirectionalMessage::SubResponse(reply);
+ msg.write(writer).expect("failed to write subresponse");
+ }
+ other => panic!("unexpected message: {other:?}"),
+ }
+ }
+ }
+}
+
+#[allow(dead_code)]
+pub(crate) fn request_legacy(
+ writer: &mut dyn Write,
+ reader: &mut dyn BufRead,
+ request: Request,
+) -> Response {
+ let protocol = JsonLegacy;
+ protocol.request(writer, request);
+ protocol.receive(reader, writer)
+}
+
+#[allow(dead_code)]
+pub(crate) fn request_bidirectional<F>(
+ writer: &mut dyn Write,
+ reader: &mut dyn BufRead,
+ request: BiRequest,
+ callback: F,
+) -> BiResponse
+where
+ F: Fn(SubRequest) -> Result<SubResponse, ServerError>,
+{
+ let protocol = PostcardBidirectional { callback };
+ protocol.request(writer, request);
+ protocol.receive(reader, writer)
+}