kiln_ai.adapters.fine_tune.dataset_formatter

  1import json
  2import tempfile
  3from enum import Enum
  4from pathlib import Path
  5from typing import Any, Dict, Protocol
  6from uuid import uuid4
  7
  8from kiln_ai.datamodel import DatasetSplit, TaskRun
  9
 10
 11class DatasetFormat(str, Enum):
 12    """Formats for dataset generation. Both for file format (like JSONL), and internal structure (like chat/toolcall)"""
 13
 14    """OpenAI chat format with plaintext response"""
 15    OPENAI_CHAT_JSONL = "openai_chat_jsonl"
 16
 17    """OpenAI chat format with tool call response"""
 18    OPENAI_CHAT_TOOLCALL_JSONL = "openai_chat_toolcall_jsonl"
 19
 20    """HuggingFace chat template in JSONL"""
 21    HUGGINGFACE_CHAT_TEMPLATE_JSONL = "huggingface_chat_template_jsonl"
 22
 23    """HuggingFace chat template with tool calls in JSONL"""
 24    HUGGINGFACE_CHAT_TEMPLATE_TOOLCALL_JSONL = (
 25        "huggingface_chat_template_toolcall_jsonl"
 26    )
 27
 28
 29class FormatGenerator(Protocol):
 30    """Protocol for format generators"""
 31
 32    def __call__(self, task_run: TaskRun, system_message: str) -> Dict[str, Any]: ...
 33
 34
 35def generate_chat_message_response(
 36    task_run: TaskRun, system_message: str
 37) -> Dict[str, Any]:
 38    """Generate OpenAI chat format with plaintext response"""
 39    return {
 40        "messages": [
 41            {"role": "system", "content": system_message},
 42            {"role": "user", "content": task_run.input},
 43            {"role": "assistant", "content": task_run.output.output},
 44        ]
 45    }
 46
 47
 48def generate_chat_message_toolcall(
 49    task_run: TaskRun, system_message: str
 50) -> Dict[str, Any]:
 51    """Generate OpenAI chat format with tool call response"""
 52    try:
 53        arguments = json.loads(task_run.output.output)
 54    except json.JSONDecodeError as e:
 55        raise ValueError(f"Invalid JSON in for tool call: {e}") from e
 56
 57    return {
 58        "messages": [
 59            {"role": "system", "content": system_message},
 60            {"role": "user", "content": task_run.input},
 61            {
 62                "role": "assistant",
 63                "content": None,
 64                "tool_calls": [
 65                    {
 66                        "id": "call_1",
 67                        "type": "function",
 68                        "function": {
 69                            "name": "task_response",
 70                            # Yes we parse then dump again. This ensures it's valid JSON, and ensures it goes to 1 line
 71                            "arguments": json.dumps(arguments),
 72                        },
 73                    }
 74                ],
 75            },
 76        ]
 77    }
 78
 79
 80def generate_huggingface_chat_template(
 81    task_run: TaskRun, system_message: str
 82) -> Dict[str, Any]:
 83    """Generate HuggingFace chat template"""
 84    return {
 85        "conversations": [
 86            {"role": "system", "content": system_message},
 87            {"role": "user", "content": task_run.input},
 88            {"role": "assistant", "content": task_run.output.output},
 89        ]
 90    }
 91
 92
 93def generate_huggingface_chat_template_toolcall(
 94    task_run: TaskRun, system_message: str
 95) -> Dict[str, Any]:
 96    """Generate HuggingFace chat template with tool calls"""
 97    try:
 98        arguments = json.loads(task_run.output.output)
 99    except json.JSONDecodeError as e:
