"""Import RDF graphs into Pydantic models."""
from __future__ import annotations
import warnings
from collections.abc import Iterator
from typing import Any, TypeVar, cast
from pydantic import BaseModel, ValidationError
from pydantic.fields import FieldInfo
from pyoxigraph import NamedNode
from triplemodel.store import RdfGraph as Graph
from triplemodel.store.namespaces import XSD
from pyoxigraph import Literal as RdfLiteral
from triplemodel.store.terms import OxTerm, RdfTerm as Node, is_named, term_str
from triplemodel._typing import (
ModelFieldScalar,
ModelFieldValue,
ModelInitData,
OnDuplicate,
)
from triplemodel.config import (
RDF_TYPE,
EmbedMode,
RdfConfig,
get_rdf_config,
id_from_subject_uri,
)
from triplemodel.embed.strategies import import_nested_value
from triplemodel.fields.metadata import (
id_field_is_iri_id,
inverse_for_field,
predicate_for_field,
transitive_for_field,
)
from triplemodel.io.rdfs import transitive_objects
from triplemodel.namespaces import resolve_predicate
from triplemodel.fields.resolver import default_resolver
from triplemodel.metadata.predicate_map import (
owned_predicates_for_class,
predicate_map_for_class,
)
from triplemodel.metadata.cardinality import (
field_cardinality,
nested_model_type,
ref_collection_element_type,
raise_if_inverse_collection,
raise_if_nested_collection,
scalar_python_type,
union_member_types,
)
from triplemodel.protocols import PredicateResolver as PredicateResolverProtocol
from triplemodel.terms.collection import read_rdf_list
from triplemodel.terms.convert import term_to_python
from triplemodel.terms.lang import (
LangString,
MultiLangString,
_base_direction_name,
normalize_lang_tag,
)
from triplemodel.terms.typed_literal import TypedLiteral
from triplemodel.terms.iri import normalize_iri
from triplemodel.terms.registry import LiteralRegistry, default_registry
T = TypeVar("T", bound=BaseModel)
_IMPORT_KWARG_KEYS = frozenset(
{
"validate_type",
"on_duplicate",
"resolver",
"registry",
"de_skolemize",
"strict_import",
"warn_unmapped_fields",
"type_uri",
"config",
"chunk_size",
}
)
def split_load_kwargs(
kwargs: dict[str, Any],
) -> tuple[dict[str, Any], dict[str, Any]]:
"""Split mixed kwargs into document parse kwargs and model import kwargs."""
import_kwargs = {k: v for k, v in kwargs.items() if k in _IMPORT_KWARG_KEYS}
parse_kwargs = {k: v for k, v in kwargs.items() if k not in _IMPORT_KWARG_KEYS}
return parse_kwargs, import_kwargs
def _handle_duplicate(
field_name: str,
predicate: str,
uri: str,
count: int,
on_duplicate: OnDuplicate,
*,
message: str | None = None,
) -> None:
dup_msg = message or (
f"Multiple objects ({count}) for field {field_name!r} "
f"(predicate {predicate!r}, subject {uri!r}); using the first only."
)
if on_duplicate in ("ignore", "first"):
return
if on_duplicate == "error":
raise ValueError(dup_msg)
if on_duplicate == "warn":
warnings.warn(dup_msg, stacklevel=3)
def _enforce_subject_predicates(
graph: Graph,
subject: Node,
uri_str: str,
model_cls: type[BaseModel],
cfg: RdfConfig,
*,
resolver: PredicateResolverProtocol | None = None,
strict_import: bool | None = None,
warn_unmapped: bool | None = None,
) -> None:
"""Optionally validate triples on ``subject`` against owned predicates."""
do_strict = cfg.strict_import if strict_import is None else strict_import
do_warn = cfg.warn_unmapped_fields if warn_unmapped is None else warn_unmapped
if not do_strict and not do_warn:
return
owned = owned_predicates_for_class(model_cls, resolver=resolver, config=cfg)
allowed = set(owned) | {RDF_TYPE}
for _, pred, _ in graph.triples((subject, None, None)):
pred_str = term_str(pred)
if pred_str in allowed:
continue
msg = (
f"Predicate {pred_str!r} on subject {uri_str!r} is not mapped on "
f"{model_cls.__name__}."
)
if do_strict:
raise ValueError(msg)
if do_warn:
warnings.warn(msg, stacklevel=3)
def _handle_forward_inverse_conflict(
field_name: str,
forward_predicate: str,
inverse_predicate: str,
uri: str,
on_duplicate: OnDuplicate,
) -> None:
msg = (
f"Both forward predicate {forward_predicate!r} and inverse "
f"{inverse_predicate!r} have values for field {field_name!r} "
f"(subject {uri!r}); using forward objects only."
)
if on_duplicate == "error":
raise ValueError(msg)
if on_duplicate == "warn":
warnings.warn(msg, stacklevel=3)
def _union_conversion_order(
term: OxTerm, members: tuple[type, ...]
) -> tuple[type, ...]:
"""Prefer union members that match the literal datatype."""
if not isinstance(term, RdfLiteral) or not members:
return members
if term.datatype == XSD.integer and int in members:
return (int,) + tuple(m for m in members if m is not int)
if term.datatype in (XSD.string, None) and str in members:
return (str,) + tuple(m for m in members if m is not str)
return members
def _term_to_field(
term: OxTerm,
py_type: type | None,
field_name: str,
predicate: str,
uri: str,
*,
field_info: FieldInfo | None = None,
registry: LiteralRegistry = default_registry,
) -> ModelFieldScalar:
members = union_member_types(field_info) if field_info is not None else ()
types_to_try: tuple[type | None, ...]
if members:
types_to_try = _union_conversion_order(term, members)
elif py_type is not None:
types_to_try = (py_type,)
else:
types_to_try = (None,)
last_exc: Exception | None = None
for tp in types_to_try:
try:
return cast(ModelFieldScalar, term_to_python(term, tp, registry=registry))
except (ValueError, TypeError) as exc:
last_exc = exc
msg = (
f"Cannot convert object for field {field_name!r} "
f"(predicate {predicate!r}, subject {uri!r})"
)
raise ValueError(f"{msg}: {last_exc}") from last_exc
def _import_ref_resource(
graph: Graph,
term: OxTerm,
nested_type: type[BaseModel],
field_name: str,
predicate: str,
uri: str,
*,
on_duplicate: OnDuplicate,
registry: LiteralRegistry,
de_skolemize: bool,
) -> BaseModel:
node = cast(Node, term)
if not is_named(node):
raise ValueError(
f"Cannot import ref field {field_name!r} from term {term!r}; "
"expected a URI resource."
)
return graph_to_model(
graph,
nested_type,
node,
on_duplicate=on_duplicate,
registry=registry,
de_skolemize=de_skolemize,
)
def _import_set_field_values(
objects: list[OxTerm],
py_type: type | None,
field_name: str,
predicate: str,
uri: str,
*,
field_info: FieldInfo,
registry: LiteralRegistry,
on_duplicate: OnDuplicate,
) -> set[ModelFieldScalar]:
"""Import a ``set`` field, with per-element duplicate handling for ``TypedLiteral``."""
if py_type is not TypedLiteral:
return {
_term_to_field(
o,
py_type,
field_name,
predicate,
uri,
field_info=field_info,
registry=registry,
)
for o in objects
}
result: set[TypedLiteral] = set()
seen: set[tuple[str, str | None]] = set()
for o in objects:
item = cast(
TypedLiteral,
_term_to_field(
o,
TypedLiteral,
field_name,
predicate,
uri,
field_info=field_info,
registry=registry,
),
)
key = (item.value, item.datatype)
if key in seen:
_handle_duplicate(
field_name,
predicate,
uri,
2,
on_duplicate,
message=(
f"Duplicate TypedLiteral ({item.value!r}, datatype={item.datatype!r}) "
f"for field {field_name!r} (predicate {predicate!r}, subject {uri!r}); "
"using the first only."
),
)
continue
seen.add(key)
result.add(item)
return cast(set[ModelFieldScalar], result)
def import_multi_lang_field(
objects: list[OxTerm],
field_name: str,
predicate: str,
uri: str,
*,
on_duplicate: OnDuplicate,
) -> MultiLangString:
"""Build a :class:`~triplemodel.terms.lang.MultiLangString` from RDF literals."""
by_lang: dict[str, LangString] = {}
for term in objects:
if not isinstance(term, RdfLiteral):
continue
raw_lang = term.language
if raw_lang is None:
continue
lang = normalize_lang_tag(raw_lang)
ls = LangString(
str(term.value),
lang,
_base_direction_name(term.direction),
)
if lang in by_lang:
_handle_duplicate(
field_name,
predicate,
uri,
2,
on_duplicate,
message=(
f"Conflicting values for language {lang!r} on field "
f"{field_name!r} (predicate {predicate!r}, subject {uri!r}); "
"using the first only."
),
)
continue
by_lang[lang] = ls
return MultiLangString(by_lang)
[docs]
def import_field_value(
graph: Graph,
objects: list[OxTerm],
field_info: FieldInfo,
field_name: str,
predicate: str,
uri: str,
*,
embed: EmbedMode,
on_duplicate: OnDuplicate,
registry: LiteralRegistry = default_registry,
de_skolemize: bool = False,
) -> ModelFieldValue:
"""Hydrate a single model field from RDF objects (used by :func:`graph_to_model`)."""
card = field_cardinality(field_info)
nested_cls = nested_model_type(field_info)
if not objects:
if card in ("list", "set"):
return [] if card == "list" else set()
return None
if card in ("nested", "ref") and nested_cls is not None:
if len(objects) > 1 and on_duplicate != "ignore":
_handle_duplicate(field_name, predicate, uri, len(objects), on_duplicate)
term = objects[0]
if not isinstance(term, Node):
raise ValueError(f"Cannot import nested field {field_name!r} from {term!r}")
nested_type = cast(type[BaseModel], nested_cls)
if card == "ref":
if not is_named(term):
raise ValueError(
f"Cannot import ref field {field_name!r} from term {term!r}; "
"expected a URI resource."
)
return graph_to_model(
graph,
nested_type,
term,
on_duplicate=on_duplicate,
registry=registry,
de_skolemize=de_skolemize,
)
return import_nested_value(
graph,
term,
nested_type,
embed=embed,
on_duplicate=on_duplicate,
registry=registry,
de_skolemize=de_skolemize,
)
if card == "list":
ref_cls = ref_collection_element_type(field_info)
if ref_cls is not None:
nested_type = cast(type[BaseModel], ref_cls)
return cast(
ModelFieldValue,
[
cast(
ModelFieldScalar,
_import_ref_resource(
graph,
o,
nested_type,
field_name,
predicate,
uri,
on_duplicate=on_duplicate,
registry=registry,
de_skolemize=de_skolemize,
),
)
for o in objects
],
)
if len(objects) > 1 and on_duplicate != "ignore":
_handle_duplicate(field_name, predicate, uri, len(objects), on_duplicate)
py_type = scalar_python_type(field_info)
return read_rdf_list(graph, objects[0], py_type, registry=registry)
if card == "set":
ref_cls = ref_collection_element_type(field_info)
if ref_cls is not None:
nested_type = cast(type[BaseModel], ref_cls)
loaded = [
cast(
ModelFieldScalar,
_import_ref_resource(
graph,
o,
nested_type,
field_name,
predicate,
uri,
on_duplicate=on_duplicate,
registry=registry,
de_skolemize=de_skolemize,
),
)
for o in objects
]
try:
return set(loaded)
except TypeError:
return loaded
py_type = scalar_python_type(field_info)
if transitive_for_field(field_info):
pred_uri = predicate_for_field(field_info) or predicate
object_uris = transitive_objects(graph, uri, pred_uri)
return {
_term_to_field(
NamedNode(o),
py_type,
field_name,
pred_uri,
uri,
field_info=field_info,
registry=registry,
)
for o in object_uris
}
return _import_set_field_values(
objects,
py_type,
field_name,
predicate,
uri,
field_info=field_info,
registry=registry,
on_duplicate=on_duplicate,
)
py_type = scalar_python_type(field_info)
if py_type is MultiLangString:
return import_multi_lang_field(
objects,
field_name,
predicate,
uri,
on_duplicate=on_duplicate,
)
if len(objects) > 1:
_handle_duplicate(field_name, predicate, uri, len(objects), on_duplicate)
return _term_to_field(
objects[0],
py_type,
field_name,
predicate,
uri,
field_info=field_info,
registry=registry,
)
[docs]
def graph_to_model(
graph: Graph,
model_cls: type[T],
uri: str | Node,
*,
config: RdfConfig | None = None,
validate_type: bool = True,
on_duplicate: OnDuplicate = "warn",
resolver: PredicateResolverProtocol | None = None,
registry: LiteralRegistry = default_registry,
de_skolemize: bool | None = None,
strict_import: bool | None = None,
warn_unmapped_fields: bool | None = None,
) -> T:
"""Hydrate a single model instance from triples about ``uri``."""
cfg = config or get_rdf_config(model_cls)
from triplemodel.io.skolem import apply_de_skolemize
do_de = cfg.skolemize_import if de_skolemize is None else de_skolemize
graph = apply_de_skolemize(graph, de_skolemize=do_de)
r = resolver or default_resolver
prefixes = cfg.prefixes_dict
subject: Node = uri if isinstance(uri, Node) else NamedNode(normalize_iri(uri))
uri_str = term_str(subject)
if validate_type and cfg.type_uri:
type_ref = NamedNode(cfg.type_uri)
if (subject, NamedNode(RDF_TYPE), type_ref) not in graph:
raise ValueError(
f"Subject {uri_str!r} does not have rdf:type {cfg.type_uri!r} required by "
f"{model_cls.__name__}."
)
elif validate_type and cfg.instance_of_predicates:
prefixes = cfg.prefixes_dict
type_uris = cfg.instance_type_uris
matched = False
for pred_raw in cfg.instance_of_predicates:
pred_ref = NamedNode(resolve_predicate(pred_raw, prefixes))
if type_uris:
for type_raw in type_uris:
type_ref = NamedNode(resolve_predicate(type_raw, prefixes))
if (subject, pred_ref, type_ref) in graph:
matched = True
break
elif list(graph.objects(subject, pred_ref)):
matched = True
if matched:
break
if not matched:
raise ValueError(
f"Subject {uri_str!r} does not match instance_of constraints for "
f"{model_cls.__name__}."
)
data: ModelInitData = {}
if cfg.id_field:
extracted = (
id_from_subject_uri(cfg.namespace, uri_str) if cfg.namespace else None
)
if extracted is not None:
data[cfg.id_field] = extracted
elif id_field_is_iri_id(model_cls, cfg.id_field):
data[cfg.id_field] = uri_str
for name, predicate in predicate_map_for_class(
model_cls, resolver=r, config=cfg
).items():
if predicate is None:
continue
field_info = model_cls.model_fields[name]
raise_if_nested_collection(field_info)
raise_if_inverse_collection(field_info)
pred_ref = NamedNode(predicate)
forward_objects = sorted(
graph.objects(subject, pred_ref),
key=lambda o: term_str(o),
)
inv_raw = inverse_for_field(field_info)
inverse_objects: list[Node] = []
inv_predicate: str | None = None
if inv_raw is not None:
inv_predicate = resolve_predicate(inv_raw, prefixes)
inverse_objects = cast(
list[Node],
sorted(
graph.subjects(NamedNode(inv_predicate), subject),
key=str,
),
)
if forward_objects and inverse_objects and inv_predicate is not None:
if on_duplicate != "ignore":
_handle_forward_inverse_conflict(
name,
predicate,
inv_predicate,
uri_str,
on_duplicate,
)
if forward_objects:
objects = forward_objects
elif inverse_objects:
objects = inverse_objects
else:
continue
data[name] = import_field_value(
graph,
cast(list[OxTerm], objects),
field_info,
name,
predicate,
uri_str,
embed=cfg.embed,
on_duplicate=on_duplicate,
registry=registry,
de_skolemize=False,
)
_enforce_subject_predicates(
graph,
subject,
uri_str,
model_cls,
cfg,
resolver=r,
strict_import=strict_import,
warn_unmapped=warn_unmapped_fields,
)
try:
return model_cls.model_validate(data)
except ValidationError as exc:
raise ValueError(
f"Cannot validate {model_cls.__name__} from graph for subject {uri_str!r}: {exc}"
) from exc
def _subject_uris_for_model(
graph: Graph,
model_cls: type[T],
cfg: RdfConfig,
*,
type_uri: str | None = None,
resolver: PredicateResolverProtocol | None = None,
) -> list[str]:
from triplemodel.io.discovery import (
discover_subject_uris,
discover_subjects_by_instance_of,
)
rdf_type = type_uri if type_uri is not None else cfg.type_uri
if rdf_type:
return sorted(
term_str(s)
for s in graph.subjects(NamedNode(RDF_TYPE), NamedNode(rdf_type))
if isinstance(s, NamedNode)
)
if cfg.instance_of_predicates:
return discover_subjects_by_instance_of(graph, cfg)
return discover_subject_uris(graph, model_cls, cfg, resolver=resolver)
[docs]
def iter_graph_to_models(
graph: Graph,
model_cls: type[T],
*,
chunk_size: int = 500,
type_uri: str | None = None,
config: RdfConfig | None = None,
validate_type: bool = True,
on_duplicate: OnDuplicate = "warn",
resolver: PredicateResolverProtocol | None = None,
registry: LiteralRegistry = default_registry,
de_skolemize: bool | None = None,
strict_import: bool | None = None,
warn_unmapped_fields: bool | None = None,
) -> Iterator[list[T]]:
"""Yield chunks of model instances loaded from ``graph``."""
if chunk_size <= 0:
raise ValueError(f"chunk_size must be positive, got {chunk_size!r}.")
cfg = config or get_rdf_config(model_cls)
from triplemodel.io.skolem import apply_de_skolemize
do_de = cfg.skolemize_import if de_skolemize is None else de_skolemize
graph = apply_de_skolemize(graph, de_skolemize=do_de)
uris = _subject_uris_for_model(
graph, model_cls, cfg, type_uri=type_uri, resolver=resolver
)
def _gen() -> Iterator[list[T]]:
for start in range(0, len(uris), chunk_size):
chunk_uris = uris[start : start + chunk_size]
chunk: list[T] = []
for subject_uri in chunk_uris:
chunk.append(
graph_to_model(
graph,
model_cls,
subject_uri,
config=cfg,
validate_type=validate_type,
on_duplicate=on_duplicate,
resolver=resolver,
registry=registry,
de_skolemize=False,
strict_import=strict_import,
warn_unmapped_fields=warn_unmapped_fields,
)
)
yield chunk
return _gen()
[docs]
def graph_to_models(
graph: Graph,
model_cls: type[T],
*,
chunk_size: int | None = None,
type_uri: str | None = None,
config: RdfConfig | None = None,
validate_type: bool = True,
on_duplicate: OnDuplicate = "warn",
resolver: PredicateResolverProtocol | None = None,
registry: LiteralRegistry = default_registry,
de_skolemize: bool | None = None,
strict_import: bool | None = None,
warn_unmapped_fields: bool | None = None,
) -> list[T]:
"""Load all resources of ``type_uri`` (or the model's configured type) as models."""
cfg = config or get_rdf_config(model_cls)
size = chunk_size if chunk_size is not None else 2**31
instances: list[T] = []
for chunk in iter_graph_to_models(
graph,
model_cls,
chunk_size=size,
type_uri=type_uri,
config=cfg,
validate_type=validate_type,
on_duplicate=on_duplicate,
resolver=resolver,
registry=registry,
de_skolemize=de_skolemize,
strict_import=strict_import,
warn_unmapped_fields=warn_unmapped_fields,
):
instances.extend(chunk)
if cfg.type_uri or type_uri:
instances.sort(key=lambda m: cfg.subject_uri(m))
return instances