kiln_ai.adapters.ml_model_list

  1import os
  2from dataclasses import dataclass
  3from enum import Enum
  4from os import getenv
  5from typing import Dict, List, NoReturn
  6
  7import httpx
  8from langchain_aws import ChatBedrockConverse
  9from langchain_core.language_models.chat_models import BaseChatModel
 10from langchain_groq import ChatGroq
 11from langchain_ollama import ChatOllama
 12from langchain_openai import ChatOpenAI
 13from pydantic import BaseModel
 14
 15from ..utils.config import Config
 16
 17
 18class ModelProviderName(str, Enum):
 19    openai = "openai"
 20    groq = "groq"
 21    amazon_bedrock = "amazon_bedrock"
 22    ollama = "ollama"
 23    openrouter = "openrouter"
 24
 25
 26class ModelFamily(str, Enum):
 27    gpt = "gpt"
 28    llama = "llama"
 29    phi = "phi"
 30    mistral = "mistral"
 31    gemma = "gemma"
 32
 33
 34# Where models have instruct and raw versions, instruct is default and raw is specified
 35class ModelName(str, Enum):
 36    llama_3_1_8b = "llama_3_1_8b"
 37    llama_3_1_70b = "llama_3_1_70b"
 38    llama_3_1_405b = "llama_3_1_405b"
 39    gpt_4o_mini = "gpt_4o_mini"
 40    gpt_4o = "gpt_4o"
 41    phi_3_5 = "phi_3_5"
 42    mistral_large = "mistral_large"
 43    mistral_nemo = "mistral_nemo"
 44    gemma_2_2b = "gemma_2_2b"
 45    gemma_2_9b = "gemma_2_9b"
 46    gemma_2_27b = "gemma_2_27b"
 47
 48
 49class KilnModelProvider(BaseModel):
 50    name: ModelProviderName
 51    # Allow overriding the model level setting
 52    supports_structured_output: bool = True
 53    provider_options: Dict = {}
 54
 55
 56class KilnModel(BaseModel):
 57    family: str
 58    name: str
 59    friendly_name: str
 60    providers: List[KilnModelProvider]
 61    supports_structured_output: bool = True
 62
 63
 64built_in_models: List[KilnModel] = [
 65    # GPT 4o Mini
 66    KilnModel(
 67        family=ModelFamily.gpt,
 68        name=ModelName.gpt_4o_mini,
 69        friendly_name="GPT 4o Mini",
 70        providers=[
 71            KilnModelProvider(
 72                name=ModelProviderName.openai,
 73                provider_options={"model": "gpt-4o-mini"},
 74            ),
 75            KilnModelProvider(
 76                name=ModelProviderName.openrouter,
 77                provider_options={"model": "openai/gpt-4o-mini"},
 78            ),
 79        ],
 80    ),
 81    # GPT 4o
 82    KilnModel(
 83        family=ModelFamily.gpt,
 84        name=ModelName.gpt_4o,
 85        friendly_name="GPT 4o",
 86        providers=[
 87            KilnModelProvider(
 88                name=ModelProviderName.openai,
 89                provider_options={"model": "gpt-4o"},
 90            ),
 91            KilnModelProvider(
 92                name=ModelProviderName.openrouter,
 93                provider_options={"model": "openai/gpt-4o-2024-08-06"},
 94            ),
 95        ],
 96    ),
 97    # Llama 3.1-8b
 98    KilnModel(
 99        family=ModelFamily.llama,
100        name=ModelName.llama_3_1_8b,
101        friendly_name="Llama 3.1 8B",
102        providers=[
103            KilnModelProvider(
104                name=ModelProviderName.groq,
105                provider_options={"model": "llama-3.1-8b-instant"},
106            ),
107            KilnModelProvider(
108                name=ModelProviderName.amazon_bedrock,
109                provider_options={
110                    "model": "meta.llama3-1-8b-instruct-v1:0",
111                    "region_name": "us-west-2",  # Llama 3.1 only in west-2
112                },
113            ),
114            KilnModelProvider(
115                name=ModelProviderName.ollama,
116                provider_options={"model": "llama3.1"},  # 8b is default
117            ),
118            KilnModelProvider(
119                name=ModelProviderName.openrouter,
120                provider_options={"model": "meta-llama/llama-3.1-8b-instruct"},
121            ),
122        ],
123    ),
124    # Llama 3.1 70b
125    KilnModel(
126        family=ModelFamily.llama,
127        name=ModelName.llama_3_1_70b,
128        friendly_name="Llama 3.1 70B",
129        providers=[
130            KilnModelProvider(
131                name=ModelProviderName.groq,
132                provider_options={"model": "llama-3.1-70b-versatile"},
133            ),
134            KilnModelProvider(
135                name=ModelProviderName.amazon_bedrock,
136                # TODO: this should work but a bug in the bedrock response schema
137                supports_structured_output=False,
138                provider_options={
139                    "model": "meta.llama3-1-70b-instruct-v1:0",
140                    "region_name": "us-west-2",  # Llama 3.1 only in west-2
141                },
142            ),
143            KilnModelProvider(
144                name=ModelProviderName.openrouter,
145                provider_options={"model": "meta-llama/llama-3.1-70b-instruct"},
146            ),
147            # TODO: enable once tests update to check if model is available
148            # KilnModelProvider(
149            #     provider=ModelProviders.ollama,
150            #     provider_options={"model": "llama3.1:70b"},
151            # ),
152        ],
153    ),
154    # Llama 3.1 405b
155    KilnModel(
156        family=ModelFamily.llama,
157        name=ModelName.llama_3_1_405b,
158        friendly_name="Llama 3.1 405B",
159        providers=[
160            # TODO: bring back when groq does: https://console.groq.com/docs/models
161            # KilnModelProvider(
162            #     name=ModelProviderName.groq,
163            #     provider_options={"model": "llama-3.1-405b-instruct-v1:0"},
164            # ),
165            KilnModelProvider(
166                name=ModelProviderName.amazon_bedrock,
167                provider_options={
168                    "model": "meta.llama3-1-405b-instruct-v1:0",
169                    "region_name": "us-west-2",  # Llama 3.1 only in west-2
170                },
171            ),
172            # TODO: enable once tests update to check if model is available
173            # KilnModelProvider(
174            #     name=ModelProviderName.ollama,
175            #     provider_options={"model": "llama3.1:405b"},
176            # ),
177            KilnModelProvider(
178                name=ModelProviderName.openrouter,
179                provider_options={"model": "meta-llama/llama-3.1-405b-instruct"},
180            ),
181        ],
182    ),
183    # Mistral Nemo
184    KilnModel(
185        family=ModelFamily.mistral,
186        name=ModelName.mistral_nemo,
187        friendly_name="Mistral Nemo",
188        providers=[
189            KilnModelProvider(
190                name=ModelProviderName.openrouter,
191                provider_options={"model": "mistralai/mistral-nemo"},
192            ),
193        ],
194    ),
195    # Mistral Large
196    KilnModel(
197        family=ModelFamily.mistral,
198        name=ModelName.mistral_large,
199        friendly_name="Mistral Large",
200        providers=[
201            KilnModelProvider(
202                name=ModelProviderName.amazon_bedrock,
203                provider_options={
204                    "model": "mistral.mistral-large-2407-v1:0",
205                    "region_name": "us-west-2",  # only in west-2
206                },
207            ),
208            KilnModelProvider(
209                name=ModelProviderName.openrouter,
210                provider_options={"model": "mistralai/mistral-large"},
211            ),
212            # TODO: enable once tests update to check if model is available
213            # KilnModelProvider(
214            #     provider=ModelProviders.ollama,
215            #     provider_options={"model": "mistral-large"},
216            # ),
217        ],
218    ),
219    # Phi 3.5
220    KilnModel(
221        family=ModelFamily.phi,
222        name=ModelName.phi_3_5,
223        friendly_name="Phi 3.5",
224        supports_structured_output=False,
225        providers=[
226            KilnModelProvider(
227                name=ModelProviderName.ollama,
228                provider_options={"model": "phi3.5"},
229            ),
230            KilnModelProvider(
231                name=ModelProviderName.openrouter,
232                provider_options={"model": "microsoft/phi-3.5-mini-128k-instruct"},
233            ),
234        ],
235    ),
236    # Gemma 2 2.6b
237    KilnModel(
238        family=ModelFamily.gemma,
239        name=ModelName.gemma_2_2b,
240        friendly_name="Gemma 2 2B",
241        supports_structured_output=False,
242        providers=[
243            KilnModelProvider(
244                name=ModelProviderName.ollama,
245                provider_options={
246                    "model": "gemma2:2b",
247                },
248            ),
249        ],
250    ),
251    # Gemma 2 9b
252    KilnModel(
253        family=ModelFamily.gemma,
254        name=ModelName.gemma_2_9b,
255        friendly_name="Gemma 2 9B",
256        supports_structured_output=False,
257        providers=[
258            # TODO: enable once tests update to check if model is available
259            # KilnModelProvider(
260            #     name=ModelProviderName.ollama,
261            #     provider_options={
262            #         "model": "gemma2:9b",
263            #     },
264            # ),
265            KilnModelProvider(
266                name=ModelProviderName.openrouter,
267                provider_options={"model": "google/gemma-2-9b-it"},
268            ),
269        ],
270    ),
271    # Gemma 2 27b
272    KilnModel(
273        family=ModelFamily.gemma,
274        name=ModelName.gemma_2_27b,
275        friendly_name="Gemma 2 27B",
276        supports_structured_output=False,
277        providers=[
278            # TODO: enable once tests update to check if model is available
279            # KilnModelProvider(
280            #     name=ModelProviderName.ollama,
281            #     provider_options={
282            #         "model": "gemma2:27b",
283            #     },
284            # ),
285            KilnModelProvider(
286                name=ModelProviderName.openrouter,
287                provider_options={"model": "google/gemma-2-27b-it"},
288            ),
289        ],
290    ),
291]
292
293
294def provider_name_from_id(id: str) -> str:
295    if id in ModelProviderName.__members__:
296        enum_id = ModelProviderName(id)
297        match enum_id:
298            case ModelProviderName.amazon_bedrock:
299                return "Amazon Bedrock"
300            case ModelProviderName.openrouter:
301                return "OpenRouter"
302            case ModelProviderName.groq:
303                return "Groq"
304            case ModelProviderName.ollama:
305                return "Ollama"
306            case ModelProviderName.openai:
307                return "OpenAI"
308            case _:
309                # triggers pyright warning if I miss a case
310                raise_exhaustive_error(enum_id)
311
312    return "Unknown provider: " + id
313
314
315def raise_exhaustive_error(value: NoReturn) -> NoReturn:
316    raise ValueError(f"Unhandled enum value: {value}")
317
318
319@dataclass
320class ModelProviderWarning:
321    required_config_keys: List[str]
322    message: str
323
324
325provider_warnings: Dict[ModelProviderName, ModelProviderWarning] = {
326    ModelProviderName.amazon_bedrock: ModelProviderWarning(
327        required_config_keys=["bedrock_access_key", "bedrock_secret_key"],
328        message="Attempted to use Amazon Bedrock without an access key and secret set. \nGet your keys from https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/overview",
329    ),
330    ModelProviderName.openrouter: ModelProviderWarning(
331        required_config_keys=["open_router_api_key"],
332        message="Attempted to use OpenRouter without an API key set. \nGet your API key from https://openrouter.ai/settings/keys",
333    ),
334    ModelProviderName.groq: ModelProviderWarning(
335        required_config_keys=["groq_api_key"],
336        message="Attempted to use Groq without an API key set. \nGet your API key from https://console.groq.com/keys",
337    ),
338    ModelProviderName.openai: ModelProviderWarning(
339        required_config_keys=["open_ai_api_key"],
340        message="Attempted to use OpenAI without an API key set. \nGet your API key from https://platform.openai.com/account/api-keys",
341    ),
342}
343
344
345def get_config_value(key: str):
346    try:
347        return Config.shared().__getattr__(key)
348    except AttributeError:
349        return None
350
351
352def check_provider_warnings(provider_name: ModelProviderName):
353    warning_check = provider_warnings.get(provider_name)
354    if warning_check is None:
355        return
356    for key in warning_check.required_config_keys:
357        if get_config_value(key) is None:
358            raise ValueError(warning_check.message)
359
360
361def langchain_model_from(name: str, provider_name: str | None = None) -> BaseChatModel:
362    if name not in ModelName.__members__:
363        raise ValueError(f"Invalid name: {name}")
364
365    # Select the model from built_in_models using the name
366    model = next(filter(lambda m: m.name == name, built_in_models))
367    if model is None:
368        raise ValueError(f"Model {name} not found")
369
370    # If a provider is provided, select the provider from the model's provider_config
371    provider: KilnModelProvider | None = None
372    if model.providers is None or len(model.providers) == 0:
373        raise ValueError(f"Model {name} has no providers")
374    elif provider_name is None:
375        # TODO: priority order
376        provider = model.providers[0]
377    else:
378        provider = next(
379            filter(lambda p: p.name == provider_name, model.providers), None
380        )
381    if provider is None:
382        raise ValueError(f"Provider {provider_name} not found for model {name}")
383
384    check_provider_warnings(provider.name)
385
386    if provider.name == ModelProviderName.openai:
387        api_key = Config.shared().open_ai_api_key
388        return ChatOpenAI(**provider.provider_options, openai_api_key=api_key)  # type: ignore[arg-type]
389    elif provider.name == ModelProviderName.groq:
390        api_key = Config.shared().groq_api_key
391        if api_key is None:
392            raise ValueError(
393                "Attempted to use Groq without an API key set. "
394                "Get your API key from https://console.groq.com/keys"
395            )
396        return ChatGroq(**provider.provider_options, groq_api_key=api_key)  # type: ignore[arg-type]
397    elif provider.name == ModelProviderName.amazon_bedrock:
398        api_key = Config.shared().bedrock_access_key
399        secret_key = Config.shared().bedrock_secret_key
400        # langchain doesn't allow passing these, so ugly hack to set env vars
401        os.environ["AWS_ACCESS_KEY_ID"] = api_key
402        os.environ["AWS_SECRET_ACCESS_KEY"] = secret_key
403        return ChatBedrockConverse(
404            **provider.provider_options,
405        )
406    elif provider.name == ModelProviderName.ollama:
407        return ChatOllama(**provider.provider_options, base_url=ollama_base_url())
408    elif provider.name == ModelProviderName.openrouter:
409        api_key = Config.shared().open_router_api_key
410        base_url = getenv("OPENROUTER_BASE_URL") or "https://openrouter.ai/api/v1"
411        return ChatOpenAI(
412            **provider.provider_options,
413            openai_api_key=api_key,  # type: ignore[arg-type]
414            openai_api_base=base_url,  # type: ignore[arg-type]
415            default_headers={
416                "HTTP-Referer": "https://kiln-ai.com/openrouter",
417                "X-Title": "KilnAI",
418            },
419        )
420    else:
421        raise ValueError(f"Invalid model or provider: {name} - {provider_name}")
422
423
424def ollama_base_url():
425    env_base_url = os.getenv("OLLAMA_BASE_URL")
426    if env_base_url is not None:
427        return env_base_url
428    return "http://localhost:11434"
429
430
431async def ollama_online():
432    try:
433        httpx.get(ollama_base_url() + "/api/tags")
434    except httpx.RequestError:
435        return False
436    return True
class ModelProviderName(builtins.str, enum.Enum):
19class ModelProviderName(str, Enum):
20    openai = "openai"
21    groq = "groq"
22    amazon_bedrock = "amazon_bedrock"
23    ollama = "ollama"
24    openrouter = "openrouter"

