Coverage for bioimageio/spec/_internal/validator_annotations.py: 82%
38 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-02 14:21 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-02 14:21 +0000
1import sys
2from dataclasses import dataclass
3from typing import Any, Dict, Type
5import annotated_types
6from pydantic import GetCoreSchemaHandler, functional_validators
7from pydantic_core import CoreSchema
8from pydantic_core.core_schema import no_info_after_validator_function
10if sys.version_info < (3, 10):
11 SLOTS: Dict[str, bool] = {}
12 KW_ONLY: Dict[str, bool] = {}
13else:
14 SLOTS = {"slots": True}
15 KW_ONLY = {"kw_only": True}
18# TODO: make sure we use this one everywhere and not the vanilla pydantic one
19@dataclass(frozen=True, **SLOTS)
20class AfterValidator(functional_validators.AfterValidator):
21 def __str__(self):
22 return f"AfterValidator({self.func.__name__})"
25# TODO: make sure we use this one everywhere and not the vanilla pydantic one
26@dataclass(frozen=True, **SLOTS)
27class BeforeValidator(functional_validators.BeforeValidator):
28 def __str__(self):
29 return f"BeforeValidator({self.func.__name__})"
32# TODO: make sure we use this one everywhere and not the vanilla pydantic one
33@dataclass(frozen=True, **SLOTS)
34class Predicate(annotated_types.Predicate):
35 def __str__(self):
36 return f"Predicate({self.func.__name__})"
39@dataclass(frozen=True, **SLOTS)
40class RestrictCharacters:
41 alphabet: str
43 def __get_pydantic_core_schema__(
44 self, source: Type[Any], handler: GetCoreSchemaHandler
45 ) -> CoreSchema:
46 if not self.alphabet:
47 raise ValueError("Alphabet may not be empty")
49 schema = handler(source) # get the CoreSchema from the type / inner constraints
50 if schema["type"] != "str" and not (
51 schema["type"] == "function-after" and schema["schema"]["type"] == "str"
52 ):
53 raise TypeError("RestrictCharacters can only be applied to strings")
55 return no_info_after_validator_function(
56 self.validate,
57 schema,
58 )
60 def validate(self, value: str) -> str:
61 if any(c not in self.alphabet for c in value):
62 raise ValueError(f"{value!r} is not restricted to {self.alphabet!r}")
63 return value