Coverage for src/chat_limiter/batch.py: 83%

335 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2025-07-11 20:26 +0100

1""" 

2Batch processing functionality for handling multiple requests efficiently. 

3""" 

4 

5import asyncio 

6import logging 

7import traceback 

8from abc import ABC, abstractmethod 

9from collections.abc import Callable 

10from concurrent.futures import ThreadPoolExecutor, as_completed 

11from dataclasses import dataclass, field 

12from typing import ( 

13 TYPE_CHECKING, 

14 Any, 

15 Generic, 

16 TypeVar, 

17) 

18 

19import httpx 

20 

21from tqdm.asyncio import tqdm as atqdm 

22from tqdm import tqdm 

23 

24if TYPE_CHECKING: 

25 pass 

26 

27from .limiter import ChatLimiter 

28from .types import ChatCompletionRequest, ChatCompletionResponse 

29 

30logger = logging.getLogger(__name__) 

31 

32# Type variables for generic batch processing 

33BatchItemT = TypeVar("BatchItemT") 

34BatchResultT = TypeVar("BatchResultT") 

35 

36 

37@dataclass 

38class BatchConfig: 

39 """Configuration for batch processing.""" 

40 

41 # Concurrency settings 

42 max_concurrent_requests: int = 10 

43 max_workers: int = 4 # For sync processing 

44 

45 # Retry settings 

46 max_retries_per_item: int = 3 

47 retry_delay: float = 1.0 

48 

49 # Progress tracking 

50 show_progress: bool = True 

51 progress_desc: str = "Processing batch" 

52 

53 # Error handling 

54 stop_on_first_error: bool = False 

55 collect_errors: bool = True 

56 verbose: bool = False 

57 

58 # Batch size optimization 

59 adaptive_batch_size: bool = True 

60 min_batch_size: int = 1 

61 max_batch_size: int = 100 

62 

63 # Request grouping 

64 group_by_model: bool = True 

65 group_by_provider: bool = True 

66 

67 

68@dataclass 

69class BatchItem(Generic[BatchItemT]): 

70 """A single item in a batch request.""" 

71 

72 # Item data 

73 data: BatchItemT 

74 

75 # Request configuration 

76 method: str = "POST" 

77 url: str = "/chat/completions" 

78 json_data: dict[str, Any] | None = None 

79 

80 # Metadata 

81 id: str | None = None 

82 metadata: dict[str, Any] = field(default_factory=dict) 

83 

84 # Processing state 

85 attempt_count: int = 0 

86 last_error: Exception | None = None 

87 

88 

89@dataclass 

90class BatchResult(Generic[BatchResultT]): 

91 """Result of processing a batch item.""" 

92 

93 # Original item 

94 item: "BatchItem[Any]" 

95 

96 # Result data 

97 result: BatchResultT | None = None 

98 

99 # Processing metadata 

100 duration: float = 0.0 

101 attempt_count: int = 0 

102 

103 # Error information 

104 success: bool = True 

105 error_message: str | None = None 

106 

107 # Response metadata 

108 response_headers: dict[str, str] = field(default_factory=dict) 

109 status_code: int | None = None 

110 

111 

112class BatchProcessor(ABC, Generic[BatchItemT, BatchResultT]): 

113 """Abstract base class for batch processing.""" 

114 

115 def __init__( 

116 self, 

117 limiter: ChatLimiter, 

118 config: BatchConfig | None = None, 

119 ): 

120 self.limiter = limiter 

121 self.config = config or BatchConfig() 

122 self._results: list[BatchResult[BatchResultT]] = [] 

123 self._errors: list[Exception] = [] 

124 

125 # Enable verbose mode on limiter if config specifies it 

126 if hasattr(self.limiter, 'set_verbose_mode'): 

127 self.limiter.set_verbose_mode(self.config.verbose) 

128 

129 @abstractmethod 

130 async def process_item(self, item: BatchItem[BatchItemT]) -> BatchResultT: 

131 """Process a single batch item.""" 

132 pass 

133 

134 @abstractmethod 

135 def process_item_sync(self, item: BatchItem[BatchItemT]) -> BatchResultT: 

136 """Process a single batch item synchronously.""" 

137 pass 

138 