str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str

Create a new string object from the given object. If encoding or errors is specified, then the object must expose a data buffer that will be decoded using the given encoding and error handler. Otherwise, returns the result of object.__str__() (if defined) or repr(object). encoding defaults to sys.getdefaultencoding(). errors defaults to 'strict'.

openai = <ModelProviderName.openai: 'openai'>
groq = <ModelProviderName.groq: 'groq'>
amazon_bedrock = <ModelProviderName.amazon_bedrock: 'amazon_bedrock'>
ollama = <ModelProviderName.ollama: 'ollama'>
openrouter = <ModelProviderName.openrouter: 'openrouter'>
class ModelFamily(builtins.str, enum.Enum):
27class ModelFamily(str, Enum):
28    gpt = "gpt"
29    llama = "llama"
30    phi = "phi"
31    mistral = "mistral"
32    gemma = "gemma"

str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str

Create a new string object from the given object. If encoding or errors is specified, then the object must expose a data buffer that will be decoded using the given encoding and error handler. Otherwise, returns the result of object.__str__() (if defined) or repr(object). encoding defaults to sys.getdefaultencoding(). errors defaults to 'strict'.

gpt = <ModelFamily.gpt: 'gpt'>
llama = <ModelFamily.llama: 'llama'>
phi = <ModelFamily.phi: 'phi'>
mistral = <ModelFamily.mistral: 'mistral'>
gemma = <ModelFamily.gemma: 'gemma'>
class ModelName(builtins.str, enum.Enum):
36class ModelName(str, Enum):
37    llama_3_1_8b = "llama_3_1_8b"
38    llama_3_1_70b = "llama_3_1_70b"
39    llama_3_1_405b = "llama_3_1_405b"
40    gpt_4o_mini = "gpt_4o_mini"
41    gpt_4o = "gpt_4o"
42    phi_3_5 = "phi_3_5"
43    mistral_large = "mistral_large"
44    mistral_nemo = "mistral_nemo"
45    gemma_2_2b = "gemma_2_2b"
46    gemma_2_9b = "gemma_2_9b"
47    gemma_2_27b = "gemma_2_27b"

