Skip to content

Datasets

The Dataset classes define the data model behind Open Targets Gentropy. Every class inherits from the Dataset class and contains a dataframe with a predefined schema that can be found in the respective classes.

gentropy.dataset.dataset.Dataset dataclass

Bases: ABC

Open Targets Gentropy Dataset.

Dataset is a wrapper around a Spark DataFrame with a predefined schema. Schemas for each child dataset are described in the schemas module.

Source code in src/gentropy/dataset/dataset.py
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
@dataclass
class Dataset(ABC):
    """Open Targets Gentropy Dataset.

    `Dataset` is a wrapper around a Spark DataFrame with a predefined schema. Schemas for each child dataset are described in the `schemas` module.
    """

    _df: DataFrame
    _schema: StructType

    def __post_init__(self: Dataset) -> None:
        """Post init."""
        self.validate_schema()

    @property
    def df(self: Dataset) -> DataFrame:
        """Dataframe included in the Dataset.

        Returns:
            DataFrame: Dataframe included in the Dataset
        """
        return self._df

    @df.setter
    def df(self: Dataset, new_df: DataFrame) -> None:  # noqa: CCE001
        """Dataframe setter.

        Args:
            new_df (DataFrame): New dataframe to be included in the Dataset
        """
        self._df: DataFrame = new_df
        self.validate_schema()

    @property
    def schema(self: Dataset) -> StructType:
        """Dataframe expected schema.

        Returns:
            StructType: Dataframe expected schema
        """
        return self._schema

    @classmethod
    @abstractmethod
    def get_schema(cls: type[Self]) -> StructType:
        """Abstract method to get the schema. Must be implemented by child classes.

        Returns:
            StructType: Schema for the Dataset
        """
        pass

    @classmethod
    def from_parquet(
        cls: type[Self],
        session: Session,
        path: str | list[str],
        **kwargs: bool | float | int | str | None,
    ) -> Self:
        """Reads parquet into a Dataset with a given schema.

        Args:
            session (Session): Spark session
            path (str | list[str]): Path to the parquet dataset
            **kwargs (bool | float | int | str | None): Additional arguments to pass to spark.read.parquet

        Returns:
            Self: Dataset with the parquet file contents

        Raises:
            ValueError: Parquet file is empty
        """
        schema = cls.get_schema()
        df = session.load_data(path, format="parquet", schema=schema, **kwargs)
        if df.isEmpty():
            raise ValueError(f"Parquet file is empty: {path}")
        return cls(_df=df, _schema=schema)

    def filter(self: Self, condition: Column) -> Self:
        """Creates a new instance of a Dataset with the DataFrame filtered by the condition.

        Args:
            condition (Column): Condition to filter the DataFrame

        Returns:
            Self: Filtered Dataset
        """
        df = self._df.filter(condition)
        class_constructor = self.__class__
        return class_constructor(_df=df, _schema=class_constructor.get_schema())

    def validate_schema(self: Dataset) -> None:
        """Validate DataFrame schema against expected class schema.

        Raises:
            ValueError: DataFrame schema is not valid
        """
        expected_schema = self._schema
        expected_fields = flatten_schema(expected_schema)
        observed_schema = self._df.schema
        observed_fields = flatten_schema(observed_schema)

        # Unexpected fields in dataset
        if unexpected_field_names := [
            x.name
            for x in observed_fields
            if x.name not in [y.name for y in expected_fields]
        ]:
            raise ValueError(
                f"The {unexpected_field_names} fields are not included in DataFrame schema: {expected_fields}"
            )

        # Required fields not in dataset
        required_fields = [x.name for x in expected_schema if not x.nullable]
        if missing_required_fields := [
            req
            for req in required_fields
            if not any(field.name == req for field in observed_fields)
        ]:
            raise ValueError(
                f"The {missing_required_fields} fields are required but missing: {required_fields}"
            )

        # Fields with duplicated names
        if duplicated_fields := [
            x for x in set(observed_fields) if observed_fields.count(x) > 1
        ]:
            raise ValueError(
                f"The following fields are duplicated in DataFrame schema: {duplicated_fields}"
            )

        # Fields with different datatype
        observed_field_types = {
            field.name: type(field.dataType) for field in observed_fields
        }
        expected_field_types = {
            field.name: type(field.dataType) for field in expected_fields
        }
        if fields_with_different_observed_datatype := [
            name
            for name, observed_type in observed_field_types.items()
            if name in expected_field_types
            and observed_type != expected_field_types[name]
        ]:
            raise ValueError(
                f"The following fields present differences in their datatypes: {fields_with_different_observed_datatype}."
            )

    def drop_infinity_values(self: Self, *cols: str) -> Self:
        """Drop infinity values from Double typed column.

        Infinity type reference - https://spark.apache.org/docs/latest/sql-ref-datatypes.html#floating-point-special-values
        The implementation comes from https://stackoverflow.com/questions/34432998/how-to-replace-infinity-in-pyspark-dataframe

        Args:
            *cols (str): names of the columns to check for infinite values, these should be of DoubleType only!

        Returns:
            Self: Dataset after removing infinite values
        """
        if len(cols) == 0:
            return self
        inf_strings = ("Inf", "+Inf", "-Inf", "Infinity", "+Infinity", "-Infinity")
        inf_values = [f.lit(v).cast(DoubleType()) for v in inf_strings]
        conditions = [f.col(c).isin(inf_values) for c in cols]
        # reduce individual filter expressions with or statement
        # to col("beta").isin([lit(Inf)]) | col("beta").isin([lit(Inf)])...
        condition = reduce(lambda a, b: a | b, conditions)
        self.df = self._df.filter(~condition)
        return self

    def persist(self: Self) -> Self:
        """Persist in memory the DataFrame included in the Dataset.

        Returns:
            Self: Persisted Dataset
        """
        self.df = self._df.persist()
        return self

    def unpersist(self: Self) -> Self:
        """Remove the persisted DataFrame from memory.

        Returns:
            Self: Unpersisted Dataset
        """
        self.df = self._df.unpersist()
        return self

    def coalesce(self: Self, num_partitions: int, **kwargs: Any) -> Self:
        """Coalesce the DataFrame included in the Dataset.

        Coalescing is efficient for decreasing the number of partitions because it avoids a full shuffle of the data.

        Args:
            num_partitions (int): Number of partitions to coalesce to
            **kwargs (Any): Arguments to pass to the coalesce method

        Returns:
            Self: Coalesced Dataset
        """
        self.df = self._df.coalesce(num_partitions, **kwargs)
        return self

    def repartition(self: Self, num_partitions: int, **kwargs: Any) -> Self:
        """Repartition the DataFrame included in the Dataset.

        Repartitioning creates new partitions with data that is distributed evenly.

        Args:
            num_partitions (int): Number of partitions to repartition to
            **kwargs (Any): Arguments to pass to the repartition method

        Returns:
            Self: Repartitioned Dataset
        """
        self.df = self._df.repartition(num_partitions, **kwargs)
        return self

    @staticmethod
    def update_quality_flag(
        qc: Column, flag_condition: Column, flag_text: Enum
    ) -> Column:
        """Update the provided quality control list with a new flag if condition is met.

        Args:
            qc (Column): Array column with the current list of qc flags.
            flag_condition (Column): This is a column of booleans, signing which row should be flagged
            flag_text (Enum): Text for the new quality control flag

        Returns:
            Column: Array column with the updated list of qc flags.
        """
        qc = f.when(qc.isNull(), f.array()).otherwise(qc)
        return f.when(
            flag_condition,
            f.array_union(qc, f.array(f.lit(flag_text.value))),
        ).otherwise(qc)