139 def create_batch_items( 

140 self, 

141 items: list[BatchItemT], 

142 request_fn: Callable[[BatchItemT], tuple[str, str, dict[str, Any]]] | None = None, 

143 ) -> list[BatchItem[BatchItemT]]: 

144 """Create batch items from raw data.""" 

145 batch_items = [] 

146 

147 for i, item in enumerate(items): 

148 batch_item = BatchItem( 

149 data=item, 

150 id=f"item_{i}", 

151 ) 

152 

153 # Configure request if function provided 

154 if request_fn: 

155 method, url, json_data = request_fn(item) 

156 batch_item.method = method 

157 batch_item.url = url 

158 batch_item.json_data = json_data 

159 

160 batch_items.append(batch_item) 

161 

162 return batch_items 

163 

164 async def process_batch( 

165 self, 

166 items: list[BatchItemT] | list[BatchItem[BatchItemT]], 

167 request_fn: Callable[[BatchItemT], tuple[str, str, dict[str, Any]]] | None = None, 

168 ) -> list[BatchResult[BatchResultT]]: 

169 """Process a batch of items asynchronously.""" 

170 # Convert to batch items if needed 

171 if items and not isinstance(items[0], BatchItem): 

172 batch_items = self.create_batch_items(items, request_fn) # type: ignore 

173 else: 

174 batch_items = items # type: ignore 

175 

176 # Group items if configured 

177 if self.config.group_by_model or self.config.group_by_provider: 

178 grouped_items = self._group_items(batch_items) 

179 else: 

180 grouped_items = {"default": batch_items} 

181 

182 # Process groups 

183 all_results = [] 

184 

185 # Calculate total items for progress tracking 

186 total_items = sum(len(group_items) for group_items in grouped_items.values()) 

187 

188 # Initialize progress bar if enabled 

189 progress_bar = None 

190 if self.config.show_progress: 

191 progress_bar = tqdm( 

192 total=total_items, 

193 desc=self.config.progress_desc, 

194 unit="item" 

195 ) 

196 

197 for group_name, group_items in grouped_items.items(): 

198 logger.info( 

199 f"Processing group '{group_name}' with {len(group_items)} items" 

200 ) 

201 

202 # Create semaphore for concurrency control 

203 semaphore = asyncio.Semaphore(self.config.max_concurrent_requests) 

204 

205 # Process items with concurrency control and progress tracking 

206 tasks = [ 

207 self._process_item_with_retry(item, semaphore, progress_bar) for item in group_items 

208 ] 

209 

210 # Wait for all tasks to complete 

211 group_results = await asyncio.gather(*tasks, return_exceptions=True) 

212 

213 # Handle exceptions from gather 

214 for i, result in enumerate(group_results): 

215 if isinstance(result, Exception): 

216 # Create error result 

217 error_result: BatchResult[BatchResultT] = BatchResult( 

218 item=group_items[i], 

219 success=False, 

220 error_message=str(result), 

221 attempt_count=group_items[i].attempt_count, 

222 ) 

223 all_results.append(error_result) 

224 else: 

225 all_results.append(result) # type: ignore 

226 

227 # Close progress bar if it was created 

228 if progress_bar: 

229 progress_bar.close() 

230 

231 self._results = all_results 

232 return all_results 

233 

234 def process_batch_sync( 

235 self, 

236 items: list[BatchItemT] | list[BatchItem[BatchItemT]], 

237 request_fn: Callable[[BatchItemT], tuple[str, str, dict[str, Any]]] | None = None, 

238 ) -> list[BatchResult[BatchResultT]]: 

239 """Process a batch of items synchronously.""" 

240 # Convert to batch items if needed 

241 if items and not isinstance(items[0], BatchItem): 

242 batch_items = self.create_batch_items(items, request_fn) # type: ignore 

243 else: 

244 batch_items = items # type: ignore 

245 

246 # Group items if configured 

247 if self.config.group_by_model or self.config.group_by_provider: 

248 grouped_items = self._group_items(batch_items) 

249 else: 

250 grouped_items = {"default": batch_items} 

251 

252 # Calculate total items for progress tracking 

253 total_items = sum(len(group_items) for group_items in grouped_items.values()) 

254 

255 # Initialize progress bar if enabled 

256 progress_bar = None 

257 if self.config.show_progress: 