str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str

Create a new string object from the given object. If encoding or errors is specified, then the object must expose a data buffer that will be decoded using the given encoding and error handler. Otherwise, returns the result of object.__str__() (if defined) or repr(object). encoding defaults to sys.getdefaultencoding(). errors defaults to 'strict'.

llama_3_1_8b = <ModelName.llama_3_1_8b: 'llama_3_1_8b'>
llama_3_1_70b = <ModelName.llama_3_1_70b: 'llama_3_1_70b'>
llama_3_1_405b = <ModelName.llama_3_1_405b: 'llama_3_1_405b'>
gpt_4o_mini = <ModelName.gpt_4o_mini: 'gpt_4o_mini'>
gpt_4o = <ModelName.gpt_4o: 'gpt_4o'>
phi_3_5 = <ModelName.phi_3_5: 'phi_3_5'>
mistral_large = <ModelName.mistral_large: 'mistral_large'>
mistral_nemo = <ModelName.mistral_nemo: 'mistral_nemo'>
gemma_2_2b = <ModelName.gemma_2_2b: 'gemma_2_2b'>
gemma_2_9b = <ModelName.gemma_2_9b: 'gemma_2_9b'>
gemma_2_27b = <ModelName.gemma_2_27b: 'gemma_2_27b'>
class KilnModelProvider(pydantic.main.BaseModel):
50class KilnModelProvider(BaseModel):
51    name: ModelProviderName
52    # Allow overriding the model level setting
53    supports_structured_output: bool = True
54    provider_options: Dict = {}

