1use anyhow::{Context, Result};
6use starlark::environment::{FrozenModule, Globals, GlobalsBuilder, Module};
7use starlark::eval::Evaluator;
8use starlark::starlark_module;
9use starlark::syntax::{AstModule, Dialect};
10use starlark::values::none::NoneType;
11use std::path::{Path, PathBuf};
12
13#[derive(Debug, Clone, PartialEq, Eq)]
15pub enum PolicyDecision {
16 Allow,
18 Prompt(String),
20 Deny(String),
22}
23
24pub struct PolicyEngine {
26 policies: Vec<FrozenModule>,
28 policy_dir: PathBuf,
30}
31
32impl PolicyEngine {
33 pub fn new() -> Result<Self> {
35 let policy_dir = Self::default_policy_dir();
36 let mut engine = Self {
37 policies: Vec::new(),
38 policy_dir: policy_dir.clone(),
39 };
40
41 if policy_dir.exists() {
43 engine.load_policies()?;
44 } else {
45 log::info!(
46 "Policy directory {:?} does not exist, using defaults",
47 policy_dir
48 );
49 }
50
51 Ok(engine)
52 }
53
54 pub fn default_policy_dir() -> PathBuf {
56 perspt_core::paths::resolve_policy_dir()
58 .or_else(perspt_core::paths::policy_dir)
59 .unwrap_or_else(|| PathBuf::from(".").join(".perspt").join("rules"))
60 }
61
62 pub fn load_policies(&mut self) -> Result<()> {
64 if !self.policy_dir.exists() {
65 return Ok(());
66 }
67
68 for entry in std::fs::read_dir(&self.policy_dir)? {
69 let entry = entry?;
70 let path = entry.path();
71
72 if path.extension().is_some_and(|ext| ext == "star") {
73 match self.load_policy_file(&path) {
74 Ok(module) => {
75 self.policies.push(module);
76 log::info!("Loaded policy: {:?}", path);
77 }
78 Err(e) => {
79 log::warn!("Failed to load policy {:?}: {}", path, e);
80 }
81 }
82 }
83 }
84
85 log::info!("Loaded {} policies", self.policies.len());
86 Ok(())
87 }
88
89 fn load_policy_file(&self, path: &Path) -> Result<FrozenModule> {
91 let content = std::fs::read_to_string(path)
92 .context(format!("Failed to read policy file: {:?}", path))?;
93
94 let ast = AstModule::parse(path.to_string_lossy().as_ref(), content, &Dialect::Standard)
95 .map_err(|e| anyhow::anyhow!("Parse error: {}", e))?;
96
97 let globals = Self::create_globals();
98 let module = Module::new();
99
100 {
101 let mut eval = Evaluator::new(&module);
102 eval.eval_module(ast, &globals)
103 .map_err(|e| anyhow::anyhow!("Eval error: {}", e))?;
104 }
105
106 Ok(module.freeze()?)
107 }
108
109 fn create_globals() -> Globals {
111 #[starlark_module]
112 fn policy_builtins(builder: &mut GlobalsBuilder) {
113 fn matches_pattern(command: &str, pattern: &str) -> anyhow::Result<bool> {
115 Ok(command.contains(pattern))
116 }
117
118 fn log_policy(message: &str) -> anyhow::Result<NoneType> {
120 log::info!("[Policy] {}", message);
121 Ok(NoneType)
122 }
123 }
124
125 GlobalsBuilder::standard().with(policy_builtins).build()
126 }
127
128 pub fn evaluate(&self, command: &str) -> PolicyDecision {
130 if self.policies.is_empty() {
132 return self.default_policy(command);
133 }
134
135 self.default_policy(command)
138 }
139
140 fn default_policy(&self, command: &str) -> PolicyDecision {
142 let dangerous_patterns = ["rm -rf", "sudo", "chmod 777", "> /dev/", "mkfs", "dd if="];
144
145 for pattern in &dangerous_patterns {
146 if command.contains(pattern) {
147 return PolicyDecision::Deny(format!(
148 "Command contains dangerous pattern: {}",
149 pattern
150 ));
151 }
152 }
153
154 let network_patterns = ["curl", "wget", "nc ", "ssh ", "scp "];
156 for pattern in &network_patterns {
157 if command.contains(pattern) {
158 return PolicyDecision::Prompt(format!(
159 "Command requires network access: {}",
160 command
161 ));
162 }
163 }
164
165 if command.contains("git push") || command.contains("git force") {
167 return PolicyDecision::Prompt("Git push operation requires confirmation".to_string());
168 }
169
170 PolicyDecision::Allow
171 }
172
173 pub fn is_safe(&self, command: &str) -> bool {
175 matches!(self.evaluate(command), PolicyDecision::Allow)
176 }
177}
178
179impl Default for PolicyEngine {
180 fn default() -> Self {
181 Self::new().unwrap_or_else(|_| Self {
182 policies: Vec::new(),
183 policy_dir: PathBuf::from("."),
184 })
185 }
186}
187
188#[cfg(test)]
189mod tests {
190 use super::*;
191
192 #[test]
193 fn test_default_policy_allows_safe_commands() {
194 let engine = PolicyEngine::default();
195 assert!(matches!(
196 engine.evaluate("cargo build"),
197 PolicyDecision::Allow
198 ));
199 assert!(matches!(engine.evaluate("ls -la"), PolicyDecision::Allow));
200 }
201
202 #[test]
203 fn test_default_policy_denies_dangerous() {
204 let engine = PolicyEngine::default();
205 assert!(matches!(
206 engine.evaluate("rm -rf /"),
207 PolicyDecision::Deny(_)
208 ));
209 assert!(matches!(
210 engine.evaluate("sudo rm file"),
211 PolicyDecision::Deny(_)
212 ));
213 }
214
215 #[test]
216 fn test_default_policy_prompts_network() {
217 let engine = PolicyEngine::default();
218 assert!(matches!(
219 engine.evaluate("curl https://example.com"),
220 PolicyDecision::Prompt(_)
221 ));
222 }
223}