258 progress_bar = tqdm( 

259 total=total_items, 

260 desc=self.config.progress_desc, 

261 unit="item" 

262 ) 

263 

264 # Process groups 

265 all_results = [] 

266 for group_name, group_items in grouped_items.items(): 

267 logger.info( 

268 f"Processing group '{group_name}' with {len(group_items)} items" 

269 ) 

270 

271 # Use ThreadPoolExecutor for concurrent processing 

272 with ThreadPoolExecutor(max_workers=self.config.max_workers) as executor: 

273 # Submit all tasks 

274 future_to_item = { 

275 executor.submit(self._process_item_sync_with_retry, item, progress_bar): item 

276 for item in group_items 

277 } 

278 

279 # Collect results 

280 for future in as_completed(future_to_item): 

281 item = future_to_item[future] 

282 try: 

283 result = future.result() 

284 all_results.append(result) 

285 except Exception as e: 

286 error_result: BatchResult[BatchResultT] = BatchResult( 

287 item=item, 

288 success=False, 

289 error_message=str(e), 

290 attempt_count=item.attempt_count, 

291 ) 

292 all_results.append(error_result) 

293 

294 # Close progress bar if it was created 

295 if progress_bar: 

296 progress_bar.close() 

297 

298 self._results = all_results 

299 return all_results 

300 

301 def _group_items( 

302 self, items: list[BatchItem[BatchItemT]] 

303 ) -> dict[str, list[BatchItem[BatchItemT]]]: 

304 """Group items by model or provider.""" 

305 groups: dict[str, list[BatchItem[BatchItemT]]] = {} 

306 

307 for item in items: 

308 # Determine group key 

309 group_key = "default" 

310 

311 if ( 

312 self.config.group_by_model 

313 and item.json_data 

314 and "model" in item.json_data 

315 ): 

316 group_key = item.json_data["model"] 

317 elif self.config.group_by_provider: 

318 group_key = self.limiter.provider.value 

319 

320 # Add to group 

321 if group_key not in groups: 

322 groups[group_key] = [] 

323 groups[group_key].append(item) 

324 

325 return groups 

326 

327 async def _process_item_with_retry( 

328 self, 

329 item: BatchItem[BatchItemT], 

330 semaphore: asyncio.Semaphore, 

331 progress_bar: tqdm | None = None, 

332 ) -> BatchResult[BatchResultT]: 

333 """Process a single item with retry logic.""" 

334 async with semaphore: 

335 import time 

336 

337 start_time = time.time() 

338 

339 for attempt in range(self.config.max_retries_per_item + 1): 

340 item.attempt_count = attempt + 1 

341 

342 try: 

343 # Process the item 

344 result = await self.process_item(item) 

345 

346 # Update progress bar on success 

347 if progress_bar: 

348 progress_bar.update(1) 

349 

350 # Success 

351 return BatchResult( 

352 item=item, 

353 result=result, 

354 success=True, 

355 duration=time.time() - start_time, 

356 attempt_count=item.attempt_count, 

357 ) 

358 

359 except Exception as e: 

360 item.last_error = e 

361 

362 # Check if this is a timeout error 

363 is_timeout_error = ( 

364 isinstance(e, (httpx.ReadTimeout, httpx.ConnectTimeout)) or 

365 (hasattr(e, '__cause__') and isinstance(e.__cause__, (httpx.ReadTimeout, httpx.ConnectTimeout))) or 

366 'ReadTimeout' in str(type(e)) or 'timeout' in str(e).lower() 

367 ) 

368 

369 # Print user-friendly error messages 

370 if self.config.verbose: 

371 if is_timeout_error: 

372 # Get current timeout from the limiter 

373 current_timeout = getattr(self.limiter, '_user_timeout', 120.0) 

374 print(f"⏱️ TIMEOUT ERROR in batch item {item.id} (attempt {attempt + 1}):") 

375 print(f" Current timeout setting: {current_timeout} seconds") 

376 print(f" The request took longer than {current_timeout}s to complete.") 

377 print(f"") 

378 print(f"💡 How to fix this:") 

379 print(f" 1. Increase timeout: ChatLimiter.for_model('{getattr(self.limiter, 'provider', 'your-model')}', timeout={current_timeout + 60})") 