Usage docs: https://docs.pydantic.dev/2.8/concepts/models/

A base class for creating Pydantic models.

Attributes: __class_vars__: The names of classvars defined on the model. __private_attributes__: Metadata about the private attributes of the model. __signature__: The signature for instantiating the model.

__pydantic_complete__: Whether model building is completed, or if there are still undefined fields.
__pydantic_core_schema__: The pydantic-core schema used to build the SchemaValidator and SchemaSerializer.
__pydantic_custom_init__: Whether the model has a custom `__init__` function.
__pydantic_decorators__: Metadata containing the decorators defined on the model.
    This replaces `Model.__validators__` and `Model.__root_validators__` from Pydantic V1.
__pydantic_generic_metadata__: Metadata for generic models; contains data used for a similar purpose to
    __args__, __origin__, __parameters__ in typing-module generics. May eventually be replaced by these.
__pydantic_parent_namespace__: Parent namespace of the model, used for automatic rebuilding of models.
__pydantic_post_init__: The name of the post-init method for the model, if defined.
__pydantic_root_model__: Whether the model is a `RootModel`.
__pydantic_serializer__: The pydantic-core SchemaSerializer used to dump instances of the model.
__pydantic_validator__: The pydantic-core SchemaValidator used to validate instances of the model.

__pydantic_extra__: An instance attribute with the values of extra fields from validation when
    `model_config['extra'] == 'allow'`.
