use std::error::Error; use std::fmt::{Display, Formatter}; use std::io::ErrorKind; use std::path::{Component, Path, PathBuf}; use log::{debug, info, warn}; use sha1_smol::{Digest, Sha1}; use tokio::fs::File; use tokio::{fs, io}; use tokio::io::{AsyncReadExt, AsyncWriteExt}; #[derive(Debug)] pub enum IntegrityError { SizeMismatch{ expect: usize, actual: usize }, Sha1Mismatch{ expect: Digest, actual: Digest } } impl Display for IntegrityError { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { match self { IntegrityError::SizeMismatch{ expect, actual } => write!(f, "size mismatch (expect {expect} bytes, got {actual} bytes)"), IntegrityError::Sha1Mismatch {expect, actual} => write!(f, "sha1 mismatch (expect {expect}, got {actual})") } } } impl Error for IntegrityError {} pub fn verify_sha1(expect: Digest, s: &str) -> Result<(), Digest> { let dig = Sha1::from(s).digest(); if dig == expect { return Ok(()); } Err(dig) } #[derive(Debug)] pub enum FileVerifyError { Integrity(PathBuf, IntegrityError), Open(PathBuf, tokio::io::Error), Read(PathBuf, tokio::io::Error), } impl Display for FileVerifyError { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { match self { FileVerifyError::Integrity(path, e) => write!(f, "file integrity error {}: {}", path.display(), e), FileVerifyError::Open(path, e) => write!(f, "error opening file {}: {}", path.display(), e), FileVerifyError::Read(path, e) => write!(f, "error reading file {}: {}", path.display(), e) } } } impl Error for FileVerifyError { fn source(&self) -> Option<&(dyn Error + 'static)> { match self { FileVerifyError::Integrity(_, e) => Some(e), FileVerifyError::Open(_, e) => Some(e), FileVerifyError::Read(_, e) => Some(e) } } } pub async fn verify_file(path: impl AsRef, expect_size: Option, expect_sha1: Option) -> Result<(), FileVerifyError> { let path = path.as_ref(); if expect_size.is_none() && expect_sha1.is_none() { return match path.metadata() { Ok(_) => { debug!("No size or sha1 for {}, have to assume it's good.", path.display()); Ok(()) }, Err(e) => { Err(FileVerifyError::Open(path.to_path_buf(), e)) } } } let mut file = File::open(path).await.map_err(|e| FileVerifyError::Open(path.to_owned(), e))?; let mut tally = 0usize; let mut st = Sha1::new(); let mut buf = [0u8; 4096]; loop { let n = match file.read(&mut buf).await { Ok(n) => n, Err(e) => match e.kind() { ErrorKind::Interrupted => continue, _ => return Err(FileVerifyError::Read(path.to_owned(), e)) } }; if n == 0 { break; } st.update(&buf[..n]); tally += n; } let dig = st.digest(); if expect_size.is_some_and(|sz| sz != tally) { return Err(FileVerifyError::Integrity(path.to_owned(), IntegrityError::SizeMismatch { expect: expect_size.unwrap(), actual: tally })); } else if expect_sha1.is_some_and(|exp_dig| exp_dig != dig) { return Err(FileVerifyError::Integrity(path.to_owned(), IntegrityError::Sha1Mismatch { expect: expect_sha1.unwrap(), actual: dig })); } Ok(()) } #[derive(Debug)] pub enum EnsureFileError { IO { what: &'static str, error: io::Error }, Download { url: String, error: reqwest::Error }, Integrity(IntegrityError), Offline, MissingURL } impl Display for EnsureFileError { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { match self { EnsureFileError::IO { what, error } => write!(f, "i/o error ensuring file ({what}): {error}"), EnsureFileError::Download { url, error } => write!(f, "error downloading file ({url}): {error}"), EnsureFileError::Integrity(e) => write!(f, "integrity error for downloaded file: {e}"), EnsureFileError::Offline => f.write_str("unable to download file while offline"), EnsureFileError::MissingURL => f.write_str("missing url"), } } } impl Error for EnsureFileError { fn source(&self) -> Option<&(dyn Error + 'static)> { match self { EnsureFileError::IO { error, .. } => Some(error), EnsureFileError::Download { error, .. } => Some(error), EnsureFileError::Integrity(error) => Some(error), _ => None } } } pub async fn ensure_file(path: impl AsRef, url: Option<&str>, expect_size: Option, expect_sha1: Option, online: bool, force_download: bool) -> Result { let path = path.as_ref(); if !force_download { match verify_file(path, expect_size, expect_sha1).await { Ok(_) => { info!("File {} exists and integrity matches. Skipping.", path.display()); return Ok(false); }, Err(FileVerifyError::Open(_, e)) if e.kind() == ErrorKind::NotFound => (), Err(FileVerifyError::Integrity(_, e)) => info!("File {} on disk failed integrity check: {}", path.display(), e), Err(FileVerifyError::Open(_, e)) | Err(FileVerifyError::Read(_, e)) => return Err(EnsureFileError::IO { what: "verifying fileon disk", error: e }) } } if !online { warn!("Cannot download {} to {} while offline!", url.unwrap_or("(no url)"), path.display()); return Err(EnsureFileError::Offline); } // download the file let Some(url) = url else { return Err(EnsureFileError::MissingURL); }; let mut file = File::create(path).await.map_err(|e| EnsureFileError::IO { what: "save downloaded file (open)", error: e })?; debug!("File {} must be downloaded ({}).", path.display(), url); let mut response = reqwest::get(url).await.map_err(|e| EnsureFileError::Download { url: url.to_owned(), error: e })?; let mut tally = 0usize; let mut sha1 = Sha1::new(); while let Some(chunk) = response.chunk().await.map_err(|e| EnsureFileError::Download { url: url.to_owned(), error: e })? { let slice = chunk.as_ref(); file.write_all(slice).await.map_err(|e| EnsureFileError::IO { what: "save downloaded file (write)", error: e })?; tally += slice.len(); sha1.update(slice); } drop(file); // manually close file let del_file_silent = || async { debug!("Deleting downloaded file {} since its integrity doesn't match :(", path.display()); let _ = fs::remove_file(path).await.map_err(|e| warn!("failed to delete invalid downloaded file: {}", e)); () }; if expect_size.is_some_and(|s| s != tally) { del_file_silent().await; return Err(EnsureFileError::Integrity(IntegrityError::SizeMismatch { expect: expect_size.unwrap(), actual: tally })); } let digest = sha1.digest(); if expect_sha1.is_some_and(|exp_dig| exp_dig != digest) { del_file_silent().await; return Err(EnsureFileError::Integrity(IntegrityError::Sha1Mismatch { expect: expect_sha1.unwrap(), actual: digest })); } info!("File {} downloaded successfully.", path.display()); Ok(true) } pub fn check_path(name: &str) -> Result<&Path, &'static str> { let entry_path: &Path = Path::new(name); let mut depth = 0usize; for component in entry_path.components() { depth = match component { Component::Prefix(_) | Component::RootDir => return Err("root path component in entry"), Component::ParentDir => depth.checked_sub(1) .map_or_else(|| Err("entry path escapes"), |s| Ok(s))?, Component::Normal(_) => depth + 1, _ => depth } } Ok(entry_path) } #[cfg(windows)] pub fn strip_verbatim(path: &Path) -> &Path { let Some(Component::Prefix(p)) = path.components().next() else { return path; }; match p.kind() { Prefix::VerbatimDisk(_) => Path::new(unsafe { OsStr::from_encoded_bytes_unchecked(&path.as_os_str().as_encoded_bytes()[4..]) }), _ => path } } #[cfg(not(windows))] pub fn strip_verbatim(path: &Path) -> &Path { path } pub trait AsJavaPath { fn as_java_path(&self) -> &Path; } impl AsJavaPath for Path { fn as_java_path(&self) -> &Path { strip_verbatim(self) } } #[cfg(test)] mod tests { #[allow(unused_imports)] use super::*; #[test] #[cfg(windows)] fn test_strip_verbatim() { let path = Path::new(r"\\?\C:\Some\Verbatim\Path"); match path.components().next().unwrap() { Component::Prefix(p) => assert!(matches!(p.kind(), Prefix::VerbatimDisk(_)), "(TEST BUG) path does not start with verbatim disk"), _ => panic!("(TEST BUG) path does not start with prefix") } let path2 = path.as_java_path(); match path2.components().next().unwrap() { Component::Prefix(p) => assert!(matches!(p.kind(), Prefix::Disk(_))), _ => panic!("path does not begin with prefix") } } }