Merge pull request #1 from leachy14/codex/refactor-signal-handling-in-main.rs

refactor: handle SIGINT with tokio signal
This commit is contained in:
Christopher 2025-08-25 00:53:01 -04:00 committed by GitHub
commit a5e225bf60
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 199 additions and 125 deletions

11
Cargo.lock generated
View File

@ -596,7 +596,6 @@ dependencies = [
"serde", "serde",
"serde_json", "serde_json",
"serial_test", "serial_test",
"signal-hook",
"syntect", "syntect",
"tempfile", "tempfile",
"tokio", "tokio",
@ -1585,16 +1584,6 @@ version = "1.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64"
[[package]]
name = "signal-hook"
version = "0.3.18"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d881a16cf4426aa584979d30bd82cb33429027e42122b169753d6ef1085ed6e2"
dependencies = [
"libc",
"signal-hook-registry",
]
[[package]] [[package]]
name = "signal-hook-registry" name = "signal-hook-registry"
version = "1.4.6" version = "1.4.6"

View File

@ -23,7 +23,6 @@ syntect = "5.1"
regex = "1.0" regex = "1.0"
futures = "0.3" futures = "0.3"
tokio-stream = "0.1" tokio-stream = "0.1"
signal-hook = "0.3"
[dev-dependencies] [dev-dependencies]
tempfile = "3.0" tempfile = "3.0"

View File

@ -2,7 +2,8 @@ use anyhow::Result;
use crate::config::Config; use crate::config::Config;
use crate::core::{ use crate::core::{
create_client, get_provider_for_model, provider::{get_model_info_list, get_display_name_for_model, get_model_id_from_display_name}, create_client, get_provider_for_model,
provider::{get_display_name_for_model, get_model_id_from_display_name, get_model_info_list},
ChatClient, Session, ChatClient, Session,
}; };
use crate::utils::{Display, InputHandler, SessionAction}; use crate::utils::{Display, InputHandler, SessionAction};
@ -41,14 +42,16 @@ impl ChatCLI {
pub async fn run(&mut self) -> Result<()> { pub async fn run(&mut self) -> Result<()> {
self.display.print_header(); self.display.print_header();
self.display.print_info("Type your message and press Enter. Commands start with '/'."); self.display
.print_info("Type your message and press Enter. Commands start with '/'.");
self.display.print_info("Type /help for help."); self.display.print_info("Type /help for help.");
let provider = get_provider_for_model(&self.session.model); let provider = get_provider_for_model(&self.session.model);
let display_name = get_display_name_for_model(&self.session.model); let display_name = get_display_name_for_model(&self.session.model);
self.display.print_model_info(&display_name, provider.as_str()); self.display
.print_model_info(&display_name, provider.as_str());
self.display.print_session_info(&self.session.name); self.display.print_session_info(&self.session.name);
println!(); println!();
loop { loop {
@ -77,6 +80,10 @@ impl ChatCLI {
} }
} }
self.save_and_cleanup()
}
pub fn save_and_cleanup(&mut self) -> Result<()> {
self.session.save()?; self.session.save()?;
self.input.cleanup()?; // Use cleanup instead of just save_history self.input.cleanup()?; // Use cleanup instead of just save_history
Ok(()) Ok(())
@ -94,19 +101,19 @@ impl ChatCLI {
let reasoning_effort = self.session.reasoning_effort.clone(); let reasoning_effort = self.session.reasoning_effort.clone();
let enable_extended_thinking = self.session.enable_extended_thinking; let enable_extended_thinking = self.session.enable_extended_thinking;
let thinking_budget_tokens = self.session.thinking_budget_tokens; let thinking_budget_tokens = self.session.thinking_budget_tokens;
// Check if we should use streaming before getting client // Check if we should use streaming before getting client
let should_use_streaming = { let should_use_streaming = {
let client = self.get_client()?; let client = self.get_client()?;
client.supports_streaming() client.supports_streaming()
}; };
if should_use_streaming { if should_use_streaming {
println!(); // Add padding before AI response println!(); // Add padding before AI response
print!("{}> ", console::style("🤖").magenta()); print!("{}> ", console::style("🤖").magenta());
use std::io::{self, Write}; use std::io::{self, Write};
io::stdout().flush().ok(); io::stdout().flush().ok();
let stream_callback = { let stream_callback = {
use crate::core::StreamCallback; use crate::core::StreamCallback;
Box::new(move |chunk: &str| { Box::new(move |chunk: &str| {
@ -116,7 +123,7 @@ impl ChatCLI {
Box::pin(async move {}) as Pin<Box<dyn Future<Output = ()> + Send>> Box::pin(async move {}) as Pin<Box<dyn Future<Output = ()> + Send>>
}) as StreamCallback }) as StreamCallback
}; };
let client = self.get_client()?; let client = self.get_client()?;
match client match client
.chat_completion_stream( .chat_completion_stream(
@ -138,7 +145,8 @@ impl ChatCLI {
} }
Err(e) => { Err(e) => {
println!(); // Add newline after failed streaming println!(); // Add newline after failed streaming
self.display.print_error(&format!("Streaming failed: {}", e)); self.display
.print_error(&format!("Streaming failed: {}", e));
return Err(e); return Err(e);
} }
} }
@ -146,7 +154,7 @@ impl ChatCLI {
// Fallback to non-streaming // Fallback to non-streaming
let spinner = self.display.show_spinner("Thinking"); let spinner = self.display.show_spinner("Thinking");
let client = self.get_client()?; let client = self.get_client()?;
match client match client
.chat_completion( .chat_completion(
&model, &model,
@ -218,7 +226,8 @@ impl ChatCLI {
self.handle_save_command(&parts)?; self.handle_save_command(&parts)?;
} }
_ => { _ => {
self.display.print_error(&format!("Unknown command: {} (see /help)", parts[0])); self.display
.print_error(&format!("Unknown command: {} (see /help)", parts[0]));
} }
} }
@ -227,15 +236,18 @@ impl ChatCLI {
async fn model_switcher(&mut self) -> Result<()> { async fn model_switcher(&mut self) -> Result<()> {
let model_info_list = get_model_info_list(); let model_info_list = get_model_info_list();
let display_names: Vec<String> = model_info_list.iter().map(|info| info.display_name.to_string()).collect(); let display_names: Vec<String> = model_info_list
.iter()
.map(|info| info.display_name.to_string())
.collect();
let current_display_name = get_display_name_for_model(&self.session.model); let current_display_name = get_display_name_for_model(&self.session.model);
let selection = self.input.select_from_list( let selection = self.input.select_from_list(
"Select a model:", "Select a model:",
&display_names, &display_names,
Some(&current_display_name), Some(&current_display_name),
)?; )?;
match selection { match selection {
Some(display_name) => { Some(display_name) => {
if let Some(model_id) = get_model_id_from_display_name(&display_name) { if let Some(model_id) = get_model_id_from_display_name(&display_name) {
@ -260,11 +272,10 @@ impl ChatCLI {
self.display.print_info("Model selection cancelled"); self.display.print_info("Model selection cancelled");
} }
} }
Ok(()) Ok(())
} }
fn handle_new_session(&mut self, parts: &[&str]) -> Result<()> { fn handle_new_session(&mut self, parts: &[&str]) -> Result<()> {
if parts.len() != 2 { if parts.len() != 2 {
self.display.print_error("Usage: /new <session_name>"); self.display.print_error("Usage: /new <session_name>");
@ -274,18 +285,16 @@ impl ChatCLI {
self.session.save()?; self.session.save()?;
let new_session = Session::new(parts[1].to_string(), self.session.model.clone()); let new_session = Session::new(parts[1].to_string(), self.session.model.clone());
self.session = new_session; self.session = new_session;
self.display.print_command_result(&format!("New session '{}' started", self.session.name)); self.display
.print_command_result(&format!("New session '{}' started", self.session.name));
Ok(()) Ok(())
} }
async fn session_manager(&mut self) -> Result<()> { async fn session_manager(&mut self) -> Result<()> {
loop { loop {
let sessions = Session::list_sessions()?; let sessions = Session::list_sessions()?;
let session_names: Vec<String> = sessions let session_names: Vec<String> = sessions.into_iter().map(|(name, _)| name).collect();
.into_iter()
.map(|(name, _)| name)
.collect();
if session_names.is_empty() { if session_names.is_empty() {
self.display.print_info("No sessions available"); self.display.print_info("No sessions available");
@ -304,7 +313,7 @@ impl ChatCLI {
self.display.print_info("Already in that session"); self.display.print_info("Already in that session");
return Ok(()); return Ok(());
} }
self.session.save()?; self.session.save()?;
match Session::load(&session_name) { match Session::load(&session_name) {
Ok(session) => { Ok(session) => {
@ -318,7 +327,8 @@ impl ChatCLI {
return Ok(()); return Ok(());
} }
Err(e) => { Err(e) => {
self.display.print_error(&format!("Failed to load session: {}", e)); self.display
.print_error(&format!("Failed to load session: {}", e));
// Don't return, allow user to try again or cancel // Don't return, allow user to try again or cancel
} }
} }
@ -326,8 +336,11 @@ impl ChatCLI {
SessionAction::Delete(session_name) => { SessionAction::Delete(session_name) => {
match Session::delete_session(&session_name) { match Session::delete_session(&session_name) {
Ok(()) => { Ok(()) => {
self.display.print_command_result(&format!("Session '{}' deleted", session_name)); self.display.print_command_result(&format!(
"Session '{}' deleted",
session_name
));
// If we deleted the current session, we need to handle this specially // If we deleted the current session, we need to handle this specially
if session_name == self.session.name { if session_name == self.session.name {
// Try to switch to another session or create a default one // Try to switch to another session or create a default one
@ -336,18 +349,23 @@ impl ChatCLI {
.into_iter() .into_iter()
.map(|(name, _)| name) .map(|(name, _)| name)
.collect(); .collect();
if remaining_names.is_empty() { if remaining_names.is_empty() {
// No sessions left, create a default one // No sessions left, create a default one
self.session = Session::new("default".to_string(), self.session.model.clone()); self.session = Session::new(
self.display.print_command_result("Created new default session"); "default".to_string(),
self.session.model.clone(),
);
self.display
.print_command_result("Created new default session");
return Ok(()); return Ok(());
} else { } else {
// Switch to the first available session // Switch to the first available session
match Session::load(&remaining_names[0]) { match Session::load(&remaining_names[0]) {
Ok(session) => { Ok(session) => {
self.session = session; self.session = session;
let display_name = get_display_name_for_model(&self.session.model); let display_name =
get_display_name_for_model(&self.session.model);
self.display.print_command_result(&format!( self.display.print_command_result(&format!(
"Switched to session '{}' (model={})", "Switched to session '{}' (model={})",
self.session.name, display_name self.session.name, display_name
@ -356,10 +374,18 @@ impl ChatCLI {
return Ok(()); return Ok(());
} }
Err(e) => { Err(e) => {
self.display.print_error(&format!("Failed to load fallback session: {}", e)); self.display.print_error(&format!(
"Failed to load fallback session: {}",
e
));
// Create a new default session as fallback // Create a new default session as fallback
self.session = Session::new("default".to_string(), self.session.model.clone()); self.session = Session::new(
self.display.print_command_result("Created new default session"); "default".to_string(),
self.session.model.clone(),
);
self.display.print_command_result(
"Created new default session",
);
return Ok(()); return Ok(());
} }
} }
@ -368,7 +394,8 @@ impl ChatCLI {
// Continue to show updated session list if we didn't delete current session // Continue to show updated session list if we didn't delete current session
} }
Err(e) => { Err(e) => {
self.display.print_error(&format!("Failed to delete session: {}", e)); self.display
.print_error(&format!("Failed to delete session: {}", e));
// Continue to allow retry // Continue to allow retry
} }
} }
@ -383,7 +410,8 @@ impl ChatCLI {
)); ));
} }
Err(e) => { Err(e) => {
self.display.print_error(&format!("Failed to set default session: {}", e)); self.display
.print_error(&format!("Failed to set default session: {}", e));
} }
} }
// Continue to show session list // Continue to show session list
@ -395,32 +423,48 @@ impl ChatCLI {
} }
} }
async fn tools_manager(&mut self) -> Result<()> { async fn tools_manager(&mut self) -> Result<()> {
loop { loop {
// Show current tool status // Show current tool status
self.display.print_info("Tool Management:"); self.display.print_info("Tool Management:");
let web_status = if self.session.enable_web_search { "✓ enabled" } else { "✗ disabled" }; let web_status = if self.session.enable_web_search {
let reasoning_status = if self.session.enable_reasoning_summary { "✓ enabled" } else { "✗ disabled" }; "✓ enabled"
let extended_thinking_status = if self.session.enable_extended_thinking { "✓ enabled" } else { "✗ disabled" }; } else {
"✗ disabled"
};
let reasoning_status = if self.session.enable_reasoning_summary {
"✓ enabled"
} else {
"✗ disabled"
};
let extended_thinking_status = if self.session.enable_extended_thinking {
"✓ enabled"
} else {
"✗ disabled"
};
println!(" Web Search: {}", web_status); println!(" Web Search: {}", web_status);
println!(" Reasoning Summaries: {}", reasoning_status); println!(" Reasoning Summaries: {}", reasoning_status);
println!(" Reasoning Effort: {}", self.session.reasoning_effort); println!(" Reasoning Effort: {}", self.session.reasoning_effort);
println!(" Extended Thinking: {}", extended_thinking_status); println!(" Extended Thinking: {}", extended_thinking_status);
println!(" Thinking Budget: {} tokens", self.session.thinking_budget_tokens); println!(
" Thinking Budget: {} tokens",
self.session.thinking_budget_tokens
);
// Check model compatibility // Check model compatibility
let model = self.session.model.clone(); let model = self.session.model.clone();
let provider = get_provider_for_model(&model); let provider = get_provider_for_model(&model);
let reasoning_enabled = self.session.enable_reasoning_summary; let reasoning_enabled = self.session.enable_reasoning_summary;
// Show compatibility warnings based on provider // Show compatibility warnings based on provider
match provider { match provider {
crate::core::provider::Provider::Anthropic => { crate::core::provider::Provider::Anthropic => {
if reasoning_enabled { if reasoning_enabled {
self.display.print_warning("Reasoning summaries are not supported by Anthropic models"); self.display.print_warning(
"Reasoning summaries are not supported by Anthropic models",
);
} }
if self.session.enable_extended_thinking { if self.session.enable_extended_thinking {
// Extended thinking is supported by Anthropic models // Extended thinking is supported by Anthropic models
@ -429,38 +473,47 @@ impl ChatCLI {
} }
crate::core::provider::Provider::OpenAI => { crate::core::provider::Provider::OpenAI => {
if self.session.enable_extended_thinking { if self.session.enable_extended_thinking {
self.display.print_warning("Extended thinking is not supported by OpenAI models"); self.display
.print_warning("Extended thinking is not supported by OpenAI models");
} }
// OpenAI models generally support other features // OpenAI models generally support other features
} }
} }
// Tool management options // Tool management options
let options = vec![ let options = vec![
"Toggle Web Search", "Toggle Web Search",
"Toggle Reasoning Summaries", "Toggle Reasoning Summaries",
"Set Reasoning Effort", "Set Reasoning Effort",
"Toggle Extended Thinking", "Toggle Extended Thinking",
"Set Thinking Budget", "Set Thinking Budget",
"Done" "Done",
]; ];
let selection = self.input.select_from_list( let selection = self
"Select an option:", .input
&options, .select_from_list("Select an option:", &options, None)?;
None,
)?;
match selection.as_deref() { match selection.as_deref() {
Some("Toggle Web Search") => { Some("Toggle Web Search") => {
self.session.enable_web_search = !self.session.enable_web_search; self.session.enable_web_search = !self.session.enable_web_search;
let state = if self.session.enable_web_search { "enabled" } else { "disabled" }; let state = if self.session.enable_web_search {
self.display.print_command_result(&format!("Web search {}", state)); "enabled"
} else {
"disabled"
};
self.display
.print_command_result(&format!("Web search {}", state));
} }
Some("Toggle Reasoning Summaries") => { Some("Toggle Reasoning Summaries") => {
self.session.enable_reasoning_summary = !self.session.enable_reasoning_summary; self.session.enable_reasoning_summary = !self.session.enable_reasoning_summary;
let state = if self.session.enable_reasoning_summary { "enabled" } else { "disabled" }; let state = if self.session.enable_reasoning_summary {
self.display.print_command_result(&format!("Reasoning summaries {}", state)); "enabled"
} else {
"disabled"
};
self.display
.print_command_result(&format!("Reasoning summaries {}", state));
} }
Some("Set Reasoning Effort") => { Some("Set Reasoning Effort") => {
let effort_options = vec!["low", "medium", "high"]; let effort_options = vec!["low", "medium", "high"];
@ -470,22 +523,32 @@ impl ChatCLI {
Some(&self.session.reasoning_effort), Some(&self.session.reasoning_effort),
)? { )? {
self.session.reasoning_effort = effort.to_string(); self.session.reasoning_effort = effort.to_string();
self.display.print_command_result(&format!("Reasoning effort set to {}", effort)); self.display
.print_command_result(&format!("Reasoning effort set to {}", effort));
if !self.session.model.starts_with("gpt-5") { if !self.session.model.starts_with("gpt-5") {
self.display.print_warning("Reasoning effort is only supported by GPT-5 models"); self.display.print_warning(
"Reasoning effort is only supported by GPT-5 models",
);
} }
} }
} }
Some("Toggle Extended Thinking") => { Some("Toggle Extended Thinking") => {
self.session.enable_extended_thinking = !self.session.enable_extended_thinking; self.session.enable_extended_thinking = !self.session.enable_extended_thinking;
let state = if self.session.enable_extended_thinking { "enabled" } else { "disabled" }; let state = if self.session.enable_extended_thinking {
self.display.print_command_result(&format!("Extended thinking {}", state)); "enabled"
} else {
"disabled"
};
self.display
.print_command_result(&format!("Extended thinking {}", state));
let provider = get_provider_for_model(&self.session.model); let provider = get_provider_for_model(&self.session.model);
match provider { match provider {
crate::core::provider::Provider::OpenAI => { crate::core::provider::Provider::OpenAI => {
self.display.print_warning("Extended thinking is not supported by OpenAI models"); self.display.print_warning(
"Extended thinking is not supported by OpenAI models",
);
} }
crate::core::provider::Provider::Anthropic => { crate::core::provider::Provider::Anthropic => {
// Supported // Supported
@ -502,12 +565,17 @@ impl ChatCLI {
)? { )? {
if let Ok(budget) = budget_str.parse::<u32>() { if let Ok(budget) = budget_str.parse::<u32>() {
self.session.thinking_budget_tokens = budget; self.session.thinking_budget_tokens = budget;
self.display.print_command_result(&format!("Thinking budget set to {} tokens", budget)); self.display.print_command_result(&format!(
"Thinking budget set to {} tokens",
budget
));
let provider = get_provider_for_model(&self.session.model); let provider = get_provider_for_model(&self.session.model);
match provider { match provider {
crate::core::provider::Provider::OpenAI => { crate::core::provider::Provider::OpenAI => {
self.display.print_warning("Extended thinking is not supported by OpenAI models"); self.display.print_warning(
"Extended thinking is not supported by OpenAI models",
);
} }
crate::core::provider::Provider::Anthropic => { crate::core::provider::Provider::Anthropic => {
if budget < 1024 { if budget < 1024 {
@ -523,18 +591,18 @@ impl ChatCLI {
} }
_ => {} _ => {}
} }
self.session.save()?; // Save changes after each modification self.session.save()?; // Save changes after each modification
println!(); // Add spacing println!(); // Add spacing
} }
Ok(()) Ok(())
} }
fn handle_history_command(&mut self, parts: &[&str]) -> Result<()> { fn handle_history_command(&mut self, parts: &[&str]) -> Result<()> {
let mut filter_role: Option<&str> = None; let mut filter_role: Option<&str> = None;
let mut limit: Option<usize> = None; let mut limit: Option<usize> = None;
// Parse parameters // Parse parameters
for &part in parts.iter().skip(1) { for &part in parts.iter().skip(1) {
match part { match part {
@ -543,37 +611,41 @@ impl ChatCLI {
if let Ok(num) = part.parse::<usize>() { if let Ok(num) = part.parse::<usize>() {
limit = Some(num); limit = Some(num);
} else { } else {
self.display.print_error(&format!("Invalid parameter: {}", part)); self.display
self.display.print_info("Usage: /history [user|assistant] [number]"); .print_error(&format!("Invalid parameter: {}", part));
self.display
.print_info("Usage: /history [user|assistant] [number]");
return Ok(()); return Ok(());
} }
} }
} }
} }
// Filter messages (skip system prompt at index 0) // Filter messages (skip system prompt at index 0)
let mut messages: Vec<(usize, &crate::core::Message)> = self.session.messages let mut messages: Vec<(usize, &crate::core::Message)> = self
.session
.messages
.iter() .iter()
.enumerate() .enumerate()
.skip(1) // Skip system prompt .skip(1) // Skip system prompt
.collect(); .collect();
// Apply role filter // Apply role filter
if let Some(role) = filter_role { if let Some(role) = filter_role {
messages.retain(|(_, msg)| msg.role == role); messages.retain(|(_, msg)| msg.role == role);
} }
// Apply limit // Apply limit
if let Some(limit_count) = limit { if let Some(limit_count) = limit {
let start_index = messages.len().saturating_sub(limit_count); let start_index = messages.len().saturating_sub(limit_count);
messages = messages[start_index..].to_vec(); messages = messages[start_index..].to_vec();
} }
if messages.is_empty() { if messages.is_empty() {
self.display.print_info("No messages to display"); self.display.print_info("No messages to display");
return Ok(()); return Ok(());
} }
// Format and display // Format and display
self.display.print_conversation_history(&messages); self.display.print_conversation_history(&messages);
Ok(()) Ok(())
@ -590,7 +662,7 @@ impl ChatCLI {
let valid_formats = ["markdown", "md", "json", "txt"]; let valid_formats = ["markdown", "md", "json", "txt"];
if !valid_formats.contains(&format.as_str()) { if !valid_formats.contains(&format.as_str()) {
self.display.print_error(&format!( self.display.print_error(&format!(
"Invalid format '{}'. Supported formats: markdown, json, txt", "Invalid format '{}'. Supported formats: markdown, json, txt",
format format
)); ));
self.display.print_info("Usage: /export [format]"); self.display.print_info("Usage: /export [format]");
@ -602,19 +674,20 @@ impl ChatCLI {
.duration_since(std::time::UNIX_EPOCH) .duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default() .unwrap_or_default()
.as_secs(); .as_secs();
let extension = match format.as_str() { let extension = match format.as_str() {
"markdown" | "md" => "md", "markdown" | "md" => "md",
"json" => "json", "json" => "json",
"txt" => "txt", "txt" => "txt",
_ => "md", _ => "md",
}; };
// Create exports directory if it doesn't exist // Create exports directory if it doesn't exist
let exports_dir = std::path::Path::new("exports"); let exports_dir = std::path::Path::new("exports");
if !exports_dir.exists() { if !exports_dir.exists() {
if let Err(e) = std::fs::create_dir(exports_dir) { if let Err(e) = std::fs::create_dir(exports_dir) {
self.display.print_error(&format!("Failed to create exports directory: {}", e)); self.display
.print_error(&format!("Failed to create exports directory: {}", e));
return Ok(()); return Ok(());
} }
} }
@ -622,10 +695,13 @@ impl ChatCLI {
let filename = format!("{}_{}.{}", self.session.name, now, extension); let filename = format!("{}_{}.{}", self.session.name, now, extension);
let file_path = exports_dir.join(&filename); let file_path = exports_dir.join(&filename);
match self.session.export(&format, file_path.to_str().unwrap_or(&filename)) { match self
.session
.export(&format, file_path.to_str().unwrap_or(&filename))
{
Ok(()) => { Ok(()) => {
self.display.print_command_result(&format!( self.display.print_command_result(&format!(
"Conversation exported to '{}'", "Conversation exported to '{}'",
file_path.display() file_path.display()
)); ));
} }
@ -640,7 +716,8 @@ impl ChatCLI {
fn handle_save_command(&mut self, parts: &[&str]) -> Result<()> { fn handle_save_command(&mut self, parts: &[&str]) -> Result<()> {
if parts.len() != 2 { if parts.len() != 2 {
self.display.print_error("Usage: /save <new_session_name>"); self.display.print_error("Usage: /save <new_session_name>");
self.display.print_info("Example: /save my_important_conversation"); self.display
.print_info("Example: /save my_important_conversation");
return Ok(()); return Ok(());
} }
@ -656,7 +733,7 @@ impl ChatCLI {
if let Ok(sessions) = Session::list_sessions() { if let Ok(sessions) = Session::list_sessions() {
if sessions.iter().any(|(name, _)| name == new_session_name) { if sessions.iter().any(|(name, _)| name == new_session_name) {
self.display.print_error(&format!( self.display.print_error(&format!(
"Session '{}' already exists. Choose a different name.", "Session '{}' already exists. Choose a different name.",
new_session_name new_session_name
)); ));
return Ok(()); return Ok(());
@ -670,7 +747,7 @@ impl ChatCLI {
match self.session.save_as(new_session_name) { match self.session.save_as(new_session_name) {
Ok(()) => { Ok(()) => {
self.display.print_command_result(&format!( self.display.print_command_result(&format!(
"Current session saved as '{}' ({} messages copied)", "Current session saved as '{}' ({} messages copied)",
new_session_name, new_session_name,
self.session.messages.len().saturating_sub(1) // Exclude system prompt self.session.messages.len().saturating_sub(1) // Exclude system prompt
)); ));
@ -680,11 +757,11 @@ impl ChatCLI {
)); ));
} }
Err(e) => { Err(e) => {
self.display.print_error(&format!("Failed to save session: {}", e)); self.display
.print_error(&format!("Failed to save session: {}", e));
} }
} }
Ok(()) Ok(())
} }
}
}

View File

@ -5,9 +5,7 @@ mod utils;
use anyhow::{Context, Result}; use anyhow::{Context, Result};
use clap::Parser; use clap::Parser;
use signal_hook::{consts::SIGINT, iterator::Signals}; use tokio::signal::unix::{signal, SignalKind};
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use crate::cli::ChatCLI; use crate::cli::ChatCLI;
use crate::config::Config; use crate::config::Config;
@ -19,7 +17,11 @@ use crate::utils::Display;
#[command(about = "A lightweight command-line interface for chatting with AI models")] #[command(about = "A lightweight command-line interface for chatting with AI models")]
#[command(version)] #[command(version)]
struct Args { struct Args {
#[arg(short, long, help = "Session name (defaults to configured default session)")] #[arg(
short,
long,
help = "Session name (defaults to configured default session)"
)]
session: Option<String>, session: Option<String>,
#[arg(short, long, help = "Model name to use (overrides saved value)")] #[arg(short, long, help = "Model name to use (overrides saved value)")]
@ -34,16 +36,8 @@ async fn main() -> Result<()> {
let args = Args::parse(); let args = Args::parse();
let display = Display::new(); let display = Display::new();
// Set up signal handling for proper cleanup // Set up signal handling using Tokio's signal support
let term = Arc::new(AtomicBool::new(false)); let mut sigint = signal(SignalKind::interrupt())?;
let term_clone = term.clone();
std::thread::spawn(move || {
let mut signals = Signals::new(&[SIGINT]).unwrap();
for _ in signals.forever() {
term_clone.store(true, Ordering::Relaxed);
}
});
// Handle config creation // Handle config creation
if args.create_config { if args.create_config {
@ -53,14 +47,16 @@ async fn main() -> Result<()> {
// Load configuration // Load configuration
let config = Config::load().context("Failed to load configuration")?; let config = Config::load().context("Failed to load configuration")?;
// Validate environment variables // Validate environment variables
let env_vars = Config::validate_env_variables().context("Environment validation failed")?; let env_vars = Config::validate_env_variables().context("Environment validation failed")?;
// Load or create session // Load or create session
// Use configured default session if none specified // Use configured default session if none specified
let session_name = args.session.unwrap_or_else(|| config.defaults.default_session.clone()); let session_name = args
.session
.unwrap_or_else(|| config.defaults.default_session.clone());
let session = match Session::load(&session_name) { let session = match Session::load(&session_name) {
Ok(mut session) => { Ok(mut session) => {
if let Some(model) = args.model { if let Some(model) = args.model {
@ -99,9 +95,22 @@ async fn main() -> Result<()> {
// Print configuration info // Print configuration info
config.print_config_info(); config.print_config_info();
// Run the CLI // Run the CLI and handle SIGINT gracefully
let mut cli = ChatCLI::new(session, config).context("Failed to initialize CLI")?; let mut cli = ChatCLI::new(session, config).context("Failed to initialize CLI")?;
cli.run().await.context("CLI error")?;
let mut cli_run = Box::pin(cli.run());
tokio::select! {
res = &mut cli_run => {
res.context("CLI error")?;
}
_ = sigint.recv() => {
// Drop the CLI future before using `cli` again
drop(cli_run);
display.print_info("Received interrupt signal. Exiting...");
cli.save_and_cleanup()?;
}
}
Ok(()) Ok(())
} }