From e25d726da42e1542a5d7a8fb8d286d48fd892a29 Mon Sep 17 00:00:00 2001 From: Prospector <6166773+Prospector@users.noreply.github.com> Date: Fri, 15 Aug 2025 12:54:38 -0700 Subject: [PATCH] Revert "Implement a more robust IPC system between the launcher and client (#4159)" This reverts commit 5ffcc48d75a94f93449e96b7d7ee65f2ef07eca8. --- .../src/api/oauth_utils/auth_code_reply.rs | 20 +- packages/app-lib/build.rs | 1 + packages/app-lib/java/build.gradle.kts | 1 - .../com/modrinth/theseus/MinecraftLaunch.java | 48 +++- .../com/modrinth/theseus/rpc/RpcHandlers.java | 46 ---- .../theseus/rpc/RpcMethodException.java | 9 - .../com/modrinth/theseus/rpc/TheseusRpc.java | 183 ------------- packages/app-lib/src/api/mod.rs | 5 +- packages/app-lib/src/error.rs | 3 - packages/app-lib/src/launcher/args.rs | 7 - packages/app-lib/src/launcher/mod.rs | 19 +- packages/app-lib/src/state/process.rs | 19 +- packages/app-lib/src/util/mod.rs | 2 - packages/app-lib/src/util/network.rs | 17 -- packages/app-lib/src/util/rpc.rs | 258 ------------------ 15 files changed, 70 insertions(+), 568 deletions(-) delete mode 100644 packages/app-lib/java/src/main/java/com/modrinth/theseus/rpc/RpcHandlers.java delete mode 100644 packages/app-lib/java/src/main/java/com/modrinth/theseus/rpc/RpcMethodException.java delete mode 100644 packages/app-lib/java/src/main/java/com/modrinth/theseus/rpc/TheseusRpc.java delete mode 100644 packages/app-lib/src/util/network.rs delete mode 100644 packages/app-lib/src/util/rpc.rs diff --git a/apps/app/src/api/oauth_utils/auth_code_reply.rs b/apps/app/src/api/oauth_utils/auth_code_reply.rs index fedffcb0..4e4a5292 100644 --- a/apps/app/src/api/oauth_utils/auth_code_reply.rs +++ b/apps/app/src/api/oauth_utils/auth_code_reply.rs @@ -11,7 +11,7 @@ //! [RFC 8252]: https://datatracker.ietf.org/doc/html/rfc8252 use std::{ - net::SocketAddr, + net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}, sync::{LazyLock, Mutex}, time::Duration, }; @@ -19,8 +19,10 @@ use std::{ use hyper::body::Incoming; use hyper_util::rt::{TokioIo, TokioTimer}; use theseus::ErrorKind; -use theseus::prelude::tcp_listen_any_loopback; -use tokio::sync::{broadcast, oneshot}; +use tokio::{ + net::TcpListener, + sync::{broadcast, oneshot}, +}; static SERVER_SHUTDOWN: LazyLock> = LazyLock::new(|| broadcast::channel(1024).0); @@ -33,7 +35,17 @@ static SERVER_SHUTDOWN: LazyLock> = pub async fn listen( listen_socket_tx: oneshot::Sender>, ) -> Result, theseus::Error> { - let listener = match tcp_listen_any_loopback().await { + // IPv4 is tried first for the best compatibility and performance with most systems. + // IPv6 is also tried in case IPv4 is not available. Resolving "localhost" is avoided + // to prevent failures deriving from improper name resolution setup. Any available + // ephemeral port is used to prevent conflicts with other services. This is all as per + // RFC 8252's recommendations + const ANY_LOOPBACK_SOCKET: &[SocketAddr] = &[ + SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0), + SocketAddr::new(IpAddr::V6(Ipv6Addr::LOCALHOST), 0), + ]; + + let listener = match TcpListener::bind(ANY_LOOPBACK_SOCKET).await { Ok(listener) => { listen_socket_tx .send(listener.local_addr().map_err(|e| { diff --git a/packages/app-lib/build.rs b/packages/app-lib/build.rs index 10ed29b9..48da4b45 100644 --- a/packages/app-lib/build.rs +++ b/packages/app-lib/build.rs @@ -53,6 +53,7 @@ fn build_java_jars() { .arg("build") .arg("--no-daemon") .arg("--console=rich") + .arg("--info") .current_dir(dunce::canonicalize("java").unwrap()) .status() .expect("Failed to wait on Gradle build"); diff --git a/packages/app-lib/java/build.gradle.kts b/packages/app-lib/java/build.gradle.kts index 98c95c8c..a671dd6f 100644 --- a/packages/app-lib/java/build.gradle.kts +++ b/packages/app-lib/java/build.gradle.kts @@ -11,7 +11,6 @@ repositories { dependencies { implementation("org.ow2.asm:asm:9.8") implementation("org.ow2.asm:asm-tree:9.8") - implementation("com.google.code.gson:gson:2.13.1") testImplementation(libs.junit.jupiter) testRuntimeOnly("org.junit.platform:junit-platform-launcher") diff --git a/packages/app-lib/java/src/main/java/com/modrinth/theseus/MinecraftLaunch.java b/packages/app-lib/java/src/main/java/com/modrinth/theseus/MinecraftLaunch.java index b474ba02..9d61a0c0 100644 --- a/packages/app-lib/java/src/main/java/com/modrinth/theseus/MinecraftLaunch.java +++ b/packages/app-lib/java/src/main/java/com/modrinth/theseus/MinecraftLaunch.java @@ -1,13 +1,11 @@ package com.modrinth.theseus; -import com.modrinth.theseus.rpc.RpcHandlers; -import com.modrinth.theseus.rpc.TheseusRpc; +import java.io.ByteArrayOutputStream; import java.io.IOException; import java.lang.reflect.AccessibleObject; import java.lang.reflect.Method; import java.lang.reflect.Modifier; import java.util.Arrays; -import java.util.concurrent.CompletableFuture; public final class MinecraftLaunch { public static void main(String[] args) throws IOException, ReflectiveOperationException { @@ -15,19 +13,45 @@ public final class MinecraftLaunch { final String[] gameArgs = Arrays.copyOfRange(args, 1, args.length); System.setProperty("modrinth.process.args", String.join("\u001f", gameArgs)); + parseInput(); - final CompletableFuture waitForLaunch = new CompletableFuture<>(); - TheseusRpc.connectAndStart( - System.getProperty("modrinth.internal.ipc.host"), - Integer.getInteger("modrinth.internal.ipc.port"), - new RpcHandlers() - .handler("set_system_property", String.class, String.class, System::setProperty) - .handler("launch", () -> waitForLaunch.complete(null))); - - waitForLaunch.join(); relaunch(mainClass, gameArgs); } + private static void parseInput() throws IOException { + final ByteArrayOutputStream line = new ByteArrayOutputStream(); + while (true) { + final int b = System.in.read(); + if (b < 0) { + throw new IllegalStateException("Stdin terminated while parsing"); + } + if (b != '\n') { + line.write(b); + continue; + } + if (handleLine(line.toString("UTF-8"))) { + break; + } + line.reset(); + } + } + + private static boolean handleLine(String line) { + final String[] parts = line.split("\t", 2); + switch (parts[0]) { + case "property": { + final String[] keyValue = parts[1].split("\t", 2); + System.setProperty(keyValue[0], keyValue[1]); + return false; + } + case "launch": + return true; + } + + System.err.println("Unknown input line " + line); + return false; + } + private static void relaunch(String mainClassName, String[] args) throws ReflectiveOperationException { final int javaVersion = getJavaVersion(); final Class mainClass = Class.forName(mainClassName); diff --git a/packages/app-lib/java/src/main/java/com/modrinth/theseus/rpc/RpcHandlers.java b/packages/app-lib/java/src/main/java/com/modrinth/theseus/rpc/RpcHandlers.java deleted file mode 100644 index 257148ef..00000000 --- a/packages/app-lib/java/src/main/java/com/modrinth/theseus/rpc/RpcHandlers.java +++ /dev/null @@ -1,46 +0,0 @@ -package com.modrinth.theseus.rpc; - -import com.google.gson.JsonElement; -import com.google.gson.JsonNull; -import java.util.HashMap; -import java.util.Map; -import java.util.function.BiConsumer; -import java.util.function.Function; - -public class RpcHandlers { - private final Map> handlers = new HashMap<>(); - private boolean frozen; - - public RpcHandlers handler(String functionName, Runnable handler) { - return addHandler(functionName, args -> { - handler.run(); - return JsonNull.INSTANCE; - }); - } - - public RpcHandlers handler( - String functionName, Class arg1Type, Class arg2Type, BiConsumer handler) { - return addHandler(functionName, args -> { - if (args.length != 2) { - throw new IllegalArgumentException(functionName + " expected 2 arguments"); - } - final A arg1 = TheseusRpc.GSON.fromJson(args[0], arg1Type); - final B arg2 = TheseusRpc.GSON.fromJson(args[1], arg2Type); - handler.accept(arg1, arg2); - return JsonNull.INSTANCE; - }); - } - - private RpcHandlers addHandler(String functionName, Function handler) { - if (frozen) { - throw new IllegalStateException("Cannot add handler to frozen RpcHandlers instance"); - } - handlers.put(functionName, handler); - return this; - } - - Map> build() { - frozen = true; - return handlers; - } -} diff --git a/packages/app-lib/java/src/main/java/com/modrinth/theseus/rpc/RpcMethodException.java b/packages/app-lib/java/src/main/java/com/modrinth/theseus/rpc/RpcMethodException.java deleted file mode 100644 index f9ab75a3..00000000 --- a/packages/app-lib/java/src/main/java/com/modrinth/theseus/rpc/RpcMethodException.java +++ /dev/null @@ -1,9 +0,0 @@ -package com.modrinth.theseus.rpc; - -public class RpcMethodException extends RuntimeException { - private static final long serialVersionUID = 1922360184188807964L; - - public RpcMethodException(String message) { - super(message); - } -} diff --git a/packages/app-lib/java/src/main/java/com/modrinth/theseus/rpc/TheseusRpc.java b/packages/app-lib/java/src/main/java/com/modrinth/theseus/rpc/TheseusRpc.java deleted file mode 100644 index ff460ff8..00000000 --- a/packages/app-lib/java/src/main/java/com/modrinth/theseus/rpc/TheseusRpc.java +++ /dev/null @@ -1,183 +0,0 @@ -package com.modrinth.theseus.rpc; - -import com.google.gson.*; -import com.google.gson.reflect.TypeToken; -import java.io.*; -import java.net.Socket; -import java.nio.charset.StandardCharsets; -import java.util.Map; -import java.util.UUID; -import java.util.concurrent.BlockingQueue; -import java.util.concurrent.CompletableFuture; -import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.LinkedBlockingQueue; -import java.util.concurrent.atomic.AtomicReference; -import java.util.function.Function; - -public final class TheseusRpc { - static final Gson GSON = new GsonBuilder() - .setStrictness(Strictness.STRICT) - .setFieldNamingPolicy(FieldNamingPolicy.LOWER_CASE_WITH_UNDERSCORES) - .disableHtmlEscaping() - .create(); - private static final TypeToken MESSAGE_TYPE = TypeToken.get(RpcMessage.class); - - private static final AtomicReference RPC = new AtomicReference<>(); - - private final BlockingQueue mainThreadQueue = new LinkedBlockingQueue<>(); - private final Map> awaitingResponse = new ConcurrentHashMap<>(); - private final Map> handlers; - private final Socket socket; - - private TheseusRpc(Socket socket, RpcHandlers handlers) { - this.socket = socket; - this.handlers = handlers.build(); - } - - public static void connectAndStart(String host, int port, RpcHandlers handlers) throws IOException { - if (RPC.get() != null) { - throw new IllegalStateException("Can only connect to RPC once"); - } - - final Socket socket = new Socket(host, port); - final TheseusRpc rpc = new TheseusRpc(socket, handlers); - final Thread mainThread = new Thread(rpc::mainThread, "Theseus RPC Main"); - final Thread readThread = new Thread(rpc::readThread, "Theseus RPC Read"); - mainThread.setDaemon(true); - readThread.setDaemon(true); - mainThread.start(); - readThread.start(); - RPC.set(rpc); - } - - public static TheseusRpc getRpc() { - final TheseusRpc rpc = RPC.get(); - if (rpc == null) { - throw new IllegalStateException("Called getRpc before RPC initialized"); - } - return rpc; - } - - public CompletableFuture callMethod(TypeToken returnType, String method, Object... args) { - final JsonElement[] jsonArgs = new JsonElement[args.length]; - for (int i = 0; i < args.length; i++) { - jsonArgs[i] = GSON.toJsonTree(args[i]); - } - - final RpcMessage message = new RpcMessage(method, jsonArgs); - final ResponseWaiter responseWaiter = new ResponseWaiter<>(returnType); - awaitingResponse.put(message.id, responseWaiter); - mainThreadQueue.add(message); - return responseWaiter.future; - } - - private void mainThread() { - try { - final Writer writer = new OutputStreamWriter(socket.getOutputStream(), StandardCharsets.UTF_8); - while (true) { - final RpcMessage message = mainThreadQueue.take(); - final RpcMessage toSend; - if (message.isForSending) { - toSend = message; - } else { - final Function handler = handlers.get(message.method); - if (handler == null) { - System.err.println("Unknown theseus RPC method " + message.method); - continue; - } - RpcMessage response; - try { - response = new RpcMessage(message.id, handler.apply(message.args)); - } catch (Exception e) { - response = new RpcMessage(message.id, e.toString()); - } - toSend = response; - } - GSON.toJson(toSend, writer); - writer.write('\n'); - writer.flush(); - } - } catch (IOException e) { - throw new UncheckedIOException(e); - } catch (InterruptedException ignored) { - } - } - - private void readThread() { - try { - final BufferedReader reader = - new BufferedReader(new InputStreamReader(socket.getInputStream(), StandardCharsets.UTF_8)); - while (true) { - final RpcMessage message = GSON.fromJson(reader.readLine(), MESSAGE_TYPE); - if (message.method == null) { - final ResponseWaiter waiter = awaitingResponse.get(message.id); - if (waiter != null) { - handleResponse(waiter, message); - } - } else { - mainThreadQueue.put(message); - } - } - } catch (IOException e) { - throw new UncheckedIOException(e); - } catch (InterruptedException ignored) { - } - } - - private void handleResponse(ResponseWaiter waiter, RpcMessage message) { - if (message.error != null) { - waiter.future.completeExceptionally(new RpcMethodException(message.error)); - return; - } - try { - waiter.future.complete(GSON.fromJson(message.response, waiter.type)); - } catch (JsonSyntaxException e) { - waiter.future.completeExceptionally(e); - } - } - - private static class RpcMessage { - final UUID id; - final String method; // Optional - final JsonElement[] args; // Optional - final JsonElement response; // Optional - final String error; // Optional - final transient boolean isForSending; - - RpcMessage(String method, JsonElement[] args) { - id = UUID.randomUUID(); - this.method = method; - this.args = args; - response = null; - error = null; - isForSending = true; - } - - RpcMessage(UUID id, JsonElement response) { - this.id = id; - method = null; - args = null; - this.response = response; - error = null; - isForSending = true; - } - - RpcMessage(UUID id, String error) { - this.id = id; - method = null; - args = null; - response = null; - this.error = error; - isForSending = true; - } - } - - private static class ResponseWaiter { - final TypeToken type; - final CompletableFuture future = new CompletableFuture<>(); - - ResponseWaiter(TypeToken type) { - this.type = type; - } - } -} diff --git a/packages/app-lib/src/api/mod.rs b/packages/app-lib/src/api/mod.rs index 020afbe4..b173d035 100644 --- a/packages/app-lib/src/api/mod.rs +++ b/packages/app-lib/src/api/mod.rs @@ -35,9 +35,6 @@ pub mod prelude { jre, metadata, minecraft_auth, mr_auth, pack, process, profile::{self, Profile, create}, settings, - util::{ - io::{IOError, canonicalize}, - network::tcp_listen_any_loopback, - }, + util::io::{IOError, canonicalize}, }; } diff --git a/packages/app-lib/src/error.rs b/packages/app-lib/src/error.rs index 773d55da..75c144f5 100644 --- a/packages/app-lib/src/error.rs +++ b/packages/app-lib/src/error.rs @@ -151,9 +151,6 @@ pub enum ErrorKind { "A skin texture must have a dimension of either 64x64 or 64x32 pixels" )] InvalidSkinTexture, - - #[error("RPC error: {0}")] - RpcError(String), } #[derive(Debug)] diff --git a/packages/app-lib/src/launcher/args.rs b/packages/app-lib/src/launcher/args.rs index e2093f61..350d67c0 100644 --- a/packages/app-lib/src/launcher/args.rs +++ b/packages/app-lib/src/launcher/args.rs @@ -16,7 +16,6 @@ use daedalus::{ use dunce::canonicalize; use hashlink::LinkedHashSet; use std::io::{BufRead, BufReader}; -use std::net::SocketAddr; use std::{collections::HashMap, path::Path}; use uuid::Uuid; @@ -125,7 +124,6 @@ pub fn get_jvm_arguments( quick_play_type: &QuickPlayType, quick_play_version: QuickPlayVersion, log_config: Option<&LoggingConfiguration>, - ipc_addr: SocketAddr, ) -> crate::Result> { let mut parsed_arguments = Vec::new(); @@ -183,11 +181,6 @@ pub fn get_jvm_arguments( .to_string_lossy() )); - parsed_arguments - .push(format!("-Dmodrinth.internal.ipc.host={}", ipc_addr.ip())); - parsed_arguments - .push(format!("-Dmodrinth.internal.ipc.port={}", ipc_addr.port())); - parsed_arguments.push(format!( "-Dmodrinth.internal.quickPlay.serverVersion={}", serde_json::to_value(quick_play_version.server)? diff --git a/packages/app-lib/src/launcher/mod.rs b/packages/app-lib/src/launcher/mod.rs index 1b7a7d7e..64eb1d90 100644 --- a/packages/app-lib/src/launcher/mod.rs +++ b/packages/app-lib/src/launcher/mod.rs @@ -12,7 +12,6 @@ use crate::state::{ Credentials, JavaVersion, ProcessMetadata, ProfileInstallStage, }; use crate::util::io; -use crate::util::rpc::RpcServerBuilder; use crate::{State, get_resource_file, process, state as st}; use chrono::Utc; use daedalus as d; @@ -23,6 +22,7 @@ use serde::Deserialize; use st::Profile; use std::fmt::Write; use std::path::PathBuf; +use tokio::io::AsyncWriteExt; use tokio::process::Command; mod args; @@ -608,8 +608,6 @@ pub async fn launch_minecraft( let (main_class_keep_alive, main_class_path) = get_resource_file!(env "JAVA_JARS_DIR" / "theseus.jar")?; - let rpc_server = RpcServerBuilder::new().launch().await?; - command.args( args::get_jvm_arguments( args.get(&d::minecraft::ArgumentType::Jvm) @@ -635,7 +633,6 @@ pub async fn launch_minecraft( .logging .as_ref() .and_then(|x| x.get(&LoggingSide::Client)), - rpc_server.address(), )? .into_iter(), ); @@ -770,8 +767,7 @@ pub async fn launch_minecraft( state.directories.profile_logs_dir(&profile.path), version_info.logging.is_some(), main_class_keep_alive, - rpc_server, - async |process: &ProcessMetadata, rpc_server| { + async |process: &ProcessMetadata, stdin| { let process_start_time = process.start_time.to_rfc3339(); let profile_created_time = profile.created.to_rfc3339(); let profile_modified_time = profile.modified.to_rfc3339(); @@ -794,11 +790,14 @@ pub async fn launch_minecraft( let Some(value) = value else { continue; }; - rpc_server - .call_method_2::<()>("set_system_property", key, value) - .await?; + stdin.write_all(b"property\t").await?; + stdin.write_all(key.as_bytes()).await?; + stdin.write_u8(b'\t').await?; + stdin.write_all(value.as_bytes()).await?; + stdin.write_u8(b'\n').await?; } - rpc_server.call_method::<()>("launch").await?; + stdin.write_all(b"launch\n").await?; + stdin.flush().await?; Ok(()) }, ) diff --git a/packages/app-lib/src/state/process.rs b/packages/app-lib/src/state/process.rs index 4cff0a33..faf1c9b4 100644 --- a/packages/app-lib/src/state/process.rs +++ b/packages/app-lib/src/state/process.rs @@ -2,7 +2,6 @@ use crate::event::emit::{emit_process, emit_profile}; use crate::event::{ProcessPayloadType, ProfilePayloadType}; use crate::profile; use crate::util::io::IOError; -use crate::util::rpc::RpcServer; use chrono::{DateTime, NaiveDateTime, TimeZone, Utc}; use dashmap::DashMap; use quick_xml::Reader; @@ -16,7 +15,7 @@ use std::path::{Path, PathBuf}; use std::process::ExitStatus; use tempfile::TempDir; use tokio::io::{AsyncBufReadExt, BufReader}; -use tokio::process::{Child, Command}; +use tokio::process::{Child, ChildStdin, Command}; use uuid::Uuid; const LAUNCHER_LOG_PATH: &str = "launcher_log.txt"; @@ -47,10 +46,9 @@ impl ProcessManager { logs_folder: PathBuf, xml_logging: bool, main_class_keep_alive: TempDir, - rpc_server: RpcServer, post_process_init: impl AsyncFnOnce( &ProcessMetadata, - &RpcServer, + &mut ChildStdin, ) -> crate::Result<()>, ) -> crate::Result { mc_command.stdout(std::process::Stdio::piped()); @@ -69,12 +67,14 @@ impl ProcessManager { profile_path: profile_path.to_string(), }, child: mc_proc, - rpc_server, _main_class_keep_alive: main_class_keep_alive, }; - if let Err(e) = - post_process_init(&process.metadata, &process.rpc_server).await + if let Err(e) = post_process_init( + &process.metadata, + &mut process.child.stdin.as_mut().unwrap(), + ) + .await { tracing::error!("Failed to run post-process init: {e}"); let _ = process.child.kill().await; @@ -165,10 +165,6 @@ impl ProcessManager { self.processes.get(&id).map(|x| x.metadata.clone()) } - pub fn get_rpc(&self, id: Uuid) -> Option { - self.processes.get(&id).map(|x| x.rpc_server.clone()) - } - pub fn get_all(&self) -> Vec { self.processes .iter() @@ -219,7 +215,6 @@ struct Process { metadata: ProcessMetadata, child: Child, _main_class_keep_alive: TempDir, - rpc_server: RpcServer, } #[derive(Debug, Default)] diff --git a/packages/app-lib/src/util/mod.rs b/packages/app-lib/src/util/mod.rs index 7656b4a0..67c5ede1 100644 --- a/packages/app-lib/src/util/mod.rs +++ b/packages/app-lib/src/util/mod.rs @@ -2,8 +2,6 @@ pub mod fetch; pub mod io; pub mod jre; -pub mod network; pub mod platform; pub mod protocol_version; -pub mod rpc; pub mod server_ping; diff --git a/packages/app-lib/src/util/network.rs b/packages/app-lib/src/util/network.rs deleted file mode 100644 index 2837516c..00000000 --- a/packages/app-lib/src/util/network.rs +++ /dev/null @@ -1,17 +0,0 @@ -use std::io; -use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}; -use tokio::net::TcpListener; - -pub async fn tcp_listen_any_loopback() -> io::Result { - // IPv4 is tried first for the best compatibility and performance with most systems. - // IPv6 is also tried in case IPv4 is not available. Resolving "localhost" is avoided - // to prevent failures deriving from improper name resolution setup. Any available - // ephemeral port is used to prevent conflicts with other services. This is all as per - // RFC 8252's recommendations - const ANY_LOOPBACK_SOCKET: &[SocketAddr] = &[ - SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0), - SocketAddr::new(IpAddr::V6(Ipv6Addr::LOCALHOST), 0), - ]; - - TcpListener::bind(ANY_LOOPBACK_SOCKET).await -} diff --git a/packages/app-lib/src/util/rpc.rs b/packages/app-lib/src/util/rpc.rs deleted file mode 100644 index d6902bd8..00000000 --- a/packages/app-lib/src/util/rpc.rs +++ /dev/null @@ -1,258 +0,0 @@ -use crate::prelude::tcp_listen_any_loopback; -use crate::{ErrorKind, Result}; -use futures::{SinkExt, StreamExt}; -use serde::de::DeserializeOwned; -use serde::{Deserialize, Serialize}; -use serde_json::Value; -use std::collections::HashMap; -use std::net::SocketAddr; -use std::pin::Pin; -use std::sync::{Arc, Mutex}; -use tokio::net::TcpListener; -use tokio::sync::{mpsc, oneshot}; -use tokio::task::AbortHandle; -use tokio_util::codec::{Decoder, LinesCodec, LinesCodecError}; -use uuid::Uuid; - -type HandlerFuture = Pin>>>; -type HandlerMethod = Box) -> HandlerFuture>; -type HandlerMap = HashMap<&'static str, HandlerMethod>; -type WaitingResponsesMap = - Arc>>>>; - -pub struct RpcServerBuilder { - handlers: HandlerMap, -} - -impl RpcServerBuilder { - pub fn new() -> Self { - Self { - handlers: HashMap::new(), - } - } - - // We'll use this function in the future. Please remove this #[allow] when we do. - #[allow(dead_code)] - pub fn handler( - mut self, - function_name: &'static str, - handler: HandlerMethod, - ) -> Self { - self.handlers.insert(function_name, Box::new(handler)); - self - } - - pub async fn launch(self) -> Result { - let socket = tcp_listen_any_loopback().await?; - let address = socket.local_addr()?; - let (message_sender, message_receiver) = mpsc::unbounded_channel(); - let waiting_responses = Arc::new(Mutex::new(HashMap::new())); - - let join_handle = { - let waiting_responses = waiting_responses.clone(); - tokio::spawn(async move { - let mut server = RunningRpcServer { - message_receiver, - handlers: self.handlers, - waiting_responses: waiting_responses.clone(), - }; - if let Err(e) = server.run(socket).await { - tracing::error!("Failed to run RPC server: {e}"); - } - waiting_responses.lock().unwrap().clear(); - }) - }; - Ok(RpcServer { - address, - message_sender, - waiting_responses, - abort_handle: join_handle.abort_handle(), - }) - } -} - -#[derive(Debug, Clone)] -pub struct RpcServer { - address: SocketAddr, - message_sender: mpsc::UnboundedSender, - waiting_responses: WaitingResponsesMap, - abort_handle: AbortHandle, -} - -impl RpcServer { - pub fn address(&self) -> SocketAddr { - self.address - } - - pub async fn call_method( - &self, - method: &str, - ) -> Result { - self.call_method_any(method, vec![]).await - } - - pub async fn call_method_2( - &self, - method: &str, - arg1: impl Serialize, - arg2: impl Serialize, - ) -> Result { - self.call_method_any( - method, - vec![serde_json::to_value(arg1)?, serde_json::to_value(arg2)?], - ) - .await - } - - async fn call_method_any( - &self, - method: &str, - args: Vec, - ) -> Result { - if self.message_sender.is_closed() { - return Err(ErrorKind::RpcError( - "RPC connection closed".to_string(), - ) - .into()); - } - - let id = Uuid::new_v4(); - let (send, recv) = oneshot::channel(); - self.waiting_responses.lock().unwrap().insert(id, send); - - let message = RpcMessage { - id, - body: RpcMessageBody::Call { - method: method.to_owned(), - args, - }, - }; - if self.message_sender.send(message).is_err() { - self.waiting_responses.lock().unwrap().remove(&id); - return Err(ErrorKind::RpcError( - "RPC connection closed while sending".to_string(), - ) - .into()); - } - - tracing::debug!("Waiting on result for {id}"); - let Ok(result) = recv.await else { - self.waiting_responses.lock().unwrap().remove(&id); - return Err(ErrorKind::RpcError( - "RPC connection closed while waiting for response".to_string(), - ) - .into()); - }; - result.and_then(|x| Ok(serde_json::from_value(x)?)) - } -} - -impl Drop for RpcServer { - fn drop(&mut self) { - self.abort_handle.abort(); - } -} - -struct RunningRpcServer { - message_receiver: mpsc::UnboundedReceiver, - handlers: HandlerMap, - waiting_responses: WaitingResponsesMap, -} - -impl RunningRpcServer { - async fn run(&mut self, listener: TcpListener) -> Result<()> { - let (socket, _) = listener.accept().await?; - drop(listener); - - let mut socket = LinesCodec::new().framed(socket); - loop { - let to_send = tokio::select! { - message = self.message_receiver.recv() => { - if message.is_none() { - break; - } - message - }, - message = socket.next() => { - let message: RpcMessage = match message { - None => break, - Some(Ok(message)) => serde_json::from_str(&message)?, - Some(Err(LinesCodecError::Io(e))) => Err(e)?, - Some(Err(LinesCodecError::MaxLineLengthExceeded)) => unreachable!(), - }; - self.handle_message(message).await? - }, - }; - if let Some(message) = to_send { - let json = serde_json::to_string(&message)?; - match socket.send(json).await { - Ok(()) => {} - Err(LinesCodecError::Io(e)) => Err(e)?, - Err(LinesCodecError::MaxLineLengthExceeded) => { - unreachable!() - } - }; - } - } - Ok(()) - } - - async fn handle_message( - &self, - message: RpcMessage, - ) -> Result> { - if let RpcMessageBody::Call { method, args } = message.body { - let response = match self.handlers.get(method.as_str()) { - Some(handler) => match handler(args).await { - Ok(result) => RpcMessageBody::Respond { response: result }, - Err(e) => RpcMessageBody::Error { - error: e.to_string(), - }, - }, - None => RpcMessageBody::Error { - error: format!("Unknown theseus RPC method {method}"), - }, - }; - Ok(Some(RpcMessage { - id: message.id, - body: response, - })) - } else if let Some(sender) = - self.waiting_responses.lock().unwrap().remove(&message.id) - { - let _ = sender.send(match message.body { - RpcMessageBody::Respond { response } => Ok(response), - RpcMessageBody::Error { error } => { - Err(ErrorKind::RpcError(error).into()) - } - _ => unreachable!(), - }); - Ok(None) - } else { - Ok(None) - } - } -} - -#[derive(Debug, Serialize, Deserialize)] -struct RpcMessage { - id: Uuid, - #[serde(flatten)] - body: RpcMessageBody, -} - -#[derive(Debug, Serialize, Deserialize)] -#[serde(untagged)] -enum RpcMessageBody { - Call { - method: String, - args: Vec, - }, - Respond { - #[serde(default, skip_serializing_if = "Value::is_null")] - response: Value, - }, - Error { - error: String, - }, -}