Skip to content

Commit

Permalink
Merge pull request #16 from AI-sandbox/feature/scatter-plot
Browse files Browse the repository at this point in the history
Enhance the scatter plot to support diploid data and update maasMDS and mdPCA to compute and store haplotypes_, samples_, n_haplotypes, and n_samples
  • Loading branch information
salcc authored Jan 3, 2025
2 parents 5462ae3 + 8a5f1a1 commit bbb7aa7
Show file tree
Hide file tree
Showing 4 changed files with 277 additions and 53 deletions.
134 changes: 94 additions & 40 deletions demos/mdPCA_maasMDS.ipynb

Large diffs are not rendered by default.

84 changes: 83 additions & 1 deletion snputils/processing/maasmds.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,8 @@ def __init__(
self.__n_components = n_components
self.__rsid_or_chrompos = rsid_or_chrompos
self.__X_new_ = None # Store transformed SNP data
self.__haplotypes_ = None # Store haplotypes after filtering if min_percent_snps > 0
self.__samples_ = None # Store samples after filtering if min_percent_snps > 0

# Fit and transform if a `snpobj`, `laiobj`, `labels_file`, and `ancestry` are provided
if self.snpobj is not None and self.laiobj is not None and self.labels_file is not None and self.ancestry is not None:
Expand Down Expand Up @@ -417,8 +419,11 @@ def X_new_(self) -> Optional[np.ndarray]:
Retrieve `X_new_`.
Returns:
**array of shape (n_samples, n_components):**
**array of shape (n_haplotypes_, n_components):**
The transformed SNP data projected onto the `n_components` principal components.
n_haplotypes_ is the number of haplotypes, potentially reduced if filtering is applied
(`min_percent_snps > 0`). For diploid individuals without filtering, the shape is
`(n_samples * 2, n_components)`.
"""
return self.__X_new_

Expand All @@ -429,6 +434,82 @@ def X_new_(self, x: np.ndarray) -> None:
"""
self.__X_new_ = x

@property
def haplotypes_(self) -> Optional[List[str]]:
"""
Retrieve `haplotypes_`.
Returns:
list of str:
A list of unique haplotype identifiers.
"""
if isinstance(self.__haplotypes_, np.ndarray):
return self.__haplotypes_.ravel().tolist() # Flatten and convert NumPy array to a list
elif isinstance(self.__haplotypes_, list):
if len(self.__haplotypes_) == 1 and isinstance(self.__haplotypes_[0], np.ndarray):
return self.__haplotypes_[0].ravel().tolist() # Handle list containing a single array
return self.__haplotypes_ # Already a flat list
elif self.__haplotypes_ is None:
return None # If no haplotypes are set
else:
raise TypeError("`haplotypes_` must be a list or a NumPy array.")

@haplotypes_.setter
def haplotypes_(self, x: Union[np.ndarray, List[str]]) -> None:
"""
Update `haplotypes_`.
"""
if isinstance(x, np.ndarray):
self.__haplotypes_ = x.ravel().tolist() # Flatten and convert to a list
elif isinstance(x, list):
if len(x) == 1 and isinstance(x[0], np.ndarray): # Handle list containing a single array
self.__haplotypes_ = x[0].ravel().tolist()
else:
self.__haplotypes_ = x # Use directly if already a list
else:
raise TypeError("`x` must be a list or a NumPy array.")

@property
def samples_(self) -> Optional[List[str]]:
"""
Retrieve `samples_`.
Returns:
list of str:
A list of sample identifiers based on `haplotypes_` and `average_strands`.
"""
haplotypes = self.haplotypes_
if haplotypes is None:
return None
if self.__average_strands:
return haplotypes
else:
return [x[:-2] for x in haplotypes]

@property
def n_haplotypes(self) -> Optional[int]:
"""
Retrieve `n_haplotypes`.
Returns:
**int:**
The total number of haplotypes, potentially reduced if filtering is applied
(`min_percent_snps > 0`).
"""
return len(self.__haplotypes_)

@property
def n_samples(self) -> Optional[int]:
"""
Retrieve `n_samples`.
Returns:
**int:**
The total number of samples, potentially reduced if filtering is applied
(`min_percent_snps > 0`).
"""
return len(np.unique(self.samples_))

@staticmethod
def _load_masks_file(masks_file):
mask_files = np.load(masks_file, allow_pickle=True)
Expand Down Expand Up @@ -512,3 +593,4 @@ def fit_transform(
distance_list = [[distance_mat(first=masks[0][self.ancestry], dist_func=self.distance_type)]]

self.X_new_ = mds_transform(distance_list, groups, weights, ind_ID_list, self.n_components)
self.haplotypes_ = ind_ID_list
83 changes: 83 additions & 0 deletions snputils/processing/mdpca.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,8 @@ def __init__(
self.__rsid_or_chrompos = rsid_or_chrompos
self.__percent_vals_masked = percent_vals_masked
self.__X_new_ = None # Store transformed SNP data
self.__haplotypes_ = None # Store haplotypes after filtering if min_percent_snps > 0
self.__samples_ = None # Store samples after filtering if min_percent_snps > 0

# Fit and transform if a `snpobj`, `laiobj`, `labels_file`, and `ancestry` are provided
if self.snpobj is not None and self.laiobj is not None and self.labels_file is not None and self.ancestry is not None:
Expand Down Expand Up @@ -493,6 +495,9 @@ def X_new_(self) -> Optional[np.ndarray]:
Returns:
**array of shape (n_samples, n_components):**
The transformed SNP data projected onto the `n_components` principal components.
n_haplotypes_ is the number of haplotypes, potentially reduced if filtering is applied
(`min_percent_snps > 0`). For diploid individuals without filtering, the shape is
`(n_samples * 2, n_components)`.
"""
return self.__X_new_

Expand All @@ -503,6 +508,82 @@ def X_new_(self, x: np.ndarray) -> None:
"""
self.__X_new_ = x

@property
def haplotypes_(self) -> Optional[List[str]]:
"""
Retrieve `haplotypes_`.
Returns:
list of str:
A list of unique haplotype identifiers.
"""
if isinstance(self.__haplotypes_, np.ndarray):
return self.__haplotypes_.ravel().tolist() # Flatten and convert NumPy array to a list
elif isinstance(self.__haplotypes_, list):
if len(self.__haplotypes_) == 1 and isinstance(self.__haplotypes_[0], np.ndarray):
return self.__haplotypes_[0].ravel().tolist() # Handle list containing a single array
return self.__haplotypes_ # Already a flat list
elif self.__haplotypes_ is None:
return None # If no haplotypes are set
else:
raise TypeError("`haplotypes_` must be a list or a NumPy array.")

@haplotypes_.setter
def haplotypes_(self, x: Union[np.ndarray, List[str]]) -> None:
"""
Update `haplotypes_`.
"""
if isinstance(x, np.ndarray):
self.__haplotypes_ = x.ravel().tolist() # Flatten and convert to a list
elif isinstance(x, list):
if len(x) == 1 and isinstance(x[0], np.ndarray): # Handle list containing a single array
self.__haplotypes_ = x[0].ravel().tolist()
else:
self.__haplotypes_ = x # Use directly if already a list
else:
raise TypeError("`x` must be a list or a NumPy array.")

@property
def samples_(self) -> Optional[List[str]]:
"""
Retrieve `samples_`.
Returns:
list of str:
A list of sample identifiers based on `haplotypes_` and `average_strands`.
"""
haplotypes = self.haplotypes_
if haplotypes is None:
return None
if self.__average_strands:
return haplotypes
else:
return [x[:-2] for x in haplotypes]

@property
def n_haplotypes(self) -> Optional[int]:
"""
Retrieve `n_haplotypes`.
Returns:
**int:**
The total number of haplotypes, potentially reduced if filtering is applied
(`min_percent_snps > 0`).
"""
return len(self.haplotypes_)

@property
def n_samples(self) -> Optional[int]:
"""
Retrieve `n_samples`.
Returns:
**int:**
The total number of samples, potentially reduced if filtering is applied
(`min_percent_snps > 0`).
"""
return len(np.unique(self.samples_))

def copy(self) -> 'mdPCA':
"""
Create and return a copy of `self`.
Expand Down Expand Up @@ -919,4 +1000,6 @@ def fit_transform(
weights
)

self.haplotypes_ = ind_id_list

return self.X_new_
29 changes: 17 additions & 12 deletions snputils/visualization/scatter_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def scatter(
Args:
dimredobj (np.ndarray):
Reduced dimensionality data; expected to have n_samples x 2 shape.
Reduced dimensionality data; expected to have `(n_haplotypes, 2)` shape.
labels_file (str):
Path to a TSV file with columns 'indID' and 'label', providing labels for coloring and annotating points.
abbreviation_inside_dots (bool):
Expand All @@ -41,29 +41,34 @@ def scatter(
Returns:
None
"""
# Load data from the dimension-reduced object and labels file
data = dimredobj.X_new_ # 2D data points for plotting
labels_df = pd.read_csv(labels_file, sep='\t') # Load labels from TSV
# Load labels from TSV
labels_df = pd.read_csv(labels_file, sep='\t')

# Initialize the plot
fig, ax = plt.subplots(figsize=(10, 8))
# Filter labels based on the indIDs in dimredobj
sample_ids = dimredobj.samples_
filtered_labels_df = labels_df[labels_df['indID'].isin(sample_ids)]

# Define unique colors for each group label, either from color_palette or defaulting to 'tab10'
unique_labels = labels_df['label'].unique()
unique_labels = filtered_labels_df['label'].unique()
colors = color_palette if color_palette else cm.get_cmap('tab10', len(unique_labels))

# Initialize the plot
fig, ax = plt.subplots(figsize=(10, 8))

# Calculate the overall center of the plot (used for positioning arrows)
plot_center = data.mean(axis=0)
plot_center = dimredobj.X_new_.mean(axis=0)

# Dictionary to hold centroid positions for each label
centroids = {}

# Iterate through each unique label to plot points and centroids
# Plot data points and centroids by label
for i, label in enumerate(unique_labels):
# Filter points corresponding to the current label
indices = labels_df[labels_df['label'] == label].index
points = data[indices]
# Get sample IDs corresponding to the current label
sample_ids_for_label = filtered_labels_df[filtered_labels_df['label'] == label]['indID']

# Filter points based on sample IDs
points = dimredobj.X_new_[np.isin(dimredobj.samples_, sample_ids_for_label)]

if dots:
# Plot individual points for the current group
ax.scatter(points[:, 0], points[:, 1], s=30, color=colors(i), alpha=0.6, label=label)
Expand Down

0 comments on commit bbb7aa7

Please sign in to comment.