380 print(f" 2. Reduce concurrency: BatchConfig(max_concurrent_requests={max(1, self.config.max_concurrent_requests // 2)})") 

381 print(f" 3. Current concurrency: {self.config.max_concurrent_requests} requests") 

382 print(f"") 

383 else: 

384 print(f"❌ Exception in batch item {item.id} (attempt {attempt + 1}):") 

385 

386 traceback.print_exc() 

387 

388 # If this is the last attempt or we should stop on error 

389 if ( 

390 attempt == self.config.max_retries_per_item 

391 or self.config.stop_on_first_error 

392 ): 

393 # Update progress bar on final failure 

394 if progress_bar: 

395 progress_bar.update(1) 

396 

397 return BatchResult( 

398 item=item, 

399 success=False, 

400 error_message=str(e), 

401 duration=time.time() - start_time, 

402 attempt_count=item.attempt_count, 

403 ) 

404 

405 # Wait before retry - longer for timeout errors 

406 if is_timeout_error: 

407 # For timeout errors, wait longer and suggest more aggressive backing off 

408 retry_delay = self.config.retry_delay * (3**attempt) # More aggressive backoff 

409 else: 

410 retry_delay = self.config.retry_delay * (2**attempt) 

411 

412 await asyncio.sleep(retry_delay) 

413 

414 # This should never be reached, but added for type checking 

415 return BatchResult( 

416 item=item, 

417 success=False, 

418 error_message="Unexpected error in retry logic", 

419 duration=time.time() - start_time, 

420 attempt_count=item.attempt_count, 

421 ) 

422 

423 def _process_item_sync_with_retry( 

424 self, 

425 item: BatchItem[BatchItemT], 

426 progress_bar: tqdm | None = None, 

427 ) -> BatchResult[BatchResultT]: 

428 """Process a single item with retry logic (sync).""" 

429 import time 

430 

431 start_time = time.time() 

432 

433 for attempt in range(self.config.max_retries_per_item + 1): 

434 item.attempt_count = attempt + 1 

435 

436 try: 

437 # Process the item 

438 result = self.process_item_sync(item) 

439 

440 # Update progress bar on success 

441 if progress_bar: 

442 progress_bar.update(1) 

443 

444 # Success 

445 return BatchResult( 

446 item=item, 

447 result=result, 

448 success=True, 

449 duration=time.time() - start_time, 

450 attempt_count=item.attempt_count, 

451 ) 

452 

453 except Exception as e: 

454 item.last_error = e 

455 

456 # Print traceback if verbose mode is enabled 

457 if self.config.verbose: 

458 print(f"Exception in batch item {item.id} (attempt {attempt + 1}):") 

459 traceback.print_exc() 

460 

461 # If this is the last attempt or we should stop on error 

462 if ( 

463 attempt == self.config.max_retries_per_item 

464 or self.config.stop_on_first_error 

465 ): 

466 # Update progress bar on final failure 

467 if progress_bar: 

468 progress_bar.update(1) 

469 

470 return BatchResult( 

471 item=item, 

472 success=False, 

473 error_message=str(e), 

474 duration=time.time() - start_time, 

475 attempt_count=item.attempt_count, 

476 ) 

477 

478 # Wait before retry 

479 time.sleep(self.config.retry_delay * (2**attempt)) 

480 

481 # This should never be reached, but added for type checking 

482 return BatchResult( 

483 item=item, 

484 success=False, 

485 error_message="Unexpected error in retry logic", 

486 duration=time.time() - start_time, 

487 attempt_count=item.attempt_count, 

488 ) 

489 

490 def get_success_rate(self) -> float: 

491 """Get the success rate of the last batch.""" 

492 if not self._results: 

493 return 0.0 

494 

495 successful = sum(1 for r in self._results if r.success) 

496 return successful / len(self._results) 

497 

498 def get_successful_results(self) -> list[BatchResult[BatchResultT]]: 

499 """Get only successful results.""" 

500 return [r for r in self._results if r.success] 

501 

502 def get_failed_results(self) -> list[BatchResult[BatchResultT]]: 

503 """Get only failed results.""" 

504 return [r for r in self._results if not r.success] 

505 

506 def get_stats(self) -> dict[str, Any]: 

507 """Get comprehensive processing statistics.""" 

508 if not self._results: 

