use std::error::Error; use std::fmt::{Debug, Display, Formatter}; use std::io::ErrorKind; use std::path::{Path, PathBuf}; use futures::{stream, Stream, StreamExt}; use log::debug; use reqwest::{Client, IntoUrl, Method, RequestBuilder}; use sha1_smol::{Digest, Sha1}; use tokio::fs; use tokio::fs::File; use tokio::io::{self, AsyncReadExt, AsyncWriteExt}; use crate::launcher::constants::USER_AGENT; pub trait Download: Debug + Display { // return Ok(None) to skip downloading this file fn get_url(&self) -> impl IntoUrl; async fn prepare(&mut self, req: RequestBuilder) -> Result, Box>; async fn handle_chunk(&mut self, chunk: &[u8]) -> Result<(), Box>; async fn finish(&mut self) -> Result<(), Box>; } pub struct MultiDownloader { jobs: Vec, client: Client, nconcurrent: usize } #[derive(Debug, Clone, Copy)] pub enum Phase { Prepare, Send, Receive, HandleChunk, Finish } impl Display for Phase { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { match self { /* an error occurred while (present participle) ... */ Self::Prepare => f.write_str("preparing the request"), Self::Send => f.write_str("sending the request"), Self::Receive => f.write_str("receiving response data"), Self::HandleChunk => f.write_str("handling response data"), Self::Finish => f.write_str("finishing the request"), } } } pub struct PhaseDownloadError<'j, T: Download> { phase: Phase, inner: Box, job: &'j T } impl<'j, T: Download> Debug for PhaseDownloadError<'j, T> { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { f.debug_struct("PhaseDownloadError") .field("phase", &self.phase) .field("inner", &self.inner) .field("job", &self.job) .finish() } } impl<'j, T: Download> Display for PhaseDownloadError<'j, T> { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { write!(f, "error while {} ({}): {}", self.phase, self.job, self.inner) } } impl<'j, T: Download> Error for PhaseDownloadError<'j, T> { fn source(&self) -> Option<&(dyn Error + 'static)> { Some(&*self.inner) } } impl<'j, T: Download> PhaseDownloadError<'j, T> { fn new(phase: Phase, inner: Box, job: &'j T) -> Self { PhaseDownloadError { phase, inner, job } } } impl MultiDownloader { pub fn new(jobs: impl IntoIterator) -> MultiDownloader { Self::with_concurrent(jobs, 8) } pub fn with_concurrent(jobs: impl IntoIterator, n: usize) -> MultiDownloader { assert!(n > 0); MultiDownloader { jobs: jobs.into_iter().collect(), client: Client::new(), nconcurrent: n } } pub async fn perform(&mut self) -> impl Stream>> { stream::iter(self.jobs.iter_mut()).map(|job| { let client = &self.client; macro_rules! map_err { ($result:expr, $phase:expr, $job:expr) => { match $result { Ok(v) => v, Err(e) => return Err(PhaseDownloadError::new($phase, e.into(), $job)) } } } async move { let Some(rq) = map_err!( job.prepare(client.request(Method::GET, job.get_url()) .header(reqwest::header::USER_AGENT, USER_AGENT)).await, Phase::Prepare, job) else { return Ok(()) }; let mut data = map_err!(map_err!(rq.send().await, Phase::Send, job).error_for_status(), Phase::Send, job).bytes_stream(); while let Some(bytes) = data.next().await { let bytes = map_err!(bytes, Phase::Receive, job); map_err!(job.handle_chunk(bytes.as_ref()).await, Phase::HandleChunk, job); } job.finish().await.map_err(|e| PhaseDownloadError::new(Phase::Finish, e.into(), job))?; Ok(()) } }).buffer_unordered(self.nconcurrent) } } #[derive(Debug)] pub enum IntegrityError { SizeMismatch{ expect: usize, actual: usize }, Sha1Mismatch{ expect: Digest, actual: Digest } } impl Display for IntegrityError { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { match self { IntegrityError::SizeMismatch{ expect, actual } => write!(f, "size mismatch (expect {expect} bytes, got {actual} bytes)"), IntegrityError::Sha1Mismatch {expect, actual} => write!(f, "sha1 mismatch (expect {expect}, got {actual})") } } } impl Error for IntegrityError {} pub struct VerifiedDownload { url: String, expect_size: Option, expect_sha1: Option, path: PathBuf, file: Option, sha1: Sha1, tally: usize } impl Debug for VerifiedDownload { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { f.debug_struct("VerifiedDownload") .field("url", &self.url) .field("expect_size", &self.expect_size) .field("expect_sha1", &self.expect_sha1) .field("path", &self.path).finish() } } impl Display for VerifiedDownload { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { write!(f, "downloading {} to {}", self.url, self.path.to_string_lossy()) } } impl VerifiedDownload { pub fn new(url: &str, path: &Path, expect_size: Option, expect_sha1: Option) -> VerifiedDownload { VerifiedDownload { url: url.to_owned(), path: path.to_owned(), expect_size, expect_sha1, file: None, sha1: Sha1::new(), tally: 0 } } pub fn with_size(mut self, expect: usize) -> VerifiedDownload { self.expect_size = Some(expect); self } pub fn with_sha1(mut self, expect: Digest) -> VerifiedDownload { self.expect_sha1.replace(expect); self } pub fn get_path(&self) -> &Path { &self.path } pub async fn make_dirs(&self) -> Result<(), io::Error> { fs::create_dir_all(self.path.parent().expect("download created with no containing directory (?)")).await } async fn open_output(&mut self) -> Result<(), io::Error> { self.file.replace(File::create(&self.path).await?); Ok(()) } } impl Download for VerifiedDownload { fn get_url(&self) -> impl IntoUrl { &self.url } async fn prepare(&mut self, req: RequestBuilder) -> Result, Box> { let mut file = match File::open(&self.path).await { Ok(file) => file, Err(e) => return if e.kind() == ErrorKind::NotFound { // assume the parent folder exists (responsibility of the caller to ensure this) debug!("File {} does not exist, downloading it.", self.path.to_string_lossy()); self.open_output().await?; Ok(Some(req)) } else { debug!("Error opening {}: {}", self.path.to_string_lossy(), e); Err(e.into()) } }; // short-circuit this if self.expect_size.is_none() && self.expect_sha1.is_none() { debug!("No size or sha1 for {}, have to assume it's good.", self.path.to_string_lossy()); return Ok(None); } let mut tally = 0usize; let mut sha1 = Sha1::new(); let mut buf = [0u8; 4096]; loop { let n = match file.read(&mut buf).await { Ok(n) => n, Err(e) => match e.kind() { ErrorKind::Interrupted => continue, _ => { debug!("Error reading {}: {}", self.path.to_string_lossy(), e); return Err(e.into()); } } }; if n == 0 { break; } tally += n; sha1.update(&buf[..n]); } if self.expect_sha1.is_none_or(|d| d == sha1.digest()) && self.expect_size.is_none_or(|s| s == tally) { debug!("Not downloading {}, sha1 and size match.", self.path.to_string_lossy()); return Ok(None); } drop(file); // potentially racy to close the file and reopen it... :/ self.open_output().await?; debug!("Downloading {} because sha1 or size does not match.", self.path.to_string_lossy()); Ok(Some(req)) } async fn handle_chunk(&mut self, chunk: &[u8]) -> Result<(), Box> { self.file.as_mut().unwrap().write_all(chunk).await?; self.tally += chunk.len(); self.sha1.update(chunk); Ok(()) } async fn finish(&mut self) -> Result<(), Box> { let digest = self.sha1.digest(); if let Some(d) = self.expect_sha1 { if d != digest { debug!("Could not download {}: sha1 mismatch (exp {}, got {}).", self.path.to_string_lossy(), d, digest); return Err(IntegrityError::Sha1Mismatch { expect: d, actual: digest }.into()); } } else if let Some(s) = self.expect_size { if s != self.tally { debug!("Could not download {}: size mismatch (exp {}, got {}).", self.path.to_string_lossy(), s, self.tally); return Err(IntegrityError::SizeMismatch { expect: s, actual: self.tally }.into()); } } debug!("Successfully downloaded {} ({} bytes).", self.path.to_string_lossy(), self.tally); // release the file descriptor (don't want to wait until it's dropped automatically because idk when that would be) drop(self.file.take().unwrap()); Ok(()) } }