__pydantic_fields_set__: An instance attribute with the names of fields explicitly set.
__pydantic_private__: Instance attribute with the values of private attributes set on the model instance.
supports_structured_output: bool
provider_options: Dict
model_config = {}
model_fields = {'name': FieldInfo(annotation=ModelProviderName, required=True), 'supports_structured_output': FieldInfo(annotation=bool, required=False, default=True), 'provider_options': FieldInfo(annotation=Dict, required=False, default={})}
model_computed_fields = {}
class KilnModel(pydantic.main.BaseModel):
57class KilnModel(BaseModel):
58    family: str
59    name: str
60    friendly_name: str
61    providers: List[KilnModelProvider]
62    supports_structured_output: bool = True

Usage docs: https://docs.pydantic.dev/2.8/concepts/models/

A base class for creating Pydantic models.

Attributes: __class_vars__: The names of classvars defined on the model. __private_attributes__: Metadata about the private attributes of the model. __signature__: The signature for instantiating the model.

__pydantic_complete__: Whether model building is completed, or if there are still undefined fields.
__pydantic_core_schema__: The pydantic-core schema used to build the SchemaValidator and SchemaSerializer.
__pydantic_custom_init__: Whether the model has a custom `__init__` function.
__pydantic_decorators__: Metadata containing the decorators defined on the model.
    This replaces `Model.__validators__` and `Model.__root_validators__` from Pydantic V1.
