kiln_ai.adapters.langchain_adapters

  1import os
  2from os import getenv
  3from typing import Any, Dict
  4
  5from langchain_aws import ChatBedrockConverse
  6from langchain_core.language_models import LanguageModelInput
  7from langchain_core.language_models.chat_models import BaseChatModel
  8from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
  9from langchain_core.messages.base import BaseMessage
 10from langchain_core.runnables import Runnable
 11from langchain_fireworks import ChatFireworks
 12from langchain_groq import ChatGroq
 13from langchain_ollama import ChatOllama
 14from langchain_openai import ChatOpenAI
 15from pydantic import BaseModel
 16
 17import kiln_ai.datamodel as datamodel
 18from kiln_ai.adapters.ollama_tools import (
 19    get_ollama_connection,
 20    ollama_base_url,
 21    ollama_model_installed,
 22)
 23from kiln_ai.utils.config import Config
 24
 25from .base_adapter import AdapterInfo, BaseAdapter, BasePromptBuilder, RunOutput
 26from .ml_model_list import KilnModelProvider, ModelProviderName
 27from .provider_tools import kiln_model_provider_from
 28
 29LangChainModelType = BaseChatModel | Runnable[LanguageModelInput, Dict | BaseModel]
 30
 31
 32class LangchainAdapter(BaseAdapter):
 33    _model: LangChainModelType | None = None
 34
 35    def __init__(
 36        self,
 37        kiln_task: datamodel.Task,
 38        custom_model: BaseChatModel | None = None,
 39        model_name: str | None = None,
 40        provider: str | None = None,
 41        prompt_builder: BasePromptBuilder | None = None,
 42    ):
 43        super().__init__(kiln_task, prompt_builder=prompt_builder)
 44        if custom_model is not None:
 45            self._model = custom_model
 46
 47            # Attempt to infer model provider and name from custom model
 48            self.model_provider = "custom.langchain:" + custom_model.__class__.__name__
 49            self.model_name = "custom.langchain:unknown_model"
 50            if hasattr(custom_model, "model_name") and isinstance(
 51                getattr(custom_model, "model_name"), str
 52            ):
 53                self.model_name = "custom.langchain:" + getattr(
 54                    custom_model, "model_name"
 55                )
 56            if hasattr(custom_model, "model") and isinstance(
 57                getattr(custom_model, "model"), str
 58            ):
 59                self.model_name = "custom.langchain:" + getattr(custom_model, "model")
 60        elif model_name is not None:
 61            self.model_name = model_name
 62            self.model_provider = provider or "custom.langchain.default_provider"
 63        else:
 64            raise ValueError(
 65                "model_name and provider must be provided if custom_model is not provided"
 66            )
 67
 68    async def model(self) -> LangChainModelType:
 69        # cached model
 70        if self._model:
 71            return self._model
 72
 73        self._model = await langchain_model_from(self.model_name, self.model_provider)
 74
 75        if self.has_structured_output():
 76            if not hasattr(self._model, "with_structured_output") or not callable(
 77                getattr(self._model, "with_structured_output")
 78            ):
 79                raise ValueError(
 80                    f"model {self._model} does not support structured output, cannot use output_json_schema"
 81                )
 82            # Langchain expects title/description to be at top level, on top of json schema
 83            output_schema = self.kiln_task.output_schema()
 84            if output_schema is None:
 85                raise ValueError(
 86                    f"output_json_schema is not valid json: {self.kiln_task.output_json_schema}"
 87                )
 88            output_schema["title"] = "task_response"
 89            output_schema["description"] = "A response from the task"
 90            with_structured_output_options = await get_structured_output_options(
 91                self.model_name, self.model_provider
 92            )
 93            self._model = self._model.with_structured_output(
 94                output_schema,
 95                include_raw=True,
 96                **with_structured_output_options,
 97            )
 98        return self._model
 99