df: DataFrame property writable

Dataframe included in the Dataset.

Returns:

Name Type Description
DataFrame DataFrame

Dataframe included in the Dataset

schema: StructType property

Dataframe expected schema.

Returns:

Name Type Description
StructType StructType

Dataframe expected schema

coalesce(num_partitions: int, **kwargs: Any) -> Self

Coalesce the DataFrame included in the Dataset.

Coalescing is efficient for decreasing the number of partitions because it avoids a full shuffle of the data.

Parameters:

Name Type Description Default
num_partitions int

Number of partitions to coalesce to

required
**kwargs Any

Arguments to pass to the coalesce method

{}

Returns:

Name Type Description
Self Self

Coalesced Dataset

Source code in src/gentropy/dataset/dataset.py
214
215
216
217
218
219
220
221
222
223
224
225
226
227
def coalesce(self: Self, num_partitions: int, **kwargs: Any) -> Self:
    """Coalesce the DataFrame included in the Dataset.

    Coalescing is efficient for decreasing the number of partitions because it avoids a full shuffle of the data.

    Args:
        num_partitions (int): Number of partitions to coalesce to
        **kwargs (Any): Arguments to pass to the coalesce method

    Returns:
        Self: Coalesced Dataset
    """
    self.df = self._df.coalesce(num_partitions, **kwargs)
    return self

drop_infinity_values(*cols: str) -> Self

