Skip to main content

perspt_agent/
agent.rs

1//! Agent Trait and Implementations
2//!
3//! Defines the interface for all agent implementations and provides
4//! LLM-integrated implementations for Architect, Actuator, and Verifier roles.
5
6use crate::types::{AgentContext, AgentMessage, ModelTier, SRBNNode};
7use anyhow::Result;
8use async_trait::async_trait;
9use perspt_core::llm_provider::GenAIProvider;
10use perspt_core::types::{PromptEvidence, PromptIntent};
11use std::fs;
12use std::path::Path;
13use std::sync::Arc;
14
15/// The Agent trait defines the interface for SRBN agents.
16///
17/// Each agent role (Architect, Actuator, Verifier, Speculator) implements
18/// this trait to provide specialized behavior.
19#[async_trait]
20pub trait Agent: Send + Sync {
21    /// Process a task and return a message
22    async fn process(&self, node: &SRBNNode, ctx: &AgentContext) -> Result<AgentMessage>;
23
24    /// Get the agent's display name
25    fn name(&self) -> &str;
26
27    /// Check if this agent can handle the given node
28    fn can_handle(&self, node: &SRBNNode) -> bool;
29
30    /// Get the model name used by this agent (for logging)
31    fn model(&self) -> &str;
32
33    /// Build the prompt for this agent (for logging)
34    fn build_prompt(&self, node: &SRBNNode, ctx: &AgentContext) -> String;
35}
36
37/// Architect agent - handles planning and DAG construction
38pub struct ArchitectAgent {
39    model: String,
40    provider: Arc<GenAIProvider>,
41}
42
43impl ArchitectAgent {
44    pub fn new(provider: Arc<GenAIProvider>, model: Option<String>) -> Self {
45        Self {
46            model: model.unwrap_or_else(|| ModelTier::Architect.default_model().to_string()),
47            provider,
48        }
49    }
50
51    pub fn build_planning_prompt(&self, node: &SRBNNode, ctx: &AgentContext) -> String {
52        let project_context = format!(
53            "Context Files: {:?}\nOutput Targets: {:?}",
54            node.context_files, node.output_targets
55        );
56        let ev = PromptEvidence {
57            user_goal: Some(node.goal.clone()),
58            project_summary: Some(project_context),
59            working_dir: Some(ctx.working_dir.display().to_string()),
60            active_plugins: ctx.active_plugins.clone(),
61            ..Default::default()
62        };
63        crate::prompt_compiler::compile(PromptIntent::ArchitectExisting, &ev).text
64    }
65}
66
67#[async_trait]
68impl Agent for ArchitectAgent {
69    async fn process(&self, node: &SRBNNode, ctx: &AgentContext) -> Result<AgentMessage> {
70        log::info!(
71            "[Architect] Processing node: {} with model {}",
72            node.node_id,
73            self.model
74        );
75
76        let prompt = self.build_planning_prompt(node, ctx);
77
78        let response = self
79            .provider
80            .generate_response_simple(&self.model, &prompt)
81            .await?
82            .text;
83
84        Ok(AgentMessage::new(ModelTier::Architect, response))
85    }
86
87    fn name(&self) -> &str {
88        "Architect"
89    }
90
91    fn can_handle(&self, node: &SRBNNode) -> bool {
92        matches!(node.tier, ModelTier::Architect)
93    }
94
95    fn model(&self) -> &str {
96        &self.model
97    }
98
99    fn build_prompt(&self, node: &SRBNNode, ctx: &AgentContext) -> String {
100        self.build_planning_prompt(node, ctx)
101    }
102}
103
104/// Actuator agent - handles code generation
105pub struct ActuatorAgent {
106    model: String,
107    provider: Arc<GenAIProvider>,
108}
109
110impl ActuatorAgent {
111    pub fn new(provider: Arc<GenAIProvider>, model: Option<String>) -> Self {
112        Self {
113            model: model.unwrap_or_else(|| ModelTier::Actuator.default_model().to_string()),
114            provider,
115        }
116    }
117
118    pub fn build_coding_prompt(&self, node: &SRBNNode, ctx: &AgentContext) -> String {
119        let contract = &node.contract;
120        let allowed_output_paths: Vec<String> = node
121            .output_targets
122            .iter()
123            .map(|path| path.to_string_lossy().to_string())
124            .collect();
125        let workspace_import_hints = Self::workspace_import_hints(&ctx.working_dir);
126
127        // Determine target file from output_targets or generate default
128        let _target_file = node
129            .output_targets
130            .first()
131            .map(|p| p.to_string_lossy().to_string())
132            .unwrap_or_else(|| "main.py".to_string());
133
134        // PSP-5: Determine output format based on execution mode and plugin
135        let is_project_mode = ctx.execution_mode == perspt_core::types::ExecutionMode::Project;
136        let has_multiple_outputs = node.output_targets.len() > 1;
137
138        let ev = PromptEvidence {
139            node_goal: Some(node.goal.clone()),
140            output_files: allowed_output_paths.clone(),
141            context_files: node
142                .context_files
143                .iter()
144                .map(|p| p.to_string_lossy().to_string())
145                .collect(),
146            interface_signature: Some(contract.interface_signature.clone()),
147            invariants: Some(format!("{:?}", contract.invariants)),
148            forbidden_patterns: Some(format!("{:?}", contract.forbidden_patterns)),
149            working_dir: Some(format!("{:?}", ctx.working_dir)),
150            workspace_import_hints: Some(format!("{:?}", workspace_import_hints)),
151            ..Default::default()
152        };
153        let intent = if is_project_mode || has_multiple_outputs {
154            PromptIntent::ActuatorMultiOutput
155        } else {
156            PromptIntent::ActuatorSingleOutput
157        };
158        crate::prompt_compiler::compile(intent, &ev).text
159    }
160
161    fn workspace_import_hints(working_dir: &Path) -> Vec<String> {
162        let mut hints = Vec::new();
163
164        // Rust: detect workspace members OR single-crate name
165        let rust_hints = Self::detect_rust_workspace_crates(working_dir);
166        if !rust_hints.is_empty() {
167            hints.extend(rust_hints);
168        }
169
170        if let Some(package_name) = Self::detect_python_package_name(working_dir) {
171            hints.push(format!(
172                "Python package import root: {}. Tests and entry points must import `{}` and never `src.{}`.",
173                package_name, package_name, package_name
174            ));
175        }
176
177        hints
178    }
179
180    /// Detect Rust crate names for import hints.
181    ///
182    /// Handles both:
183    /// - Single-crate projects: `[package]` with a `name`
184    /// - Workspace projects: `[workspace]` with `members`, enumerating each member's crate name
185    fn detect_rust_workspace_crates(working_dir: &Path) -> Vec<String> {
186        let cargo_toml = match fs::read_to_string(working_dir.join("Cargo.toml")) {
187            Ok(content) => content,
188            Err(_) => return Vec::new(),
189        };
190
191        // Check if this is a workspace manifest
192        let mut in_workspace = false;
193        let mut in_package = false;
194        let mut members: Vec<String> = Vec::new();
195        let mut single_crate_name: Option<String> = None;
196        let mut is_workspace = false;
197
198        for raw_line in cargo_toml.lines() {
199            let line = raw_line.trim();
200            if line.starts_with('[') {
201                in_workspace = line == "[workspace]";
202                in_package = line == "[package]";
203                if in_workspace {
204                    is_workspace = true;
205                }
206                continue;
207            }
208
209            // Parse [package] name for single-crate projects
210            if in_package && line.starts_with("name") {
211                if let Some((_, value)) = line.split_once('=') {
212                    single_crate_name = Some(value.trim().trim_matches('"').to_string());
213                }
214            }
215
216            // Parse [workspace] members
217            if in_workspace && line.starts_with("members") {
218                if let Some((_, value)) = line.split_once('=') {
219                    let raw = value.trim();
220                    // Parse inline array: members = ["crates/foo", "crates/bar"]
221                    if raw.starts_with('[') {
222                        let inner = raw.trim_start_matches('[').trim_end_matches(']');
223                        for item in inner.split(',') {
224                            let member = item.trim().trim_matches('"').trim_matches('\'');
225                            if !member.is_empty() {
226                                members.push(member.to_string());
227                            }
228                        }
229                    }
230                }
231            }
232        }
233
234        if is_workspace && !members.is_empty() {
235            // Enumerate each member crate's name
236            let mut hints = Vec::new();
237            let mut crate_names = Vec::new();
238
239            for member in &members {
240                let member_cargo = working_dir.join(member).join("Cargo.toml");
241                if let Ok(content) = fs::read_to_string(&member_cargo) {
242                    let mut in_pkg = false;
243                    for raw_line in content.lines() {
244                        let line = raw_line.trim();
245                        if line.starts_with('[') {
246                            in_pkg = line == "[package]";
247                            continue;
248                        }
249                        if in_pkg && line.starts_with("name") {
250                            if let Some((_, value)) = line.split_once('=') {
251                                let name = value.trim().trim_matches('"').to_string();
252                                crate_names.push(name);
253                            }
254                            break;
255                        }
256                    }
257                }
258            }
259
260            if !crate_names.is_empty() {
261                hints.push(format!(
262                    "Rust workspace with {} crate(s): {}. \
263                     Cross-crate imports use `use <crate_name>::...;`. \
264                     Add dependencies between workspace crates via `<name>.workspace = true` \
265                     or `<name> = {{ path = \"../other\" }}`.",
266                    crate_names.len(),
267                    crate_names.join(", ")
268                ));
269            }
270
271            hints
272        } else if let Some(name) = single_crate_name {
273            vec![format!(
274                "Rust crate name: {}. Integration tests and external modules must import via `{}`.",
275                name, name
276            )]
277        } else {
278            Vec::new()
279        }
280    }
281
282    fn detect_python_package_name(working_dir: &Path) -> Option<String> {
283        let src_dir = working_dir.join("src");
284        if let Ok(entries) = fs::read_dir(&src_dir) {
285            for entry in entries.flatten() {
286                if entry.file_type().ok()?.is_dir() {
287                    let name = entry.file_name().to_string_lossy().to_string();
288                    if !name.starts_with('.') {
289                        return Some(name);
290                    }
291                }
292            }
293        }
294
295        let pyproject = fs::read_to_string(working_dir.join("pyproject.toml")).ok()?;
296        let mut in_project = false;
297        for raw_line in pyproject.lines() {
298            let line = raw_line.trim();
299            if line.starts_with('[') {
300                in_project = line == "[project]";
301                continue;
302            }
303
304            if in_project && line.starts_with("name") {
305                let (_, value) = line.split_once('=')?;
306                return Some(value.trim().trim_matches('"').replace('-', "_"));
307            }
308        }
309
310        None
311    }
312}
313
314#[async_trait]
315impl Agent for ActuatorAgent {
316    async fn process(&self, node: &SRBNNode, ctx: &AgentContext) -> Result<AgentMessage> {
317        log::info!(
318            "[Actuator] Processing node: {} with model {}",
319            node.node_id,
320            self.model
321        );
322
323        let prompt = self.build_coding_prompt(node, ctx);
324
325        let response = self
326            .provider
327            .generate_response_simple(&self.model, &prompt)
328            .await?
329            .text;
330
331        Ok(AgentMessage::new(ModelTier::Actuator, response))
332    }
333
334    fn name(&self) -> &str {
335        "Actuator"
336    }
337
338    fn can_handle(&self, node: &SRBNNode) -> bool {
339        matches!(node.tier, ModelTier::Actuator)
340    }
341
342    fn model(&self) -> &str {
343        &self.model
344    }
345
346    fn build_prompt(&self, node: &SRBNNode, ctx: &AgentContext) -> String {
347        self.build_coding_prompt(node, ctx)
348    }
349}
350
351/// Verifier agent - handles stability verification and contract checking
352pub struct VerifierAgent {
353    model: String,
354    provider: Arc<GenAIProvider>,
355}
356
357impl VerifierAgent {
358    pub fn new(provider: Arc<GenAIProvider>, model: Option<String>) -> Self {
359        Self {
360            model: model.unwrap_or_else(|| ModelTier::Verifier.default_model().to_string()),
361            provider,
362        }
363    }
364
365    pub fn build_verification_prompt(&self, node: &SRBNNode, implementation: &str) -> String {
366        let contract = &node.contract;
367        let ev = PromptEvidence {
368            interface_signature: Some(contract.interface_signature.clone()),
369            invariants: Some(format!("{:?}", contract.invariants)),
370            forbidden_patterns: Some(format!("{:?}", contract.forbidden_patterns)),
371            weighted_tests: Some(format!("{:?}", contract.weighted_tests)),
372            existing_file_contents: vec![(String::new(), implementation.to_string())],
373            ..Default::default()
374        };
375        crate::prompt_compiler::compile(PromptIntent::VerifierAnalysis, &ev).text
376    }
377}
378
379#[async_trait]
380impl Agent for VerifierAgent {
381    async fn process(&self, node: &SRBNNode, ctx: &AgentContext) -> Result<AgentMessage> {
382        log::info!(
383            "[Verifier] Processing node: {} with model {}",
384            node.node_id,
385            self.model
386        );
387
388        // In a real implementation, we would get the actual implementation from the context
389        let implementation = ctx
390            .history
391            .last()
392            .map(|m| m.content.as_str())
393            .unwrap_or("No implementation provided");
394
395        let prompt = self.build_verification_prompt(node, implementation);
396
397        let response = self
398            .provider
399            .generate_response_simple(&self.model, &prompt)
400            .await?
401            .text;
402
403        Ok(AgentMessage::new(ModelTier::Verifier, response))
404    }
405
406    fn name(&self) -> &str {
407        "Verifier"
408    }
409
410    fn can_handle(&self, node: &SRBNNode) -> bool {
411        matches!(node.tier, ModelTier::Verifier)
412    }
413
414    fn model(&self) -> &str {
415        &self.model
416    }
417
418    fn build_prompt(&self, node: &SRBNNode, _ctx: &AgentContext) -> String {
419        // Verifier needs implementation context, use a placeholder
420        self.build_verification_prompt(node, "<implementation>")
421    }
422}
423
424/// Speculator agent - handles fast lookahead for exploration
425pub struct SpeculatorAgent {
426    model: String,
427    provider: Arc<GenAIProvider>,
428}
429
430impl SpeculatorAgent {
431    pub fn new(provider: Arc<GenAIProvider>, model: Option<String>) -> Self {
432        Self {
433            model: model.unwrap_or_else(|| ModelTier::Speculator.default_model().to_string()),
434            provider,
435        }
436    }
437}
438
439#[async_trait]
440impl Agent for SpeculatorAgent {
441    async fn process(&self, node: &SRBNNode, ctx: &AgentContext) -> Result<AgentMessage> {
442        log::info!(
443            "[Speculator] Processing node: {} with model {}",
444            node.node_id,
445            self.model
446        );
447
448        let prompt = self.build_prompt(node, ctx);
449
450        let response = self
451            .provider
452            .generate_response_simple(&self.model, &prompt)
453            .await?
454            .text;
455
456        Ok(AgentMessage::new(ModelTier::Speculator, response))
457    }
458
459    fn name(&self) -> &str {
460        "Speculator"
461    }
462
463    fn can_handle(&self, node: &SRBNNode) -> bool {
464        matches!(node.tier, ModelTier::Speculator)
465    }
466
467    fn model(&self) -> &str {
468        &self.model
469    }
470
471    fn build_prompt(&self, node: &SRBNNode, _ctx: &AgentContext) -> String {
472        let ev = PromptEvidence {
473            node_goal: Some(node.goal.clone()),
474            ..Default::default()
475        };
476        crate::prompt_compiler::compile(PromptIntent::SpeculatorBasic, &ev).text
477    }
478}
479
480#[cfg(test)]
481mod tests {
482    use super::*;
483    use tempfile::tempdir;
484
485    #[test]
486    fn build_coding_prompt_includes_rust_crate_hint() {
487        let dir = tempdir().unwrap();
488        fs::write(
489            dir.path().join("Cargo.toml"),
490            "[package]\nname = \"validator_lib\"\nversion = \"0.1.0\"\n",
491        )
492        .unwrap();
493
494        let provider = Arc::new(GenAIProvider::new().unwrap());
495        let agent = ActuatorAgent::new(provider, Some("test-model".into()));
496        let mut node = SRBNNode::new("n1".into(), "goal".into(), ModelTier::Actuator);
497        node.output_targets.push("tests/integration.rs".into());
498        let ctx = AgentContext {
499            working_dir: dir.path().to_path_buf(),
500            ..Default::default()
501        };
502
503        let prompt = agent.build_coding_prompt(&node, &ctx);
504        assert!(
505            prompt.contains("Rust crate name: validator_lib"),
506            "{prompt}"
507        );
508    }
509
510    #[test]
511    fn build_coding_prompt_includes_python_package_hint() {
512        let dir = tempdir().unwrap();
513        fs::create_dir_all(dir.path().join("src/psp5_python_verify")).unwrap();
514        fs::write(
515            dir.path().join("pyproject.toml"),
516            "[project]\nname = \"psp5-python-verify\"\nversion = \"0.1.0\"\n",
517        )
518        .unwrap();
519
520        let provider = Arc::new(GenAIProvider::new().unwrap());
521        let agent = ActuatorAgent::new(provider, Some("test-model".into()));
522        let mut node = SRBNNode::new("n1".into(), "goal".into(), ModelTier::Actuator);
523        node.output_targets.push("tests/test_main.py".into());
524        let ctx = AgentContext {
525            working_dir: dir.path().to_path_buf(),
526            ..Default::default()
527        };
528
529        let prompt = agent.build_coding_prompt(&node, &ctx);
530        assert!(
531            prompt.contains("Python package import root: psp5_python_verify"),
532            "{prompt}"
533        );
534    }
535}