winbrew_database/
command_registry.rs

1use anyhow::{Context, Result};
2use rusqlite::{
3    Connection, Error as SqlError, ErrorCode, OptionalExtension, params, params_from_iter,
4};
5use std::collections::BTreeMap;
6use thiserror::Error;
7
8#[derive(Debug, Error, Clone, PartialEq, Eq)]
9#[error(
10    "command '{command_name}' was claimed by another install while this install was in progress"
11)]
12pub struct CommandRegistryConflictError {
13    pub command_name: String,
14}
15
16pub fn parse_command_names(raw_commands: Option<&str>) -> Result<Vec<String>> {
17    let Some(raw_commands) = raw_commands else {
18        return Ok(Vec::new());
19    };
20
21    let commands: Vec<String> = serde_json::from_str(raw_commands)
22        .with_context(|| "failed to parse exposed commands JSON")?;
23
24    Ok(normalize_command_names(commands))
25}
26
27pub fn find_command_owner(conn: &Connection, command_name: &str) -> Result<Option<String>> {
28    let mut stmt = conn.prepare(
29        "SELECT package_name
30         FROM command_registry
31         WHERE command_name = ?1",
32    )?;
33
34    stmt.query_row(params![command_name], |row| row.get::<_, String>(0))
35        .optional()
36        .context("failed to read command registry")
37}
38
39pub fn find_command_owners(
40    conn: &Connection,
41    command_names: &[String],
42) -> Result<BTreeMap<String, String>> {
43    let command_names = normalize_command_names(command_names.iter().map(|value| value.as_str()));
44
45    if command_names.is_empty() {
46        return Ok(BTreeMap::new());
47    }
48
49    let placeholders = std::iter::repeat_n("?", command_names.len())
50        .collect::<Vec<_>>()
51        .join(", ");
52    let query = format!(
53        "SELECT command_name, package_name
54         FROM command_registry
55         WHERE command_name IN ({placeholders})
56         ORDER BY command_name ASC"
57    );
58
59    let mut stmt = conn.prepare(&query)?;
60    let mut rows = stmt.query(params_from_iter(
61        command_names.iter().map(|value| value.as_str()),
62    ))?;
63
64    let mut owners = BTreeMap::new();
65    while let Some(row) = rows.next()? {
66        let command_name: String = row.get(0)?;
67        let package_name: String = row.get(1)?;
68        owners.insert(command_name, package_name);
69    }
70
71    Ok(owners)
72}
73
74pub fn get_package_command_names(
75    conn: &Connection,
76    package_name: &str,
77) -> Result<Option<Vec<String>>> {
78    let mut stmt = conn.prepare(
79        "SELECT commands_json
80         FROM package_command_lists
81         WHERE package_name = ?1",
82    )?;
83
84    let commands_json = stmt
85        .query_row(params![package_name], |row| row.get::<_, String>(0))
86        .optional()
87        .context("failed to read package command list")?;
88
89    let Some(commands_json) = commands_json else {
90        return Ok(None);
91    };
92
93    Ok(Some(parse_command_names(Some(commands_json.as_str()))?))
94}
95
96pub fn list_commands_for_package(conn: &Connection, package_name: &str) -> Result<Vec<String>> {
97    let mut stmt = conn.prepare(
98        "SELECT command_name
99         FROM command_registry
100         WHERE package_name = ?1
101         ORDER BY command_name ASC",
102    )?;
103
104    stmt.query_map(params![package_name], |row| row.get::<_, String>(0))?
105        .collect::<std::result::Result<Vec<_>, _>>()
106        .context("failed to read package commands")
107}
108
109pub fn sync_package_commands(
110    conn: &Connection,
111    package_name: &str,
112    raw_commands: Option<&str>,
113) -> Result<()> {
114    let commands = parse_command_names(raw_commands)?;
115    let commands_json =
116        serde_json::to_string(&commands).context("failed to serialize package command list")?;
117
118    conn.execute(
119        "INSERT INTO package_command_lists (package_name, commands_json)
120         VALUES (?1, ?2)
121         ON CONFLICT(package_name) DO UPDATE SET
122             commands_json = excluded.commands_json",
123        params![package_name, commands_json],
124    )
125    .context("failed to upsert package command list")?;
126
127    conn.execute(
128        "DELETE FROM command_registry WHERE package_name = ?1",
129        params![package_name],
130    )
131    .context("failed to clear command registry rows")?;
132
133    let mut stmt = conn.prepare(
134        "INSERT INTO command_registry (command_name, package_name)
135         VALUES (?1, ?2)",
136    )?;
137
138    for command_name in commands {
139        match stmt.execute(params![command_name.as_str(), package_name]) {
140            Ok(_) => {}
141            Err(err) if is_unique_conflict(&err) => {
142                return Err(CommandRegistryConflictError { command_name }.into());
143            }
144            Err(err) => return Err(err).context("failed to update command registry"),
145        }
146    }
147
148    Ok(())
149}
150
151fn normalize_command_names<I, S>(commands: I) -> Vec<String>
152where
153    I: IntoIterator<Item = S>,
154    S: AsRef<str>,
155{
156    let mut normalized = BTreeMap::new();
157
158    for command in commands {
159        let trimmed = command.as_ref().trim();
160        if trimmed.is_empty() {
161            continue;
162        }
163
164        normalized
165            .entry(trimmed.to_ascii_lowercase())
166            .or_insert_with(|| trimmed.to_string());
167    }
168
169    normalized.into_values().collect()
170}
171
172fn is_unique_conflict(err: &SqlError) -> bool {
173    matches!(err, SqlError::SqliteFailure(error, _) if error.code == ErrorCode::ConstraintViolation)
174}