100        raise ValueError(f"Invalid JSON in for tool call: {e}") from e
101
102    # See https://huggingface.co/docs/transformers/en/chat_templating
103    return {
104        "conversations": [
105            {"role": "system", "content": system_message},
106            {"role": "user", "content": task_run.input},
107            {
108                "role": "assistant",
109                "tool_calls": [
110                    {
111                        "type": "function",
112                        "function": {
113                            "name": "task_response",
114                            "id": str(uuid4()).replace("-", "")[:9],
115                            "arguments": arguments,
116                        },
117                    }
118                ],
119            },
120        ]
121    }
122
123
124FORMAT_GENERATORS: Dict[DatasetFormat, FormatGenerator] = {
125    DatasetFormat.OPENAI_CHAT_JSONL: generate_chat_message_response,
126    DatasetFormat.OPENAI_CHAT_TOOLCALL_JSONL: generate_chat_message_toolcall,
127    DatasetFormat.HUGGINGFACE_CHAT_TEMPLATE_JSONL: generate_huggingface_chat_template,
128    DatasetFormat.HUGGINGFACE_CHAT_TEMPLATE_TOOLCALL_JSONL: generate_huggingface_chat_template_toolcall,
129}
130
131
132class DatasetFormatter:
133    """Handles formatting of datasets into various output formats"""
134
135    def __init__(self, dataset: DatasetSplit, system_message: str):
136        self.dataset = dataset
137        self.system_message = system_message
138
139        task = dataset.parent_task()
140        if task is None:
141            raise ValueError("Dataset has no parent task")
142        self.task = task
143
144    def dump_to_file(
145        self, split_name: str, format_type: DatasetFormat, path: Path | None = None
146    ) -> Path:
147        """
148        Format the dataset into the specified format.
149
150        Args:
151            split_name: Name of the split to dump
152            format_type: Format to generate the dataset in
153            path: Optional path to write to. If None, writes to temp directory
154
155        Returns:
156            Path to the generated file
157        """
158        if format_type not in FORMAT_GENERATORS:
159            raise ValueError(f"Unsupported format: {format_type}")
160        if split_name not in self.dataset.split_contents:
161            raise ValueError(f"Split {split_name} not found in dataset")
162
163        generator = FORMAT_GENERATORS[format_type]
164
165        # Write to a temp file if no path is provided
166        output_path = (
167            path
168            or Path(tempfile.gettempdir())
169            / f"{self.dataset.name}_{split_name}_{format_type}.jsonl"
170        )
171
172        runs = self.task.runs()
173        runs_by_id = {run.id: run for run in runs}
174
175        # Generate formatted output with UTF-8 encoding
176        with open(output_path, "w", encoding="utf-8") as f:
177            for run_id in self.dataset.split_contents[split_name]:
178                task_run = runs_by_id[run_id]
179                if task_run is None:
180                    raise ValueError(
181                        f"Task run {run_id} not found. This is required by this dataset."
182                    )
183
184                example = generator(task_run, self.system_message)
185                f.write(json.dumps(example) + "\n")
186
187        return output_path
class DatasetFormat(builtins.str, enum.Enum):
12class DatasetFormat(str, Enum):
13    """Formats for dataset generation. Both for file format (like JSONL), and internal structure (like chat/toolcall)"""
14
15    """OpenAI chat format with plaintext response"""
16    OPENAI_CHAT_JSONL = "openai_chat_jsonl"
17
18    """OpenAI chat format with tool call response"""
19    OPENAI_CHAT_TOOLCALL_JSONL = "openai_chat_toolcall_jsonl"
20
21    """HuggingFace chat template in JSONL"""
22    HUGGINGFACE_CHAT_TEMPLATE_JSONL = "huggingface_chat_template_jsonl"
23
24    """HuggingFace chat template with tool calls in JSONL"""
25    HUGGINGFACE_CHAT_TEMPLATE_TOOLCALL_JSONL = (
26        "huggingface_chat_template_toolcall_jsonl"
27    )

Formats for dataset generation. Both for file format (like JSONL), and internal structure (like chat/toolcall)

OPENAI_CHAT_JSONL = <DatasetFormat.OPENAI_CHAT_JSONL: 'openai_chat_jsonl'>

OpenAI chat format with tool call response

OPENAI_CHAT_TOOLCALL_JSONL = <DatasetFormat.OPENAI_CHAT_TOOLCALL_JSONL: 'openai_chat_toolcall_jsonl'>

HuggingFace chat template in JSONL

HUGGINGFACE_CHAT_TEMPLATE_JSONL = <DatasetFormat.HUGGINGFACE_CHAT_TEMPLATE_JSONL: 'huggingface_chat_template_jsonl'>

HuggingFace chat template with tool calls in JSONL

HUGGINGFACE_CHAT_TEMPLATE_TOOLCALL_JSONL = <DatasetFormat.HUGGINGFACE_CHAT_TEMPLATE_TOOLCALL_JSONL: 'huggingface_chat_template_toolcall_jsonl'>
class FormatGenerator(typing.Protocol):
30class FormatGenerator(Protocol):
31    """Protocol for format generators"""
32
33    def __call__(self, task_run: TaskRun, system_message: str) -> Dict[str, Any]: ...

Protocol for format generators

