L2G Trainer
gentropy.method.l2g.trainer.LocusToGeneTrainer
dataclass
¶
Modelling of what is the most likely causal gene associated with a given locus.
Source code in src/gentropy/method/l2g/trainer.py
15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 |
|
cross_validate(l2g_model: LocusToGeneModel, data: L2GFeatureMatrix, num_folds: int, param_grid: Optional[list] = None) -> LocusToGeneModel
classmethod
¶
Perform k-fold cross validation on the model.
By providing a model with a parameter grid, this method will perform k-fold cross validation on the model for each combination of parameters and return the best model.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
l2g_model |
LocusToGeneModel
|
Model to fit to the data on |
required |
data |
L2GFeatureMatrix
|
Data to perform cross validation on |
required |
num_folds |
int
|
Number of folds to use for cross validation |
required |
param_grid |
Optional[list]
|
List of parameter maps to use for cross validation |
None
|
Returns:
Name | Type | Description |
---|---|---|
LocusToGeneModel |
LocusToGeneModel
|
Trained model fitted with the best hyperparameters |
Raises:
Type | Description |
---|---|
ValueError
|
Parameter grid is empty. Cannot perform cross-validation. |
ValueError
|
Unable to retrieve the best model. |
Source code in src/gentropy/method/l2g/trainer.py
60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 |
|
train(gold_standard_data: L2GFeatureMatrix, l2g_model: LocusToGeneModel, evaluate: bool, wandb_run_name: str | None = None, model_path: str | None = None, **hyperparams: dict[str, Any]) -> LocusToGeneModel
classmethod
¶
Train the Locus to Gene model.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
gold_standard_data |
L2GFeatureMatrix
|
Feature matrix for the associations in the gold standard |
required |
l2g_model |
LocusToGeneModel
|
Model to fit to the data on |
required |
evaluate |
bool
|
Whether to evaluate the model on a test set |
required |
wandb_run_name |
str | None
|
Descriptive name for the run to be tracked with W&B |
None
|
model_path |
str | None
|
Path to save the model to |
None
|
**hyperparams |
dict[str, Any]
|
Hyperparameters to use for the model |
{}
|
Returns:
Name | Type | Description |
---|---|---|
LocusToGeneModel |
LocusToGeneModel
|
Trained model |
Source code in src/gentropy/method/l2g/trainer.py
22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 |
|