kiln_ai.adapters.langchain_adapters
1import os 2from os import getenv 3from typing import Any, Dict 4 5from langchain_aws import ChatBedrockConverse 6from langchain_core.language_models import LanguageModelInput 7from langchain_core.language_models.chat_models import BaseChatModel 8from langchain_core.messages import AIMessage, HumanMessage, SystemMessage 9from langchain_core.messages.base import BaseMessage 10from langchain_core.runnables import Runnable 11from langchain_fireworks import ChatFireworks 12from langchain_groq import ChatGroq 13from langchain_ollama import ChatOllama 14from langchain_openai import ChatOpenAI 15from pydantic import BaseModel 16 17import kiln_ai.datamodel as datamodel 18from kiln_ai.adapters.ollama_tools import ( 19 get_ollama_connection, 20 ollama_base_url, 21 ollama_model_installed, 22) 23from kiln_ai.utils.config import Config 24 25from .base_adapter import AdapterInfo, BaseAdapter, BasePromptBuilder, RunOutput 26from .ml_model_list import KilnModelProvider, ModelProviderName 27from .provider_tools import kiln_model_provider_from 28 29LangChainModelType = BaseChatModel | Runnable[LanguageModelInput, Dict | BaseModel] 30 31 32class LangchainAdapter(BaseAdapter): 33 _model: LangChainModelType | None = None 34 35 def __init__( 36 self, 37 kiln_task: datamodel.Task, 38 custom_model: BaseChatModel | None = None, 39 model_name: str | None = None, 40 provider: str | None = None, 41 prompt_builder: BasePromptBuilder | None = None, 42 ): 43 super().__init__(kiln_task, prompt_builder=prompt_builder) 44 if custom_model is not None: 45 self._model = custom_model 46 47 # Attempt to infer model provider and name from custom model 48 self.model_provider = "custom.langchain:" + custom_model.__class__.__name__ 49 self.model_name = "custom.langchain:unknown_model" 50 if hasattr(custom_model, "model_name") and isinstance( 51 getattr(custom_model, "model_name"), str 52 ): 53 self.model_name = "custom.langchain:" + getattr( 54 custom_model, "model_name" 55 ) 56 if hasattr(custom_model, "model") and isinstance( 57 getattr(custom_model, "model"), str 58 ): 59 self.model_name = "custom.langchain:" + getattr(custom_model, "model") 60 elif model_name is not None: 61 self.model_name = model_name 62 self.model_provider = provider or "custom.langchain.default_provider" 63 else: 64 raise ValueError( 65 "model_name and provider must be provided if custom_model is not provided" 66 ) 67 68 async def model(self) -> LangChainModelType: 69 # cached model 70 if self._model: 71 return self._model 72 73 self._model = await langchain_model_from(self.model_name, self.model_provider) 74 75 if self.has_structured_output(): 76 if not hasattr(self._model, "with_structured_output") or not callable( 77 getattr(self._model, "with_structured_output") 78 ): 79 raise ValueError( 80 f"model {self._model} does not support structured output, cannot use output_json_schema" 81 ) 82 # Langchain expects title/description to be at top level, on top of json schema 83 output_schema = self.kiln_task.output_schema() 84 if output_schema is None: 85 raise ValueError( 86 f"output_json_schema is not valid json: {self.kiln_task.output_json_schema}" 87 ) 88 output_schema["title"] = "task_response" 89 output_schema["description"] = "A response from the task" 90 with_structured_output_options = await get_structured_output_options( 91 self.model_name, self.model_provider 92 ) 93 self._model = self._model.with_structured_output( 94 output_schema, 95 include_raw=True, 96 **with_structured_output_options, 97 ) 98 return self._model 99 100 async def _run(self, input: Dict | str) -> RunOutput: 101 model = await self.model() 102 chain = model 103 intermediate_outputs = {} 104 105 prompt = self.build_prompt() 106 user_msg = self.prompt_builder.build_user_message(input) 107 messages = [ 108 SystemMessage(content=prompt), 109 HumanMessage(content=user_msg), 110 ] 111 112 # COT with structured output 113 cot_prompt = self.prompt_builder.chain_of_thought_prompt() 114 if cot_prompt and self.has_structured_output(): 115 # Base model (without structured output) used for COT message 116 base_model = await langchain_model_from( 117 self.model_name, self.model_provider 118 ) 119 messages.append( 120 SystemMessage(content=cot_prompt), 121 ) 122 123 cot_messages = [*messages] 124 cot_response = await base_model.ainvoke(cot_messages) 125 intermediate_outputs["chain_of_thought"] = cot_response.content 126 messages.append(AIMessage(content=cot_response.content)) 127 messages.append( 128 SystemMessage(content="Considering the above, return a final result.") 129 ) 130 elif cot_prompt: 131 messages.append(SystemMessage(content=cot_prompt)) 132 133 response = await chain.ainvoke(messages) 134 135 if self.has_structured_output(): 136 if ( 137 not isinstance(response, dict) 138 or "parsed" not in response 139 or not isinstance(response["parsed"], dict) 140 ): 141 raise RuntimeError(f"structured response not returned: {response}") 142 structured_response = response["parsed"] 143 return RunOutput( 144 output=self._munge_response(structured_response), 145 intermediate_outputs=intermediate_outputs, 146 ) 147 else: 148 if not isinstance(response, BaseMessage): 149 raise RuntimeError(f"response is not a BaseMessage: {response}") 150 text_content = response.content 151 if not isinstance(text_content, str): 152 raise RuntimeError(f"response is not a string: {text_content}") 153 return RunOutput( 154 output=text_content, 155 intermediate_outputs=intermediate_outputs, 156 ) 157 158 def adapter_info(self) -> AdapterInfo: 159 return AdapterInfo( 160 model_name=self.model_name, 161 model_provider=self.model_provider, 162 adapter_name="kiln_langchain_adapter", 163 prompt_builder_name=self.prompt_builder.__class__.prompt_builder_name(), 164 ) 165 166 def _munge_response(self, response: Dict) -> Dict: 167 # Mistral Large tool calling format is a bit different. Convert to standard format. 168 if ( 169 "name" in response 170 and response["name"] == "task_response" 171 and "arguments" in response 172 ): 173 return response["arguments"] 174 return response 175 176 177async def get_structured_output_options( 178 model_name: str, model_provider: str 179) -> Dict[str, Any]: 180 finetune_provider = await kiln_model_provider_from(model_name, model_provider) 181 if finetune_provider and finetune_provider.adapter_options.get("langchain"): 182 return finetune_provider.adapter_options["langchain"].get( 183 "with_structured_output_options", {} 184 ) 185 return {} 186 187 188async def langchain_model_from( 189 name: str, provider_name: str | None = None 190) -> BaseChatModel: 191 provider = await kiln_model_provider_from(name, provider_name) 192 return await langchain_model_from_provider(provider, name) 193 194 195async def langchain_model_from_provider( 196 provider: KilnModelProvider, model_name: str 197) -> BaseChatModel: 198 if provider.name == ModelProviderName.openai: 199 api_key = Config.shared().open_ai_api_key 200 return ChatOpenAI(**provider.provider_options, openai_api_key=api_key) # type: ignore[arg-type] 201 elif provider.name == ModelProviderName.groq: 202 api_key = Config.shared().groq_api_key 203 if api_key is None: 204 raise ValueError( 205 "Attempted to use Groq without an API key set. " 206 "Get your API key from https://console.groq.com/keys" 207 ) 208 return ChatGroq(**provider.provider_options, groq_api_key=api_key) # type: ignore[arg-type] 209 elif provider.name == ModelProviderName.amazon_bedrock: 210 api_key = Config.shared().bedrock_access_key 211 secret_key = Config.shared().bedrock_secret_key 212 # langchain doesn't allow passing these, so ugly hack to set env vars 213 os.environ["AWS_ACCESS_KEY_ID"] = api_key 214 os.environ["AWS_SECRET_ACCESS_KEY"] = secret_key 215 return ChatBedrockConverse( 216 **provider.provider_options, 217 ) 218 elif provider.name == ModelProviderName.fireworks_ai: 219 api_key = Config.shared().fireworks_api_key 220 return ChatFireworks(**provider.provider_options, api_key=api_key) 221 elif provider.name == ModelProviderName.ollama: 222 # Ollama model naming is pretty flexible. We try a few versions of the model name 223 potential_model_names = [] 224 if "model" in provider.provider_options: 225 potential_model_names.append(provider.provider_options["model"]) 226 if "model_aliases" in provider.provider_options: 227 potential_model_names.extend(provider.provider_options["model_aliases"]) 228 229 # Get the list of models Ollama supports 230 ollama_connection = await get_ollama_connection() 231 if ollama_connection is None: 232 raise ValueError("Failed to connect to Ollama. Ensure Ollama is running.") 233 234 for model_name in potential_model_names: 235 if ollama_model_installed(ollama_connection, model_name): 236 return ChatOllama(model=model_name, base_url=ollama_base_url()) 237 238 raise ValueError(f"Model {model_name} not installed on Ollama") 239 elif provider.name == ModelProviderName.openrouter: 240 api_key = Config.shared().open_router_api_key 241 base_url = getenv("OPENROUTER_BASE_URL") or "https://openrouter.ai/api/v1" 242 return ChatOpenAI( 243 **provider.provider_options, 244 openai_api_key=api_key, # type: ignore[arg-type] 245 openai_api_base=base_url, # type: ignore[arg-type] 246 default_headers={ 247 "HTTP-Referer": "https://getkiln.ai/openrouter", 248 "X-Title": "KilnAI", 249 }, 250 ) 251 else: 252 raise ValueError(f"Invalid model or provider: {model_name} - {provider.name}")
LangChainModelType =
typing.Union[langchain_core.language_models.chat_models.BaseChatModel, langchain_core.runnables.base.Runnable[typing.Union[langchain_core.prompt_values.PromptValue, str, collections.abc.Sequence[typing.Union[langchain_core.messages.base.BaseMessage, list[str], tuple[str, str], str, dict[str, typing.Any]]]], typing.Union[typing.Dict, pydantic.main.BaseModel]]]
33class LangchainAdapter(BaseAdapter): 34 _model: LangChainModelType | None = None 35 36 def __init__( 37 self, 38 kiln_task: datamodel.Task, 39 custom_model: BaseChatModel | None = None, 40 model_name: str | None = None, 41 provider: str | None = None, 42 prompt_builder: BasePromptBuilder | None = None, 43 ): 44 super().__init__(kiln_task, prompt_builder=prompt_builder) 45 if custom_model is not None: 46 self._model = custom_model 47 48 # Attempt to infer model provider and name from custom model 49 self.model_provider = "custom.langchain:" + custom_model.__class__.__name__ 50 self.model_name = "custom.langchain:unknown_model" 51 if hasattr(custom_model, "model_name") and isinstance( 52 getattr(custom_model, "model_name"), str 53 ): 54 self.model_name = "custom.langchain:" + getattr( 55 custom_model, "model_name" 56 ) 57 if hasattr(custom_model, "model") and isinstance( 58 getattr(custom_model, "model"), str 59 ): 60 self.model_name = "custom.langchain:" + getattr(custom_model, "model") 61 elif model_name is not None: 62 self.model_name = model_name 63 self.model_provider = provider or "custom.langchain.default_provider" 64 else: 65 raise ValueError( 66 "model_name and provider must be provided if custom_model is not provided" 67 ) 68 69 async def model(self) -> LangChainModelType: 70 # cached model 71 if self._model: 72 return self._model 73 74 self._model = await langchain_model_from(self.model_name, self.model_provider) 75 76 if self.has_structured_output(): 77 if not hasattr(self._model, "with_structured_output") or not callable( 78 getattr(self._model, "with_structured_output") 79 ): 80 raise ValueError( 81 f"model {self._model} does not support structured output, cannot use output_json_schema" 82 ) 83 # Langchain expects title/description to be at top level, on top of json schema 84 output_schema = self.kiln_task.output_schema() 85 if output_schema is None: 86 raise ValueError( 87 f"output_json_schema is not valid json: {self.kiln_task.output_json_schema}" 88 ) 89 output_schema["title"] = "task_response" 90 output_schema["description"] = "A response from the task" 91 with_structured_output_options = await get_structured_output_options( 92 self.model_name, self.model_provider 93 ) 94 self._model = self._model.with_structured_output( 95 output_schema, 96 include_raw=True, 97 **with_structured_output_options, 98 ) 99 return self._model 100 101 async def _run(self, input: Dict | str) -> RunOutput: 102 model = await self.model() 103 chain = model 104 intermediate_outputs = {} 105 106 prompt = self.build_prompt() 107 user_msg = self.prompt_builder.build_user_message(input) 108 messages = [ 109 SystemMessage(content=prompt), 110 HumanMessage(content=user_msg), 111 ] 112 113 # COT with structured output 114 cot_prompt = self.prompt_builder.chain_of_thought_prompt() 115 if cot_prompt and self.has_structured_output(): 116 # Base model (without structured output) used for COT message 117 base_model = await langchain_model_from( 118 self.model_name, self.model_provider 119 ) 120 messages.append( 121 SystemMessage(content=cot_prompt), 122 ) 123 124 cot_messages = [*messages] 125 cot_response = await base_model.ainvoke(cot_messages) 126 intermediate_outputs["chain_of_thought"] = cot_response.content 127 messages.append(AIMessage(content=cot_response.content)) 128 messages.append( 129 SystemMessage(content="Considering the above, return a final result.") 130 ) 131 elif cot_prompt: 132 messages.append(SystemMessage(content=cot_prompt)) 133 134 response = await chain.ainvoke(messages) 135 136 if self.has_structured_output(): 137 if ( 138 not isinstance(response, dict) 139 or "parsed" not in response 140 or not isinstance(response["parsed"], dict) 141 ): 142 raise RuntimeError(f"structured response not returned: {response}") 143 structured_response = response["parsed"] 144 return RunOutput( 145 output=self._munge_response(structured_response), 146 intermediate_outputs=intermediate_outputs, 147 ) 148 else: 149 if not isinstance(response, BaseMessage): 150 raise RuntimeError(f"response is not a BaseMessage: {response}") 151 text_content = response.content 152 if not isinstance(text_content, str): 153 raise RuntimeError(f"response is not a string: {text_content}") 154 return RunOutput( 155 output=text_content, 156 intermediate_outputs=intermediate_outputs, 157 ) 158 159 def adapter_info(self) -> AdapterInfo: 160 return AdapterInfo( 161 model_name=self.model_name, 162 model_provider=self.model_provider, 163 adapter_name="kiln_langchain_adapter", 164 prompt_builder_name=self.prompt_builder.__class__.prompt_builder_name(), 165 ) 166 167 def _munge_response(self, response: Dict) -> Dict: 168 # Mistral Large tool calling format is a bit different. Convert to standard format. 169 if ( 170 "name" in response 171 and response["name"] == "task_response" 172 and "arguments" in response 173 ): 174 return response["arguments"] 175 return response
Base class for AI model adapters that handle task execution.
This abstract class provides the foundation for implementing model-specific adapters that can process tasks with structured or unstructured inputs/outputs. It handles input/output validation, prompt building, and run tracking.
Attributes: prompt_builder (BasePromptBuilder): Builder for constructing prompts for the model kiln_task (Task): The task configuration and metadata output_schema (dict | None): JSON schema for validating structured outputs input_schema (dict | None): JSON schema for validating structured inputs
LangchainAdapter( kiln_task: kiln_ai.datamodel.Task, custom_model: langchain_core.language_models.chat_models.BaseChatModel | None = None, model_name: str | None = None, provider: str | None = None, prompt_builder: kiln_ai.adapters.prompt_builders.BasePromptBuilder | None = None)
36 def __init__( 37 self, 38 kiln_task: datamodel.Task, 39 custom_model: BaseChatModel | None = None, 40 model_name: str | None = None, 41 provider: str | None = None, 42 prompt_builder: BasePromptBuilder | None = None, 43 ): 44 super().__init__(kiln_task, prompt_builder=prompt_builder) 45 if custom_model is not None: 46 self._model = custom_model 47 48 # Attempt to infer model provider and name from custom model 49 self.model_provider = "custom.langchain:" + custom_model.__class__.__name__ 50 self.model_name = "custom.langchain:unknown_model" 51 if hasattr(custom_model, "model_name") and isinstance( 52 getattr(custom_model, "model_name"), str 53 ): 54 self.model_name = "custom.langchain:" + getattr( 55 custom_model, "model_name" 56 ) 57 if hasattr(custom_model, "model") and isinstance( 58 getattr(custom_model, "model"), str 59 ): 60 self.model_name = "custom.langchain:" + getattr(custom_model, "model") 61 elif model_name is not None: 62 self.model_name = model_name 63 self.model_provider = provider or "custom.langchain.default_provider" 64 else: 65 raise ValueError( 66 "model_name and provider must be provided if custom_model is not provided" 67 )
async def
model( self) -> Union[langchain_core.language_models.chat_models.BaseChatModel, langchain_core.runnables.base.Runnable[Union[langchain_core.prompt_values.PromptValue, str, Sequence[Union[langchain_core.messages.base.BaseMessage, list[str], tuple[str, str], str, dict[str, Any]]]], Union[Dict, pydantic.main.BaseModel]]]:
69 async def model(self) -> LangChainModelType: 70 # cached model 71 if self._model: 72 return self._model 73 74 self._model = await langchain_model_from(self.model_name, self.model_provider) 75 76 if self.has_structured_output(): 77 if not hasattr(self._model, "with_structured_output") or not callable( 78 getattr(self._model, "with_structured_output") 79 ): 80 raise ValueError( 81 f"model {self._model} does not support structured output, cannot use output_json_schema" 82 ) 83 # Langchain expects title/description to be at top level, on top of json schema 84 output_schema = self.kiln_task.output_schema() 85 if output_schema is None: 86 raise ValueError( 87 f"output_json_schema is not valid json: {self.kiln_task.output_json_schema}" 88 ) 89 output_schema["title"] = "task_response" 90 output_schema["description"] = "A response from the task" 91 with_structured_output_options = await get_structured_output_options( 92 self.model_name, self.model_provider 93 ) 94 self._model = self._model.with_structured_output( 95 output_schema, 96 include_raw=True, 97 **with_structured_output_options, 98 ) 99 return self._model
async def
get_structured_output_options(model_name: str, model_provider: str) -> Dict[str, Any]:
178async def get_structured_output_options( 179 model_name: str, model_provider: str 180) -> Dict[str, Any]: 181 finetune_provider = await kiln_model_provider_from(model_name, model_provider) 182 if finetune_provider and finetune_provider.adapter_options.get("langchain"): 183 return finetune_provider.adapter_options["langchain"].get( 184 "with_structured_output_options", {} 185 ) 186 return {}
async def
langchain_model_from( name: str, provider_name: str | None = None) -> langchain_core.language_models.chat_models.BaseChatModel:
async def
langchain_model_from_provider( provider: kiln_ai.adapters.ml_model_list.KilnModelProvider, model_name: str) -> langchain_core.language_models.chat_models.BaseChatModel:
196async def langchain_model_from_provider( 197 provider: KilnModelProvider, model_name: str 198) -> BaseChatModel: 199 if provider.name == ModelProviderName.openai: 200 api_key = Config.shared().open_ai_api_key 201 return ChatOpenAI(**provider.provider_options, openai_api_key=api_key) # type: ignore[arg-type] 202 elif provider.name == ModelProviderName.groq: 203 api_key = Config.shared().groq_api_key 204 if api_key is None: 205 raise ValueError( 206 "Attempted to use Groq without an API key set. " 207 "Get your API key from https://console.groq.com/keys" 208 ) 209 return ChatGroq(**provider.provider_options, groq_api_key=api_key) # type: ignore[arg-type] 210 elif provider.name == ModelProviderName.amazon_bedrock: 211 api_key = Config.shared().bedrock_access_key 212 secret_key = Config.shared().bedrock_secret_key 213 # langchain doesn't allow passing these, so ugly hack to set env vars 214 os.environ["AWS_ACCESS_KEY_ID"] = api_key 215 os.environ["AWS_SECRET_ACCESS_KEY"] = secret_key 216 return ChatBedrockConverse( 217 **provider.provider_options, 218 ) 219 elif provider.name == ModelProviderName.fireworks_ai: 220 api_key = Config.shared().fireworks_api_key 221 return ChatFireworks(**provider.provider_options, api_key=api_key) 222 elif provider.name == ModelProviderName.ollama: 223 # Ollama model naming is pretty flexible. We try a few versions of the model name 224 potential_model_names = [] 225 if "model" in provider.provider_options: 226 potential_model_names.append(provider.provider_options["model"]) 227 if "model_aliases" in provider.provider_options: 228 potential_model_names.extend(provider.provider_options["model_aliases"]) 229 230 # Get the list of models Ollama supports 231 ollama_connection = await get_ollama_connection() 232 if ollama_connection is None: 233 raise ValueError("Failed to connect to Ollama. Ensure Ollama is running.") 234 235 for model_name in potential_model_names: 236 if ollama_model_installed(ollama_connection, model_name): 237 return ChatOllama(model=model_name, base_url=ollama_base_url()) 238 239 raise ValueError(f"Model {model_name} not installed on Ollama") 240 elif provider.name == ModelProviderName.openrouter: 241 api_key = Config.shared().open_router_api_key 242 base_url = getenv("OPENROUTER_BASE_URL") or "https://openrouter.ai/api/v1" 243 return ChatOpenAI( 244 **provider.provider_options, 245 openai_api_key=api_key, # type: ignore[arg-type] 246 openai_api_base=base_url, # type: ignore[arg-type] 247 default_headers={ 248 "HTTP-Referer": "https://getkiln.ai/openrouter", 249 "X-Title": "KilnAI", 250 }, 251 ) 252 else: 253 raise ValueError(f"Invalid model or provider: {model_name} - {provider.name}")