Creation of torch-geometric data objects from AnnData#

%load_ext autoreload
%autoreload 2
import warnings
import anndata as ad
import squidpy as sq
from geome import transforms
from geome.ann2data import Ann2DataByCategory
from utils.datasets import DatasetHartmann

warnings.filterwarnings('ignore')

All NCEM Datasets#

Load Unprocessed Dataset#

# Mibitof
dataset = DatasetHartmann(data_path='./example_data/hartmann/')
adatas = list(dataset.img_celldata.values())
Loading data from raw files
registering celldata
collecting image-wise celldata
adding graph-level covariates
Loaded 58 images with complete data from 4 patients over 63747 cells with 36 cell features and 8 distinct celltypes.

Some Preprocessing Done Manually#

These processing steps can also be done in the a2c callable if they are given as functions in to the preprocess list

# Merge the list of adatas and convert some string to categories as they should be
adata = ad.concat(adatas)
adata
AnnData object with n_obs × n_vars = 63747 × 36
    obs: 'point', 'cell_id', 'cell_size', 'donor', 'Cluster', 'Cluster_preprocessed'
    obsm: 'spatial', 'node_types'

Creating A2D#

fields = {
    'features':['obs/Cluster_preprocessed','obs/donor','obsm/design_matrix'],
    'labels':['X']
}

Here we list the preprocessing steps that we need to be done on anndata. They take two parameters the adata and the fields. But we don’t use the fields in this example.

from geome.transforms import Compose, Categorize, AddDesignMatrix, AddAdjMatrix


# [
#     lambda x,_: transforms.categorize_obs(x,['donor', 'Cluster_preprocessed', 'point']),
#     lambda x,_: transforms.add_design_matrix(x,'obs/Cluster_preprocessed','obs/donor','design_matrix'),
# ]
adj_matrix_loc = 'obsp/adjacency_matrix_connectivities'


preprocess = Compose([
    Categorize(['donor', 'Cluster_preprocessed', 'point'], axis='obs'),
    AddAdjMatrix(adj_matrix_loc)
])
transform = AddDesignMatrix('obs/Cluster_preprocessed','obs/donor', adj_matrix_loc,'design_matrix')
category_to_iterate = 'point'
adata.obs['point'] # note that the dtype is not categorical.
59191    scMEP_point_1
59192    scMEP_point_1
59193    scMEP_point_1
59194    scMEP_point_1
59195    scMEP_point_1
             ...      
18510    scMEP_point_9
18511    scMEP_point_9
18512    scMEP_point_9
18513    scMEP_point_9
18514    scMEP_point_9
Name: point, Length: 63747, dtype: category
Categories (58, object): ['scMEP_point_1', 'scMEP_point_2', 'scMEP_point_3', 'scMEP_point_4', ..., 'scMEP_point_55', 'scMEP_point_56', 'scMEP_point_57', 'scMEP_point_58']
a2d = Ann2DataByCategory(
    fields=fields,
    category=category_to_iterate,
    preprocess=preprocess,
    transform=transform,
)

Convert AnnData to Data on call#

datas = a2d(adata)
datas
<generator object Ann2Data.__call__ at 0x168a64220>

With the update a2d returns a generator object so that one can process the data when they need it, this is useful for large datasets. If you want to convert the generator to a list you can use list(a2d(adata)).