509 return {"total": 0, "successful": 0, "failed": 0, "success_rate": 0.0} 

510 

511 successful = self.get_successful_results() 

512 failed = self.get_failed_results() 

513 

514 # Calculate timing statistics 

515 durations = [r.duration for r in self._results] 

516 avg_duration = sum(durations) / len(durations) if durations else 0 

517 

518 return { 

519 "total": len(self._results), 

520 "successful": len(successful), 

521 "failed": len(failed), 

522 "success_rate": self.get_success_rate(), 

523 "avg_duration": avg_duration, 

524 "total_duration": sum(durations), 

525 "avg_attempts": sum(r.attempt_count for r in self._results) 

526 / len(self._results), 

527 } 

528 

529 

530class ChatBatchProcessor(BatchProcessor[dict[str, Any], dict[str, Any]]): 

531 """Batch processor for chat completion requests.""" 

532 

533 async def process_item(self, item: BatchItem[dict[str, Any]]) -> dict[str, Any]: 

534 """Process a single chat completion request.""" 

535 request_data = item.json_data or item.data 

536 

537 # Log prompt if verbose mode is enabled 

538 if self.config.verbose: 

539 print(f"\n--- PROMPT (Item {item.id}) ---") 

540 if "messages" in request_data: 

541 for msg in request_data["messages"]: 

542 role = msg.get("role", "unknown") 

543 content = msg.get("content", "") 

544 print(f"{role.upper()}: {content}") 

545 else: 

546 print(f"REQUEST DATA: {request_data}") 

547 print("--- END PROMPT ---\n") 

548 

549 # Make the request using the limiter 

550 response = await self.limiter.request( 

551 method=item.method, 

552 url=item.url, 

553 json=request_data, 

554 ) 

555 

556 # Parse response 

557 response.raise_for_status() 

558 result: dict[str, Any] = response.json() 

559 

560 # Log response if verbose mode is enabled 

561 if self.config.verbose: 

562 print(f"\n--- RESPONSE (Item {item.id}) ---") 

563 if "choices" in result and result["choices"]: 

564 for i, choice in enumerate(result["choices"]): 

565 if "message" in choice: 

566 content = choice["message"].get("content", "") 

567 print(f"CHOICE {i}: {content}") 

568 elif "text" in choice: 

569 print(f"CHOICE {i}: {choice['text']}") 

570 else: 

571 print(f"FULL RESPONSE: {result}") 

572 print("--- END RESPONSE ---\n") 

573 

574 # Store response metadata 

575 item.metadata["response_headers"] = dict(response.headers) 

576 item.metadata["status_code"] = response.status_code 

577 

578 return result 

579 

580 def process_item_sync(self, item: BatchItem[dict[str, Any]]) -> dict[str, Any]: 

581 """Process a single chat completion request synchronously.""" 

582 request_data = item.json_data or item.data 

583 

584 # Log prompt if verbose mode is enabled 

585 if self.config.verbose: 

586 print(f"\n--- PROMPT (Item {item.id}) ---") 

587 if "messages" in request_data: 

588 for msg in request_data["messages"]: 

589 role = msg.get("role", "unknown") 

590 content = msg.get("content", "") 

591 print(f"{role.upper()}: {content}") 

592 else: 

593 print(f"REQUEST DATA: {request_data}") 

594 print("--- END PROMPT ---\n") 

595 

596 # Make the request using the limiter 

597 response = self.limiter.request_sync( 

598 method=item.method, 

599 url=item.url, 

600 json=request_data, 

601 ) 

602 

603 # Parse response 

604 response.raise_for_status() 

605 result: dict[str, Any] = response.json() 

606 

607 # Log response if verbose mode is enabled 

608 if self.config.verbose: 

609 print(f"\n--- RESPONSE (Item {item.id}) ---") 

610 if "choices" in result and result["choices"]: 

611 for i, choice in enumerate(result["choices"]): 

612 if "message" in choice: 

613 content = choice["message"].get("content", "") 

614 print(f"CHOICE {i}: {content}") 

615 elif "text" in choice: 

616 print(f"CHOICE {i}: {choice['text']}") 

617 else: 

618 print(f"FULL RESPONSE: {result}") 

619 print("--- END RESPONSE ---\n") 

620 

621 # Store response metadata 