100    async def _run(self, input: Dict | str) -> RunOutput:
101        model = await self.model()
102        chain = model
103        intermediate_outputs = {}
104
105        prompt = self.build_prompt()
106        user_msg = self.prompt_builder.build_user_message(input)
107        messages = [
108            SystemMessage(content=prompt),
109            HumanMessage(content=user_msg),
110        ]
111
112        # COT with structured output
113        cot_prompt = self.prompt_builder.chain_of_thought_prompt()
114        if cot_prompt and self.has_structured_output():
115            # Base model (without structured output) used for COT message
116            base_model = await langchain_model_from(
117                self.model_name, self.model_provider
118            )
119            messages.append(
120                SystemMessage(content=cot_prompt),
121            )
122
123            cot_messages = [*messages]
124            cot_response = await base_model.ainvoke(cot_messages)
125            intermediate_outputs["chain_of_thought"] = cot_response.content
126            messages.append(AIMessage(content=cot_response.content))
127            messages.append(
128                SystemMessage(content="Considering the above, return a final result.")
129            )
130        elif cot_prompt:
131            messages.append(SystemMessage(content=cot_prompt))
132
133        response = await chain.ainvoke(messages)
134
135        if self.has_structured_output():
136            if (
137                not isinstance(response, dict)
138                or "parsed" not in response
139                or not isinstance(response["parsed"], dict)
140            ):
141                raise RuntimeError(f"structured response not returned: {response}")
142            structured_response = response["parsed"]
143            return RunOutput(
144                output=self._munge_response(structured_response),
145                intermediate_outputs=intermediate_outputs,
146            )
147        else:
148            if not isinstance(response, BaseMessage):
149                raise RuntimeError(f"response is not a BaseMessage: {response}")
150            text_content = response.content
151            if not isinstance(text_content, str):
152                raise RuntimeError(f"response is not a string: {text_content}")
153            return RunOutput(
154                output=text_content,
155                intermediate_outputs=intermediate_outputs,
156            )
157
158    def adapter_info(self) -> AdapterInfo:
159        return AdapterInfo(
160            model_name=self.model_name,
161            model_provider=self.model_provider,
162            adapter_name="kiln_langchain_adapter",
163            prompt_builder_name=self.prompt_builder.__class__.prompt_builder_name(),
164        )
165
166    def _munge_response(self, response: Dict) -> Dict:
167        # Mistral Large tool calling format is a bit different. Convert to standard format.
168        if (
169            "name" in response
170            and response["name"] == "task_response"
171            and "arguments" in response
172        ):
173            return response["arguments"]
174        return response
175
176
177async def get_structured_output_options(
178    model_name: str, model_provider: str
179) -> Dict[str, Any]:
180    finetune_provider = await kiln_model_provider_from(model_name, model_provider)
181    if finetune_provider and finetune_provider.adapter_options.get("langchain"):
182        return finetune_provider.adapter_options["langchain"].get(
183            "with_structured_output_options", {}
184        )
185    return {}
186
187
188async def langchain_model_from(
189    name: str, provider_name: str | None = None
190) -> BaseChatModel:
191    provider = await kiln_model_provider_from(name, provider_name)
192    return await langchain_model_from_provider(provider, name)
193
194
195async def langchain_model_from_provider(
196    provider: KilnModelProvider, model_name: str
197) -> BaseChatModel:
198    if provider.name == ModelProviderName.openai:
199        api_key = Config.shared().open_ai_api_key
200        return ChatOpenAI(**provider.provider_options, openai_api_key=api_key)  # type: ignore[arg-type]
201    elif provider.name == ModelProviderName.groq:
202        api_key = Config.shared().groq_api_key
203        if api_key is None:
204            raise ValueError(
205                "Attempted to use Groq without an API key set. "
206                "Get your API key from https://console.groq.com/keys"
207            )
208        return ChatGroq(**provider.provider_options, groq_api_key=api_key)  # type: ignore[arg-type]
209    elif provider.name == ModelProviderName.amazon_bedrock:
210        api_key = Config.shared().bedrock_access_key
211        secret_key = Config.shared().bedrock_secret_key
212        # langchain doesn't allow passing these, so ugly hack to set env vars
213        os.environ["AWS_ACCESS_KEY_ID"] = api_key
214        os.environ["AWS_SECRET_ACCESS_KEY"] = secret_key
215        return ChatBedrockConverse(
216            **provider.provider_options,
217        )
218    elif provider.name == ModelProviderName.fireworks_ai:
219        api_key = Config.shared().fireworks_api_key
220        return ChatFireworks(**provider.provider_options, api_key=api_key)
221    elif provider.name == ModelProviderName.ollama:
222        # Ollama model naming is pretty flexible. We try a few versions of the model name
223        potential_model_names = []
224        if "model" in provider.provider_options:
225            potential_model_names.append(provider.provider_options["model"])
226        if "model_aliases" in provider.provider_options:
227            potential_model_names.extend(provider.provider_options["model_aliases"])
228
229        # Get the list of models Ollama supports
230        ollama_connection = await get_ollama_connection()
231        if ollama_connection is None:
232            raise ValueError("Failed to connect to Ollama. Ensure Ollama is running.")
233
234        for model_name in potential_model_names:
235            if ollama_model_installed(ollama_connection, model_name):
236                return ChatOllama(model=model_name, base_url=ollama_base_url())
237
238        raise ValueError(f"Model {model_name} not installed on Ollama")
239    elif provider.name == ModelProviderName.openrouter:
240        api_key = Config.shared().open_router_api_key
241        base_url = getenv("OPENROUTER_BASE_URL") or "https://openrouter.ai/api/v1"
242        return ChatOpenAI(
243            **provider.provider_options,
244            openai_api_key=api_key,  # type: ignore[arg-type]
245            openai_api_base=base_url,  # type: ignore[arg-type]
246            default_headers={
247                "HTTP-Referer": "https://getkiln.ai/openrouter",
248                "X-Title": "KilnAI",
249            },
250        )
251    else:
252        raise ValueError(f"Invalid model or provider: {model_name} - {provider.name}")
LangChainModelType = typing.Union[langchain_core.language_models.chat_models.BaseChatModel, langchain_core.runnables.base.Runnable[typing.Union[langchain_core.prompt_values.PromptValue, str, collections.abc.Sequence[typing.Union[langchain_core.messages.base.BaseMessage, list[str], tuple[str, str], str, dict[str, typing.Any]]]], typing.Union[typing.Dict, pydantic.main.BaseModel]]]
class LangchainAdapter(kiln_ai.adapters.base_adapter.BaseAdapter):
 33class LangchainAdapter(BaseAdapter):
 34    _model: LangChainModelType | None = None
 35
 36    def __init__(
 37        self,
 38        kiln_task: datamodel.Task,
 39        custom_model: BaseChatModel | None = None,
 40        model_name: str | None = None,
 41        provider: str | None = None,
 42        prompt_builder: BasePromptBuilder | None = None,
 43    ):
 44        super().__init__(kiln_task, prompt_builder=prompt_builder)
 45        if custom_model is not None:
 46            self._model = custom_model
 47
 48            # Attempt to infer model provider and name from custom model
 49            self.model_provider = "custom.langchain:" + custom_model.__class__.__name__
 50            self.model_name = "custom.langchain:unknown_model"
 51            if hasattr(custom_model, "model_name") and isinstance(
 52                getattr(custom_model, "model_name"), str
 53            ):
 54                self.model_name = "custom.langchain:" + getattr(
 55                    custom_model, "model_name"
 56                )
 57            if hasattr(custom_model, "model") and isinstance(
 58                getattr(custom_model, "model"), str
 59            ):
 60                self.model_name = "custom.langchain:" + getattr(custom_model, "model")
 61        elif model_name is not None:
 62            self.model_name = model_name
 63            self.model_provider = provider or "custom.langchain.default_provider"
 64        else:
 65            raise ValueError(
 66                "model_name and provider must be provided if custom_model is not provided"
 67            )
 68
 69    async def model(self) -> LangChainModelType:
 70        # cached model
 71        if self._model:
 72            return self._model
 73
 74        self._model = await langchain_model_from(self.model_name, self.model_provider)
 75
 76        if self.has_structured_output():
 77            if not hasattr(self._model, "with_structured_output") or not callable(
 78                getattr(self._model, "with_structured_output")
 79            ):
 80                raise ValueError(
 81                    f"model {self._model} does not support structured output, cannot use output_json_schema"
 82                )
 83            # Langchain expects title/description to be at top level, on top of json schema
 84            output_schema = self.kiln_task.output_schema()
 85            if output_schema is None:
 86                raise ValueError(
 87                    f"output_json_schema is not valid json: {self.kiln_task.output_json_schema}"
 88                )
 89            output_schema["title"] = "task_response"
 90            output_schema["description"] = "A response from the task"
 91            with_structured_output_options = await get_structured_output_options(
 92                self.model_name, self.model_provider
 93            )
 94            self._model = self._model.with_structured_output(
 95                output_schema,
 96                include_raw=True,
 97                **with_structured_output_options,
 98            )
 99        return self._model