__pydantic_generic_metadata__: Metadata for generic models; contains data used for a similar purpose to
    __args__, __origin__, __parameters__ in typing-module generics. May eventually be replaced by these.
__pydantic_parent_namespace__: Parent namespace of the model, used for automatic rebuilding of models.
__pydantic_post_init__: The name of the post-init method for the model, if defined.
__pydantic_root_model__: Whether the model is a `RootModel`.
__pydantic_serializer__: The pydantic-core SchemaSerializer used to dump instances of the model.
__pydantic_validator__: The pydantic-core SchemaValidator used to validate instances of the model.

__pydantic_extra__: An instance attribute with the values of extra fields from validation when
    `model_config['extra'] == 'allow'`.
__pydantic_fields_set__: An instance attribute with the names of fields explicitly set.
__pydantic_private__: Instance attribute with the values of private attributes set on the model instance.
family: str
name: str
friendly_name: str
providers: List[KilnModelProvider]
supports_structured_output: bool
model_config = {}
model_fields = {'family': FieldInfo(annotation=str, required=True), 'name': FieldInfo(annotation=str, required=True), 'friendly_name': FieldInfo(annotation=str, required=True), 'providers': FieldInfo(annotation=List[KilnModelProvider], required=True), 'supports_structured_output': FieldInfo(annotation=bool, required=False, default=True)}
model_computed_fields = {}
built_in_models: List[KilnModel] = [KilnModel(family='gpt', name='gpt_4o_mini', friendly_name='GPT 4o Mini', providers=[KilnModelProvider(name=<ModelProviderName.openai: 'openai'>, supports_structured_output=True, provider_options={'model': 'gpt-4o-mini'}), KilnModelProvider(name=<ModelProviderName.openrouter: 'openrouter'>, supports_structured_output=True, provider_options={'model': 'openai/gpt-4o-mini'})], supports_structured_output=True), KilnModel(family='gpt', name='gpt_4o', friendly_name='GPT 4o', providers=[KilnModelProvider(name=<ModelProviderName.openai: 'openai'>, supports_structured_output=True, provider_options={'model': 'gpt-4o'}), KilnModelProvider(name=<ModelProviderName.openrouter: 'openrouter'>, supports_structured_output=True, provider_options={'model': 'openai/gpt-4o-2024-08-06'})], supports_structured_output=True), KilnModel(family='llama', name='llama_3_1_8b', friendly_name='Llama 3.1 8B', providers=[KilnModelProvider(name=<ModelProviderName.groq: 'groq'>, supports_structured_output=True, provider_options={'model': 'llama-3.1-8b-instant'}), KilnModelProvider(name=<ModelProviderName.amazon_bedrock: 'amazon_bedrock'>, supports_structured_output=True, provider_options={'model': 'meta.llama3-1-8b-instruct-v1:0', 'region_name': 'us-west-2'}), KilnModelProvider(name=<ModelProviderName.ollama: 'ollama'>, supports_structured_output=True, provider_options={'model': 'llama3.1'}), KilnModelProvider(name=<ModelProviderName.openrouter: 'openrouter'>, supports_structured_output=True, provider_options={'model': 'meta-llama/llama-3.1-8b-instruct'})], supports_structured_output=True), KilnModel(family='llama', name='llama_3_1_70b', friendly_name='Llama 3.1 70B', providers=[KilnModelProvider(name=<ModelProviderName.groq: 'groq'>, supports_structured_output=True, provider_options={'model': 'llama-3.1-70b-versatile'}), KilnModelProvider(name=<ModelProviderName.amazon_bedrock: 'amazon_bedrock'>, supports_structured_output=False, provider_options={'model': 'meta.llama3-1-70b-instruct-v1:0', 'region_name': 'us-west-2'}), KilnModelProvider(name=<ModelProviderName.openrouter: 'openrouter'>, supports_structured_output=True, provider_options={'model': 'meta-llama/llama-3.1-70b-instruct'})], supports_structured_output=True), KilnModel(family='llama', name='llama_3_1_405b', friendly_name='Llama 3.1 405B', providers=[KilnModelProvider(name=<ModelProviderName.amazon_bedrock: 'amazon_bedrock'>, supports_structured_output=True, provider_options={'model': 'meta.llama3-1-405b-instruct-v1:0', 'region_name': 'us-west-2'}), KilnModelProvider(name=<ModelProviderName.openrouter: 'openrouter'>, supports_structured_output=True, provider_options={'model': 'meta-llama/llama-3.1-405b-instruct'})], supports_structured_output=True), KilnModel(family='mistral', name='mistral_nemo', friendly_name='Mistral Nemo', providers=[KilnModelProvider(name=<ModelProviderName.openrouter: 'openrouter'>, supports_structured_output=True, provider_options={'model': 'mistralai/mistral-nemo'})], supports_structured_output=True), KilnModel(family='mistral', name='mistral_large', friendly_name='Mistral Large', providers=[KilnModelProvider(name=<ModelProviderName.amazon_bedrock: 'amazon_bedrock'>, supports_structured_output=True, provider_options={'model': 'mistral.mistral-large-2407-v1:0', 'region_name': 'us-west-2'}), KilnModelProvider(name=<ModelProviderName.openrouter: 'openrouter'>, supports_structured_output=True, provider_options={'model': 'mistralai/mistral-large'})], supports_structured_output=True), KilnModel(family='phi', name='phi_3_5', friendly_name='Phi 3.5', providers=[KilnModelProvider(name=<ModelProviderName.ollama: 'ollama'>, supports_structured_output=True, provider_options={'model': 'phi3.5'}), KilnModelProvider(name=<ModelProviderName.openrouter: 'openrouter'>, supports_structured_output=True, provider_options={'model': 'microsoft/phi-3.5-mini-128k-instruct'})], supports_structured_output=False), KilnModel(family='gemma', name='gemma_2_2b', friendly_name='Gemma 2 2B', providers=[KilnModelProvider(name=<ModelProviderName.ollama: 'ollama'>, supports_structured_output=True, provider_options={'model': 'gemma2:2b'})], supports_structured_output=False), KilnModel(family='gemma', name='gemma_2_9b', friendly_name='Gemma 2 9B', providers=[KilnModelProvider(name=<ModelProviderName.openrouter: 'openrouter'>, supports_structured_output=True, provider_options={'model': 'google/gemma-2-9b-it'})], supports_structured_output=False), KilnModel(family='gemma', name='gemma_2_27b', friendly_name='Gemma 2 27B', providers=[KilnModelProvider(name=<ModelProviderName.openrouter: 'openrouter'>, supports_structured_output=True, provider_options={'model': 'google/gemma-2-27b-it'})], supports_structured_output=False)]
def provider_name_from_id(id: str) -> str:
295def provider_name_from_id(id: str) -> str:
296    if id in ModelProviderName.__members__:
297        enum_id = ModelProviderName(id)
298        match enum_id:
299            case ModelProviderName.amazon_bedrock:
300                return "Amazon Bedrock"
301            case ModelProviderName.openrouter:
302                return "OpenRouter"
303            case ModelProviderName.groq:
304                return "Groq"
305            case ModelProviderName.ollama:
306                return "Ollama"
307            case ModelProviderName.openai:
308                return "OpenAI"
309            case _:
310                # triggers pyright warning if I miss a case
311                raise_exhaustive_error(enum_id)
312
313    return "Unknown provider: " + id
def raise_exhaustive_error(value: NoReturn) -> NoReturn:
316def raise_exhaustive_error(value: NoReturn) -> NoReturn:
317    raise ValueError(f"Unhandled enum value: {value}")
@dataclass
class ModelProviderWarning:
320@dataclass
321class ModelProviderWarning:
322    required_config_keys: List[str]
323    message: str
ModelProviderWarning(required_config_keys: List[str], message: str)
required_config_keys: List[str]
message: str
provider_warnings: Dict[ModelProviderName, ModelProviderWarning] = {<ModelProviderName.amazon_bedrock: 'amazon_bedrock'>: ModelProviderWarning(required_config_keys=['bedrock_access_key', 'bedrock_secret_key'], message='Attempted to use Amazon Bedrock without an access key and secret set. \nGet your keys from https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/overview'), <ModelProviderName.openrouter: 'openrouter'>: ModelProviderWarning(required_config_keys=['open_router_api_key'], message='Attempted to use OpenRouter without an API key set. \nGet your API key from https://openrouter.ai/settings/keys'), <ModelProviderName.groq: 'groq'>: ModelProviderWarning(required_config_keys=['groq_api_key'], message='Attempted to use Groq without an API key set. \nGet your API key from https://console.groq.com/keys'), <ModelProviderName.openai: 'openai'>: ModelProviderWarning(required_config_keys=['open_ai_api_key'], message='Attempted to use OpenAI without an API key set. \nGet your API key from https://platform.openai.com/account/api-keys')}
def get_config_value(key: str):
346def get_config_value(key: str):
347    try:
348        return Config.shared().__getattr__(key)
349    except AttributeError:
350        return None
def check_provider_warnings(provider_name: ModelProviderName):
353def check_provider_warnings(provider_name: ModelProviderName):
354    warning_check = provider_warnings.get(provider_name)
355    if warning_check is None:
356        return
357    for key in warning_check.required_config_keys:
358        if get_config_value(key) is None:
359            raise ValueError(warning_check.message)
def langchain_model_from( name: str, provider_name: str | None = None) -> langchain_core.language_models.chat_models.BaseChatModel:
362def langchain_model_from(name: str, provider_name: str | None = None) -> BaseChatModel:
363    if name not in ModelName.__members__:
364        raise ValueError(f"Invalid name: {name}")
365
366    # Select the model from built_in_models using the name
367    model = next(filter(lambda m: m.name == name, built_in_models))
368    if model is None:
369        raise ValueError(f"Model {name} not found")
370
371    # If a provider is provided, select the provider from the model's provider_config
372    provider: KilnModelProvider | None = None
373    if model.providers is None or len(model.providers) == 0:
374        raise ValueError(f"Model {name} has no providers")
375    elif provider_name is None:
376        # TODO: priority order
377        provider = model.providers[0]
378    else:
379        provider = next(
380            filter(lambda p: p.name == provider_name, model.providers), None
381        )
382    if provider is None:
383        raise ValueError(f"Provider {provider_name} not found for model {name}")
384
385    check_provider_warnings(provider.name)
386
387    if provider.name == ModelProviderName.openai:
388        api_key = Config.shared().open_ai_api_key
389        return ChatOpenAI(**provider.provider_options, openai_api_key=api_key)  # type: ignore[arg-type]
390    elif provider.name == ModelProviderName.groq:
391        api_key = Config.shared().groq_api_key
392        if api_key is None:
393            raise ValueError(
394                "Attempted to use Groq without an API key set. "
395                "Get your API key from https://console.groq.com/keys"
396            )
397        return ChatGroq(**provider.provider_options, groq_api_key=api_key)  # type: ignore[arg-type]
398    elif provider.name == ModelProviderName.amazon_bedrock:
399        api_key = Config.shared().bedrock_access_key
400        secret_key = Config.shared().bedrock_secret_key
401        # langchain doesn't allow passing these, so ugly hack to set env vars
402        os.environ["AWS_ACCESS_KEY_ID"] = api_key
403        os.environ["AWS_SECRET_ACCESS_KEY"] = secret_key
404        return ChatBedrockConverse(
405            **provider.provider_options,
406        )
407    elif provider.name == ModelProviderName.ollama:
408        return ChatOllama(**provider.provider_options, base_url=ollama_base_url())
409    elif provider.name == ModelProviderName.openrouter:
410        api_key = Config.shared().open_router_api_key
411        base_url = getenv("OPENROUTER_BASE_URL") or "https://openrouter.ai/api/v1"
412        return ChatOpenAI(
413            **provider.provider_options,
414            openai_api_key=api_key,  # type: ignore[arg-type]
415            openai_api_base=base_url,  # type: ignore[arg-type]
416            default_headers={
417                "HTTP-Referer": "https://kiln-ai.com/openrouter",
418                "X-Title": "KilnAI",
419            },
420        )
421    else:
422        raise ValueError(f"Invalid model or provider: {name} - {provider_name}")
def ollama_base_url():
425def ollama_base_url():
426    env_base_url = os.getenv("OLLAMA_BASE_URL")
427    if env_base_url is not None:
428        return env_base_url
429    return "http://localhost:11434"
async def ollama_online():
432async def ollama_online():
433    try:
434        httpx.get(ollama_base_url() + "/api/tags")
435    except httpx.RequestError:
436        return False
437    return True