Skip to content

locus_to_gene

gentropy.l2g.LocusToGeneStep

Locus to gene step.

Source code in src/gentropy/l2g.py
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
class LocusToGeneStep:
    """Locus to gene step."""

    def __init__(
        self,
        session: Session,
        run_mode: str,
        model_path: str,
        predictions_path: str,
        credible_set_path: str,
        variant_gene_path: str,
        colocalisation_path: str,
        study_index_path: str,
        gold_standard_curation_path: str,
        gene_interactions_path: str,
        features_list: list[str],
        hyperparameters: dict[str, Any],
        wandb_run_name: str | None = None,
        perform_cross_validation: bool = False,
    ) -> None:
        """Run step.

        Args:
            session (Session): Session object.
            run_mode (str): One of "train" or "predict".
            model_path (str): Path to save the model.
            predictions_path (str): Path to save the predictions.
            credible_set_path (str): Path to credible set Parquet files.
            variant_gene_path (str): Path to variant to gene Parquet files.
            colocalisation_path (str): Path to colocalisation Parquet files.
            study_index_path (str): Path to study index Parquet files.
            gold_standard_curation_path (str): Path to gold standard curation JSON files.
            gene_interactions_path (str): Path to gene interactions Parquet files.
            features_list (list[str]): List of features to use.
            hyperparameters (dict[str, Any]): Hyperparameters for the model.
            wandb_run_name (str | None): Name of the run to be tracked on W&B.
            perform_cross_validation (bool): Whether to perform cross validation.

        Raises:
            ValueError: if run_mode is not one of "train" or "predict".
        """
        print("Sci-kit learn version: ", sklearn.__version__)  # noqa: T201
        if run_mode not in ["train", "predict"]:
            raise ValueError(
                f"run_mode must be one of 'train' or 'predict', got {run_mode}"
            )
        # Load common inputs
        credible_set = StudyLocus.from_parquet(
            session, credible_set_path, recursiveFileLookup=True
        )
        studies = StudyIndex.from_parquet(
            session, study_index_path, recursiveFileLookup=True
        )
        v2g = V2G.from_parquet(session, variant_gene_path)
        coloc = Colocalisation.from_parquet(
            session, colocalisation_path, recursiveFileLookup=True
        )

        if run_mode == "predict":
            if not model_path or not predictions_path:
                raise ValueError(
                    "model_path and predictions_path must be set for predict mode."
                )
            predictions = L2GPrediction.from_credible_set(
                model_path, list(features_list), credible_set, studies, v2g, coloc
            )
            predictions.df.write.mode(session.write_mode).parquet(predictions_path)
            session.logger.info(predictions_path)
        elif (
            run_mode == "train"
            and gold_standard_curation_path
            and gene_interactions_path
        ):
            # Process gold standard and L2G features
            gs_curation = session.spark.read.json(gold_standard_curation_path)
            interactions = session.spark.read.parquet(gene_interactions_path)
            study_locus_overlap = StudyLocus(
                # We just extract overlaps of associations in the gold standard. This parsing is a duplication of the one in the gold standard curation,
                # but we need to do it here to be able to parse gold standards later
                _df=credible_set.df.join(
                    f.broadcast(
                        gs_curation.select(
                            StudyLocus.assign_study_locus_id(
                                f.col("association_info.otg_id"),  # studyId
                                f.concat_ws(  # variantId
                                    "_",
                                    f.col("sentinel_variant.locus_GRCh38.chromosome"),
                                    f.col("sentinel_variant.locus_GRCh38.position"),
                                    f.col("sentinel_variant.alleles.reference"),
                                    f.col("sentinel_variant.alleles.alternative"),
                                ),
                            ).alias("studyLocusId"),
                        )
                    ),
                    "studyLocusId",
                    "inner",
                ),
                _schema=StudyLocus.get_schema(),
            ).find_overlaps(studies)

            gold_standards = L2GGoldStandard.from_otg_curation(
                gold_standard_curation=gs_curation,
                v2g=v2g,
                study_locus_overlap=study_locus_overlap,
                interactions=interactions,
            )

            fm = L2GFeatureMatrix.generate_features(
                features_list=features_list,
                credible_set=credible_set,
                study_index=studies,
                variant_gene=v2g,
                colocalisation=coloc,
            )

            data = (
                # Annotate gold standards with features
                L2GFeatureMatrix(
                    _df=fm.df.join(
                        f.broadcast(
                            gold_standards.df.drop("variantId", "studyId", "sources")
                        ),
                        on=["studyLocusId", "geneId"],
                        how="inner",
                    ),
                    _schema=L2GFeatureMatrix.get_schema(),
                )
                .fill_na()
                .select_features(list(features_list))
            )

            # Instantiate classifier
            estimator = SparkXGBClassifier(
                eval_metric="logloss",
                features_col="features",
                label_col="label",
                max_depth=5,
            )
            l2g_model = LocusToGeneModel(
                features_list=list(features_list), estimator=estimator
            )
            if perform_cross_validation:
                # Perform cross validation to extract what are the best hyperparameters
                cv_folds = hyperparameters.get("cross_validation_folds", 5)
                LocusToGeneTrainer.cross_validate(
                    l2g_model=l2g_model,
                    data=data,
                    num_folds=cv_folds,
                )
            else:
                # Train model
                LocusToGeneTrainer.train(
                    gold_standard_data=data,
                    l2g_model=l2g_model,
                    model_path=model_path,
                    evaluate=True,
                    wandb_run_name=wandb_run_name,
                    **hyperparameters,
                )
                session.logger.info(model_path)

