nichecompass.data.prepare_data
- nichecompass.data.prepare_data(adata, cat_covariates_label_encoders, adata_atac=None, counts_key='counts', adj_key='spatial_connectivities', cat_covariates_keys=None, edge_val_ratio=0.1, edge_test_ratio=0.0, node_val_ratio=0.1, node_test_ratio=0.0)
Prepare data for model training including edge-level and node-level train, validation, and test splits.
- Parameters:
adata (
AnnData) – AnnData object with counts stored in ´adata.layers[counts_key]´ or ´adata.X´ depending on ´counts_key´, and sparse adjacency matrix stored in ´adata.obsp[adj_key]´.adata_atac (
Optional[AnnData] (default:None)) – Additional optional AnnData object with paired spatial ATAC data.cat_covariates_label_encoders (
List[dict]) – List of categorical covariates label encoders from the model (label encoding indeces need to be aligned with the ones from the model to get the correct categorical covariates embeddings).counts_key (
Optional[str] (default:'counts')) – Key under which the counts are stored in ´adata.layer´. If ´None´, uses ´adata.X´ as counts.adj_key (
str(default:'spatial_connectivities')) – Key under which the sparse adjacency matrix is stored in ´adata.obsp´.cat_covariates_keys (
Optional[List[str]] (default:None)) – Keys under which the categorical covariates are stored in ´adata.obs´.edge_val_ratio (
float(default:0.1)) – Fraction of the data that is used as validation set on edge-level.edge_test_ratio (
float(default:0.0)) – Fraction of the data that is used as test set on edge-level.node_val_ratio (
float(default:0.1)) – Fraction of the data that is used as validation set on node-level.node_test_ratio (
float(default:0.0)) – Fraction of the data that is used as test set on node-level.
- Return type:
- Returns:
data_dict: Dictionary containing edge-level training, validation and test PyG Data objects and node-level PyG Data object with split masks under keys ´edge_train_data´, ´edge_val_data´, ´edge_test_data´, and ´node_masked_data´ respectively. The edge-level PyG Data objects contain edges in the ´edge_label_index´ attribute and edge labels in the ´edge_label´ attribute.