kiln_ai.adapters.prompt_builders

  1import json
  2from abc import ABCMeta, abstractmethod
  3from typing import Dict
  4
  5from kiln_ai.datamodel import Task
  6from kiln_ai.utils.formatting import snake_case
  7
  8
  9class BasePromptBuilder(metaclass=ABCMeta):
 10    def __init__(self, task: Task):
 11        self.task = task
 12
 13    @abstractmethod
 14    def build_prompt(self) -> str:
 15        pass
 16
 17    # override to change the name of the prompt builder (if changing class names)
 18    @classmethod
 19    def prompt_builder_name(cls) -> str:
 20        return snake_case(cls.__name__)
 21
 22    # Can be overridden to add more information to the user message
 23    def build_user_message(self, input: Dict | str) -> str:
 24        if isinstance(input, Dict):
 25            return f"The input is:\n{json.dumps(input, indent=2)}"
 26
 27        return f"The input is:\n{input}"
 28
 29
 30class SimplePromptBuilder(BasePromptBuilder):
 31    def build_prompt(self) -> str:
 32        base_prompt = self.task.instruction
 33
 34        # TODO: this is just a quick version. Formatting and best practices TBD
 35        if len(self.task.requirements) > 0:
 36            base_prompt += (
 37                "\n\nYour response should respect the following requirements:\n"
 38            )
 39            # iterate requirements, formatting them in numbereed list like 1) task.instruction\n2)...
 40            for i, requirement in enumerate(self.task.requirements):
 41                base_prompt += f"{i+1}) {requirement.instruction}\n"
 42
 43        return base_prompt
 44
 45
 46class MultiShotPromptBuilder(BasePromptBuilder):
 47    @classmethod
 48    def example_count(cls) -> int:
 49        return 25
 50
 51    def build_prompt(self) -> str:
 52        base_prompt = f"# Instruction\n\n{ self.task.instruction }\n\n"
 53
 54        if len(self.task.requirements) > 0:
 55            base_prompt += "# Requirements\n\nYour response should respect the following requirements:\n"
 56            for i, requirement in enumerate(self.task.requirements):
 57                base_prompt += f"{i+1}) {requirement.instruction}\n"
 58            base_prompt += "\n"
 59
 60        valid_examples: list[tuple[str, str]] = []
 61        runs = self.task.runs()
 62
 63        # first pass, we look for repaired outputs. These are the best examples.
 64        for run in runs:
 65            if len(valid_examples) >= self.__class__.example_count():
 66                break
 67            if run.repaired_output is not None:
 68                valid_examples.append((run.input, run.repaired_output.output))
 69
 70        # second pass, we look for high quality outputs (rating based)
 71        # Minimum is "high_quality" (4 star in star rating scale), then sort by rating
 72        # exclude repaired outputs as they were used above
 73        runs_with_rating = [
 74            run
 75            for run in runs
 76            if run.output.rating is not None
 77            and run.output.rating.value is not None
 78            and run.output.rating.is_high_quality()
 79            and run.repaired_output is None
 80        ]
 81        runs_with_rating.sort(
 82            key=lambda x: (x.output.rating and x.output.rating.value) or 0, reverse=True
 83        )
 84        for run in runs_with_rating:
 85            if len(valid_examples) >= self.__class__.example_count():
 86                break
 87            valid_examples.append((run.input, run.output.output))
 88
 89        if len(valid_examples) > 0:
 90            base_prompt += "# Example Outputs\n\n"
 91            for i, example in enumerate(valid_examples):
 92                base_prompt += (
 93                    f"## Example {i+1}\n\nInput: {example[0]}\nOutput: {example[1]}\n\n"
 94                )
 95
 96        return base_prompt
 97
 98
 99class FewShotPromptBuilder(MultiShotPromptBuilder):