__init__(session: Session, run_mode: str, model_path: str, predictions_path: str, credible_set_path: str, variant_gene_path: str, colocalisation_path: str, study_index_path: str, gold_standard_curation_path: str, gene_interactions_path: str, features_list: list[str], hyperparameters: dict[str, Any], wandb_run_name: str | None = None, perform_cross_validation: bool = False) -> None

Run step.

Parameters:

Name Type Description Default
session Session

Session object.

required
run_mode str

One of "train" or "predict".

required
model_path str

Path to save the model.

required
predictions_path str

Path to save the predictions.

required
credible_set_path str

Path to credible set Parquet files.

required
variant_gene_path str

Path to variant to gene Parquet files.

required
colocalisation_path str

Path to colocalisation Parquet files.

required
study_index_path str

Path to study index Parquet files.

required
gold_standard_curation_path str

Path to gold standard curation JSON files.

required
gene_interactions_path str

Path to gene interactions Parquet files.

required
features_list list[str]

List of features to use.

required
hyperparameters dict[str, Any]

Hyperparameters for the model.

required
wandb_run_name str | None

Name of the run to be tracked on W&B.

None
perform_cross_validation bool

Whether to perform cross validation.

False

Raises:

Type Description
ValueError

if run_mode is not one of "train" or "predict".

