From b2c1857edd7ae4075d71de6b41c621a7af8c68ab Mon Sep 17 00:00:00 2001 From: Christopher Date: Mon, 25 Aug 2025 00:52:47 -0400 Subject: [PATCH] refactor: handle SIGINT with tokio signal --- Cargo.lock | 11 --- Cargo.toml | 1 - src/cli.rs | 265 +++++++++++++++++++++++++++++++++------------------- src/main.rs | 47 ++++++---- 4 files changed, 199 insertions(+), 125 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 92eeb7d..4e1976a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -596,7 +596,6 @@ dependencies = [ "serde", "serde_json", "serial_test", - "signal-hook", "syntect", "tempfile", "tokio", @@ -1585,16 +1584,6 @@ version = "1.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" -[[package]] -name = "signal-hook" -version = "0.3.18" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d881a16cf4426aa584979d30bd82cb33429027e42122b169753d6ef1085ed6e2" -dependencies = [ - "libc", - "signal-hook-registry", -] - [[package]] name = "signal-hook-registry" version = "1.4.6" diff --git a/Cargo.toml b/Cargo.toml index 6285965..c9217f5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,7 +23,6 @@ syntect = "5.1" regex = "1.0" futures = "0.3" tokio-stream = "0.1" -signal-hook = "0.3" [dev-dependencies] tempfile = "3.0" diff --git a/src/cli.rs b/src/cli.rs index 1b8cbc7..638fcf8 100644 --- a/src/cli.rs +++ b/src/cli.rs @@ -2,7 +2,8 @@ use anyhow::Result; use crate::config::Config; use crate::core::{ - create_client, get_provider_for_model, provider::{get_model_info_list, get_display_name_for_model, get_model_id_from_display_name}, + create_client, get_provider_for_model, + provider::{get_display_name_for_model, get_model_id_from_display_name, get_model_info_list}, ChatClient, Session, }; use crate::utils::{Display, InputHandler, SessionAction}; @@ -41,14 +42,16 @@ impl ChatCLI { pub async fn run(&mut self) -> Result<()> { self.display.print_header(); - self.display.print_info("Type your message and press Enter. Commands start with '/'."); + self.display + .print_info("Type your message and press Enter. Commands start with '/'."); self.display.print_info("Type /help for help."); - + let provider = get_provider_for_model(&self.session.model); let display_name = get_display_name_for_model(&self.session.model); - self.display.print_model_info(&display_name, provider.as_str()); + self.display + .print_model_info(&display_name, provider.as_str()); self.display.print_session_info(&self.session.name); - + println!(); loop { @@ -77,6 +80,10 @@ impl ChatCLI { } } + self.save_and_cleanup() + } + + pub fn save_and_cleanup(&mut self) -> Result<()> { self.session.save()?; self.input.cleanup()?; // Use cleanup instead of just save_history Ok(()) @@ -94,19 +101,19 @@ impl ChatCLI { let reasoning_effort = self.session.reasoning_effort.clone(); let enable_extended_thinking = self.session.enable_extended_thinking; let thinking_budget_tokens = self.session.thinking_budget_tokens; - + // Check if we should use streaming before getting client let should_use_streaming = { let client = self.get_client()?; client.supports_streaming() }; - + if should_use_streaming { println!(); // Add padding before AI response print!("{}> ", console::style("🤖").magenta()); use std::io::{self, Write}; io::stdout().flush().ok(); - + let stream_callback = { use crate::core::StreamCallback; Box::new(move |chunk: &str| { @@ -116,7 +123,7 @@ impl ChatCLI { Box::pin(async move {}) as Pin + Send>> }) as StreamCallback }; - + let client = self.get_client()?; match client .chat_completion_stream( @@ -138,7 +145,8 @@ impl ChatCLI { } Err(e) => { println!(); // Add newline after failed streaming - self.display.print_error(&format!("Streaming failed: {}", e)); + self.display + .print_error(&format!("Streaming failed: {}", e)); return Err(e); } } @@ -146,7 +154,7 @@ impl ChatCLI { // Fallback to non-streaming let spinner = self.display.show_spinner("Thinking"); let client = self.get_client()?; - + match client .chat_completion( &model, @@ -218,7 +226,8 @@ impl ChatCLI { self.handle_save_command(&parts)?; } _ => { - self.display.print_error(&format!("Unknown command: {} (see /help)", parts[0])); + self.display + .print_error(&format!("Unknown command: {} (see /help)", parts[0])); } } @@ -227,15 +236,18 @@ impl ChatCLI { async fn model_switcher(&mut self) -> Result<()> { let model_info_list = get_model_info_list(); - let display_names: Vec = model_info_list.iter().map(|info| info.display_name.to_string()).collect(); + let display_names: Vec = model_info_list + .iter() + .map(|info| info.display_name.to_string()) + .collect(); let current_display_name = get_display_name_for_model(&self.session.model); - + let selection = self.input.select_from_list( "Select a model:", &display_names, Some(¤t_display_name), )?; - + match selection { Some(display_name) => { if let Some(model_id) = get_model_id_from_display_name(&display_name) { @@ -260,11 +272,10 @@ impl ChatCLI { self.display.print_info("Model selection cancelled"); } } - + Ok(()) } - fn handle_new_session(&mut self, parts: &[&str]) -> Result<()> { if parts.len() != 2 { self.display.print_error("Usage: /new "); @@ -274,18 +285,16 @@ impl ChatCLI { self.session.save()?; let new_session = Session::new(parts[1].to_string(), self.session.model.clone()); self.session = new_session; - self.display.print_command_result(&format!("New session '{}' started", self.session.name)); - + self.display + .print_command_result(&format!("New session '{}' started", self.session.name)); + Ok(()) } async fn session_manager(&mut self) -> Result<()> { loop { let sessions = Session::list_sessions()?; - let session_names: Vec = sessions - .into_iter() - .map(|(name, _)| name) - .collect(); + let session_names: Vec = sessions.into_iter().map(|(name, _)| name).collect(); if session_names.is_empty() { self.display.print_info("No sessions available"); @@ -304,7 +313,7 @@ impl ChatCLI { self.display.print_info("Already in that session"); return Ok(()); } - + self.session.save()?; match Session::load(&session_name) { Ok(session) => { @@ -318,7 +327,8 @@ impl ChatCLI { return Ok(()); } Err(e) => { - self.display.print_error(&format!("Failed to load session: {}", e)); + self.display + .print_error(&format!("Failed to load session: {}", e)); // Don't return, allow user to try again or cancel } } @@ -326,8 +336,11 @@ impl ChatCLI { SessionAction::Delete(session_name) => { match Session::delete_session(&session_name) { Ok(()) => { - self.display.print_command_result(&format!("Session '{}' deleted", session_name)); - + self.display.print_command_result(&format!( + "Session '{}' deleted", + session_name + )); + // If we deleted the current session, we need to handle this specially if session_name == self.session.name { // Try to switch to another session or create a default one @@ -336,18 +349,23 @@ impl ChatCLI { .into_iter() .map(|(name, _)| name) .collect(); - + if remaining_names.is_empty() { // No sessions left, create a default one - self.session = Session::new("default".to_string(), self.session.model.clone()); - self.display.print_command_result("Created new default session"); + self.session = Session::new( + "default".to_string(), + self.session.model.clone(), + ); + self.display + .print_command_result("Created new default session"); return Ok(()); } else { // Switch to the first available session match Session::load(&remaining_names[0]) { Ok(session) => { self.session = session; - let display_name = get_display_name_for_model(&self.session.model); + let display_name = + get_display_name_for_model(&self.session.model); self.display.print_command_result(&format!( "Switched to session '{}' (model={})", self.session.name, display_name @@ -356,10 +374,18 @@ impl ChatCLI { return Ok(()); } Err(e) => { - self.display.print_error(&format!("Failed to load fallback session: {}", e)); + self.display.print_error(&format!( + "Failed to load fallback session: {}", + e + )); // Create a new default session as fallback - self.session = Session::new("default".to_string(), self.session.model.clone()); - self.display.print_command_result("Created new default session"); + self.session = Session::new( + "default".to_string(), + self.session.model.clone(), + ); + self.display.print_command_result( + "Created new default session", + ); return Ok(()); } } @@ -368,7 +394,8 @@ impl ChatCLI { // Continue to show updated session list if we didn't delete current session } Err(e) => { - self.display.print_error(&format!("Failed to delete session: {}", e)); + self.display + .print_error(&format!("Failed to delete session: {}", e)); // Continue to allow retry } } @@ -383,7 +410,8 @@ impl ChatCLI { )); } Err(e) => { - self.display.print_error(&format!("Failed to set default session: {}", e)); + self.display + .print_error(&format!("Failed to set default session: {}", e)); } } // Continue to show session list @@ -395,32 +423,48 @@ impl ChatCLI { } } - async fn tools_manager(&mut self) -> Result<()> { loop { // Show current tool status self.display.print_info("Tool Management:"); - - let web_status = if self.session.enable_web_search { "✓ enabled" } else { "✗ disabled" }; - let reasoning_status = if self.session.enable_reasoning_summary { "✓ enabled" } else { "✗ disabled" }; - let extended_thinking_status = if self.session.enable_extended_thinking { "✓ enabled" } else { "✗ disabled" }; - + + let web_status = if self.session.enable_web_search { + "✓ enabled" + } else { + "✗ disabled" + }; + let reasoning_status = if self.session.enable_reasoning_summary { + "✓ enabled" + } else { + "✗ disabled" + }; + let extended_thinking_status = if self.session.enable_extended_thinking { + "✓ enabled" + } else { + "✗ disabled" + }; + println!(" Web Search: {}", web_status); println!(" Reasoning Summaries: {}", reasoning_status); println!(" Reasoning Effort: {}", self.session.reasoning_effort); println!(" Extended Thinking: {}", extended_thinking_status); - println!(" Thinking Budget: {} tokens", self.session.thinking_budget_tokens); - + println!( + " Thinking Budget: {} tokens", + self.session.thinking_budget_tokens + ); + // Check model compatibility let model = self.session.model.clone(); let provider = get_provider_for_model(&model); let reasoning_enabled = self.session.enable_reasoning_summary; - + // Show compatibility warnings based on provider match provider { crate::core::provider::Provider::Anthropic => { if reasoning_enabled { - self.display.print_warning("Reasoning summaries are not supported by Anthropic models"); + self.display.print_warning( + "Reasoning summaries are not supported by Anthropic models", + ); } if self.session.enable_extended_thinking { // Extended thinking is supported by Anthropic models @@ -429,38 +473,47 @@ impl ChatCLI { } crate::core::provider::Provider::OpenAI => { if self.session.enable_extended_thinking { - self.display.print_warning("Extended thinking is not supported by OpenAI models"); + self.display + .print_warning("Extended thinking is not supported by OpenAI models"); } // OpenAI models generally support other features } } - + // Tool management options let options = vec![ "Toggle Web Search", - "Toggle Reasoning Summaries", + "Toggle Reasoning Summaries", "Set Reasoning Effort", "Toggle Extended Thinking", "Set Thinking Budget", - "Done" + "Done", ]; - - let selection = self.input.select_from_list( - "Select an option:", - &options, - None, - )?; - + + let selection = self + .input + .select_from_list("Select an option:", &options, None)?; + match selection.as_deref() { Some("Toggle Web Search") => { self.session.enable_web_search = !self.session.enable_web_search; - let state = if self.session.enable_web_search { "enabled" } else { "disabled" }; - self.display.print_command_result(&format!("Web search {}", state)); + let state = if self.session.enable_web_search { + "enabled" + } else { + "disabled" + }; + self.display + .print_command_result(&format!("Web search {}", state)); } Some("Toggle Reasoning Summaries") => { self.session.enable_reasoning_summary = !self.session.enable_reasoning_summary; - let state = if self.session.enable_reasoning_summary { "enabled" } else { "disabled" }; - self.display.print_command_result(&format!("Reasoning summaries {}", state)); + let state = if self.session.enable_reasoning_summary { + "enabled" + } else { + "disabled" + }; + self.display + .print_command_result(&format!("Reasoning summaries {}", state)); } Some("Set Reasoning Effort") => { let effort_options = vec!["low", "medium", "high"]; @@ -470,22 +523,32 @@ impl ChatCLI { Some(&self.session.reasoning_effort), )? { self.session.reasoning_effort = effort.to_string(); - self.display.print_command_result(&format!("Reasoning effort set to {}", effort)); - + self.display + .print_command_result(&format!("Reasoning effort set to {}", effort)); + if !self.session.model.starts_with("gpt-5") { - self.display.print_warning("Reasoning effort is only supported by GPT-5 models"); + self.display.print_warning( + "Reasoning effort is only supported by GPT-5 models", + ); } } } Some("Toggle Extended Thinking") => { self.session.enable_extended_thinking = !self.session.enable_extended_thinking; - let state = if self.session.enable_extended_thinking { "enabled" } else { "disabled" }; - self.display.print_command_result(&format!("Extended thinking {}", state)); - + let state = if self.session.enable_extended_thinking { + "enabled" + } else { + "disabled" + }; + self.display + .print_command_result(&format!("Extended thinking {}", state)); + let provider = get_provider_for_model(&self.session.model); match provider { crate::core::provider::Provider::OpenAI => { - self.display.print_warning("Extended thinking is not supported by OpenAI models"); + self.display.print_warning( + "Extended thinking is not supported by OpenAI models", + ); } crate::core::provider::Provider::Anthropic => { // Supported @@ -502,12 +565,17 @@ impl ChatCLI { )? { if let Ok(budget) = budget_str.parse::() { self.session.thinking_budget_tokens = budget; - self.display.print_command_result(&format!("Thinking budget set to {} tokens", budget)); - + self.display.print_command_result(&format!( + "Thinking budget set to {} tokens", + budget + )); + let provider = get_provider_for_model(&self.session.model); match provider { crate::core::provider::Provider::OpenAI => { - self.display.print_warning("Extended thinking is not supported by OpenAI models"); + self.display.print_warning( + "Extended thinking is not supported by OpenAI models", + ); } crate::core::provider::Provider::Anthropic => { if budget < 1024 { @@ -523,18 +591,18 @@ impl ChatCLI { } _ => {} } - + self.session.save()?; // Save changes after each modification println!(); // Add spacing } - + Ok(()) } fn handle_history_command(&mut self, parts: &[&str]) -> Result<()> { let mut filter_role: Option<&str> = None; let mut limit: Option = None; - + // Parse parameters for &part in parts.iter().skip(1) { match part { @@ -543,37 +611,41 @@ impl ChatCLI { if let Ok(num) = part.parse::() { limit = Some(num); } else { - self.display.print_error(&format!("Invalid parameter: {}", part)); - self.display.print_info("Usage: /history [user|assistant] [number]"); + self.display + .print_error(&format!("Invalid parameter: {}", part)); + self.display + .print_info("Usage: /history [user|assistant] [number]"); return Ok(()); } } } } - + // Filter messages (skip system prompt at index 0) - let mut messages: Vec<(usize, &crate::core::Message)> = self.session.messages + let mut messages: Vec<(usize, &crate::core::Message)> = self + .session + .messages .iter() .enumerate() .skip(1) // Skip system prompt .collect(); - + // Apply role filter if let Some(role) = filter_role { messages.retain(|(_, msg)| msg.role == role); } - + // Apply limit if let Some(limit_count) = limit { let start_index = messages.len().saturating_sub(limit_count); messages = messages[start_index..].to_vec(); } - + if messages.is_empty() { self.display.print_info("No messages to display"); return Ok(()); } - + // Format and display self.display.print_conversation_history(&messages); Ok(()) @@ -590,7 +662,7 @@ impl ChatCLI { let valid_formats = ["markdown", "md", "json", "txt"]; if !valid_formats.contains(&format.as_str()) { self.display.print_error(&format!( - "Invalid format '{}'. Supported formats: markdown, json, txt", + "Invalid format '{}'. Supported formats: markdown, json, txt", format )); self.display.print_info("Usage: /export [format]"); @@ -602,19 +674,20 @@ impl ChatCLI { .duration_since(std::time::UNIX_EPOCH) .unwrap_or_default() .as_secs(); - + let extension = match format.as_str() { "markdown" | "md" => "md", "json" => "json", "txt" => "txt", _ => "md", }; - + // Create exports directory if it doesn't exist let exports_dir = std::path::Path::new("exports"); if !exports_dir.exists() { if let Err(e) = std::fs::create_dir(exports_dir) { - self.display.print_error(&format!("Failed to create exports directory: {}", e)); + self.display + .print_error(&format!("Failed to create exports directory: {}", e)); return Ok(()); } } @@ -622,10 +695,13 @@ impl ChatCLI { let filename = format!("{}_{}.{}", self.session.name, now, extension); let file_path = exports_dir.join(&filename); - match self.session.export(&format, file_path.to_str().unwrap_or(&filename)) { + match self + .session + .export(&format, file_path.to_str().unwrap_or(&filename)) + { Ok(()) => { self.display.print_command_result(&format!( - "Conversation exported to '{}'", + "Conversation exported to '{}'", file_path.display() )); } @@ -640,7 +716,8 @@ impl ChatCLI { fn handle_save_command(&mut self, parts: &[&str]) -> Result<()> { if parts.len() != 2 { self.display.print_error("Usage: /save "); - self.display.print_info("Example: /save my_important_conversation"); + self.display + .print_info("Example: /save my_important_conversation"); return Ok(()); } @@ -656,7 +733,7 @@ impl ChatCLI { if let Ok(sessions) = Session::list_sessions() { if sessions.iter().any(|(name, _)| name == new_session_name) { self.display.print_error(&format!( - "Session '{}' already exists. Choose a different name.", + "Session '{}' already exists. Choose a different name.", new_session_name )); return Ok(()); @@ -670,7 +747,7 @@ impl ChatCLI { match self.session.save_as(new_session_name) { Ok(()) => { self.display.print_command_result(&format!( - "Current session saved as '{}' ({} messages copied)", + "Current session saved as '{}' ({} messages copied)", new_session_name, self.session.messages.len().saturating_sub(1) // Exclude system prompt )); @@ -680,11 +757,11 @@ impl ChatCLI { )); } Err(e) => { - self.display.print_error(&format!("Failed to save session: {}", e)); + self.display + .print_error(&format!("Failed to save session: {}", e)); } } Ok(()) } - -} \ No newline at end of file +} diff --git a/src/main.rs b/src/main.rs index 20cd447..99c1540 100644 --- a/src/main.rs +++ b/src/main.rs @@ -5,9 +5,7 @@ mod utils; use anyhow::{Context, Result}; use clap::Parser; -use signal_hook::{consts::SIGINT, iterator::Signals}; -use std::sync::atomic::{AtomicBool, Ordering}; -use std::sync::Arc; +use tokio::signal::unix::{signal, SignalKind}; use crate::cli::ChatCLI; use crate::config::Config; @@ -19,7 +17,11 @@ use crate::utils::Display; #[command(about = "A lightweight command-line interface for chatting with AI models")] #[command(version)] struct Args { - #[arg(short, long, help = "Session name (defaults to configured default session)")] + #[arg( + short, + long, + help = "Session name (defaults to configured default session)" + )] session: Option, #[arg(short, long, help = "Model name to use (overrides saved value)")] @@ -34,16 +36,8 @@ async fn main() -> Result<()> { let args = Args::parse(); let display = Display::new(); - // Set up signal handling for proper cleanup - let term = Arc::new(AtomicBool::new(false)); - let term_clone = term.clone(); - - std::thread::spawn(move || { - let mut signals = Signals::new(&[SIGINT]).unwrap(); - for _ in signals.forever() { - term_clone.store(true, Ordering::Relaxed); - } - }); + // Set up signal handling using Tokio's signal support + let mut sigint = signal(SignalKind::interrupt())?; // Handle config creation if args.create_config { @@ -53,14 +47,16 @@ async fn main() -> Result<()> { // Load configuration let config = Config::load().context("Failed to load configuration")?; - + // Validate environment variables let env_vars = Config::validate_env_variables().context("Environment validation failed")?; // Load or create session // Use configured default session if none specified - let session_name = args.session.unwrap_or_else(|| config.defaults.default_session.clone()); - + let session_name = args + .session + .unwrap_or_else(|| config.defaults.default_session.clone()); + let session = match Session::load(&session_name) { Ok(mut session) => { if let Some(model) = args.model { @@ -99,9 +95,22 @@ async fn main() -> Result<()> { // Print configuration info config.print_config_info(); - // Run the CLI + // Run the CLI and handle SIGINT gracefully let mut cli = ChatCLI::new(session, config).context("Failed to initialize CLI")?; - cli.run().await.context("CLI error")?; + + let mut cli_run = Box::pin(cli.run()); + + tokio::select! { + res = &mut cli_run => { + res.context("CLI error")?; + } + _ = sigint.recv() => { + // Drop the CLI future before using `cli` again + drop(cli_run); + display.print_info("Received interrupt signal. Exiting..."); + cli.save_and_cleanup()?; + } + } Ok(()) }