Merge branch 'master' into codex/refactor-handle_user_message-to-pass-references

This commit is contained in:
Christopher 2025-08-25 00:56:15 -04:00 committed by GitHub
commit 222c1c2182
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 228 additions and 127 deletions

12
Cargo.lock generated
View File

@ -590,13 +590,13 @@ dependencies = [
"futures", "futures",
"indicatif", "indicatif",
"mockall", "mockall",
"once_cell",
"regex", "regex",
"reqwest", "reqwest",
"rustyline", "rustyline",
"serde", "serde",
"serde_json", "serde_json",
"serial_test", "serial_test",
"signal-hook",
"syntect", "syntect",
"tempfile", "tempfile",
"tokio", "tokio",
@ -1585,16 +1585,6 @@ version = "1.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64"
[[package]]
name = "signal-hook"
version = "0.3.18"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d881a16cf4426aa584979d30bd82cb33429027e42122b169753d6ef1085ed6e2"
dependencies = [
"libc",
"signal-hook-registry",
]
[[package]] [[package]]
name = "signal-hook-registry" name = "signal-hook-registry"
version = "1.4.6" version = "1.4.6"

View File

@ -23,7 +23,8 @@ syntect = "5.1"
regex = "1.0" regex = "1.0"
futures = "0.3" futures = "0.3"
tokio-stream = "0.1" tokio-stream = "0.1"
signal-hook = "0.3" once_cell = "1.21"
[dev-dependencies] [dev-dependencies]
tempfile = "3.0" tempfile = "3.0"

View File

@ -2,7 +2,8 @@ use anyhow::Result;
use crate::config::Config; use crate::config::Config;
use crate::core::{ use crate::core::{
create_client, get_provider_for_model, provider::{get_model_info_list, get_display_name_for_model, get_model_id_from_display_name}, create_client, get_provider_for_model,
provider::{get_display_name_for_model, get_model_id_from_display_name, get_model_info_list},
ChatClient, Session, ChatClient, Session,
}; };
use crate::utils::{Display, InputHandler, SessionAction}; use crate::utils::{Display, InputHandler, SessionAction};
@ -41,14 +42,16 @@ impl ChatCLI {
pub async fn run(&mut self) -> Result<()> { pub async fn run(&mut self) -> Result<()> {
self.display.print_header(); self.display.print_header();
self.display.print_info("Type your message and press Enter. Commands start with '/'."); self.display
.print_info("Type your message and press Enter. Commands start with '/'.");
self.display.print_info("Type /help for help."); self.display.print_info("Type /help for help.");
let provider = get_provider_for_model(&self.session.model); let provider = get_provider_for_model(&self.session.model);
let display_name = get_display_name_for_model(&self.session.model); let display_name = get_display_name_for_model(&self.session.model);
self.display.print_model_info(&display_name, provider.as_str()); self.display
.print_model_info(&display_name, provider.as_str());
self.display.print_session_info(&self.session.name); self.display.print_session_info(&self.session.name);
println!(); println!();
loop { loop {
@ -77,6 +80,10 @@ impl ChatCLI {
} }
} }
self.save_and_cleanup()
}
pub fn save_and_cleanup(&mut self) -> Result<()> {
self.session.save()?; self.session.save()?;
self.input.cleanup()?; // Use cleanup instead of just save_history self.input.cleanup()?; // Use cleanup instead of just save_history
Ok(()) Ok(())
@ -86,6 +93,16 @@ impl ChatCLI {
self.session.add_user_message(message.to_string()); self.session.add_user_message(message.to_string());
self.session.save()?; self.session.save()?;
// Clone data needed for the API call before getting mutable client reference
let model = self.session.model.clone();
let messages = self.session.messages.clone();
let enable_web_search = self.session.enable_web_search;
let enable_reasoning_summary = self.session.enable_reasoning_summary;
let reasoning_effort = self.session.reasoning_effort.clone();
let enable_extended_thinking = self.session.enable_extended_thinking;
let thinking_budget_tokens = self.session.thinking_budget_tokens;
// Check if we should use streaming before getting client // Check if we should use streaming before getting client
let should_use_streaming = { let should_use_streaming = {
let client = self.get_client()?; let client = self.get_client()?;
@ -97,7 +114,7 @@ impl ChatCLI {
print!("{}> ", console::style("🤖").magenta()); print!("{}> ", console::style("🤖").magenta());
use std::io::{self, Write}; use std::io::{self, Write};
io::stdout().flush().ok(); io::stdout().flush().ok();
let stream_callback = { let stream_callback = {
use crate::core::StreamCallback; use crate::core::StreamCallback;
Box::new(move |chunk: &str| { Box::new(move |chunk: &str| {
@ -107,9 +124,11 @@ impl ChatCLI {
Box::pin(async move {}) as Pin<Box<dyn Future<Output = ()> + Send>> Box::pin(async move {}) as Pin<Box<dyn Future<Output = ()> + Send>>
}) as StreamCallback }) as StreamCallback
}; };
let client = self.get_client()?.clone(); let client = self.get_client()?.clone();
let response = client let response = client
.chat_completion_stream( .chat_completion_stream(
&self.session.model, &self.session.model,
&self.session.messages, &self.session.messages,
@ -129,15 +148,18 @@ impl ChatCLI {
} }
Err(e) => { Err(e) => {
println!(); // Add newline after failed streaming println!(); // Add newline after failed streaming
self.display.print_error(&format!("Streaming failed: {}", e)); self.display
.print_error(&format!("Streaming failed: {}", e));
return Err(e); return Err(e);
} }
} }
} else { } else {
// Fallback to non-streaming // Fallback to non-streaming
let spinner = self.display.show_spinner("Thinking"); let spinner = self.display.show_spinner("Thinking");
let client = self.get_client()?.clone(); let client = self.get_client()?.clone();
let response = client let response = client
.chat_completion( .chat_completion(
&self.session.model, &self.session.model,
&self.session.messages, &self.session.messages,
@ -208,7 +230,8 @@ impl ChatCLI {
self.handle_save_command(&parts)?; self.handle_save_command(&parts)?;
} }
_ => { _ => {
self.display.print_error(&format!("Unknown command: {} (see /help)", parts[0])); self.display
.print_error(&format!("Unknown command: {} (see /help)", parts[0]));
} }
} }
@ -217,15 +240,18 @@ impl ChatCLI {
async fn model_switcher(&mut self) -> Result<()> { async fn model_switcher(&mut self) -> Result<()> {
let model_info_list = get_model_info_list(); let model_info_list = get_model_info_list();
let display_names: Vec<String> = model_info_list.iter().map(|info| info.display_name.to_string()).collect(); let display_names: Vec<String> = model_info_list
.iter()
.map(|info| info.display_name.to_string())
.collect();
let current_display_name = get_display_name_for_model(&self.session.model); let current_display_name = get_display_name_for_model(&self.session.model);
let selection = self.input.select_from_list( let selection = self.input.select_from_list(
"Select a model:", "Select a model:",
&display_names, &display_names,
Some(&current_display_name), Some(&current_display_name),
)?; )?;
match selection { match selection {
Some(display_name) => { Some(display_name) => {
if let Some(model_id) = get_model_id_from_display_name(&display_name) { if let Some(model_id) = get_model_id_from_display_name(&display_name) {
@ -250,11 +276,10 @@ impl ChatCLI {
self.display.print_info("Model selection cancelled"); self.display.print_info("Model selection cancelled");
} }
} }
Ok(()) Ok(())
} }
fn handle_new_session(&mut self, parts: &[&str]) -> Result<()> { fn handle_new_session(&mut self, parts: &[&str]) -> Result<()> {
if parts.len() != 2 { if parts.len() != 2 {
self.display.print_error("Usage: /new <session_name>"); self.display.print_error("Usage: /new <session_name>");
@ -264,18 +289,16 @@ impl ChatCLI {
self.session.save()?; self.session.save()?;
let new_session = Session::new(parts[1].to_string(), self.session.model.clone()); let new_session = Session::new(parts[1].to_string(), self.session.model.clone());
self.session = new_session; self.session = new_session;
self.display.print_command_result(&format!("New session '{}' started", self.session.name)); self.display
.print_command_result(&format!("New session '{}' started", self.session.name));
Ok(()) Ok(())
} }
async fn session_manager(&mut self) -> Result<()> { async fn session_manager(&mut self) -> Result<()> {
loop { loop {
let sessions = Session::list_sessions()?; let sessions = Session::list_sessions()?;
let session_names: Vec<String> = sessions let session_names: Vec<String> = sessions.into_iter().map(|(name, _)| name).collect();
.into_iter()
.map(|(name, _)| name)
.collect();
if session_names.is_empty() { if session_names.is_empty() {
self.display.print_info("No sessions available"); self.display.print_info("No sessions available");
@ -294,7 +317,7 @@ impl ChatCLI {
self.display.print_info("Already in that session"); self.display.print_info("Already in that session");
return Ok(()); return Ok(());
} }
self.session.save()?; self.session.save()?;
match Session::load(&session_name) { match Session::load(&session_name) {
Ok(session) => { Ok(session) => {
@ -308,7 +331,8 @@ impl ChatCLI {
return Ok(()); return Ok(());
} }
Err(e) => { Err(e) => {
self.display.print_error(&format!("Failed to load session: {}", e)); self.display
.print_error(&format!("Failed to load session: {}", e));
// Don't return, allow user to try again or cancel // Don't return, allow user to try again or cancel
} }
} }
@ -316,8 +340,11 @@ impl ChatCLI {
SessionAction::Delete(session_name) => { SessionAction::Delete(session_name) => {
match Session::delete_session(&session_name) { match Session::delete_session(&session_name) {
Ok(()) => { Ok(()) => {
self.display.print_command_result(&format!("Session '{}' deleted", session_name)); self.display.print_command_result(&format!(
"Session '{}' deleted",
session_name
));
// If we deleted the current session, we need to handle this specially // If we deleted the current session, we need to handle this specially
if session_name == self.session.name { if session_name == self.session.name {
// Try to switch to another session or create a default one // Try to switch to another session or create a default one
@ -326,18 +353,23 @@ impl ChatCLI {
.into_iter() .into_iter()
.map(|(name, _)| name) .map(|(name, _)| name)
.collect(); .collect();
if remaining_names.is_empty() { if remaining_names.is_empty() {
// No sessions left, create a default one // No sessions left, create a default one
self.session = Session::new("default".to_string(), self.session.model.clone()); self.session = Session::new(
self.display.print_command_result("Created new default session"); "default".to_string(),
self.session.model.clone(),
);
self.display
.print_command_result("Created new default session");
return Ok(()); return Ok(());
} else { } else {
// Switch to the first available session // Switch to the first available session
match Session::load(&remaining_names[0]) { match Session::load(&remaining_names[0]) {
Ok(session) => { Ok(session) => {
self.session = session; self.session = session;
let display_name = get_display_name_for_model(&self.session.model); let display_name =
get_display_name_for_model(&self.session.model);
self.display.print_command_result(&format!( self.display.print_command_result(&format!(
"Switched to session '{}' (model={})", "Switched to session '{}' (model={})",
self.session.name, display_name self.session.name, display_name
@ -346,10 +378,18 @@ impl ChatCLI {
return Ok(()); return Ok(());
} }
Err(e) => { Err(e) => {
self.display.print_error(&format!("Failed to load fallback session: {}", e)); self.display.print_error(&format!(
"Failed to load fallback session: {}",
e
));
// Create a new default session as fallback // Create a new default session as fallback
self.session = Session::new("default".to_string(), self.session.model.clone()); self.session = Session::new(
self.display.print_command_result("Created new default session"); "default".to_string(),
self.session.model.clone(),
);
self.display.print_command_result(
"Created new default session",
);
return Ok(()); return Ok(());
} }
} }
@ -358,7 +398,8 @@ impl ChatCLI {
// Continue to show updated session list if we didn't delete current session // Continue to show updated session list if we didn't delete current session
} }
Err(e) => { Err(e) => {
self.display.print_error(&format!("Failed to delete session: {}", e)); self.display
.print_error(&format!("Failed to delete session: {}", e));
// Continue to allow retry // Continue to allow retry
} }
} }
@ -373,7 +414,8 @@ impl ChatCLI {
)); ));
} }
Err(e) => { Err(e) => {
self.display.print_error(&format!("Failed to set default session: {}", e)); self.display
.print_error(&format!("Failed to set default session: {}", e));
} }
} }
// Continue to show session list // Continue to show session list
@ -385,32 +427,48 @@ impl ChatCLI {
} }
} }
async fn tools_manager(&mut self) -> Result<()> { async fn tools_manager(&mut self) -> Result<()> {
loop { loop {
// Show current tool status // Show current tool status
self.display.print_info("Tool Management:"); self.display.print_info("Tool Management:");
let web_status = if self.session.enable_web_search { "✓ enabled" } else { "✗ disabled" }; let web_status = if self.session.enable_web_search {
let reasoning_status = if self.session.enable_reasoning_summary { "✓ enabled" } else { "✗ disabled" }; "✓ enabled"
let extended_thinking_status = if self.session.enable_extended_thinking { "✓ enabled" } else { "✗ disabled" }; } else {
"✗ disabled"
};
let reasoning_status = if self.session.enable_reasoning_summary {
"✓ enabled"
} else {
"✗ disabled"
};
let extended_thinking_status = if self.session.enable_extended_thinking {
"✓ enabled"
} else {
"✗ disabled"
};
println!(" Web Search: {}", web_status); println!(" Web Search: {}", web_status);
println!(" Reasoning Summaries: {}", reasoning_status); println!(" Reasoning Summaries: {}", reasoning_status);
println!(" Reasoning Effort: {}", self.session.reasoning_effort); println!(" Reasoning Effort: {}", self.session.reasoning_effort);
println!(" Extended Thinking: {}", extended_thinking_status); println!(" Extended Thinking: {}", extended_thinking_status);
println!(" Thinking Budget: {} tokens", self.session.thinking_budget_tokens); println!(
" Thinking Budget: {} tokens",
self.session.thinking_budget_tokens
);
// Check model compatibility // Check model compatibility
let model = self.session.model.clone(); let model = self.session.model.clone();
let provider = get_provider_for_model(&model); let provider = get_provider_for_model(&model);
let reasoning_enabled = self.session.enable_reasoning_summary; let reasoning_enabled = self.session.enable_reasoning_summary;
// Show compatibility warnings based on provider // Show compatibility warnings based on provider
match provider { match provider {
crate::core::provider::Provider::Anthropic => { crate::core::provider::Provider::Anthropic => {
if reasoning_enabled { if reasoning_enabled {
self.display.print_warning("Reasoning summaries are not supported by Anthropic models"); self.display.print_warning(
"Reasoning summaries are not supported by Anthropic models",
);
} }
if self.session.enable_extended_thinking { if self.session.enable_extended_thinking {
// Extended thinking is supported by Anthropic models // Extended thinking is supported by Anthropic models
@ -419,38 +477,47 @@ impl ChatCLI {
} }
crate::core::provider::Provider::OpenAI => { crate::core::provider::Provider::OpenAI => {
if self.session.enable_extended_thinking { if self.session.enable_extended_thinking {
self.display.print_warning("Extended thinking is not supported by OpenAI models"); self.display
.print_warning("Extended thinking is not supported by OpenAI models");
} }
// OpenAI models generally support other features // OpenAI models generally support other features
} }
} }
// Tool management options // Tool management options
let options = vec![ let options = vec![
"Toggle Web Search", "Toggle Web Search",
"Toggle Reasoning Summaries", "Toggle Reasoning Summaries",
"Set Reasoning Effort", "Set Reasoning Effort",
"Toggle Extended Thinking", "Toggle Extended Thinking",
"Set Thinking Budget", "Set Thinking Budget",
"Done" "Done",
]; ];
let selection = self.input.select_from_list( let selection = self
"Select an option:", .input
&options, .select_from_list("Select an option:", &options, None)?;
None,
)?;
match selection.as_deref() { match selection.as_deref() {
Some("Toggle Web Search") => { Some("Toggle Web Search") => {
self.session.enable_web_search = !self.session.enable_web_search; self.session.enable_web_search = !self.session.enable_web_search;
let state = if self.session.enable_web_search { "enabled" } else { "disabled" }; let state = if self.session.enable_web_search {
self.display.print_command_result(&format!("Web search {}", state)); "enabled"
} else {
"disabled"
};
self.display
.print_command_result(&format!("Web search {}", state));
} }
Some("Toggle Reasoning Summaries") => { Some("Toggle Reasoning Summaries") => {
self.session.enable_reasoning_summary = !self.session.enable_reasoning_summary; self.session.enable_reasoning_summary = !self.session.enable_reasoning_summary;
let state = if self.session.enable_reasoning_summary { "enabled" } else { "disabled" }; let state = if self.session.enable_reasoning_summary {
self.display.print_command_result(&format!("Reasoning summaries {}", state)); "enabled"
} else {
"disabled"
};
self.display
.print_command_result(&format!("Reasoning summaries {}", state));
} }
Some("Set Reasoning Effort") => { Some("Set Reasoning Effort") => {
let effort_options = vec!["low", "medium", "high"]; let effort_options = vec!["low", "medium", "high"];
@ -460,22 +527,32 @@ impl ChatCLI {
Some(&self.session.reasoning_effort), Some(&self.session.reasoning_effort),
)? { )? {
self.session.reasoning_effort = effort.to_string(); self.session.reasoning_effort = effort.to_string();
self.display.print_command_result(&format!("Reasoning effort set to {}", effort)); self.display
.print_command_result(&format!("Reasoning effort set to {}", effort));
if !self.session.model.starts_with("gpt-5") { if !self.session.model.starts_with("gpt-5") {
self.display.print_warning("Reasoning effort is only supported by GPT-5 models"); self.display.print_warning(
"Reasoning effort is only supported by GPT-5 models",
);
} }
} }
} }
Some("Toggle Extended Thinking") => { Some("Toggle Extended Thinking") => {
self.session.enable_extended_thinking = !self.session.enable_extended_thinking; self.session.enable_extended_thinking = !self.session.enable_extended_thinking;
let state = if self.session.enable_extended_thinking { "enabled" } else { "disabled" }; let state = if self.session.enable_extended_thinking {
self.display.print_command_result(&format!("Extended thinking {}", state)); "enabled"
} else {
"disabled"
};
self.display
.print_command_result(&format!("Extended thinking {}", state));
let provider = get_provider_for_model(&self.session.model); let provider = get_provider_for_model(&self.session.model);
match provider { match provider {
crate::core::provider::Provider::OpenAI => { crate::core::provider::Provider::OpenAI => {
self.display.print_warning("Extended thinking is not supported by OpenAI models"); self.display.print_warning(
"Extended thinking is not supported by OpenAI models",
);
} }
crate::core::provider::Provider::Anthropic => { crate::core::provider::Provider::Anthropic => {
// Supported // Supported
@ -492,12 +569,17 @@ impl ChatCLI {
)? { )? {
if let Ok(budget) = budget_str.parse::<u32>() { if let Ok(budget) = budget_str.parse::<u32>() {
self.session.thinking_budget_tokens = budget; self.session.thinking_budget_tokens = budget;
self.display.print_command_result(&format!("Thinking budget set to {} tokens", budget)); self.display.print_command_result(&format!(
"Thinking budget set to {} tokens",
budget
));
let provider = get_provider_for_model(&self.session.model); let provider = get_provider_for_model(&self.session.model);
match provider { match provider {
crate::core::provider::Provider::OpenAI => { crate::core::provider::Provider::OpenAI => {
self.display.print_warning("Extended thinking is not supported by OpenAI models"); self.display.print_warning(
"Extended thinking is not supported by OpenAI models",
);
} }
crate::core::provider::Provider::Anthropic => { crate::core::provider::Provider::Anthropic => {
if budget < 1024 { if budget < 1024 {
@ -513,18 +595,18 @@ impl ChatCLI {
} }
_ => {} _ => {}
} }
self.session.save()?; // Save changes after each modification self.session.save()?; // Save changes after each modification
println!(); // Add spacing println!(); // Add spacing
} }
Ok(()) Ok(())
} }
fn handle_history_command(&mut self, parts: &[&str]) -> Result<()> { fn handle_history_command(&mut self, parts: &[&str]) -> Result<()> {
let mut filter_role: Option<&str> = None; let mut filter_role: Option<&str> = None;
let mut limit: Option<usize> = None; let mut limit: Option<usize> = None;
// Parse parameters // Parse parameters
for &part in parts.iter().skip(1) { for &part in parts.iter().skip(1) {
match part { match part {
@ -533,37 +615,41 @@ impl ChatCLI {
if let Ok(num) = part.parse::<usize>() { if let Ok(num) = part.parse::<usize>() {
limit = Some(num); limit = Some(num);
} else { } else {
self.display.print_error(&format!("Invalid parameter: {}", part)); self.display
self.display.print_info("Usage: /history [user|assistant] [number]"); .print_error(&format!("Invalid parameter: {}", part));
self.display
.print_info("Usage: /history [user|assistant] [number]");
return Ok(()); return Ok(());
} }
} }
} }
} }
// Filter messages (skip system prompt at index 0) // Filter messages (skip system prompt at index 0)
let mut messages: Vec<(usize, &crate::core::Message)> = self.session.messages let mut messages: Vec<(usize, &crate::core::Message)> = self
.session
.messages
.iter() .iter()
.enumerate() .enumerate()
.skip(1) // Skip system prompt .skip(1) // Skip system prompt
.collect(); .collect();
// Apply role filter // Apply role filter
if let Some(role) = filter_role { if let Some(role) = filter_role {
messages.retain(|(_, msg)| msg.role == role); messages.retain(|(_, msg)| msg.role == role);
} }
// Apply limit // Apply limit
if let Some(limit_count) = limit { if let Some(limit_count) = limit {
let start_index = messages.len().saturating_sub(limit_count); let start_index = messages.len().saturating_sub(limit_count);
messages = messages[start_index..].to_vec(); messages = messages[start_index..].to_vec();
} }
if messages.is_empty() { if messages.is_empty() {
self.display.print_info("No messages to display"); self.display.print_info("No messages to display");
return Ok(()); return Ok(());
} }
// Format and display // Format and display
self.display.print_conversation_history(&messages); self.display.print_conversation_history(&messages);
Ok(()) Ok(())
@ -580,7 +666,7 @@ impl ChatCLI {
let valid_formats = ["markdown", "md", "json", "txt"]; let valid_formats = ["markdown", "md", "json", "txt"];
if !valid_formats.contains(&format.as_str()) { if !valid_formats.contains(&format.as_str()) {
self.display.print_error(&format!( self.display.print_error(&format!(
"Invalid format '{}'. Supported formats: markdown, json, txt", "Invalid format '{}'. Supported formats: markdown, json, txt",
format format
)); ));
self.display.print_info("Usage: /export [format]"); self.display.print_info("Usage: /export [format]");
@ -592,19 +678,20 @@ impl ChatCLI {
.duration_since(std::time::UNIX_EPOCH) .duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default() .unwrap_or_default()
.as_secs(); .as_secs();
let extension = match format.as_str() { let extension = match format.as_str() {
"markdown" | "md" => "md", "markdown" | "md" => "md",
"json" => "json", "json" => "json",
"txt" => "txt", "txt" => "txt",
_ => "md", _ => "md",
}; };
// Create exports directory if it doesn't exist // Create exports directory if it doesn't exist
let exports_dir = std::path::Path::new("exports"); let exports_dir = std::path::Path::new("exports");
if !exports_dir.exists() { if !exports_dir.exists() {
if let Err(e) = std::fs::create_dir(exports_dir) { if let Err(e) = std::fs::create_dir(exports_dir) {
self.display.print_error(&format!("Failed to create exports directory: {}", e)); self.display
.print_error(&format!("Failed to create exports directory: {}", e));
return Ok(()); return Ok(());
} }
} }
@ -612,10 +699,13 @@ impl ChatCLI {
let filename = format!("{}_{}.{}", self.session.name, now, extension); let filename = format!("{}_{}.{}", self.session.name, now, extension);
let file_path = exports_dir.join(&filename); let file_path = exports_dir.join(&filename);
match self.session.export(&format, file_path.to_str().unwrap_or(&filename)) { match self
.session
.export(&format, file_path.to_str().unwrap_or(&filename))
{
Ok(()) => { Ok(()) => {
self.display.print_command_result(&format!( self.display.print_command_result(&format!(
"Conversation exported to '{}'", "Conversation exported to '{}'",
file_path.display() file_path.display()
)); ));
} }
@ -630,7 +720,8 @@ impl ChatCLI {
fn handle_save_command(&mut self, parts: &[&str]) -> Result<()> { fn handle_save_command(&mut self, parts: &[&str]) -> Result<()> {
if parts.len() != 2 { if parts.len() != 2 {
self.display.print_error("Usage: /save <new_session_name>"); self.display.print_error("Usage: /save <new_session_name>");
self.display.print_info("Example: /save my_important_conversation"); self.display
.print_info("Example: /save my_important_conversation");
return Ok(()); return Ok(());
} }
@ -646,7 +737,7 @@ impl ChatCLI {
if let Ok(sessions) = Session::list_sessions() { if let Ok(sessions) = Session::list_sessions() {
if sessions.iter().any(|(name, _)| name == new_session_name) { if sessions.iter().any(|(name, _)| name == new_session_name) {
self.display.print_error(&format!( self.display.print_error(&format!(
"Session '{}' already exists. Choose a different name.", "Session '{}' already exists. Choose a different name.",
new_session_name new_session_name
)); ));
return Ok(()); return Ok(());
@ -660,7 +751,7 @@ impl ChatCLI {
match self.session.save_as(new_session_name) { match self.session.save_as(new_session_name) {
Ok(()) => { Ok(()) => {
self.display.print_command_result(&format!( self.display.print_command_result(&format!(
"Current session saved as '{}' ({} messages copied)", "Current session saved as '{}' ({} messages copied)",
new_session_name, new_session_name,
self.session.messages.len().saturating_sub(1) // Exclude system prompt self.session.messages.len().saturating_sub(1) // Exclude system prompt
)); ));
@ -670,11 +761,11 @@ impl ChatCLI {
)); ));
} }
Err(e) => { Err(e) => {
self.display.print_error(&format!("Failed to save session: {}", e)); self.display
.print_error(&format!("Failed to save session: {}", e));
} }
} }
Ok(()) Ok(())
} }
}
}

View File

@ -1,5 +1,6 @@
use anyhow::{Context, Result}; use anyhow::{Context, Result};
use chrono::{DateTime, Utc}; use chrono::{DateTime, Utc};
use once_cell::sync::OnceCell;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::fs; use std::fs;
use std::path::PathBuf; use std::path::PathBuf;
@ -12,6 +13,12 @@ ASCII art or concise bullet lists over heavy markup, and wrap code \
snippets in fenced blocks when helpful. Do not emit trailing spaces or \ snippets in fenced blocks when helpful. Do not emit trailing spaces or \
control characters."; control characters.";
static CONFIG: OnceCell<Config> = OnceCell::new();
fn get_config() -> &'static Config {
CONFIG.get_or_init(|| Config::load().unwrap_or_default())
}
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Message { pub struct Message {
pub role: String, pub role: String,
@ -98,7 +105,7 @@ impl Session {
} }
pub fn sessions_dir() -> Result<PathBuf> { pub fn sessions_dir() -> Result<PathBuf> {
let config = Config::load().unwrap_or_default(); let config = get_config();
let home = dirs::home_dir().context("Could not find home directory")?; let home = dirs::home_dir().context("Could not find home directory")?;
let sessions_dir = home.join(&config.session.sessions_dir_name); let sessions_dir = home.join(&config.session.sessions_dir_name);
@ -111,11 +118,12 @@ impl Session {
} }
pub fn session_path(name: &str) -> Result<PathBuf> { pub fn session_path(name: &str) -> Result<PathBuf> {
let config = Config::load().unwrap_or_default(); let config = get_config();
Ok(Self::sessions_dir()?.join(format!("{}.{}", name, config.session.file_extension))) Ok(Self::sessions_dir()?.join(format!("{}.{}", name, config.session.file_extension)))
} }
pub fn save(&self) -> Result<()> { pub fn save(&self) -> Result<()> {
let _config = get_config();
let data = SessionData { let data = SessionData {
model: self.model.clone(), model: self.model.clone(),
messages: self.messages.clone(), messages: self.messages.clone(),
@ -143,6 +151,7 @@ impl Session {
} }
pub fn load(name: &str) -> Result<Self> { pub fn load(name: &str) -> Result<Self> {
let _config = get_config();
let path = Self::session_path(name)?; let path = Self::session_path(name)?;
if !path.exists() { if !path.exists() {
@ -195,7 +204,7 @@ impl Session {
/// Truncates conversation history to stay within configured limits /// Truncates conversation history to stay within configured limits
fn truncate_history_if_needed(&mut self) { fn truncate_history_if_needed(&mut self) {
let config = Config::load().unwrap_or_default(); let config = get_config();
let max_history = config.limits.max_conversation_history; let max_history = config.limits.max_conversation_history;
// Always preserve the system prompt (first message) // Always preserve the system prompt (first message)
@ -252,13 +261,14 @@ impl Session {
} }
pub fn list_sessions() -> Result<Vec<(String, DateTime<Utc>)>> { pub fn list_sessions() -> Result<Vec<(String, DateTime<Utc>)>> {
let _config = get_config();
let sessions = Self::list_sessions_lazy(false)?; let sessions = Self::list_sessions_lazy(false)?;
Ok(sessions.into_iter().map(|s| (s.name, s.last_modified)).collect()) Ok(sessions.into_iter().map(|s| (s.name, s.last_modified)).collect())
} }
/// Lists sessions with lazy loading - only loads full data if detailed=true /// Lists sessions with lazy loading - only loads full data if detailed=true
pub fn list_sessions_lazy(detailed: bool) -> Result<Vec<SessionInfo>> { pub fn list_sessions_lazy(detailed: bool) -> Result<Vec<SessionInfo>> {
let config = Config::load().unwrap_or_default(); let config = get_config();
let sessions_dir = Self::sessions_dir()?; let sessions_dir = Self::sessions_dir()?;
if !sessions_dir.exists() { if !sessions_dir.exists() {
@ -351,7 +361,7 @@ impl Session {
/// Checks if session needs cleanup based on size and age /// Checks if session needs cleanup based on size and age
pub fn needs_cleanup(&self) -> bool { pub fn needs_cleanup(&self) -> bool {
let stats = self.get_stats(); let stats = self.get_stats();
let config = Config::load().unwrap_or_default(); let config = get_config();
// Check if conversation is too long // Check if conversation is too long
if stats.total_messages > config.limits.max_conversation_history * 2 { if stats.total_messages > config.limits.max_conversation_history * 2 {
@ -368,7 +378,7 @@ impl Session {
/// Performs aggressive cleanup for memory optimization /// Performs aggressive cleanup for memory optimization
pub fn cleanup_for_memory(&mut self) { pub fn cleanup_for_memory(&mut self) {
let config = Config::load().unwrap_or_default(); let config = get_config();
let target_messages = config.limits.max_conversation_history / 2; let target_messages = config.limits.max_conversation_history / 2;
if self.messages.len() > target_messages + 1 { if self.messages.len() > target_messages + 1 {

View File

@ -5,9 +5,7 @@ mod utils;
use anyhow::{Context, Result}; use anyhow::{Context, Result};
use clap::Parser; use clap::Parser;
use signal_hook::{consts::SIGINT, iterator::Signals}; use tokio::signal::unix::{signal, SignalKind};
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use crate::cli::ChatCLI; use crate::cli::ChatCLI;
use crate::config::Config; use crate::config::Config;
@ -19,7 +17,11 @@ use crate::utils::Display;
#[command(about = "A lightweight command-line interface for chatting with AI models")] #[command(about = "A lightweight command-line interface for chatting with AI models")]
#[command(version)] #[command(version)]
struct Args { struct Args {
#[arg(short, long, help = "Session name (defaults to configured default session)")] #[arg(
short,
long,
help = "Session name (defaults to configured default session)"
)]
session: Option<String>, session: Option<String>,
#[arg(short, long, help = "Model name to use (overrides saved value)")] #[arg(short, long, help = "Model name to use (overrides saved value)")]
@ -34,16 +36,8 @@ async fn main() -> Result<()> {
let args = Args::parse(); let args = Args::parse();
let display = Display::new(); let display = Display::new();
// Set up signal handling for proper cleanup // Set up signal handling using Tokio's signal support
let term = Arc::new(AtomicBool::new(false)); let mut sigint = signal(SignalKind::interrupt())?;
let term_clone = term.clone();
std::thread::spawn(move || {
let mut signals = Signals::new(&[SIGINT]).unwrap();
for _ in signals.forever() {
term_clone.store(true, Ordering::Relaxed);
}
});
// Handle config creation // Handle config creation
if args.create_config { if args.create_config {
@ -53,14 +47,16 @@ async fn main() -> Result<()> {
// Load configuration // Load configuration
let config = Config::load().context("Failed to load configuration")?; let config = Config::load().context("Failed to load configuration")?;
// Validate environment variables // Validate environment variables
let env_vars = Config::validate_env_variables().context("Environment validation failed")?; let env_vars = Config::validate_env_variables().context("Environment validation failed")?;
// Load or create session // Load or create session
// Use configured default session if none specified // Use configured default session if none specified
let session_name = args.session.unwrap_or_else(|| config.defaults.default_session.clone()); let session_name = args
.session
.unwrap_or_else(|| config.defaults.default_session.clone());
let session = match Session::load(&session_name) { let session = match Session::load(&session_name) {
Ok(mut session) => { Ok(mut session) => {
if let Some(model) = args.model { if let Some(model) = args.model {
@ -99,9 +95,22 @@ async fn main() -> Result<()> {
// Print configuration info // Print configuration info
config.print_config_info(); config.print_config_info();
// Run the CLI // Run the CLI and handle SIGINT gracefully
let mut cli = ChatCLI::new(session, config).context("Failed to initialize CLI")?; let mut cli = ChatCLI::new(session, config).context("Failed to initialize CLI")?;
cli.run().await.context("CLI error")?;
let mut cli_run = Box::pin(cli.run());
tokio::select! {
res = &mut cli_run => {
res.context("CLI error")?;
}
_ = sigint.recv() => {
// Drop the CLI future before using `cli` again
drop(cli_run);
display.print_info("Received interrupt signal. Exiting...");
cli.save_and_cleanup()?;
}
}
Ok(()) Ok(())
} }