Source code in src/gentropy/l2g.py
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
def __init__(
    self,
    session: Session,
    run_mode: str,
    model_path: str,
    predictions_path: str,
    credible_set_path: str,
    variant_gene_path: str,
    colocalisation_path: str,
    study_index_path: str,
    gold_standard_curation_path: str,
    gene_interactions_path: str,
    features_list: list[str],
    hyperparameters: dict[str, Any],
    wandb_run_name: str | None = None,
    perform_cross_validation: bool = False,
) -> None:
    """Run step.

    Args:
        session (Session): Session object.
        run_mode (str): One of "train" or "predict".
        model_path (str): Path to save the model.
        predictions_path (str): Path to save the predictions.
        credible_set_path (str): Path to credible set Parquet files.
        variant_gene_path (str): Path to variant to gene Parquet files.
        colocalisation_path (str): Path to colocalisation Parquet files.
        study_index_path (str): Path to study index Parquet files.
        gold_standard_curation_path (str): Path to gold standard curation JSON files.
        gene_interactions_path (str): Path to gene interactions Parquet files.
        features_list (list[str]): List of features to use.
        hyperparameters (dict[str, Any]): Hyperparameters for the model.
        wandb_run_name (str | None): Name of the run to be tracked on W&B.
        perform_cross_validation (bool): Whether to perform cross validation.

    Raises:
        ValueError: if run_mode is not one of "train" or "predict".
    """
    print("Sci-kit learn version: ", sklearn.__version__)  # noqa: T201
    if run_mode not in ["train", "predict"]:
        raise ValueError(
            f"run_mode must be one of 'train' or 'predict', got {run_mode}"
        )
    # Load common inputs
    credible_set = StudyLocus.from_parquet(
        session, credible_set_path, recursiveFileLookup=True
    )
    studies = StudyIndex.from_parquet(
        session, study_index_path, recursiveFileLookup=True
    )
    v2g = V2G.from_parquet(session, variant_gene_path)
    coloc = Colocalisation.from_parquet(
        session, colocalisation_path, recursiveFileLookup=True
    )

    if run_mode == "predict":
        if not model_path or not predictions_path:
            raise ValueError(
                "model_path and predictions_path must be set for predict mode."
            )
        predictions = L2GPrediction.from_credible_set(
            model_path, list(features_list), credible_set, studies, v2g, coloc
        )
        predictions.df.write.mode(session.write_mode).parquet(predictions_path)
        session.logger.info(predictions_path)
    elif (
        run_mode == "train"
        and gold_standard_curation_path
        and gene_interactions_path
    ):
        # Process gold standard and L2G features
        gs_curation = session.spark.read.json(gold_standard_curation_path)
        interactions = session.spark.read.parquet(gene_interactions_path)
        study_locus_overlap = StudyLocus(
            # We just extract overlaps of associations in the gold standard. This parsing is a duplication of the one in the gold standard curation,
            # but we need to do it here to be able to parse gold standards later
            _df=credible_set.df.join(
                f.broadcast(
                    gs_curation.select(
                        StudyLocus.assign_study_locus_id(
                            f.col("association_info.otg_id"),  # studyId
                            f.concat_ws(  # variantId
                                "_",
                                f.col("sentinel_variant.locus_GRCh38.chromosome"),
                                f.col("sentinel_variant.locus_GRCh38.position"),
                                f.col("sentinel_variant.alleles.reference"),
                                f.col("sentinel_variant.alleles.alternative"),
                            ),
                        ).alias("studyLocusId"),
                    )
                ),
                "studyLocusId",
                "inner",
            ),
            _schema=StudyLocus.get_schema(),
        ).find_overlaps(studies)

        gold_standards = L2GGoldStandard.from_otg_curation(
            gold_standard_curation=gs_curation,
            v2g=v2g,
            study_locus_overlap=study_locus_overlap,
            interactions=interactions,
        )

        fm = L2GFeatureMatrix.generate_features(
            features_list=features_list,
            credible_set=credible_set,
            study_index=studies,
            variant_gene=v2g,
            colocalisation=coloc,
        )

        data = (
            # Annotate gold standards with features
            L2GFeatureMatrix(
                _df=fm.df.join(
                    f.broadcast(
                        gold_standards.df.drop("variantId", "studyId", "sources")
                    ),
                    on=["studyLocusId", "geneId"],
                    how="inner",
                ),
                _schema=L2GFeatureMatrix.get_schema(),
            )
            .fill_na()
            .select_features(list(features_list))
        )

        # Instantiate classifier
        estimator = SparkXGBClassifier(
            eval_metric="logloss",
            features_col="features",
            label_col="label",
            max_depth=5,
        )
        l2g_model = LocusToGeneModel(
            features_list=list(features_list), estimator=estimator
        )
        if perform_cross_validation:
            # Perform cross validation to extract what are the best hyperparameters
            cv_folds = hyperparameters.get("cross_validation_folds", 5)
            LocusToGeneTrainer.cross_validate(
                l2g_model=l2g_model,
                data=data,
                num_folds=cv_folds,
            )
        else:
            # Train model
            LocusToGeneTrainer.train(
                gold_standard_data=data,
                l2g_model=l2g_model,
                model_path=model_path,
                evaluate=True,
                wandb_run_name=wandb_run_name,
                **hyperparameters,
            )
            session.logger.info(model_path)