Skip to content

L2G Trainer

gentropy.method.l2g.trainer.LocusToGeneTrainer dataclass

Modelling of what is the most likely causal gene associated with a given locus.

Source code in src/gentropy/method/l2g/trainer.py
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
@dataclass
class LocusToGeneTrainer:
    """Modelling of what is the most likely causal gene associated with a given locus."""

    _model: LocusToGeneModel
    train_set: L2GFeatureMatrix

    @classmethod
    def train(
        cls: type[LocusToGeneTrainer],
        gold_standard_data: L2GFeatureMatrix,
        l2g_model: LocusToGeneModel,
        evaluate: bool,
        wandb_run_name: str | None = None,
        model_path: str | None = None,
        **hyperparams: dict[str, Any],
    ) -> LocusToGeneModel:
        """Train the Locus to Gene model.

        Args:
            gold_standard_data (L2GFeatureMatrix): Feature matrix for the associations in the gold standard
            l2g_model (LocusToGeneModel): Model to fit to the data on
            evaluate (bool): Whether to evaluate the model on a test set
            wandb_run_name (str | None): Descriptive name for the run to be tracked with W&B
            model_path (str | None): Path to save the model to
            **hyperparams (dict[str, Any]): Hyperparameters to use for the model

        Returns:
            LocusToGeneModel: Trained model
        """
        train, test = gold_standard_data.train_test_split(fraction=0.8)

        model = l2g_model.add_pipeline_stage(l2g_model.estimator).fit(train)

        if evaluate:
            l2g_model.evaluate(
                results=model.predict(test),
                hyperparameters=hyperparams,
                wandb_run_name=wandb_run_name,
                gold_standard_data=gold_standard_data,
            )
        if model_path:
            l2g_model.save(model_path)
        return l2g_model

    @classmethod
    def cross_validate(
        cls: type[LocusToGeneTrainer],
        l2g_model: LocusToGeneModel,
        data: L2GFeatureMatrix,
        num_folds: int,
        param_grid: Optional[list] = None,  # type: ignore
    ) -> LocusToGeneModel:
        """Perform k-fold cross validation on the model.

        By providing a model with a parameter grid, this method will perform k-fold cross validation on the model for each
        combination of parameters and return the best model.

        Args:
            l2g_model (LocusToGeneModel): Model to fit to the data on
            data (L2GFeatureMatrix): Data to perform cross validation on
            num_folds (int): Number of folds to use for cross validation
            param_grid (Optional[list]): List of parameter maps to use for cross validation

        Returns:
            LocusToGeneModel: Trained model fitted with the best hyperparameters

        Raises:
            ValueError: Parameter grid is empty. Cannot perform cross-validation.
            ValueError: Unable to retrieve the best model.
        """
        evaluator = MulticlassClassificationEvaluator()
        params_grid = param_grid or l2g_model.get_param_grid()
        if not param_grid:
            raise ValueError(
                "Parameter grid is empty. Cannot perform cross-validation."
            )
        cv = CrossValidator(
            numFolds=num_folds,
            estimator=l2g_model.estimator,
            estimatorParamMaps=params_grid,
            evaluator=evaluator,
            parallelism=2,
            collectSubModels=False,
            seed=42,
        )

        l2g_model.add_pipeline_stage(cv)  # type: ignore[assignment, unused-ignore]

        # Integrate the best model from the last stage of the pipeline
        if (full_pipeline_model := l2g_model.fit(data).model) is None or not hasattr(
            full_pipeline_model, "stages"
        ):
            raise ValueError("Unable to retrieve the best model.")
        l2g_model.model = full_pipeline_model.stages[-1].bestModel  # type: ignore[assignment, unused-ignore]
        return l2g_model

cross_validate(l2g_model: LocusToGeneModel, data: L2GFeatureMatrix, num_folds: int, param_grid: Optional[list] = None) -> LocusToGeneModel classmethod

Perform k-fold cross validation on the model.

By providing a model with a parameter grid, this method will perform k-fold cross validation on the model for each combination of parameters and return the best model.

Parameters:

Name Type Description Default
l2g_model LocusToGeneModel

Model to fit to the data on

required
data L2GFeatureMatrix

Data to perform cross validation on

required
num_folds int

Number of folds to use for cross validation

required
param_grid Optional[list]

List of parameter maps to use for cross validation

None

Returns:

Name Type Description
LocusToGeneModel LocusToGeneModel

