from collections.abc import Callable from pathlib import Path from typing import Any, Literal, get_args, get_origin import click from pydantic import BaseModel from pydantic.fields import FieldInfo from pydantic_core import PydanticUndefined from ..core.config.model import OtherConfig TYPE_MAP = { str: click.STRING, int: click.INT, bool: click.BOOL, float: click.FLOAT, Path: click.Path(), } class ConfigOption(click.Option): """ Custom click option that allows for more flexible handling of Pydantic models. This is used to ensure that options can be generated dynamically from Pydantic models. """ model_name: str | None field_name: str | None def __init__(self, *args, **kwargs): self.model_name = kwargs.pop("model_name", None) self.field_name = kwargs.pop("field_name", None) super().__init__(*args, **kwargs) def options_from_model(model: type[BaseModel], parent_name: str = "") -> Callable: """ A decorator factory that generates click.option decorators from a Pydantic model. This function introspects a Pydantic model and creates a stack of decorators that can be applied to a click command function, ensuring the CLI options always match the configuration model. Args: model: The Pydantic BaseModel class to generate options from. Returns: A decorator that applies the generated options to a function. """ decorators = [] is_external_tool = issubclass(model, OtherConfig) model_name = model.__name__.lower().replace("config", "") # Introspect the model's fields for field_name, field_info in model.model_fields.items(): if isinstance(field_info.annotation, type) and issubclass( field_info.annotation, BaseModel ): nested_decorators = options_from_model(field_info.annotation, field_name) nested_decorator_list = getattr(nested_decorators, "decorators", []) decorators.extend(nested_decorator_list) continue if is_external_tool: cli_name = f"--{model_name}-{field_name.replace('_', '-')}" else: cli_name = f"--{field_name.replace('_', '-')}" kwargs = { "type": _get_click_type(field_info), "help": field_info.description or "", } if field_info.annotation is bool: if field_info.default is not PydanticUndefined: kwargs["default"] = field_info.default kwargs["show_default"] = True if is_external_tool: cli_name = ( f"{cli_name}/--no-{model_name}-{field_name.replace('_', '-')}" ) else: cli_name = f"{cli_name}/--no-{field_name.replace('_', '-')}" elif field_info.default is not PydanticUndefined: kwargs["default"] = field_info.default kwargs["show_default"] = True decorators.append( click.option( cli_name, cls=ConfigOption, model_name=model_name, field_name=field_name, **kwargs, ) ) for field_name, computed_field_info in model.model_computed_fields.items(): if is_external_tool: cli_name = f"--{model_name}-{field_name.replace('_', '-')}" else: cli_name = f"--{field_name.replace('_', '-')}" kwargs = { "type": TYPE_MAP[computed_field_info.return_type], "help": computed_field_info.description or "", } decorators.append( click.option( cli_name, cls=ConfigOption, model_name=model_name, field_name=field_name, **kwargs, ) ) def decorator(f: Callable) -> Callable: # Apply the decorators in reverse order to the function for deco in reversed(decorators): f = deco(f) return f # Store the list of decorators as an attribute for nested calls setattr(decorator, "decorators", decorators) return decorator def _get_click_type(field_info: FieldInfo) -> Any: """Maps a Pydantic field's type to a corresponding click type.""" field_type = field_info.annotation # Check if the type is a Literal if ( field_type is not None and hasattr(field_type, "__origin__") and get_origin(field_type) is Literal ): args = get_args(field_type) if args: return click.Choice(args) # Check for examples in field_info - use as choices if hasattr(field_info, "examples") and field_info.examples: return click.Choice(field_info.examples) # Check for numeric constraints and create click.Range if field_type in (int, float): constraints = {} # Extract constraints from field_info.metadata if hasattr(field_info, "metadata") and field_info.metadata: for constraint in field_info.metadata: constraint_type = type(constraint).__name__ if constraint_type == "Ge" and hasattr(constraint, "ge"): constraints["min"] = constraint.ge elif constraint_type == "Le" and hasattr(constraint, "le"): constraints["max"] = constraint.le elif constraint_type == "Gt" and hasattr(constraint, "gt"): # gt means strictly greater than, so min should be gt + 1 for int if field_type is int: constraints["min"] = constraint.gt + 1 else: # For float, we can't easily handle strict inequality in click.Range constraints["min"] = constraint.gt elif constraint_type == "Lt" and hasattr(constraint, "lt"): # lt means strictly less than, so max should be lt - 1 for int if field_type is int: constraints["max"] = constraint.lt - 1 else: # For float, we can't easily handle strict inequality in click.Range constraints["max"] = constraint.lt # Create click.Range if we have constraints if constraints: range_kwargs = {} if "min" in constraints: range_kwargs["min"] = constraints["min"] if "max" in constraints: range_kwargs["max"] = constraints["max"] if range_kwargs: if field_type is int: return click.IntRange(**range_kwargs) else: return click.FloatRange(**range_kwargs) return TYPE_MAP.get( field_type, click.STRING ) # Default to STRING if type is not found