Centralize provider capabilities

This commit is contained in:
Christopher 2025-08-25 01:27:09 -04:00
parent 1ac7914646
commit 4618f273f3
2 changed files with 200 additions and 98 deletions

View File

@ -93,7 +93,6 @@ impl ChatCLI {
self.session.add_user_message(message.to_string());
self.session.save()?;
// Clone data needed for the API call before getting mutable client reference
let model = self.session.model.clone();
let messages = self.session.messages.clone();
@ -125,10 +124,8 @@ impl ChatCLI {
}) as StreamCallback
};
let client = self.get_client()?.clone();
let response = client
.chat_completion_stream(
&self.session.model,
&self.session.messages,
@ -159,7 +156,6 @@ impl ChatCLI {
let client = self.get_client()?.clone();
let response = client
.chat_completion(
&self.session.model,
&self.session.messages,
@ -460,27 +456,33 @@ impl ChatCLI {
// Check model compatibility
let model = self.session.model.clone();
let provider = get_provider_for_model(&model);
let reasoning_enabled = self.session.enable_reasoning_summary;
let capabilities = provider.capabilities();
// Show compatibility warnings based on provider
match provider {
crate::core::provider::Provider::Anthropic => {
if reasoning_enabled {
self.display.print_warning(
"Reasoning summaries are not supported by Anthropic models",
);
}
if self.session.enable_extended_thinking {
// Extended thinking is supported by Anthropic models
}
// Web search is now supported by Anthropic models
}
crate::core::provider::Provider::OpenAI => {
if self.session.enable_extended_thinking {
self.display
.print_warning("Extended thinking is not supported by OpenAI models");
}
// OpenAI models generally support other features
if self.session.enable_reasoning_summary && !capabilities.supports_reasoning_summaries {
self.display.print_warning(&format!(
"Reasoning summaries are not supported by {} models",
format!("{:?}", provider)
));
}
if self.session.enable_extended_thinking && !capabilities.supports_extended_thinking {
self.display.print_warning(&format!(
"Extended thinking is not supported by {} models",
format!("{:?}", provider)
));
}
if self.session.enable_web_search && !capabilities.supports_web_search {
self.display.print_warning(&format!(
"Web search is not supported by {} models",
format!("{:?}", provider)
));
}
if let Some(min) = capabilities.min_thinking_budget {
if self.session.thinking_budget_tokens < min {
self.display.print_warning(&format!(
"Minimum thinking budget is {} tokens for {} models",
min,
format!("{:?}", provider)
));
}
}
@ -548,15 +550,12 @@ impl ChatCLI {
.print_command_result(&format!("Extended thinking {}", state));
let provider = get_provider_for_model(&self.session.model);
match provider {
crate::core::provider::Provider::OpenAI => {
self.display.print_warning(
"Extended thinking is not supported by OpenAI models",
);
}
crate::core::provider::Provider::Anthropic => {
// Supported
}
let capabilities = provider.capabilities();
if !capabilities.supports_extended_thinking {
self.display.print_warning(&format!(
"Extended thinking is not supported by {} models",
format!("{:?}", provider)
));
}
}
Some("Set Thinking Budget") => {
@ -575,16 +574,19 @@ impl ChatCLI {
));
let provider = get_provider_for_model(&self.session.model);
match provider {
crate::core::provider::Provider::OpenAI => {
self.display.print_warning(
"Extended thinking is not supported by OpenAI models",
);
}
crate::core::provider::Provider::Anthropic => {
if budget < 1024 {
self.display.print_warning("Minimum thinking budget is 1024 tokens for Anthropic models");
}
let capabilities = provider.capabilities();
if !capabilities.supports_extended_thinking {
self.display.print_warning(&format!(
"Extended thinking is not supported by {} models",
format!("{:?}", provider)
));
} else if let Some(min) = capabilities.min_thinking_budget {
if budget < min {
self.display.print_warning(&format!(
"Minimum thinking budget is {} tokens for {} models",
min,
format!("{:?}", provider)
));
}
}
}

View File

@ -6,6 +6,14 @@ pub enum Provider {
Anthropic,
}
#[derive(Debug, Clone, Copy)]
pub struct ProviderCapabilities {
pub supports_web_search: bool,
pub supports_reasoning_summaries: bool,
pub supports_extended_thinking: bool,
pub min_thinking_budget: Option<u32>,
}
pub struct ModelInfo {
pub model_id: &'static str,
pub display_name: &'static str,
@ -18,11 +26,28 @@ impl Provider {
Provider::Anthropic => "anthropic",
}
}
pub fn capabilities(&self) -> ProviderCapabilities {
match self {
Provider::OpenAI => ProviderCapabilities {
supports_web_search: true,
supports_reasoning_summaries: true,
supports_extended_thinking: false,
min_thinking_budget: None,
},
Provider::Anthropic => ProviderCapabilities {
supports_web_search: true,
supports_reasoning_summaries: false,
supports_extended_thinking: true,
min_thinking_budget: Some(1024),
},
}
}
}
pub fn get_supported_models() -> HashMap<Provider, Vec<&'static str>> {
let mut models = HashMap::new();
models.insert(
Provider::OpenAI,
vec![
@ -37,7 +62,7 @@ pub fn get_supported_models() -> HashMap<Provider, Vec<&'static str>> {
"o3-mini",
],
);
models.insert(
Provider::Anthropic,
vec![
@ -48,29 +73,70 @@ pub fn get_supported_models() -> HashMap<Provider, Vec<&'static str>> {
"claude-3-haiku-20240307",
],
);
models
}
pub fn get_model_info_list() -> Vec<ModelInfo> {
vec![
// OpenAI models
ModelInfo { model_id: "gpt-4.1", display_name: "GPT-4.1" },
ModelInfo { model_id: "gpt-4.1-mini", display_name: "GPT-4.1 Mini" },
ModelInfo { model_id: "gpt-4o", display_name: "GPT-4o" },
ModelInfo { model_id: "gpt-5", display_name: "GPT-5" },
ModelInfo { model_id: "gpt-5-chat-latest", display_name: "GPT-5 Chat Latest" },
ModelInfo { model_id: "o1", display_name: "o1" },
ModelInfo { model_id: "o3", display_name: "o3" },
ModelInfo { model_id: "o4-mini", display_name: "o4 Mini" },
ModelInfo { model_id: "o3-mini", display_name: "o3 Mini" },
ModelInfo {
model_id: "gpt-4.1",
display_name: "GPT-4.1",
},
ModelInfo {
model_id: "gpt-4.1-mini",
display_name: "GPT-4.1 Mini",
},
ModelInfo {
model_id: "gpt-4o",
display_name: "GPT-4o",
},
ModelInfo {
model_id: "gpt-5",
display_name: "GPT-5",
},
ModelInfo {
model_id: "gpt-5-chat-latest",
display_name: "GPT-5 Chat Latest",
},
ModelInfo {
model_id: "o1",
display_name: "o1",
},
ModelInfo {
model_id: "o3",
display_name: "o3",
},
ModelInfo {
model_id: "o4-mini",
display_name: "o4 Mini",
},
ModelInfo {
model_id: "o3-mini",
display_name: "o3 Mini",
},
// Anthropic models with friendly names
ModelInfo { model_id: "claude-opus-4-1-20250805", display_name: "Claude Opus 4.1" },
ModelInfo { model_id: "claude-sonnet-4-20250514", display_name: "Claude Sonnet 4.0" },
ModelInfo { model_id: "claude-3-7-sonnet-20250219", display_name: "Claude 3.7 Sonnet" },
ModelInfo { model_id: "claude-3-5-haiku-20241022", display_name: "Claude 3.5 Haiku" },
ModelInfo { model_id: "claude-3-haiku-20240307", display_name: "Claude 3.0 Haiku" },
ModelInfo {
model_id: "claude-opus-4-1-20250805",
display_name: "Claude Opus 4.1",
},
ModelInfo {
model_id: "claude-sonnet-4-20250514",
display_name: "Claude Sonnet 4.0",
},
ModelInfo {
model_id: "claude-3-7-sonnet-20250219",
display_name: "Claude 3.7 Sonnet",
},
ModelInfo {
model_id: "claude-3-5-haiku-20241022",
display_name: "Claude 3.5 Haiku",
},
ModelInfo {
model_id: "claude-3-haiku-20240307",
display_name: "Claude 3.0 Haiku",
},
]
}
@ -99,13 +165,13 @@ pub fn get_all_models() -> Vec<&'static str> {
pub fn get_provider_for_model(model: &str) -> Provider {
let supported = get_supported_models();
for (provider, models) in supported {
if models.contains(&model) {
return provider;
}
}
Provider::OpenAI // default fallback
}
@ -126,18 +192,18 @@ mod tests {
#[test]
fn test_get_supported_models() {
let models = get_supported_models();
// Check that both providers are present
assert!(models.contains_key(&Provider::OpenAI));
assert!(models.contains_key(&Provider::Anthropic));
// Check OpenAI models
let openai_models = models.get(&Provider::OpenAI).unwrap();
assert!(openai_models.contains(&"gpt-5"));
assert!(openai_models.contains(&"gpt-4o"));
assert!(openai_models.contains(&"o1"));
assert!(openai_models.len() > 0);
// Check Anthropic models
let anthropic_models = models.get(&Provider::Anthropic).unwrap();
assert!(anthropic_models.contains(&"claude-sonnet-4-20250514"));
@ -148,15 +214,17 @@ mod tests {
#[test]
fn test_get_model_info_list() {
let model_infos = get_model_info_list();
assert!(model_infos.len() > 0);
// Check some specific models
let gpt5_info = model_infos.iter().find(|info| info.model_id == "gpt-5");
assert!(gpt5_info.is_some());
assert_eq!(gpt5_info.unwrap().display_name, "GPT-5");
let claude_info = model_infos.iter().find(|info| info.model_id == "claude-sonnet-4-20250514");
let claude_info = model_infos
.iter()
.find(|info| info.model_id == "claude-sonnet-4-20250514");
assert!(claude_info.is_some());
assert_eq!(claude_info.unwrap().display_name, "Claude Sonnet 4.0");
}
@ -165,20 +233,32 @@ mod tests {
fn test_get_display_name_for_model() {
// Test known models
assert_eq!(get_display_name_for_model("gpt-5"), "GPT-5");
assert_eq!(get_display_name_for_model("claude-sonnet-4-20250514"), "Claude Sonnet 4.0");
assert_eq!(
get_display_name_for_model("claude-sonnet-4-20250514"),
"Claude Sonnet 4.0"
);
assert_eq!(get_display_name_for_model("o1"), "o1");
// Test unknown model (should return the model_id itself)
assert_eq!(get_display_name_for_model("unknown-model-123"), "unknown-model-123");
assert_eq!(
get_display_name_for_model("unknown-model-123"),
"unknown-model-123"
);
}
#[test]
fn test_get_model_id_from_display_name() {
// Test known display names
assert_eq!(get_model_id_from_display_name("GPT-5"), Some("gpt-5".to_string()));
assert_eq!(get_model_id_from_display_name("Claude Sonnet 4.0"), Some("claude-sonnet-4-20250514".to_string()));
assert_eq!(
get_model_id_from_display_name("GPT-5"),
Some("gpt-5".to_string())
);
assert_eq!(
get_model_id_from_display_name("Claude Sonnet 4.0"),
Some("claude-sonnet-4-20250514".to_string())
);
assert_eq!(get_model_id_from_display_name("o1"), Some("o1".to_string()));
// Test unknown display name
assert_eq!(get_model_id_from_display_name("Unknown Model"), None);
}
@ -186,9 +266,9 @@ mod tests {
#[test]
fn test_get_all_models() {
let all_models = get_all_models();
assert!(all_models.len() > 0);
// Check that models from both providers are included
assert!(all_models.contains(&"gpt-5"));
assert!(all_models.contains(&"gpt-4o"));
@ -203,14 +283,26 @@ mod tests {
assert_eq!(get_provider_for_model("gpt-4o"), Provider::OpenAI);
assert_eq!(get_provider_for_model("o1"), Provider::OpenAI);
assert_eq!(get_provider_for_model("gpt-4.1"), Provider::OpenAI);
// Test Anthropic models
assert_eq!(get_provider_for_model("claude-sonnet-4-20250514"), Provider::Anthropic);
assert_eq!(get_provider_for_model("claude-3-5-haiku-20241022"), Provider::Anthropic);
assert_eq!(get_provider_for_model("claude-opus-4-1-20250805"), Provider::Anthropic);
assert_eq!(
get_provider_for_model("claude-sonnet-4-20250514"),
Provider::Anthropic
);
assert_eq!(
get_provider_for_model("claude-3-5-haiku-20241022"),
Provider::Anthropic
);
assert_eq!(
get_provider_for_model("claude-opus-4-1-20250805"),
Provider::Anthropic
);
// Test unknown model (should default to OpenAI)
assert_eq!(get_provider_for_model("unknown-model-123"), Provider::OpenAI);
assert_eq!(
get_provider_for_model("unknown-model-123"),
Provider::OpenAI
);
}
#[test]
@ -220,7 +312,7 @@ mod tests {
assert!(is_model_supported("claude-sonnet-4-20250514"));
assert!(is_model_supported("o1"));
assert!(is_model_supported("claude-3-haiku-20240307"));
// Test unsupported models
assert!(!is_model_supported("unsupported-model"));
assert!(!is_model_supported("gpt-6"));
@ -238,11 +330,11 @@ mod tests {
#[test]
fn test_provider_hash() {
use std::collections::HashMap;
let mut map = HashMap::new();
map.insert(Provider::OpenAI, "openai_value");
map.insert(Provider::Anthropic, "anthropic_value");
assert_eq!(map.get(&Provider::OpenAI), Some(&"openai_value"));
assert_eq!(map.get(&Provider::Anthropic), Some(&"anthropic_value"));
}
@ -253,7 +345,7 @@ mod tests {
model_id: "test-model",
display_name: "Test Model",
};
assert_eq!(model_info.model_id, "test-model");
assert_eq!(model_info.display_name, "Test Model");
}
@ -261,11 +353,14 @@ mod tests {
#[test]
fn test_all_model_infos_have_valid_display_names() {
let model_infos = get_model_info_list();
for info in model_infos {
assert!(!info.model_id.is_empty(), "Model ID should not be empty");
assert!(!info.display_name.is_empty(), "Display name should not be empty");
assert!(
!info.display_name.is_empty(),
"Display name should not be empty"
);
// Display name should be different from model_id for most cases
// (though some might be the same like "o1")
assert!(info.display_name.len() > 0);
@ -277,23 +372,28 @@ mod tests {
let supported_models = get_supported_models();
let all_models = get_all_models();
let model_infos = get_model_info_list();
// All models in get_all_models should be in supported_models
for model in &all_models {
let found = supported_models.values().any(|models| models.contains(model));
let found = supported_models
.values()
.any(|models| models.contains(model));
assert!(found, "Model {} not found in supported_models", model);
}
// All models in model_infos should be in all_models
for info in &model_infos {
assert!(all_models.contains(&info.model_id),
"Model {} from model_infos not found in all_models", info.model_id);
assert!(
all_models.contains(&info.model_id),
"Model {} from model_infos not found in all_models",
info.model_id
);
}
// All models in all_models should have corresponding model_info
for model in &all_models {
let found = model_infos.iter().any(|info| info.model_id == *model);
assert!(found, "Model {} not found in model_infos", model);
}
}
}
}