622 item.metadata["response_headers"] = dict(response.headers) 

623 item.metadata["status_code"] = response.status_code 

624 

625 return result 

626 

627 

628# Convenience functions for common use cases 

629async def process_chat_batch( 

630 limiter: ChatLimiter, 

631 requests: list[dict[str, Any]], 

632 config: BatchConfig | None = None, 

633) -> list[BatchResult[dict[str, Any]]]: 

634 """ 

635 Process a batch of chat completion requests. 

636 

637 Args: 

638 limiter: Configured ChatLimiter instance 

639 requests: List of request data (will be sent as JSON) 

640 config: Optional batch processing configuration 

641 

642 Returns: 

643 List of batch results 

644 """ 

645 processor = ChatBatchProcessor(limiter, config) 

646 return await processor.process_batch(requests) 

647 

648 

649def process_chat_batch_sync( 

650 limiter: ChatLimiter, 

651 requests: list[dict[str, Any]], 

652 config: BatchConfig | None = None, 

653) -> list[BatchResult[dict[str, Any]]]: 

654 """ 

655 Process a batch of chat completion requests synchronously. 

656 

657 Args: 

658 limiter: Configured ChatLimiter instance 

659 requests: List of request data (will be sent as JSON) 

660 config: Optional batch processing configuration 

661 

662 Returns: 

663 List of batch results 

664 """ 

665 processor = ChatBatchProcessor(limiter, config) 

666 return processor.process_batch_sync(requests) 

667 

668 

669# High-level chat completion batch processing 

670class ChatCompletionBatchProcessor(BatchProcessor[ChatCompletionRequest, ChatCompletionResponse]): 

671 """High-level batch processor for chat completion requests.""" 

672 

673 async def process_item(self, item: BatchItem[ChatCompletionRequest]) -> ChatCompletionResponse: 

674 """Process a single chat completion request using high-level interface.""" 

675 request = item.data 

676 

677 # Log prompt if verbose mode is enabled 

678 if self.config.verbose: 

679 print(f"\n--- PROMPT (Item {item.id}) ---") 

680 print(f"MODEL: {request.model}") 

681 for msg in request.messages: 

682 print(f"{msg.role.value.upper()}: {msg.content}") 

683 print("--- END PROMPT ---\n") 

684 

685 # Use the high-level chat completion method 

686 response = await self.limiter.chat_completion( 

687 model=request.model, 

688 messages=request.messages, 

689 max_tokens=request.max_tokens, 

690 temperature=request.temperature, 

691 top_p=request.top_p, 

692 stop=request.stop, 

693 stream=request.stream, 

694 # Provider-specific parameters 

695 frequency_penalty=request.frequency_penalty, 

696 presence_penalty=request.presence_penalty, 

697 top_k=request.top_k, 

698 ) 

699 

700 # Check for errors in the response 

701 if not response.success: 

702 raise Exception(f"Chat completion failed: {response.error_message}") 

703 

704 # Log response if verbose mode is enabled 

705 if self.config.verbose: 

706 print(f"\n--- RESPONSE (Item {item.id}) ---") 

707 print(f"MODEL: {response.model}") 

708 if response.choices: 

709 for i, choice in enumerate(response.choices): 

710 print(f"CHOICE {i}: {choice.message.content}") 

711 print("--- END RESPONSE ---\n") 

712 

713 return response 

714 

715 def process_item_sync(self, item: BatchItem[ChatCompletionRequest]) -> ChatCompletionResponse: 

716 """Process a single chat completion request synchronously using high-level interface.""" 

717 request = item.data 

718 

719 # Log prompt if verbose mode is enabled 

720 if self.config.verbose: 

721 print(f"\n--- PROMPT (Item {item.id}) ---") 

722 print(f"MODEL: {request.model}") 

723 for msg in request.messages: 

724 print(f"{msg.role.value.upper()}: {msg.content}") 

725 print("--- END PROMPT ---\n") 

726 

727 # Use the high-level chat completion method (sync) 

728 response = self.limiter.chat_completion_sync( 

729 model=request.model, 

730 messages=request.messages, 

731 max_tokens=request.max_tokens, 

732 temperature=request.temperature, 

733 top_p=request.top_p, 

734 stop=request.stop, 

735 stream=request.stream, 

736 # Provider-specific parameters 

737 frequency_penalty=request.frequency_penalty, 

738 presence_penalty=request.presence_penalty, 

739 top_k=request.top_k, 

740 ) 

