Skip to main content

perspt_core/
llm_provider.rs

1//! # LLM Provider Module
2//!
3//! Thread-safe LLM provider abstraction for multi-agent use.
4//! Wraps genai::Client with Arc<RwLock<>> for shared state.
5
6use anyhow::{Context, Result};
7use futures::StreamExt;
8use genai::adapter::AdapterKind;
9use genai::chat::{ChatMessage, ChatRequest, ChatStreamEvent};
10use genai::Client;
11use std::sync::Arc;
12use std::time::Instant;
13use tokio::sync::{mpsc, RwLock};
14
15/// End of transmission signal
16pub const EOT_SIGNAL: &str = "<|EOT|>";
17
18/// Response from a non-streaming LLM call, carrying text and token usage.
19#[derive(Debug, Clone)]
20pub struct LlmResponse {
21    pub text: String,
22    pub tokens_in: Option<i32>,
23    pub tokens_out: Option<i32>,
24}
25
26/// Shared state for rate limiting and token counting
27#[derive(Default)]
28struct SharedState {
29    total_tokens_used: usize,
30    request_count: usize,
31}
32
33/// Thread-safe LLM provider implementation using Arc<RwLock<>>.
34///
35/// This provider can be cheaply cloned and shared across multiple agents.
36/// Each clone shares the same underlying client and rate limiting state.
37#[derive(Clone)]
38pub struct GenAIProvider {
39    /// The underlying genai client
40    client: Arc<Client>,
41    /// Shared state for rate limiting and metrics
42    shared: Arc<RwLock<SharedState>>,
43}
44
45impl GenAIProvider {
46    /// Creates a new GenAI provider with automatic configuration.
47    pub fn new() -> Result<Self> {
48        let client = Client::default();
49        Ok(Self {
50            client: Arc::new(client),
51            shared: Arc::new(RwLock::new(SharedState::default())),
52        })
53    }
54
55    /// Creates a new GenAI provider with explicit configuration.
56    pub fn new_with_config(provider_type: Option<&str>, api_key: Option<&str>) -> Result<Self> {
57        // Set environment variable if API key is provided
58        if let (Some(provider), Some(key)) = (provider_type, api_key) {
59            let env_var = match provider {
60                "openai" => "OPENAI_API_KEY",
61                "anthropic" => "ANTHROPIC_API_KEY",
62                "gemini" => "GEMINI_API_KEY",
63                "groq" => "GROQ_API_KEY",
64                "cohere" => "COHERE_API_KEY",
65                "xai" => "XAI_API_KEY",
66                "deepseek" => "DEEPSEEK_API_KEY",
67                "ollama" => {
68                    log::info!("Ollama provider detected - no API key required for local setup");
69                    return Self::new();
70                }
71                _ => {
72                    log::warn!("Unknown provider type for API key: {provider}");
73                    return Self::new();
74                }
75            };
76
77            log::info!("Setting {env_var} environment variable for genai client");
78            std::env::set_var(env_var, key);
79        }
80
81        Self::new()
82    }
83
84    /// Get total tokens used across all requests
85    pub async fn get_total_tokens_used(&self) -> usize {
86        self.shared.read().await.total_tokens_used
87    }
88
89    /// Get total request count
90    pub async fn get_request_count(&self) -> usize {
91        self.shared.read().await.request_count
92    }
93
94    /// Increment request counter (for metrics)
95    async fn increment_request(&self) {
96        let mut state = self.shared.write().await;
97        state.request_count += 1;
98    }
99
100    /// Add tokens to the total count
101    pub async fn add_tokens(&self, count: usize) {
102        let mut state = self.shared.write().await;
103        state.total_tokens_used += count;
104    }
105
106    /// Retrieves all available models for a specific provider.
107    pub async fn get_available_models(&self, provider: &str) -> Result<Vec<String>> {
108        let adapter_kind = str_to_adapter_kind(provider)?;
109
110        let models = self
111            .client
112            .all_model_names(adapter_kind)
113            .await
114            .context(format!("Failed to get models for provider: {provider}"))?;
115
116        Ok(models)
117    }
118
119    /// Generates a simple text response without streaming.
120    /// Includes exponential backoff retry for rate limits and transient errors.
121    pub async fn generate_response_simple(&self, model: &str, prompt: &str) -> Result<LlmResponse> {
122        self.generate_response_with_retry(model, prompt, 3).await
123    }
124
125    /// Generates a response with configurable retry count and exponential backoff.
126    pub async fn generate_response_with_retry(
127        &self,
128        model: &str,
129        prompt: &str,
130        max_retries: usize,
131    ) -> Result<LlmResponse> {
132        self.increment_request().await;
133
134        let chat_req = ChatRequest::default().append_message(ChatMessage::user(prompt));
135
136        log::debug!(
137            "Sending chat request to model: {model} with prompt length: {} chars",
138            prompt.len()
139        );
140
141        let start_time = Instant::now();
142        let mut last_error: Option<anyhow::Error> = None;
143        let mut retry_count = 0;
144
145        while retry_count <= max_retries {
146            if retry_count > 0 {
147                // Exponential backoff: 1s, 2s, 4s, 8s, ... (capped at 16s)
148                let delay_secs = std::cmp::min(1u64 << (retry_count - 1), 16);
149                log::warn!(
150                    "Retry {}/{} for model {} after {}s delay (previous error: {:?})",
151                    retry_count,
152                    max_retries,
153                    model,
154                    delay_secs,
155                    last_error.as_ref().map(|e| e.to_string())
156                );
157                println!(
158                    "   ⏳ Rate limited, retrying in {}s (attempt {}/{})",
159                    delay_secs, retry_count, max_retries
160                );
161                tokio::time::sleep(tokio::time::Duration::from_secs(delay_secs)).await;
162            }
163
164            match self.client.exec_chat(model, chat_req.clone(), None).await {
165                Ok(chat_res) => {
166                    let tokens_in = chat_res.usage.prompt_tokens;
167                    let tokens_out = chat_res.usage.completion_tokens;
168                    let content = chat_res
169                        .first_text()
170                        .context("No text content in response")?;
171                    log::debug!(
172                        "Received response with {} characters in {}ms (tokens: in={:?}, out={:?})",
173                        content.len(),
174                        start_time.elapsed().as_millis(),
175                        tokens_in,
176                        tokens_out,
177                    );
178
179                    // Update shared token counter with real values when available
180                    let total = tokens_in.unwrap_or(0) + tokens_out.unwrap_or(0);
181                    if total > 0 {
182                        self.add_tokens(total as usize).await;
183                    }
184
185                    return Ok(LlmResponse {
186                        text: content.to_string(),
187                        tokens_in,
188                        tokens_out,
189                    });
190                }
191                Err(e) => {
192                    let err_str = e.to_string();
193
194                    // Check if it's a retryable error (rate limit, server error, network)
195                    let is_retryable = err_str.contains("429")
196                        || err_str.contains("rate limit")
197                        || err_str.contains("Rate limit")
198                        || err_str.contains("RESOURCE_EXHAUSTED")
199                        || err_str.contains("500")
200                        || err_str.contains("502")
201                        || err_str.contains("503")
202                        || err_str.contains("504")
203                        || err_str.contains("timeout")
204                        || err_str.contains("connection");
205
206                    if is_retryable && retry_count < max_retries {
207                        log::warn!("Retryable error for model {}: {}", model, err_str);
208                        last_error = Some(anyhow::anyhow!("{}", err_str));
209                        retry_count += 1;
210                        continue;
211                    } else {
212                        return Err(anyhow::anyhow!(
213                            "Failed to execute chat request for model {}: {}",
214                            model,
215                            err_str
216                        ));
217                    }
218                }
219            }
220        }
221
222        // Should not reach here, but handle gracefully
223        Err(last_error
224            .unwrap_or_else(|| anyhow::anyhow!("Unknown error after {} retries", max_retries)))
225    }
226
227    /// Generates a streaming response and sends chunks via mpsc channel.
228    pub async fn generate_response_stream_to_channel(
229        &self,
230        model: &str,
231        prompt: &str,
232        tx: mpsc::UnboundedSender<String>,
233    ) -> Result<()> {
234        self.increment_request().await;
235
236        let chat_req = ChatRequest::default().append_message(ChatMessage::user(prompt));
237
238        log::debug!("Sending streaming chat request to model: {model} with prompt: {prompt}");
239
240        let chat_res_stream = self
241            .client
242            .exec_chat_stream(model, chat_req, None)
243            .await
244            .context(format!(
245                "Failed to execute streaming chat request for model: {model}"
246            ))?;
247
248        let mut stream = chat_res_stream.stream;
249        let mut chunk_count = 0;
250        let mut total_content_length = 0;
251        let mut stream_ended_explicitly = false;
252        let start_time = Instant::now();
253
254        log::info!(
255            "=== STREAM START === Model: {}, Prompt length: {} chars",
256            model,
257            prompt.len()
258        );
259
260        while let Some(chunk_result) = stream.next().await {
261            let elapsed = start_time.elapsed();
262
263            match chunk_result {
264                Ok(ChatStreamEvent::Start) => {
265                    log::info!(">>> STREAM STARTED for model: {model} at {elapsed:?}");
266                }
267                Ok(ChatStreamEvent::Chunk(chunk)) => {
268                    chunk_count += 1;
269                    total_content_length += chunk.content.len();
270
271                    if chunk_count % 10 == 0 || chunk.content.len() > 100 {
272                        log::info!(
273                            "CHUNK #{}: {} chars, total: {} chars, elapsed: {:?}",
274                            chunk_count,
275                            chunk.content.len(),
276                            total_content_length,
277                            elapsed
278                        );
279                    }
280
281                    if !chunk.content.is_empty() && tx.send(chunk.content.clone()).is_err() {
282                        log::error!(
283                            "!!! CHANNEL SEND FAILED for chunk #{chunk_count} - STOPPING STREAM !!!"
284                        );
285                        break;
286                    }
287                }
288                Ok(ChatStreamEvent::ReasoningChunk(chunk)) => {
289                    log::info!(
290                        "REASONING CHUNK: {} chars at {:?}",
291                        chunk.content.len(),
292                        elapsed
293                    );
294                }
295                Ok(ChatStreamEvent::End(_)) => {
296                    log::info!(">>> STREAM ENDED EXPLICITLY for model: {model} after {chunk_count} chunks, {total_content_length} chars, {elapsed:?} elapsed");
297                    stream_ended_explicitly = true;
298                    break;
299                }
300                Ok(ChatStreamEvent::ToolCallChunk(_)) => {
301                    log::debug!("Tool call chunk received (ignored)");
302                }
303                Ok(ChatStreamEvent::ThoughtSignatureChunk(_)) => {
304                    log::debug!("Thought signature chunk received (ignored)");
305                }
306                Err(e) => {
307                    log::error!(
308                        "!!! STREAM ERROR after {chunk_count} chunks at {elapsed:?}: {e} !!!"
309                    );
310                    let error_msg = format!("Stream error: {e}");
311                    let _ = tx.send(error_msg);
312                    return Err(e.into());
313                }
314            }
315        }
316
317        let final_elapsed = start_time.elapsed();
318        if !stream_ended_explicitly {
319            log::warn!("!!! STREAM ENDED IMPLICITLY (exhausted) for model: {model} after {chunk_count} chunks, {total_content_length} chars, {final_elapsed:?} elapsed !!!");
320        }
321
322        log::info!(
323            "=== STREAM COMPLETE === Model: {model}, Final: {chunk_count} chunks, {total_content_length} chars, {final_elapsed:?} elapsed"
324        );
325
326        // Add approximate token count
327        self.add_tokens(total_content_length / 4).await; // Rough estimate
328
329        if tx.send(EOT_SIGNAL.to_string()).is_err() {
330            log::error!("!!! FAILED TO SEND EOT SIGNAL - channel may be closed !!!");
331            return Err(anyhow::anyhow!("Channel closed during EOT signal send"));
332        }
333
334        log::info!(">>> EOT SIGNAL SENT for model: {model} <<<");
335        Ok(())
336    }
337
338    /// Get a list of supported providers
339    pub fn get_supported_providers() -> Vec<&'static str> {
340        vec![
341            "openai",
342            "anthropic",
343            "gemini",
344            "groq",
345            "cohere",
346            "ollama",
347            "xai",
348            "deepseek",
349        ]
350    }
351
352    /// Get all available providers
353    pub async fn get_available_providers(&self) -> Result<Vec<String>> {
354        Ok(Self::get_supported_providers()
355            .iter()
356            .map(|s| s.to_string())
357            .collect())
358    }
359
360    /// Test if a model is available and working
361    pub async fn test_model(&self, model: &str) -> Result<bool> {
362        match self.generate_response_simple(model, "Hello").await {
363            Ok(_) => {
364                log::info!("Model {model} is available and working");
365                Ok(true)
366            }
367            Err(e) => {
368                log::warn!("Model {model} test failed: {e}");
369                Ok(false)
370            }
371        }
372    }
373
374    /// Validate and get the best available model for a provider
375    pub async fn validate_model(&self, model: &str, provider_type: Option<&str>) -> Result<String> {
376        if self.test_model(model).await? {
377            return Ok(model.to_string());
378        }
379
380        if let Some(provider) = provider_type {
381            if let Ok(models) = self.get_available_models(provider).await {
382                if !models.is_empty() {
383                    log::info!("Model {} not available, using {} instead", model, models[0]);
384                    return Ok(models[0].clone());
385                }
386            }
387        }
388
389        log::warn!("Could not validate model {model}, proceeding anyway");
390        Ok(model.to_string())
391    }
392}
393
394/// Convert a provider string to genai AdapterKind
395fn str_to_adapter_kind(provider: &str) -> Result<AdapterKind> {
396    match provider.to_lowercase().as_str() {
397        "openai" => Ok(AdapterKind::OpenAI),
398        "anthropic" => Ok(AdapterKind::Anthropic),
399        "gemini" | "google" => Ok(AdapterKind::Gemini),
400        "groq" => Ok(AdapterKind::Groq),
401        "cohere" => Ok(AdapterKind::Cohere),
402        "ollama" => Ok(AdapterKind::Ollama),
403        "xai" => Ok(AdapterKind::Xai),
404        "deepseek" => Ok(AdapterKind::DeepSeek),
405        _ => Err(anyhow::anyhow!("Unsupported provider: {}", provider)),
406    }
407}
408
409#[cfg(test)]
410mod tests {
411    use super::*;
412
413    #[test]
414    fn test_str_to_adapter_kind() {
415        assert!(str_to_adapter_kind("openai").is_ok());
416        assert!(str_to_adapter_kind("anthropic").is_ok());
417        assert!(str_to_adapter_kind("gemini").is_ok());
418        assert!(str_to_adapter_kind("google").is_ok());
419        assert!(str_to_adapter_kind("groq").is_ok());
420        assert!(str_to_adapter_kind("cohere").is_ok());
421        assert!(str_to_adapter_kind("ollama").is_ok());
422        assert!(str_to_adapter_kind("xai").is_ok());
423        assert!(str_to_adapter_kind("deepseek").is_ok());
424        assert!(str_to_adapter_kind("invalid").is_err());
425    }
426
427    #[tokio::test]
428    async fn test_provider_creation() {
429        let provider = GenAIProvider::new();
430        assert!(provider.is_ok());
431    }
432
433    #[tokio::test]
434    async fn test_provider_is_clonable() {
435        let provider = GenAIProvider::new().unwrap();
436        let _clone1 = provider.clone();
437        let _clone2 = provider.clone();
438        // All clones share the same underlying state
439    }
440}