FormatGenerator(*args, **kwargs)
1767def _no_init_or_replace_init(self, *args, **kwargs):
1768    cls = type(self)
1769
1770    if cls._is_protocol:
1771        raise TypeError('Protocols cannot be instantiated')
1772
1773    # Already using a custom `__init__`. No need to calculate correct
1774    # `__init__` to call. This can lead to RecursionError. See bpo-45121.
1775    if cls.__init__ is not _no_init_or_replace_init:
1776        return
1777
1778    # Initially, `__init__` of a protocol subclass is set to `_no_init_or_replace_init`.
1779    # The first instantiation of the subclass will call `_no_init_or_replace_init` which
1780    # searches for a proper new `__init__` in the MRO. The new `__init__`
1781    # replaces the subclass' old `__init__` (ie `_no_init_or_replace_init`). Subsequent
1782    # instantiation of the protocol subclass will thus use the new
1783    # `__init__` and no longer call `_no_init_or_replace_init`.
1784    for base in cls.__mro__:
1785        init = base.__dict__.get('__init__', _no_init_or_replace_init)
1786        if init is not _no_init_or_replace_init:
1787            cls.__init__ = init
1788            break
1789    else:
1790        # should not happen
1791        cls.__init__ = object.__init__
1792
1793    cls.__init__(self, *args, **kwargs)
def generate_chat_message_response( task_run: kiln_ai.datamodel.TaskRun, system_message: str) -> Dict[str, Any]:
36def generate_chat_message_response(
37    task_run: TaskRun, system_message: str
38) -> Dict[str, Any]:
39    """Generate OpenAI chat format with plaintext response"""
40    return {
41        "messages": [
42            {"role": "system", "content": system_message},
43            {"role": "user", "content": task_run.input},
44            {"role": "assistant", "content": task_run.output.output},
45        ]
46    }

Generate OpenAI chat format with plaintext response

def generate_chat_message_toolcall( task_run: kiln_ai.datamodel.TaskRun, system_message: str) -> Dict[str, Any]:
49def generate_chat_message_toolcall(
50    task_run: TaskRun, system_message: str
51) -> Dict[str, Any]:
52    """Generate OpenAI chat format with tool call response"""
53    try:
54        arguments = json.loads(task_run.output.output)
55    except json.JSONDecodeError as e:
56        raise ValueError(f"Invalid JSON in for tool call: {e}") from e
57
58    return {
59        "messages": [
60            {"role": "system", "content": system_message},
61            {"role": "user", "content": task_run.input},
62            {
63                "role": "assistant",
64                "content": None,
65                "tool_calls": [
66                    {
67                        "id": "call_1",
68                        "type": "function",
69                        "function": {
70                            "name": "task_response",
71                            # Yes we parse then dump again. This ensures it's valid JSON, and ensures it goes to 1 line
72                            "arguments": json.dumps(arguments),
73                        },
74                    }
75                ],
76            },
77        ]
78    }

Generate OpenAI chat format with tool call response

def generate_huggingface_chat_template( task_run: kiln_ai.datamodel.TaskRun, system_message: str) -> Dict[str, Any]:
81def generate_huggingface_chat_template(
82    task_run: TaskRun, system_message: str
83) -> Dict[str, Any]:
84    """Generate HuggingFace chat template"""
85    return {
86        "conversations": [
87            {"role": "system", "content": system_message},
88            {"role": "user", "content": task_run.input},
89            {"role": "assistant", "content": task_run.output.output},
90        ]
91    }

Generate HuggingFace chat template

def generate_huggingface_chat_template_toolcall( task_run: kiln_ai.datamodel.TaskRun, system_message: str) -> Dict[str, Any]:
 94def generate_huggingface_chat_template_toolcall(
 95    task_run: TaskRun, system_message: str
 96) -> Dict[str, Any]:
 97    """Generate HuggingFace chat template with tool calls"""
 98    try:
 99        arguments = json.loads(task_run.output.output)
100    except json.JSONDecodeError as e:
101        raise ValueError(f"Invalid JSON in for tool call: {e}") from e
102
103    # See https://huggingface.co/docs/transformers/en/chat_templating
104    return {
105        "conversations": [
106            {"role": "system", "content": system_message},
107            {"role": "user", "content": task_run.input},
108            {
109                "role": "assistant",
110                "tool_calls": [
111                    {
112                        "type": "function",
113                        "function": {
114                            "name": "task_response",
115                            "id": str(uuid4()).replace("-", "")[:9],
116                            "arguments": arguments,
117                        },
118                    }
119                ],
120            },
121        ]
122    }

Generate HuggingFace chat template with tool calls