Trained model fitted with the best hyperparameters

Raises:

Type Description
ValueError

Parameter grid is empty. Cannot perform cross-validation.

ValueError

Unable to retrieve the best model.

Source code in src/gentropy/method/l2g/trainer.py
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
@classmethod
def cross_validate(
    cls: type[LocusToGeneTrainer],
    l2g_model: LocusToGeneModel,
    data: L2GFeatureMatrix,
    num_folds: int,
    param_grid: Optional[list] = None,  # type: ignore
) -> LocusToGeneModel:
    """Perform k-fold cross validation on the model.

    By providing a model with a parameter grid, this method will perform k-fold cross validation on the model for each
    combination of parameters and return the best model.

    Args:
        l2g_model (LocusToGeneModel): Model to fit to the data on
        data (L2GFeatureMatrix): Data to perform cross validation on
        num_folds (int): Number of folds to use for cross validation
        param_grid (Optional[list]): List of parameter maps to use for cross validation

    Returns:
        LocusToGeneModel: Trained model fitted with the best hyperparameters

    Raises:
        ValueError: Parameter grid is empty. Cannot perform cross-validation.
        ValueError: Unable to retrieve the best model.
    """
    evaluator = MulticlassClassificationEvaluator()
    params_grid = param_grid or l2g_model.get_param_grid()
    if not param_grid:
        raise ValueError(
            "Parameter grid is empty. Cannot perform cross-validation."
        )
    cv = CrossValidator(
        numFolds=num_folds,
        estimator=l2g_model.estimator,
        estimatorParamMaps=params_grid,
        evaluator=evaluator,
        parallelism=2,
        collectSubModels=False,
        seed=42,
    )

    l2g_model.add_pipeline_stage(cv)  # type: ignore[assignment, unused-ignore]

    # Integrate the best model from the last stage of the pipeline
    if (full_pipeline_model := l2g_model.fit(data).model) is None or not hasattr(
        full_pipeline_model, "stages"
    ):
        raise ValueError("Unable to retrieve the best model.")
    l2g_model.model = full_pipeline_model.stages[-1].bestModel  # type: ignore[assignment, unused-ignore]
    return l2g_model

train(gold_standard_data: L2GFeatureMatrix, l2g_model: LocusToGeneModel, evaluate: bool, wandb_run_name: str | None = None, model_path: str | None = None, **hyperparams: dict[str, Any]) -> LocusToGeneModel classmethod

Train the Locus to Gene model.

Parameters:

Name Type Description Default
gold_standard_data L2GFeatureMatrix

Feature matrix for the associations in the gold standard

required
l2g_model LocusToGeneModel

Model to fit to the data on

required
evaluate bool

Whether to evaluate the model on a test set

required
wandb_run_name str | None

Descriptive name for the run to be tracked with W&B

None
model_path str | None

Path to save the model to

None
**hyperparams dict[str, Any]

Hyperparameters to use for the model

{}

Returns:

Name Type Description
LocusToGeneModel LocusToGeneModel

Trained model

Source code in src/gentropy/method/l2g/trainer.py
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
@classmethod
def train(
    cls: type[LocusToGeneTrainer],
    gold_standard_data: L2GFeatureMatrix,
    l2g_model: LocusToGeneModel,
    evaluate: bool,
    wandb_run_name: str | None = None,
    model_path: str | None = None,
    **hyperparams: dict[str, Any],
) -> LocusToGeneModel:
    """Train the Locus to Gene model.

    Args:
        gold_standard_data (L2GFeatureMatrix): Feature matrix for the associations in the gold standard
        l2g_model (LocusToGeneModel): Model to fit to the data on
        evaluate (bool): Whether to evaluate the model on a test set
        wandb_run_name (str | None): Descriptive name for the run to be tracked with W&B
        model_path (str | None): Path to save the model to
        **hyperparams (dict[str, Any]): Hyperparameters to use for the model

    Returns:
        LocusToGeneModel: Trained model
    """
    train, test = gold_standard_data.train_test_split(fraction=0.8)

    model = l2g_model.add_pipeline_stage(l2g_model.estimator).fit(train)

    if evaluate:
        l2g_model.evaluate(
            results=model.predict(test),
            hyperparameters=hyperparams,
            wandb_run_name=wandb_run_name,
            gold_standard_data=gold_standard_data,
        )
    if model_path:
        l2g_model.save(model_path)
    return l2g_model