You've already forked AstralRinth
Merge commit '7fa442fb28a2b9156690ff147206275163e7aec8' into beta
This commit is contained in:
@@ -1,5 +1,7 @@
|
||||
//! Functions for fetching information from the Internet
|
||||
use super::io::{self, IOError};
|
||||
use crate::ErrorKind;
|
||||
use crate::LAUNCHER_USER_AGENT;
|
||||
use crate::event::LoadingBarId;
|
||||
use crate::event::emit::emit_loading;
|
||||
use bytes::Bytes;
|
||||
@@ -19,11 +21,8 @@ pub struct FetchSemaphore(pub Semaphore);
|
||||
|
||||
pub static REQWEST_CLIENT: LazyLock<reqwest::Client> = LazyLock::new(|| {
|
||||
let mut headers = reqwest::header::HeaderMap::new();
|
||||
let header = reqwest::header::HeaderValue::from_str(&format!(
|
||||
"modrinth/theseus/{} (support@modrinth.com)",
|
||||
env!("CARGO_PKG_VERSION")
|
||||
))
|
||||
.unwrap();
|
||||
let header =
|
||||
reqwest::header::HeaderValue::from_str(LAUNCHER_USER_AGENT).unwrap();
|
||||
headers.insert(reqwest::header::USER_AGENT, header);
|
||||
reqwest::Client::builder()
|
||||
.tcp_keepalive(Some(time::Duration::from_secs(10)))
|
||||
@@ -108,32 +107,31 @@ pub async fn fetch_advanced(
|
||||
|
||||
let result = req.send().await;
|
||||
match result {
|
||||
Ok(x) => {
|
||||
if x.status().is_server_error() {
|
||||
if attempt <= FETCH_ATTEMPTS {
|
||||
continue;
|
||||
} else {
|
||||
return Err(crate::Error::from(
|
||||
crate::ErrorKind::OtherError(
|
||||
"Server error when fetching content"
|
||||
.to_string(),
|
||||
),
|
||||
));
|
||||
Ok(resp) => {
|
||||
if resp.status().is_server_error() && attempt <= FETCH_ATTEMPTS
|
||||
{
|
||||
continue;
|
||||
}
|
||||
if resp.status().is_client_error()
|
||||
|| resp.status().is_server_error()
|
||||
{
|
||||
let backup_error = resp.error_for_status_ref().unwrap_err();
|
||||
if let Ok(error) = resp.json().await {
|
||||
return Err(ErrorKind::LabrinthError(error).into());
|
||||
}
|
||||
return Err(backup_error.into());
|
||||
}
|
||||
|
||||
let bytes = if let Some((bar, total)) = &loading_bar {
|
||||
let length = x.content_length();
|
||||
let length = resp.content_length();
|
||||
if let Some(total_size) = length {
|
||||
use futures::StreamExt;
|
||||
let mut stream = x.bytes_stream();
|
||||
let mut stream = resp.bytes_stream();
|
||||
let mut bytes = Vec::new();
|
||||
while let Some(item) = stream.next().await {
|
||||
let chunk = item.or(Err(
|
||||
crate::error::ErrorKind::NoValueFor(
|
||||
"fetch bytes".to_string(),
|
||||
),
|
||||
))?;
|
||||
let chunk = item.or(Err(ErrorKind::NoValueFor(
|
||||
"fetch bytes".to_string(),
|
||||
)))?;
|
||||
bytes.append(&mut chunk.to_vec());
|
||||
emit_loading(
|
||||
bar,
|
||||
@@ -145,10 +143,10 @@ pub async fn fetch_advanced(
|
||||
|
||||
Ok(bytes::Bytes::from(bytes))
|
||||
} else {
|
||||
x.bytes().await
|
||||
resp.bytes().await
|
||||
}
|
||||
} else {
|
||||
x.bytes().await
|
||||
resp.bytes().await
|
||||
};
|
||||
|
||||
if let Ok(bytes) = bytes {
|
||||
@@ -158,7 +156,7 @@ pub async fn fetch_advanced(
|
||||
if attempt <= FETCH_ATTEMPTS {
|
||||
continue;
|
||||
} else {
|
||||
return Err(crate::ErrorKind::HashError(
|
||||
return Err(ErrorKind::HashError(
|
||||
sha1.to_string(),
|
||||
hash,
|
||||
)
|
||||
@@ -194,10 +192,9 @@ pub async fn fetch_mirrors(
|
||||
exec: impl sqlx::Executor<'_, Database = sqlx::Sqlite> + Copy,
|
||||
) -> crate::Result<Bytes> {
|
||||
if mirrors.is_empty() {
|
||||
return Err(crate::ErrorKind::InputError(
|
||||
"No mirrors provided!".to_string(),
|
||||
)
|
||||
.into());
|
||||
return Err(
|
||||
ErrorKind::InputError("No mirrors provided!".to_string()).into()
|
||||
);
|
||||
}
|
||||
|
||||
for (index, mirror) in mirrors.iter().enumerate() {
|
||||
@@ -276,8 +273,8 @@ pub async fn write(
|
||||
}
|
||||
|
||||
pub async fn copy(
|
||||
src: impl AsRef<std::path::Path>,
|
||||
dest: impl AsRef<std::path::Path>,
|
||||
src: impl AsRef<Path>,
|
||||
dest: impl AsRef<Path>,
|
||||
semaphore: &IoSemaphore,
|
||||
) -> crate::Result<()> {
|
||||
let src: &Path = src.as_ref();
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
// IO error
|
||||
// A wrapper around the tokio IO functions that adds the path to the error message, instead of the uninformative std::io::Error.
|
||||
|
||||
use std::{io::Write, path::Path};
|
||||
use std::{
|
||||
io::{ErrorKind, Write},
|
||||
path::Path,
|
||||
};
|
||||
use tempfile::NamedTempFile;
|
||||
use tokio::task::spawn_blocking;
|
||||
|
||||
@@ -32,6 +35,13 @@ impl IOError {
|
||||
path: path.to_string_lossy().to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn kind(&self) -> ErrorKind {
|
||||
match self {
|
||||
IOError::IOPathError { source, .. } => source.kind(),
|
||||
IOError::IOError(source) => source.kind(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn canonicalize(
|
||||
|
||||
@@ -2,7 +2,9 @@
|
||||
pub mod fetch;
|
||||
pub mod io;
|
||||
pub mod jre;
|
||||
pub mod network;
|
||||
pub mod platform;
|
||||
pub mod utils; // [AR] Feature
|
||||
pub mod protocol_version;
|
||||
pub mod rpc;
|
||||
pub mod server_ping;
|
||||
|
||||
93
packages/app-lib/src/util/network.rs
Normal file
93
packages/app-lib/src/util/network.rs
Normal file
@@ -0,0 +1,93 @@
|
||||
use crate::Result;
|
||||
use std::io;
|
||||
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
|
||||
use tokio::net::TcpListener;
|
||||
|
||||
pub async fn tcp_listen_any_loopback() -> io::Result<TcpListener> {
|
||||
// IPv4 is tried first for the best compatibility and performance with most systems.
|
||||
// IPv6 is also tried in case IPv4 is not available. Resolving "localhost" is avoided
|
||||
// to prevent failures deriving from improper name resolution setup. Any available
|
||||
// ephemeral port is used to prevent conflicts with other services. This is all as per
|
||||
// RFC 8252's recommendations
|
||||
const ANY_LOOPBACK_SOCKET: &[SocketAddr] = &[
|
||||
SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
|
||||
SocketAddr::new(IpAddr::V6(Ipv6Addr::LOCALHOST), 0),
|
||||
];
|
||||
|
||||
TcpListener::bind(ANY_LOOPBACK_SOCKET).await
|
||||
}
|
||||
|
||||
#[cfg(windows)]
|
||||
pub async fn is_network_metered() -> Result<bool> {
|
||||
use windows::Networking::Connectivity::{
|
||||
NetworkCostType, NetworkInformation,
|
||||
};
|
||||
|
||||
let cost_type = NetworkInformation::GetInternetConnectionProfile()?
|
||||
.GetConnectionCost()?
|
||||
.NetworkCostType()?;
|
||||
Ok(matches!(
|
||||
cost_type,
|
||||
NetworkCostType::Fixed | NetworkCostType::Variable
|
||||
))
|
||||
}
|
||||
|
||||
#[cfg(target_os = "macos")]
|
||||
pub async fn is_network_metered() -> Result<bool> {
|
||||
use crate::ErrorKind;
|
||||
use cidre::dispatch::Queue;
|
||||
use cidre::nw::PathMonitor;
|
||||
use std::time::Duration;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio_util::future::FutureExt;
|
||||
|
||||
let (sender, mut receiver) = mpsc::channel(1);
|
||||
|
||||
let queue = Queue::new();
|
||||
let mut monitor = PathMonitor::new();
|
||||
monitor.set_queue(&queue);
|
||||
monitor.set_update_handler(move |path| {
|
||||
let _ = sender.try_send(path.is_constrained() || path.is_expensive());
|
||||
});
|
||||
|
||||
monitor.start();
|
||||
let result = receiver
|
||||
.recv()
|
||||
.timeout(Duration::from_millis(100))
|
||||
.await
|
||||
.ok()
|
||||
.flatten();
|
||||
monitor.cancel();
|
||||
|
||||
result.ok_or_else(|| {
|
||||
ErrorKind::OtherError(
|
||||
"NWPathMonitor didn't provide an NWPath in time".to_string(),
|
||||
)
|
||||
.into()
|
||||
})
|
||||
}
|
||||
|
||||
#[cfg(target_os = "linux")]
|
||||
pub async fn is_network_metered() -> Result<bool> {
|
||||
// Thanks to https://github.com/Hakanbaban53/rclone-manager for showing how to do this
|
||||
use zbus::{Connection, Proxy};
|
||||
|
||||
let connection = Connection::system().await?;
|
||||
let proxy = Proxy::new(
|
||||
&connection,
|
||||
"org.freedesktop.NetworkManager",
|
||||
"/org/freedesktop/NetworkManager",
|
||||
"org.freedesktop.NetworkManager",
|
||||
)
|
||||
.await?;
|
||||
let metered = proxy.get_property("Metered").await?;
|
||||
Ok(matches!(metered, 1 | 3))
|
||||
}
|
||||
|
||||
#[cfg(not(any(windows, target_os = "macos", target_os = "linux")))]
|
||||
pub async fn is_network_metered() -> Result<bool> {
|
||||
tracing::warn!(
|
||||
"is_network_metered called on unsupported platform. Assuming unmetered."
|
||||
);
|
||||
Ok(false)
|
||||
}
|
||||
@@ -1,65 +1,6 @@
|
||||
//! Platform-related code
|
||||
use daedalus::minecraft::{Os, OsRule};
|
||||
|
||||
// OS detection
|
||||
pub trait OsExt {
|
||||
/// Get the OS of the current system
|
||||
fn native() -> Self;
|
||||
|
||||
/// Gets the OS + Arch of the current system
|
||||
fn native_arch(java_arch: &str) -> Self;
|
||||
|
||||
/// Gets the OS from an OS + Arch
|
||||
fn get_os(&self) -> Self;
|
||||
}
|
||||
|
||||
impl OsExt for Os {
|
||||
fn native() -> Self {
|
||||
match std::env::consts::OS {
|
||||
"windows" => Self::Windows,
|
||||
"macos" => Self::Osx,
|
||||
"linux" => Self::Linux,
|
||||
_ => Self::Unknown,
|
||||
}
|
||||
}
|
||||
|
||||
fn native_arch(java_arch: &str) -> Self {
|
||||
if std::env::consts::OS == "windows" {
|
||||
if java_arch == "aarch64" {
|
||||
Os::WindowsArm64
|
||||
} else {
|
||||
Os::Windows
|
||||
}
|
||||
} else if std::env::consts::OS == "linux" {
|
||||
if java_arch == "aarch64" {
|
||||
Os::LinuxArm64
|
||||
} else if java_arch == "arm" {
|
||||
Os::LinuxArm32
|
||||
} else {
|
||||
Os::Linux
|
||||
}
|
||||
} else if std::env::consts::OS == "macos" {
|
||||
if java_arch == "aarch64" {
|
||||
Os::OsxArm64
|
||||
} else {
|
||||
Os::Osx
|
||||
}
|
||||
} else {
|
||||
Os::Unknown
|
||||
}
|
||||
}
|
||||
|
||||
fn get_os(&self) -> Self {
|
||||
match self {
|
||||
Os::OsxArm64 => Os::Osx,
|
||||
Os::LinuxArm32 => Os::Linux,
|
||||
Os::LinuxArm64 => Os::Linux,
|
||||
Os::WindowsArm64 => Os::Windows,
|
||||
_ => self.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Bit width
|
||||
#[cfg(target_pointer_width = "64")]
|
||||
pub const ARCH_WIDTH: &str = "64";
|
||||
|
||||
258
packages/app-lib/src/util/rpc.rs
Normal file
258
packages/app-lib/src/util/rpc.rs
Normal file
@@ -0,0 +1,258 @@
|
||||
use crate::prelude::tcp_listen_any_loopback;
|
||||
use crate::{ErrorKind, Result};
|
||||
use futures::{SinkExt, StreamExt};
|
||||
use serde::de::DeserializeOwned;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
use std::collections::HashMap;
|
||||
use std::net::SocketAddr;
|
||||
use std::pin::Pin;
|
||||
use std::sync::{Arc, Mutex};
|
||||
use tokio::net::TcpListener;
|
||||
use tokio::sync::{mpsc, oneshot};
|
||||
use tokio::task::AbortHandle;
|
||||
use tokio_util::codec::{Decoder, LinesCodec, LinesCodecError};
|
||||
use uuid::Uuid;
|
||||
|
||||
type HandlerFuture = Pin<Box<dyn Send + Future<Output = Result<Value>>>>;
|
||||
type HandlerMethod = Box<dyn Send + Sync + Fn(Vec<Value>) -> HandlerFuture>;
|
||||
type HandlerMap = HashMap<&'static str, HandlerMethod>;
|
||||
type WaitingResponsesMap =
|
||||
Arc<Mutex<HashMap<Uuid, oneshot::Sender<Result<Value>>>>>;
|
||||
|
||||
pub struct RpcServerBuilder {
|
||||
handlers: HandlerMap,
|
||||
}
|
||||
|
||||
impl RpcServerBuilder {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
handlers: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
// We'll use this function in the future. Please remove this #[allow] when we do.
|
||||
#[allow(dead_code)]
|
||||
pub fn handler(
|
||||
mut self,
|
||||
function_name: &'static str,
|
||||
handler: HandlerMethod,
|
||||
) -> Self {
|
||||
self.handlers.insert(function_name, Box::new(handler));
|
||||
self
|
||||
}
|
||||
|
||||
pub async fn launch(self) -> Result<RpcServer> {
|
||||
let socket = tcp_listen_any_loopback().await?;
|
||||
let address = socket.local_addr()?;
|
||||
let (message_sender, message_receiver) = mpsc::unbounded_channel();
|
||||
let waiting_responses = Arc::new(Mutex::new(HashMap::new()));
|
||||
|
||||
let join_handle = {
|
||||
let waiting_responses = waiting_responses.clone();
|
||||
tokio::spawn(async move {
|
||||
let mut server = RunningRpcServer {
|
||||
message_receiver,
|
||||
handlers: self.handlers,
|
||||
waiting_responses: waiting_responses.clone(),
|
||||
};
|
||||
if let Err(e) = server.run(socket).await {
|
||||
tracing::error!("Failed to run RPC server: {e}");
|
||||
}
|
||||
waiting_responses.lock().unwrap().clear();
|
||||
})
|
||||
};
|
||||
Ok(RpcServer {
|
||||
address,
|
||||
message_sender,
|
||||
waiting_responses,
|
||||
abort_handle: join_handle.abort_handle(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct RpcServer {
|
||||
address: SocketAddr,
|
||||
message_sender: mpsc::UnboundedSender<RpcMessage>,
|
||||
waiting_responses: WaitingResponsesMap,
|
||||
abort_handle: AbortHandle,
|
||||
}
|
||||
|
||||
impl RpcServer {
|
||||
pub fn address(&self) -> SocketAddr {
|
||||
self.address
|
||||
}
|
||||
|
||||
pub async fn call_method<R: DeserializeOwned>(
|
||||
&self,
|
||||
method: &str,
|
||||
) -> Result<R> {
|
||||
self.call_method_any(method, vec![]).await
|
||||
}
|
||||
|
||||
pub async fn call_method_2<R: DeserializeOwned>(
|
||||
&self,
|
||||
method: &str,
|
||||
arg1: impl Serialize,
|
||||
arg2: impl Serialize,
|
||||
) -> Result<R> {
|
||||
self.call_method_any(
|
||||
method,
|
||||
vec![serde_json::to_value(arg1)?, serde_json::to_value(arg2)?],
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn call_method_any<R: DeserializeOwned>(
|
||||
&self,
|
||||
method: &str,
|
||||
args: Vec<Value>,
|
||||
) -> Result<R> {
|
||||
if self.message_sender.is_closed() {
|
||||
return Err(ErrorKind::RpcError(
|
||||
"RPC connection closed".to_string(),
|
||||
)
|
||||
.into());
|
||||
}
|
||||
|
||||
let id = Uuid::new_v4();
|
||||
let (send, recv) = oneshot::channel();
|
||||
self.waiting_responses.lock().unwrap().insert(id, send);
|
||||
|
||||
let message = RpcMessage {
|
||||
id,
|
||||
body: RpcMessageBody::Call {
|
||||
method: method.to_owned(),
|
||||
args,
|
||||
},
|
||||
};
|
||||
if self.message_sender.send(message).is_err() {
|
||||
self.waiting_responses.lock().unwrap().remove(&id);
|
||||
return Err(ErrorKind::RpcError(
|
||||
"RPC connection closed while sending".to_string(),
|
||||
)
|
||||
.into());
|
||||
}
|
||||
|
||||
tracing::debug!("Waiting on result for {id}");
|
||||
let Ok(result) = recv.await else {
|
||||
self.waiting_responses.lock().unwrap().remove(&id);
|
||||
return Err(ErrorKind::RpcError(
|
||||
"RPC connection closed while waiting for response".to_string(),
|
||||
)
|
||||
.into());
|
||||
};
|
||||
result.and_then(|x| Ok(serde_json::from_value(x)?))
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for RpcServer {
|
||||
fn drop(&mut self) {
|
||||
self.abort_handle.abort();
|
||||
}
|
||||
}
|
||||
|
||||
struct RunningRpcServer {
|
||||
message_receiver: mpsc::UnboundedReceiver<RpcMessage>,
|
||||
handlers: HandlerMap,
|
||||
waiting_responses: WaitingResponsesMap,
|
||||
}
|
||||
|
||||
impl RunningRpcServer {
|
||||
async fn run(&mut self, listener: TcpListener) -> Result<()> {
|
||||
let (socket, _) = listener.accept().await?;
|
||||
drop(listener);
|
||||
|
||||
let mut socket = LinesCodec::new().framed(socket);
|
||||
loop {
|
||||
let to_send = tokio::select! {
|
||||
message = self.message_receiver.recv() => {
|
||||
if message.is_none() {
|
||||
break;
|
||||
}
|
||||
message
|
||||
},
|
||||
message = socket.next() => {
|
||||
let message: RpcMessage = match message {
|
||||
None => break,
|
||||
Some(Ok(message)) => serde_json::from_str(&message)?,
|
||||
Some(Err(LinesCodecError::Io(e))) => Err(e)?,
|
||||
Some(Err(LinesCodecError::MaxLineLengthExceeded)) => unreachable!(),
|
||||
};
|
||||
self.handle_message(message).await?
|
||||
},
|
||||
};
|
||||
if let Some(message) = to_send {
|
||||
let json = serde_json::to_string(&message)?;
|
||||
match socket.send(json).await {
|
||||
Ok(()) => {}
|
||||
Err(LinesCodecError::Io(e)) => Err(e)?,
|
||||
Err(LinesCodecError::MaxLineLengthExceeded) => {
|
||||
unreachable!()
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn handle_message(
|
||||
&self,
|
||||
message: RpcMessage,
|
||||
) -> Result<Option<RpcMessage>> {
|
||||
if let RpcMessageBody::Call { method, args } = message.body {
|
||||
let response = match self.handlers.get(method.as_str()) {
|
||||
Some(handler) => match handler(args).await {
|
||||
Ok(result) => RpcMessageBody::Respond { response: result },
|
||||
Err(e) => RpcMessageBody::Error {
|
||||
error: e.to_string(),
|
||||
},
|
||||
},
|
||||
None => RpcMessageBody::Error {
|
||||
error: format!("Unknown theseus RPC method {method}"),
|
||||
},
|
||||
};
|
||||
Ok(Some(RpcMessage {
|
||||
id: message.id,
|
||||
body: response,
|
||||
}))
|
||||
} else if let Some(sender) =
|
||||
self.waiting_responses.lock().unwrap().remove(&message.id)
|
||||
{
|
||||
let _ = sender.send(match message.body {
|
||||
RpcMessageBody::Respond { response } => Ok(response),
|
||||
RpcMessageBody::Error { error } => {
|
||||
Err(ErrorKind::RpcError(error).into())
|
||||
}
|
||||
_ => unreachable!(),
|
||||
});
|
||||
Ok(None)
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct RpcMessage {
|
||||
id: Uuid,
|
||||
#[serde(flatten)]
|
||||
body: RpcMessageBody,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(untagged)]
|
||||
enum RpcMessageBody {
|
||||
Call {
|
||||
method: String,
|
||||
args: Vec<Value>,
|
||||
},
|
||||
Respond {
|
||||
#[serde(default, skip_serializing_if = "Value::is_null")]
|
||||
response: Value,
|
||||
},
|
||||
Error {
|
||||
error: String,
|
||||
},
|
||||
}
|
||||
Reference in New Issue
Block a user