FORMAT_GENERATORS: Dict[DatasetFormat, FormatGenerator] = {<DatasetFormat.OPENAI_CHAT_JSONL: 'openai_chat_jsonl'>: <function generate_chat_message_response>, <DatasetFormat.OPENAI_CHAT_TOOLCALL_JSONL: 'openai_chat_toolcall_jsonl'>: <function generate_chat_message_toolcall>, <DatasetFormat.HUGGINGFACE_CHAT_TEMPLATE_JSONL: 'huggingface_chat_template_jsonl'>: <function generate_huggingface_chat_template>, <DatasetFormat.HUGGINGFACE_CHAT_TEMPLATE_TOOLCALL_JSONL: 'huggingface_chat_template_toolcall_jsonl'>: <function generate_huggingface_chat_template_toolcall>}
class DatasetFormatter:
133class DatasetFormatter:
134    """Handles formatting of datasets into various output formats"""
135
136    def __init__(self, dataset: DatasetSplit, system_message: str):
137        self.dataset = dataset
138        self.system_message = system_message
139
140        task = dataset.parent_task()
141        if task is None:
142            raise ValueError("Dataset has no parent task")
143        self.task = task
144
145    def dump_to_file(
146        self, split_name: str, format_type: DatasetFormat, path: Path | None = None
147    ) -> Path:
148        """
149        Format the dataset into the specified format.
150
151        Args:
152            split_name: Name of the split to dump
153            format_type: Format to generate the dataset in
154            path: Optional path to write to. If None, writes to temp directory
155
156        Returns:
157            Path to the generated file
158        """
159        if format_type not in FORMAT_GENERATORS:
160            raise ValueError(f"Unsupported format: {format_type}")
161        if split_name not in self.dataset.split_contents:
162            raise ValueError(f"Split {split_name} not found in dataset")
163
164        generator = FORMAT_GENERATORS[format_type]
165
166        # Write to a temp file if no path is provided
167        output_path = (
168            path
169            or Path(tempfile.gettempdir())
170            / f"{self.dataset.name}_{split_name}_{format_type}.jsonl"
171        )
172
173        runs = self.task.runs()
174        runs_by_id = {run.id: run for run in runs}
175
176        # Generate formatted output with UTF-8 encoding
177        with open(output_path, "w", encoding="utf-8") as f:
178            for run_id in self.dataset.split_contents[split_name]:
179                task_run = runs_by_id[run_id]
180                if task_run is None:
181                    raise ValueError(
182                        f"Task run {run_id} not found. This is required by this dataset."
183                    )
184
185                example = generator(task_run, self.system_message)
186                f.write(json.dumps(example) + "\n")
187
188        return output_path

Handles formatting of datasets into various output formats

DatasetFormatter(dataset: kiln_ai.datamodel.DatasetSplit, system_message: str)
136    def __init__(self, dataset: DatasetSplit, system_message: str):
137        self.dataset = dataset
138        self.system_message = system_message
139
140        task = dataset.parent_task()
141        if task is None:
142            raise ValueError("Dataset has no parent task")
143        self.task = task
dataset
system_message
task
def dump_to_file( self, split_name: str, format_type: DatasetFormat, path: pathlib.Path | None = None) -> pathlib.Path:
145    def dump_to_file(
146        self, split_name: str, format_type: DatasetFormat, path: Path | None = None
147    ) -> Path:
148        """
149        Format the dataset into the specified format.
150
151        Args:
152            split_name: Name of the split to dump
153            format_type: Format to generate the dataset in
154            path: Optional path to write to. If None, writes to temp directory
155
156        Returns:
157            Path to the generated file
158        """
159        if format_type not in FORMAT_GENERATORS:
160            raise ValueError(f"Unsupported format: {format_type}")
161        if split_name not in self.dataset.split_contents:
162            raise ValueError(f"Split {split_name} not found in dataset")
163
164        generator = FORMAT_GENERATORS[format_type]
165
166        # Write to a temp file if no path is provided
167        output_path = (
168            path
169            or Path(tempfile.gettempdir())
170            / f"{self.dataset.name}_{split_name}_{format_type}.jsonl"
171        )
172
173        runs = self.task.runs()
174        runs_by_id = {run.id: run for run in runs}
175
176        # Generate formatted output with UTF-8 encoding
177        with open(output_path, "w", encoding="utf-8") as f:
178            for run_id in self.dataset.split_contents[split_name]:
179                task_run = runs_by_id[run_id]
180                if task_run is None:
181                    raise ValueError(
182                        f"Task run {run_id} not found. This is required by this dataset."
183                    )
184
185                example = generator(task_run, self.system_message)
186                f.write(json.dumps(example) + "\n")
187
188        return output_path

Format the dataset into the specified format.

Args: split_name: Name of the split to dump format_type: Format to generate the dataset in path: Optional path to write to. If None, writes to temp directory

Returns: Path to the generated file