diff --git a/src/cli.rs b/src/cli.rs index 1f6a559..409e332 100644 --- a/src/cli.rs +++ b/src/cli.rs @@ -93,7 +93,6 @@ impl ChatCLI { self.session.add_user_message(message.to_string()); self.session.save()?; - // Clone data needed for the API call before getting mutable client reference let model = self.session.model.clone(); let messages = self.session.messages.clone(); @@ -125,10 +124,8 @@ impl ChatCLI { }) as StreamCallback }; - let client = self.get_client()?.clone(); let response = client - .chat_completion_stream( &self.session.model, &self.session.messages, @@ -159,7 +156,6 @@ impl ChatCLI { let client = self.get_client()?.clone(); let response = client - .chat_completion( &self.session.model, &self.session.messages, @@ -460,27 +456,33 @@ impl ChatCLI { // Check model compatibility let model = self.session.model.clone(); let provider = get_provider_for_model(&model); - let reasoning_enabled = self.session.enable_reasoning_summary; + let capabilities = provider.capabilities(); - // Show compatibility warnings based on provider - match provider { - crate::core::provider::Provider::Anthropic => { - if reasoning_enabled { - self.display.print_warning( - "Reasoning summaries are not supported by Anthropic models", - ); - } - if self.session.enable_extended_thinking { - // Extended thinking is supported by Anthropic models - } - // Web search is now supported by Anthropic models - } - crate::core::provider::Provider::OpenAI => { - if self.session.enable_extended_thinking { - self.display - .print_warning("Extended thinking is not supported by OpenAI models"); - } - // OpenAI models generally support other features + if self.session.enable_reasoning_summary && !capabilities.supports_reasoning_summaries { + self.display.print_warning(&format!( + "Reasoning summaries are not supported by {} models", + format!("{:?}", provider) + )); + } + if self.session.enable_extended_thinking && !capabilities.supports_extended_thinking { + self.display.print_warning(&format!( + "Extended thinking is not supported by {} models", + format!("{:?}", provider) + )); + } + if self.session.enable_web_search && !capabilities.supports_web_search { + self.display.print_warning(&format!( + "Web search is not supported by {} models", + format!("{:?}", provider) + )); + } + if let Some(min) = capabilities.min_thinking_budget { + if self.session.thinking_budget_tokens < min { + self.display.print_warning(&format!( + "Minimum thinking budget is {} tokens for {} models", + min, + format!("{:?}", provider) + )); } } @@ -548,15 +550,12 @@ impl ChatCLI { .print_command_result(&format!("Extended thinking {}", state)); let provider = get_provider_for_model(&self.session.model); - match provider { - crate::core::provider::Provider::OpenAI => { - self.display.print_warning( - "Extended thinking is not supported by OpenAI models", - ); - } - crate::core::provider::Provider::Anthropic => { - // Supported - } + let capabilities = provider.capabilities(); + if !capabilities.supports_extended_thinking { + self.display.print_warning(&format!( + "Extended thinking is not supported by {} models", + format!("{:?}", provider) + )); } } Some("Set Thinking Budget") => { @@ -575,16 +574,19 @@ impl ChatCLI { )); let provider = get_provider_for_model(&self.session.model); - match provider { - crate::core::provider::Provider::OpenAI => { - self.display.print_warning( - "Extended thinking is not supported by OpenAI models", - ); - } - crate::core::provider::Provider::Anthropic => { - if budget < 1024 { - self.display.print_warning("Minimum thinking budget is 1024 tokens for Anthropic models"); - } + let capabilities = provider.capabilities(); + if !capabilities.supports_extended_thinking { + self.display.print_warning(&format!( + "Extended thinking is not supported by {} models", + format!("{:?}", provider) + )); + } else if let Some(min) = capabilities.min_thinking_budget { + if budget < min { + self.display.print_warning(&format!( + "Minimum thinking budget is {} tokens for {} models", + min, + format!("{:?}", provider) + )); } } } diff --git a/src/core/provider.rs b/src/core/provider.rs index 1470d84..81f4b2b 100644 --- a/src/core/provider.rs +++ b/src/core/provider.rs @@ -6,6 +6,14 @@ pub enum Provider { Anthropic, } +#[derive(Debug, Clone, Copy)] +pub struct ProviderCapabilities { + pub supports_web_search: bool, + pub supports_reasoning_summaries: bool, + pub supports_extended_thinking: bool, + pub min_thinking_budget: Option, +} + pub struct ModelInfo { pub model_id: &'static str, pub display_name: &'static str, @@ -18,11 +26,28 @@ impl Provider { Provider::Anthropic => "anthropic", } } + + pub fn capabilities(&self) -> ProviderCapabilities { + match self { + Provider::OpenAI => ProviderCapabilities { + supports_web_search: true, + supports_reasoning_summaries: true, + supports_extended_thinking: false, + min_thinking_budget: None, + }, + Provider::Anthropic => ProviderCapabilities { + supports_web_search: true, + supports_reasoning_summaries: false, + supports_extended_thinking: true, + min_thinking_budget: Some(1024), + }, + } + } } pub fn get_supported_models() -> HashMap> { let mut models = HashMap::new(); - + models.insert( Provider::OpenAI, vec![ @@ -37,7 +62,7 @@ pub fn get_supported_models() -> HashMap> { "o3-mini", ], ); - + models.insert( Provider::Anthropic, vec![ @@ -48,29 +73,70 @@ pub fn get_supported_models() -> HashMap> { "claude-3-haiku-20240307", ], ); - + models } pub fn get_model_info_list() -> Vec { vec![ // OpenAI models - ModelInfo { model_id: "gpt-4.1", display_name: "GPT-4.1" }, - ModelInfo { model_id: "gpt-4.1-mini", display_name: "GPT-4.1 Mini" }, - ModelInfo { model_id: "gpt-4o", display_name: "GPT-4o" }, - ModelInfo { model_id: "gpt-5", display_name: "GPT-5" }, - ModelInfo { model_id: "gpt-5-chat-latest", display_name: "GPT-5 Chat Latest" }, - ModelInfo { model_id: "o1", display_name: "o1" }, - ModelInfo { model_id: "o3", display_name: "o3" }, - ModelInfo { model_id: "o4-mini", display_name: "o4 Mini" }, - ModelInfo { model_id: "o3-mini", display_name: "o3 Mini" }, - + ModelInfo { + model_id: "gpt-4.1", + display_name: "GPT-4.1", + }, + ModelInfo { + model_id: "gpt-4.1-mini", + display_name: "GPT-4.1 Mini", + }, + ModelInfo { + model_id: "gpt-4o", + display_name: "GPT-4o", + }, + ModelInfo { + model_id: "gpt-5", + display_name: "GPT-5", + }, + ModelInfo { + model_id: "gpt-5-chat-latest", + display_name: "GPT-5 Chat Latest", + }, + ModelInfo { + model_id: "o1", + display_name: "o1", + }, + ModelInfo { + model_id: "o3", + display_name: "o3", + }, + ModelInfo { + model_id: "o4-mini", + display_name: "o4 Mini", + }, + ModelInfo { + model_id: "o3-mini", + display_name: "o3 Mini", + }, // Anthropic models with friendly names - ModelInfo { model_id: "claude-opus-4-1-20250805", display_name: "Claude Opus 4.1" }, - ModelInfo { model_id: "claude-sonnet-4-20250514", display_name: "Claude Sonnet 4.0" }, - ModelInfo { model_id: "claude-3-7-sonnet-20250219", display_name: "Claude 3.7 Sonnet" }, - ModelInfo { model_id: "claude-3-5-haiku-20241022", display_name: "Claude 3.5 Haiku" }, - ModelInfo { model_id: "claude-3-haiku-20240307", display_name: "Claude 3.0 Haiku" }, + ModelInfo { + model_id: "claude-opus-4-1-20250805", + display_name: "Claude Opus 4.1", + }, + ModelInfo { + model_id: "claude-sonnet-4-20250514", + display_name: "Claude Sonnet 4.0", + }, + ModelInfo { + model_id: "claude-3-7-sonnet-20250219", + display_name: "Claude 3.7 Sonnet", + }, + ModelInfo { + model_id: "claude-3-5-haiku-20241022", + display_name: "Claude 3.5 Haiku", + }, + ModelInfo { + model_id: "claude-3-haiku-20240307", + display_name: "Claude 3.0 Haiku", + }, ] } @@ -99,13 +165,13 @@ pub fn get_all_models() -> Vec<&'static str> { pub fn get_provider_for_model(model: &str) -> Provider { let supported = get_supported_models(); - + for (provider, models) in supported { if models.contains(&model) { return provider; } } - + Provider::OpenAI // default fallback } @@ -126,18 +192,18 @@ mod tests { #[test] fn test_get_supported_models() { let models = get_supported_models(); - + // Check that both providers are present assert!(models.contains_key(&Provider::OpenAI)); assert!(models.contains_key(&Provider::Anthropic)); - + // Check OpenAI models let openai_models = models.get(&Provider::OpenAI).unwrap(); assert!(openai_models.contains(&"gpt-5")); assert!(openai_models.contains(&"gpt-4o")); assert!(openai_models.contains(&"o1")); assert!(openai_models.len() > 0); - + // Check Anthropic models let anthropic_models = models.get(&Provider::Anthropic).unwrap(); assert!(anthropic_models.contains(&"claude-sonnet-4-20250514")); @@ -148,15 +214,17 @@ mod tests { #[test] fn test_get_model_info_list() { let model_infos = get_model_info_list(); - + assert!(model_infos.len() > 0); - + // Check some specific models let gpt5_info = model_infos.iter().find(|info| info.model_id == "gpt-5"); assert!(gpt5_info.is_some()); assert_eq!(gpt5_info.unwrap().display_name, "GPT-5"); - - let claude_info = model_infos.iter().find(|info| info.model_id == "claude-sonnet-4-20250514"); + + let claude_info = model_infos + .iter() + .find(|info| info.model_id == "claude-sonnet-4-20250514"); assert!(claude_info.is_some()); assert_eq!(claude_info.unwrap().display_name, "Claude Sonnet 4.0"); } @@ -165,20 +233,32 @@ mod tests { fn test_get_display_name_for_model() { // Test known models assert_eq!(get_display_name_for_model("gpt-5"), "GPT-5"); - assert_eq!(get_display_name_for_model("claude-sonnet-4-20250514"), "Claude Sonnet 4.0"); + assert_eq!( + get_display_name_for_model("claude-sonnet-4-20250514"), + "Claude Sonnet 4.0" + ); assert_eq!(get_display_name_for_model("o1"), "o1"); - + // Test unknown model (should return the model_id itself) - assert_eq!(get_display_name_for_model("unknown-model-123"), "unknown-model-123"); + assert_eq!( + get_display_name_for_model("unknown-model-123"), + "unknown-model-123" + ); } #[test] fn test_get_model_id_from_display_name() { // Test known display names - assert_eq!(get_model_id_from_display_name("GPT-5"), Some("gpt-5".to_string())); - assert_eq!(get_model_id_from_display_name("Claude Sonnet 4.0"), Some("claude-sonnet-4-20250514".to_string())); + assert_eq!( + get_model_id_from_display_name("GPT-5"), + Some("gpt-5".to_string()) + ); + assert_eq!( + get_model_id_from_display_name("Claude Sonnet 4.0"), + Some("claude-sonnet-4-20250514".to_string()) + ); assert_eq!(get_model_id_from_display_name("o1"), Some("o1".to_string())); - + // Test unknown display name assert_eq!(get_model_id_from_display_name("Unknown Model"), None); } @@ -186,9 +266,9 @@ mod tests { #[test] fn test_get_all_models() { let all_models = get_all_models(); - + assert!(all_models.len() > 0); - + // Check that models from both providers are included assert!(all_models.contains(&"gpt-5")); assert!(all_models.contains(&"gpt-4o")); @@ -203,14 +283,26 @@ mod tests { assert_eq!(get_provider_for_model("gpt-4o"), Provider::OpenAI); assert_eq!(get_provider_for_model("o1"), Provider::OpenAI); assert_eq!(get_provider_for_model("gpt-4.1"), Provider::OpenAI); - + // Test Anthropic models - assert_eq!(get_provider_for_model("claude-sonnet-4-20250514"), Provider::Anthropic); - assert_eq!(get_provider_for_model("claude-3-5-haiku-20241022"), Provider::Anthropic); - assert_eq!(get_provider_for_model("claude-opus-4-1-20250805"), Provider::Anthropic); - + assert_eq!( + get_provider_for_model("claude-sonnet-4-20250514"), + Provider::Anthropic + ); + assert_eq!( + get_provider_for_model("claude-3-5-haiku-20241022"), + Provider::Anthropic + ); + assert_eq!( + get_provider_for_model("claude-opus-4-1-20250805"), + Provider::Anthropic + ); + // Test unknown model (should default to OpenAI) - assert_eq!(get_provider_for_model("unknown-model-123"), Provider::OpenAI); + assert_eq!( + get_provider_for_model("unknown-model-123"), + Provider::OpenAI + ); } #[test] @@ -220,7 +312,7 @@ mod tests { assert!(is_model_supported("claude-sonnet-4-20250514")); assert!(is_model_supported("o1")); assert!(is_model_supported("claude-3-haiku-20240307")); - + // Test unsupported models assert!(!is_model_supported("unsupported-model")); assert!(!is_model_supported("gpt-6")); @@ -238,11 +330,11 @@ mod tests { #[test] fn test_provider_hash() { use std::collections::HashMap; - + let mut map = HashMap::new(); map.insert(Provider::OpenAI, "openai_value"); map.insert(Provider::Anthropic, "anthropic_value"); - + assert_eq!(map.get(&Provider::OpenAI), Some(&"openai_value")); assert_eq!(map.get(&Provider::Anthropic), Some(&"anthropic_value")); } @@ -253,7 +345,7 @@ mod tests { model_id: "test-model", display_name: "Test Model", }; - + assert_eq!(model_info.model_id, "test-model"); assert_eq!(model_info.display_name, "Test Model"); } @@ -261,11 +353,14 @@ mod tests { #[test] fn test_all_model_infos_have_valid_display_names() { let model_infos = get_model_info_list(); - + for info in model_infos { assert!(!info.model_id.is_empty(), "Model ID should not be empty"); - assert!(!info.display_name.is_empty(), "Display name should not be empty"); - + assert!( + !info.display_name.is_empty(), + "Display name should not be empty" + ); + // Display name should be different from model_id for most cases // (though some might be the same like "o1") assert!(info.display_name.len() > 0); @@ -277,23 +372,28 @@ mod tests { let supported_models = get_supported_models(); let all_models = get_all_models(); let model_infos = get_model_info_list(); - + // All models in get_all_models should be in supported_models for model in &all_models { - let found = supported_models.values().any(|models| models.contains(model)); + let found = supported_models + .values() + .any(|models| models.contains(model)); assert!(found, "Model {} not found in supported_models", model); } - + // All models in model_infos should be in all_models for info in &model_infos { - assert!(all_models.contains(&info.model_id), - "Model {} from model_infos not found in all_models", info.model_id); + assert!( + all_models.contains(&info.model_id), + "Model {} from model_infos not found in all_models", + info.model_id + ); } - + // All models in all_models should have corresponding model_info for model in &all_models { let found = model_infos.iter().any(|info| info.model_id == *model); assert!(found, "Model {} not found in model_infos", model); } } -} \ No newline at end of file +}