winbrew_core\network/
download.rs

1use std::fs;
2use std::io::{BufWriter, Read, Write};
3use std::path::Path;
4use std::time::Duration;
5
6use super::{BoxError, DownloadError, Result};
7
8/// Blocking HTTP client used by the download helpers.
9pub type Client = reqwest::blocking::Client;
10
11const DOWNLOAD_REQUEST_TIMEOUT_SECS: u64 = 300;
12const DOWNLOAD_CONNECT_TIMEOUT_SECS: u64 = 30;
13const DOWNLOAD_READ_BUFFER_SIZE: usize = 256 * 1024;
14const DOWNLOAD_WRITE_BUFFER_SIZE: usize = 1024 * 1024;
15const PROGRESS_REPORT_INTERVAL: u64 = 1024 * 1024;
16
17/// Builds the shared blocking HTTP client for downloads.
18///
19/// The client applies the caller-provided user agent plus the shared request and
20/// connect timeouts used across the download pipeline.
21pub fn build_client(user_agent: &str) -> Result<Client> {
22    Client::builder()
23        .user_agent(user_agent)
24        .timeout(Duration::from_secs(DOWNLOAD_REQUEST_TIMEOUT_SECS))
25        .connect_timeout(Duration::from_secs(DOWNLOAD_CONNECT_TIMEOUT_SECS))
26        .build()
27        .map_err(DownloadError::build_client)
28}
29
30/// Streams `url` into `temp_path`.
31///
32/// `on_start` receives the server-reported content length when available.
33/// `on_progress` receives byte deltas since the previous progress callback, which
34/// matches the existing `inc(...)`-style callers in the UI layer.
35/// `on_chunk` sees each raw chunk before it is written to disk.
36/// If the server reports a content length, the helper verifies the streamed byte
37/// count matches it before committing the temp file.
38/// The temporary file is removed automatically if streaming or finalization fails.
39pub fn download_url_to_temp_file<FStart, FProgress, FChunk>(
40    client: &Client,
41    url: &str,
42    temp_path: &Path,
43    label: impl std::fmt::Display,
44    on_start: FStart,
45    mut on_progress: FProgress,
46    mut on_chunk: FChunk,
47) -> Result<()>
48where
49    FStart: FnOnce(Option<u64>),
50    FProgress: FnMut(u64),
51    FChunk: FnMut(&[u8]) -> std::result::Result<(), BoxError>,
52{
53    let label = label.to_string();
54    let temp_file_guard = TempFileGuard::new(temp_path);
55
56    let mut response = client
57        .get(url)
58        .send()
59        .map_err(|err| DownloadError::request(&label, url, err))?
60        .error_for_status()
61        .map_err(|err| DownloadError::request_failed(&label, err))?;
62
63    let content_length = response.content_length();
64
65    let file = fs::File::create(temp_path)
66        .map_err(|err| DownloadError::create_temp_file(&label, temp_path.to_path_buf(), err))?;
67
68    if let Some(total_size) = content_length {
69        file.set_len(total_size)
70            .map_err(|err| DownloadError::preallocate(&label, err))?;
71        // `set_len` reserves the size, but the cursor still starts at byte 0.
72    }
73
74    on_start(content_length);
75
76    let mut writer = BufWriter::with_capacity(DOWNLOAD_WRITE_BUFFER_SIZE, file);
77    let mut buffer = [0u8; DOWNLOAD_READ_BUFFER_SIZE];
78    let mut downloaded: u64 = 0;
79    let mut last_reported: u64 = 0;
80
81    loop {
82        let read = response
83            .read(&mut buffer)
84            .map_err(|err| DownloadError::read(&label, err))?;
85        if read == 0 {
86            break;
87        }
88
89        let chunk = &buffer[..read];
90        on_chunk(chunk)?;
91        writer
92            .write_all(chunk)
93            .map_err(|err| DownloadError::write(&label, err))?;
94
95        downloaded += read as u64;
96        if downloaded - last_reported >= PROGRESS_REPORT_INTERVAL {
97            on_progress(downloaded - last_reported);
98            last_reported = downloaded;
99        }
100    }
101
102    if last_reported != downloaded {
103        on_progress(downloaded - last_reported);
104    }
105
106    validate_download_size(&label, content_length, downloaded)?;
107
108    let file = writer
109        .into_inner()
110        .map_err(|err| err.into_error())
111        .map_err(|err| DownloadError::finalize_buffer(&label, err))?;
112
113    // This is the durability boundary for the temp file; callers only rename it.
114    file.sync_all()
115        .map_err(|err| DownloadError::sync(&label, err))?;
116
117    temp_file_guard.commit();
118
119    Ok(())
120}
121
122/// Returns the last URL path segment or `download.bin` when the URL does not expose one.
123pub fn installer_filename(url: &str) -> String {
124    last_path_segment(url).unwrap_or_else(|| "download.bin".to_string())
125}
126
127/// Returns `true` when the URL path ends in `.zip`, ignoring query and fragment parts.
128pub fn is_zip_path(url: &str) -> bool {
129    last_path_segment(url).is_some_and(|segment| {
130        segment
131            .rsplit_once('.')
132            .is_some_and(|(_, ext)| ext.eq_ignore_ascii_case("zip"))
133    })
134}
135
136/// Returns `true` when the URL path ends in `.7z`, ignoring query and fragment parts.
137pub fn is_7z_path(url: &str) -> bool {
138    last_path_segment(url).is_some_and(|segment| {
139        segment
140            .rsplit_once('.')
141            .is_some_and(|(_, ext)| ext.eq_ignore_ascii_case("7z"))
142    })
143}
144
145fn last_path_segment(url: &str) -> Option<String> {
146    let parsed = url::Url::parse(url).ok()?;
147
148    parsed
149        .path_segments()?
150        .next_back()
151        .filter(|segment| !segment.is_empty())
152        .map(str::to_string)
153}
154
155/// Validates the streamed byte count against `Content-Length` when the server
156/// reports one.
157///
158/// The check is exact: both short downloads and extra bytes are treated as
159/// errors. If the server does not report a length, the check is skipped.
160fn validate_download_size(label: &str, expected: Option<u64>, actual: u64) -> Result<()> {
161    if let Some(expected) = expected
162        && actual != expected
163    {
164        return Err(DownloadError::size_mismatch(label, expected, actual));
165    }
166
167    Ok(())
168}
169
170struct TempFileGuard<'a> {
171    path: &'a Path,
172    committed: bool,
173}
174
175impl<'a> TempFileGuard<'a> {
176    fn new(path: &'a Path) -> Self {
177        Self {
178            path,
179            committed: false,
180        }
181    }
182
183    fn commit(mut self) {
184        self.committed = true;
185    }
186}
187
188impl Drop for TempFileGuard<'_> {
189    fn drop(&mut self) {
190        if !self.committed {
191            let _ = fs::remove_file(self.path);
192        }
193    }
194}
195
196#[cfg(test)]
197mod tests {
198    use super::{installer_filename, is_7z_path, is_zip_path};
199
200    #[test]
201    fn installer_filename_uses_last_segment() {
202        assert_eq!(
203            installer_filename("https://example.invalid/a/b/tool.zip"),
204            "tool.zip"
205        );
206    }
207
208    #[test]
209    fn installer_filename_ignores_query_and_fragment() {
210        assert_eq!(
211            installer_filename("https://example.invalid/tool.exe?token=123#xyz"),
212            "tool.exe"
213        );
214    }
215
216    #[test]
217    fn installer_filename_falls_back_when_last_segment_is_empty() {
218        assert_eq!(
219            installer_filename("https://example.invalid/downloads/"),
220            "download.bin"
221        );
222    }
223
224    #[test]
225    fn is_zip_path_ignores_query_string() {
226        assert!(is_zip_path("https://example.invalid/tool.zip?token=abc"));
227        assert!(!is_zip_path("https://example.invalid/tool.exe?token=abc"));
228    }
229
230    #[test]
231    fn is_zip_path_rejects_empty_last_segment() {
232        assert!(!is_zip_path("https://example.invalid/downloads/"));
233    }
234
235    #[test]
236    fn is_zip_path_is_case_insensitive() {
237        assert!(is_zip_path("https://example.invalid/tool.ZIP"));
238        assert!(is_zip_path("https://example.invalid/tool.Zip"));
239    }
240
241    #[test]
242    fn is_7z_path_ignores_query_string() {
243        assert!(is_7z_path("https://example.invalid/tool.7z?token=abc"));
244        assert!(!is_7z_path("https://example.invalid/tool.exe?token=abc"));
245    }
246
247    #[test]
248    fn is_7z_path_is_case_insensitive() {
249        assert!(is_7z_path("https://example.invalid/tool.7Z"));
250        assert!(is_7z_path("https://example.invalid/tool.7z"));
251    }
252
253    #[test]
254    fn validate_download_size_accepts_matching_length() {
255        assert!(super::validate_download_size("installer", Some(42), 42).is_ok());
256    }
257
258    #[test]
259    fn validate_download_size_skips_check_without_content_length() {
260        assert!(super::validate_download_size("installer", None, 0).is_ok());
261        assert!(super::validate_download_size("installer", None, 999).is_ok());
262    }
263
264    #[test]
265    fn validate_download_size_rejects_length_mismatch() {
266        let error = super::validate_download_size("installer", Some(42), 41)
267            .expect_err("expected size mismatch error");
268
269        assert!(
270            error
271                .to_string()
272                .contains("installer size mismatch: expected 42, got 41")
273        );
274    }
275}