741 

742 # Check for errors in the response 

743 if not response.success: 

744 raise Exception(f"Chat completion failed: {response.error_message}") 

745 

746 # Log response if verbose mode is enabled 

747 if self.config.verbose: 

748 print(f"\n--- RESPONSE (Item {item.id}) ---") 

749 print(f"MODEL: {response.model}") 

750 if response.choices: 

751 for i, choice in enumerate(response.choices): 

752 print(f"CHOICE {i}: {choice.message.content}") 

753 print("--- END RESPONSE ---\n") 

754 

755 return response 

756 

757 

758# Convenience functions for high-level chat completion batches 

759async def process_chat_completion_batch( 

760 limiter: ChatLimiter, 

761 requests: list[ChatCompletionRequest], 

762 config: BatchConfig | None = None, 

763) -> list[BatchResult[ChatCompletionResponse]]: 

764 """ 

765 Process a batch of high-level chat completion requests. 

766 

767 Args: 

768 limiter: Configured ChatLimiter instance 

769 requests: List of ChatCompletionRequest objects 

770 config: Optional batch processing configuration 

771 

772 Returns: 

773 List of batch results containing ChatCompletionResponse objects 

774 

775 Example: 

776 from chat_limiter import ChatLimiter, Message, MessageRole, ChatCompletionRequest 

777 

778 requests = [ 

779 ChatCompletionRequest( 

780 model="gpt-4o", 

781 messages=[Message(role=MessageRole.USER, content="Hello!")], 

782 max_tokens=50 

783 ), 

784 ChatCompletionRequest( 

785 model="gpt-4o", 

786 messages=[Message(role=MessageRole.USER, content="How are you?")], 

787 max_tokens=50 

788 ) 

789 ] 

790 

791 async with ChatLimiter.for_model("gpt-4o", api_key) as limiter: 

792 results = await process_chat_completion_batch(limiter, requests) 

793 """ 

794 processor = ChatCompletionBatchProcessor(limiter, config) 

795 return await processor.process_batch(requests) 

796 

797 

798def process_chat_completion_batch_sync( 

799 limiter: ChatLimiter, 

800 requests: list[ChatCompletionRequest], 

801 config: BatchConfig | None = None, 

802) -> list[BatchResult[ChatCompletionResponse]]: 

803 """ 

804 Process a batch of high-level chat completion requests synchronously. 

805 

806 Args: 

807 limiter: Configured ChatLimiter instance 

808 requests: List of ChatCompletionRequest objects 

809 config: Optional batch processing configuration 

810 

811 Returns: 

812 List of batch results containing ChatCompletionResponse objects 

813 """ 

814 processor = ChatCompletionBatchProcessor(limiter, config) 

815 return processor.process_batch_sync(requests) 

816 

817 

818# Helper function for creating chat completion requests from simple data 

819def create_chat_completion_requests( 

820 model: str, 

821 prompts: list[str], 

822 max_tokens: int | None = None, 

823 temperature: float | None = None, 

824 **kwargs: Any, 

825) -> list[ChatCompletionRequest]: 

826 """ 

827 Create a list of ChatCompletionRequest objects from simple prompts. 

828 

829 Args: 

830 model: The model to use for all requests 

831 prompts: List of user prompts 

832 max_tokens: Maximum tokens per completion 

833 temperature: Sampling temperature 

834 **kwargs: Additional parameters for all requests 

835 

836 Returns: 

837 List of ChatCompletionRequest objects 

838 

839 Example: 

840 requests = create_chat_completion_requests( 

841 model="gpt-4o", 

842 prompts=["Hello!", "How are you?", "What is Python?"], 

843 max_tokens=50, 

844 temperature=0.7 

845 ) 

846 """ 

847 from .types import Message, MessageRole 

848 

849 requests = [] 

850 for prompt in prompts: 

851 request = ChatCompletionRequest( 

852 model=model, 

853 messages=[Message(role=MessageRole.USER, content=prompt)], 

854 max_tokens=max_tokens, 

855 temperature=temperature, 

856 **kwargs 

857 ) 

858 requests.append(request) 

859 

860 return requests