datas = list(datas)
datas
[Data(features=[1338, 88], labels=[1338, 36]),
 Data(features=[311, 88], labels=[311, 36]),
 Data(features=[768, 88], labels=[768, 36]),
 Data(features=[1020, 88], labels=[1020, 36]),
 Data(features=[2100, 88], labels=[2100, 36]),
 Data(features=[1325, 88], labels=[1325, 36]),
 Data(features=[1091, 88], labels=[1091, 36]),
 Data(features=[1046, 88], labels=[1046, 36]),
 Data(features=[618, 88], labels=[618, 36]),
 Data(features=[61, 88], labels=[61, 36]),
 Data(features=[1316, 88], labels=[1316, 36]),
 Data(features=[1540, 88], labels=[1540, 36]),
 Data(features=[1822, 88], labels=[1822, 36]),
 Data(features=[863, 88], labels=[863, 36]),
 Data(features=[564, 88], labels=[564, 36]),
 Data(features=[1023, 88], labels=[1023, 36]),
 Data(features=[324, 88], labels=[324, 36]),
 Data(features=[287, 88], labels=[287, 36]),
 Data(features=[636, 88], labels=[636, 36]),
 Data(features=[890, 88], labels=[890, 36]),
 Data(features=[1235, 88], labels=[1235, 36]),
 Data(features=[1020, 88], labels=[1020, 36]),
 Data(features=[1241, 88], labels=[1241, 36]),
 Data(features=[1438, 88], labels=[1438, 36]),
 Data(features=[1021, 88], labels=[1021, 36]),
 Data(features=[1632, 88], labels=[1632, 36]),
 Data(features=[780, 88], labels=[780, 36]),
 Data(features=[524, 88], labels=[524, 36]),
 Data(features=[669, 88], labels=[669, 36]),
 Data(features=[241, 88], labels=[241, 36]),
 Data(features=[935, 88], labels=[935, 36]),
 Data(features=[347, 88], labels=[347, 36]),
 Data(features=[1499, 88], labels=[1499, 36]),
 Data(features=[601, 88], labels=[601, 36]),
 Data(features=[2268, 88], labels=[2268, 36]),
 Data(features=[1912, 88], labels=[1912, 36]),
 Data(features=[1678, 88], labels=[1678, 36]),
 Data(features=[1025, 88], labels=[1025, 36]),
 Data(features=[1306, 88], labels=[1306, 36]),
 Data(features=[852, 88], labels=[852, 36]),
 Data(features=[1664, 88], labels=[1664, 36]),
 Data(features=[1698, 88], labels=[1698, 36]),
 Data(features=[1672, 88], labels=[1672, 36]),
 Data(features=[777, 88], labels=[777, 36]),
 Data(features=[556, 88], labels=[556, 36]),
 Data(features=[554, 88], labels=[554, 36]),
 Data(features=[937, 88], labels=[937, 36]),
 Data(features=[1524, 88], labels=[1524, 36]),
 Data(features=[1528, 88], labels=[1528, 36]),
 Data(features=[721, 88], labels=[721, 36]),
 Data(features=[1395, 88], labels=[1395, 36]),
 Data(features=[611, 88], labels=[611, 36]),
 Data(features=[1872, 88], labels=[1872, 36]),
 Data(features=[1217, 88], labels=[1217, 36]),
 Data(features=[1531, 88], labels=[1531, 36]),
 Data(features=[1927, 88], labels=[1927, 36]),
 Data(features=[690, 88], labels=[690, 36]),
 Data(features=[1706, 88], labels=[1706, 36])]
datas[0].features, datas[0].features.shape
(tensor([[0., 0., 0.,  ..., 1., 0., 0.],
         [1., 0., 0.,  ..., 1., 0., 0.],
         [0., 0., 0.,  ..., 1., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 1., 0., 0.],
         [0., 0., 0.,  ..., 1., 0., 0.],
         [0., 0., 0.,  ..., 1., 0., 0.]]),
 torch.Size([1338, 88]))

Squidpy Datasets#

adata = sq.datasets.mibitof()
adata
AnnData object with n_obs × n_vars = 3309 × 36
    obs: 'row_num', 'point', 'cell_id', 'X1', 'center_rowcoord', 'center_colcoord', 'cell_size', 'category', 'donor', 'Cluster', 'batch', 'library_id'
    var: 'mean-0', 'std-0', 'mean-1', 'std-1', 'mean-2', 'std-2'
    uns: 'Cluster_colors', 'batch_colors', 'neighbors', 'spatial', 'umap'
    obsm: 'X_scanorama', 'X_umap', 'spatial'
    obsp: 'connectivities', 'distances'
# fields = {
#     'features':['obs/Cluster','obs/donor'],
#     'labels':['X']
# }

fields = {
    'features':['obs/Cluster','obs/donor'],
    'labels':['X'],
    'edge_index':['uns/edge_index'],
}

# preprocess = [
#     lambda x,_: transforms.add_design_matrix(x,'obs/Cluster','obs/donor','design_matrix')
# ]

adj_matrix_loc = 'obsp/connectivities'

from geome.transforms import AddEdgeIndex
transform = Compose([
    AddEdgeIndex(adj_matrix_loc,'edge_index', overwrite=True),
    AddDesignMatrix('obs/Cluster','obs/donor', adj_matrix_loc,'design_matrix'),

])


category_to_iterate = 'library_id'

a2c = Ann2DataByCategory(fields=fields,
                             category=category_to_iterate,
                             preprocess=None,
                             transform=transform)
datas = a2c(adata)
datas
<generator object Ann2Data.__call__ at 0x2b6623a60>
datas = list(datas)
datas
[Data(edge_index=[2, 8878], features=[1023, 10], labels=[1023, 36]),
 Data(edge_index=[2, 17770], features=[1241, 10], labels=[1241, 36]),
 Data(edge_index=[2, 3944], features=[1045, 10], labels=[1045, 36])]
datas[0].labels, datas[0].labels.shape
(tensor([[-0.0146, -0.2531, -0.0700,  ..., -0.1332, -0.0686, -0.1984],
         [-0.2564, -0.0944, -0.0410,  ..., -0.1053, -0.0211, -0.1020],
         [-0.3227, -0.2246, -0.0606,  ..., -0.1715, -0.0644, -0.0406],
         ...,
         [-0.1450, -0.0382,  0.0800,  ..., -0.2933, -0.2550,  0.2536],
         [-0.1106, -0.2884, -0.0969,  ..., -0.3815, -0.1163,  0.0804],
         [-0.0943, -0.1985, -0.0632,  ..., -0.3534, -0.0252, -0.0550]]),
 torch.Size([1023, 36]))