1use serde::de::DeserializeOwned;
17
18#[derive(Debug, Clone, Copy, PartialEq, Eq)]
23pub enum ProviderFamily {
24 OpenAI,
25 Anthropic,
26 Gemini,
27 Groq,
28 Cohere,
29 XAI,
30 DeepSeek,
31 Ollama,
32 Unknown,
33}
34
35impl ProviderFamily {
36 pub fn from_model_name(model: &str) -> Self {
41 let lower = model.to_lowercase();
42 if lower.starts_with("gpt-")
43 || lower.starts_with("o1-")
44 || lower.starts_with("o3-")
45 || lower.starts_with("o4-")
46 || lower.contains("openai")
47 {
48 ProviderFamily::OpenAI
49 } else if lower.starts_with("claude") || lower.contains("anthropic") {
50 ProviderFamily::Anthropic
51 } else if lower.starts_with("gemini") || lower.contains("google") {
52 ProviderFamily::Gemini
53 } else if lower.contains("groq")
54 || lower.starts_with("llama")
55 || lower.starts_with("mixtral")
56 {
57 ProviderFamily::Groq
58 } else if lower.starts_with("command") || lower.contains("cohere") {
59 ProviderFamily::Cohere
60 } else if lower.starts_with("grok") || lower.contains("xai") {
61 ProviderFamily::XAI
62 } else if lower.starts_with("deepseek") {
63 ProviderFamily::DeepSeek
64 } else if lower.contains("ollama") {
65 ProviderFamily::Ollama
66 } else {
67 ProviderFamily::Unknown
68 }
69 }
70}
71
72impl std::fmt::Display for ProviderFamily {
73 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
74 match self {
75 ProviderFamily::OpenAI => write!(f, "openai"),
76 ProviderFamily::Anthropic => write!(f, "anthropic"),
77 ProviderFamily::Gemini => write!(f, "gemini"),
78 ProviderFamily::Groq => write!(f, "groq"),
79 ProviderFamily::Cohere => write!(f, "cohere"),
80 ProviderFamily::XAI => write!(f, "xai"),
81 ProviderFamily::DeepSeek => write!(f, "deepseek"),
82 ProviderFamily::Ollama => write!(f, "ollama"),
83 ProviderFamily::Unknown => write!(f, "unknown"),
84 }
85 }
86}
87
88#[derive(Debug, Clone, Copy, PartialEq, Eq)]
90pub enum ExtractionMethod {
91 FencedJson,
93 GenericFence,
95 DirectJson,
97 EmbeddedJson,
99}
100
101impl std::fmt::Display for ExtractionMethod {
102 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
103 match self {
104 ExtractionMethod::FencedJson => write!(f, "fenced_json"),
105 ExtractionMethod::GenericFence => write!(f, "generic_fence"),
106 ExtractionMethod::DirectJson => write!(f, "direct_json"),
107 ExtractionMethod::EmbeddedJson => write!(f, "embedded_json"),
108 }
109 }
110}
111
112#[derive(Debug, Clone)]
114pub struct NormalizedOutput {
115 pub json_body: String,
117 pub method: ExtractionMethod,
119}
120
121#[derive(Debug, Clone)]
123pub struct NormalizationError {
124 pub reason: String,
126 pub input_len: usize,
128}
129
130impl std::fmt::Display for NormalizationError {
131 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
132 write!(
133 f,
134 "normalization failed (input {} bytes): {}",
135 self.input_len, self.reason
136 )
137 }
138}
139
140impl std::error::Error for NormalizationError {}
141
142pub fn extract_json(raw: &str) -> Result<NormalizedOutput, NormalizationError> {
153 let trimmed = raw.trim();
154
155 if trimmed.is_empty() {
156 return Err(NormalizationError {
157 reason: "empty input".to_string(),
158 input_len: 0,
159 });
160 }
161
162 if let Some(body) = extract_fenced_json(trimmed) {
164 return Ok(NormalizedOutput {
165 json_body: body,
166 method: ExtractionMethod::FencedJson,
167 });
168 }
169
170 if let Some(body) = extract_generic_fence_json(trimmed) {
172 return Ok(NormalizedOutput {
173 json_body: body,
174 method: ExtractionMethod::GenericFence,
175 });
176 }
177
178 if trimmed.starts_with('{') || trimmed.starts_with('[') {
180 return Ok(NormalizedOutput {
181 json_body: trimmed.to_string(),
182 method: ExtractionMethod::DirectJson,
183 });
184 }
185
186 if let Some(body) = extract_embedded_json(trimmed) {
188 return Ok(NormalizedOutput {
189 json_body: body,
190 method: ExtractionMethod::EmbeddedJson,
191 });
192 }
193
194 Err(NormalizationError {
195 reason: "no JSON object or array found in response".to_string(),
196 input_len: raw.len(),
197 })
198}
199
200pub fn extract_and_deserialize<T: DeserializeOwned>(
202 raw: &str,
203) -> Result<(T, ExtractionMethod), NormalizationError> {
204 let output = extract_json(raw)?;
205 match serde_json::from_str::<T>(&output.json_body) {
206 Ok(value) => Ok((value, output.method)),
207 Err(e) => Err(NormalizationError {
208 reason: format!(
209 "JSON extracted via {} but deserialization failed: {}",
210 output.method, e
211 ),
212 input_len: raw.len(),
213 }),
214 }
215}
216
217fn extract_fenced_json(input: &str) -> Option<String> {
223 let marker = "```json";
224 let start_idx = input.find(marker)?;
225 let body_start = start_idx + marker.len();
226
227 let remaining = &input[body_start..];
229 let remaining = remaining.strip_prefix('\n').unwrap_or(remaining);
230
231 let end_offset = remaining.find("```")?;
232 let body = remaining[..end_offset].trim();
233 if body.is_empty() {
234 return None;
235 }
236 Some(body.to_string())
237}
238
239fn extract_generic_fence_json(input: &str) -> Option<String> {
241 let marker = "```";
242 let start_idx = input.find(marker)?;
243 let after_marker = start_idx + marker.len();
244
245 let remaining = &input[after_marker..];
247 let body_start = remaining.find('\n').map(|n| n + 1).unwrap_or(0);
248 let remaining = &remaining[body_start..];
249
250 let end_offset = remaining.find("```")?;
251 let body = remaining[..end_offset].trim();
252
253 if body.starts_with('{') || body.starts_with('[') {
255 Some(body.to_string())
256 } else {
257 None
258 }
259}
260
261fn extract_embedded_json(input: &str) -> Option<String> {
264 let open = input.find('{')?;
265 let mut depth = 0i32;
267 let mut in_string = false;
268 let mut escape_next = false;
269 let mut close = None;
270
271 for (i, ch) in input[open..].char_indices() {
272 if escape_next {
273 escape_next = false;
274 continue;
275 }
276 match ch {
277 '\\' if in_string => {
278 escape_next = true;
279 }
280 '"' => {
281 in_string = !in_string;
282 }
283 '{' if !in_string => {
284 depth += 1;
285 }
286 '}' if !in_string => {
287 depth -= 1;
288 if depth == 0 {
289 close = Some(open + i);
290 break;
291 }
292 }
293 _ => {}
294 }
295 }
296
297 let close = close?;
298 let body = &input[open..=close];
299 Some(body.to_string())
300}
301
302#[derive(Debug, Clone, PartialEq, Eq)]
312pub struct FileMarker {
313 pub path: Option<String>,
315 pub content: String,
317 pub is_diff: bool,
319}
320
321pub fn extract_file_markers(raw: &str) -> Vec<FileMarker> {
337 use crate::path::normalize_artifact_path;
338
339 let mut markers = Vec::new();
340 let mut current_path: Option<String> = None;
341 let mut current_is_diff = false;
342 let mut current_content = String::new();
343 let mut in_fence = false;
344 let mut fence_content = String::new();
345 let mut had_heading = false;
346
347 for line in raw.lines() {
348 let trimmed = line.trim();
349
350 if let Some(heading) = parse_file_heading(trimmed) {
352 if had_heading || !current_content.trim().is_empty() {
354 flush_marker(
355 &mut markers,
356 ¤t_path,
357 ¤t_content,
358 current_is_diff,
359 );
360 }
361
362 let (path_raw, is_diff) = heading;
363 current_path = normalize_artifact_path(&path_raw).ok().or(Some(path_raw));
364 current_is_diff = is_diff;
365 current_content.clear();
366 had_heading = true;
367 continue;
368 }
369
370 if trimmed.starts_with("```") {
372 if in_fence {
373 in_fence = false;
375 if !fence_content.trim().is_empty() {
376 if !current_content.is_empty() {
377 current_content.push('\n');
378 }
379 current_content.push_str(fence_content.trim());
380 }
381 fence_content.clear();
382 } else {
383 in_fence = true;
385 fence_content.clear();
386 }
387 continue;
388 }
389
390 if in_fence {
391 if !fence_content.is_empty() {
392 fence_content.push('\n');
393 }
394 fence_content.push_str(line);
395 } else if had_heading {
396 if !current_content.is_empty() {
398 current_content.push('\n');
399 }
400 current_content.push_str(line);
401 }
402 }
403
404 if in_fence && !fence_content.trim().is_empty() {
406 if !current_content.is_empty() {
407 current_content.push('\n');
408 }
409 current_content.push_str(fence_content.trim());
410 }
411
412 if had_heading || !current_content.trim().is_empty() {
414 flush_marker(
415 &mut markers,
416 ¤t_path,
417 ¤t_content,
418 current_is_diff,
419 );
420 }
421
422 markers
423}
424
425fn parse_file_heading(line: &str) -> Option<(String, bool)> {
427 let stripped = line.trim_start_matches('#').trim();
429
430 let (rest, is_diff) = if let Some(rest) = stripped.strip_prefix("File:") {
432 (rest, false)
433 } else if let Some(rest) = stripped.strip_prefix("Diff:") {
434 (rest, true)
435 } else {
436 return None;
437 };
438
439 let path = rest
440 .trim()
441 .trim_matches('`')
442 .trim_matches('"')
443 .trim_matches('\'')
444 .to_string();
445
446 if path.is_empty() {
447 return None;
448 }
449
450 Some((path, is_diff))
451}
452
453fn flush_marker(
454 markers: &mut Vec<FileMarker>,
455 path: &Option<String>,
456 content: &str,
457 is_diff: bool,
458) {
459 let trimmed = content.trim();
460 if trimmed.is_empty() && path.is_none() {
461 return;
462 }
463 markers.push(FileMarker {
464 path: path.clone(),
465 content: trimmed.to_string(),
466 is_diff,
467 });
468}
469
470#[cfg(test)]
471mod tests {
472 use super::*;
473
474 #[test]
477 fn test_direct_json_object() {
478 let raw = r#"{"tasks": [{"id": "1"}]}"#;
479 let out = extract_json(raw).unwrap();
480 assert_eq!(out.method, ExtractionMethod::DirectJson);
481 assert_eq!(out.json_body, raw);
482 }
483
484 #[test]
485 fn test_direct_json_array() {
486 let raw = r#"[{"id": 1}]"#;
487 let out = extract_json(raw).unwrap();
488 assert_eq!(out.method, ExtractionMethod::DirectJson);
489 }
490
491 #[test]
492 fn test_fenced_json() {
493 let raw = "Here is the plan:\n```json\n{\"tasks\": []}\n```\nDone.";
494 let out = extract_json(raw).unwrap();
495 assert_eq!(out.method, ExtractionMethod::FencedJson);
496 assert_eq!(out.json_body, "{\"tasks\": []}");
497 }
498
499 #[test]
500 fn test_generic_fence_with_json() {
501 let raw = "Result:\n```\n{\"artifacts\": []}\n```";
502 let out = extract_json(raw).unwrap();
503 assert_eq!(out.method, ExtractionMethod::GenericFence);
504 assert_eq!(out.json_body, "{\"artifacts\": []}");
505 }
506
507 #[test]
508 fn test_generic_fence_with_language_hint() {
509 let raw = "```rust\nfn main() {}\n```";
510 let result = extract_json(raw);
514 if let Ok(out) = &result {
517 assert_ne!(out.method, ExtractionMethod::GenericFence);
518 }
519 }
520
521 #[test]
522 fn test_embedded_json_with_wrapper_text() {
523 let raw = "Sure! Here is the bundle:\n{\"artifacts\": [{\"path\": \"main.rs\", \"operation\": \"write\", \"content\": \"fn main() {}\"}]}\nLet me know if you need changes.";
524 let out = extract_json(raw).unwrap();
525 assert_eq!(out.method, ExtractionMethod::EmbeddedJson);
526 assert!(out.json_body.starts_with('{'));
527 assert!(out.json_body.ends_with('}'));
528 }
529
530 #[test]
531 fn test_embedded_json_with_nested_braces() {
532 let raw = "Plan: {\"a\": {\"b\": {\"c\": 1}}} end";
533 let out = extract_json(raw).unwrap();
534 assert_eq!(out.method, ExtractionMethod::EmbeddedJson);
535 assert_eq!(out.json_body, "{\"a\": {\"b\": {\"c\": 1}}}");
536 }
537
538 #[test]
539 fn test_embedded_json_with_strings_containing_braces() {
540 let raw = r#"Output: {"msg": "hello { world }"} done"#;
541 let out = extract_json(raw).unwrap();
542 assert_eq!(out.method, ExtractionMethod::EmbeddedJson);
543 assert_eq!(out.json_body, r#"{"msg": "hello { world }"}"#);
544 }
545
546 #[test]
547 fn test_empty_input() {
548 let result = extract_json("");
549 assert!(result.is_err());
550 }
551
552 #[test]
553 fn test_no_json_at_all() {
554 let result = extract_json("This is just a plain text response with no JSON.");
555 assert!(result.is_err());
556 }
557
558 #[test]
559 fn test_fenced_json_takes_priority_over_embedded() {
560 let raw = "Preamble {\"stray\": 1}\n```json\n{\"real\": 2}\n```";
561 let out = extract_json(raw).unwrap();
562 assert_eq!(out.method, ExtractionMethod::FencedJson);
563 assert_eq!(out.json_body, "{\"real\": 2}");
564 }
565
566 #[test]
569 fn test_extract_and_deserialize_ok() {
570 #[derive(serde::Deserialize)]
571 struct Simple {
572 value: i32,
573 }
574 let raw = "```json\n{\"value\": 42}\n```";
575 let (obj, method): (Simple, _) = extract_and_deserialize(raw).unwrap();
576 assert_eq!(obj.value, 42);
577 assert_eq!(method, ExtractionMethod::FencedJson);
578 }
579
580 #[test]
581 fn test_extract_and_deserialize_bad_schema() {
582 #[derive(Debug, serde::Deserialize)]
583 struct Strict {
584 #[allow(dead_code)]
585 required_field: String,
586 }
587 let raw = "{\"other\": 1}";
588 let result: Result<(Strict, _), _> = extract_and_deserialize(raw);
589 assert!(result.is_err());
590 let err = result.unwrap_err();
591 assert!(err.reason.contains("deserialization failed"));
592 }
593
594 #[test]
597 fn test_provider_family_classification() {
598 assert_eq!(
599 ProviderFamily::from_model_name("gpt-4o"),
600 ProviderFamily::OpenAI
601 );
602 assert_eq!(
603 ProviderFamily::from_model_name("claude-opus-4-20250514"),
604 ProviderFamily::Anthropic
605 );
606 assert_eq!(
607 ProviderFamily::from_model_name("gemini-2.5-pro"),
608 ProviderFamily::Gemini
609 );
610 assert_eq!(
611 ProviderFamily::from_model_name("deepseek-r1"),
612 ProviderFamily::DeepSeek
613 );
614 assert_eq!(
615 ProviderFamily::from_model_name("my-custom-model"),
616 ProviderFamily::Unknown
617 );
618 }
619
620 #[test]
621 fn test_extract_json_with_nested_code_fence() {
622 let raw = r#"
624Here is the plan I've created for you:
625
626```json
627{
628 "steps": [
629 {"id": "s1", "action": "create_file", "path": "src/lib.rs"},
630 {"id": "s2", "action": "run_tests", "path": "."}
631 ],
632 "description": "Create and verify a new library"
633}
634```
635
636Let me know if you'd like any changes.
637"#;
638 let output = extract_json(raw).unwrap();
639 assert_eq!(output.method, ExtractionMethod::FencedJson);
640 assert!(output.json_body.contains("create_file"));
641 assert!(output.json_body.contains("run_tests"));
642 }
643
644 #[test]
645 fn test_extract_and_deserialize_realistic_plan() {
646 #[derive(Debug, serde::Deserialize, PartialEq)]
647 struct Step {
648 id: String,
649 action: String,
650 }
651 #[derive(Debug, serde::Deserialize)]
652 struct Plan {
653 steps: Vec<Step>,
654 }
655
656 let raw = r#"Sure! ```json
657{"steps": [{"id": "1", "action": "lint"}, {"id": "2", "action": "test"}]}
658```"#;
659
660 let (plan, method): (Plan, _) = extract_and_deserialize(raw).unwrap();
661 assert_eq!(method, ExtractionMethod::FencedJson);
662 assert_eq!(plan.steps.len(), 2);
663 assert_eq!(plan.steps[0].action, "lint");
664 assert_eq!(plan.steps[1].action, "test");
665 }
666
667 #[test]
670 fn test_file_markers_basic() {
671 let raw = "\
672### File: src/main.rs
673```rust
674fn main() {}
675```
676### File: src/lib.rs
677```rust
678pub fn hello() {}
679```
680";
681 let markers = extract_file_markers(raw);
682 assert_eq!(markers.len(), 2);
683 assert_eq!(markers[0].path.as_deref(), Some("src/main.rs"));
684 assert_eq!(markers[0].content, "fn main() {}");
685 assert!(!markers[0].is_diff);
686 assert_eq!(markers[1].path.as_deref(), Some("src/lib.rs"));
687 assert_eq!(markers[1].content, "pub fn hello() {}");
688 }
689
690 #[test]
691 fn test_file_markers_with_diff() {
692 let raw = "\
693### Diff: src/lib.rs
694```diff
695- old line
696+ new line
697```
698";
699 let markers = extract_file_markers(raw);
700 assert_eq!(markers.len(), 1);
701 assert_eq!(markers[0].path.as_deref(), Some("src/lib.rs"));
702 assert!(markers[0].is_diff);
703 assert!(markers[0].content.contains("- old line"));
704 }
705
706 #[test]
707 fn test_file_markers_no_heading_prefix() {
708 let raw = "\
709File: src/main.rs
710```
711fn main() {}
712```
713";
714 let markers = extract_file_markers(raw);
715 assert_eq!(markers.len(), 1);
716 assert_eq!(markers[0].path.as_deref(), Some("src/main.rs"));
717 }
718
719 #[test]
720 fn test_file_markers_backtick_wrapped_path() {
721 let raw = "\
722### File: `src/main.rs`
723```rust
724fn main() {}
725```
726";
727 let markers = extract_file_markers(raw);
728 assert_eq!(markers.len(), 1);
729 assert_eq!(markers[0].path.as_deref(), Some("src/main.rs"));
730 }
731
732 #[test]
733 fn test_file_markers_empty_input() {
734 let markers = extract_file_markers("");
735 assert!(markers.is_empty());
736 }
737
738 #[test]
739 fn test_file_markers_no_headings_returns_empty() {
740 let raw = "Just some text with no file markers.";
741 let markers = extract_file_markers(raw);
742 assert!(markers.is_empty());
743 }
744
745 #[test]
746 fn test_file_markers_multiple_heading_levels() {
747 let raw = "\
748## File: src/a.rs
749content a
750### File: src/b.rs
751content b
752";
753 let markers = extract_file_markers(raw);
754 assert_eq!(markers.len(), 2);
755 assert_eq!(markers[0].path.as_deref(), Some("src/a.rs"));
756 assert_eq!(markers[1].path.as_deref(), Some("src/b.rs"));
757 }
758
759 #[test]
760 fn test_file_markers_path_normalization() {
761 let raw = "\
762### File: ./src/../src/main.rs
763```
764fn main() {}
765```
766";
767 let markers = extract_file_markers(raw);
768 assert_eq!(markers.len(), 1);
769 assert_eq!(markers[0].path.as_deref(), Some("src/main.rs"));
770 }
771}