Drop infinity values from Double typed column.

Infinity type reference - https://spark.apache.org/docs/latest/sql-ref-datatypes.html#floating-point-special-values The implementation comes from https://stackoverflow.com/questions/34432998/how-to-replace-infinity-in-pyspark-dataframe

Parameters:

Name Type Description Default
*cols str

names of the columns to check for infinite values, these should be of DoubleType only!

()

Returns:

Name Type Description
Self Self

Dataset after removing infinite values

Source code in src/gentropy/dataset/dataset.py
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
def drop_infinity_values(self: Self, *cols: str) -> Self:
    """Drop infinity values from Double typed column.

    Infinity type reference - https://spark.apache.org/docs/latest/sql-ref-datatypes.html#floating-point-special-values
    The implementation comes from https://stackoverflow.com/questions/34432998/how-to-replace-infinity-in-pyspark-dataframe

    Args:
        *cols (str): names of the columns to check for infinite values, these should be of DoubleType only!

    Returns:
        Self: Dataset after removing infinite values
    """
    if len(cols) == 0:
        return self
    inf_strings = ("Inf", "+Inf", "-Inf", "Infinity", "+Infinity", "-Infinity")
    inf_values = [f.lit(v).cast(DoubleType()) for v in inf_strings]
    conditions = [f.col(c).isin(inf_values) for c in cols]
    # reduce individual filter expressions with or statement
    # to col("beta").isin([lit(Inf)]) | col("beta").isin([lit(Inf)])...
    condition = reduce(lambda a, b: a | b, conditions)
    self.df = self._df.filter(~condition)
    return self

filter(condition: Column) -> Self

Creates a new instance of a Dataset with the DataFrame filtered by the condition.

Parameters:

Name Type Description Default
condition Column

Condition to filter the DataFrame

required

Returns:

Name Type Description
Self Self

Filtered Dataset

Source code in src/gentropy/dataset/dataset.py
103
104
105
106
107
108
109
110
111
112
113
114
def filter(self: Self, condition: Column) -> Self:
    """Creates a new instance of a Dataset with the DataFrame filtered by the condition.

    Args:
        condition (Column): Condition to filter the DataFrame

    Returns:
        Self: Filtered Dataset
    """
    df = self._df.filter(condition)
    class_constructor = self.__class__
    return class_constructor(_df=df, _schema=class_constructor.get_schema())

from_parquet(session: Session, path: str | list[str], **kwargs: bool | float | int | str | None) -> Self classmethod

Reads parquet into a Dataset with a given schema.

Parameters:

Name Type Description Default
session Session

Spark session

required
path str | list[str]

Path to the parquet dataset

required
**kwargs bool | float | int | str | None

Additional arguments to pass to spark.read.parquet

{}

Returns:

Name Type Description
Self Self

Dataset with the parquet file contents

Raises:

Type Description
ValueError

Parquet file is empty

Source code in src/gentropy/dataset/dataset.py
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
@classmethod
def from_parquet(
    cls: type[Self],
    session: Session,
    path: str | list[str],
    **kwargs: bool | float | int | str | None,
) -> Self:
    """Reads parquet into a Dataset with a given schema.

    Args:
        session (Session): Spark session
        path (str | list[str]): Path to the parquet dataset
        **kwargs (bool | float | int | str | None): Additional arguments to pass to spark.read.parquet

    Returns:
        Self: Dataset with the parquet file contents

    Raises:
        ValueError: Parquet file is empty
    """
    schema = cls.get_schema()
    df = session.load_data(path, format="parquet", schema=schema, **kwargs)
    if df.isEmpty():
        raise ValueError(f"Parquet file is empty: {path}")
    return cls(_df=df, _schema=schema)

get_schema() -> StructType abstractmethod classmethod

Abstract method to get the schema. Must be implemented by child classes.

Returns:

Name Type Description
StructType StructType

Schema for the Dataset

Source code in src/gentropy/dataset/dataset.py
67
68
69
70
71
72
73
74
75
@classmethod
@abstractmethod
def get_schema(cls: type[Self]) -> StructType:
    """Abstract method to get the schema. Must be implemented by child classes.

    Returns:
        StructType: Schema for the Dataset
    """
    pass

persist() -> Self

Persist in memory the DataFrame included in the Dataset.

Returns:

Name Type Description
Self Self

Persisted Dataset

Source code in src/gentropy/dataset/dataset.py
196
197
198
199
200
201
202
203
def persist(self: Self) -> Self:
    """Persist in memory the DataFrame included in the Dataset.

    Returns:
        Self: Persisted Dataset
    """
    self.df = self._df.persist()
    return self

