Centralize provider capabilities

This commit is contained in:
Christopher 2025-08-25 01:27:09 -04:00
parent 1ac7914646
commit 4618f273f3
2 changed files with 200 additions and 98 deletions

View File

@ -93,7 +93,6 @@ impl ChatCLI {
self.session.add_user_message(message.to_string()); self.session.add_user_message(message.to_string());
self.session.save()?; self.session.save()?;
// Clone data needed for the API call before getting mutable client reference // Clone data needed for the API call before getting mutable client reference
let model = self.session.model.clone(); let model = self.session.model.clone();
let messages = self.session.messages.clone(); let messages = self.session.messages.clone();
@ -125,10 +124,8 @@ impl ChatCLI {
}) as StreamCallback }) as StreamCallback
}; };
let client = self.get_client()?.clone(); let client = self.get_client()?.clone();
let response = client let response = client
.chat_completion_stream( .chat_completion_stream(
&self.session.model, &self.session.model,
&self.session.messages, &self.session.messages,
@ -159,7 +156,6 @@ impl ChatCLI {
let client = self.get_client()?.clone(); let client = self.get_client()?.clone();
let response = client let response = client
.chat_completion( .chat_completion(
&self.session.model, &self.session.model,
&self.session.messages, &self.session.messages,
@ -460,27 +456,33 @@ impl ChatCLI {
// Check model compatibility // Check model compatibility
let model = self.session.model.clone(); let model = self.session.model.clone();
let provider = get_provider_for_model(&model); let provider = get_provider_for_model(&model);
let reasoning_enabled = self.session.enable_reasoning_summary; let capabilities = provider.capabilities();
// Show compatibility warnings based on provider if self.session.enable_reasoning_summary && !capabilities.supports_reasoning_summaries {
match provider { self.display.print_warning(&format!(
crate::core::provider::Provider::Anthropic => { "Reasoning summaries are not supported by {} models",
if reasoning_enabled { format!("{:?}", provider)
self.display.print_warning( ));
"Reasoning summaries are not supported by Anthropic models", }
); if self.session.enable_extended_thinking && !capabilities.supports_extended_thinking {
} self.display.print_warning(&format!(
if self.session.enable_extended_thinking { "Extended thinking is not supported by {} models",
// Extended thinking is supported by Anthropic models format!("{:?}", provider)
} ));
// Web search is now supported by Anthropic models }
} if self.session.enable_web_search && !capabilities.supports_web_search {
crate::core::provider::Provider::OpenAI => { self.display.print_warning(&format!(
if self.session.enable_extended_thinking { "Web search is not supported by {} models",
self.display format!("{:?}", provider)
.print_warning("Extended thinking is not supported by OpenAI models"); ));
} }
// OpenAI models generally support other features if let Some(min) = capabilities.min_thinking_budget {
if self.session.thinking_budget_tokens < min {
self.display.print_warning(&format!(
"Minimum thinking budget is {} tokens for {} models",
min,
format!("{:?}", provider)
));
} }
} }
@ -548,15 +550,12 @@ impl ChatCLI {
.print_command_result(&format!("Extended thinking {}", state)); .print_command_result(&format!("Extended thinking {}", state));
let provider = get_provider_for_model(&self.session.model); let provider = get_provider_for_model(&self.session.model);
match provider { let capabilities = provider.capabilities();
crate::core::provider::Provider::OpenAI => { if !capabilities.supports_extended_thinking {
self.display.print_warning( self.display.print_warning(&format!(
"Extended thinking is not supported by OpenAI models", "Extended thinking is not supported by {} models",
); format!("{:?}", provider)
} ));
crate::core::provider::Provider::Anthropic => {
// Supported
}
} }
} }
Some("Set Thinking Budget") => { Some("Set Thinking Budget") => {
@ -575,16 +574,19 @@ impl ChatCLI {
)); ));
let provider = get_provider_for_model(&self.session.model); let provider = get_provider_for_model(&self.session.model);
match provider { let capabilities = provider.capabilities();
crate::core::provider::Provider::OpenAI => { if !capabilities.supports_extended_thinking {
self.display.print_warning( self.display.print_warning(&format!(
"Extended thinking is not supported by OpenAI models", "Extended thinking is not supported by {} models",
); format!("{:?}", provider)
} ));
crate::core::provider::Provider::Anthropic => { } else if let Some(min) = capabilities.min_thinking_budget {
if budget < 1024 { if budget < min {
self.display.print_warning("Minimum thinking budget is 1024 tokens for Anthropic models"); self.display.print_warning(&format!(
} "Minimum thinking budget is {} tokens for {} models",
min,
format!("{:?}", provider)
));
} }
} }
} }

View File

@ -6,6 +6,14 @@ pub enum Provider {
Anthropic, Anthropic,
} }
#[derive(Debug, Clone, Copy)]
pub struct ProviderCapabilities {
pub supports_web_search: bool,
pub supports_reasoning_summaries: bool,
pub supports_extended_thinking: bool,
pub min_thinking_budget: Option<u32>,
}
pub struct ModelInfo { pub struct ModelInfo {
pub model_id: &'static str, pub model_id: &'static str,
pub display_name: &'static str, pub display_name: &'static str,
@ -18,11 +26,28 @@ impl Provider {
Provider::Anthropic => "anthropic", Provider::Anthropic => "anthropic",
} }
} }
pub fn capabilities(&self) -> ProviderCapabilities {
match self {
Provider::OpenAI => ProviderCapabilities {
supports_web_search: true,
supports_reasoning_summaries: true,
supports_extended_thinking: false,
min_thinking_budget: None,
},
Provider::Anthropic => ProviderCapabilities {
supports_web_search: true,
supports_reasoning_summaries: false,
supports_extended_thinking: true,
min_thinking_budget: Some(1024),
},
}
}
} }
pub fn get_supported_models() -> HashMap<Provider, Vec<&'static str>> { pub fn get_supported_models() -> HashMap<Provider, Vec<&'static str>> {
let mut models = HashMap::new(); let mut models = HashMap::new();
models.insert( models.insert(
Provider::OpenAI, Provider::OpenAI,
vec![ vec![
@ -37,7 +62,7 @@ pub fn get_supported_models() -> HashMap<Provider, Vec<&'static str>> {
"o3-mini", "o3-mini",
], ],
); );
models.insert( models.insert(
Provider::Anthropic, Provider::Anthropic,
vec![ vec![
@ -48,29 +73,70 @@ pub fn get_supported_models() -> HashMap<Provider, Vec<&'static str>> {
"claude-3-haiku-20240307", "claude-3-haiku-20240307",
], ],
); );
models models
} }
pub fn get_model_info_list() -> Vec<ModelInfo> { pub fn get_model_info_list() -> Vec<ModelInfo> {
vec![ vec![
// OpenAI models // OpenAI models
ModelInfo { model_id: "gpt-4.1", display_name: "GPT-4.1" }, ModelInfo {
ModelInfo { model_id: "gpt-4.1-mini", display_name: "GPT-4.1 Mini" }, model_id: "gpt-4.1",
ModelInfo { model_id: "gpt-4o", display_name: "GPT-4o" }, display_name: "GPT-4.1",
ModelInfo { model_id: "gpt-5", display_name: "GPT-5" }, },
ModelInfo { model_id: "gpt-5-chat-latest", display_name: "GPT-5 Chat Latest" }, ModelInfo {
ModelInfo { model_id: "o1", display_name: "o1" }, model_id: "gpt-4.1-mini",
ModelInfo { model_id: "o3", display_name: "o3" }, display_name: "GPT-4.1 Mini",
ModelInfo { model_id: "o4-mini", display_name: "o4 Mini" }, },
ModelInfo { model_id: "o3-mini", display_name: "o3 Mini" }, ModelInfo {
model_id: "gpt-4o",
display_name: "GPT-4o",
},
ModelInfo {
model_id: "gpt-5",
display_name: "GPT-5",
},
ModelInfo {
model_id: "gpt-5-chat-latest",
display_name: "GPT-5 Chat Latest",
},
ModelInfo {
model_id: "o1",
display_name: "o1",
},
ModelInfo {
model_id: "o3",
display_name: "o3",
},
ModelInfo {
model_id: "o4-mini",
display_name: "o4 Mini",
},
ModelInfo {
model_id: "o3-mini",
display_name: "o3 Mini",
},
// Anthropic models with friendly names // Anthropic models with friendly names
ModelInfo { model_id: "claude-opus-4-1-20250805", display_name: "Claude Opus 4.1" }, ModelInfo {
ModelInfo { model_id: "claude-sonnet-4-20250514", display_name: "Claude Sonnet 4.0" }, model_id: "claude-opus-4-1-20250805",
ModelInfo { model_id: "claude-3-7-sonnet-20250219", display_name: "Claude 3.7 Sonnet" }, display_name: "Claude Opus 4.1",
ModelInfo { model_id: "claude-3-5-haiku-20241022", display_name: "Claude 3.5 Haiku" }, },
ModelInfo { model_id: "claude-3-haiku-20240307", display_name: "Claude 3.0 Haiku" }, ModelInfo {
model_id: "claude-sonnet-4-20250514",
display_name: "Claude Sonnet 4.0",
},
ModelInfo {
model_id: "claude-3-7-sonnet-20250219",
display_name: "Claude 3.7 Sonnet",
},
ModelInfo {
model_id: "claude-3-5-haiku-20241022",
display_name: "Claude 3.5 Haiku",
},
ModelInfo {
model_id: "claude-3-haiku-20240307",
display_name: "Claude 3.0 Haiku",
},
] ]
} }
@ -99,13 +165,13 @@ pub fn get_all_models() -> Vec<&'static str> {
pub fn get_provider_for_model(model: &str) -> Provider { pub fn get_provider_for_model(model: &str) -> Provider {
let supported = get_supported_models(); let supported = get_supported_models();
for (provider, models) in supported { for (provider, models) in supported {
if models.contains(&model) { if models.contains(&model) {
return provider; return provider;
} }
} }
Provider::OpenAI // default fallback Provider::OpenAI // default fallback
} }
@ -126,18 +192,18 @@ mod tests {
#[test] #[test]
fn test_get_supported_models() { fn test_get_supported_models() {
let models = get_supported_models(); let models = get_supported_models();
// Check that both providers are present // Check that both providers are present
assert!(models.contains_key(&Provider::OpenAI)); assert!(models.contains_key(&Provider::OpenAI));
assert!(models.contains_key(&Provider::Anthropic)); assert!(models.contains_key(&Provider::Anthropic));
// Check OpenAI models // Check OpenAI models
let openai_models = models.get(&Provider::OpenAI).unwrap(); let openai_models = models.get(&Provider::OpenAI).unwrap();
assert!(openai_models.contains(&"gpt-5")); assert!(openai_models.contains(&"gpt-5"));
assert!(openai_models.contains(&"gpt-4o")); assert!(openai_models.contains(&"gpt-4o"));
assert!(openai_models.contains(&"o1")); assert!(openai_models.contains(&"o1"));
assert!(openai_models.len() > 0); assert!(openai_models.len() > 0);
// Check Anthropic models // Check Anthropic models
let anthropic_models = models.get(&Provider::Anthropic).unwrap(); let anthropic_models = models.get(&Provider::Anthropic).unwrap();
assert!(anthropic_models.contains(&"claude-sonnet-4-20250514")); assert!(anthropic_models.contains(&"claude-sonnet-4-20250514"));
@ -148,15 +214,17 @@ mod tests {
#[test] #[test]
fn test_get_model_info_list() { fn test_get_model_info_list() {
let model_infos = get_model_info_list(); let model_infos = get_model_info_list();
assert!(model_infos.len() > 0); assert!(model_infos.len() > 0);
// Check some specific models // Check some specific models
let gpt5_info = model_infos.iter().find(|info| info.model_id == "gpt-5"); let gpt5_info = model_infos.iter().find(|info| info.model_id == "gpt-5");
assert!(gpt5_info.is_some()); assert!(gpt5_info.is_some());
assert_eq!(gpt5_info.unwrap().display_name, "GPT-5"); assert_eq!(gpt5_info.unwrap().display_name, "GPT-5");
let claude_info = model_infos.iter().find(|info| info.model_id == "claude-sonnet-4-20250514"); let claude_info = model_infos
.iter()
.find(|info| info.model_id == "claude-sonnet-4-20250514");
assert!(claude_info.is_some()); assert!(claude_info.is_some());
assert_eq!(claude_info.unwrap().display_name, "Claude Sonnet 4.0"); assert_eq!(claude_info.unwrap().display_name, "Claude Sonnet 4.0");
} }
@ -165,20 +233,32 @@ mod tests {
fn test_get_display_name_for_model() { fn test_get_display_name_for_model() {
// Test known models // Test known models
assert_eq!(get_display_name_for_model("gpt-5"), "GPT-5"); assert_eq!(get_display_name_for_model("gpt-5"), "GPT-5");
assert_eq!(get_display_name_for_model("claude-sonnet-4-20250514"), "Claude Sonnet 4.0"); assert_eq!(
get_display_name_for_model("claude-sonnet-4-20250514"),
"Claude Sonnet 4.0"
);
assert_eq!(get_display_name_for_model("o1"), "o1"); assert_eq!(get_display_name_for_model("o1"), "o1");
// Test unknown model (should return the model_id itself) // Test unknown model (should return the model_id itself)
assert_eq!(get_display_name_for_model("unknown-model-123"), "unknown-model-123"); assert_eq!(
get_display_name_for_model("unknown-model-123"),
"unknown-model-123"
);
} }
#[test] #[test]
fn test_get_model_id_from_display_name() { fn test_get_model_id_from_display_name() {
// Test known display names // Test known display names
assert_eq!(get_model_id_from_display_name("GPT-5"), Some("gpt-5".to_string())); assert_eq!(
assert_eq!(get_model_id_from_display_name("Claude Sonnet 4.0"), Some("claude-sonnet-4-20250514".to_string())); get_model_id_from_display_name("GPT-5"),
Some("gpt-5".to_string())
);
assert_eq!(
get_model_id_from_display_name("Claude Sonnet 4.0"),
Some("claude-sonnet-4-20250514".to_string())
);
assert_eq!(get_model_id_from_display_name("o1"), Some("o1".to_string())); assert_eq!(get_model_id_from_display_name("o1"), Some("o1".to_string()));
// Test unknown display name // Test unknown display name
assert_eq!(get_model_id_from_display_name("Unknown Model"), None); assert_eq!(get_model_id_from_display_name("Unknown Model"), None);
} }
@ -186,9 +266,9 @@ mod tests {
#[test] #[test]
fn test_get_all_models() { fn test_get_all_models() {
let all_models = get_all_models(); let all_models = get_all_models();
assert!(all_models.len() > 0); assert!(all_models.len() > 0);
// Check that models from both providers are included // Check that models from both providers are included
assert!(all_models.contains(&"gpt-5")); assert!(all_models.contains(&"gpt-5"));
assert!(all_models.contains(&"gpt-4o")); assert!(all_models.contains(&"gpt-4o"));
@ -203,14 +283,26 @@ mod tests {
assert_eq!(get_provider_for_model("gpt-4o"), Provider::OpenAI); assert_eq!(get_provider_for_model("gpt-4o"), Provider::OpenAI);
assert_eq!(get_provider_for_model("o1"), Provider::OpenAI); assert_eq!(get_provider_for_model("o1"), Provider::OpenAI);
assert_eq!(get_provider_for_model("gpt-4.1"), Provider::OpenAI); assert_eq!(get_provider_for_model("gpt-4.1"), Provider::OpenAI);
// Test Anthropic models // Test Anthropic models
assert_eq!(get_provider_for_model("claude-sonnet-4-20250514"), Provider::Anthropic); assert_eq!(
assert_eq!(get_provider_for_model("claude-3-5-haiku-20241022"), Provider::Anthropic); get_provider_for_model("claude-sonnet-4-20250514"),
assert_eq!(get_provider_for_model("claude-opus-4-1-20250805"), Provider::Anthropic); Provider::Anthropic
);
assert_eq!(
get_provider_for_model("claude-3-5-haiku-20241022"),
Provider::Anthropic
);
assert_eq!(
get_provider_for_model("claude-opus-4-1-20250805"),
Provider::Anthropic
);
// Test unknown model (should default to OpenAI) // Test unknown model (should default to OpenAI)
assert_eq!(get_provider_for_model("unknown-model-123"), Provider::OpenAI); assert_eq!(
get_provider_for_model("unknown-model-123"),
Provider::OpenAI
);
} }
#[test] #[test]
@ -220,7 +312,7 @@ mod tests {
assert!(is_model_supported("claude-sonnet-4-20250514")); assert!(is_model_supported("claude-sonnet-4-20250514"));
assert!(is_model_supported("o1")); assert!(is_model_supported("o1"));
assert!(is_model_supported("claude-3-haiku-20240307")); assert!(is_model_supported("claude-3-haiku-20240307"));
// Test unsupported models // Test unsupported models
assert!(!is_model_supported("unsupported-model")); assert!(!is_model_supported("unsupported-model"));
assert!(!is_model_supported("gpt-6")); assert!(!is_model_supported("gpt-6"));
@ -238,11 +330,11 @@ mod tests {
#[test] #[test]
fn test_provider_hash() { fn test_provider_hash() {
use std::collections::HashMap; use std::collections::HashMap;
let mut map = HashMap::new(); let mut map = HashMap::new();
map.insert(Provider::OpenAI, "openai_value"); map.insert(Provider::OpenAI, "openai_value");
map.insert(Provider::Anthropic, "anthropic_value"); map.insert(Provider::Anthropic, "anthropic_value");
assert_eq!(map.get(&Provider::OpenAI), Some(&"openai_value")); assert_eq!(map.get(&Provider::OpenAI), Some(&"openai_value"));
assert_eq!(map.get(&Provider::Anthropic), Some(&"anthropic_value")); assert_eq!(map.get(&Provider::Anthropic), Some(&"anthropic_value"));
} }
@ -253,7 +345,7 @@ mod tests {
model_id: "test-model", model_id: "test-model",
display_name: "Test Model", display_name: "Test Model",
}; };
assert_eq!(model_info.model_id, "test-model"); assert_eq!(model_info.model_id, "test-model");
assert_eq!(model_info.display_name, "Test Model"); assert_eq!(model_info.display_name, "Test Model");
} }
@ -261,11 +353,14 @@ mod tests {
#[test] #[test]
fn test_all_model_infos_have_valid_display_names() { fn test_all_model_infos_have_valid_display_names() {
let model_infos = get_model_info_list(); let model_infos = get_model_info_list();
for info in model_infos { for info in model_infos {
assert!(!info.model_id.is_empty(), "Model ID should not be empty"); assert!(!info.model_id.is_empty(), "Model ID should not be empty");
assert!(!info.display_name.is_empty(), "Display name should not be empty"); assert!(
!info.display_name.is_empty(),
"Display name should not be empty"
);
// Display name should be different from model_id for most cases // Display name should be different from model_id for most cases
// (though some might be the same like "o1") // (though some might be the same like "o1")
assert!(info.display_name.len() > 0); assert!(info.display_name.len() > 0);
@ -277,23 +372,28 @@ mod tests {
let supported_models = get_supported_models(); let supported_models = get_supported_models();
let all_models = get_all_models(); let all_models = get_all_models();
let model_infos = get_model_info_list(); let model_infos = get_model_info_list();
// All models in get_all_models should be in supported_models // All models in get_all_models should be in supported_models
for model in &all_models { for model in &all_models {
let found = supported_models.values().any(|models| models.contains(model)); let found = supported_models
.values()
.any(|models| models.contains(model));
assert!(found, "Model {} not found in supported_models", model); assert!(found, "Model {} not found in supported_models", model);
} }
// All models in model_infos should be in all_models // All models in model_infos should be in all_models
for info in &model_infos { for info in &model_infos {
assert!(all_models.contains(&info.model_id), assert!(
"Model {} from model_infos not found in all_models", info.model_id); all_models.contains(&info.model_id),
"Model {} from model_infos not found in all_models",
info.model_id
);
} }
// All models in all_models should have corresponding model_info // All models in all_models should have corresponding model_info
for model in &all_models { for model in &all_models {
let found = model_infos.iter().any(|info| info.model_id == *model); let found = model_infos.iter().any(|info| info.model_id == *model);
assert!(found, "Model {} not found in model_infos", model); assert!(found, "Model {} not found in model_infos", model);
} }
} }
} }