1use md5::Md5;
8use sha1::Sha1;
9use sha2::{Digest, Sha256, Sha512};
10use std::fs::File;
11use std::io::{self, Write};
12use std::path::Path;
13use thiserror::Error;
14use winbrew_models::shared::hash::HashAlgorithm;
15
16#[derive(Debug, Error, Clone, PartialEq, Eq)]
18pub enum HashError {
19 #[error("checksum mismatch for installer: expected {expected}, got {actual}")]
20 ChecksumMismatch { expected: String, actual: String },
21
22 #[error("{algorithm} checksums are disabled by default for security")]
23 LegacyChecksumAlgorithm { algorithm: HashAlgorithm },
24}
25
26pub type Result<T> = std::result::Result<T, HashError>;
27
28#[derive(Debug)]
30pub enum Hasher {
31 Md5(Md5),
32 Sha1(Sha1),
33 Sha256(Sha256),
34 Sha512(Sha512),
35}
36
37impl Hasher {
38 pub fn new(algorithm: HashAlgorithm) -> Self {
39 match algorithm {
40 HashAlgorithm::Md5 => Self::Md5(Md5::new()),
41 HashAlgorithm::Sha1 => Self::Sha1(Sha1::new()),
42 HashAlgorithm::Sha256 => Self::Sha256(Sha256::new()),
43 HashAlgorithm::Sha512 => Self::Sha512(Sha512::new()),
44 }
45 }
46
47 pub fn update(&mut self, bytes: &[u8]) {
48 match self {
49 Self::Md5(hasher) => hasher.update(bytes),
50 Self::Sha1(hasher) => hasher.update(bytes),
51 Self::Sha256(hasher) => hasher.update(bytes),
52 Self::Sha512(hasher) => hasher.update(bytes),
53 }
54 }
55
56 pub fn finalize(self) -> Vec<u8> {
57 match self {
58 Self::Md5(hasher) => hasher.finalize().to_vec(),
59 Self::Sha1(hasher) => hasher.finalize().to_vec(),
60 Self::Sha256(hasher) => hasher.finalize().to_vec(),
61 Self::Sha512(hasher) => hasher.finalize().to_vec(),
62 }
63 }
64}
65
66pub fn hash_algorithm(value: &str) -> Option<HashAlgorithm> {
68 HashAlgorithm::detect(value)
69}
70
71pub fn verify_hash(expected_hash: &str, actual_hash: impl AsRef<[u8]>) -> Result<()> {
72 let expected_hash = normalize_hash(expected_hash);
73 if expected_hash.is_empty() {
74 return Ok(());
75 }
76
77 let bytes = actual_hash.as_ref();
78 let mut actual_hash = String::with_capacity(bytes.len() * 2);
79 const HEX_CHARS: &[u8; 16] = b"0123456789abcdef";
80
81 for &byte in bytes {
82 actual_hash.push(HEX_CHARS[(byte >> 4) as usize] as char);
83 actual_hash.push(HEX_CHARS[(byte & 0x0f) as usize] as char);
84 }
85
86 if actual_hash != expected_hash {
87 return Err(HashError::ChecksumMismatch {
88 expected: expected_hash,
89 actual: actual_hash,
90 });
91 }
92
93 Ok(())
94}
95
96pub fn hash_file(path: &Path, algorithm: HashAlgorithm) -> io::Result<Vec<u8>> {
97 let mut file = File::open(path)?;
98 let mut writer = HashWriter::new(Hasher::new(algorithm));
99
100 io::copy(&mut file, &mut writer)?;
101
102 Ok(writer.finish())
103}
104
105pub fn normalize_hash(value: &str) -> String {
106 let trimmed = value.trim();
107 let stripped = ["md5:", "sha1:", "sha256:", "sha512:"]
108 .into_iter()
109 .find_map(|prefix| trimmed.strip_prefix(prefix))
110 .unwrap_or(trimmed);
111
112 stripped.to_ascii_lowercase()
113}
114
115struct HashWriter {
116 hasher: Hasher,
117}
118
119impl HashWriter {
120 fn new(hasher: Hasher) -> Self {
121 Self { hasher }
122 }
123
124 fn finish(self) -> Vec<u8> {
125 self.hasher.finalize()
126 }
127}
128
129impl Write for HashWriter {
130 fn write(&mut self, bytes: &[u8]) -> io::Result<usize> {
131 self.hasher.update(bytes);
132 Ok(bytes.len())
133 }
134
135 fn flush(&mut self) -> io::Result<()> {
136 Ok(())
137 }
138}
139
140#[cfg(test)]
141mod tests {
142 use super::{HashAlgorithm, Hasher, hash_algorithm, hash_file, normalize_hash, verify_hash};
143 use sha2::{Digest, Sha256, Sha512};
144 use std::fs;
145 use tempfile::tempdir;
146
147 #[test]
148 fn normalize_hash_strips_prefix_and_whitespace() {
149 assert_eq!(normalize_hash(" md5:ABC123 "), "abc123");
150 assert_eq!(normalize_hash(" sha256:ABC123 "), "abc123");
151 assert_eq!(normalize_hash(" sha1:ABC123 "), "abc123");
152 assert_eq!(normalize_hash(" sha512:ABC123 "), "abc123");
153 assert_eq!(normalize_hash(" ABC123 "), "abc123");
154 }
155
156 #[test]
157 fn verify_hash_accepts_matching_hash() {
158 let actual = [0x12, 0x34, 0xab, 0xcd];
159 assert!(verify_hash("sha256:1234abcd", actual).is_ok());
160 }
161
162 #[test]
163 fn hash_algorithm_detects_supported_algorithms() {
164 assert_eq!(
165 hash_algorithm("md5:aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"),
166 Some(HashAlgorithm::Md5)
167 );
168 assert_eq!(
169 hash_algorithm("sha1:aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"),
170 Some(HashAlgorithm::Sha1)
171 );
172 assert_eq!(
173 hash_algorithm(
174 "sha256:aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"
175 ),
176 Some(HashAlgorithm::Sha256)
177 );
178 assert_eq!(
179 hash_algorithm(
180 "sha512:aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"
181 ),
182 Some(HashAlgorithm::Sha512)
183 );
184 }
185
186 #[test]
187 fn verify_hash_rejects_mismatch() {
188 let actual = [0x12, 0x34, 0xab, 0xcd];
189 assert!(verify_hash("sha256:11111111", actual).is_err());
190 }
191
192 #[test]
193 fn hasher_streams_sha512_chunks() {
194 let mut hasher = Hasher::new(HashAlgorithm::Sha512);
195 hasher.update(b"ab");
196 hasher.update(b"c");
197
198 assert_eq!(hasher.finalize(), Sha512::digest(b"abc").to_vec());
199 }
200
201 #[test]
202 fn hash_file_streams_contents() {
203 let temp_dir = tempdir().expect("temp dir");
204 let path = temp_dir.path().join("payload.bin");
205
206 fs::write(&path, b"abc").expect("write payload");
207
208 let digest = hash_file(&path, HashAlgorithm::Sha256).expect("hash file");
209
210 assert_eq!(digest, Sha256::digest(b"abc").to_vec());
211 }
212}