100
101    async def _run(self, input: Dict | str) -> RunOutput:
102        model = await self.model()
103        chain = model
104        intermediate_outputs = {}
105
106        prompt = self.build_prompt()
107        user_msg = self.prompt_builder.build_user_message(input)
108        messages = [
109            SystemMessage(content=prompt),
110            HumanMessage(content=user_msg),
111        ]
112
113        # COT with structured output
114        cot_prompt = self.prompt_builder.chain_of_thought_prompt()
115        if cot_prompt and self.has_structured_output():
116            # Base model (without structured output) used for COT message
117            base_model = await langchain_model_from(
118                self.model_name, self.model_provider
119            )
120            messages.append(
121                SystemMessage(content=cot_prompt),
122            )
123
124            cot_messages = [*messages]
125            cot_response = await base_model.ainvoke(cot_messages)
126            intermediate_outputs["chain_of_thought"] = cot_response.content
127            messages.append(AIMessage(content=cot_response.content))
128            messages.append(
129                SystemMessage(content="Considering the above, return a final result.")
130            )
131        elif cot_prompt:
132            messages.append(SystemMessage(content=cot_prompt))
133
134        response = await chain.ainvoke(messages)
135
136        if self.has_structured_output():
137            if (
138                not isinstance(response, dict)
139                or "parsed" not in response
140                or not isinstance(response["parsed"], dict)
141            ):
142                raise RuntimeError(f"structured response not returned: {response}")
143            structured_response = response["parsed"]
144            return RunOutput(
145                output=self._munge_response(structured_response),
146                intermediate_outputs=intermediate_outputs,
147            )
148        else:
149            if not isinstance(response, BaseMessage):
150                raise RuntimeError(f"response is not a BaseMessage: {response}")
151            text_content = response.content
152            if not isinstance(text_content, str):
153                raise RuntimeError(f"response is not a string: {text_content}")
154            return RunOutput(
155                output=text_content,
156                intermediate_outputs=intermediate_outputs,
157            )
158
159    def adapter_info(self) -> AdapterInfo:
160        return AdapterInfo(
161            model_name=self.model_name,
162            model_provider=self.model_provider,
163            adapter_name="kiln_langchain_adapter",
164            prompt_builder_name=self.prompt_builder.__class__.prompt_builder_name(),
165        )
166
167    def _munge_response(self, response: Dict) -> Dict:
168        # Mistral Large tool calling format is a bit different. Convert to standard format.
169        if (
170            "name" in response
171            and response["name"] == "task_response"
172            and "arguments" in response
173        ):
174            return response["arguments"]
175        return response