repartition(num_partitions: int, **kwargs: Any) -> Self

Repartition the DataFrame included in the Dataset.

Repartitioning creates new partitions with data that is distributed evenly.

Parameters:

Name Type Description Default
num_partitions int

Number of partitions to repartition to

required
**kwargs Any

Arguments to pass to the repartition method

{}

Returns:

Name Type Description
Self Self

Repartitioned Dataset

Source code in src/gentropy/dataset/dataset.py
229
230
231
232
233
234
235
236
237
238
239
240
241
242
def repartition(self: Self, num_partitions: int, **kwargs: Any) -> Self:
    """Repartition the DataFrame included in the Dataset.

    Repartitioning creates new partitions with data that is distributed evenly.

    Args:
        num_partitions (int): Number of partitions to repartition to
        **kwargs (Any): Arguments to pass to the repartition method

    Returns:
        Self: Repartitioned Dataset
    """
    self.df = self._df.repartition(num_partitions, **kwargs)
    return self

unpersist() -> Self

Remove the persisted DataFrame from memory.

Returns:

Name Type Description
Self Self

Unpersisted Dataset

Source code in src/gentropy/dataset/dataset.py
205
206
207
208
209
210
211
212
def unpersist(self: Self) -> Self:
    """Remove the persisted DataFrame from memory.

    Returns:
        Self: Unpersisted Dataset
    """
    self.df = self._df.unpersist()
    return self

update_quality_flag(qc: Column, flag_condition: Column, flag_text: Enum) -> Column staticmethod

Update the provided quality control list with a new flag if condition is met.

Parameters:

Name Type Description Default
qc Column

Array column with the current list of qc flags.

required
flag_condition Column

This is a column of booleans, signing which row should be flagged

required
flag_text Enum

Text for the new quality control flag

required

Returns:

Name Type Description
Column Column

Array column with the updated list of qc flags.

Source code in src/gentropy/dataset/dataset.py
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
@staticmethod
def update_quality_flag(
    qc: Column, flag_condition: Column, flag_text: Enum
) -> Column:
    """Update the provided quality control list with a new flag if condition is met.

    Args:
        qc (Column): Array column with the current list of qc flags.
        flag_condition (Column): This is a column of booleans, signing which row should be flagged
        flag_text (Enum): Text for the new quality control flag

    Returns:
        Column: Array column with the updated list of qc flags.
    """
    qc = f.when(qc.isNull(), f.array()).otherwise(qc)
    return f.when(
        flag_condition,
        f.array_union(qc, f.array(f.lit(flag_text.value))),
    ).otherwise(qc)

validate_schema() -> None

Validate DataFrame schema against expected class schema.

Raises:

Type Description
ValueError

DataFrame schema is not valid

Source code in src/gentropy/dataset/dataset.py
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
def validate_schema(self: Dataset) -> None:
    """Validate DataFrame schema against expected class schema.

    Raises:
        ValueError: DataFrame schema is not valid
    """
    expected_schema = self._schema
    expected_fields = flatten_schema(expected_schema)
    observed_schema = self._df.schema
    observed_fields = flatten_schema(observed_schema)

    # Unexpected fields in dataset
    if unexpected_field_names := [
        x.name
        for x in observed_fields
        if x.name not in [y.name for y in expected_fields]
    ]:
        raise ValueError(
            f"The {unexpected_field_names} fields are not included in DataFrame schema: {expected_fields}"
        )

    # Required fields not in dataset
    required_fields = [x.name for x in expected_schema if not x.nullable]
    if missing_required_fields := [
        req
        for req in required_fields
        if not any(field.name == req for field in observed_fields)
    ]:
        raise ValueError(
            f"The {missing_required_fields} fields are required but missing: {required_fields}"
        )

    # Fields with duplicated names
    if duplicated_fields := [
        x for x in set(observed_fields) if observed_fields.count(x) > 1
    ]:
        raise ValueError(
            f"The following fields are duplicated in DataFrame schema: {duplicated_fields}"
        )

    # Fields with different datatype
    observed_field_types = {
        field.name: type(field.dataType) for field in observed_fields
    }
    expected_field_types = {
        field.name: type(field.dataType) for field in expected_fields
    }
    if fields_with_different_observed_datatype := [
        name
        for name, observed_type in observed_field_types.items()
        if name in expected_field_types
        and observed_type != expected_field_types[name]
    ]:
        raise ValueError(
            f"The following fields present differences in their datatypes: {fields_with_different_observed_datatype}."
        )