Skip to content

Dataset

Bases: ABC

Open Targets Genetics Dataset.

Dataset is a wrapper around a Spark DataFrame with a predefined schema. Schemas for each child dataset are described in the json.schemas module.

Source code in src/otg/dataset/dataset.py
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
@dataclass
class Dataset(ABC):
    """Open Targets Genetics Dataset.

    `Dataset` is a wrapper around a Spark DataFrame with a predefined schema. Schemas for each child dataset are described in the `json.schemas` module.
    """

    _df: DataFrame
    _schema: StructType

    def __post_init__(self: Dataset) -> None:
        """Post init."""
        self.validate_schema()

    @property
    def df(self: Dataset) -> DataFrame:
        """Dataframe included in the Dataset."""
        return self._df

    @df.setter
    def df(self: Dataset, new_df: DataFrame) -> None:  # noqa: CCE001
        """Dataframe setter."""
        self._df = new_df
        self.validate_schema()

    @property
    def schema(self: Dataset) -> StructType:
        """Dataframe expected schema."""
        return self._schema

    @classmethod
    @abstractmethod
    def get_schema(cls: type[Dataset]) -> StructType:
        """Abstract method to get the schema. Must be implemented by child classes."""
        pass

    @classmethod
    def from_parquet(
        cls: type[Dataset], session: Session, path: str, **kwargs: Dict[str, Any]
    ) -> Dataset:
        """Reads a parquet file into a Dataset with a given schema."""
        schema = cls.get_schema()
        df = session.read_parquet(path=path, schema=schema, **kwargs)
        return cls(_df=df, _schema=schema)

    def validate_schema(self: Dataset) -> None:  # sourcery skip: invert-any-all
        """Validate DataFrame schema against expected class schema.

        Raises:
            ValueError: DataFrame schema is not valid
        """
        expected_schema = self._schema
        expected_fields = flatten_schema(expected_schema)
        observed_schema = self._df.schema
        observed_fields = flatten_schema(observed_schema)

        # Unexpected fields in dataset
        if unexpected_field_names := [
            x.name
            for x in observed_fields
            if x.name not in [y.name for y in expected_fields]
        ]:
            raise ValueError(
                f"The {unexpected_field_names} fields are not included in DataFrame schema: {expected_fields}"
            )

        # Required fields not in dataset
        required_fields = [x.name for x in expected_schema if not x.nullable]
        if missing_required_fields := [
            req
            for req in required_fields
            if not any(field.name == req for field in observed_fields)
        ]:
            raise ValueError(
                f"The {missing_required_fields} fields are required but missing: {required_fields}"
            )

        # Fields with duplicated names
        if duplicated_fields := [
            x for x in set(observed_fields) if observed_fields.count(x) > 1
        ]:
            raise ValueError(
                f"The following fields are duplicated in DataFrame schema: {duplicated_fields}"
            )

        # Fields with different datatype
        observed_field_types = {
            field.name: type(field.dataType) for field in observed_fields
        }
        expected_field_types = {
            field.name: type(field.dataType) for field in expected_fields
        }
        if fields_with_different_observed_datatype := [
            name
            for name, observed_type in observed_field_types.items()
            if name in expected_field_types
            and observed_type != expected_field_types[name]
        ]:
            raise ValueError(
                f"The following fields present differences in their datatypes: {fields_with_different_observed_datatype}."
            )

    def persist(self: Dataset) -> Dataset:
        """Persist in memory the DataFrame included in the Dataset."""
        self.df = self._df.persist()
        return self

    def unpersist(self: Dataset) -> Dataset:
        """Remove the persisted DataFrame from memory."""
        self.df = self._df.unpersist()
        return self

df: DataFrame writable property

Dataframe included in the Dataset.

schema: StructType property

Dataframe expected schema.

__post_init__()

Post init.

Source code in src/otg/dataset/dataset.py
27
28
29
def __post_init__(self: Dataset) -> None:
    """Post init."""
    self.validate_schema()

from_parquet(session, path, kwargs) classmethod

Reads a parquet file into a Dataset with a given schema.

Source code in src/otg/dataset/dataset.py
53
54
55
56
57
58
59
60
@classmethod
def from_parquet(
    cls: type[Dataset], session: Session, path: str, **kwargs: Dict[str, Any]
) -> Dataset:
    """Reads a parquet file into a Dataset with a given schema."""
    schema = cls.get_schema()
    df = session.read_parquet(path=path, schema=schema, **kwargs)
    return cls(_df=df, _schema=schema)

get_schema() classmethod abstractmethod

Abstract method to get the schema. Must be implemented by child classes.

Source code in src/otg/dataset/dataset.py
47
48
49
50
51
@classmethod
@abstractmethod
def get_schema(cls: type[Dataset]) -> StructType:
    """Abstract method to get the schema. Must be implemented by child classes."""
    pass

persist()

Persist in memory the DataFrame included in the Dataset.

Source code in src/otg/dataset/dataset.py
119
120
121
122
def persist(self: Dataset) -> Dataset:
    """Persist in memory the DataFrame included in the Dataset."""
    self.df = self._df.persist()
    return self

unpersist()

Remove the persisted DataFrame from memory.

Source code in src/otg/dataset/dataset.py
124
125
126
127
def unpersist(self: Dataset) -> Dataset:
    """Remove the persisted DataFrame from memory."""
    self.df = self._df.unpersist()
    return self

validate_schema()

Validate DataFrame schema against expected class schema.

Raises:

Type Description
ValueError

DataFrame schema is not valid

Source code in src/otg/dataset/dataset.py
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
def validate_schema(self: Dataset) -> None:  # sourcery skip: invert-any-all
    """Validate DataFrame schema against expected class schema.

    Raises:
        ValueError: DataFrame schema is not valid
    """
    expected_schema = self._schema
    expected_fields = flatten_schema(expected_schema)
    observed_schema = self._df.schema
    observed_fields = flatten_schema(observed_schema)

    # Unexpected fields in dataset
    if unexpected_field_names := [
        x.name
        for x in observed_fields
        if x.name not in [y.name for y in expected_fields]
    ]:
        raise ValueError(
            f"The {unexpected_field_names} fields are not included in DataFrame schema: {expected_fields}"
        )

    # Required fields not in dataset
    required_fields = [x.name for x in expected_schema if not x.nullable]
    if missing_required_fields := [
        req
        for req in required_fields
        if not any(field.name == req for field in observed_fields)
    ]:
        raise ValueError(
            f"The {missing_required_fields} fields are required but missing: {required_fields}"
        )

    # Fields with duplicated names
    if duplicated_fields := [
        x for x in set(observed_fields) if observed_fields.count(x) > 1
    ]:
        raise ValueError(
            f"The following fields are duplicated in DataFrame schema: {duplicated_fields}"
        )

    # Fields with different datatype
    observed_field_types = {
        field.name: type(field.dataType) for field in observed_fields
    }
    expected_field_types = {
        field.name: type(field.dataType) for field in expected_fields
    }
    if fields_with_different_observed_datatype := [
        name
        for name, observed_type in observed_field_types.items()
        if name in expected_field_types
        and observed_type != expected_field_types[name]
    ]:
        raise ValueError(
            f"The following fields present differences in their datatypes: {fields_with_different_observed_datatype}."
        )