Base class for AI model adapters that handle task execution.

This abstract class provides the foundation for implementing model-specific adapters that can process tasks with structured or unstructured inputs/outputs. It handles input/output validation, prompt building, and run tracking.

Attributes: prompt_builder (BasePromptBuilder): Builder for constructing prompts for the model kiln_task (Task): The task configuration and metadata output_schema (dict | None): JSON schema for validating structured outputs input_schema (dict | None): JSON schema for validating structured inputs

LangchainAdapter( kiln_task: kiln_ai.datamodel.Task, custom_model: langchain_core.language_models.chat_models.BaseChatModel | None = None, model_name: str | None = None, provider: str | None = None, prompt_builder: kiln_ai.adapters.prompt_builders.BasePromptBuilder | None = None)
36    def __init__(
37        self,
38        kiln_task: datamodel.Task,
39        custom_model: BaseChatModel | None = None,
40        model_name: str | None = None,
41        provider: str | None = None,
42        prompt_builder: BasePromptBuilder | None = None,
43    ):
44        super().__init__(kiln_task, prompt_builder=prompt_builder)
45        if custom_model is not None:
46            self._model = custom_model
47
48            # Attempt to infer model provider and name from custom model
49            self.model_provider = "custom.langchain:" + custom_model.__class__.__name__
50            self.model_name = "custom.langchain:unknown_model"
51            if hasattr(custom_model, "model_name") and isinstance(
52                getattr(custom_model, "model_name"), str
53            ):
54                self.model_name = "custom.langchain:" + getattr(
55                    custom_model, "model_name"
56                )
57            if hasattr(custom_model, "model") and isinstance(
58                getattr(custom_model, "model"), str
59            ):
60                self.model_name = "custom.langchain:" + getattr(custom_model, "model")
61        elif model_name is not None:
62            self.model_name = model_name
63            self.model_provider = provider or "custom.langchain.default_provider"
64        else:
65            raise ValueError(
66                "model_name and provider must be provided if custom_model is not provided"
67            )
async def model( self) -> Union[langchain_core.language_models.chat_models.BaseChatModel, langchain_core.runnables.base.Runnable[Union[langchain_core.prompt_values.PromptValue, str, Sequence[Union[langchain_core.messages.base.BaseMessage, list[str], tuple[str, str], str, dict[str, Any]]]], Union[Dict, pydantic.main.BaseModel]]]:
69    async def model(self) -> LangChainModelType:
70        # cached model
71        if self._model:
72            return self._model
73
74        self._model = await langchain_model_from(self.model_name, self.model_provider)
75
76        if self.has_structured_output():
77            if not hasattr(self._model, "with_structured_output") or not callable(
78                getattr(self._model, "with_structured_output")
79            ):
80                raise ValueError(
81                    f"model {self._model} does not support structured output, cannot use output_json_schema"
82                )
83            # Langchain expects title/description to be at top level, on top of json schema
84            output_schema = self.kiln_task.output_schema()
85            if output_schema is None:
86                raise ValueError(
87                    f"output_json_schema is not valid json: {self.kiln_task.output_json_schema}"
88                )
89            output_schema["title"] = "task_response"
90            output_schema["description"] = "A response from the task"
91            with_structured_output_options = await get_structured_output_options(
92                self.model_name, self.model_provider
93            )
94            self._model = self._model.with_structured_output(
95                output_schema,
96                include_raw=True,
97                **with_structured_output_options,
98            )
99        return self._model
def adapter_info(self) -> kiln_ai.adapters.base_adapter.AdapterInfo:
159    def adapter_info(self) -> AdapterInfo:
160        return AdapterInfo(
161            model_name=self.model_name,
162            model_provider=self.model_provider,
163            adapter_name="kiln_langchain_adapter",
164            prompt_builder_name=self.prompt_builder.__class__.prompt_builder_name(),
165        )
async def get_structured_output_options(model_name: str, model_provider: str) -> Dict[str, Any]:
178async def get_structured_output_options(
179    model_name: str, model_provider: str
180) -> Dict[str, Any]:
181    finetune_provider = await kiln_model_provider_from(model_name, model_provider)
182    if finetune_provider and finetune_provider.adapter_options.get("langchain"):
183        return finetune_provider.adapter_options["langchain"].get(
184            "with_structured_output_options", {}
185        )
186    return {}
async def langchain_model_from( name: str, provider_name: str | None = None) -> langchain_core.language_models.chat_models.BaseChatModel:
189async def langchain_model_from(
190    name: str, provider_name: str | None = None
191) -> BaseChatModel:
192    provider = await kiln_model_provider_from(name, provider_name)
193    return await langchain_model_from_provider(provider, name)
async def langchain_model_from_provider( provider: kiln_ai.adapters.ml_model_list.KilnModelProvider, model_name: str) -> langchain_core.language_models.chat_models.BaseChatModel:
196async def langchain_model_from_provider(
197    provider: KilnModelProvider, model_name: str
198) -> BaseChatModel:
199    if provider.name == ModelProviderName.openai:
200        api_key = Config.shared().open_ai_api_key
201        return ChatOpenAI(**provider.provider_options, openai_api_key=api_key)  # type: ignore[arg-type]
202    elif provider.name == ModelProviderName.groq:
203        api_key = Config.shared().groq_api_key
204        if api_key is None:
205            raise ValueError(
206                "Attempted to use Groq without an API key set. "
207                "Get your API key from https://console.groq.com/keys"
208            )
209        return ChatGroq(**provider.provider_options, groq_api_key=api_key)  # type: ignore[arg-type]
210    elif provider.name == ModelProviderName.amazon_bedrock:
211        api_key = Config.shared().bedrock_access_key
212        secret_key = Config.shared().bedrock_secret_key
213        # langchain doesn't allow passing these, so ugly hack to set env vars
214        os.environ["AWS_ACCESS_KEY_ID"] = api_key
215        os.environ["AWS_SECRET_ACCESS_KEY"] = secret_key
216        return ChatBedrockConverse(
217            **provider.provider_options,
218        )
219    elif provider.name == ModelProviderName.fireworks_ai:
220        api_key = Config.shared().fireworks_api_key
221        return ChatFireworks(**provider.provider_options, api_key=api_key)
222    elif provider.name == ModelProviderName.ollama:
223        # Ollama model naming is pretty flexible. We try a few versions of the model name
224        potential_model_names = []
225        if "model" in provider.provider_options:
226            potential_model_names.append(provider.provider_options["model"])
227        if "model_aliases" in provider.provider_options:
228            potential_model_names.extend(provider.provider_options["model_aliases"])
229
230        # Get the list of models Ollama supports
231        ollama_connection = await get_ollama_connection()
232        if ollama_connection is None:
233            raise ValueError("Failed to connect to Ollama. Ensure Ollama is running.")
234
235        for model_name in potential_model_names:
236            if ollama_model_installed(ollama_connection, model_name):
237                return ChatOllama(model=model_name, base_url=ollama_base_url())
238
239        raise ValueError(f"Model {model_name} not installed on Ollama")
240    elif provider.name == ModelProviderName.openrouter:
241        api_key = Config.shared().open_router_api_key
242        base_url = getenv("OPENROUTER_BASE_URL") or "https://openrouter.ai/api/v1"
243        return ChatOpenAI(
244            **provider.provider_options,
245            openai_api_key=api_key,  # type: ignore[arg-type]
246            openai_api_base=base_url,  # type: ignore[arg-type]
247            default_headers={
248                "HTTP-Referer": "https://getkiln.ai/openrouter",
249                "X-Title": "KilnAI",
250            },
251        )
252    else:
253        raise ValueError(f"Invalid model or provider: {model_name} - {provider.name}")