From 7d237f692c9d26b5e987f7bb7b6d25c6373166fc Mon Sep 17 00:00:00 2001 From: leach Date: Fri, 15 Aug 2025 15:41:32 -0400 Subject: [PATCH] config file, bugfixes --- Cargo.lock | 60 +++++++++++ Cargo.toml | 1 + src/cli.rs | 7 +- src/config.rs | 235 ++++++++++++++++++++++++++++++++++++++++++++ src/core/client.rs | 46 +++++---- src/core/session.rs | 8 +- src/main.rs | 36 +++++-- 7 files changed, 364 insertions(+), 29 deletions(-) create mode 100644 src/config.rs diff --git a/Cargo.lock b/Cargo.lock index 28b9105..fcfe9a0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -457,6 +457,7 @@ dependencies = [ "serde", "serde_json", "tokio", + "toml", ] [[package]] @@ -1199,6 +1200,15 @@ dependencies = [ "serde", ] +[[package]] +name = "serde_spanned" +version = "0.6.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bf41e0cfaf7226dca15e8197172c295a782857fcb97fad1808a166870dee75a3" +dependencies = [ + "serde", +] + [[package]] name = "serde_urlencoded" version = "0.7.1" @@ -1422,6 +1432,47 @@ dependencies = [ "tokio", ] +[[package]] +name = "toml" +version = "0.8.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc1beb996b9d83529a9e75c17a1686767d148d70663143c7854d8b4a09ced362" +dependencies = [ + "serde", + "serde_spanned", + "toml_datetime", + "toml_edit", +] + +[[package]] +name = "toml_datetime" +version = "0.6.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "22cddaf88f4fbc13c51aebbf5f8eceb5c7c5a9da2ac40a13519eb5b0a0e8f11c" +dependencies = [ + "serde", +] + +[[package]] +name = "toml_edit" +version = "0.22.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "41fe8c660ae4257887cf66394862d21dbca4a6ddd26f04a3560410406a2f819a" +dependencies = [ + "indexmap", + "serde", + "serde_spanned", + "toml_datetime", + "toml_write", + "winnow", +] + +[[package]] +name = "toml_write" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5d99f8c9a7727884afe522e9bd5edbfc91a3312b36a77b5fb8926e4c31a41801" + [[package]] name = "tower-service" version = "0.3.3" @@ -1930,6 +1981,15 @@ version = "0.53.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "271414315aff87387382ec3d271b52d7ae78726f5d44ac98b4f4030c91880486" +[[package]] +name = "winnow" +version = "0.7.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f3edebf492c8125044983378ecb5766203ad3b4c2f7a922bd7dd207f6d443e95" +dependencies = [ + "memchr", +] + [[package]] name = "winreg" version = "0.50.0" diff --git a/Cargo.toml b/Cargo.toml index 5084059..1f6a8c6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,3 +18,4 @@ console = "0.15" indicatif = "0.17" dirs = "5.0" rustyline = "13.0" +toml = "0.8" diff --git a/src/cli.rs b/src/cli.rs index 6a76da5..18a8e3a 100644 --- a/src/cli.rs +++ b/src/cli.rs @@ -1,5 +1,6 @@ use anyhow::Result; +use crate::config::Config; use crate::core::{ create_client, get_provider_for_model, provider::get_all_models, provider::get_supported_models, provider::is_model_supported, ChatClient, Session, @@ -12,22 +13,24 @@ pub struct ChatCLI { current_model: Option, display: Display, input: InputHandler, + config: Config, } impl ChatCLI { - pub fn new(session: Session) -> Result { + pub fn new(session: Session, config: Config) -> Result { Ok(Self { session, client: None, current_model: None, display: Display::new(), input: InputHandler::new()?, + config, }) } fn get_client(&mut self) -> Result<&ChatClient> { if self.client.is_none() || self.current_model.as_ref() != Some(&self.session.model) { - let client = create_client(&self.session.model)?; + let client = create_client(&self.session.model, &self.config)?; self.current_model = Some(self.session.model.clone()); self.client = Some(client); } diff --git a/src/config.rs b/src/config.rs new file mode 100644 index 0000000..3fb11ff --- /dev/null +++ b/src/config.rs @@ -0,0 +1,235 @@ +use anyhow::{Context, Result}; +use serde::{Deserialize, Serialize}; +use std::env; +use std::fs; +use std::path::PathBuf; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Config { + pub api: ApiConfig, + pub defaults: DefaultsConfig, + pub limits: LimitsConfig, + pub session: SessionConfig, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ApiConfig { + pub openai_base_url: String, + pub anthropic_base_url: String, + pub anthropic_version: String, + pub request_timeout_seconds: u64, + pub max_retries: u32, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct DefaultsConfig { + pub model: String, + pub reasoning_effort: String, + pub enable_web_search: bool, + pub enable_reasoning_summary: bool, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct LimitsConfig { + pub max_tokens_anthropic: u32, + pub max_conversation_history: usize, + pub max_sessions_to_list: usize, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SessionConfig { + pub sessions_dir_name: String, + pub file_extension: String, +} + +impl Default for Config { + fn default() -> Self { + Self { + api: ApiConfig::default(), + defaults: DefaultsConfig::default(), + limits: LimitsConfig::default(), + session: SessionConfig::default(), + } + } +} + +impl Default for ApiConfig { + fn default() -> Self { + Self { + openai_base_url: "https://api.openai.com/v1".to_string(), + anthropic_base_url: "https://api.anthropic.com/v1".to_string(), + anthropic_version: "2023-06-01".to_string(), + request_timeout_seconds: 120, + max_retries: 3, + } + } +} + +impl Default for DefaultsConfig { + fn default() -> Self { + Self { + model: "gpt-5".to_string(), + reasoning_effort: "medium".to_string(), + enable_web_search: true, + enable_reasoning_summary: false, + } + } +} + +impl Default for LimitsConfig { + fn default() -> Self { + Self { + max_tokens_anthropic: 4096, + max_conversation_history: 100, + max_sessions_to_list: 50, + } + } +} + +impl Default for SessionConfig { + fn default() -> Self { + Self { + sessions_dir_name: ".chat_cli_sessions".to_string(), + file_extension: "json".to_string(), + } + } +} + +#[derive(Debug)] +pub struct EnvVariables { + pub openai_api_key: Option, + pub anthropic_api_key: Option, + pub openai_base_url: Option, + pub default_model: Option, +} + +impl Config { + pub fn load() -> Result { + let config_path = Self::config_file_path()?; + + if config_path.exists() { + let config_content = fs::read_to_string(&config_path) + .with_context(|| format!("Failed to read config file: {:?}", config_path))?; + + let mut config: Config = toml::from_str(&config_content) + .with_context(|| format!("Failed to parse config file: {:?}", config_path))?; + + // Override with environment variables if present + config.apply_env_overrides()?; + + Ok(config) + } else { + let mut config = Config::default(); + config.apply_env_overrides()?; + Ok(config) + } + } + + pub fn save(&self) -> Result<()> { + let config_path = Self::config_file_path()?; + + // Create config directory if it doesn't exist + if let Some(parent) = config_path.parent() { + fs::create_dir_all(parent) + .with_context(|| format!("Failed to create config directory: {:?}", parent))?; + } + + let config_content = toml::to_string_pretty(self) + .context("Failed to serialize config")?; + + fs::write(&config_path, config_content) + .with_context(|| format!("Failed to write config file: {:?}", config_path))?; + + Ok(()) + } + + pub fn config_file_path() -> Result { + let home = dirs::home_dir().context("Could not find home directory")?; + Ok(home.join(".config").join("gpt-cli-rust").join("config.toml")) + } + + fn apply_env_overrides(&mut self) -> Result<()> { + // Override API URLs + if let Ok(openai_base_url) = env::var("OPENAI_BASE_URL") { + self.api.openai_base_url = openai_base_url; + } + + // Override defaults + if let Ok(default_model) = env::var("DEFAULT_MODEL") { + self.defaults.model = default_model; + } + + Ok(()) + } + + pub fn validate_env_variables() -> Result { + let openai_api_key = env::var("OPENAI_API_KEY").ok(); + let anthropic_api_key = env::var("ANTHROPIC_API_KEY").ok(); + let openai_base_url = env::var("OPENAI_BASE_URL").ok(); + let default_model = env::var("DEFAULT_MODEL").ok(); + + // At least one API key must be present + if openai_api_key.is_none() && anthropic_api_key.is_none() { + return Err(anyhow::anyhow!( + "At least one API key must be set: OPENAI_API_KEY or ANTHROPIC_API_KEY" + )); + } + + Ok(EnvVariables { + openai_api_key, + anthropic_api_key, + openai_base_url, + default_model, + }) + } + + pub fn validate_model_availability(&self, env: &EnvVariables, model: &str) -> Result<()> { + use crate::core::provider::{get_provider_for_model, Provider}; + + let provider = get_provider_for_model(model); + + match provider { + Provider::OpenAI => { + if env.openai_api_key.is_none() { + return Err(anyhow::anyhow!( + "OPENAI_API_KEY is required for OpenAI model: {}", model + )); + } + } + Provider::Anthropic => { + if env.anthropic_api_key.is_none() { + return Err(anyhow::anyhow!( + "ANTHROPIC_API_KEY is required for Anthropic model: {}", model + )); + } + } + } + + Ok(()) + } + + pub fn create_example_config() -> Result<()> { + let config_path = Self::config_file_path()?; + + if config_path.exists() { + return Ok(()); // Don't overwrite existing config + } + + let example_config = Config::default(); + example_config.save()?; + + println!("Created example config file at: {:?}", config_path); + println!("You can customize it to change default settings."); + + Ok(()) + } + + pub fn print_config_info(&self) { + println!("📋 Configuration:"); + println!(" Default model: {}", self.defaults.model); + println!(" Web search: {}", if self.defaults.enable_web_search { "enabled" } else { "disabled" }); + println!(" Reasoning summaries: {}", if self.defaults.enable_reasoning_summary { "enabled" } else { "disabled" }); + println!(" Request timeout: {}s", self.api.request_timeout_seconds); + println!(" Max conversation history: {}", self.limits.max_conversation_history); + } +} \ No newline at end of file diff --git a/src/core/client.rs b/src/core/client.rs index 0f7aab6..0fc8bbf 100644 --- a/src/core/client.rs +++ b/src/core/client.rs @@ -3,7 +3,9 @@ use reqwest::Client; use serde::Deserialize; use serde_json::{json, Value}; use std::env; +use std::time::Duration; +use crate::config::Config; use super::{provider::Provider, session::Message}; #[derive(Debug)] @@ -128,7 +130,8 @@ struct SearchAction { #[allow(dead_code)] #[serde(rename = "type")] action_type: String, - query: String, + #[serde(default)] + query: Option, } #[derive(Deserialize)] @@ -164,19 +167,19 @@ struct AnthropicContent { } impl OpenAIClient { - pub fn new() -> Result { + pub fn new(config: &Config) -> Result { let api_key = env::var("OPENAI_API_KEY") .context("OPENAI_API_KEY environment variable is required")?; - - let base_url = env::var("OPENAI_BASE_URL") - .unwrap_or_else(|_| "https://api.openai.com/v1".to_string()); - let client = Client::new(); + let client = Client::builder() + .timeout(Duration::from_secs(config.api.request_timeout_seconds)) + .build() + .context("Failed to create HTTP client")?; Ok(Self { client, api_key, - base_url, + base_url: config.api.openai_base_url.clone(), }) } @@ -368,7 +371,11 @@ impl OpenAIClient { if item.status.as_deref() == Some("completed") { search_count += 1; if let Some(action) = &item.action { - final_content.push_str(&format!("🔍 Search {}: \"{}\"\n", search_count, action.query)); + if let Some(query) = &action.query { + final_content.push_str(&format!("🔍 Search {}: \"{}\"\n", search_count, query)); + } else { + final_content.push_str(&format!("🔍 Search {}: [no query specified]\n", search_count)); + } } } } @@ -439,17 +446,19 @@ impl OpenAIClient { } impl AnthropicClient { - pub fn new() -> Result { + pub fn new(config: &Config) -> Result { let api_key = env::var("ANTHROPIC_API_KEY") .context("ANTHROPIC_API_KEY environment variable is required")?; - - let base_url = "https://api.anthropic.com/v1".to_string(); - let client = Client::new(); + + let client = Client::builder() + .timeout(Duration::from_secs(config.api.request_timeout_seconds)) + .build() + .context("Failed to create HTTP client")?; Ok(Self { client, api_key, - base_url, + base_url: config.api.anthropic_base_url.clone(), }) } @@ -489,9 +498,10 @@ impl AnthropicClient { let (system_prompt, user_messages) = Self::convert_messages(messages); + let config = crate::config::Config::load().unwrap_or_default(); let mut payload = json!({ "model": model, - "max_tokens": 4096, + "max_tokens": config.limits.max_tokens_anthropic, "messages": user_messages }); @@ -504,7 +514,7 @@ impl AnthropicClient { .post(&url) .header("x-api-key", &self.api_key) .header("Content-Type", "application/json") - .header("anthropic-version", "2023-06-01") + .header("anthropic-version", &config.api.anthropic_version) .json(&payload) .send() .await @@ -544,16 +554,16 @@ impl AnthropicClient { } } -pub fn create_client(model: &str) -> Result { +pub fn create_client(model: &str, config: &Config) -> Result { let provider = super::provider::get_provider_for_model(model); match provider { Provider::OpenAI => { - let client = OpenAIClient::new()?; + let client = OpenAIClient::new(config)?; Ok(ChatClient::OpenAI(client)) } Provider::Anthropic => { - let client = AnthropicClient::new()?; + let client = AnthropicClient::new(config)?; Ok(ChatClient::Anthropic(client)) } } diff --git a/src/core/session.rs b/src/core/session.rs index ebabc35..4de9878 100644 --- a/src/core/session.rs +++ b/src/core/session.rs @@ -4,6 +4,8 @@ use serde::{Deserialize, Serialize}; use std::fs; use std::path::PathBuf; +use crate::config::Config; + const SYSTEM_PROMPT: &str = "You are an AI assistant running in a terminal (CLI) environment. \ Optimise all answers for 80‑column readability, prefer plain text, \ ASCII art or concise bullet lists over heavy markup, and wrap code \ @@ -62,8 +64,9 @@ impl Session { } pub fn sessions_dir() -> Result { + let config = Config::load().unwrap_or_default(); let home = dirs::home_dir().context("Could not find home directory")?; - let sessions_dir = home.join(".chat_cli_sessions"); + let sessions_dir = home.join(&config.session.sessions_dir_name); if !sessions_dir.exists() { fs::create_dir_all(&sessions_dir) @@ -74,7 +77,8 @@ impl Session { } pub fn session_path(name: &str) -> Result { - Ok(Self::sessions_dir()?.join(format!("{}.json", name))) + let config = Config::load().unwrap_or_default(); + Ok(Self::sessions_dir()?.join(format!("{}.{}", name, config.session.file_extension))) } pub fn save(&self) -> Result<()> { diff --git a/src/main.rs b/src/main.rs index fe96c1d..46b4f64 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,12 +1,13 @@ mod cli; +mod config; mod core; mod utils; use anyhow::{Context, Result}; use clap::Parser; -use std::env; use crate::cli::ChatCLI; +use crate::config::Config; use crate::core::{provider::is_model_supported, Session}; use crate::utils::Display; @@ -20,6 +21,9 @@ struct Args { #[arg(short, long, help = "Model name to use (overrides saved value)")] model: Option, + + #[arg(long, help = "Create example configuration file and exit")] + create_config: bool, } #[tokio::main] @@ -27,6 +31,18 @@ async fn main() -> Result<()> { let args = Args::parse(); let display = Display::new(); + // Handle config creation + if args.create_config { + Config::create_example_config()?; + return Ok(()); + } + + // Load configuration + let config = Config::load().context("Failed to load configuration")?; + + // Validate environment variables + let env_vars = Config::validate_env_variables().context("Environment validation failed")?; + // Load or create session let session = match Session::load(&args.session) { Ok(mut session) => { @@ -45,23 +61,29 @@ async fn main() -> Result<()> { Err(_) => { let default_model = args .model - .or_else(|| env::var("DEFAULT_MODEL").ok()) - .unwrap_or_else(|| "gpt-5".to_string()); + .or_else(|| env_vars.default_model.clone()) + .unwrap_or_else(|| config.defaults.model.clone()); if !is_model_supported(&default_model) { display.print_warning(&format!( - "Model '{}' is not supported. Falling back to 'gpt-5'", - default_model + "Model '{}' is not supported. Falling back to '{}'", + default_model, config.defaults.model )); - Session::new(args.session, "gpt-5".to_string()) + Session::new(args.session, config.defaults.model.clone()) } else { Session::new(args.session, default_model) } } }; + // Validate model availability + config.validate_model_availability(&env_vars, &session.model)?; + + // Print configuration info + config.print_config_info(); + // Run the CLI - let mut cli = ChatCLI::new(session).context("Failed to initialize CLI")?; + let mut cli = ChatCLI::new(session, config).context("Failed to initialize CLI")?; cli.run().await.context("CLI error")?; Ok(())