diff --git a/Cargo.lock b/Cargo.lock index 92eeb7d..426cff6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -590,13 +590,13 @@ dependencies = [ "futures", "indicatif", "mockall", + "once_cell", "regex", "reqwest", "rustyline", "serde", "serde_json", "serial_test", - "signal-hook", "syntect", "tempfile", "tokio", @@ -1585,16 +1585,6 @@ version = "1.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" -[[package]] -name = "signal-hook" -version = "0.3.18" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d881a16cf4426aa584979d30bd82cb33429027e42122b169753d6ef1085ed6e2" -dependencies = [ - "libc", - "signal-hook-registry", -] - [[package]] name = "signal-hook-registry" version = "1.4.6" diff --git a/Cargo.toml b/Cargo.toml index 6285965..014d305 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,7 +23,8 @@ syntect = "5.1" regex = "1.0" futures = "0.3" tokio-stream = "0.1" -signal-hook = "0.3" +once_cell = "1.21" + [dev-dependencies] tempfile = "3.0" diff --git a/src/cli.rs b/src/cli.rs index 916c0c2..1f6a559 100644 --- a/src/cli.rs +++ b/src/cli.rs @@ -2,7 +2,8 @@ use anyhow::Result; use crate::config::Config; use crate::core::{ - create_client, get_provider_for_model, provider::{get_model_info_list, get_display_name_for_model, get_model_id_from_display_name}, + create_client, get_provider_for_model, + provider::{get_display_name_for_model, get_model_id_from_display_name, get_model_info_list}, ChatClient, Session, }; use crate::utils::{Display, InputHandler, SessionAction}; @@ -41,14 +42,16 @@ impl ChatCLI { pub async fn run(&mut self) -> Result<()> { self.display.print_header(); - self.display.print_info("Type your message and press Enter. Commands start with '/'."); + self.display + .print_info("Type your message and press Enter. Commands start with '/'."); self.display.print_info("Type /help for help."); - + let provider = get_provider_for_model(&self.session.model); let display_name = get_display_name_for_model(&self.session.model); - self.display.print_model_info(&display_name, provider.as_str()); + self.display + .print_model_info(&display_name, provider.as_str()); self.display.print_session_info(&self.session.name); - + println!(); loop { @@ -77,6 +80,10 @@ impl ChatCLI { } } + self.save_and_cleanup() + } + + pub fn save_and_cleanup(&mut self) -> Result<()> { self.session.save()?; self.input.cleanup()?; // Use cleanup instead of just save_history Ok(()) @@ -86,6 +93,16 @@ impl ChatCLI { self.session.add_user_message(message.to_string()); self.session.save()?; + + // Clone data needed for the API call before getting mutable client reference + let model = self.session.model.clone(); + let messages = self.session.messages.clone(); + let enable_web_search = self.session.enable_web_search; + let enable_reasoning_summary = self.session.enable_reasoning_summary; + let reasoning_effort = self.session.reasoning_effort.clone(); + let enable_extended_thinking = self.session.enable_extended_thinking; + let thinking_budget_tokens = self.session.thinking_budget_tokens; + // Check if we should use streaming before getting client let should_use_streaming = { let client = self.get_client()?; @@ -97,7 +114,7 @@ impl ChatCLI { print!("{}> ", console::style("🤖").magenta()); use std::io::{self, Write}; io::stdout().flush().ok(); - + let stream_callback = { use crate::core::StreamCallback; Box::new(move |chunk: &str| { @@ -107,9 +124,11 @@ impl ChatCLI { Box::pin(async move {}) as Pin + Send>> }) as StreamCallback }; + let client = self.get_client()?.clone(); let response = client + .chat_completion_stream( &self.session.model, &self.session.messages, @@ -129,15 +148,18 @@ impl ChatCLI { } Err(e) => { println!(); // Add newline after failed streaming - self.display.print_error(&format!("Streaming failed: {}", e)); + self.display + .print_error(&format!("Streaming failed: {}", e)); return Err(e); } } } else { // Fallback to non-streaming let spinner = self.display.show_spinner("Thinking"); + let client = self.get_client()?.clone(); let response = client + .chat_completion( &self.session.model, &self.session.messages, @@ -208,7 +230,8 @@ impl ChatCLI { self.handle_save_command(&parts)?; } _ => { - self.display.print_error(&format!("Unknown command: {} (see /help)", parts[0])); + self.display + .print_error(&format!("Unknown command: {} (see /help)", parts[0])); } } @@ -217,15 +240,18 @@ impl ChatCLI { async fn model_switcher(&mut self) -> Result<()> { let model_info_list = get_model_info_list(); - let display_names: Vec = model_info_list.iter().map(|info| info.display_name.to_string()).collect(); + let display_names: Vec = model_info_list + .iter() + .map(|info| info.display_name.to_string()) + .collect(); let current_display_name = get_display_name_for_model(&self.session.model); - + let selection = self.input.select_from_list( "Select a model:", &display_names, Some(¤t_display_name), )?; - + match selection { Some(display_name) => { if let Some(model_id) = get_model_id_from_display_name(&display_name) { @@ -250,11 +276,10 @@ impl ChatCLI { self.display.print_info("Model selection cancelled"); } } - + Ok(()) } - fn handle_new_session(&mut self, parts: &[&str]) -> Result<()> { if parts.len() != 2 { self.display.print_error("Usage: /new "); @@ -264,18 +289,16 @@ impl ChatCLI { self.session.save()?; let new_session = Session::new(parts[1].to_string(), self.session.model.clone()); self.session = new_session; - self.display.print_command_result(&format!("New session '{}' started", self.session.name)); - + self.display + .print_command_result(&format!("New session '{}' started", self.session.name)); + Ok(()) } async fn session_manager(&mut self) -> Result<()> { loop { let sessions = Session::list_sessions()?; - let session_names: Vec = sessions - .into_iter() - .map(|(name, _)| name) - .collect(); + let session_names: Vec = sessions.into_iter().map(|(name, _)| name).collect(); if session_names.is_empty() { self.display.print_info("No sessions available"); @@ -294,7 +317,7 @@ impl ChatCLI { self.display.print_info("Already in that session"); return Ok(()); } - + self.session.save()?; match Session::load(&session_name) { Ok(session) => { @@ -308,7 +331,8 @@ impl ChatCLI { return Ok(()); } Err(e) => { - self.display.print_error(&format!("Failed to load session: {}", e)); + self.display + .print_error(&format!("Failed to load session: {}", e)); // Don't return, allow user to try again or cancel } } @@ -316,8 +340,11 @@ impl ChatCLI { SessionAction::Delete(session_name) => { match Session::delete_session(&session_name) { Ok(()) => { - self.display.print_command_result(&format!("Session '{}' deleted", session_name)); - + self.display.print_command_result(&format!( + "Session '{}' deleted", + session_name + )); + // If we deleted the current session, we need to handle this specially if session_name == self.session.name { // Try to switch to another session or create a default one @@ -326,18 +353,23 @@ impl ChatCLI { .into_iter() .map(|(name, _)| name) .collect(); - + if remaining_names.is_empty() { // No sessions left, create a default one - self.session = Session::new("default".to_string(), self.session.model.clone()); - self.display.print_command_result("Created new default session"); + self.session = Session::new( + "default".to_string(), + self.session.model.clone(), + ); + self.display + .print_command_result("Created new default session"); return Ok(()); } else { // Switch to the first available session match Session::load(&remaining_names[0]) { Ok(session) => { self.session = session; - let display_name = get_display_name_for_model(&self.session.model); + let display_name = + get_display_name_for_model(&self.session.model); self.display.print_command_result(&format!( "Switched to session '{}' (model={})", self.session.name, display_name @@ -346,10 +378,18 @@ impl ChatCLI { return Ok(()); } Err(e) => { - self.display.print_error(&format!("Failed to load fallback session: {}", e)); + self.display.print_error(&format!( + "Failed to load fallback session: {}", + e + )); // Create a new default session as fallback - self.session = Session::new("default".to_string(), self.session.model.clone()); - self.display.print_command_result("Created new default session"); + self.session = Session::new( + "default".to_string(), + self.session.model.clone(), + ); + self.display.print_command_result( + "Created new default session", + ); return Ok(()); } } @@ -358,7 +398,8 @@ impl ChatCLI { // Continue to show updated session list if we didn't delete current session } Err(e) => { - self.display.print_error(&format!("Failed to delete session: {}", e)); + self.display + .print_error(&format!("Failed to delete session: {}", e)); // Continue to allow retry } } @@ -373,7 +414,8 @@ impl ChatCLI { )); } Err(e) => { - self.display.print_error(&format!("Failed to set default session: {}", e)); + self.display + .print_error(&format!("Failed to set default session: {}", e)); } } // Continue to show session list @@ -385,32 +427,48 @@ impl ChatCLI { } } - async fn tools_manager(&mut self) -> Result<()> { loop { // Show current tool status self.display.print_info("Tool Management:"); - - let web_status = if self.session.enable_web_search { "✓ enabled" } else { "✗ disabled" }; - let reasoning_status = if self.session.enable_reasoning_summary { "✓ enabled" } else { "✗ disabled" }; - let extended_thinking_status = if self.session.enable_extended_thinking { "✓ enabled" } else { "✗ disabled" }; - + + let web_status = if self.session.enable_web_search { + "✓ enabled" + } else { + "✗ disabled" + }; + let reasoning_status = if self.session.enable_reasoning_summary { + "✓ enabled" + } else { + "✗ disabled" + }; + let extended_thinking_status = if self.session.enable_extended_thinking { + "✓ enabled" + } else { + "✗ disabled" + }; + println!(" Web Search: {}", web_status); println!(" Reasoning Summaries: {}", reasoning_status); println!(" Reasoning Effort: {}", self.session.reasoning_effort); println!(" Extended Thinking: {}", extended_thinking_status); - println!(" Thinking Budget: {} tokens", self.session.thinking_budget_tokens); - + println!( + " Thinking Budget: {} tokens", + self.session.thinking_budget_tokens + ); + // Check model compatibility let model = self.session.model.clone(); let provider = get_provider_for_model(&model); let reasoning_enabled = self.session.enable_reasoning_summary; - + // Show compatibility warnings based on provider match provider { crate::core::provider::Provider::Anthropic => { if reasoning_enabled { - self.display.print_warning("Reasoning summaries are not supported by Anthropic models"); + self.display.print_warning( + "Reasoning summaries are not supported by Anthropic models", + ); } if self.session.enable_extended_thinking { // Extended thinking is supported by Anthropic models @@ -419,38 +477,47 @@ impl ChatCLI { } crate::core::provider::Provider::OpenAI => { if self.session.enable_extended_thinking { - self.display.print_warning("Extended thinking is not supported by OpenAI models"); + self.display + .print_warning("Extended thinking is not supported by OpenAI models"); } // OpenAI models generally support other features } } - + // Tool management options let options = vec![ "Toggle Web Search", - "Toggle Reasoning Summaries", + "Toggle Reasoning Summaries", "Set Reasoning Effort", "Toggle Extended Thinking", "Set Thinking Budget", - "Done" + "Done", ]; - - let selection = self.input.select_from_list( - "Select an option:", - &options, - None, - )?; - + + let selection = self + .input + .select_from_list("Select an option:", &options, None)?; + match selection.as_deref() { Some("Toggle Web Search") => { self.session.enable_web_search = !self.session.enable_web_search; - let state = if self.session.enable_web_search { "enabled" } else { "disabled" }; - self.display.print_command_result(&format!("Web search {}", state)); + let state = if self.session.enable_web_search { + "enabled" + } else { + "disabled" + }; + self.display + .print_command_result(&format!("Web search {}", state)); } Some("Toggle Reasoning Summaries") => { self.session.enable_reasoning_summary = !self.session.enable_reasoning_summary; - let state = if self.session.enable_reasoning_summary { "enabled" } else { "disabled" }; - self.display.print_command_result(&format!("Reasoning summaries {}", state)); + let state = if self.session.enable_reasoning_summary { + "enabled" + } else { + "disabled" + }; + self.display + .print_command_result(&format!("Reasoning summaries {}", state)); } Some("Set Reasoning Effort") => { let effort_options = vec!["low", "medium", "high"]; @@ -460,22 +527,32 @@ impl ChatCLI { Some(&self.session.reasoning_effort), )? { self.session.reasoning_effort = effort.to_string(); - self.display.print_command_result(&format!("Reasoning effort set to {}", effort)); - + self.display + .print_command_result(&format!("Reasoning effort set to {}", effort)); + if !self.session.model.starts_with("gpt-5") { - self.display.print_warning("Reasoning effort is only supported by GPT-5 models"); + self.display.print_warning( + "Reasoning effort is only supported by GPT-5 models", + ); } } } Some("Toggle Extended Thinking") => { self.session.enable_extended_thinking = !self.session.enable_extended_thinking; - let state = if self.session.enable_extended_thinking { "enabled" } else { "disabled" }; - self.display.print_command_result(&format!("Extended thinking {}", state)); - + let state = if self.session.enable_extended_thinking { + "enabled" + } else { + "disabled" + }; + self.display + .print_command_result(&format!("Extended thinking {}", state)); + let provider = get_provider_for_model(&self.session.model); match provider { crate::core::provider::Provider::OpenAI => { - self.display.print_warning("Extended thinking is not supported by OpenAI models"); + self.display.print_warning( + "Extended thinking is not supported by OpenAI models", + ); } crate::core::provider::Provider::Anthropic => { // Supported @@ -492,12 +569,17 @@ impl ChatCLI { )? { if let Ok(budget) = budget_str.parse::() { self.session.thinking_budget_tokens = budget; - self.display.print_command_result(&format!("Thinking budget set to {} tokens", budget)); - + self.display.print_command_result(&format!( + "Thinking budget set to {} tokens", + budget + )); + let provider = get_provider_for_model(&self.session.model); match provider { crate::core::provider::Provider::OpenAI => { - self.display.print_warning("Extended thinking is not supported by OpenAI models"); + self.display.print_warning( + "Extended thinking is not supported by OpenAI models", + ); } crate::core::provider::Provider::Anthropic => { if budget < 1024 { @@ -513,18 +595,18 @@ impl ChatCLI { } _ => {} } - + self.session.save()?; // Save changes after each modification println!(); // Add spacing } - + Ok(()) } fn handle_history_command(&mut self, parts: &[&str]) -> Result<()> { let mut filter_role: Option<&str> = None; let mut limit: Option = None; - + // Parse parameters for &part in parts.iter().skip(1) { match part { @@ -533,37 +615,41 @@ impl ChatCLI { if let Ok(num) = part.parse::() { limit = Some(num); } else { - self.display.print_error(&format!("Invalid parameter: {}", part)); - self.display.print_info("Usage: /history [user|assistant] [number]"); + self.display + .print_error(&format!("Invalid parameter: {}", part)); + self.display + .print_info("Usage: /history [user|assistant] [number]"); return Ok(()); } } } } - + // Filter messages (skip system prompt at index 0) - let mut messages: Vec<(usize, &crate::core::Message)> = self.session.messages + let mut messages: Vec<(usize, &crate::core::Message)> = self + .session + .messages .iter() .enumerate() .skip(1) // Skip system prompt .collect(); - + // Apply role filter if let Some(role) = filter_role { messages.retain(|(_, msg)| msg.role == role); } - + // Apply limit if let Some(limit_count) = limit { let start_index = messages.len().saturating_sub(limit_count); messages = messages[start_index..].to_vec(); } - + if messages.is_empty() { self.display.print_info("No messages to display"); return Ok(()); } - + // Format and display self.display.print_conversation_history(&messages); Ok(()) @@ -580,7 +666,7 @@ impl ChatCLI { let valid_formats = ["markdown", "md", "json", "txt"]; if !valid_formats.contains(&format.as_str()) { self.display.print_error(&format!( - "Invalid format '{}'. Supported formats: markdown, json, txt", + "Invalid format '{}'. Supported formats: markdown, json, txt", format )); self.display.print_info("Usage: /export [format]"); @@ -592,19 +678,20 @@ impl ChatCLI { .duration_since(std::time::UNIX_EPOCH) .unwrap_or_default() .as_secs(); - + let extension = match format.as_str() { "markdown" | "md" => "md", "json" => "json", "txt" => "txt", _ => "md", }; - + // Create exports directory if it doesn't exist let exports_dir = std::path::Path::new("exports"); if !exports_dir.exists() { if let Err(e) = std::fs::create_dir(exports_dir) { - self.display.print_error(&format!("Failed to create exports directory: {}", e)); + self.display + .print_error(&format!("Failed to create exports directory: {}", e)); return Ok(()); } } @@ -612,10 +699,13 @@ impl ChatCLI { let filename = format!("{}_{}.{}", self.session.name, now, extension); let file_path = exports_dir.join(&filename); - match self.session.export(&format, file_path.to_str().unwrap_or(&filename)) { + match self + .session + .export(&format, file_path.to_str().unwrap_or(&filename)) + { Ok(()) => { self.display.print_command_result(&format!( - "Conversation exported to '{}'", + "Conversation exported to '{}'", file_path.display() )); } @@ -630,7 +720,8 @@ impl ChatCLI { fn handle_save_command(&mut self, parts: &[&str]) -> Result<()> { if parts.len() != 2 { self.display.print_error("Usage: /save "); - self.display.print_info("Example: /save my_important_conversation"); + self.display + .print_info("Example: /save my_important_conversation"); return Ok(()); } @@ -646,7 +737,7 @@ impl ChatCLI { if let Ok(sessions) = Session::list_sessions() { if sessions.iter().any(|(name, _)| name == new_session_name) { self.display.print_error(&format!( - "Session '{}' already exists. Choose a different name.", + "Session '{}' already exists. Choose a different name.", new_session_name )); return Ok(()); @@ -660,7 +751,7 @@ impl ChatCLI { match self.session.save_as(new_session_name) { Ok(()) => { self.display.print_command_result(&format!( - "Current session saved as '{}' ({} messages copied)", + "Current session saved as '{}' ({} messages copied)", new_session_name, self.session.messages.len().saturating_sub(1) // Exclude system prompt )); @@ -670,11 +761,11 @@ impl ChatCLI { )); } Err(e) => { - self.display.print_error(&format!("Failed to save session: {}", e)); + self.display + .print_error(&format!("Failed to save session: {}", e)); } } Ok(()) } - -} \ No newline at end of file +} diff --git a/src/core/session.rs b/src/core/session.rs index 08e14ca..2887232 100644 --- a/src/core/session.rs +++ b/src/core/session.rs @@ -1,5 +1,6 @@ use anyhow::{Context, Result}; use chrono::{DateTime, Utc}; +use once_cell::sync::OnceCell; use serde::{Deserialize, Serialize}; use std::fs; use std::path::PathBuf; @@ -12,6 +13,12 @@ ASCII art or concise bullet lists over heavy markup, and wrap code \ snippets in fenced blocks when helpful. Do not emit trailing spaces or \ control characters."; +static CONFIG: OnceCell = OnceCell::new(); + +fn get_config() -> &'static Config { + CONFIG.get_or_init(|| Config::load().unwrap_or_default()) +} + #[derive(Debug, Clone, Serialize, Deserialize)] pub struct Message { pub role: String, @@ -98,7 +105,7 @@ impl Session { } pub fn sessions_dir() -> Result { - let config = Config::load().unwrap_or_default(); + let config = get_config(); let home = dirs::home_dir().context("Could not find home directory")?; let sessions_dir = home.join(&config.session.sessions_dir_name); @@ -111,11 +118,12 @@ impl Session { } pub fn session_path(name: &str) -> Result { - let config = Config::load().unwrap_or_default(); + let config = get_config(); Ok(Self::sessions_dir()?.join(format!("{}.{}", name, config.session.file_extension))) } pub fn save(&self) -> Result<()> { + let _config = get_config(); let data = SessionData { model: self.model.clone(), messages: self.messages.clone(), @@ -143,6 +151,7 @@ impl Session { } pub fn load(name: &str) -> Result { + let _config = get_config(); let path = Self::session_path(name)?; if !path.exists() { @@ -195,7 +204,7 @@ impl Session { /// Truncates conversation history to stay within configured limits fn truncate_history_if_needed(&mut self) { - let config = Config::load().unwrap_or_default(); + let config = get_config(); let max_history = config.limits.max_conversation_history; // Always preserve the system prompt (first message) @@ -252,13 +261,14 @@ impl Session { } pub fn list_sessions() -> Result)>> { + let _config = get_config(); let sessions = Self::list_sessions_lazy(false)?; Ok(sessions.into_iter().map(|s| (s.name, s.last_modified)).collect()) } /// Lists sessions with lazy loading - only loads full data if detailed=true pub fn list_sessions_lazy(detailed: bool) -> Result> { - let config = Config::load().unwrap_or_default(); + let config = get_config(); let sessions_dir = Self::sessions_dir()?; if !sessions_dir.exists() { @@ -351,7 +361,7 @@ impl Session { /// Checks if session needs cleanup based on size and age pub fn needs_cleanup(&self) -> bool { let stats = self.get_stats(); - let config = Config::load().unwrap_or_default(); + let config = get_config(); // Check if conversation is too long if stats.total_messages > config.limits.max_conversation_history * 2 { @@ -368,7 +378,7 @@ impl Session { /// Performs aggressive cleanup for memory optimization pub fn cleanup_for_memory(&mut self) { - let config = Config::load().unwrap_or_default(); + let config = get_config(); let target_messages = config.limits.max_conversation_history / 2; if self.messages.len() > target_messages + 1 { diff --git a/src/main.rs b/src/main.rs index 20cd447..99c1540 100644 --- a/src/main.rs +++ b/src/main.rs @@ -5,9 +5,7 @@ mod utils; use anyhow::{Context, Result}; use clap::Parser; -use signal_hook::{consts::SIGINT, iterator::Signals}; -use std::sync::atomic::{AtomicBool, Ordering}; -use std::sync::Arc; +use tokio::signal::unix::{signal, SignalKind}; use crate::cli::ChatCLI; use crate::config::Config; @@ -19,7 +17,11 @@ use crate::utils::Display; #[command(about = "A lightweight command-line interface for chatting with AI models")] #[command(version)] struct Args { - #[arg(short, long, help = "Session name (defaults to configured default session)")] + #[arg( + short, + long, + help = "Session name (defaults to configured default session)" + )] session: Option, #[arg(short, long, help = "Model name to use (overrides saved value)")] @@ -34,16 +36,8 @@ async fn main() -> Result<()> { let args = Args::parse(); let display = Display::new(); - // Set up signal handling for proper cleanup - let term = Arc::new(AtomicBool::new(false)); - let term_clone = term.clone(); - - std::thread::spawn(move || { - let mut signals = Signals::new(&[SIGINT]).unwrap(); - for _ in signals.forever() { - term_clone.store(true, Ordering::Relaxed); - } - }); + // Set up signal handling using Tokio's signal support + let mut sigint = signal(SignalKind::interrupt())?; // Handle config creation if args.create_config { @@ -53,14 +47,16 @@ async fn main() -> Result<()> { // Load configuration let config = Config::load().context("Failed to load configuration")?; - + // Validate environment variables let env_vars = Config::validate_env_variables().context("Environment validation failed")?; // Load or create session // Use configured default session if none specified - let session_name = args.session.unwrap_or_else(|| config.defaults.default_session.clone()); - + let session_name = args + .session + .unwrap_or_else(|| config.defaults.default_session.clone()); + let session = match Session::load(&session_name) { Ok(mut session) => { if let Some(model) = args.model { @@ -99,9 +95,22 @@ async fn main() -> Result<()> { // Print configuration info config.print_config_info(); - // Run the CLI + // Run the CLI and handle SIGINT gracefully let mut cli = ChatCLI::new(session, config).context("Failed to initialize CLI")?; - cli.run().await.context("CLI error")?; + + let mut cli_run = Box::pin(cli.run()); + + tokio::select! { + res = &mut cli_run => { + res.context("CLI error")?; + } + _ = sigint.recv() => { + // Drop the CLI future before using `cli` again + drop(cli_run); + display.print_info("Received interrupt signal. Exiting..."); + cli.save_and_cleanup()?; + } + } Ok(()) }