100    @classmethod
101    def example_count(cls) -> int:
102        return 4
103
104
105prompt_builder_registry = {
106    "simple_prompt_builder": SimplePromptBuilder,
107    "multi_shot_prompt_builder": MultiShotPromptBuilder,
108    "few_shot_prompt_builder": FewShotPromptBuilder,
109}
110
111
112# Our UI has some names that are not the same as the class names, which also hint parameters.
113def prompt_builder_from_ui_name(ui_name: str) -> type[BasePromptBuilder]:
114    match ui_name:
115        case "basic":
116            return SimplePromptBuilder
117        case "few_shot":
118            return FewShotPromptBuilder
119        case "many_shot":
120            return MultiShotPromptBuilder
121        case _:
122            raise ValueError(f"Unknown prompt builder: {ui_name}")
class BasePromptBuilder:
10class BasePromptBuilder(metaclass=ABCMeta):
11    def __init__(self, task: Task):
12        self.task = task
13
14    @abstractmethod
15    def build_prompt(self) -> str:
16        pass
17
18    # override to change the name of the prompt builder (if changing class names)
19    @classmethod
20    def prompt_builder_name(cls) -> str:
21        return snake_case(cls.__name__)
22
23    # Can be overridden to add more information to the user message
24    def build_user_message(self, input: Dict | str) -> str:
25        if isinstance(input, Dict):
26            return f"The input is:\n{json.dumps(input, indent=2)}"
27
28        return f"The input is:\n{input}"
task
@abstractmethod
def build_prompt(self) -> str:
14    @abstractmethod
15    def build_prompt(self) -> str:
16        pass
@classmethod
def prompt_builder_name(cls) -> str:
19    @classmethod
20    def prompt_builder_name(cls) -> str:
21        return snake_case(cls.__name__)
def build_user_message(self, input: Union[Dict, str]) -> str:
24    def build_user_message(self, input: Dict | str) -> str:
25        if isinstance(input, Dict):
26            return f"The input is:\n{json.dumps(input, indent=2)}"
27
28        return f"The input is:\n{input}"
class SimplePromptBuilder(BasePromptBuilder):
31class SimplePromptBuilder(BasePromptBuilder):
32    def build_prompt(self) -> str:
33        base_prompt = self.task.instruction
34
35        # TODO: this is just a quick version. Formatting and best practices TBD
36        if len(self.task.requirements) > 0:
37            base_prompt += (
38                "\n\nYour response should respect the following requirements:\n"
39            )
40            # iterate requirements, formatting them in numbereed list like 1) task.instruction\n2)...
41            for i, requirement in enumerate(self.task.requirements):
42                base_prompt += f"{i+1}) {requirement.instruction}\n"
43
44        return base_prompt
def build_prompt(self) -> str:
32    def build_prompt(self) -> str:
33        base_prompt = self.task.instruction
34
35        # TODO: this is just a quick version. Formatting and best practices TBD
36        if len(self.task.requirements) > 0:
37            base_prompt += (
38                "\n\nYour response should respect the following requirements:\n"
39            )
40            # iterate requirements, formatting them in numbereed list like 1) task.instruction\n2)...
41            for i, requirement in enumerate(self.task.requirements):
42                base_prompt += f"{i+1}) {requirement.instruction}\n"
43
44        return base_prompt
class MultiShotPromptBuilder(BasePromptBuilder):
47class MultiShotPromptBuilder(BasePromptBuilder):
48    @classmethod
49    def example_count(cls) -> int:
50        return 25
51
52    def build_prompt(self) -> str:
53        base_prompt = f"# Instruction\n\n{ self.task.instruction }\n\n"
54
55        if len(self.task.requirements) > 0:
56            base_prompt += "# Requirements\n\nYour response should respect the following requirements:\n"
57            for i, requirement in enumerate(self.task.requirements):
58                base_prompt += f"{i+1}) {requirement.instruction}\n"
59            base_prompt += "\n"
60
61        valid_examples: list[tuple[str, str]] = []
62        runs = self.task.runs()
63
64        # first pass, we look for repaired outputs. These are the best examples.
65        for run in runs:
66            if len(valid_examples) >= self.__class__.example_count():
67                break
68            if run.repaired_output is not None:
69                valid_examples.append((run.input, run.repaired_output.output))
70
71        # second pass, we look for high quality outputs (rating based)
72        # Minimum is "high_quality" (4 star in star rating scale), then sort by rating
73        # exclude repaired outputs as they were used above
74        runs_with_rating = [
75            run
76            for run in runs
77            if run.output.rating is not None
78            and run.output.rating.value is not None
79            and run.output.rating.is_high_quality()
80            and run.repaired_output is None
81        ]
82        runs_with_rating.sort(
83            key=lambda x: (x.output.rating and x.output.rating.value) or 0, reverse=True
84        )
85        for run in runs_with_rating:
86            if len(valid_examples) >= self.__class__.example_count():
87                break
88            valid_examples.append((run.input, run.output.output))
89
90        if len(valid_examples) > 0:
91            base_prompt += "# Example Outputs\n\n"
92            for i, example in enumerate(valid_examples):
93                base_prompt += (
94                    f"## Example {i+1}\n\nInput: {example[0]}\nOutput: {example[1]}\n\n"
95                )
96
97        return base_prompt
@classmethod
def example_count(cls) -> int:
48    @classmethod
49    def example_count(cls) -> int:
50        return 25
def build_prompt(self) -> str:
52    def build_prompt(self) -> str:
53        base_prompt = f"# Instruction\n\n{ self.task.instruction }\n\n"
54
55        if len(self.task.requirements) > 0:
56            base_prompt += "# Requirements\n\nYour response should respect the following requirements:\n"
57            for i, requirement in enumerate(self.task.requirements):
58                base_prompt += f"{i+1}) {requirement.instruction}\n"
59            base_prompt += "\n"
60
61        valid_examples: list[tuple[str, str]] = []
62        runs = self.task.runs()
63
64        # first pass, we look for repaired outputs. These are the best examples.
65        for run in runs:
66            if len(valid_examples) >= self.__class__.example_count():
67                break
68            if run.repaired_output is not None:
69                valid_examples.append((run.input, run.repaired_output.output))
70
71        # second pass, we look for high quality outputs (rating based)
72        # Minimum is "high_quality" (4 star in star rating scale), then sort by rating
73        # exclude repaired outputs as they were used above
74        runs_with_rating = [
75            run
76            for run in runs
77            if run.output.rating is not None
78            and run.output.rating.value is not None
79            and run.output.rating.is_high_quality()
80            and run.repaired_output is None
81        ]
82        runs_with_rating.sort(
83            key=lambda x: (x.output.rating and x.output.rating.value) or 0, reverse=True
84        )
85        for run in runs_with_rating:
86            if len(valid_examples) >= self.__class__.example_count():
87                break
88            valid_examples.append((run.input, run.output.output))
89
90        if len(valid_examples) > 0:
91            base_prompt += "# Example Outputs\n\n"
92            for i, example in enumerate(valid_examples):
93                base_prompt += (
94                    f"## Example {i+1}\n\nInput: {example[0]}\nOutput: {example[1]}\n\n"
95                )
96
97        return base_prompt
class FewShotPromptBuilder(MultiShotPromptBuilder):
100class FewShotPromptBuilder(MultiShotPromptBuilder):
101    @classmethod
102    def example_count(cls) -> int:
103        return 4
@classmethod
def example_count(cls) -> int:
101    @classmethod
102    def example_count(cls) -> int:
103        return 4
prompt_builder_registry = {'simple_prompt_builder': <class 'SimplePromptBuilder'>, 'multi_shot_prompt_builder': <class 'MultiShotPromptBuilder'>, 'few_shot_prompt_builder': <class 'FewShotPromptBuilder'>}
def prompt_builder_from_ui_name(ui_name: str) -> type[BasePromptBuilder]:
114def prompt_builder_from_ui_name(ui_name: str) -> type[BasePromptBuilder]:
115    match ui_name:
116        case "basic":
117            return SimplePromptBuilder
118        case "few_shot":
119            return FewShotPromptBuilder
120        case "many_shot":
121            return MultiShotPromptBuilder
122        case _:
123            raise ValueError(f"Unknown prompt builder: {ui_name}")